Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1import functools 

2import json 

3import numbers 

4import operator 

5import os 

6import re 

7import warnings 

8from time import time 

9from typing import List, Dict, Any 

10 

11import numpy as np 

12 

13from ase.atoms import Atoms 

14from ase.calculators.calculator import all_properties, all_changes 

15from ase.data import atomic_numbers 

16from ase.db.row import AtomsRow 

17from ase.formula import Formula 

18from ase.io.jsonio import create_ase_object 

19from ase.parallel import world, DummyMPI, parallel_function, parallel_generator 

20from ase.utils import Lock, PurePath 

21 

22 

23T2000 = 946681200.0 # January 1. 2000 

24YEAR = 31557600.0 # 365.25 days 

25 

26 

27@functools.total_ordering 

28class KeyDescription: 

29 _subscript = re.compile(r'`(.)_(.)`') 

30 _superscript = re.compile(r'`(.*)\^\{?(.*?)\}?`') 

31 

32 def __init__(self, key, shortdesc=None, longdesc=None, unit=''): 

33 self.key = key 

34 

35 if shortdesc is None: 

36 shortdesc = key 

37 

38 if longdesc is None: 

39 longdesc = shortdesc 

40 

41 self.shortdesc = shortdesc 

42 self.longdesc = longdesc 

43 

44 # Somewhat arbitrary that we do this conversion. Can we avoid that? 

45 # Previously done in create_key_descriptions(). 

46 unit = self._subscript.sub(r'\1<sub>\2</sub>', unit) 

47 unit = self._superscript.sub(r'\1<sup>\2</sup>', unit) 

48 unit = unit.replace(r'\text{', '').replace('}', '') 

49 

50 self.unit = unit 

51 

52 def __repr__(self): 

53 cls = type(self).__name__ 

54 return (f'{cls}({self.key!r}, {self.shortdesc!r}, {self.longdesc!r}, ' 

55 f'unit={self.unit!r})') 

56 

57 # The templates like to sort key descriptions by shortdesc. 

58 def __eq__(self, other): 

59 return self.shortdesc == getattr(other, 'shortdesc', None) 

60 

61 def __lt__(self, other): 

62 return self.shortdesc < getattr(other, 'shortdesc', self.shortdesc) 

63 

64 

65def get_key_descriptions(): 

66 KD = KeyDescription 

67 return {keydesc.key: keydesc for keydesc in [ 

68 KD('id', 'ID', 'Uniqe row ID'), 

69 KD('age', 'Age', 'Time since creation'), 

70 KD('formula', 'Formula', 'Chemical formula'), 

71 KD('pbc', 'PBC', 'Periodic boundary conditions'), 

72 KD('user', 'Username'), 

73 KD('calculator', 'Calculator', 'ASE-calculator name'), 

74 KD('energy', 'Energy', 'Total energy', unit='eV'), 

75 KD('natoms', 'Number of atoms'), 

76 KD('fmax', 'Maximum force', unit='eV/Å'), 

77 KD('smax', 'Maximum stress', 'Maximum stress on unit cell', 

78 unit='eV/ų'), 

79 KD('charge', 'Charge', 'Net charge in unit cell', unit='|e|'), 

80 KD('mass', 'Mass', 'Sum of atomic masses in unit cell', unit='au'), 

81 KD('magmom', 'Magnetic moment', unit='μ_B'), 

82 KD('unique_id', 'Unique ID', 'Random (unique) ID'), 

83 KD('volume', 'Volume', 'Volume of unit cell', unit='ų') 

84 ]} 

85 

86 

87def now(): 

88 """Return time since January 1. 2000 in years.""" 

89 return (time() - T2000) / YEAR 

90 

91 

92seconds = {'s': 1, 

93 'm': 60, 

94 'h': 3600, 

95 'd': 86400, 

96 'w': 604800, 

97 'M': 2629800, 

98 'y': YEAR} 

99 

100longwords = {'s': 'second', 

101 'm': 'minute', 

102 'h': 'hour', 

103 'd': 'day', 

104 'w': 'week', 

105 'M': 'month', 

106 'y': 'year'} 

107 

108ops = {'<': operator.lt, 

109 '<=': operator.le, 

110 '=': operator.eq, 

111 '>=': operator.ge, 

112 '>': operator.gt, 

113 '!=': operator.ne} 

114 

115invop = {'<': '>=', '<=': '>', '>=': '<', '>': '<=', '=': '!=', '!=': '='} 

116 

117word = re.compile('[_a-zA-Z][_0-9a-zA-Z]*$') 

118 

119reserved_keys = set(all_properties + 

120 all_changes + 

121 list(atomic_numbers) + 

122 ['id', 'unique_id', 'ctime', 'mtime', 'user', 

123 'fmax', 'smax', 

124 'momenta', 'constraints', 'natoms', 'formula', 'age', 

125 'calculator', 'calculator_parameters', 

126 'key_value_pairs', 'data']) 

127 

128numeric_keys = set(['id', 'energy', 'magmom', 'charge', 'natoms']) 

129 

130 

131def check(key_value_pairs): 

132 for key, value in key_value_pairs.items(): 

133 if key == "external_tables": 

134 # Checks for external_tables are not 

135 # performed 

136 continue 

137 

138 if not word.match(key) or key in reserved_keys: 

139 raise ValueError('Bad key: {}'.format(key)) 

140 try: 

141 Formula(key, strict=True) 

142 except ValueError: 

143 pass 

144 else: 

145 warnings.warn( 

146 'It is best not to use keys ({0}) that are also a ' 

147 'chemical formula. If you do a "db.select({0!r})",' 

148 'you will not find rows with your key. Instead, you wil get ' 

149 'rows containing the atoms in the formula!'.format(key)) 

150 if not isinstance(value, (numbers.Real, str, np.bool_)): 

151 raise ValueError('Bad value for {!r}: {}'.format(key, value)) 

152 if isinstance(value, str): 

153 for t in [int, float]: 

154 if str_represents(value, t): 

155 raise ValueError( 

156 'Value ' + value + ' is put in as string ' + 

157 'but can be interpreted as ' + 

158 '{}! Please convert '.format(t.__name__) + 

159 'to {} using '.format(t.__name__) + 

160 '{}(value) before '.format(t.__name__) + 

161 'writing to the database OR change ' + 

162 'to a different string.') 

163 

164 

165def str_represents(value, t=int): 

166 try: 

167 t(value) 

168 except ValueError: 

169 return False 

170 return True 

171 

172 

173def connect(name, type='extract_from_name', create_indices=True, 

174 use_lock_file=True, append=True, serial=False): 

175 """Create connection to database. 

176 

177 name: str 

178 Filename or address of database. 

179 type: str 

180 One of 'json', 'db', 'postgresql', 

181 (JSON, SQLite, PostgreSQL). 

182 Default is 'extract_from_name', which will guess the type 

183 from the name. 

184 use_lock_file: bool 

185 You can turn this off if you know what you are doing ... 

186 append: bool 

187 Use append=False to start a new database. 

188 """ 

189 

190 if isinstance(name, PurePath): 

191 name = str(name) 

192 

193 if type == 'extract_from_name': 

194 if name is None: 

195 type = None 

196 elif not isinstance(name, str): 

197 type = 'json' 

198 elif (name.startswith('postgresql://') or 

199 name.startswith('postgres://')): 

200 type = 'postgresql' 

201 elif name.startswith('mysql://') or name.startswith('mariadb://'): 

202 type = 'mysql' 

203 else: 

204 type = os.path.splitext(name)[1][1:] 

205 if type == '': 

206 raise ValueError('No file extension or database type given') 

207 

208 if type is None: 

209 return Database() 

210 

211 if not append and world.rank == 0: 

212 if isinstance(name, str) and os.path.isfile(name): 

213 os.remove(name) 

214 

215 if type not in ['postgresql', 'mysql'] and isinstance(name, str): 

216 name = os.path.abspath(name) 

217 

218 if type == 'json': 

219 from ase.db.jsondb import JSONDatabase 

220 return JSONDatabase(name, use_lock_file=use_lock_file, serial=serial) 

221 if type == 'db': 

222 from ase.db.sqlite import SQLite3Database 

223 return SQLite3Database(name, create_indices, use_lock_file, 

224 serial=serial) 

225 if type == 'postgresql': 

226 from ase.db.postgresql import PostgreSQLDatabase 

227 return PostgreSQLDatabase(name) 

228 

229 if type == 'mysql': 

230 from ase.db.mysql import MySQLDatabase 

231 return MySQLDatabase(name) 

232 raise ValueError('Unknown database type: ' + type) 

233 

234 

235def lock(method): 

236 """Decorator for using a lock-file.""" 

237 @functools.wraps(method) 

238 def new_method(self, *args, **kwargs): 

239 if self.lock is None: 

240 return method(self, *args, **kwargs) 

241 else: 

242 with self.lock: 

243 return method(self, *args, **kwargs) 

244 return new_method 

245 

246 

247def convert_str_to_int_float_or_str(value): 

248 """Safe eval()""" 

249 try: 

250 return int(value) 

251 except ValueError: 

252 try: 

253 value = float(value) 

254 except ValueError: 

255 value = {'True': True, 'False': False}.get(value, value) 

256 return value 

257 

258 

259def parse_selection(selection, **kwargs): 

260 if selection is None or selection == '': 

261 expressions = [] 

262 elif isinstance(selection, int): 

263 expressions = [('id', '=', selection)] 

264 elif isinstance(selection, list): 

265 expressions = selection 

266 else: 

267 expressions = [w.strip() for w in selection.split(',')] 

268 keys = [] 

269 comparisons = [] 

270 for expression in expressions: 

271 if isinstance(expression, (list, tuple)): 

272 comparisons.append(expression) 

273 continue 

274 if expression.count('<') == 2: 

275 value, expression = expression.split('<', 1) 

276 if expression[0] == '=': 

277 op = '>=' 

278 expression = expression[1:] 

279 else: 

280 op = '>' 

281 key = expression.split('<', 1)[0] 

282 comparisons.append((key, op, value)) 

283 for op in ['!=', '<=', '>=', '<', '>', '=']: 

284 if op in expression: 

285 break 

286 else: # no break 

287 if expression in atomic_numbers: 

288 comparisons.append((expression, '>', 0)) 

289 else: 

290 try: 

291 count = Formula(expression).count() 

292 except ValueError: 

293 keys.append(expression) 

294 else: 

295 comparisons.extend((symbol, '>', n - 1) 

296 for symbol, n in count.items()) 

297 continue 

298 key, value = expression.split(op) 

299 comparisons.append((key, op, value)) 

300 

301 cmps = [] 

302 for key, value in kwargs.items(): 

303 comparisons.append((key, '=', value)) 

304 

305 for key, op, value in comparisons: 

306 if key == 'age': 

307 key = 'ctime' 

308 op = invop[op] 

309 value = now() - time_string_to_float(value) 

310 elif key == 'formula': 

311 if op != '=': 

312 raise ValueError('Use fomula=...') 

313 f = Formula(value) 

314 count = f.count() 

315 cmps.extend((atomic_numbers[symbol], '=', n) 

316 for symbol, n in count.items()) 

317 key = 'natoms' 

318 value = len(f) 

319 elif key in atomic_numbers: 

320 key = atomic_numbers[key] 

321 value = int(value) 

322 elif isinstance(value, str): 

323 value = convert_str_to_int_float_or_str(value) 

324 if key in numeric_keys and not isinstance(value, (int, float)): 

325 msg = 'Wrong type for "{}{}{}" - must be a number' 

326 raise ValueError(msg.format(key, op, value)) 

327 cmps.append((key, op, value)) 

328 

329 return keys, cmps 

330 

331 

332class Database: 

333 """Base class for all databases.""" 

334 def __init__(self, filename=None, create_indices=True, 

335 use_lock_file=False, serial=False): 

336 """Database object. 

337 

338 serial: bool 

339 Let someone else handle parallelization. Default behavior is 

340 to interact with the database on the master only and then 

341 distribute results to all slaves. 

342 """ 

343 if isinstance(filename, str): 

344 filename = os.path.expanduser(filename) 

345 self.filename = filename 

346 self.create_indices = create_indices 

347 if use_lock_file and isinstance(filename, str): 

348 self.lock = Lock(filename + '.lock', world=DummyMPI()) 

349 else: 

350 self.lock = None 

351 self.serial = serial 

352 

353 # Decription of columns and other stuff: 

354 self._metadata: Dict[str, Any] = None 

355 

356 @property 

357 def metadata(self) -> Dict[str, Any]: 

358 raise NotImplementedError 

359 

360 @parallel_function 

361 @lock 

362 def write(self, atoms, key_value_pairs={}, data={}, id=None, **kwargs): 

363 """Write atoms to database with key-value pairs. 

364 

365 atoms: Atoms object 

366 Write atomic numbers, positions, unit cell and boundary 

367 conditions. If a calculator is attached, write also already 

368 calculated properties such as the energy and forces. 

369 key_value_pairs: dict 

370 Dictionary of key-value pairs. Values must be strings or numbers. 

371 data: dict 

372 Extra stuff (not for searching). 

373 id: int 

374 Overwrite existing row. 

375 

376 Key-value pairs can also be set using keyword arguments:: 

377 

378 connection.write(atoms, name='ABC', frequency=42.0) 

379 

380 Returns integer id of the new row. 

381 """ 

382 

383 if atoms is None: 

384 atoms = Atoms() 

385 

386 kvp = dict(key_value_pairs) # modify a copy 

387 kvp.update(kwargs) 

388 

389 id = self._write(atoms, kvp, data, id) 

390 return id 

391 

392 def _write(self, atoms, key_value_pairs, data, id=None): 

393 check(key_value_pairs) 

394 return 1 

395 

396 @parallel_function 

397 @lock 

398 def reserve(self, **key_value_pairs): 

399 """Write empty row if not already present. 

400 

401 Usage:: 

402 

403 id = conn.reserve(key1=value1, key2=value2, ...) 

404 

405 Write an empty row with the given key-value pairs and 

406 return the integer id. If such a row already exists, don't write 

407 anything and return None. 

408 """ 

409 

410 for dct in self._select([], 

411 [(key, '=', value) 

412 for key, value in key_value_pairs.items()]): 

413 return None 

414 

415 atoms = Atoms() 

416 

417 calc_name = key_value_pairs.pop('calculator', None) 

418 

419 if calc_name: 

420 # Allow use of calculator key 

421 assert calc_name.lower() == calc_name 

422 

423 # Fake calculator class: 

424 class Fake: 

425 name = calc_name 

426 

427 def todict(self): 

428 return {} 

429 

430 def check_state(self, atoms): 

431 return ['positions'] 

432 

433 atoms.calc = Fake() 

434 

435 id = self._write(atoms, key_value_pairs, {}, None) 

436 

437 return id 

438 

439 def __delitem__(self, id): 

440 self.delete([id]) 

441 

442 def get_atoms(self, selection=None, 

443 add_additional_information=False, **kwargs): 

444 """Get Atoms object. 

445 

446 selection: int, str or list 

447 See the select() method. 

448 add_additional_information: bool 

449 Put key-value pairs and data into Atoms.info dictionary. 

450 

451 In addition, one can use keyword arguments to select specific 

452 key-value pairs. 

453 """ 

454 

455 row = self.get(selection, **kwargs) 

456 return row.toatoms(add_additional_information) 

457 

458 def __getitem__(self, selection): 

459 return self.get(selection) 

460 

461 def get(self, selection=None, **kwargs): 

462 """Select a single row and return it as a dictionary. 

463 

464 selection: int, str or list 

465 See the select() method. 

466 """ 

467 rows = list(self.select(selection, limit=2, **kwargs)) 

468 if not rows: 

469 raise KeyError('no match') 

470 assert len(rows) == 1, 'more than one row matched' 

471 return rows[0] 

472 

473 @parallel_generator 

474 def select(self, selection=None, filter=None, explain=False, 

475 verbosity=1, limit=None, offset=0, sort=None, 

476 include_data=True, columns='all', **kwargs): 

477 """Select rows. 

478 

479 Return AtomsRow iterator with results. Selection is done 

480 using key-value pairs and the special keys: 

481 

482 formula, age, user, calculator, natoms, energy, magmom 

483 and/or charge. 

484 

485 selection: int, str or list 

486 Can be: 

487 

488 * an integer id 

489 * a string like 'key=value', where '=' can also be one of 

490 '<=', '<', '>', '>=' or '!='. 

491 * a string like 'key' 

492 * comma separated strings like 'key1<value1,key2=value2,key' 

493 * list of strings or tuples: [('charge', '=', 1)]. 

494 filter: function 

495 A function that takes as input a row and returns True or False. 

496 explain: bool 

497 Explain query plan. 

498 verbosity: int 

499 Possible values: 0, 1 or 2. 

500 limit: int or None 

501 Limit selection. 

502 offset: int 

503 Offset into selected rows. 

504 sort: str 

505 Sort rows after key. Prepend with minus sign for a decending sort. 

506 include_data: bool 

507 Use include_data=False to skip reading data from rows. 

508 columns: 'all' or list of str 

509 Specify which columns from the SQL table to include. 

510 For example, if only the row id and the energy is needed, 

511 queries can be speeded up by setting columns=['id', 'energy']. 

512 """ 

513 

514 if sort: 

515 if sort == 'age': 

516 sort = '-ctime' 

517 elif sort == '-age': 

518 sort = 'ctime' 

519 elif sort.lstrip('-') == 'user': 

520 sort += 'name' 

521 

522 keys, cmps = parse_selection(selection, **kwargs) 

523 for row in self._select(keys, cmps, explain=explain, 

524 verbosity=verbosity, 

525 limit=limit, offset=offset, sort=sort, 

526 include_data=include_data, 

527 columns=columns): 

528 if filter is None or filter(row): 

529 yield row 

530 

531 def count(self, selection=None, **kwargs): 

532 """Count rows. 

533 

534 See the select() method for the selection syntax. Use db.count() or 

535 len(db) to count all rows. 

536 """ 

537 n = 0 

538 for row in self.select(selection, **kwargs): 

539 n += 1 

540 return n 

541 

542 def __len__(self): 

543 return self.count() 

544 

545 @parallel_function 

546 @lock 

547 def update(self, id, atoms=None, delete_keys=[], data=None, 

548 **add_key_value_pairs): 

549 """Update and/or delete key-value pairs of row(s). 

550 

551 id: int 

552 ID of row to update. 

553 atoms: Atoms object 

554 Optionally update the Atoms data (positions, cell, ...). 

555 data: dict 

556 Data dict to be added to the existing data. 

557 delete_keys: list of str 

558 Keys to remove. 

559 

560 Use keyword arguments to add new key-value pairs. 

561 

562 Returns number of key-value pairs added and removed. 

563 """ 

564 

565 if not isinstance(id, numbers.Integral): 

566 if isinstance(id, list): 

567 err = ('First argument must be an int and not a list.\n' 

568 'Do something like this instead:\n\n' 

569 'with db:\n' 

570 ' for id in ids:\n' 

571 ' db.update(id, ...)') 

572 raise ValueError(err) 

573 raise TypeError('id must be an int') 

574 

575 check(add_key_value_pairs) 

576 

577 row = self._get_row(id) 

578 kvp = row.key_value_pairs 

579 

580 n = len(kvp) 

581 for key in delete_keys: 

582 kvp.pop(key, None) 

583 n -= len(kvp) 

584 m = -len(kvp) 

585 kvp.update(add_key_value_pairs) 

586 m += len(kvp) 

587 

588 moredata = data 

589 data = row.get('data', {}) 

590 if moredata: 

591 data.update(moredata) 

592 if not data: 

593 data = None 

594 

595 if atoms: 

596 oldrow = row 

597 row = AtomsRow(atoms) 

598 # Copy over data, kvp, ctime, user and id 

599 row._data = oldrow._data 

600 row.__dict__.update(kvp) 

601 row._keys = list(kvp) 

602 row.ctime = oldrow.ctime 

603 row.user = oldrow.user 

604 row.id = id 

605 

606 if atoms or os.path.splitext(self.filename)[1] == '.json': 

607 self._write(row, kvp, data, row.id) 

608 else: 

609 self._update(row.id, kvp, data) 

610 return m, n 

611 

612 def delete(self, ids): 

613 """Delete rows.""" 

614 raise NotImplementedError 

615 

616 

617def time_string_to_float(s): 

618 if isinstance(s, (float, int)): 

619 return s 

620 s = s.replace(' ', '') 

621 if '+' in s: 

622 return sum(time_string_to_float(x) for x in s.split('+')) 

623 if s[-2].isalpha() and s[-1] == 's': 

624 s = s[:-1] 

625 i = 1 

626 while s[i].isdigit(): 

627 i += 1 

628 return seconds[s[i:]] * int(s[:i]) / YEAR 

629 

630 

631def float_to_time_string(t, long=False): 

632 t *= YEAR 

633 for s in 'yMwdhms': 

634 x = t / seconds[s] 

635 if x > 5: 

636 break 

637 if long: 

638 return '{:.3f} {}s'.format(x, longwords[s]) 

639 else: 

640 return '{:.0f}{}'.format(round(x), s) 

641 

642 

643def object_to_bytes(obj: Any) -> bytes: 

644 """Serialize Python object to bytes.""" 

645 parts = [b'12345678'] 

646 obj = o2b(obj, parts) 

647 offset = sum(len(part) for part in parts) 

648 x = np.array(offset, np.int64) 

649 if not np.little_endian: 

650 x.byteswap(True) 

651 parts[0] = x.tobytes() 

652 parts.append(json.dumps(obj, separators=(',', ':')).encode()) 

653 return b''.join(parts) 

654 

655 

656def bytes_to_object(b: bytes) -> Any: 

657 """Deserialize bytes to Python object.""" 

658 x = np.frombuffer(b[:8], np.int64) 

659 if not np.little_endian: 

660 x = x.byteswap() 

661 offset = x.item() 

662 obj = json.loads(b[offset:].decode()) 

663 return b2o(obj, b) 

664 

665 

666def o2b(obj: Any, parts: List[bytes]): 

667 if isinstance(obj, (int, float, bool, str, type(None))): 

668 return obj 

669 if isinstance(obj, dict): 

670 return {key: o2b(value, parts) for key, value in obj.items()} 

671 if isinstance(obj, (list, tuple)): 

672 return [o2b(value, parts) for value in obj] 

673 if isinstance(obj, np.ndarray): 

674 assert obj.dtype != object, \ 

675 'Cannot convert ndarray of type "object" to bytes.' 

676 offset = sum(len(part) for part in parts) 

677 if not np.little_endian: 

678 obj = obj.byteswap() 

679 parts.append(obj.tobytes()) 

680 return {'__ndarray__': [obj.shape, 

681 obj.dtype.name, 

682 offset]} 

683 if isinstance(obj, complex): 

684 return {'__complex__': [obj.real, obj.imag]} 

685 objtype = getattr(obj, 'ase_objtype') 

686 if objtype: 

687 dct = o2b(obj.todict(), parts) 

688 dct['__ase_objtype__'] = objtype 

689 return dct 

690 raise ValueError('Objects of type {type} not allowed' 

691 .format(type=type(obj))) 

692 

693 

694def b2o(obj: Any, b: bytes) -> Any: 

695 if isinstance(obj, (int, float, bool, str, type(None))): 

696 return obj 

697 

698 if isinstance(obj, list): 

699 return [b2o(value, b) for value in obj] 

700 

701 assert isinstance(obj, dict) 

702 

703 x = obj.get('__complex__') 

704 if x is not None: 

705 return complex(*x) 

706 

707 x = obj.get('__ndarray__') 

708 if x is not None: 

709 shape, name, offset = x 

710 dtype = np.dtype(name) 

711 size = dtype.itemsize * np.prod(shape).astype(int) 

712 a = np.frombuffer(b[offset:offset + size], dtype) 

713 a.shape = shape 

714 if not np.little_endian: 

715 a = a.byteswap() 

716 return a 

717 

718 dct = {key: b2o(value, b) for key, value in obj.items()} 

719 objtype = dct.pop('__ase_objtype__', None) 

720 if objtype is None: 

721 return dct 

722 return create_ase_object(objtype, dct)