Coverage for /builds/ase/ase/ase/parallel.py : 50.47%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import os
2import atexit
3import functools
4import pickle
5import sys
6import time
8import numpy as np
11def get_txt(txt, rank):
12 if hasattr(txt, 'write'):
13 # Note: User-supplied object might write to files from many ranks.
14 return txt
15 elif rank == 0:
16 if txt is None:
17 return open(os.devnull, 'w')
18 elif txt == '-':
19 return sys.stdout
20 else:
21 return open(txt, 'w', 1)
22 else:
23 return open(os.devnull, 'w')
26def paropen(name, mode='r', buffering=-1, encoding=None, comm=None):
27 """MPI-safe version of open function.
29 In read mode, the file is opened on all nodes. In write and
30 append mode, the file is opened on the master only, and /dev/null
31 is opened on all other nodes.
32 """
33 if comm is None:
34 comm = world
35 if comm.rank > 0 and mode[0] != 'r':
36 name = os.devnull
37 return open(name, mode, buffering, encoding)
40def parprint(*args, **kwargs):
41 """MPI-safe print - prints only from master. """
42 if world.rank == 0:
43 print(*args, **kwargs)
46class DummyMPI:
47 rank = 0
48 size = 1
50 def _returnval(self, a, root=-1):
51 # MPI interface works either on numbers, in which case a number is
52 # returned, or on arrays, in-place.
53 if np.isscalar(a):
54 return a
55 if hasattr(a, '__array__'):
56 a = a.__array__()
57 assert isinstance(a, np.ndarray)
58 return None
60 def sum(self, a, root=-1):
61 return self._returnval(a)
63 def product(self, a, root=-1):
64 return self._returnval(a)
66 def broadcast(self, a, root):
67 assert root == 0
68 return self._returnval(a)
70 def barrier(self):
71 pass
74class MPI:
75 """Wrapper for MPI world object.
77 Decides at runtime (after all imports) which one to use:
79 * MPI4Py
80 * GPAW
81 * a dummy implementation for serial runs
83 """
85 def __init__(self):
86 self.comm = None
88 def __getattr__(self, name):
89 # Pickling of objects that carry instances of MPI class
90 # (e.g. NEB) raises RecursionError since it tries to access
91 # the optional __setstate__ method (which we do not implement)
92 # when unpickling. The two lines below prevent the
93 # RecursionError. This also affects modules that use pickling
94 # e.g. multiprocessing. For more details see:
95 # https://gitlab.com/ase/ase/-/merge_requests/2695
96 if name == '__setstate__':
97 raise AttributeError(name)
99 if self.comm is None:
100 self.comm = _get_comm()
101 return getattr(self.comm, name)
104def _get_comm():
105 """Get the correct MPI world object."""
106 if 'mpi4py' in sys.modules:
107 return MPI4PY()
108 if '_gpaw' in sys.modules:
109 import _gpaw
110 if hasattr(_gpaw, 'Communicator'):
111 return _gpaw.Communicator()
112 if '_asap' in sys.modules:
113 import _asap
114 if hasattr(_asap, 'Communicator'):
115 return _asap.Communicator()
116 return DummyMPI()
119class MPI4PY:
120 def __init__(self, mpi4py_comm=None):
121 if mpi4py_comm is None:
122 from mpi4py import MPI
123 mpi4py_comm = MPI.COMM_WORLD
124 self.comm = mpi4py_comm
126 @property
127 def rank(self):
128 return self.comm.rank
130 @property
131 def size(self):
132 return self.comm.size
134 def _returnval(self, a, b):
135 """Behave correctly when working on scalars/arrays.
137 Either input is an array and we in-place write b (output from
138 mpi4py) back into a, or input is a scalar and we return the
139 corresponding output scalar."""
140 if np.isscalar(a):
141 assert np.isscalar(b)
142 return b
143 else:
144 assert not np.isscalar(b)
145 a[:] = b
146 return None
148 def sum(self, a, root=-1):
149 if root == -1:
150 b = self.comm.allreduce(a)
151 else:
152 b = self.comm.reduce(a, root)
153 return self._returnval(a, b)
155 def split(self, split_size=None):
156 """Divide the communicator."""
157 # color - subgroup id
158 # key - new subgroup rank
159 if not split_size:
160 split_size = self.size
161 color = int(self.rank // (self.size / split_size))
162 key = int(self.rank % (self.size / split_size))
163 comm = self.comm.Split(color, key)
164 return MPI4PY(comm)
166 def barrier(self):
167 self.comm.barrier()
169 def abort(self, code):
170 self.comm.Abort(code)
172 def broadcast(self, a, root):
173 b = self.comm.bcast(a, root=root)
174 if self.rank == root:
175 if np.isscalar(a):
176 return a
177 return
178 return self._returnval(a, b)
181world = None
183# Check for special MPI-enabled Python interpreters:
184if '_gpaw' in sys.builtin_module_names:
185 # http://wiki.fysik.dtu.dk/gpaw
186 import _gpaw
187 world = _gpaw.Communicator()
188elif '_asap' in sys.builtin_module_names:
189 # Modern version of Asap
190 # http://wiki.fysik.dtu.dk/asap
191 # We cannot import asap3.mpi here, as that creates an import deadlock
192 import _asap
193 world = _asap.Communicator()
195# Check if MPI implementation has been imported already:
196elif '_gpaw' in sys.modules:
197 # Same thing as above but for the module version
198 import _gpaw
199 try:
200 world = _gpaw.Communicator()
201 except AttributeError:
202 pass
203elif '_asap' in sys.modules:
204 import _asap
205 try:
206 world = _asap.Communicator()
207 except AttributeError:
208 pass
209elif 'mpi4py' in sys.modules:
210 world = MPI4PY()
212if world is None:
213 world = MPI()
216def barrier():
217 world.barrier()
220def broadcast(obj, root=0, comm=world):
221 """Broadcast a Python object across an MPI communicator and return it."""
222 if comm.rank == root:
223 string = pickle.dumps(obj, pickle.HIGHEST_PROTOCOL)
224 n = np.array([len(string)], int)
225 else:
226 string = None
227 n = np.empty(1, int)
228 comm.broadcast(n, root)
229 if comm.rank == root:
230 string = np.frombuffer(string, np.int8)
231 else:
232 string = np.zeros(n, np.int8)
233 comm.broadcast(string, root)
234 if comm.rank == root:
235 return obj
236 else:
237 return pickle.loads(string.tobytes())
240def parallel_function(func):
241 """Decorator for broadcasting from master to slaves using MPI.
243 Disable by passing parallel=False to the function. For a method,
244 you can also disable the parallel behavior by giving the instance
245 a self.serial = True.
246 """
248 @functools.wraps(func)
249 def new_func(*args, **kwargs):
250 if (world.size == 1 or
251 args and getattr(args[0], 'serial', False) or
252 not kwargs.pop('parallel', True)):
253 # Disable:
254 return func(*args, **kwargs)
256 ex = None
257 result = None
258 if world.rank == 0:
259 try:
260 result = func(*args, **kwargs)
261 except Exception as x:
262 ex = x
263 ex, result = broadcast((ex, result))
264 if ex is not None:
265 raise ex
266 return result
268 return new_func
271def parallel_generator(generator):
272 """Decorator for broadcasting yields from master to slaves using MPI.
274 Disable by passing parallel=False to the function. For a method,
275 you can also disable the parallel behavior by giving the instance
276 a self.serial = True.
277 """
279 @functools.wraps(generator)
280 def new_generator(*args, **kwargs):
281 if (world.size == 1 or
282 args and getattr(args[0], 'serial', False) or
283 not kwargs.pop('parallel', True)):
284 # Disable:
285 for result in generator(*args, **kwargs):
286 yield result
287 return
289 if world.rank == 0:
290 try:
291 for result in generator(*args, **kwargs):
292 broadcast((None, result))
293 yield result
294 except Exception as ex:
295 broadcast((ex, None))
296 raise ex
297 broadcast((None, None))
298 else:
299 ex2, result = broadcast((None, None))
300 if ex2 is not None:
301 raise ex2
302 while result is not None:
303 yield result
304 ex2, result = broadcast((None, None))
305 if ex2 is not None:
306 raise ex2
308 return new_generator
311def register_parallel_cleanup_function():
312 """Call MPI_Abort if python crashes.
314 This will terminate the processes on the other nodes."""
316 if world.size == 1:
317 return
319 def cleanup(sys=sys, time=time, world=world):
320 error = getattr(sys, 'last_type', None)
321 if error:
322 sys.stdout.flush()
323 sys.stderr.write(('ASE CLEANUP (node %d): %s occurred. ' +
324 'Calling MPI_Abort!\n') % (world.rank, error))
325 sys.stderr.flush()
326 # Give other nodes a moment to crash by themselves (perhaps
327 # producing helpful error messages):
328 time.sleep(3)
329 world.abort(42)
331 atexit.register(cleanup)
334def distribute_cpus(size, comm):
335 """Distribute cpus to tasks and calculators.
337 Input:
338 size: number of nodes per calculator
339 comm: total communicator object
341 Output:
342 communicator for this rank, number of calculators, index for this rank
343 """
345 assert size <= comm.size
346 assert comm.size % size == 0
348 tasks_rank = comm.rank // size
350 r0 = tasks_rank * size
351 ranks = np.arange(r0, r0 + size)
352 mycomm = comm.new_communicator(ranks)
354 return mycomm, comm.size // size, tasks_rank
357def myslice(ntotal, comm):
358 """Return the slice of your tasks for ntotal jobs"""
359 n = -(-ntotal // comm.size) # ceil divide
360 return slice(n * comm.rank, n * (comm.rank + 1))