Coverage for /builds/ase/ase/ase/db/core.py : 85.68%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import functools
2import json
3import numbers
4import operator
5import os
6import re
7import warnings
8from time import time
9from typing import List, Dict, Any
11import numpy as np
13from ase.atoms import Atoms
14from ase.calculators.calculator import all_properties, all_changes
15from ase.data import atomic_numbers
16from ase.db.row import AtomsRow
17from ase.formula import Formula
18from ase.io.jsonio import create_ase_object
19from ase.parallel import world, DummyMPI, parallel_function, parallel_generator
20from ase.utils import Lock, PurePath
23T2000 = 946681200.0 # January 1. 2000
24YEAR = 31557600.0 # 365.25 days
27@functools.total_ordering
28class KeyDescription:
29 _subscript = re.compile(r'`(.)_(.)`')
30 _superscript = re.compile(r'`(.*)\^\{?(.*?)\}?`')
32 def __init__(self, key, shortdesc=None, longdesc=None, unit=''):
33 self.key = key
35 if shortdesc is None:
36 shortdesc = key
38 if longdesc is None:
39 longdesc = shortdesc
41 self.shortdesc = shortdesc
42 self.longdesc = longdesc
44 # Somewhat arbitrary that we do this conversion. Can we avoid that?
45 # Previously done in create_key_descriptions().
46 unit = self._subscript.sub(r'\1<sub>\2</sub>', unit)
47 unit = self._superscript.sub(r'\1<sup>\2</sup>', unit)
48 unit = unit.replace(r'\text{', '').replace('}', '')
50 self.unit = unit
52 def __repr__(self):
53 cls = type(self).__name__
54 return (f'{cls}({self.key!r}, {self.shortdesc!r}, {self.longdesc!r}, '
55 f'unit={self.unit!r})')
57 # The templates like to sort key descriptions by shortdesc.
58 def __eq__(self, other):
59 return self.shortdesc == getattr(other, 'shortdesc', None)
61 def __lt__(self, other):
62 return self.shortdesc < getattr(other, 'shortdesc', self.shortdesc)
65def get_key_descriptions():
66 KD = KeyDescription
67 return {keydesc.key: keydesc for keydesc in [
68 KD('id', 'ID', 'Uniqe row ID'),
69 KD('age', 'Age', 'Time since creation'),
70 KD('formula', 'Formula', 'Chemical formula'),
71 KD('pbc', 'PBC', 'Periodic boundary conditions'),
72 KD('user', 'Username'),
73 KD('calculator', 'Calculator', 'ASE-calculator name'),
74 KD('energy', 'Energy', 'Total energy', unit='eV'),
75 KD('natoms', 'Number of atoms'),
76 KD('fmax', 'Maximum force', unit='eV/Å'),
77 KD('smax', 'Maximum stress', 'Maximum stress on unit cell',
78 unit='eV/ų'),
79 KD('charge', 'Charge', 'Net charge in unit cell', unit='|e|'),
80 KD('mass', 'Mass', 'Sum of atomic masses in unit cell', unit='au'),
81 KD('magmom', 'Magnetic moment', unit='μ_B'),
82 KD('unique_id', 'Unique ID', 'Random (unique) ID'),
83 KD('volume', 'Volume', 'Volume of unit cell', unit='ų')
84 ]}
87def now():
88 """Return time since January 1. 2000 in years."""
89 return (time() - T2000) / YEAR
92seconds = {'s': 1,
93 'm': 60,
94 'h': 3600,
95 'd': 86400,
96 'w': 604800,
97 'M': 2629800,
98 'y': YEAR}
100longwords = {'s': 'second',
101 'm': 'minute',
102 'h': 'hour',
103 'd': 'day',
104 'w': 'week',
105 'M': 'month',
106 'y': 'year'}
108ops = {'<': operator.lt,
109 '<=': operator.le,
110 '=': operator.eq,
111 '>=': operator.ge,
112 '>': operator.gt,
113 '!=': operator.ne}
115invop = {'<': '>=', '<=': '>', '>=': '<', '>': '<=', '=': '!=', '!=': '='}
117word = re.compile('[_a-zA-Z][_0-9a-zA-Z]*$')
119reserved_keys = set(all_properties +
120 all_changes +
121 list(atomic_numbers) +
122 ['id', 'unique_id', 'ctime', 'mtime', 'user',
123 'fmax', 'smax',
124 'momenta', 'constraints', 'natoms', 'formula', 'age',
125 'calculator', 'calculator_parameters',
126 'key_value_pairs', 'data'])
128numeric_keys = set(['id', 'energy', 'magmom', 'charge', 'natoms'])
131def check(key_value_pairs):
132 for key, value in key_value_pairs.items():
133 if key == "external_tables":
134 # Checks for external_tables are not
135 # performed
136 continue
138 if not word.match(key) or key in reserved_keys:
139 raise ValueError('Bad key: {}'.format(key))
140 try:
141 Formula(key, strict=True)
142 except ValueError:
143 pass
144 else:
145 warnings.warn(
146 'It is best not to use keys ({0}) that are also a '
147 'chemical formula. If you do a "db.select({0!r})",'
148 'you will not find rows with your key. Instead, you wil get '
149 'rows containing the atoms in the formula!'.format(key))
150 if not isinstance(value, (numbers.Real, str, np.bool_)):
151 raise ValueError('Bad value for {!r}: {}'.format(key, value))
152 if isinstance(value, str):
153 for t in [int, float]:
154 if str_represents(value, t):
155 raise ValueError(
156 'Value ' + value + ' is put in as string ' +
157 'but can be interpreted as ' +
158 '{}! Please convert '.format(t.__name__) +
159 'to {} using '.format(t.__name__) +
160 '{}(value) before '.format(t.__name__) +
161 'writing to the database OR change ' +
162 'to a different string.')
165def str_represents(value, t=int):
166 try:
167 t(value)
168 except ValueError:
169 return False
170 return True
173def connect(name, type='extract_from_name', create_indices=True,
174 use_lock_file=True, append=True, serial=False):
175 """Create connection to database.
177 name: str
178 Filename or address of database.
179 type: str
180 One of 'json', 'db', 'postgresql',
181 (JSON, SQLite, PostgreSQL).
182 Default is 'extract_from_name', which will guess the type
183 from the name.
184 use_lock_file: bool
185 You can turn this off if you know what you are doing ...
186 append: bool
187 Use append=False to start a new database.
188 """
190 if isinstance(name, PurePath):
191 name = str(name)
193 if type == 'extract_from_name':
194 if name is None:
195 type = None
196 elif not isinstance(name, str):
197 type = 'json'
198 elif (name.startswith('postgresql://') or
199 name.startswith('postgres://')):
200 type = 'postgresql'
201 elif name.startswith('mysql://') or name.startswith('mariadb://'):
202 type = 'mysql'
203 else:
204 type = os.path.splitext(name)[1][1:]
205 if type == '':
206 raise ValueError('No file extension or database type given')
208 if type is None:
209 return Database()
211 if not append and world.rank == 0:
212 if isinstance(name, str) and os.path.isfile(name):
213 os.remove(name)
215 if type not in ['postgresql', 'mysql'] and isinstance(name, str):
216 name = os.path.abspath(name)
218 if type == 'json':
219 from ase.db.jsondb import JSONDatabase
220 return JSONDatabase(name, use_lock_file=use_lock_file, serial=serial)
221 if type == 'db':
222 from ase.db.sqlite import SQLite3Database
223 return SQLite3Database(name, create_indices, use_lock_file,
224 serial=serial)
225 if type == 'postgresql':
226 from ase.db.postgresql import PostgreSQLDatabase
227 return PostgreSQLDatabase(name)
229 if type == 'mysql':
230 from ase.db.mysql import MySQLDatabase
231 return MySQLDatabase(name)
232 raise ValueError('Unknown database type: ' + type)
235def lock(method):
236 """Decorator for using a lock-file."""
237 @functools.wraps(method)
238 def new_method(self, *args, **kwargs):
239 if self.lock is None:
240 return method(self, *args, **kwargs)
241 else:
242 with self.lock:
243 return method(self, *args, **kwargs)
244 return new_method
247def convert_str_to_int_float_or_str(value):
248 """Safe eval()"""
249 try:
250 return int(value)
251 except ValueError:
252 try:
253 value = float(value)
254 except ValueError:
255 value = {'True': True, 'False': False}.get(value, value)
256 return value
259def parse_selection(selection, **kwargs):
260 if selection is None or selection == '':
261 expressions = []
262 elif isinstance(selection, int):
263 expressions = [('id', '=', selection)]
264 elif isinstance(selection, list):
265 expressions = selection
266 else:
267 expressions = [w.strip() for w in selection.split(',')]
268 keys = []
269 comparisons = []
270 for expression in expressions:
271 if isinstance(expression, (list, tuple)):
272 comparisons.append(expression)
273 continue
274 if expression.count('<') == 2:
275 value, expression = expression.split('<', 1)
276 if expression[0] == '=':
277 op = '>='
278 expression = expression[1:]
279 else:
280 op = '>'
281 key = expression.split('<', 1)[0]
282 comparisons.append((key, op, value))
283 for op in ['!=', '<=', '>=', '<', '>', '=']:
284 if op in expression:
285 break
286 else: # no break
287 if expression in atomic_numbers:
288 comparisons.append((expression, '>', 0))
289 else:
290 try:
291 count = Formula(expression).count()
292 except ValueError:
293 keys.append(expression)
294 else:
295 comparisons.extend((symbol, '>', n - 1)
296 for symbol, n in count.items())
297 continue
298 key, value = expression.split(op)
299 comparisons.append((key, op, value))
301 cmps = []
302 for key, value in kwargs.items():
303 comparisons.append((key, '=', value))
305 for key, op, value in comparisons:
306 if key == 'age':
307 key = 'ctime'
308 op = invop[op]
309 value = now() - time_string_to_float(value)
310 elif key == 'formula':
311 if op != '=':
312 raise ValueError('Use fomula=...')
313 f = Formula(value)
314 count = f.count()
315 cmps.extend((atomic_numbers[symbol], '=', n)
316 for symbol, n in count.items())
317 key = 'natoms'
318 value = len(f)
319 elif key in atomic_numbers:
320 key = atomic_numbers[key]
321 value = int(value)
322 elif isinstance(value, str):
323 value = convert_str_to_int_float_or_str(value)
324 if key in numeric_keys and not isinstance(value, (int, float)):
325 msg = 'Wrong type for "{}{}{}" - must be a number'
326 raise ValueError(msg.format(key, op, value))
327 cmps.append((key, op, value))
329 return keys, cmps
332class Database:
333 """Base class for all databases."""
334 def __init__(self, filename=None, create_indices=True,
335 use_lock_file=False, serial=False):
336 """Database object.
338 serial: bool
339 Let someone else handle parallelization. Default behavior is
340 to interact with the database on the master only and then
341 distribute results to all slaves.
342 """
343 if isinstance(filename, str):
344 filename = os.path.expanduser(filename)
345 self.filename = filename
346 self.create_indices = create_indices
347 if use_lock_file and isinstance(filename, str):
348 self.lock = Lock(filename + '.lock', world=DummyMPI())
349 else:
350 self.lock = None
351 self.serial = serial
353 # Decription of columns and other stuff:
354 self._metadata: Dict[str, Any] = None
356 @property
357 def metadata(self) -> Dict[str, Any]:
358 raise NotImplementedError
360 @parallel_function
361 @lock
362 def write(self, atoms, key_value_pairs={}, data={}, id=None, **kwargs):
363 """Write atoms to database with key-value pairs.
365 atoms: Atoms object
366 Write atomic numbers, positions, unit cell and boundary
367 conditions. If a calculator is attached, write also already
368 calculated properties such as the energy and forces.
369 key_value_pairs: dict
370 Dictionary of key-value pairs. Values must be strings or numbers.
371 data: dict
372 Extra stuff (not for searching).
373 id: int
374 Overwrite existing row.
376 Key-value pairs can also be set using keyword arguments::
378 connection.write(atoms, name='ABC', frequency=42.0)
380 Returns integer id of the new row.
381 """
383 if atoms is None:
384 atoms = Atoms()
386 kvp = dict(key_value_pairs) # modify a copy
387 kvp.update(kwargs)
389 id = self._write(atoms, kvp, data, id)
390 return id
392 def _write(self, atoms, key_value_pairs, data, id=None):
393 check(key_value_pairs)
394 return 1
396 @parallel_function
397 @lock
398 def reserve(self, **key_value_pairs):
399 """Write empty row if not already present.
401 Usage::
403 id = conn.reserve(key1=value1, key2=value2, ...)
405 Write an empty row with the given key-value pairs and
406 return the integer id. If such a row already exists, don't write
407 anything and return None.
408 """
410 for dct in self._select([],
411 [(key, '=', value)
412 for key, value in key_value_pairs.items()]):
413 return None
415 atoms = Atoms()
417 calc_name = key_value_pairs.pop('calculator', None)
419 if calc_name:
420 # Allow use of calculator key
421 assert calc_name.lower() == calc_name
423 # Fake calculator class:
424 class Fake:
425 name = calc_name
427 def todict(self):
428 return {}
430 def check_state(self, atoms):
431 return ['positions']
433 atoms.calc = Fake()
435 id = self._write(atoms, key_value_pairs, {}, None)
437 return id
439 def __delitem__(self, id):
440 self.delete([id])
442 def get_atoms(self, selection=None,
443 add_additional_information=False, **kwargs):
444 """Get Atoms object.
446 selection: int, str or list
447 See the select() method.
448 add_additional_information: bool
449 Put key-value pairs and data into Atoms.info dictionary.
451 In addition, one can use keyword arguments to select specific
452 key-value pairs.
453 """
455 row = self.get(selection, **kwargs)
456 return row.toatoms(add_additional_information)
458 def __getitem__(self, selection):
459 return self.get(selection)
461 def get(self, selection=None, **kwargs):
462 """Select a single row and return it as a dictionary.
464 selection: int, str or list
465 See the select() method.
466 """
467 rows = list(self.select(selection, limit=2, **kwargs))
468 if not rows:
469 raise KeyError('no match')
470 assert len(rows) == 1, 'more than one row matched'
471 return rows[0]
473 @parallel_generator
474 def select(self, selection=None, filter=None, explain=False,
475 verbosity=1, limit=None, offset=0, sort=None,
476 include_data=True, columns='all', **kwargs):
477 """Select rows.
479 Return AtomsRow iterator with results. Selection is done
480 using key-value pairs and the special keys:
482 formula, age, user, calculator, natoms, energy, magmom
483 and/or charge.
485 selection: int, str or list
486 Can be:
488 * an integer id
489 * a string like 'key=value', where '=' can also be one of
490 '<=', '<', '>', '>=' or '!='.
491 * a string like 'key'
492 * comma separated strings like 'key1<value1,key2=value2,key'
493 * list of strings or tuples: [('charge', '=', 1)].
494 filter: function
495 A function that takes as input a row and returns True or False.
496 explain: bool
497 Explain query plan.
498 verbosity: int
499 Possible values: 0, 1 or 2.
500 limit: int or None
501 Limit selection.
502 offset: int
503 Offset into selected rows.
504 sort: str
505 Sort rows after key. Prepend with minus sign for a decending sort.
506 include_data: bool
507 Use include_data=False to skip reading data from rows.
508 columns: 'all' or list of str
509 Specify which columns from the SQL table to include.
510 For example, if only the row id and the energy is needed,
511 queries can be speeded up by setting columns=['id', 'energy'].
512 """
514 if sort:
515 if sort == 'age':
516 sort = '-ctime'
517 elif sort == '-age':
518 sort = 'ctime'
519 elif sort.lstrip('-') == 'user':
520 sort += 'name'
522 keys, cmps = parse_selection(selection, **kwargs)
523 for row in self._select(keys, cmps, explain=explain,
524 verbosity=verbosity,
525 limit=limit, offset=offset, sort=sort,
526 include_data=include_data,
527 columns=columns):
528 if filter is None or filter(row):
529 yield row
531 def count(self, selection=None, **kwargs):
532 """Count rows.
534 See the select() method for the selection syntax. Use db.count() or
535 len(db) to count all rows.
536 """
537 n = 0
538 for row in self.select(selection, **kwargs):
539 n += 1
540 return n
542 def __len__(self):
543 return self.count()
545 @parallel_function
546 @lock
547 def update(self, id, atoms=None, delete_keys=[], data=None,
548 **add_key_value_pairs):
549 """Update and/or delete key-value pairs of row(s).
551 id: int
552 ID of row to update.
553 atoms: Atoms object
554 Optionally update the Atoms data (positions, cell, ...).
555 data: dict
556 Data dict to be added to the existing data.
557 delete_keys: list of str
558 Keys to remove.
560 Use keyword arguments to add new key-value pairs.
562 Returns number of key-value pairs added and removed.
563 """
565 if not isinstance(id, numbers.Integral):
566 if isinstance(id, list):
567 err = ('First argument must be an int and not a list.\n'
568 'Do something like this instead:\n\n'
569 'with db:\n'
570 ' for id in ids:\n'
571 ' db.update(id, ...)')
572 raise ValueError(err)
573 raise TypeError('id must be an int')
575 check(add_key_value_pairs)
577 row = self._get_row(id)
578 kvp = row.key_value_pairs
580 n = len(kvp)
581 for key in delete_keys:
582 kvp.pop(key, None)
583 n -= len(kvp)
584 m = -len(kvp)
585 kvp.update(add_key_value_pairs)
586 m += len(kvp)
588 moredata = data
589 data = row.get('data', {})
590 if moredata:
591 data.update(moredata)
592 if not data:
593 data = None
595 if atoms:
596 oldrow = row
597 row = AtomsRow(atoms)
598 # Copy over data, kvp, ctime, user and id
599 row._data = oldrow._data
600 row.__dict__.update(kvp)
601 row._keys = list(kvp)
602 row.ctime = oldrow.ctime
603 row.user = oldrow.user
604 row.id = id
606 if atoms or os.path.splitext(self.filename)[1] == '.json':
607 self._write(row, kvp, data, row.id)
608 else:
609 self._update(row.id, kvp, data)
610 return m, n
612 def delete(self, ids):
613 """Delete rows."""
614 raise NotImplementedError
617def time_string_to_float(s):
618 if isinstance(s, (float, int)):
619 return s
620 s = s.replace(' ', '')
621 if '+' in s:
622 return sum(time_string_to_float(x) for x in s.split('+'))
623 if s[-2].isalpha() and s[-1] == 's':
624 s = s[:-1]
625 i = 1
626 while s[i].isdigit():
627 i += 1
628 return seconds[s[i:]] * int(s[:i]) / YEAR
631def float_to_time_string(t, long=False):
632 t *= YEAR
633 for s in 'yMwdhms':
634 x = t / seconds[s]
635 if x > 5:
636 break
637 if long:
638 return '{:.3f} {}s'.format(x, longwords[s])
639 else:
640 return '{:.0f}{}'.format(round(x), s)
643def object_to_bytes(obj: Any) -> bytes:
644 """Serialize Python object to bytes."""
645 parts = [b'12345678']
646 obj = o2b(obj, parts)
647 offset = sum(len(part) for part in parts)
648 x = np.array(offset, np.int64)
649 if not np.little_endian:
650 x.byteswap(True)
651 parts[0] = x.tobytes()
652 parts.append(json.dumps(obj, separators=(',', ':')).encode())
653 return b''.join(parts)
656def bytes_to_object(b: bytes) -> Any:
657 """Deserialize bytes to Python object."""
658 x = np.frombuffer(b[:8], np.int64)
659 if not np.little_endian:
660 x = x.byteswap()
661 offset = x.item()
662 obj = json.loads(b[offset:].decode())
663 return b2o(obj, b)
666def o2b(obj: Any, parts: List[bytes]):
667 if isinstance(obj, (int, float, bool, str, type(None))):
668 return obj
669 if isinstance(obj, dict):
670 return {key: o2b(value, parts) for key, value in obj.items()}
671 if isinstance(obj, (list, tuple)):
672 return [o2b(value, parts) for value in obj]
673 if isinstance(obj, np.ndarray):
674 assert obj.dtype != object, \
675 'Cannot convert ndarray of type "object" to bytes.'
676 offset = sum(len(part) for part in parts)
677 if not np.little_endian:
678 obj = obj.byteswap()
679 parts.append(obj.tobytes())
680 return {'__ndarray__': [obj.shape,
681 obj.dtype.name,
682 offset]}
683 if isinstance(obj, complex):
684 return {'__complex__': [obj.real, obj.imag]}
685 objtype = getattr(obj, 'ase_objtype')
686 if objtype:
687 dct = o2b(obj.todict(), parts)
688 dct['__ase_objtype__'] = objtype
689 return dct
690 raise ValueError('Objects of type {type} not allowed'
691 .format(type=type(obj)))
694def b2o(obj: Any, b: bytes) -> Any:
695 if isinstance(obj, (int, float, bool, str, type(None))):
696 return obj
698 if isinstance(obj, list):
699 return [b2o(value, b) for value in obj]
701 assert isinstance(obj, dict)
703 x = obj.get('__complex__')
704 if x is not None:
705 return complex(*x)
707 x = obj.get('__ndarray__')
708 if x is not None:
709 shape, name, offset = x
710 dtype = np.dtype(name)
711 size = dtype.itemsize * np.prod(shape).astype(int)
712 a = np.frombuffer(b[offset:offset + size], dtype)
713 a.shape = shape
714 if not np.little_endian:
715 a = a.byteswap()
716 return a
718 dct = {key: b2o(value, b) for key, value in obj.items()}
719 objtype = dct.pop('__ase_objtype__', None)
720 if objtype is None:
721 return dct
722 return create_ase_object(objtype, dct)