Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1import os 

2import atexit 

3import functools 

4import pickle 

5import sys 

6import time 

7 

8import numpy as np 

9 

10 

11def get_txt(txt, rank): 

12 if hasattr(txt, 'write'): 

13 # Note: User-supplied object might write to files from many ranks. 

14 return txt 

15 elif rank == 0: 

16 if txt is None: 

17 return open(os.devnull, 'w') 

18 elif txt == '-': 

19 return sys.stdout 

20 else: 

21 return open(txt, 'w', 1) 

22 else: 

23 return open(os.devnull, 'w') 

24 

25 

26def paropen(name, mode='r', buffering=-1, encoding=None, comm=None): 

27 """MPI-safe version of open function. 

28 

29 In read mode, the file is opened on all nodes. In write and 

30 append mode, the file is opened on the master only, and /dev/null 

31 is opened on all other nodes. 

32 """ 

33 if comm is None: 

34 comm = world 

35 if comm.rank > 0 and mode[0] != 'r': 

36 name = os.devnull 

37 return open(name, mode, buffering, encoding) 

38 

39 

40def parprint(*args, **kwargs): 

41 """MPI-safe print - prints only from master. """ 

42 if world.rank == 0: 

43 print(*args, **kwargs) 

44 

45 

46class DummyMPI: 

47 rank = 0 

48 size = 1 

49 

50 def _returnval(self, a, root=-1): 

51 # MPI interface works either on numbers, in which case a number is 

52 # returned, or on arrays, in-place. 

53 if np.isscalar(a): 

54 return a 

55 if hasattr(a, '__array__'): 

56 a = a.__array__() 

57 assert isinstance(a, np.ndarray) 

58 return None 

59 

60 def sum(self, a, root=-1): 

61 return self._returnval(a) 

62 

63 def product(self, a, root=-1): 

64 return self._returnval(a) 

65 

66 def broadcast(self, a, root): 

67 assert root == 0 

68 return self._returnval(a) 

69 

70 def barrier(self): 

71 pass 

72 

73 

74class MPI: 

75 """Wrapper for MPI world object. 

76 

77 Decides at runtime (after all imports) which one to use: 

78 

79 * MPI4Py 

80 * GPAW 

81 * a dummy implementation for serial runs 

82 

83 """ 

84 

85 def __init__(self): 

86 self.comm = None 

87 

88 def __getattr__(self, name): 

89 # Pickling of objects that carry instances of MPI class 

90 # (e.g. NEB) raises RecursionError since it tries to access 

91 # the optional __setstate__ method (which we do not implement) 

92 # when unpickling. The two lines below prevent the 

93 # RecursionError. This also affects modules that use pickling 

94 # e.g. multiprocessing. For more details see: 

95 # https://gitlab.com/ase/ase/-/merge_requests/2695 

96 if name == '__setstate__': 

97 raise AttributeError(name) 

98 

99 if self.comm is None: 

100 self.comm = _get_comm() 

101 return getattr(self.comm, name) 

102 

103 

104def _get_comm(): 

105 """Get the correct MPI world object.""" 

106 if 'mpi4py' in sys.modules: 

107 return MPI4PY() 

108 if '_gpaw' in sys.modules: 

109 import _gpaw 

110 if hasattr(_gpaw, 'Communicator'): 

111 return _gpaw.Communicator() 

112 if '_asap' in sys.modules: 

113 import _asap 

114 if hasattr(_asap, 'Communicator'): 

115 return _asap.Communicator() 

116 return DummyMPI() 

117 

118 

119class MPI4PY: 

120 def __init__(self, mpi4py_comm=None): 

121 if mpi4py_comm is None: 

122 from mpi4py import MPI 

123 mpi4py_comm = MPI.COMM_WORLD 

124 self.comm = mpi4py_comm 

125 

126 @property 

127 def rank(self): 

128 return self.comm.rank 

129 

130 @property 

131 def size(self): 

132 return self.comm.size 

133 

134 def _returnval(self, a, b): 

135 """Behave correctly when working on scalars/arrays. 

136 

137 Either input is an array and we in-place write b (output from 

138 mpi4py) back into a, or input is a scalar and we return the 

139 corresponding output scalar.""" 

140 if np.isscalar(a): 

141 assert np.isscalar(b) 

142 return b 

143 else: 

144 assert not np.isscalar(b) 

145 a[:] = b 

146 return None 

147 

148 def sum(self, a, root=-1): 

149 if root == -1: 

150 b = self.comm.allreduce(a) 

151 else: 

152 b = self.comm.reduce(a, root) 

153 return self._returnval(a, b) 

154 

155 def split(self, split_size=None): 

156 """Divide the communicator.""" 

157 # color - subgroup id 

158 # key - new subgroup rank 

159 if not split_size: 

160 split_size = self.size 

161 color = int(self.rank // (self.size / split_size)) 

162 key = int(self.rank % (self.size / split_size)) 

163 comm = self.comm.Split(color, key) 

164 return MPI4PY(comm) 

165 

166 def barrier(self): 

167 self.comm.barrier() 

168 

169 def abort(self, code): 

170 self.comm.Abort(code) 

171 

172 def broadcast(self, a, root): 

173 b = self.comm.bcast(a, root=root) 

174 if self.rank == root: 

175 if np.isscalar(a): 

176 return a 

177 return 

178 return self._returnval(a, b) 

179 

180 

181world = None 

182 

183# Check for special MPI-enabled Python interpreters: 

184if '_gpaw' in sys.builtin_module_names: 

185 # http://wiki.fysik.dtu.dk/gpaw 

186 import _gpaw 

187 world = _gpaw.Communicator() 

188elif '_asap' in sys.builtin_module_names: 

189 # Modern version of Asap 

190 # http://wiki.fysik.dtu.dk/asap 

191 # We cannot import asap3.mpi here, as that creates an import deadlock 

192 import _asap 

193 world = _asap.Communicator() 

194 

195# Check if MPI implementation has been imported already: 

196elif '_gpaw' in sys.modules: 

197 # Same thing as above but for the module version 

198 import _gpaw 

199 try: 

200 world = _gpaw.Communicator() 

201 except AttributeError: 

202 pass 

203elif '_asap' in sys.modules: 

204 import _asap 

205 try: 

206 world = _asap.Communicator() 

207 except AttributeError: 

208 pass 

209elif 'mpi4py' in sys.modules: 

210 world = MPI4PY() 

211 

212if world is None: 

213 world = MPI() 

214 

215 

216def barrier(): 

217 world.barrier() 

218 

219 

220def broadcast(obj, root=0, comm=world): 

221 """Broadcast a Python object across an MPI communicator and return it.""" 

222 if comm.rank == root: 

223 string = pickle.dumps(obj, pickle.HIGHEST_PROTOCOL) 

224 n = np.array([len(string)], int) 

225 else: 

226 string = None 

227 n = np.empty(1, int) 

228 comm.broadcast(n, root) 

229 if comm.rank == root: 

230 string = np.frombuffer(string, np.int8) 

231 else: 

232 string = np.zeros(n, np.int8) 

233 comm.broadcast(string, root) 

234 if comm.rank == root: 

235 return obj 

236 else: 

237 return pickle.loads(string.tobytes()) 

238 

239 

240def parallel_function(func): 

241 """Decorator for broadcasting from master to slaves using MPI. 

242 

243 Disable by passing parallel=False to the function. For a method, 

244 you can also disable the parallel behavior by giving the instance 

245 a self.serial = True. 

246 """ 

247 

248 @functools.wraps(func) 

249 def new_func(*args, **kwargs): 

250 if (world.size == 1 or 

251 args and getattr(args[0], 'serial', False) or 

252 not kwargs.pop('parallel', True)): 

253 # Disable: 

254 return func(*args, **kwargs) 

255 

256 ex = None 

257 result = None 

258 if world.rank == 0: 

259 try: 

260 result = func(*args, **kwargs) 

261 except Exception as x: 

262 ex = x 

263 ex, result = broadcast((ex, result)) 

264 if ex is not None: 

265 raise ex 

266 return result 

267 

268 return new_func 

269 

270 

271def parallel_generator(generator): 

272 """Decorator for broadcasting yields from master to slaves using MPI. 

273 

274 Disable by passing parallel=False to the function. For a method, 

275 you can also disable the parallel behavior by giving the instance 

276 a self.serial = True. 

277 """ 

278 

279 @functools.wraps(generator) 

280 def new_generator(*args, **kwargs): 

281 if (world.size == 1 or 

282 args and getattr(args[0], 'serial', False) or 

283 not kwargs.pop('parallel', True)): 

284 # Disable: 

285 for result in generator(*args, **kwargs): 

286 yield result 

287 return 

288 

289 if world.rank == 0: 

290 try: 

291 for result in generator(*args, **kwargs): 

292 broadcast((None, result)) 

293 yield result 

294 except Exception as ex: 

295 broadcast((ex, None)) 

296 raise ex 

297 broadcast((None, None)) 

298 else: 

299 ex2, result = broadcast((None, None)) 

300 if ex2 is not None: 

301 raise ex2 

302 while result is not None: 

303 yield result 

304 ex2, result = broadcast((None, None)) 

305 if ex2 is not None: 

306 raise ex2 

307 

308 return new_generator 

309 

310 

311def register_parallel_cleanup_function(): 

312 """Call MPI_Abort if python crashes. 

313 

314 This will terminate the processes on the other nodes.""" 

315 

316 if world.size == 1: 

317 return 

318 

319 def cleanup(sys=sys, time=time, world=world): 

320 error = getattr(sys, 'last_type', None) 

321 if error: 

322 sys.stdout.flush() 

323 sys.stderr.write(('ASE CLEANUP (node %d): %s occurred. ' + 

324 'Calling MPI_Abort!\n') % (world.rank, error)) 

325 sys.stderr.flush() 

326 # Give other nodes a moment to crash by themselves (perhaps 

327 # producing helpful error messages): 

328 time.sleep(3) 

329 world.abort(42) 

330 

331 atexit.register(cleanup) 

332 

333 

334def distribute_cpus(size, comm): 

335 """Distribute cpus to tasks and calculators. 

336 

337 Input: 

338 size: number of nodes per calculator 

339 comm: total communicator object 

340 

341 Output: 

342 communicator for this rank, number of calculators, index for this rank 

343 """ 

344 

345 assert size <= comm.size 

346 assert comm.size % size == 0 

347 

348 tasks_rank = comm.rank // size 

349 

350 r0 = tasks_rank * size 

351 ranks = np.arange(r0, r0 + size) 

352 mycomm = comm.new_communicator(ranks) 

353 

354 return mycomm, comm.size // size, tasks_rank 

355 

356 

357def myslice(ntotal, comm): 

358 """Return the slice of your tasks for ntotal jobs""" 

359 n = -(-ntotal // comm.size) # ceil divide 

360 return slice(n * comm.rank, n * (comm.rank + 1))