Coverage for /builds/ase/ase/ase/calculators/subprocesscalculator.py : 91.37%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import os
2import sys
3from abc import ABC, abstractmethod
4import pickle
5from subprocess import Popen, PIPE
6from ase.calculators.calculator import Calculator, all_properties
9class PackedCalculator(ABC):
10 """Portable calculator for use via PythonSubProcessCalculator.
12 This class allows creating and talking to a calculator which
13 exists inside a different process, possibly with MPI or srun.
15 Use this when you want to use ASE mostly in serial, but run some
16 calculations in a parallel Python environment.
18 Most existing calculators can be used this way through the
19 NamedPackedCalculator implementation. To customize the behaviour
20 for other calculators, write a custom class inheriting this one.
22 Example::
24 from ase.build import bulk
26 atoms = bulk('Au')
27 pack = NamedPackedCalculator('emt')
29 with pack.calculator() as atoms.calc:
30 energy = atoms.get_potential_energy()
32 The computation takes place inside a subprocess which lives as long
33 as the with statement.
34 """
36 @abstractmethod
37 def unpack_calculator(self) -> Calculator:
38 """Return the calculator packed inside.
40 This method will be called inside the subprocess doing
41 computations."""
43 def calculator(self, mpi_command=None) -> 'PythonSubProcessCalculator':
44 """Return a PythonSubProcessCalculator for this calculator.
46 The subprocess calculator wraps a subprocess containing
47 the actual calculator, and computations are done inside that
48 subprocess."""
49 return PythonSubProcessCalculator(self, mpi_command=mpi_command)
52class NamedPackedCalculator(PackedCalculator):
53 """PackedCalculator implementation which works with standard calculators.
55 This works with calculators known by ase.calculators.calculator."""
57 def __init__(self, name, kwargs=None):
58 self._name = name
59 if kwargs is None:
60 kwargs = {}
61 self._kwargs = kwargs
63 def unpack_calculator(self):
64 from ase.calculators.calculator import get_calculator_class
65 cls = get_calculator_class(self._name)
66 return cls(**self._kwargs)
68 def __repr__(self):
69 return f'{self.__class__.__name__}({self._name}, {self._kwargs})'
72class MPICommand:
73 def __init__(self, argv):
74 self.argv = argv
76 @classmethod
77 def python_argv(cls):
78 return [sys.executable, '-m', 'ase.calculators.subprocesscalculator']
80 @classmethod
81 def parallel(cls, nprocs, mpi_argv=tuple()):
82 return cls(['mpiexec', '-n', str(nprocs)]
83 + list(mpi_argv)
84 + cls.python_argv()
85 + ['mpi4py'])
87 @classmethod
88 def serial(cls):
89 return MPICommand(cls.python_argv() + ['standard'])
91 def execute(self):
92 # On this computer (Ubuntu 20.04 + OpenMPI) the subprocess crashes
93 # without output during startup if os.environ is not passed along.
94 # Hence we pass os.environ. Not sure if this is a machine thing
95 # or in general. --askhl
96 return Popen(self.argv, stdout=PIPE,
97 stdin=PIPE, env=os.environ)
100def gpaw_process(ncores=1, **kwargs):
101 packed = NamedPackedCalculator('gpaw', kwargs)
102 mpicommand = MPICommand([
103 sys.executable, '-m', 'gpaw', '-P', str(ncores), 'python', '-m',
104 'ase.calculators.subprocesscalculator', 'standard',
105 ])
106 return PythonSubProcessCalculator(packed, mpicommand)
109class PythonSubProcessCalculator(Calculator):
110 """Calculator for running calculations in external processes.
112 TODO: This should work with arbitrary commands including MPI stuff.
114 This calculator runs a subprocess wherein it sets up an
115 actual calculator. Calculations are forwarded through pickle
116 to that calculator, which returns results through pickle."""
117 implemented_properties = list(all_properties)
119 def __init__(self, calc_input, mpi_command=None):
120 super().__init__()
122 # self.proc = None
123 self.calc_input = calc_input
124 if mpi_command is None:
125 mpi_command = MPICommand.serial()
126 self.mpi_command = mpi_command
128 self.protocol = None
130 def set(self, **kwargs):
131 if hasattr(self, 'client'):
132 raise RuntimeError('No setting things for now, thanks')
134 def __repr__(self):
135 return '{}({})'.format(type(self).__name__,
136 self.calc_input)
138 def __enter__(self):
139 assert self.protocol is None
140 proc = self.mpi_command.execute()
141 self.protocol = Protocol(proc)
142 self.protocol.send(self.calc_input)
143 return self
145 def __exit__(self, *args):
146 self.protocol.send('stop')
147 self.protocol.proc.communicate()
148 self.protocol = None
150 def _run_calculation(self, atoms, properties, system_changes):
151 self.protocol.send('calculate')
152 self.protocol.send((atoms, properties, system_changes))
154 def calculate(self, atoms, properties, system_changes):
155 Calculator.calculate(self, atoms, properties, system_changes)
156 # We send a pickle of self.atoms because this is a fresh copy
157 # of the input, but without an unpicklable calculator:
158 self._run_calculation(self.atoms.copy(), properties, system_changes)
159 results = self.protocol.recv()
160 self.results.update(results)
162 def backend(self):
163 return ParallelBackendInterface(self)
166class Protocol:
167 def __init__(self, proc):
168 self.proc = proc
170 def send(self, obj):
171 pickle.dump(obj, self.proc.stdin)
172 self.proc.stdin.flush()
174 def recv(self):
175 response_type, value = pickle.load(self.proc.stdout)
177 if response_type == 'raise':
178 raise value
180 assert response_type == 'return'
181 return value
184class MockMethod:
185 def __init__(self, name, calc):
186 self.name = name
187 self.calc = calc
189 def __call__(self, *args, **kwargs):
190 protocol = self.calc.protocol
191 protocol.send('callmethod')
192 protocol.send([self.name, args, kwargs])
193 return protocol.recv()
196class ParallelBackendInterface:
197 def __init__(self, calc):
198 self.calc = calc
200 def __getattr__(self, name):
201 return MockMethod(name, self.calc)
204run_modes = {'standard', 'mpi4py'}
207def callmethod(calc, attrname, args, kwargs):
208 method = getattr(calc, attrname)
209 value = method(*args, **kwargs)
210 return value
213def callfunction(func, args, kwargs):
214 return func(*args, **kwargs)
217def calculate(calc, atoms, properties, system_changes):
218 # Again we need formalization of the results/outputs, and
219 # a way to programmatically access all available properties.
220 # We do a wild hack for now:
221 calc.results.clear()
222 # If we don't clear(), the caching is broken! For stress.
223 # But not for forces. What dark magic from the depths of the
224 # underworld is at play here?
225 calc.calculate(atoms=atoms, properties=properties,
226 system_changes=system_changes)
227 results = calc.results
228 return results
231def bad_mode():
232 return SystemExit(f'sys.argv[1] must be one of {run_modes}')
235def parallel_startup():
236 try:
237 run_mode = sys.argv[1]
238 except IndexError:
239 raise bad_mode()
241 if run_mode not in run_modes:
242 raise bad_mode()
244 if run_mode == 'mpi4py':
245 # We must import mpi4py before the rest of ASE, or world will not
246 # be correctly initialized.
247 import mpi4py # noqa
249 # We switch stdout so stray print statements won't interfere with outputs:
250 binary_stdout = sys.stdout.buffer
251 sys.stdout = sys.stderr
253 return Client(input_fd=sys.stdin.buffer,
254 output_fd=binary_stdout)
257class Client:
258 def __init__(self, input_fd, output_fd):
259 from ase.parallel import world
260 self._world = world
261 self.input_fd = input_fd
262 self.output_fd = output_fd
264 def recv(self):
265 from ase.parallel import broadcast
266 if self._world.rank == 0:
267 obj = pickle.load(self.input_fd)
268 else:
269 obj = None
271 obj = broadcast(obj, 0, self._world)
272 return obj
274 def send(self, obj):
275 if self._world.rank == 0:
276 pickle.dump(obj, self.output_fd)
277 self.output_fd.flush()
279 def mainloop(self, calc):
280 while True:
281 instruction = self.recv()
282 if instruction == 'stop':
283 return
285 instruction_data = self.recv()
287 response_type, value = self.process_instruction(
288 calc, instruction, instruction_data)
289 self.send((response_type, value))
291 def process_instruction(self, calc, instruction, instruction_data):
292 if instruction == 'callmethod':
293 function = callmethod
294 args = (calc, *instruction_data)
295 elif instruction == 'calculate':
296 function = calculate
297 args = (calc, *instruction_data)
298 elif instruction == 'callfunction':
299 function = callfunction
300 args = instruction_data
301 else:
302 raise RuntimeError(f'Bad instruction: {instruction}')
304 try:
305 print('ARGS', args)
306 value = function(*args)
307 except Exception as ex:
308 import traceback
309 traceback.print_exc()
310 response_type = 'raise'
311 value = ex
312 else:
313 response_type = 'return'
314 return response_type, value
317class ParallelDispatch:
318 """Utility class to run functions in parallel.
320 with ParallelDispatch(...) as parallel:
321 parallel.call(function, args, kwargs)
323 """
324 def __init__(self, mpicommand):
325 self._mpicommand = mpicommand
326 self._protocol = None
328 def call(self, func, *args, **kwargs):
329 self._protocol.send('callfunction')
330 self._protocol.send((func, args, kwargs))
331 return self._protocol.recv()
333 def __enter__(self):
334 assert self._protocol is None
335 self._protocol = Protocol(self._mpicommand.execute())
337 # Even if we are not using a calculator, we have to send one:
338 pack = NamedPackedCalculator('emt', {})
339 self._protocol.send(pack)
340 # (We should get rid of that requirement.)
342 return self
344 def __exit__(self, *args):
345 self._protocol.send('stop')
346 self._protocol.proc.communicate()
347 self._protocol = None
350def main():
351 client = parallel_startup()
352 pack = client.recv()
353 calc = pack.unpack_calculator()
354 client.mainloop(calc)
357if __name__ == '__main__':
358 main()