Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1import functools 

2import json 

3import numbers 

4import operator 

5import os 

6import re 

7import warnings 

8from time import time 

9from typing import List, Dict, Any 

10 

11import numpy as np 

12 

13from ase.atoms import Atoms 

14from ase.calculators.calculator import all_properties, all_changes 

15from ase.data import atomic_numbers 

16from ase.db.row import AtomsRow 

17from ase.formula import Formula 

18from ase.io.jsonio import create_ase_object 

19from ase.parallel import world, DummyMPI, parallel_function, parallel_generator 

20from ase.utils import Lock, PurePath 

21 

22 

23T2000 = 946681200.0 # January 1. 2000 

24YEAR = 31557600.0 # 365.25 days 

25 

26 

27# Format of key description: ('short', 'long', 'unit') 

28default_key_descriptions = { 

29 'id': ('ID', 'Uniqe row ID', ''), 

30 'age': ('Age', 'Time since creation', ''), 

31 'formula': ('Formula', 'Chemical formula', ''), 

32 'pbc': ('PBC', 'Periodic boundary conditions', ''), 

33 'user': ('Username', '', ''), 

34 'calculator': ('Calculator', 'ASE-calculator name', ''), 

35 'energy': ('Energy', 'Total energy', 'eV'), 

36 'natoms': ('Number of atoms', '', ''), 

37 'fmax': ('Maximum force', '', 'eV/Ang'), 

38 'smax': ('Maximum stress', 'Maximum stress on unit cell', 

39 '`\\text{eV/Ang}^3`'), 

40 'charge': ('Charge', 'Net charge in unit cell', '|e|'), 

41 'mass': ('Mass', 'Sum of atomic masses in unit cell', 'au'), 

42 'magmom': ('Magnetic moment', '', 'μ_B'), 

43 'unique_id': ('Unique ID', 'Random (unique) ID', ''), 

44 'volume': ('Volume', 'Volume of unit cell', '`\\text{Ang}^3`')} 

45 

46 

47def now(): 

48 """Return time since January 1. 2000 in years.""" 

49 return (time() - T2000) / YEAR 

50 

51 

52seconds = {'s': 1, 

53 'm': 60, 

54 'h': 3600, 

55 'd': 86400, 

56 'w': 604800, 

57 'M': 2629800, 

58 'y': YEAR} 

59 

60longwords = {'s': 'second', 

61 'm': 'minute', 

62 'h': 'hour', 

63 'd': 'day', 

64 'w': 'week', 

65 'M': 'month', 

66 'y': 'year'} 

67 

68ops = {'<': operator.lt, 

69 '<=': operator.le, 

70 '=': operator.eq, 

71 '>=': operator.ge, 

72 '>': operator.gt, 

73 '!=': operator.ne} 

74 

75invop = {'<': '>=', '<=': '>', '>=': '<', '>': '<=', '=': '!=', '!=': '='} 

76 

77word = re.compile('[_a-zA-Z][_0-9a-zA-Z]*$') 

78 

79reserved_keys = set(all_properties + 

80 all_changes + 

81 list(atomic_numbers) + 

82 ['id', 'unique_id', 'ctime', 'mtime', 'user', 

83 'fmax', 'smax', 

84 'momenta', 'constraints', 'natoms', 'formula', 'age', 

85 'calculator', 'calculator_parameters', 

86 'key_value_pairs', 'data']) 

87 

88numeric_keys = set(['id', 'energy', 'magmom', 'charge', 'natoms']) 

89 

90 

91def check(key_value_pairs): 

92 for key, value in key_value_pairs.items(): 

93 if key == "external_tables": 

94 # Checks for external_tables are not 

95 # performed 

96 continue 

97 

98 if not word.match(key) or key in reserved_keys: 

99 raise ValueError('Bad key: {}'.format(key)) 

100 try: 

101 Formula(key, strict=True) 

102 except ValueError: 

103 pass 

104 else: 

105 warnings.warn( 

106 'It is best not to use keys ({0}) that are also a ' 

107 'chemical formula. If you do a "db.select({0!r})",' 

108 'you will not find rows with your key. Instead, you wil get ' 

109 'rows containing the atoms in the formula!'.format(key)) 

110 if not isinstance(value, (numbers.Real, str, np.bool_)): 

111 raise ValueError('Bad value for {!r}: {}'.format(key, value)) 

112 if isinstance(value, str): 

113 for t in [int, float]: 

114 if str_represents(value, t): 

115 raise ValueError( 

116 'Value ' + value + ' is put in as string ' + 

117 'but can be interpreted as ' + 

118 '{}! Please convert '.format(t.__name__) + 

119 'to {} using '.format(t.__name__) + 

120 '{}(value) before '.format(t.__name__) + 

121 'writing to the database OR change ' + 

122 'to a different string.') 

123 

124 

125def str_represents(value, t=int): 

126 try: 

127 t(value) 

128 except ValueError: 

129 return False 

130 return True 

131 

132 

133def connect(name, type='extract_from_name', create_indices=True, 

134 use_lock_file=True, append=True, serial=False): 

135 """Create connection to database. 

136 

137 name: str 

138 Filename or address of database. 

139 type: str 

140 One of 'json', 'db', 'postgresql', 

141 (JSON, SQLite, PostgreSQL). 

142 Default is 'extract_from_name', which will guess the type 

143 from the name. 

144 use_lock_file: bool 

145 You can turn this off if you know what you are doing ... 

146 append: bool 

147 Use append=False to start a new database. 

148 """ 

149 

150 if isinstance(name, PurePath): 

151 name = str(name) 

152 

153 if type == 'extract_from_name': 

154 if name is None: 

155 type = None 

156 elif not isinstance(name, str): 

157 type = 'json' 

158 elif (name.startswith('postgresql://') or 

159 name.startswith('postgres://')): 

160 type = 'postgresql' 

161 elif name.startswith('mysql://') or name.startswith('mariadb://'): 

162 type = 'mysql' 

163 else: 

164 type = os.path.splitext(name)[1][1:] 

165 if type == '': 

166 raise ValueError('No file extension or database type given') 

167 

168 if type is None: 

169 return Database() 

170 

171 if not append and world.rank == 0: 

172 if isinstance(name, str) and os.path.isfile(name): 

173 os.remove(name) 

174 

175 if type not in ['postgresql', 'mysql'] and isinstance(name, str): 

176 name = os.path.abspath(name) 

177 

178 if type == 'json': 

179 from ase.db.jsondb import JSONDatabase 

180 return JSONDatabase(name, use_lock_file=use_lock_file, serial=serial) 

181 if type == 'db': 

182 from ase.db.sqlite import SQLite3Database 

183 return SQLite3Database(name, create_indices, use_lock_file, 

184 serial=serial) 

185 if type == 'postgresql': 

186 from ase.db.postgresql import PostgreSQLDatabase 

187 return PostgreSQLDatabase(name) 

188 

189 if type == 'mysql': 

190 from ase.db.mysql import MySQLDatabase 

191 return MySQLDatabase(name) 

192 raise ValueError('Unknown database type: ' + type) 

193 

194 

195def lock(method): 

196 """Decorator for using a lock-file.""" 

197 @functools.wraps(method) 

198 def new_method(self, *args, **kwargs): 

199 if self.lock is None: 

200 return method(self, *args, **kwargs) 

201 else: 

202 with self.lock: 

203 return method(self, *args, **kwargs) 

204 return new_method 

205 

206 

207def convert_str_to_int_float_or_str(value): 

208 """Safe eval()""" 

209 try: 

210 return int(value) 

211 except ValueError: 

212 try: 

213 value = float(value) 

214 except ValueError: 

215 value = {'True': True, 'False': False}.get(value, value) 

216 return value 

217 

218 

219def parse_selection(selection, **kwargs): 

220 if selection is None or selection == '': 

221 expressions = [] 

222 elif isinstance(selection, int): 

223 expressions = [('id', '=', selection)] 

224 elif isinstance(selection, list): 

225 expressions = selection 

226 else: 

227 expressions = [w.strip() for w in selection.split(',')] 

228 keys = [] 

229 comparisons = [] 

230 for expression in expressions: 

231 if isinstance(expression, (list, tuple)): 

232 comparisons.append(expression) 

233 continue 

234 if expression.count('<') == 2: 

235 value, expression = expression.split('<', 1) 

236 if expression[0] == '=': 

237 op = '>=' 

238 expression = expression[1:] 

239 else: 

240 op = '>' 

241 key = expression.split('<', 1)[0] 

242 comparisons.append((key, op, value)) 

243 for op in ['!=', '<=', '>=', '<', '>', '=']: 

244 if op in expression: 

245 break 

246 else: # no break 

247 if expression in atomic_numbers: 

248 comparisons.append((expression, '>', 0)) 

249 else: 

250 try: 

251 count = Formula(expression).count() 

252 except ValueError: 

253 keys.append(expression) 

254 else: 

255 comparisons.extend((symbol, '>', n - 1) 

256 for symbol, n in count.items()) 

257 continue 

258 key, value = expression.split(op) 

259 comparisons.append((key, op, value)) 

260 

261 cmps = [] 

262 for key, value in kwargs.items(): 

263 comparisons.append((key, '=', value)) 

264 

265 for key, op, value in comparisons: 

266 if key == 'age': 

267 key = 'ctime' 

268 op = invop[op] 

269 value = now() - time_string_to_float(value) 

270 elif key == 'formula': 

271 if op != '=': 

272 raise ValueError('Use fomula=...') 

273 f = Formula(value) 

274 count = f.count() 

275 cmps.extend((atomic_numbers[symbol], '=', n) 

276 for symbol, n in count.items()) 

277 key = 'natoms' 

278 value = len(f) 

279 elif key in atomic_numbers: 

280 key = atomic_numbers[key] 

281 value = int(value) 

282 elif isinstance(value, str): 

283 value = convert_str_to_int_float_or_str(value) 

284 if key in numeric_keys and not isinstance(value, (int, float)): 

285 msg = 'Wrong type for "{}{}{}" - must be a number' 

286 raise ValueError(msg.format(key, op, value)) 

287 cmps.append((key, op, value)) 

288 

289 return keys, cmps 

290 

291 

292class Database: 

293 """Base class for all databases.""" 

294 def __init__(self, filename=None, create_indices=True, 

295 use_lock_file=False, serial=False): 

296 """Database object. 

297 

298 serial: bool 

299 Let someone else handle parallelization. Default behavior is 

300 to interact with the database on the master only and then 

301 distribute results to all slaves. 

302 """ 

303 if isinstance(filename, str): 

304 filename = os.path.expanduser(filename) 

305 self.filename = filename 

306 self.create_indices = create_indices 

307 if use_lock_file and isinstance(filename, str): 

308 self.lock = Lock(filename + '.lock', world=DummyMPI()) 

309 else: 

310 self.lock = None 

311 self.serial = serial 

312 

313 # Decription of columns and other stuff: 

314 self._metadata: Dict[str, Any] = None 

315 

316 @property 

317 def metadata(self) -> Dict[str, Any]: 

318 raise NotImplementedError 

319 

320 @parallel_function 

321 @lock 

322 def write(self, atoms, key_value_pairs={}, data={}, id=None, **kwargs): 

323 """Write atoms to database with key-value pairs. 

324 

325 atoms: Atoms object 

326 Write atomic numbers, positions, unit cell and boundary 

327 conditions. If a calculator is attached, write also already 

328 calculated properties such as the energy and forces. 

329 key_value_pairs: dict 

330 Dictionary of key-value pairs. Values must be strings or numbers. 

331 data: dict 

332 Extra stuff (not for searching). 

333 id: int 

334 Overwrite existing row. 

335 

336 Key-value pairs can also be set using keyword arguments:: 

337 

338 connection.write(atoms, name='ABC', frequency=42.0) 

339 

340 Returns integer id of the new row. 

341 """ 

342 

343 if atoms is None: 

344 atoms = Atoms() 

345 

346 kvp = dict(key_value_pairs) # modify a copy 

347 kvp.update(kwargs) 

348 

349 id = self._write(atoms, kvp, data, id) 

350 return id 

351 

352 def _write(self, atoms, key_value_pairs, data, id=None): 

353 check(key_value_pairs) 

354 return 1 

355 

356 @parallel_function 

357 @lock 

358 def reserve(self, **key_value_pairs): 

359 """Write empty row if not already present. 

360 

361 Usage:: 

362 

363 id = conn.reserve(key1=value1, key2=value2, ...) 

364 

365 Write an empty row with the given key-value pairs and 

366 return the integer id. If such a row already exists, don't write 

367 anything and return None. 

368 """ 

369 

370 for dct in self._select([], 

371 [(key, '=', value) 

372 for key, value in key_value_pairs.items()]): 

373 return None 

374 

375 atoms = Atoms() 

376 

377 calc_name = key_value_pairs.pop('calculator', None) 

378 

379 if calc_name: 

380 # Allow use of calculator key 

381 assert calc_name.lower() == calc_name 

382 

383 # Fake calculator class: 

384 class Fake: 

385 name = calc_name 

386 

387 def todict(self): 

388 return {} 

389 

390 def check_state(self, atoms): 

391 return ['positions'] 

392 

393 atoms.calc = Fake() 

394 

395 id = self._write(atoms, key_value_pairs, {}, None) 

396 

397 return id 

398 

399 def __delitem__(self, id): 

400 self.delete([id]) 

401 

402 def get_atoms(self, selection=None, 

403 add_additional_information=False, **kwargs): 

404 """Get Atoms object. 

405 

406 selection: int, str or list 

407 See the select() method. 

408 add_additional_information: bool 

409 Put key-value pairs and data into Atoms.info dictionary. 

410 

411 In addition, one can use keyword arguments to select specific 

412 key-value pairs. 

413 """ 

414 

415 row = self.get(selection, **kwargs) 

416 return row.toatoms(add_additional_information) 

417 

418 def __getitem__(self, selection): 

419 return self.get(selection) 

420 

421 def get(self, selection=None, **kwargs): 

422 """Select a single row and return it as a dictionary. 

423 

424 selection: int, str or list 

425 See the select() method. 

426 """ 

427 rows = list(self.select(selection, limit=2, **kwargs)) 

428 if not rows: 

429 raise KeyError('no match') 

430 assert len(rows) == 1, 'more than one row matched' 

431 return rows[0] 

432 

433 @parallel_generator 

434 def select(self, selection=None, filter=None, explain=False, 

435 verbosity=1, limit=None, offset=0, sort=None, 

436 include_data=True, columns='all', **kwargs): 

437 """Select rows. 

438 

439 Return AtomsRow iterator with results. Selection is done 

440 using key-value pairs and the special keys: 

441 

442 formula, age, user, calculator, natoms, energy, magmom 

443 and/or charge. 

444 

445 selection: int, str or list 

446 Can be: 

447 

448 * an integer id 

449 * a string like 'key=value', where '=' can also be one of 

450 '<=', '<', '>', '>=' or '!='. 

451 * a string like 'key' 

452 * comma separated strings like 'key1<value1,key2=value2,key' 

453 * list of strings or tuples: [('charge', '=', 1)]. 

454 filter: function 

455 A function that takes as input a row and returns True or False. 

456 explain: bool 

457 Explain query plan. 

458 verbosity: int 

459 Possible values: 0, 1 or 2. 

460 limit: int or None 

461 Limit selection. 

462 offset: int 

463 Offset into selected rows. 

464 sort: str 

465 Sort rows after key. Prepend with minus sign for a decending sort. 

466 include_data: bool 

467 Use include_data=False to skip reading data from rows. 

468 columns: 'all' or list of str 

469 Specify which columns from the SQL table to include. 

470 For example, if only the row id and the energy is needed, 

471 queries can be speeded up by setting columns=['id', 'energy']. 

472 """ 

473 

474 if sort: 

475 if sort == 'age': 

476 sort = '-ctime' 

477 elif sort == '-age': 

478 sort = 'ctime' 

479 elif sort.lstrip('-') == 'user': 

480 sort += 'name' 

481 

482 keys, cmps = parse_selection(selection, **kwargs) 

483 for row in self._select(keys, cmps, explain=explain, 

484 verbosity=verbosity, 

485 limit=limit, offset=offset, sort=sort, 

486 include_data=include_data, 

487 columns=columns): 

488 if filter is None or filter(row): 

489 yield row 

490 

491 def count(self, selection=None, **kwargs): 

492 """Count rows. 

493 

494 See the select() method for the selection syntax. Use db.count() or 

495 len(db) to count all rows. 

496 """ 

497 n = 0 

498 for row in self.select(selection, **kwargs): 

499 n += 1 

500 return n 

501 

502 def __len__(self): 

503 return self.count() 

504 

505 @parallel_function 

506 @lock 

507 def update(self, id, atoms=None, delete_keys=[], data=None, 

508 **add_key_value_pairs): 

509 """Update and/or delete key-value pairs of row(s). 

510 

511 id: int 

512 ID of row to update. 

513 atoms: Atoms object 

514 Optionally update the Atoms data (positions, cell, ...). 

515 data: dict 

516 Data dict to be added to the existing data. 

517 delete_keys: list of str 

518 Keys to remove. 

519 

520 Use keyword arguments to add new key-value pairs. 

521 

522 Returns number of key-value pairs added and removed. 

523 """ 

524 

525 if not isinstance(id, numbers.Integral): 

526 if isinstance(id, list): 

527 err = ('First argument must be an int and not a list.\n' 

528 'Do something like this instead:\n\n' 

529 'with db:\n' 

530 ' for id in ids:\n' 

531 ' db.update(id, ...)') 

532 raise ValueError(err) 

533 raise TypeError('id must be an int') 

534 

535 check(add_key_value_pairs) 

536 

537 row = self._get_row(id) 

538 kvp = row.key_value_pairs 

539 

540 n = len(kvp) 

541 for key in delete_keys: 

542 kvp.pop(key, None) 

543 n -= len(kvp) 

544 m = -len(kvp) 

545 kvp.update(add_key_value_pairs) 

546 m += len(kvp) 

547 

548 moredata = data 

549 data = row.get('data', {}) 

550 if moredata: 

551 data.update(moredata) 

552 if not data: 

553 data = None 

554 

555 if atoms: 

556 oldrow = row 

557 row = AtomsRow(atoms) 

558 # Copy over data, kvp, ctime, user and id 

559 row._data = oldrow._data 

560 row.__dict__.update(kvp) 

561 row._keys = list(kvp) 

562 row.ctime = oldrow.ctime 

563 row.user = oldrow.user 

564 row.id = id 

565 

566 if atoms or os.path.splitext(self.filename)[1] == '.json': 

567 self._write(row, kvp, data, row.id) 

568 else: 

569 self._update(row.id, kvp, data) 

570 return m, n 

571 

572 def delete(self, ids): 

573 """Delete rows.""" 

574 raise NotImplementedError 

575 

576 

577def time_string_to_float(s): 

578 if isinstance(s, (float, int)): 

579 return s 

580 s = s.replace(' ', '') 

581 if '+' in s: 

582 return sum(time_string_to_float(x) for x in s.split('+')) 

583 if s[-2].isalpha() and s[-1] == 's': 

584 s = s[:-1] 

585 i = 1 

586 while s[i].isdigit(): 

587 i += 1 

588 return seconds[s[i:]] * int(s[:i]) / YEAR 

589 

590 

591def float_to_time_string(t, long=False): 

592 t *= YEAR 

593 for s in 'yMwdhms': 

594 x = t / seconds[s] 

595 if x > 5: 

596 break 

597 if long: 

598 return '{:.3f} {}s'.format(x, longwords[s]) 

599 else: 

600 return '{:.0f}{}'.format(round(x), s) 

601 

602 

603def object_to_bytes(obj: Any) -> bytes: 

604 """Serialize Python object to bytes.""" 

605 parts = [b'12345678'] 

606 obj = o2b(obj, parts) 

607 offset = sum(len(part) for part in parts) 

608 x = np.array(offset, np.int64) 

609 if not np.little_endian: 

610 x.byteswap(True) 

611 parts[0] = x.tobytes() 

612 parts.append(json.dumps(obj, separators=(',', ':')).encode()) 

613 return b''.join(parts) 

614 

615 

616def bytes_to_object(b: bytes) -> Any: 

617 """Deserialize bytes to Python object.""" 

618 x = np.frombuffer(b[:8], np.int64) 

619 if not np.little_endian: 

620 x = x.byteswap() 

621 offset = x.item() 

622 obj = json.loads(b[offset:].decode()) 

623 return b2o(obj, b) 

624 

625 

626def o2b(obj: Any, parts: List[bytes]): 

627 if isinstance(obj, (int, float, bool, str, type(None))): 

628 return obj 

629 if isinstance(obj, dict): 

630 return {key: o2b(value, parts) for key, value in obj.items()} 

631 if isinstance(obj, (list, tuple)): 

632 return [o2b(value, parts) for value in obj] 

633 if isinstance(obj, np.ndarray): 

634 assert obj.dtype != object, \ 

635 'Cannot convert ndarray of type "object" to bytes.' 

636 offset = sum(len(part) for part in parts) 

637 if not np.little_endian: 

638 obj = obj.byteswap() 

639 parts.append(obj.tobytes()) 

640 return {'__ndarray__': [obj.shape, 

641 obj.dtype.name, 

642 offset]} 

643 if isinstance(obj, complex): 

644 return {'__complex__': [obj.real, obj.imag]} 

645 objtype = getattr(obj, 'ase_objtype') 

646 if objtype: 

647 dct = o2b(obj.todict(), parts) 

648 dct['__ase_objtype__'] = objtype 

649 return dct 

650 raise ValueError('Objects of type {type} not allowed' 

651 .format(type=type(obj))) 

652 

653 

654def b2o(obj: Any, b: bytes) -> Any: 

655 if isinstance(obj, (int, float, bool, str, type(None))): 

656 return obj 

657 

658 if isinstance(obj, list): 

659 return [b2o(value, b) for value in obj] 

660 

661 assert isinstance(obj, dict) 

662 

663 x = obj.get('__complex__') 

664 if x is not None: 

665 return complex(*x) 

666 

667 x = obj.get('__ndarray__') 

668 if x is not None: 

669 shape, name, offset = x 

670 dtype = np.dtype(name) 

671 size = dtype.itemsize * np.prod(shape).astype(int) 

672 a = np.frombuffer(b[offset:offset + size], dtype) 

673 a.shape = shape 

674 if not np.little_endian: 

675 a = a.byteswap() 

676 return a 

677 

678 dct = {key: b2o(value, b) for key, value in obj.items()} 

679 objtype = dct.pop('__ase_objtype__', None) 

680 if objtype is None: 

681 return dct 

682 return create_ase_object(objtype, dct)