Coverage for /builds/ase/ase/ase/utils/filecache.py : 95.92%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1from pathlib import Path
2import json
3from collections.abc import MutableMapping, Mapping
4from contextlib import contextmanager
5from ase.io.jsonio import read_json, write_json
6from ase.io.jsonio import encode as encode_json
7from ase.io.ulm import ulmopen, NDArrayReader, Writer, InvalidULMFileError
8from ase.utils import opencew
9from ase.parallel import world
12def missing(key):
13 raise KeyError(key)
16class Locked(Exception):
17 pass
20# Note:
21#
22# The communicator handling is a complete hack.
23# We should entirely remove communicators from these objects.
24# (Actually: opencew() should not know about communicators.)
25# Then the caller is responsible for handling parallelism,
26# which makes life simpler for both the caller and us!
27#
28# Also, things like clean()/__del__ are not correctly implemented
29# in parallel. The reason why it currently "works" is that
30# we don't call those functions from Vibrations etc., or they do so
31# only for rank==0.
34class JSONBackend:
35 extension = '.json'
36 DecodeError = json.decoder.JSONDecodeError
38 @staticmethod
39 def open_for_writing(path, comm):
40 return opencew(path, world=comm)
42 @staticmethod
43 def read(fname):
44 return read_json(fname, always_array=False)
46 @staticmethod
47 def open_and_write(target, data, comm):
48 if comm.rank == 0:
49 write_json(target, data)
51 @staticmethod
52 def write(fd, value):
53 fd.write(encode_json(value).encode('utf-8'))
55 @classmethod
56 def dump_cache(cls, path, dct, comm):
57 return CombinedJSONCache.dump_cache(path, dct, comm)
59 @classmethod
60 def create_multifile_cache(cls, directory, comm):
61 return MultiFileJSONCache(directory, comm=comm)
64class ULMBackend:
65 extension = '.ulm'
66 DecodeError = InvalidULMFileError
68 @staticmethod
69 def open_for_writing(path, comm):
70 fd = opencew(path, world=comm)
71 if fd is not None:
72 return Writer(fd, 'w', '')
74 @staticmethod
75 def read(fname):
76 with ulmopen(fname, 'r') as r:
77 data = r._data['cache']
78 if isinstance(data, NDArrayReader):
79 return data.read()
80 return data
82 @staticmethod
83 def open_and_write(target, data, comm):
84 if comm.rank == 0:
85 with ulmopen(target, 'w') as w:
86 w.write('cache', data)
88 @staticmethod
89 def write(fd, value):
90 fd.write('cache', value)
92 @classmethod
93 def dump_cache(cls, path, dct, comm):
94 return CombinedULMCache.dump_cache(path, dct, comm)
96 @classmethod
97 def create_multifile_cache(cls, directory, comm):
98 return MultiFileULMCache(directory, comm=comm)
101class CacheLock:
102 def __init__(self, fd, key, backend):
103 self.fd = fd
104 self.key = key
105 self.backend = backend
107 def save(self, value):
108 try:
109 self.backend.write(self.fd, value)
110 except Exception as ex:
111 raise RuntimeError(f'Failed to save {value} to cache') from ex
112 finally:
113 self.fd.close()
116class _MultiFileCacheTemplate(MutableMapping):
117 writable = True
119 def __init__(self, directory, comm=world):
120 self.directory = Path(directory)
121 self.comm = comm
123 def _filename(self, key):
124 return self.directory / (f'cache.{key}' + self.backend.extension)
126 def _glob(self):
127 return self.directory.glob('cache.*' + self.backend.extension)
129 def __iter__(self):
130 for path in self._glob():
131 cache, key = path.stem.split('.', 1)
132 if cache != 'cache':
133 continue
134 yield key
136 def __len__(self):
137 # Very inefficient this, but not a big usecase.
138 return len(list(self._glob()))
140 @contextmanager
141 def lock(self, key):
142 if self.comm.rank == 0:
143 self.directory.mkdir(exist_ok=True, parents=True)
144 path = self._filename(key)
145 fd = self.backend.open_for_writing(path, self.comm)
146 try:
147 if fd is None:
148 yield None
149 else:
150 yield CacheLock(fd, key, self.backend)
151 finally:
152 if fd is not None:
153 fd.close()
155 def __setitem__(self, key, value):
156 with self.lock(key) as handle:
157 if handle is None:
158 raise Locked(key)
159 handle.save(value)
161 def __getitem__(self, key):
162 path = self._filename(key)
163 try:
164 return self.backend.read(path)
165 except FileNotFoundError:
166 missing(key)
167 except self.backend.DecodeError:
168 # May be partially written, which typically means empty
169 # because the file was locked with exclusive-write-open.
170 #
171 # Since we decide what keys we have based on which files exist,
172 # we are obligated to return a value for this case too.
173 # So we return None.
174 return None
176 def __delitem__(self, key):
177 try:
178 self._filename(key).unlink()
179 except FileNotFoundError:
180 missing(key)
182 def combine(self):
183 cache = self.backend.dump_cache(self.directory, dict(self),
184 comm=self.comm)
185 assert set(cache) == set(self)
186 self.clear()
187 assert len(self) == 0
188 return cache
190 def split(self):
191 return self
193 def filecount(self):
194 return len(self)
196 def strip_empties(self):
197 empties = [key for key, value in self.items() if value is None]
198 for key in empties:
199 del self[key]
200 return len(empties)
203class _CombinedCacheTemplate(Mapping):
204 writable = False
206 def __init__(self, directory, dct, comm=world):
207 self.directory = Path(directory)
208 self._dct = dict(dct)
209 self.comm = comm
211 def filecount(self):
212 return int(self._filename.is_file())
214 @property
215 def _filename(self):
216 return self.directory / ('combined' + self.backend.extension)
218 def __len__(self):
219 return len(self._dct)
221 def __iter__(self):
222 return iter(self._dct)
224 def __getitem__(self, index):
225 return self._dct[index]
227 def _dump(self):
228 target = self._filename
229 if target.exists():
230 raise RuntimeError(f'Already exists: {target}')
231 self.directory.mkdir(exist_ok=True, parents=True)
232 self.backend.open_and_write(target, self._dct, comm=self.comm)
234 @classmethod
235 def dump_cache(cls, path, dct, comm=world):
236 cache = cls(path, dct, comm=comm)
237 cache._dump()
238 return cache
240 @classmethod
241 def load(cls, path, comm):
242 # XXX Very hacky this one
243 cache = cls(path, {}, comm=comm)
244 dct = cls.backend.read(cache._filename)
245 cache._dct.update(dct)
246 return cache
248 def clear(self):
249 self._filename.unlink()
250 self._dct.clear()
252 def combine(self):
253 return self
255 def split(self):
256 cache = self.backend.create_multifile_cache(self.directory,
257 comm=self.comm)
258 assert len(cache) == 0
259 cache.update(self)
260 assert set(cache) == set(self)
261 self.clear()
262 return cache
265class MultiFileJSONCache(_MultiFileCacheTemplate):
266 backend = JSONBackend()
269class MultiFileULMCache(_MultiFileCacheTemplate):
270 backend = ULMBackend()
273class CombinedJSONCache(_CombinedCacheTemplate):
274 backend = JSONBackend()
277class CombinedULMCache(_CombinedCacheTemplate):
278 backend = ULMBackend()
281def get_json_cache(directory, comm=world):
282 try:
283 return CombinedJSONCache.load(directory, comm=comm)
284 except FileNotFoundError:
285 return MultiFileJSONCache(directory, comm=comm)