Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1import os 

2import sys 

3from abc import ABC, abstractmethod 

4import pickle 

5from subprocess import Popen, PIPE 

6from ase.calculators.calculator import Calculator, all_properties 

7 

8 

9class PackedCalculator(ABC): 

10 """Portable calculator for use via PythonSubProcessCalculator. 

11 

12 This class allows creating and talking to a calculator which 

13 exists inside a different process, possibly with MPI or srun. 

14 

15 Use this when you want to use ASE mostly in serial, but run some 

16 calculations in a parallel Python environment. 

17 

18 Most existing calculators can be used this way through the 

19 NamedPackedCalculator implementation. To customize the behaviour 

20 for other calculators, write a custom class inheriting this one. 

21 

22 Example:: 

23 

24 from ase.build import bulk 

25 

26 atoms = bulk('Au') 

27 pack = NamedPackedCalculator('emt') 

28 

29 with pack.calculator() as atoms.calc: 

30 energy = atoms.get_potential_energy() 

31 

32 The computation takes place inside a subprocess which lives as long 

33 as the with statement. 

34 """ 

35 

36 @abstractmethod 

37 def unpack_calculator(self) -> Calculator: 

38 """Return the calculator packed inside. 

39 

40 This method will be called inside the subprocess doing 

41 computations.""" 

42 

43 def calculator(self, mpi_command=None) -> 'PythonSubProcessCalculator': 

44 """Return a PythonSubProcessCalculator for this calculator. 

45 

46 The subprocess calculator wraps a subprocess containing 

47 the actual calculator, and computations are done inside that 

48 subprocess.""" 

49 return PythonSubProcessCalculator(self, mpi_command=mpi_command) 

50 

51 

52class NamedPackedCalculator(PackedCalculator): 

53 """PackedCalculator implementation which works with standard calculators. 

54 

55 This works with calculators known by ase.calculators.calculator.""" 

56 

57 def __init__(self, name, kwargs=None): 

58 self._name = name 

59 if kwargs is None: 

60 kwargs = {} 

61 self._kwargs = kwargs 

62 

63 def unpack_calculator(self): 

64 from ase.calculators.calculator import get_calculator_class 

65 cls = get_calculator_class(self._name) 

66 return cls(**self._kwargs) 

67 

68 def __repr__(self): 

69 return f'{self.__class__.__name__}({self._name}, {self._kwargs})' 

70 

71 

72class MPICommand: 

73 def __init__(self, argv): 

74 self.argv = argv 

75 

76 @classmethod 

77 def python_argv(cls): 

78 return [sys.executable, '-m', 'ase.calculators.subprocesscalculator'] 

79 

80 @classmethod 

81 def parallel(cls, nprocs, mpi_argv=tuple()): 

82 return cls(['mpiexec', '-n', str(nprocs)] 

83 + list(mpi_argv) 

84 + cls.python_argv() 

85 + ['mpi4py']) 

86 

87 @classmethod 

88 def serial(cls): 

89 return MPICommand(cls.python_argv() + ['standard']) 

90 

91 def execute(self): 

92 # On this computer (Ubuntu 20.04 + OpenMPI) the subprocess crashes 

93 # without output during startup if os.environ is not passed along. 

94 # Hence we pass os.environ. Not sure if this is a machine thing 

95 # or in general. --askhl 

96 return Popen(self.argv, stdout=PIPE, 

97 stdin=PIPE, env=os.environ) 

98 

99 

100def gpaw_process(ncores=1, **kwargs): 

101 packed = NamedPackedCalculator('gpaw', kwargs) 

102 mpicommand = MPICommand([ 

103 sys.executable, '-m', 'gpaw', '-P', str(ncores), 'python', '-m', 

104 'ase.calculators.subprocesscalculator', 'standard', 

105 ]) 

106 return PythonSubProcessCalculator(packed, mpicommand) 

107 

108 

109class PythonSubProcessCalculator(Calculator): 

110 """Calculator for running calculations in external processes. 

111 

112 TODO: This should work with arbitrary commands including MPI stuff. 

113 

114 This calculator runs a subprocess wherein it sets up an 

115 actual calculator. Calculations are forwarded through pickle 

116 to that calculator, which returns results through pickle.""" 

117 implemented_properties = list(all_properties) 

118 

119 def __init__(self, calc_input, mpi_command=None): 

120 super().__init__() 

121 

122 # self.proc = None 

123 self.calc_input = calc_input 

124 if mpi_command is None: 

125 mpi_command = MPICommand.serial() 

126 self.mpi_command = mpi_command 

127 

128 self.protocol = None 

129 

130 def set(self, **kwargs): 

131 if hasattr(self, 'client'): 

132 raise RuntimeError('No setting things for now, thanks') 

133 

134 def __repr__(self): 

135 return '{}({})'.format(type(self).__name__, 

136 self.calc_input) 

137 

138 def __enter__(self): 

139 assert self.protocol is None 

140 proc = self.mpi_command.execute() 

141 self.protocol = Protocol(proc) 

142 self.protocol.send(self.calc_input) 

143 return self 

144 

145 def __exit__(self, *args): 

146 self.protocol.send('stop') 

147 self.protocol.proc.communicate() 

148 self.protocol = None 

149 

150 def _run_calculation(self, atoms, properties, system_changes): 

151 self.protocol.send('calculate') 

152 self.protocol.send((atoms, properties, system_changes)) 

153 

154 def calculate(self, atoms, properties, system_changes): 

155 Calculator.calculate(self, atoms, properties, system_changes) 

156 # We send a pickle of self.atoms because this is a fresh copy 

157 # of the input, but without an unpicklable calculator: 

158 self._run_calculation(self.atoms.copy(), properties, system_changes) 

159 results = self.protocol.recv() 

160 self.results.update(results) 

161 

162 def backend(self): 

163 return ParallelBackendInterface(self) 

164 

165 

166class Protocol: 

167 def __init__(self, proc): 

168 self.proc = proc 

169 

170 def send(self, obj): 

171 pickle.dump(obj, self.proc.stdin) 

172 self.proc.stdin.flush() 

173 

174 def recv(self): 

175 response_type, value = pickle.load(self.proc.stdout) 

176 

177 if response_type == 'raise': 

178 raise value 

179 

180 assert response_type == 'return' 

181 return value 

182 

183 

184class MockMethod: 

185 def __init__(self, name, calc): 

186 self.name = name 

187 self.calc = calc 

188 

189 def __call__(self, *args, **kwargs): 

190 protocol = self.calc.protocol 

191 protocol.send('callmethod') 

192 protocol.send([self.name, args, kwargs]) 

193 return protocol.recv() 

194 

195 

196class ParallelBackendInterface: 

197 def __init__(self, calc): 

198 self.calc = calc 

199 

200 def __getattr__(self, name): 

201 return MockMethod(name, self.calc) 

202 

203 

204run_modes = {'standard', 'mpi4py'} 

205 

206 

207def callmethod(calc, attrname, args, kwargs): 

208 method = getattr(calc, attrname) 

209 value = method(*args, **kwargs) 

210 return value 

211 

212 

213def callfunction(func, args, kwargs): 

214 return func(*args, **kwargs) 

215 

216 

217def calculate(calc, atoms, properties, system_changes): 

218 # Again we need formalization of the results/outputs, and 

219 # a way to programmatically access all available properties. 

220 # We do a wild hack for now: 

221 calc.results.clear() 

222 # If we don't clear(), the caching is broken! For stress. 

223 # But not for forces. What dark magic from the depths of the 

224 # underworld is at play here? 

225 calc.calculate(atoms=atoms, properties=properties, 

226 system_changes=system_changes) 

227 results = calc.results 

228 return results 

229 

230 

231def bad_mode(): 

232 return SystemExit(f'sys.argv[1] must be one of {run_modes}') 

233 

234 

235def parallel_startup(): 

236 try: 

237 run_mode = sys.argv[1] 

238 except IndexError: 

239 raise bad_mode() 

240 

241 if run_mode not in run_modes: 

242 raise bad_mode() 

243 

244 if run_mode == 'mpi4py': 

245 # We must import mpi4py before the rest of ASE, or world will not 

246 # be correctly initialized. 

247 import mpi4py # noqa 

248 

249 # We switch stdout so stray print statements won't interfere with outputs: 

250 binary_stdout = sys.stdout.buffer 

251 sys.stdout = sys.stderr 

252 

253 return Client(input_fd=sys.stdin.buffer, 

254 output_fd=binary_stdout) 

255 

256 

257class Client: 

258 def __init__(self, input_fd, output_fd): 

259 from ase.parallel import world 

260 self._world = world 

261 self.input_fd = input_fd 

262 self.output_fd = output_fd 

263 

264 def recv(self): 

265 from ase.parallel import broadcast 

266 if self._world.rank == 0: 

267 obj = pickle.load(self.input_fd) 

268 else: 

269 obj = None 

270 

271 obj = broadcast(obj, 0, self._world) 

272 return obj 

273 

274 def send(self, obj): 

275 if self._world.rank == 0: 

276 pickle.dump(obj, self.output_fd) 

277 self.output_fd.flush() 

278 

279 def mainloop(self, calc): 

280 while True: 

281 instruction = self.recv() 

282 if instruction == 'stop': 

283 return 

284 

285 instruction_data = self.recv() 

286 

287 response_type, value = self.process_instruction( 

288 calc, instruction, instruction_data) 

289 self.send((response_type, value)) 

290 

291 def process_instruction(self, calc, instruction, instruction_data): 

292 if instruction == 'callmethod': 

293 function = callmethod 

294 args = (calc, *instruction_data) 

295 elif instruction == 'calculate': 

296 function = calculate 

297 args = (calc, *instruction_data) 

298 elif instruction == 'callfunction': 

299 function = callfunction 

300 args = instruction_data 

301 else: 

302 raise RuntimeError(f'Bad instruction: {instruction}') 

303 

304 try: 

305 print('ARGS', args) 

306 value = function(*args) 

307 except Exception as ex: 

308 import traceback 

309 traceback.print_exc() 

310 response_type = 'raise' 

311 value = ex 

312 else: 

313 response_type = 'return' 

314 return response_type, value 

315 

316 

317class ParallelDispatch: 

318 """Utility class to run functions in parallel. 

319 

320 with ParallelDispatch(...) as parallel: 

321 parallel.call(function, args, kwargs) 

322 

323 """ 

324 def __init__(self, mpicommand): 

325 self._mpicommand = mpicommand 

326 self._protocol = None 

327 

328 def call(self, func, *args, **kwargs): 

329 self._protocol.send('callfunction') 

330 self._protocol.send((func, args, kwargs)) 

331 return self._protocol.recv() 

332 

333 def __enter__(self): 

334 assert self._protocol is None 

335 self._protocol = Protocol(self._mpicommand.execute()) 

336 

337 # Even if we are not using a calculator, we have to send one: 

338 pack = NamedPackedCalculator('emt', {}) 

339 self._protocol.send(pack) 

340 # (We should get rid of that requirement.) 

341 

342 return self 

343 

344 def __exit__(self, *args): 

345 self._protocol.send('stop') 

346 self._protocol.proc.communicate() 

347 self._protocol = None 

348 

349 

350def main(): 

351 client = parallel_startup() 

352 pack = client.recv() 

353 calc = pack.unpack_calculator() 

354 client.mainloop(calc) 

355 

356 

357if __name__ == '__main__': 

358 main()