Coverage for /builds/ase/ase/ase/calculators/socketio.py : 91.88%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import os
2import socket
3from subprocess import Popen, PIPE
4from contextlib import contextmanager
6import numpy as np
8from ase.calculators.calculator import (Calculator, all_changes,
9 PropertyNotImplementedError)
10import ase.units as units
11from ase.utils import IOContext
12from ase.stress import full_3x3_to_voigt_6_stress
15def actualunixsocketname(name):
16 return '/tmp/ipi_{}'.format(name)
19class SocketClosed(OSError):
20 pass
23class IPIProtocol:
24 """Communication using IPI protocol."""
26 def __init__(self, socket, txt=None):
27 self.socket = socket
29 if txt is None:
30 def log(*args):
31 pass
32 else:
33 def log(*args):
34 print('Driver:', *args, file=txt)
35 txt.flush()
36 self.log = log
38 def sendmsg(self, msg):
39 self.log(' sendmsg', repr(msg))
40 # assert msg in self.statements, msg
41 msg = msg.encode('ascii').ljust(12)
42 self.socket.sendall(msg)
44 def _recvall(self, nbytes):
45 """Repeatedly read chunks until we have nbytes.
47 Normally we get all bytes in one read, but that is not guaranteed."""
48 remaining = nbytes
49 chunks = []
50 while remaining > 0:
51 chunk = self.socket.recv(remaining)
52 if len(chunk) == 0:
53 # (If socket is still open, recv returns at least one byte)
54 raise SocketClosed()
55 chunks.append(chunk)
56 remaining -= len(chunk)
57 msg = b''.join(chunks)
58 assert len(msg) == nbytes and remaining == 0
59 return msg
61 def recvmsg(self):
62 msg = self._recvall(12)
63 if not msg:
64 raise SocketClosed()
66 assert len(msg) == 12, msg
67 msg = msg.rstrip().decode('ascii')
68 # assert msg in self.responses, msg
69 self.log(' recvmsg', repr(msg))
70 return msg
72 def send(self, a, dtype):
73 buf = np.asarray(a, dtype).tobytes()
74 # self.log(' send {}'.format(np.array(a).ravel().tolist()))
75 self.log(' send {} bytes of {}'.format(len(buf), dtype))
76 self.socket.sendall(buf)
78 def recv(self, shape, dtype):
79 a = np.empty(shape, dtype)
80 nbytes = np.dtype(dtype).itemsize * np.prod(shape)
81 buf = self._recvall(nbytes)
82 assert len(buf) == nbytes, (len(buf), nbytes)
83 self.log(' recv {} bytes of {}'.format(len(buf), dtype))
84 # print(np.frombuffer(buf, dtype=dtype))
85 a.flat[:] = np.frombuffer(buf, dtype=dtype)
86 # self.log(' recv {}'.format(a.ravel().tolist()))
87 assert np.isfinite(a).all()
88 return a
90 def sendposdata(self, cell, icell, positions):
91 assert cell.size == 9
92 assert icell.size == 9
93 assert positions.size % 3 == 0
95 self.log(' sendposdata')
96 self.sendmsg('POSDATA')
97 self.send(cell.T / units.Bohr, np.float64)
98 self.send(icell.T * units.Bohr, np.float64)
99 self.send(len(positions), np.int32)
100 self.send(positions / units.Bohr, np.float64)
102 def recvposdata(self):
103 cell = self.recv((3, 3), np.float64).T.copy()
104 icell = self.recv((3, 3), np.float64).T.copy()
105 natoms = self.recv(1, np.int32)
106 natoms = int(natoms)
107 positions = self.recv((natoms, 3), np.float64)
108 return cell * units.Bohr, icell / units.Bohr, positions * units.Bohr
110 def sendrecv_force(self):
111 self.log(' sendrecv_force')
112 self.sendmsg('GETFORCE')
113 msg = self.recvmsg()
114 assert msg == 'FORCEREADY', msg
115 e = self.recv(1, np.float64)[0]
116 natoms = self.recv(1, np.int32)
117 assert natoms >= 0
118 forces = self.recv((int(natoms), 3), np.float64)
119 virial = self.recv((3, 3), np.float64).T.copy()
120 nmorebytes = self.recv(1, np.int32)
121 nmorebytes = int(nmorebytes)
122 if nmorebytes > 0:
123 # Receiving 0 bytes will block forever on python2.
124 morebytes = self.recv(nmorebytes, np.byte)
125 else:
126 morebytes = b''
127 return (e * units.Ha, (units.Ha / units.Bohr) * forces,
128 units.Ha * virial, morebytes)
130 def sendforce(self, energy, forces, virial,
131 morebytes=np.zeros(1, dtype=np.byte)):
132 assert np.array([energy]).size == 1
133 assert forces.shape[1] == 3
134 assert virial.shape == (3, 3)
136 self.log(' sendforce')
137 self.sendmsg('FORCEREADY') # mind the units
138 self.send(np.array([energy / units.Ha]), np.float64)
139 natoms = len(forces)
140 self.send(np.array([natoms]), np.int32)
141 self.send(units.Bohr / units.Ha * forces, np.float64)
142 self.send(1.0 / units.Ha * virial.T, np.float64)
143 # We prefer to always send at least one byte due to trouble with
144 # empty messages. Reading a closed socket yields 0 bytes
145 # and thus can be confused with a 0-length bytestring.
146 self.send(np.array([len(morebytes)]), np.int32)
147 self.send(morebytes, np.byte)
149 def status(self):
150 self.log(' status')
151 self.sendmsg('STATUS')
152 msg = self.recvmsg()
153 return msg
155 def end(self):
156 self.log(' end')
157 self.sendmsg('EXIT')
159 def recvinit(self):
160 self.log(' recvinit')
161 bead_index = self.recv(1, np.int32)
162 nbytes = self.recv(1, np.int32)
163 initbytes = self.recv(nbytes, np.byte)
164 return bead_index, initbytes
166 def sendinit(self):
167 # XXX Not sure what this function is supposed to send.
168 # It 'works' with QE, but for now we try not to call it.
169 self.log(' sendinit')
170 self.sendmsg('INIT')
171 self.send(0, np.int32) # 'bead index' always zero for now
172 # We send one byte, which is zero, since things may not work
173 # with 0 bytes. Apparently implementations ignore the
174 # initialization string anyway.
175 self.send(1, np.int32)
176 self.send(np.zeros(1), np.byte) # initialization string
178 def calculate(self, positions, cell):
179 self.log('calculate')
180 msg = self.status()
181 # We don't know how NEEDINIT is supposed to work, but some codes
182 # seem to be okay if we skip it and send the positions instead.
183 if msg == 'NEEDINIT':
184 self.sendinit()
185 msg = self.status()
186 assert msg == 'READY', msg
187 icell = np.linalg.pinv(cell).transpose()
188 self.sendposdata(cell, icell, positions)
189 msg = self.status()
190 assert msg == 'HAVEDATA', msg
191 e, forces, virial, morebytes = self.sendrecv_force()
192 r = dict(energy=e,
193 forces=forces,
194 virial=virial)
195 if morebytes:
196 r['morebytes'] = morebytes
197 return r
200@contextmanager
201def bind_unixsocket(socketfile):
202 assert socketfile.startswith('/tmp/ipi_'), socketfile
203 serversocket = socket.socket(socket.AF_UNIX)
204 try:
205 serversocket.bind(socketfile)
206 except OSError as err:
207 raise OSError('{}: {}'.format(err, repr(socketfile)))
209 try:
210 with serversocket:
211 yield serversocket
212 finally:
213 os.unlink(socketfile)
216@contextmanager
217def bind_inetsocket(port):
218 serversocket = socket.socket(socket.AF_INET)
219 serversocket.setsockopt(socket.SOL_SOCKET,
220 socket.SO_REUSEADDR, 1)
221 serversocket.bind(('', port))
222 with serversocket:
223 yield serversocket
226class FileIOSocketClientLauncher:
227 def __init__(self, calc):
228 self.calc = calc
230 def __call__(self, atoms, properties=None, port=None, unixsocket=None):
231 assert self.calc is not None
232 cwd = self.calc.directory
233 profile = getattr(self.calc, 'profile', None)
234 if profile is not None:
235 # New GenericFileIOCalculator:
236 self.calc.write_inputfiles(atoms, properties)
237 if unixsocket is not None:
238 argv = profile.socketio_argv_unix(socket=unixsocket)
239 else:
240 argv = profile.socketio_argv_inet(port=port)
241 import os
242 return Popen(argv, cwd=cwd, env=os.environ)
243 else:
244 # Old FileIOCalculator:
245 self.calc.write_input(atoms, properties=properties,
246 system_changes=all_changes)
247 cmd = self.calc.command.replace('PREFIX', self.calc.prefix)
248 cmd = cmd.format(port=port, unixsocket=unixsocket)
249 return Popen(cmd, shell=True, cwd=cwd)
252class SocketServer(IOContext):
253 default_port = 31415
255 def __init__(self, # launch_client=None,
256 port=None, unixsocket=None, timeout=None,
257 log=None):
258 """Create server and listen for connections.
260 Parameters:
262 client_command: Shell command to launch client process, or None
263 The process will be launched immediately, if given.
264 Else the user is expected to launch a client whose connection
265 the server will then accept at any time.
266 One calculate() is called, the server will block to wait
267 for the client.
268 port: integer or None
269 Port on which to listen for INET connections. Defaults
270 to 31415 if neither this nor unixsocket is specified.
271 unixsocket: string or None
272 Filename for unix socket.
273 timeout: float or None
274 timeout in seconds, or unlimited by default.
275 This parameter is passed to the Python socket object; see
276 documentation therof
277 log: file object or None
278 useful debug messages are written to this."""
280 if unixsocket is None and port is None:
281 port = self.default_port
282 elif unixsocket is not None and port is not None:
283 raise ValueError('Specify only one of unixsocket and port')
285 self.port = port
286 self.unixsocket = unixsocket
287 self.timeout = timeout
288 self._closed = False
290 if unixsocket is not None:
291 actualsocket = actualunixsocketname(unixsocket)
292 conn_name = 'UNIX-socket {}'.format(actualsocket)
293 socket_context = bind_unixsocket(actualsocket)
294 else:
295 conn_name = 'INET port {}'.format(port)
296 socket_context = bind_inetsocket(port)
298 self.serversocket = self.closelater(socket_context)
300 if log:
301 print('Accepting clients on {}'.format(conn_name), file=log)
303 self.serversocket.settimeout(timeout)
305 self.serversocket.listen(1)
307 self.log = log
309 self.proc = None
311 self.protocol = None
312 self.clientsocket = None
313 self.address = None
315 # if launch_client is not None:
316 # self.proc = launch_client(port=port, unixsocket=unixsocket)
318 def _accept(self):
319 """Wait for client and establish connection."""
320 # It should perhaps be possible for process to be launched by user
321 log = self.log
322 if log:
323 print('Awaiting client', file=self.log)
325 # If we launched the subprocess, the process may crash.
326 # We want to detect this, using loop with timeouts, and
327 # raise an error rather than blocking forever.
328 if self.proc is not None:
329 self.serversocket.settimeout(1.0)
331 while True:
332 try:
333 self.clientsocket, self.address = self.serversocket.accept()
334 self.closelater(self.clientsocket)
335 except socket.timeout:
336 if self.proc is not None:
337 status = self.proc.poll()
338 if status is not None:
339 raise OSError('Subprocess terminated unexpectedly'
340 ' with status {}'.format(status))
341 else:
342 break
344 self.serversocket.settimeout(self.timeout)
345 self.clientsocket.settimeout(self.timeout)
347 if log:
348 # For unix sockets, address is b''.
349 source = ('client' if self.address == b'' else self.address)
350 print('Accepted connection from {}'.format(source), file=log)
352 self.protocol = IPIProtocol(self.clientsocket, txt=log)
354 def close(self):
355 if self._closed:
356 return
358 super().close()
360 if self.log:
361 print('Close socket server', file=self.log)
362 self._closed = True
364 # Proper way to close sockets?
365 # And indeed i-pi connections...
366 # if self.protocol is not None:
367 # self.protocol.end() # Send end-of-communication string
368 self.protocol = None
369 if self.proc is not None:
370 exitcode = self.proc.wait()
371 if exitcode != 0:
372 import warnings
373 # Quantum Espresso seems to always exit with status 128,
374 # even if successful.
375 # Should investigate at some point
376 warnings.warn('Subprocess exited with status {}'
377 .format(exitcode))
378 # self.log('IPI server closed')
380 def calculate(self, atoms):
381 """Send geometry to client and return calculated things as dict.
383 This will block until client has established connection, then
384 wait for the client to finish the calculation."""
385 assert not self._closed
387 # If we have not established connection yet, we must block
388 # until the client catches up:
389 if self.protocol is None:
390 self._accept()
391 return self.protocol.calculate(atoms.positions, atoms.cell)
394class SocketClient:
395 def __init__(self, host='localhost', port=None,
396 unixsocket=None, timeout=None, log=None, comm=None):
397 """Create client and connect to server.
399 Parameters:
401 host: string
402 Hostname of server. Defaults to localhost
403 port: integer or None
404 Port to which to connect. By default 31415.
405 unixsocket: string or None
406 If specified, use corresponding UNIX socket.
407 See documentation of unixsocket for SocketIOCalculator.
408 timeout: float or None
409 See documentation of timeout for SocketIOCalculator.
410 log: file object or None
411 Log events to this file
412 comm: communicator or None
413 MPI communicator object. Defaults to ase.parallel.world.
414 When ASE runs in parallel, only the process with world.rank == 0
415 will communicate over the socket. The received information
416 will then be broadcast on the communicator. The SocketClient
417 must be created on all ranks of world, and will see the same
418 Atoms objects."""
419 if comm is None:
420 from ase.parallel import world
421 comm = world
423 # Only rank0 actually does the socket work.
424 # The other ranks only need to follow.
425 #
426 # Note: We actually refrain from assigning all the
427 # socket-related things except on master
428 self.comm = comm
430 if self.comm.rank == 0:
431 if unixsocket is not None:
432 sock = socket.socket(socket.AF_UNIX)
433 actualsocket = actualunixsocketname(unixsocket)
434 sock.connect(actualsocket)
435 else:
436 if port is None:
437 port = SocketServer.default_port
438 sock = socket.socket(socket.AF_INET)
439 sock.connect((host, port))
440 sock.settimeout(timeout)
441 self.host = host
442 self.port = port
443 self.unixsocket = unixsocket
445 self.protocol = IPIProtocol(sock, txt=log)
446 self.log = self.protocol.log
447 self.closed = False
449 self.bead_index = 0
450 self.bead_initbytes = b''
451 self.state = 'READY'
453 def close(self):
454 if not self.closed:
455 self.log('Close SocketClient')
456 self.closed = True
457 self.protocol.socket.close()
459 def calculate(self, atoms, use_stress):
460 # We should also broadcast the bead index, once we support doing
461 # multiple beads.
462 self.comm.broadcast(atoms.positions, 0)
463 self.comm.broadcast(np.ascontiguousarray(atoms.cell), 0)
465 energy = atoms.get_potential_energy()
466 forces = atoms.get_forces()
467 if use_stress:
468 stress = atoms.get_stress(voigt=False)
469 virial = -atoms.get_volume() * stress
470 else:
471 virial = np.zeros((3, 3))
472 return energy, forces, virial
474 def irun(self, atoms, use_stress=None):
475 if use_stress is None:
476 use_stress = any(atoms.pbc)
478 my_irun = self.irun_rank0 if self.comm.rank == 0 else self.irun_rankN
479 return my_irun(atoms, use_stress)
481 def irun_rankN(self, atoms, use_stress=True):
482 stop_criterion = np.zeros(1, bool)
483 while True:
484 self.comm.broadcast(stop_criterion, 0)
485 if stop_criterion[0]:
486 return
488 self.calculate(atoms, use_stress)
489 yield
491 def irun_rank0(self, atoms, use_stress=True):
492 # For every step we either calculate or quit. We need to
493 # tell other MPI processes (if this is MPI-parallel) whether they
494 # should calculate or quit.
495 try:
496 while True:
497 try:
498 msg = self.protocol.recvmsg()
499 except SocketClosed:
500 # Server closed the connection, but we want to
501 # exit gracefully anyway
502 msg = 'EXIT'
504 if msg == 'EXIT':
505 # Send stop signal to clients:
506 self.comm.broadcast(np.ones(1, bool), 0)
507 # (When otherwise exiting, things crashed and we should
508 # let MPI_ABORT take care of the mess instead of trying
509 # to synchronize the exit)
510 return
511 elif msg == 'STATUS':
512 self.protocol.sendmsg(self.state)
513 elif msg == 'POSDATA':
514 assert self.state == 'READY'
515 cell, icell, positions = self.protocol.recvposdata()
516 atoms.cell[:] = cell
517 atoms.positions[:] = positions
519 # User may wish to do something with the atoms object now.
520 # Should we provide option to yield here?
521 #
522 # (In that case we should MPI-synchronize *before*
523 # whereas now we do it after.)
525 # Send signal for other ranks to proceed with calculation:
526 self.comm.broadcast(np.zeros(1, bool), 0)
527 energy, forces, virial = self.calculate(atoms, use_stress)
529 self.state = 'HAVEDATA'
530 yield
531 elif msg == 'GETFORCE':
532 assert self.state == 'HAVEDATA', self.state
533 self.protocol.sendforce(energy, forces, virial)
534 self.state = 'NEEDINIT'
535 elif msg == 'INIT':
536 assert self.state == 'NEEDINIT'
537 bead_index, initbytes = self.protocol.recvinit()
538 self.bead_index = bead_index
539 self.bead_initbytes = initbytes
540 self.state = 'READY'
541 else:
542 raise KeyError('Bad message', msg)
543 finally:
544 self.close()
546 def run(self, atoms, use_stress=False):
547 for _ in self.irun(atoms, use_stress=use_stress):
548 pass
551class SocketIOCalculator(Calculator, IOContext):
552 implemented_properties = ['energy', 'free_energy', 'forces', 'stress']
553 supported_changes = {'positions', 'cell'}
555 def __init__(self, calc=None, port=None,
556 unixsocket=None, timeout=None, log=None, *,
557 launch_client=None):
558 """Initialize socket I/O calculator.
560 This calculator launches a server which passes atomic
561 coordinates and unit cells to an external code via a socket,
562 and receives energy, forces, and stress in return.
564 ASE integrates this with the Quantum Espresso, FHI-aims and
565 Siesta calculators. This works with any external code that
566 supports running as a client over the i-PI protocol.
568 Parameters:
570 calc: calculator or None
572 If calc is not None, a client process will be launched
573 using calc.command, and the input file will be generated
574 using ``calc.write_input()``. Otherwise only the server will
575 run, and it is up to the user to launch a compliant client
576 process.
578 port: integer
580 port number for socket. Should normally be between 1025
581 and 65535. Typical ports for are 31415 (default) or 3141.
583 unixsocket: str or None
585 if not None, ignore host and port, creating instead a
586 unix socket using this name prefixed with ``/tmp/ipi_``.
587 The socket is deleted when the calculator is closed.
589 timeout: float >= 0 or None
591 timeout for connection, by default infinite. See
592 documentation of Python sockets. For longer jobs it is
593 recommended to set a timeout in case of undetected
594 client-side failure.
596 log: file object or None (default)
598 logfile for communication over socket. For debugging or
599 the curious.
601 In order to correctly close the sockets, it is
602 recommended to use this class within a with-block:
604 >>> with SocketIOCalculator(...) as calc:
605 ... atoms.calc = calc
606 ... atoms.get_forces()
607 ... atoms.rattle()
608 ... atoms.get_forces()
610 It is also possible to call calc.close() after
611 use. This is best done in a finally-block."""
613 Calculator.__init__(self)
615 if calc is not None:
616 if launch_client is not None:
617 raise ValueError('Cannot pass both calc and launch_client')
618 launch_client = FileIOSocketClientLauncher(calc)
619 self.launch_client = launch_client
620 self.timeout = timeout
621 self.server = None
623 self.log = self.openfile(log)
625 # We only hold these so we can pass them on to the server.
626 # They may both be None as stored here.
627 self._port = port
628 self._unixsocket = unixsocket
630 # If there is a calculator, we will launch in calculate() because
631 # we are responsible for executing the external process, too, and
632 # should do so before blocking. Without a calculator we want to
633 # block immediately:
634 if self.launch_client is None:
635 self.server = self.launch_server()
637 def todict(self):
638 d = {'type': 'calculator',
639 'name': 'socket-driver'}
640 # if self.calc is not None:
641 # d['calc'] = self.calc.todict()
642 return d
644 def launch_server(self):
645 return self.closelater(SocketServer(
646 # launch_client=launch_client,
647 port=self._port,
648 unixsocket=self._unixsocket,
649 timeout=self.timeout, log=self.log,
650 ))
652 def calculate(self, atoms=None, properties=['energy'],
653 system_changes=all_changes):
654 bad = [change for change in system_changes
655 if change not in self.supported_changes]
657 # First time calculate() is called, system_changes will be
658 # all_changes. After that, only positions and cell may change.
659 if self.atoms is not None and any(bad):
660 raise PropertyNotImplementedError(
661 'Cannot change {} through IPI protocol. '
662 'Please create new socket calculator.'
663 .format(bad if len(bad) > 1 else bad[0]))
665 self.atoms = atoms.copy()
667 if self.server is None:
668 self.server = self.launch_server()
669 proc = self.launch_client(atoms, properties,
670 port=self._port,
671 unixsocket=self._unixsocket)
672 self.server.proc = proc # XXX nasty hack
674 results = self.server.calculate(atoms)
675 results['free_energy'] = results['energy']
676 virial = results.pop('virial')
677 if self.atoms.cell.rank == 3 and any(self.atoms.pbc):
678 vol = atoms.get_volume()
679 results['stress'] = -full_3x3_to_voigt_6_stress(virial) / vol
680 self.results.update(results)
682 def close(self):
683 self.server = None
684 super().close()
687class PySocketIOClient:
688 def __init__(self, calculator_factory):
689 self._calculator_factory = calculator_factory
691 def __call__(self, atoms, properties=None, port=None, unixsocket=None):
692 import sys
693 import pickle
695 # We pickle everything first, so we won't need to bother with the
696 # process as long as it succeeds.
697 transferbytes = pickle.dumps([
698 dict(unixsocket=unixsocket, port=port),
699 atoms.copy(),
700 self._calculator_factory,
701 ])
703 proc = Popen([sys.executable, '-m', 'ase.calculators.socketio'],
704 stdin=PIPE)
706 proc.stdin.write(transferbytes)
707 proc.stdin.close()
708 return proc
710 @staticmethod
711 def main():
712 import sys
713 import pickle
715 socketinfo, atoms, get_calculator = pickle.load(sys.stdin.buffer)
716 atoms.calc = get_calculator()
717 client = SocketClient(host='localhost',
718 unixsocket=socketinfo.get('unixsocket'),
719 port=socketinfo.get('port'))
720 # XXX In principle we could avoid calculating stress until
721 # someone requests the stress, could we not?
722 # Which would make use_stress boolean unnecessary.
723 client.run(atoms, use_stress=True)
726if __name__ == '__main__':
727 PySocketIOClient.main()