Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1import os 

2import sys 

3from abc import ABC, abstractmethod 

4import pickle 

5from subprocess import Popen, PIPE 

6from ase.calculators.calculator import Calculator, all_properties 

7 

8 

9class PackedCalculator(ABC): 

10 """Portable calculator for use via PythonSubProcessCalculator. 

11 

12 This class allows creating and talking to a calculator which 

13 exists inside a different process, possibly with MPI or srun. 

14 

15 Use this when you want to use ASE mostly in serial, but run some 

16 calculations in a parallel Python environment. 

17 

18 Most existing calculators can be used this way through the 

19 NamedPackedCalculator implementation. To customize the behaviour 

20 for other calculators, write a custom class inheriting this one. 

21 

22 Example:: 

23 

24 from ase.build import bulk 

25 

26 atoms = bulk('Au') 

27 pack = NamedPackedCalculator('emt') 

28 

29 with pack.calculator() as atoms.calc: 

30 energy = atoms.get_potential_energy() 

31 

32 The computation takes place inside a subprocess which lives as long 

33 as the with statement. 

34 """ 

35 

36 @abstractmethod 

37 def unpack_calculator(self) -> Calculator: 

38 """Return the calculator packed inside. 

39 

40 This method will be called inside the subprocess doing 

41 computations.""" 

42 

43 def calculator(self, mpi_command=None) -> 'PythonSubProcessCalculator': 

44 """Return a PythonSubProcessCalculator for this calculator. 

45 

46 The subprocess calculator wraps a subprocess containing 

47 the actual calculator, and computations are done inside that 

48 subprocess.""" 

49 return PythonSubProcessCalculator(self, mpi_command=mpi_command) 

50 

51 

52class NamedPackedCalculator(PackedCalculator): 

53 """PackedCalculator implementation which works with standard calculators. 

54 

55 This works with calculators known by ase.calculators.calculator.""" 

56 

57 def __init__(self, name, kwargs=None): 

58 self._name = name 

59 if kwargs is None: 

60 kwargs = {} 

61 self._kwargs = kwargs 

62 

63 def unpack_calculator(self): 

64 from ase.calculators.calculator import get_calculator_class 

65 cls = get_calculator_class(self._name) 

66 return cls(**self._kwargs) 

67 

68 def __repr__(self): 

69 return f'{self.__class__.__name__}({self._name}, {self._kwargs})' 

70 

71 

72class MPICommand: 

73 def __init__(self, argv): 

74 self.argv = argv 

75 

76 @classmethod 

77 def python_argv(cls): 

78 return [sys.executable, '-m', 'ase.calculators.subprocesscalculator'] 

79 

80 @classmethod 

81 def parallel(cls, nprocs, mpi_argv=tuple()): 

82 return cls(['mpiexec', '-n', str(nprocs)] 

83 + list(mpi_argv) 

84 + cls.python_argv() 

85 + ['mpi4py']) 

86 

87 @classmethod 

88 def serial(cls): 

89 return MPICommand(cls.python_argv() + ['standard']) 

90 

91 def execute(self): 

92 # On this computer (Ubuntu 20.04 + OpenMPI) the subprocess crashes 

93 # without output during startup if os.environ is not passed along. 

94 # Hence we pass os.environ. Not sure if this is a machine thing 

95 # or in general. --askhl 

96 return Popen(self.argv, stdout=PIPE, 

97 stdin=PIPE, env=os.environ) 

98 

99 

100def gpaw_process(ncores=1, **kwargs): 

101 packed = NamedPackedCalculator('gpaw', kwargs) 

102 mpicommand = MPICommand([ 

103 sys.executable, '-m', 'gpaw', '-P', str(ncores), 'python', '-m', 

104 'ase.calculators.subprocesscalculator', 'standard', 

105 ]) 

106 return PythonSubProcessCalculator(packed, mpicommand) 

107 

108 

109class PythonSubProcessCalculator(Calculator): 

110 """Calculator for running calculations in external processes. 

111 

112 TODO: This should work with arbitrary commands including MPI stuff. 

113 

114 This calculator runs a subprocess wherein it sets up an 

115 actual calculator. Calculations are forwarded through pickle 

116 to that calculator, which returns results through pickle.""" 

117 implemented_properties = list(all_properties) 

118 

119 def __init__(self, calc_input, mpi_command=None): 

120 super().__init__() 

121 

122 self.proc = None 

123 self.calc_input = calc_input 

124 if mpi_command is None: 

125 mpi_command = MPICommand.serial() 

126 self.mpi_command = mpi_command 

127 

128 def set(self, **kwargs): 

129 if hasattr(self, 'proc'): 

130 raise RuntimeError('No setting things for now, thanks') 

131 

132 def _send(self, obj): 

133 pickle.dump(obj, self.proc.stdin) 

134 self.proc.stdin.flush() 

135 

136 def _recv(self): 

137 response_type, value = pickle.load(self.proc.stdout) 

138 

139 if response_type == 'raise': 

140 raise value 

141 

142 assert response_type == 'return' 

143 return value 

144 

145 def __repr__(self): 

146 return '{}({})'.format(type(self).__name__, 

147 self.calc_input) 

148 

149 def __enter__(self): 

150 assert self.proc is None 

151 self.proc = self.mpi_command.execute() 

152 self._send(self.calc_input) 

153 return self 

154 

155 def __exit__(self, *args): 

156 self._send('stop') 

157 self.proc.communicate() 

158 self.proc = None 

159 

160 def _run_calculation(self, atoms, properties, system_changes): 

161 self._send('calculate') 

162 self._send((atoms, properties, system_changes)) 

163 

164 def calculate(self, atoms, properties, system_changes): 

165 Calculator.calculate(self, atoms, properties, system_changes) 

166 # We send a pickle of self.atoms because this is a fresh copy 

167 # of the input, but without an unpicklable calculator: 

168 self._run_calculation(self.atoms.copy(), properties, system_changes) 

169 results = self._recv() 

170 self.results.update(results) 

171 

172 def backend(self): 

173 return ParallelBackendInterface(self) 

174 

175 

176class MockMethod: 

177 def __init__(self, name, interface): 

178 self.name = name 

179 self.interface = interface 

180 

181 def __call__(self, *args, **kwargs): 

182 ifc = self.interface 

183 ifc._send('callmethod') 

184 ifc._send([self.name, args, kwargs]) 

185 return ifc._recv() 

186 

187 

188class ParallelBackendInterface: 

189 def __init__(self, interface): 

190 self.interface = interface 

191 

192 def __getattr__(self, name): 

193 return MockMethod(name, self.interface) 

194 

195 

196run_modes = {'standard', 'mpi4py'} 

197 

198 

199def callmethod(calc, attrname, args, kwargs): 

200 method = getattr(calc, attrname) 

201 value = method(*args, **kwargs) 

202 return value 

203 

204 

205def calculate(calc, atoms, properties, system_changes): 

206 # Again we need formalization of the results/outputs, and 

207 # a way to programmatically access all available properties. 

208 # We do a wild hack for now: 

209 calc.results.clear() 

210 # If we don't clear(), the caching is broken! For stress. 

211 # But not for forces. What dark magic from the depths of the 

212 # underworld is at play here? 

213 calc.calculate(atoms=atoms, properties=properties, 

214 system_changes=system_changes) 

215 results = calc.results 

216 return results 

217 

218 

219def bad_mode(): 

220 return SystemExit(f'sys.argv[1] must be one of {run_modes}') 

221 

222 

223def main(): 

224 try: 

225 run_mode = sys.argv[1] 

226 except IndexError: 

227 raise bad_mode() 

228 

229 if run_mode not in run_modes: 

230 raise bad_mode() 

231 

232 if run_mode == 'mpi4py': 

233 # We must import mpi4py before the rest of ASE, or world will not 

234 # be correctly initialized. 

235 import mpi4py # noqa 

236 

237 # We switch stdout so stray print statements won't interfere with outputs: 

238 binary_stdout = sys.stdout.buffer 

239 sys.stdout = sys.stderr 

240 

241 from ase.parallel import world, broadcast 

242 

243 def recv(): 

244 if world.rank == 0: 

245 obj = pickle.load(sys.stdin.buffer) 

246 else: 

247 obj = None 

248 

249 obj = broadcast(obj, 0, world) 

250 return obj 

251 

252 def send(obj): 

253 if world.rank == 0: 

254 pickle.dump(obj, binary_stdout) 

255 binary_stdout.flush() 

256 

257 pack = recv() 

258 calc = pack.unpack_calculator() 

259 

260 while True: 

261 instruction = recv() 

262 if instruction == 'stop': 

263 return 

264 

265 if instruction == 'callmethod': 

266 function = callmethod 

267 elif instruction == 'calculate': 

268 function = calculate 

269 else: 

270 raise RuntimeError(f'Bad instruction: {instruction}') 

271 

272 instruction_data = recv() 

273 

274 try: 

275 value = function(calc, *instruction_data) 

276 except Exception as ex: 

277 response_type = 'raise' 

278 value = ex 

279 else: 

280 response_type = 'return' 

281 

282 send((response_type, value)) 

283 

284 

285if __name__ == '__main__': 

286 main()