Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1from pathlib import Path 

2import json 

3from collections.abc import MutableMapping, Mapping 

4from contextlib import contextmanager 

5from ase.io.jsonio import read_json, write_json 

6from ase.io.jsonio import encode as encode_json 

7from ase.io.ulm import ulmopen, NDArrayReader, Writer, InvalidULMFileError 

8from ase.utils import opencew 

9from ase.parallel import world 

10 

11 

12def missing(key): 

13 raise KeyError(key) 

14 

15 

16class Locked(Exception): 

17 pass 

18 

19 

20# Note: 

21# 

22# The communicator handling is a complete hack. 

23# We should entirely remove communicators from these objects. 

24# (Actually: opencew() should not know about communicators.) 

25# Then the caller is responsible for handling parallelism, 

26# which makes life simpler for both the caller and us! 

27# 

28# Also, things like clean()/__del__ are not correctly implemented 

29# in parallel. The reason why it currently "works" is that 

30# we don't call those functions from Vibrations etc., or they do so 

31# only for rank==0. 

32 

33 

34class JSONBackend: 

35 extension = '.json' 

36 DecodeError = json.decoder.JSONDecodeError 

37 

38 @staticmethod 

39 def open_for_writing(path, comm): 

40 return opencew(path, world=comm) 

41 

42 @staticmethod 

43 def read(fname): 

44 return read_json(fname, always_array=False) 

45 

46 @staticmethod 

47 def open_and_write(target, data, comm): 

48 if comm.rank == 0: 

49 write_json(target, data) 

50 

51 @staticmethod 

52 def write(fd, value): 

53 fd.write(encode_json(value).encode('utf-8')) 

54 

55 @classmethod 

56 def dump_cache(cls, path, dct, comm): 

57 return CombinedJSONCache.dump_cache(path, dct, comm) 

58 

59 @classmethod 

60 def create_multifile_cache(cls, directory, comm): 

61 return MultiFileJSONCache(directory, comm=comm) 

62 

63 

64class ULMBackend: 

65 extension = '.ulm' 

66 DecodeError = InvalidULMFileError 

67 

68 @staticmethod 

69 def open_for_writing(path, comm): 

70 fd = opencew(path, world=comm) 

71 if fd is not None: 

72 return Writer(fd, 'w', '') 

73 

74 @staticmethod 

75 def read(fname): 

76 with ulmopen(fname, 'r') as r: 

77 data = r._data['cache'] 

78 if isinstance(data, NDArrayReader): 

79 return data.read() 

80 return data 

81 

82 @staticmethod 

83 def open_and_write(target, data, comm): 

84 if comm.rank == 0: 

85 with ulmopen(target, 'w') as w: 

86 w.write('cache', data) 

87 

88 @staticmethod 

89 def write(fd, value): 

90 fd.write('cache', value) 

91 

92 @classmethod 

93 def dump_cache(cls, path, dct, comm): 

94 return CombinedULMCache.dump_cache(path, dct, comm) 

95 

96 @classmethod 

97 def create_multifile_cache(cls, directory, comm): 

98 return MultiFileULMCache(directory, comm=comm) 

99 

100 

101class CacheLock: 

102 def __init__(self, fd, key, backend): 

103 self.fd = fd 

104 self.key = key 

105 self.backend = backend 

106 

107 def save(self, value): 

108 try: 

109 self.backend.write(self.fd, value) 

110 except Exception as ex: 

111 raise RuntimeError(f'Failed to save {value} to cache') from ex 

112 finally: 

113 self.fd.close() 

114 

115 

116class _MultiFileCacheTemplate(MutableMapping): 

117 writable = True 

118 

119 def __init__(self, directory, comm=world): 

120 self.directory = Path(directory) 

121 self.comm = comm 

122 

123 def _filename(self, key): 

124 return self.directory / (f'cache.{key}' + self.backend.extension) 

125 

126 def _glob(self): 

127 return self.directory.glob('cache.*' + self.backend.extension) 

128 

129 def __iter__(self): 

130 for path in self._glob(): 

131 cache, key = path.stem.split('.', 1) 

132 if cache != 'cache': 

133 continue 

134 yield key 

135 

136 def __len__(self): 

137 # Very inefficient this, but not a big usecase. 

138 return len(list(self._glob())) 

139 

140 @contextmanager 

141 def lock(self, key): 

142 if self.comm.rank == 0: 

143 self.directory.mkdir(exist_ok=True, parents=True) 

144 path = self._filename(key) 

145 fd = self.backend.open_for_writing(path, self.comm) 

146 try: 

147 if fd is None: 

148 yield None 

149 else: 

150 yield CacheLock(fd, key, self.backend) 

151 finally: 

152 if fd is not None: 

153 fd.close() 

154 

155 def __setitem__(self, key, value): 

156 with self.lock(key) as handle: 

157 if handle is None: 

158 raise Locked(key) 

159 handle.save(value) 

160 

161 def __getitem__(self, key): 

162 path = self._filename(key) 

163 try: 

164 return self.backend.read(path) 

165 except FileNotFoundError: 

166 missing(key) 

167 except self.backend.DecodeError: 

168 # May be partially written, which typically means empty 

169 # because the file was locked with exclusive-write-open. 

170 # 

171 # Since we decide what keys we have based on which files exist, 

172 # we are obligated to return a value for this case too. 

173 # So we return None. 

174 return None 

175 

176 def __delitem__(self, key): 

177 try: 

178 self._filename(key).unlink() 

179 except FileNotFoundError: 

180 missing(key) 

181 

182 def combine(self): 

183 cache = self.backend.dump_cache(self.directory, dict(self), 

184 comm=self.comm) 

185 assert set(cache) == set(self) 

186 self.clear() 

187 assert len(self) == 0 

188 return cache 

189 

190 def split(self): 

191 return self 

192 

193 def filecount(self): 

194 return len(self) 

195 

196 def strip_empties(self): 

197 empties = [key for key, value in self.items() if value is None] 

198 for key in empties: 

199 del self[key] 

200 return len(empties) 

201 

202 

203class _CombinedCacheTemplate(Mapping): 

204 writable = False 

205 

206 def __init__(self, directory, dct, comm=world): 

207 self.directory = Path(directory) 

208 self._dct = dict(dct) 

209 self.comm = comm 

210 

211 def filecount(self): 

212 return int(self._filename.is_file()) 

213 

214 @property 

215 def _filename(self): 

216 return self.directory / ('combined' + self.backend.extension) 

217 

218 def __len__(self): 

219 return len(self._dct) 

220 

221 def __iter__(self): 

222 return iter(self._dct) 

223 

224 def __getitem__(self, index): 

225 return self._dct[index] 

226 

227 def _dump(self): 

228 target = self._filename 

229 if target.exists(): 

230 raise RuntimeError(f'Already exists: {target}') 

231 self.directory.mkdir(exist_ok=True, parents=True) 

232 self.backend.open_and_write(target, self._dct, comm=self.comm) 

233 

234 @classmethod 

235 def dump_cache(cls, path, dct, comm=world): 

236 cache = cls(path, dct, comm=comm) 

237 cache._dump() 

238 return cache 

239 

240 @classmethod 

241 def load(cls, path, comm): 

242 # XXX Very hacky this one 

243 cache = cls(path, {}, comm=comm) 

244 dct = cls.backend.read(cache._filename) 

245 cache._dct.update(dct) 

246 return cache 

247 

248 def clear(self): 

249 self._filename.unlink() 

250 self._dct.clear() 

251 

252 def combine(self): 

253 return self 

254 

255 def split(self): 

256 cache = self.backend.create_multifile_cache(self.directory, 

257 comm=self.comm) 

258 assert len(cache) == 0 

259 cache.update(self) 

260 assert set(cache) == set(self) 

261 self.clear() 

262 return cache 

263 

264 

265class MultiFileJSONCache(_MultiFileCacheTemplate): 

266 backend = JSONBackend() 

267 

268 

269class MultiFileULMCache(_MultiFileCacheTemplate): 

270 backend = ULMBackend() 

271 

272 

273class CombinedJSONCache(_CombinedCacheTemplate): 

274 backend = JSONBackend() 

275 

276 

277class CombinedULMCache(_CombinedCacheTemplate): 

278 backend = ULMBackend() 

279 

280 

281def get_json_cache(directory, comm=world): 

282 try: 

283 return CombinedJSONCache.load(directory, comm=comm) 

284 except FileNotFoundError: 

285 return MultiFileJSONCache(directory, comm=comm)