Coverage for /builds/ase/ase/ase/spectrum/dosdata.py : 100.00%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# Refactor of DOS-like data objects
2# towards replacing ase.dft.dos and ase.dft.pdos
3from abc import ABCMeta, abstractmethod
4import warnings
5from typing import Any, Dict, Sequence, Tuple, TypeVar, Union
7import numpy as np
8from ase.utils.plotting import SimplePlottingAxes
10# This import is for the benefit of type-checking / mypy
11if False:
12 import matplotlib.axes
14# For now we will be strict about Info and say it has to be str->str. Perhaps
15# later we will allow other types that have reliable comparison operations.
16Info = Dict[str, str]
18# Still no good solution to type checking with arrays.
19Floats = Union[Sequence[float], np.ndarray]
22class DOSData(metaclass=ABCMeta):
23 """Abstract base class for a single series of DOS-like data
25 Only the 'info' is a mutable attribute; DOS data is set at init"""
26 def __init__(self,
27 info: Info = None) -> None:
28 if info is None:
29 self.info = {}
30 elif isinstance(info, dict):
31 self.info = info
32 else:
33 raise TypeError("Info must be a dict or None")
35 @abstractmethod
36 def get_energies(self) -> Floats:
37 """Get energy data stored in this object"""
39 @abstractmethod
40 def get_weights(self) -> Floats:
41 """Get DOS weights stored in this object"""
43 @abstractmethod
44 def copy(self) -> 'DOSData':
45 """Returns a copy in which info dict can be safely mutated"""
47 def _sample(self,
48 energies: Floats,
49 width: float = 0.1,
50 smearing: str = 'Gauss') -> np.ndarray:
51 """Sample the DOS data at chosen points, with broadening
53 Note that no correction is made here for the sampling bin width; total
54 intensity will vary with sampling density.
56 Args:
57 energies: energy values for sampling
58 width: Width of broadening kernel
59 smearing: selection of broadening kernel (only "Gauss" is currently
60 supported)
62 Returns:
63 Weights sampled from a broadened DOS at values corresponding to x
64 """
66 self._check_positive_width(width)
67 weights_grid = np.zeros(len(energies), float)
68 weights = self.get_weights()
69 energies = np.asarray(energies, float)
71 for i, raw_energy in enumerate(self.get_energies()):
72 delta = self._delta(energies, raw_energy, width, smearing=smearing)
73 weights_grid += weights[i] * delta
74 return weights_grid
76 def _almost_equals(self, other: Any) -> bool:
77 """Compare with another DOSData for testing purposes"""
78 if not isinstance(other, type(self)):
79 return False
80 if self.info != other.info:
81 return False
82 if not np.allclose(self.get_weights(), other.get_weights()):
83 return False
84 return np.allclose(self.get_energies(), other.get_energies())
86 @staticmethod
87 def _delta(x: np.ndarray,
88 x0: float,
89 width: float,
90 smearing: str = 'Gauss') -> np.ndarray:
91 """Return a delta-function centered at 'x0'.
93 This function is used with numpy broadcasting; if x is a row and x0 is
94 a column vector, the returned data will be a 2D array with each row
95 corresponding to a different delta center.
96 """
97 if smearing.lower() == 'gauss':
98 x1 = -0.5 * ((x - x0) / width)**2
99 return np.exp(x1) / (np.sqrt(2 * np.pi) * width)
100 else:
101 msg = 'Requested smearing type not recognized. Got {}'.format(
102 smearing)
103 raise ValueError(msg)
105 @staticmethod
106 def _check_positive_width(width):
107 if width <= 0.0:
108 msg = 'Cannot add 0 or negative width smearing'
109 raise ValueError(msg)
111 def sample_grid(self,
112 npts: int,
113 xmin: float = None,
114 xmax: float = None,
115 padding: float = 3,
116 width: float = 0.1,
117 smearing: str = 'Gauss',
118 ) -> 'GridDOSData':
119 """Sample the DOS data on an evenly-spaced energy grid
121 Args:
122 npts: Number of sampled points
123 xmin: Minimum sampled x value; if unspecified, a default is chosen
124 xmax: Maximum sampled x value; if unspecified, a default is chosen
125 padding: If xmin/xmax is unspecified, default value will be padded
126 by padding * width to avoid cutting off peaks.
127 width: Width of broadening kernel
128 smearing: selection of broadening kernel (only 'Gauss' is
129 implemented)
131 Returns:
132 (energy values, sampled DOS)
133 """
135 if xmin is None:
136 xmin = min(self.get_energies()) - (padding * width)
137 if xmax is None:
138 xmax = max(self.get_energies()) + (padding * width)
139 energies_grid = np.linspace(xmin, xmax, npts)
140 weights_grid = self._sample(energies_grid, width=width,
141 smearing=smearing)
143 return GridDOSData(energies_grid, weights_grid, info=self.info.copy())
145 def plot(self,
146 npts: int = 1000,
147 xmin: float = None,
148 xmax: float = None,
149 width: float = 0.1,
150 smearing: str = 'Gauss',
151 ax: 'matplotlib.axes.Axes' = None,
152 show: bool = False,
153 filename: str = None,
154 mplargs: dict = None) -> 'matplotlib.axes.Axes':
155 """Simple 1-D plot of DOS data, resampled onto a grid
157 If the special key 'label' is present in self.info, this will be set
158 as the label for the plotted line (unless overruled in mplargs). The
159 label is only seen if a legend is added to the plot (i.e. by calling
160 ``ax.legend()``).
162 Args:
163 npts, xmin, xmax: output data range, as passed to self.sample_grid
164 width: Width of broadening kernel for self.sample_grid()
165 smearing: selection of broadening kernel for self.sample_grid()
166 ax: existing Matplotlib axes object. If not provided, a new figure
167 with one set of axes will be created using Pyplot
168 show: show the figure on-screen
169 filename: if a path is given, save the figure to this file
170 mplargs: additional arguments to pass to matplotlib plot command
171 (e.g. {'linewidth': 2} for a thicker line).
174 Returns:
175 Plotting axes. If "ax" was set, this is the same object.
176 """
178 if mplargs is None:
179 mplargs = {}
180 if 'label' not in mplargs:
181 mplargs.update({'label': self.label_from_info(self.info)})
183 return self.sample_grid(npts, xmin=xmin, xmax=xmax,
184 width=width,
185 smearing=smearing
186 ).plot(ax=ax, xmin=xmin, xmax=xmax,
187 show=show, filename=filename,
188 mplargs=mplargs)
190 @staticmethod
191 def label_from_info(info: Dict[str, str]):
192 """Generate an automatic legend label from info dict"""
193 if 'label' in info:
194 return info['label']
195 else:
196 return '; '.join(map(lambda x: '{}: {}'.format(x[0], x[1]),
197 info.items()))
200class GeneralDOSData(DOSData):
201 """Base class for a single series of DOS-like data
203 Only the 'info' is a mutable attribute; DOS data is set at init
205 This is the base class for DOSData objects that accept/set seperate
206 "energies" and "weights" sequences of equal length at init.
208 """
209 def __init__(self,
210 energies: Floats,
211 weights: Floats,
212 info: Info = None) -> None:
213 super().__init__(info=info)
215 n_entries = len(energies)
216 if len(weights) != n_entries:
217 raise ValueError("Energies and weights must be the same length")
219 # Internally store the data as a np array with two rows; energy, weight
220 self._data = np.empty((2, n_entries), dtype=float, order='C')
221 self._data[0, :] = energies
222 self._data[1, :] = weights
224 def get_energies(self) -> np.ndarray:
225 return self._data[0, :].copy()
227 def get_weights(self) -> np.ndarray:
228 return self._data[1, :].copy()
230 D = TypeVar('D', bound='GeneralDOSData')
232 def copy(self: D) -> D: # noqa F821
233 return type(self)(self.get_energies(), self.get_weights(),
234 info=self.info.copy())
237class RawDOSData(GeneralDOSData):
238 """A collection of weighted delta functions which sum to form a DOS
240 This is an appropriate data container for density-of-states (DOS) or
241 spectral data where the energy data values not form a known regular
242 grid. The data may be plotted or resampled for further analysis using the
243 sample_grid() and plot() methods. Multiple weights at the same
244 energy value will *only* be combined in output data, and data stored in
245 RawDOSData is never resampled. A plot_deltas() function is also provided
246 which plots the raw data.
248 Metadata may be stored in the info dict, in which keys and values must be
249 strings. This data is used for selecting and combining multiple DOSData
250 objects in a DOSCollection object.
252 When RawDOSData objects are combined with the addition operator::
254 big_dos = raw_dos_1 + raw_dos_2
256 the energy and weights data is *concatenated* (i.e. combined without
257 sorting or replacement) and the new info dictionary consists of the
258 *intersection* of the inputs: only key-value pairs that were common to both
259 of the input objects will be retained in the new combined object. For
260 example::
262 (RawDOSData([x1], [y1], info={'symbol': 'O', 'index': '1'})
263 + RawDOSData([x2], [y2], info={'symbol': 'O', 'index': '2'}))
265 will yield the equivalent of::
267 RawDOSData([x1, x2], [y1, y2], info={'symbol': 'O'})
269 """
271 def __add__(self, other: 'RawDOSData') -> 'RawDOSData':
272 if not isinstance(other, RawDOSData):
273 raise TypeError("RawDOSData can only be combined with other "
274 "RawDOSData objects")
276 # Take intersection of metadata (i.e. only common entries are retained)
277 new_info = dict(set(self.info.items()) & set(other.info.items()))
279 # Concatenate the energy/weight data
280 new_data = np.concatenate((self._data, other._data), axis=1)
282 new_object = RawDOSData([], [], info=new_info)
283 new_object._data = new_data
285 return new_object
287 def plot_deltas(self,
288 ax: 'matplotlib.axes.Axes' = None,
289 show: bool = False,
290 filename: str = None,
291 mplargs: dict = None) -> 'matplotlib.axes.Axes':
292 """Simple plot of sparse DOS data as a set of delta functions
294 Items at the same x-value can overlap and will not be summed together
296 Args:
297 ax: existing Matplotlib axes object. If not provided, a new figure
298 with one set of axes will be created using Pyplot
299 show: show the figure on-screen
300 filename: if a path is given, save the figure to this file
301 mplargs: additional arguments to pass to matplotlib Axes.vlines
302 command (e.g. {'linewidth': 2} for a thicker line).
304 Returns:
305 Plotting axes. If "ax" was set, this is the same object.
306 """
308 if mplargs is None:
309 mplargs = {}
311 with SimplePlottingAxes(ax=ax, show=show, filename=filename) as ax:
312 ax.vlines(self.get_energies(), 0, self.get_weights(), **mplargs)
314 return ax
317class GridDOSData(GeneralDOSData):
318 """A collection of regularly-sampled data which represents a DOS
320 This is an appropriate data container for density-of-states (DOS) or
321 spectral data where the intensity values form a regular grid. This
322 is generally the result of sampling or integrating into discrete
323 bins, rather than a collection of unique states. The data may be
324 plotted or resampled for further analysis using the sample_grid()
325 and plot() methods.
327 Metadata may be stored in the info dict, in which keys and values must be
328 strings. This data is used for selecting and combining multiple DOSData
329 objects in a DOSCollection object.
331 When RawDOSData objects are combined with the addition operator::
333 big_dos = raw_dos_1 + raw_dos_2
335 the weights data is *summed* (requiring a consistent energy grid) and the
336 new info dictionary consists of the *intersection* of the inputs: only
337 key-value pairs that were common to both of the input objects will be
338 retained in the new combined object. For example::
340 (GridDOSData([0.1, 0.2, 0.3], [y1, y2, y3],
341 info={'symbol': 'O', 'index': '1'})
342 + GridDOSData([0.1, 0.2, 0.3], [y4, y5, y6],
343 info={'symbol': 'O', 'index': '2'}))
345 will yield the equivalent of::
347 GridDOSData([0.1, 0.2, 0.3], [y1+y4, y2+y5, y3+y6], info={'symbol': 'O'})
349 """
350 def __init__(self,
351 energies: Floats,
352 weights: Floats,
353 info: Info = None) -> None:
354 n_entries = len(energies)
355 if not np.allclose(energies,
356 np.linspace(energies[0], energies[-1], n_entries)):
357 raise ValueError("Energies must be an evenly-spaced 1-D grid")
359 if len(weights) != n_entries:
360 raise ValueError("Energies and weights must be the same length")
362 super().__init__(energies, weights, info=info)
363 self.sigma_cutoff = 3
365 def _check_spacing(self, width) -> float:
366 current_spacing = self._data[0, 1] - self._data[0, 0]
367 if width < (2 * current_spacing):
368 warnings.warn(
369 "The broadening width is small compared to the original "
370 "sampling density. The results are unlikely to be smooth.")
371 return current_spacing
373 def _sample(self,
374 energies: Floats,
375 width: float = 0.1,
376 smearing: str = 'Gauss') -> np.ndarray:
377 current_spacing = self._check_spacing(width)
378 return super()._sample(energies=energies,
379 width=width, smearing=smearing
380 ) * current_spacing
382 def __add__(self, other: 'GridDOSData') -> 'GridDOSData':
383 # This method uses direct access to the mutable energy and weights data
384 # (self._data) to avoid redundant copying operations. The __init__
385 # method of GridDOSData will write this to a new array, so on this
386 # occasion it is safe to pass references to the mutable data.
388 if not isinstance(other, GridDOSData):
389 raise TypeError("GridDOSData can only be combined with other "
390 "GridDOSData objects")
391 if len(self._data[0, :]) != len(other.get_energies()):
392 raise ValueError("Cannot add GridDOSData objects with different-"
393 "length energy grids.")
395 if not np.allclose(self._data[0, :], other.get_energies()):
396 raise ValueError("Cannot add GridDOSData objects with different "
397 "energy grids.")
399 # Take intersection of metadata (i.e. only common entries are retained)
400 new_info = dict(set(self.info.items()) & set(other.info.items()))
402 # Sum the energy/weight data
403 new_weights = self._data[1, :] + other.get_weights()
405 new_object = GridDOSData(self._data[0, :], new_weights,
406 info=new_info)
407 return new_object
409 @staticmethod
410 def _interpret_smearing_args(npts: int,
411 width: float = None,
412 default_npts: int = 1000,
413 default_width: float = 0.1
414 ) -> Tuple[int, Union[float, None]]:
415 """Figure out what the user intended: resample if width provided"""
416 if width is not None:
417 if npts:
418 return (npts, float(width))
419 else:
420 return (default_npts, float(width))
421 else:
422 if npts:
423 return (npts, default_width)
424 else:
425 return (0, None)
427 def plot(self,
428 npts: int = 0,
429 xmin: float = None,
430 xmax: float = None,
431 width: float = None,
432 smearing: str = 'Gauss',
433 ax: 'matplotlib.axes.Axes' = None,
434 show: bool = False,
435 filename: str = None,
436 mplargs: dict = None) -> 'matplotlib.axes.Axes':
437 """Simple 1-D plot of DOS data
439 Data will be resampled onto a grid with `npts` points unless `npts` is
440 set to zero, in which case:
442 - no resampling takes place
443 - `width` and `smearing` are ignored
444 - `xmin` and `xmax` affect the axis limits of the plot, not the
445 underlying data.
447 If the special key 'label' is present in self.info, this will be set
448 as the label for the plotted line (unless overruled in mplargs). The
449 label is only seen if a legend is added to the plot (i.e. by calling
450 ``ax.legend()``).
452 Args:
453 npts, xmin, xmax: output data range, as passed to self.sample_grid
454 width: Width of broadening kernel, passed to self.sample_grid().
455 If no npts was set but width is set, npts will be set to 1000.
456 smearing: selection of broadening kernel for self.sample_grid()
457 ax: existing Matplotlib axes object. If not provided, a new figure
458 with one set of axes will be created using Pyplot
459 show: show the figure on-screen
460 filename: if a path is given, save the figure to this file
461 mplargs: additional arguments to pass to matplotlib plot command
462 (e.g. {'linewidth': 2} for a thicker line).
464 Returns:
465 Plotting axes. If "ax" was set, this is the same object.
466 """
468 npts, width = self._interpret_smearing_args(npts, width)
470 if mplargs is None:
471 mplargs = {}
472 if 'label' not in mplargs:
473 mplargs.update({'label': self.label_from_info(self.info)})
475 if npts:
476 assert isinstance(width, float)
477 dos = self.sample_grid(npts, xmin=xmin,
478 xmax=xmax, width=width,
479 smearing=smearing)
480 else:
481 dos = self
483 energies, intensity = dos.get_energies(), dos.get_weights()
485 with SimplePlottingAxes(ax=ax, show=show, filename=filename) as ax:
486 ax.plot(energies, intensity, **mplargs)
487 ax.set_xlim(left=xmin, right=xmax)
489 return ax