Coverage for /builds/ase/ase/ase/db/cli.py : 62.04%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import json
2import sys
3from collections import defaultdict
4from contextlib import contextmanager
5from pathlib import Path
6from typing import Iterable, Iterator
8import ase.io
9from ase.db import connect
10from ase.db.core import convert_str_to_int_float_or_str
11from ase.db.row import row2dct
12from ase.db.table import Table, all_columns
13from ase.utils import plural
16def count_keys(db, query):
17 keys = defaultdict(int)
18 for row in db.select(query):
19 for key in row._keys:
20 keys[key] += 1
22 n = max(len(key) for key in keys) + 1
23 for key, number in keys.items():
24 print('{:{}} {}'.format(key + ':', n, number))
25 return
28def main(args):
29 verbosity = 1 - args.quiet + args.verbose
30 query = ','.join(args.query)
32 if args.sort.endswith('-'):
33 # Allow using "key-" instead of "-key" for reverse sorting
34 args.sort = '-' + args.sort[:-1]
36 if query.isdigit():
37 query = int(query)
39 add_key_value_pairs = {}
40 if args.add_key_value_pairs:
41 for pair in args.add_key_value_pairs.split(','):
42 key, value = pair.split('=')
43 add_key_value_pairs[key] = convert_str_to_int_float_or_str(value)
45 if args.delete_keys:
46 delete_keys = args.delete_keys.split(',')
47 else:
48 delete_keys = []
50 db = connect(args.database, use_lock_file=not args.no_lock_file)
52 def out(*args):
53 if verbosity > 0:
54 print(*args)
56 if args.analyse:
57 db.analyse()
58 return
60 if args.show_keys:
61 count_keys(db, query)
62 return
64 if args.show_values:
65 keys = args.show_values.split(',')
66 values = {key: defaultdict(int) for key in keys}
67 numbers = set()
68 for row in db.select(query):
69 kvp = row.key_value_pairs
70 for key in keys:
71 value = kvp.get(key)
72 if value is not None:
73 values[key][value] += 1
74 if not isinstance(value, str):
75 numbers.add(key)
77 n = max(len(key) for key in keys) + 1
78 for key in keys:
79 vals = values[key]
80 if key in numbers:
81 print('{:{}} [{}..{}]'
82 .format(key + ':', n, min(vals), max(vals)))
83 else:
84 print('{:{}} {}'
85 .format(key + ':', n,
86 ', '.join('{}({})'.format(v, n)
87 for v, n in vals.items())))
88 return
90 if args.add_from_file:
91 filename = args.add_from_file
92 configs = ase.io.read(filename)
93 if not isinstance(configs, list):
94 configs = [configs]
95 for atoms in configs:
96 db.write(atoms, key_value_pairs=add_key_value_pairs)
97 out('Added ' + plural(len(configs), 'row'))
98 return
100 if args.count:
101 n = db.count(query)
102 print('%s' % plural(n, 'row'))
103 return
105 if args.insert_into:
106 if args.limit == -1:
107 args.limit = 0
109 progressbar = no_progressbar
110 length = None
112 if args.progress_bar:
113 # Try to import the one from click.
114 # People using ase.db will most likely have flask installed
115 # and therfore also click.
116 try:
117 from click import progressbar
118 except ImportError:
119 pass
120 else:
121 length = db.count(query)
123 nkvp = 0
124 nrows = 0
125 with connect(args.insert_into,
126 use_lock_file=not args.no_lock_file) as db2:
127 with progressbar(db.select(query,
128 sort=args.sort,
129 limit=args.limit,
130 offset=args.offset),
131 length=length) as rows:
132 for row in rows:
133 kvp = row.get('key_value_pairs', {})
134 nkvp -= len(kvp)
135 kvp.update(add_key_value_pairs)
136 nkvp += len(kvp)
137 if args.strip_data:
138 db2.write(row.toatoms(), **kvp)
139 else:
140 db2.write(row, data=row.get('data'), **kvp)
141 nrows += 1
143 out('Added %s (%s updated)' %
144 (plural(nkvp, 'key-value pair'),
145 plural(len(add_key_value_pairs) * nrows - nkvp, 'pair')))
146 out('Inserted %s' % plural(nrows, 'row'))
147 return
149 if args.limit == -1:
150 args.limit = 20
152 if args.explain:
153 for row in db.select(query, explain=True,
154 verbosity=verbosity,
155 limit=args.limit, offset=args.offset):
156 print(row['explain'])
157 return
159 if args.show_metadata:
160 print(json.dumps(db.metadata, sort_keys=True, indent=4))
161 return
163 if args.set_metadata:
164 with open(args.set_metadata) as fd:
165 db.metadata = json.load(fd)
166 return
168 if add_key_value_pairs or delete_keys:
169 ids = [row['id'] for row in db.select(query)]
170 M = 0
171 N = 0
172 with db:
173 for id in ids:
174 m, n = db.update(id, delete_keys=delete_keys,
175 **add_key_value_pairs)
176 M += m
177 N += n
178 out('Added %s (%s updated)' %
179 (plural(M, 'key-value pair'),
180 plural(len(add_key_value_pairs) * len(ids) - M, 'pair')))
181 out('Removed', plural(N, 'key-value pair'))
183 return
185 if args.delete:
186 ids = [row['id'] for row in db.select(query, include_data=False)]
187 if ids and not args.yes:
188 msg = 'Delete %s? (yes/No): ' % plural(len(ids), 'row')
189 if input(msg).lower() != 'yes':
190 return
191 db.delete(ids)
192 out('Deleted %s' % plural(len(ids), 'row'))
193 return
195 if args.plot:
196 if ':' in args.plot:
197 tags, keys = args.plot.split(':')
198 tags = tags.split(',')
199 else:
200 tags = []
201 keys = args.plot
202 keys = keys.split(',')
203 plots = defaultdict(list)
204 X = {}
205 labels = []
206 for row in db.select(query, sort=args.sort, include_data=False):
207 name = ','.join(str(row[tag]) for tag in tags)
208 x = row.get(keys[0])
209 if x is not None:
210 if isinstance(x, str):
211 if x not in X:
212 X[x] = len(X)
213 labels.append(x)
214 x = X[x]
215 plots[name].append([x] + [row.get(key) for key in keys[1:]])
216 import matplotlib.pyplot as plt
217 for name, plot in plots.items():
218 xyy = list(zip(*plot))
219 x = xyy[0]
220 for y, key in zip(xyy[1:], keys[1:]):
221 plt.plot(x, y, label=name + ':' + key)
222 if X:
223 plt.xticks(range(len(labels)), labels, rotation=90)
224 plt.legend()
225 plt.show()
226 return
228 if args.json:
229 row = db.get(query)
230 db2 = connect(sys.stdout, 'json', use_lock_file=False)
231 kvp = row.get('key_value_pairs', {})
232 db2.write(row, data=row.get('data'), **kvp)
233 return
235 if args.long:
236 row = db.get(query)
237 print(row2str(row))
238 return
240 if args.open_web_browser:
241 try:
242 import flask # noqa
243 except ImportError:
244 print('Please install Flask: python3 -m pip install flask')
245 return
246 check_jsmol()
247 import ase.db.app as app
248 app.add_project(db)
249 app.app.run(host='0.0.0.0', debug=True)
250 return
252 columns = list(all_columns)
253 c = args.columns
254 if c and c.startswith('++'):
255 keys = set()
256 for row in db.select(query,
257 limit=args.limit, offset=args.offset,
258 include_data=False):
259 keys.update(row._keys)
260 columns.extend(keys)
261 if c[2:3] == ',':
262 c = c[3:]
263 else:
264 c = ''
265 if c:
266 if c[0] == '+':
267 c = c[1:]
268 elif c[0] != '-':
269 columns = []
270 for col in c.split(','):
271 if col[0] == '-':
272 columns.remove(col[1:])
273 else:
274 columns.append(col.lstrip('+'))
276 table = Table(db, verbosity=verbosity, cut=args.cut)
277 table.select(query, columns, args.sort, args.limit, args.offset)
278 if args.csv:
279 table.write_csv()
280 else:
281 table.write(query)
284def row2str(row) -> str:
285 t = row2dct(row, key_descriptions={})
286 S = [t['formula'] + ':',
287 'Unit cell in Ang:',
288 'axis|periodic| x| y| z|' +
289 ' length| angle']
290 c = 1
291 fmt = (' {0}| {1}|{2[0]:>11}|{2[1]:>11}|{2[2]:>11}|' +
292 '{3:>10}|{4:>10}')
293 for p, axis, L, A in zip(row.pbc, t['cell'], t['lengths'], t['angles']):
294 S.append(fmt.format(c, [' no', 'yes'][p], axis, L, A))
295 c += 1
296 S.append('')
298 if 'stress' in t:
299 S += ['Stress tensor (xx, yy, zz, zy, zx, yx) in eV/Ang^3:',
300 ' {}\n'.format(t['stress'])]
302 if 'dipole' in t:
303 S.append('Dipole moment in e*Ang: ({})\n'.format(t['dipole']))
305 if 'constraints' in t:
306 S.append('Constraints: {}\n'.format(t['constraints']))
308 if 'data' in t:
309 S.append('Data: {}\n'.format(t['data']))
311 width0 = max(max(len(row[0]) for row in t['table']), 3)
312 width1 = max(max(len(row[1]) for row in t['table']), 11)
313 S.append('{:{}} | {:{}} | Value'
314 .format('Key', width0, 'Description', width1))
315 for key, desc, value in t['table']:
316 S.append('{:{}} | {:{}} | {}'
317 .format(key, width0, desc, width1, value))
318 return '\n'.join(S)
321@contextmanager
322def no_progressbar(iterable: Iterable,
323 length: int = None) -> Iterator[Iterable]:
324 """A do-nothing implementation."""
325 yield iterable
328def check_jsmol():
329 static = Path(__file__).parent / 'static'
330 if not (static / 'jsmol/JSmol.min.js').is_file():
331 print(f"""
332 WARNING:
333 You don't have jsmol on your system.
335 Download Jmol-*-binary.tar.gz from
336 https://sourceforge.net/projects/jmol/files/Jmol/,
337 extract jsmol.zip, unzip it and create a soft-link:
339 $ tar -xf Jmol-*-binary.tar.gz
340 $ unzip jmol-*/jsmol.zip
341 $ ln -s $PWD/jsmol {static}/jsmol
342 """,
343 file=sys.stderr)