Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# Refactor of DOS-like data objects 

2# towards replacing ase.dft.dos and ase.dft.pdos 

3from abc import ABCMeta, abstractmethod 

4import warnings 

5from typing import Any, Dict, Sequence, Tuple, TypeVar, Union 

6 

7import numpy as np 

8from ase.utils.plotting import SimplePlottingAxes 

9 

10# This import is for the benefit of type-checking / mypy 

11if False: 

12 import matplotlib.axes 

13 

14# For now we will be strict about Info and say it has to be str->str. Perhaps 

15# later we will allow other types that have reliable comparison operations. 

16Info = Dict[str, str] 

17 

18# Still no good solution to type checking with arrays. 

19Floats = Union[Sequence[float], np.ndarray] 

20 

21 

22class DOSData(metaclass=ABCMeta): 

23 """Abstract base class for a single series of DOS-like data 

24 

25 Only the 'info' is a mutable attribute; DOS data is set at init""" 

26 def __init__(self, 

27 info: Info = None) -> None: 

28 if info is None: 

29 self.info = {} 

30 elif isinstance(info, dict): 

31 self.info = info 

32 else: 

33 raise TypeError("Info must be a dict or None") 

34 

35 @abstractmethod 

36 def get_energies(self) -> Floats: 

37 """Get energy data stored in this object""" 

38 

39 @abstractmethod 

40 def get_weights(self) -> Floats: 

41 """Get DOS weights stored in this object""" 

42 

43 @abstractmethod 

44 def copy(self) -> 'DOSData': 

45 """Returns a copy in which info dict can be safely mutated""" 

46 

47 def _sample(self, 

48 energies: Floats, 

49 width: float = 0.1, 

50 smearing: str = 'Gauss') -> np.ndarray: 

51 """Sample the DOS data at chosen points, with broadening 

52 

53 Note that no correction is made here for the sampling bin width; total 

54 intensity will vary with sampling density. 

55 

56 Args: 

57 energies: energy values for sampling 

58 width: Width of broadening kernel 

59 smearing: selection of broadening kernel (only "Gauss" is currently 

60 supported) 

61 

62 Returns: 

63 Weights sampled from a broadened DOS at values corresponding to x 

64 """ 

65 

66 self._check_positive_width(width) 

67 weights_grid = np.zeros(len(energies), float) 

68 weights = self.get_weights() 

69 energies = np.asarray(energies, float) 

70 

71 for i, raw_energy in enumerate(self.get_energies()): 

72 delta = self._delta(energies, raw_energy, width, smearing=smearing) 

73 weights_grid += weights[i] * delta 

74 return weights_grid 

75 

76 def _almost_equals(self, other: Any) -> bool: 

77 """Compare with another DOSData for testing purposes""" 

78 if not isinstance(other, type(self)): 

79 return False 

80 if self.info != other.info: 

81 return False 

82 if not np.allclose(self.get_weights(), other.get_weights()): 

83 return False 

84 return np.allclose(self.get_energies(), other.get_energies()) 

85 

86 @staticmethod 

87 def _delta(x: np.ndarray, 

88 x0: float, 

89 width: float, 

90 smearing: str = 'Gauss') -> np.ndarray: 

91 """Return a delta-function centered at 'x0'. 

92 

93 This function is used with numpy broadcasting; if x is a row and x0 is 

94 a column vector, the returned data will be a 2D array with each row 

95 corresponding to a different delta center. 

96 """ 

97 if smearing.lower() == 'gauss': 

98 x1 = -0.5 * ((x - x0) / width)**2 

99 return np.exp(x1) / (np.sqrt(2 * np.pi) * width) 

100 else: 

101 msg = 'Requested smearing type not recognized. Got {}'.format( 

102 smearing) 

103 raise ValueError(msg) 

104 

105 @staticmethod 

106 def _check_positive_width(width): 

107 if width <= 0.0: 

108 msg = 'Cannot add 0 or negative width smearing' 

109 raise ValueError(msg) 

110 

111 def sample_grid(self, 

112 npts: int, 

113 xmin: float = None, 

114 xmax: float = None, 

115 padding: float = 3, 

116 width: float = 0.1, 

117 smearing: str = 'Gauss', 

118 ) -> 'GridDOSData': 

119 """Sample the DOS data on an evenly-spaced energy grid 

120 

121 Args: 

122 npts: Number of sampled points 

123 xmin: Minimum sampled x value; if unspecified, a default is chosen 

124 xmax: Maximum sampled x value; if unspecified, a default is chosen 

125 padding: If xmin/xmax is unspecified, default value will be padded 

126 by padding * width to avoid cutting off peaks. 

127 width: Width of broadening kernel 

128 smearing: selection of broadening kernel (only 'Gauss' is 

129 implemented) 

130 

131 Returns: 

132 (energy values, sampled DOS) 

133 """ 

134 

135 if xmin is None: 

136 xmin = min(self.get_energies()) - (padding * width) 

137 if xmax is None: 

138 xmax = max(self.get_energies()) + (padding * width) 

139 energies_grid = np.linspace(xmin, xmax, npts) 

140 weights_grid = self._sample(energies_grid, width=width, 

141 smearing=smearing) 

142 

143 return GridDOSData(energies_grid, weights_grid, info=self.info.copy()) 

144 

145 def plot(self, 

146 npts: int = 1000, 

147 xmin: float = None, 

148 xmax: float = None, 

149 width: float = 0.1, 

150 smearing: str = 'Gauss', 

151 ax: 'matplotlib.axes.Axes' = None, 

152 show: bool = False, 

153 filename: str = None, 

154 mplargs: dict = None) -> 'matplotlib.axes.Axes': 

155 """Simple 1-D plot of DOS data, resampled onto a grid 

156 

157 If the special key 'label' is present in self.info, this will be set 

158 as the label for the plotted line (unless overruled in mplargs). The 

159 label is only seen if a legend is added to the plot (i.e. by calling 

160 ``ax.legend()``). 

161 

162 Args: 

163 npts, xmin, xmax: output data range, as passed to self.sample_grid 

164 width: Width of broadening kernel for self.sample_grid() 

165 smearing: selection of broadening kernel for self.sample_grid() 

166 ax: existing Matplotlib axes object. If not provided, a new figure 

167 with one set of axes will be created using Pyplot 

168 show: show the figure on-screen 

169 filename: if a path is given, save the figure to this file 

170 mplargs: additional arguments to pass to matplotlib plot command 

171 (e.g. {'linewidth': 2} for a thicker line). 

172 

173 

174 Returns: 

175 Plotting axes. If "ax" was set, this is the same object. 

176 """ 

177 

178 if mplargs is None: 

179 mplargs = {} 

180 if 'label' not in mplargs: 

181 mplargs.update({'label': self.label_from_info(self.info)}) 

182 

183 return self.sample_grid(npts, xmin=xmin, xmax=xmax, 

184 width=width, 

185 smearing=smearing 

186 ).plot(ax=ax, xmin=xmin, xmax=xmax, 

187 show=show, filename=filename, 

188 mplargs=mplargs) 

189 

190 @staticmethod 

191 def label_from_info(info: Dict[str, str]): 

192 """Generate an automatic legend label from info dict""" 

193 if 'label' in info: 

194 return info['label'] 

195 else: 

196 return '; '.join(map(lambda x: '{}: {}'.format(x[0], x[1]), 

197 info.items())) 

198 

199 

200class GeneralDOSData(DOSData): 

201 """Base class for a single series of DOS-like data 

202 

203 Only the 'info' is a mutable attribute; DOS data is set at init 

204 

205 This is the base class for DOSData objects that accept/set seperate 

206 "energies" and "weights" sequences of equal length at init. 

207 

208 """ 

209 def __init__(self, 

210 energies: Floats, 

211 weights: Floats, 

212 info: Info = None) -> None: 

213 super().__init__(info=info) 

214 

215 n_entries = len(energies) 

216 if len(weights) != n_entries: 

217 raise ValueError("Energies and weights must be the same length") 

218 

219 # Internally store the data as a np array with two rows; energy, weight 

220 self._data = np.empty((2, n_entries), dtype=float, order='C') 

221 self._data[0, :] = energies 

222 self._data[1, :] = weights 

223 

224 def get_energies(self) -> np.ndarray: 

225 return self._data[0, :].copy() 

226 

227 def get_weights(self) -> np.ndarray: 

228 return self._data[1, :].copy() 

229 

230 D = TypeVar('D', bound='GeneralDOSData') 

231 

232 def copy(self: D) -> D: # noqa F821 

233 return type(self)(self.get_energies(), self.get_weights(), 

234 info=self.info.copy()) 

235 

236 

237class RawDOSData(GeneralDOSData): 

238 """A collection of weighted delta functions which sum to form a DOS 

239 

240 This is an appropriate data container for density-of-states (DOS) or 

241 spectral data where the energy data values not form a known regular 

242 grid. The data may be plotted or resampled for further analysis using the 

243 sample_grid() and plot() methods. Multiple weights at the same 

244 energy value will *only* be combined in output data, and data stored in 

245 RawDOSData is never resampled. A plot_deltas() function is also provided 

246 which plots the raw data. 

247 

248 Metadata may be stored in the info dict, in which keys and values must be 

249 strings. This data is used for selecting and combining multiple DOSData 

250 objects in a DOSCollection object. 

251 

252 When RawDOSData objects are combined with the addition operator:: 

253 

254 big_dos = raw_dos_1 + raw_dos_2 

255 

256 the energy and weights data is *concatenated* (i.e. combined without 

257 sorting or replacement) and the new info dictionary consists of the 

258 *intersection* of the inputs: only key-value pairs that were common to both 

259 of the input objects will be retained in the new combined object. For 

260 example:: 

261 

262 (RawDOSData([x1], [y1], info={'symbol': 'O', 'index': '1'}) 

263 + RawDOSData([x2], [y2], info={'symbol': 'O', 'index': '2'})) 

264 

265 will yield the equivalent of:: 

266 

267 RawDOSData([x1, x2], [y1, y2], info={'symbol': 'O'}) 

268 

269 """ 

270 

271 def __add__(self, other: 'RawDOSData') -> 'RawDOSData': 

272 if not isinstance(other, RawDOSData): 

273 raise TypeError("RawDOSData can only be combined with other " 

274 "RawDOSData objects") 

275 

276 # Take intersection of metadata (i.e. only common entries are retained) 

277 new_info = dict(set(self.info.items()) & set(other.info.items())) 

278 

279 # Concatenate the energy/weight data 

280 new_data = np.concatenate((self._data, other._data), axis=1) 

281 

282 new_object = RawDOSData([], [], info=new_info) 

283 new_object._data = new_data 

284 

285 return new_object 

286 

287 def plot_deltas(self, 

288 ax: 'matplotlib.axes.Axes' = None, 

289 show: bool = False, 

290 filename: str = None, 

291 mplargs: dict = None) -> 'matplotlib.axes.Axes': 

292 """Simple plot of sparse DOS data as a set of delta functions 

293 

294 Items at the same x-value can overlap and will not be summed together 

295 

296 Args: 

297 ax: existing Matplotlib axes object. If not provided, a new figure 

298 with one set of axes will be created using Pyplot 

299 show: show the figure on-screen 

300 filename: if a path is given, save the figure to this file 

301 mplargs: additional arguments to pass to matplotlib Axes.vlines 

302 command (e.g. {'linewidth': 2} for a thicker line). 

303 

304 Returns: 

305 Plotting axes. If "ax" was set, this is the same object. 

306 """ 

307 

308 if mplargs is None: 

309 mplargs = {} 

310 

311 with SimplePlottingAxes(ax=ax, show=show, filename=filename) as ax: 

312 ax.vlines(self.get_energies(), 0, self.get_weights(), **mplargs) 

313 

314 return ax 

315 

316 

317class GridDOSData(GeneralDOSData): 

318 """A collection of regularly-sampled data which represents a DOS 

319 

320 This is an appropriate data container for density-of-states (DOS) or 

321 spectral data where the intensity values form a regular grid. This 

322 is generally the result of sampling or integrating into discrete 

323 bins, rather than a collection of unique states. The data may be 

324 plotted or resampled for further analysis using the sample_grid() 

325 and plot() methods. 

326 

327 Metadata may be stored in the info dict, in which keys and values must be 

328 strings. This data is used for selecting and combining multiple DOSData 

329 objects in a DOSCollection object. 

330 

331 When RawDOSData objects are combined with the addition operator:: 

332 

333 big_dos = raw_dos_1 + raw_dos_2 

334 

335 the weights data is *summed* (requiring a consistent energy grid) and the 

336 new info dictionary consists of the *intersection* of the inputs: only 

337 key-value pairs that were common to both of the input objects will be 

338 retained in the new combined object. For example:: 

339 

340 (GridDOSData([0.1, 0.2, 0.3], [y1, y2, y3], 

341 info={'symbol': 'O', 'index': '1'}) 

342 + GridDOSData([0.1, 0.2, 0.3], [y4, y5, y6], 

343 info={'symbol': 'O', 'index': '2'})) 

344 

345 will yield the equivalent of:: 

346 

347 GridDOSData([0.1, 0.2, 0.3], [y1+y4, y2+y5, y3+y6], info={'symbol': 'O'}) 

348 

349 """ 

350 def __init__(self, 

351 energies: Floats, 

352 weights: Floats, 

353 info: Info = None) -> None: 

354 n_entries = len(energies) 

355 if not np.allclose(energies, 

356 np.linspace(energies[0], energies[-1], n_entries)): 

357 raise ValueError("Energies must be an evenly-spaced 1-D grid") 

358 

359 if len(weights) != n_entries: 

360 raise ValueError("Energies and weights must be the same length") 

361 

362 super().__init__(energies, weights, info=info) 

363 self.sigma_cutoff = 3 

364 

365 def _check_spacing(self, width) -> float: 

366 current_spacing = self._data[0, 1] - self._data[0, 0] 

367 if width < (2 * current_spacing): 

368 warnings.warn( 

369 "The broadening width is small compared to the original " 

370 "sampling density. The results are unlikely to be smooth.") 

371 return current_spacing 

372 

373 def _sample(self, 

374 energies: Floats, 

375 width: float = 0.1, 

376 smearing: str = 'Gauss') -> np.ndarray: 

377 current_spacing = self._check_spacing(width) 

378 return super()._sample(energies=energies, 

379 width=width, smearing=smearing 

380 ) * current_spacing 

381 

382 def __add__(self, other: 'GridDOSData') -> 'GridDOSData': 

383 # This method uses direct access to the mutable energy and weights data 

384 # (self._data) to avoid redundant copying operations. The __init__ 

385 # method of GridDOSData will write this to a new array, so on this 

386 # occasion it is safe to pass references to the mutable data. 

387 

388 if not isinstance(other, GridDOSData): 

389 raise TypeError("GridDOSData can only be combined with other " 

390 "GridDOSData objects") 

391 if len(self._data[0, :]) != len(other.get_energies()): 

392 raise ValueError("Cannot add GridDOSData objects with different-" 

393 "length energy grids.") 

394 

395 if not np.allclose(self._data[0, :], other.get_energies()): 

396 raise ValueError("Cannot add GridDOSData objects with different " 

397 "energy grids.") 

398 

399 # Take intersection of metadata (i.e. only common entries are retained) 

400 new_info = dict(set(self.info.items()) & set(other.info.items())) 

401 

402 # Sum the energy/weight data 

403 new_weights = self._data[1, :] + other.get_weights() 

404 

405 new_object = GridDOSData(self._data[0, :], new_weights, 

406 info=new_info) 

407 return new_object 

408 

409 @staticmethod 

410 def _interpret_smearing_args(npts: int, 

411 width: float = None, 

412 default_npts: int = 1000, 

413 default_width: float = 0.1 

414 ) -> Tuple[int, Union[float, None]]: 

415 """Figure out what the user intended: resample if width provided""" 

416 if width is not None: 

417 if npts: 

418 return (npts, float(width)) 

419 else: 

420 return (default_npts, float(width)) 

421 else: 

422 if npts: 

423 return (npts, default_width) 

424 else: 

425 return (0, None) 

426 

427 def plot(self, 

428 npts: int = 0, 

429 xmin: float = None, 

430 xmax: float = None, 

431 width: float = None, 

432 smearing: str = 'Gauss', 

433 ax: 'matplotlib.axes.Axes' = None, 

434 show: bool = False, 

435 filename: str = None, 

436 mplargs: dict = None) -> 'matplotlib.axes.Axes': 

437 """Simple 1-D plot of DOS data 

438 

439 Data will be resampled onto a grid with `npts` points unless `npts` is 

440 set to zero, in which case: 

441 

442 - no resampling takes place 

443 - `width` and `smearing` are ignored 

444 - `xmin` and `xmax` affect the axis limits of the plot, not the 

445 underlying data. 

446 

447 If the special key 'label' is present in self.info, this will be set 

448 as the label for the plotted line (unless overruled in mplargs). The 

449 label is only seen if a legend is added to the plot (i.e. by calling 

450 ``ax.legend()``). 

451 

452 Args: 

453 npts, xmin, xmax: output data range, as passed to self.sample_grid 

454 width: Width of broadening kernel, passed to self.sample_grid(). 

455 If no npts was set but width is set, npts will be set to 1000. 

456 smearing: selection of broadening kernel for self.sample_grid() 

457 ax: existing Matplotlib axes object. If not provided, a new figure 

458 with one set of axes will be created using Pyplot 

459 show: show the figure on-screen 

460 filename: if a path is given, save the figure to this file 

461 mplargs: additional arguments to pass to matplotlib plot command 

462 (e.g. {'linewidth': 2} for a thicker line). 

463 

464 Returns: 

465 Plotting axes. If "ax" was set, this is the same object. 

466 """ 

467 

468 npts, width = self._interpret_smearing_args(npts, width) 

469 

470 if mplargs is None: 

471 mplargs = {} 

472 if 'label' not in mplargs: 

473 mplargs.update({'label': self.label_from_info(self.info)}) 

474 

475 if npts: 

476 assert isinstance(width, float) 

477 dos = self.sample_grid(npts, xmin=xmin, 

478 xmax=xmax, width=width, 

479 smearing=smearing) 

480 else: 

481 dos = self 

482 

483 energies, intensity = dos.get_energies(), dos.get_weights() 

484 

485 with SimplePlottingAxes(ax=ax, show=show, filename=filename) as ax: 

486 ax.plot(energies, intensity, **mplargs) 

487 ax.set_xlim(left=xmin, right=xmax) 

488 

489 return ax