Coverage for /builds/ase/ase/ase/spectrum/doscollection.py : 97.83%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import collections
2from functools import reduce, singledispatch
3from typing import (Any, Dict, Iterable, List, Optional,
4 overload, Sequence, TypeVar, Union)
6import numpy as np
7from ase.spectrum.dosdata import DOSData, RawDOSData, GridDOSData, Info, Floats
8from ase.utils.plotting import SimplePlottingAxes
10# This import is for the benefit of type-checking / mypy
11if False:
12 import matplotlib.axes
15class DOSCollection(collections.abc.Sequence):
16 """Base class for a collection of DOSData objects"""
17 def __init__(self, dos_series: Iterable[DOSData]) -> None:
18 self._data = list(dos_series)
20 def _sample(self,
21 energies: Floats,
22 width: float = 0.1,
23 smearing: str = 'Gauss') -> np.ndarray:
24 """Sample the DOS data at chosen points, with broadening
26 This samples the underlying DOS data in the same way as the .sample()
27 method of those DOSData items, returning a 2-D array with columns
28 corresponding to x and rows corresponding to the collected data series.
30 Args:
31 energies: energy values for sampling
32 width: Width of broadening kernel
33 smearing: selection of broadening kernel (only "Gauss" is currently
34 supported)
36 Returns:
37 Weights sampled from a broadened DOS at values corresponding to x,
38 in rows corresponding to DOSData entries contained in this object
39 """
41 if len(self) == 0:
42 raise IndexError("No data to sample")
44 return np.asarray(
45 [data._sample(energies, width=width, smearing=smearing)
46 for data in self])
48 def plot(self,
49 npts: int = 1000,
50 xmin: float = None,
51 xmax: float = None,
52 width: float = 0.1,
53 smearing: str = 'Gauss',
54 ax: 'matplotlib.axes.Axes' = None,
55 show: bool = False,
56 filename: str = None,
57 mplargs: dict = None) -> 'matplotlib.axes.Axes':
58 """Simple plot of collected DOS data, resampled onto a grid
60 If the special key 'label' is present in self.info, this will be set
61 as the label for the plotted line (unless overruled in mplargs). The
62 label is only seen if a legend is added to the plot (i.e. by calling
63 `ax.legend()`).
65 Args:
66 npts, xmin, xmax: output data range, as passed to self.sample_grid
67 width: Width of broadening kernel, passed to self.sample_grid()
68 smearing: selection of broadening kernel for self.sample_grid()
69 ax: existing Matplotlib axes object. If not provided, a new figure
70 with one set of axes will be created using Pyplot
71 show: show the figure on-screen
72 filename: if a path is given, save the figure to this file
73 mplargs: additional arguments to pass to matplotlib plot command
74 (e.g. {'linewidth': 2} for a thicker line).
76 Returns:
77 Plotting axes. If "ax" was set, this is the same object.
78 """
79 return self.sample_grid(npts,
80 xmin=xmin, xmax=xmax,
81 width=width, smearing=smearing
82 ).plot(npts=npts,
83 xmin=xmin, xmax=xmax,
84 width=width, smearing=smearing,
85 ax=ax, show=show, filename=filename,
86 mplargs=mplargs)
88 def sample_grid(self,
89 npts: int,
90 xmin: float = None,
91 xmax: float = None,
92 padding: float = 3,
93 width: float = 0.1,
94 smearing: str = 'Gauss',
95 ) -> 'GridDOSCollection':
96 """Sample the DOS data on an evenly-spaced energy grid
98 Args:
99 npts: Number of sampled points
100 xmin: Minimum sampled energy value; if unspecified, a default is
101 chosen
102 xmax: Maximum sampled energy value; if unspecified, a default is
103 chosen
104 padding: If xmin/xmax is unspecified, default value will be padded
105 by padding * width to avoid cutting off peaks.
106 width: Width of broadening kernel, passed to self.sample_grid()
107 smearing: selection of broadening kernel, for self.sample_grid()
109 Returns:
110 (energy values, sampled DOS)
111 """
112 if len(self) == 0:
113 raise IndexError("No data to sample")
115 if xmin is None:
116 xmin = (min(min(data.get_energies()) for data in self)
117 - (padding * width))
118 if xmax is None:
119 xmax = (max(max(data.get_energies()) for data in self)
120 + (padding * width))
122 return GridDOSCollection(
123 [data.sample_grid(npts, xmin=xmin, xmax=xmax, width=width,
124 smearing=smearing)
125 for data in self])
127 @classmethod
128 def from_data(cls,
129 energies: Floats,
130 weights: Sequence[Floats],
131 info: Sequence[Info] = None) -> 'DOSCollection':
132 """Create a DOSCollection from data sharing a common set of energies
134 This is a convenience method to be used when all the DOS data in the
135 collection has a common energy axis. There is no performance advantage
136 in using this method for the generic DOSCollection, but for
137 GridDOSCollection it is more efficient.
139 Args:
140 energy: common set of energy values for input data
141 weights: array of DOS weights with rows corresponding to different
142 datasets
143 info: sequence of info dicts corresponding to weights rows.
145 Returns:
146 Collection of DOS data (in RawDOSData format)
147 """
149 info = cls._check_weights_and_info(weights, info)
151 return cls(RawDOSData(energies, row_weights, row_info)
152 for row_weights, row_info in zip(weights, info))
154 @staticmethod
155 def _check_weights_and_info(weights: Sequence[Floats],
156 info: Union[Sequence[Info], None],
157 ) -> Sequence[Info]:
158 if info is None:
159 info = [{} for _ in range(len(weights))]
160 else:
161 if len(info) != len(weights):
162 raise ValueError("Length of info must match number of rows in "
163 "weights")
164 return info
166 @overload
167 def __getitem__(self, item: int) -> DOSData:
168 ...
170 @overload # noqa F811
171 def __getitem__(self, item: slice) -> 'DOSCollection': # noqa F811
172 ...
174 def __getitem__(self, item): # noqa F811
175 if isinstance(item, int):
176 return self._data[item]
177 elif isinstance(item, slice):
178 return type(self)(self._data[item])
179 else:
180 raise TypeError("index in DOSCollection must be an integer or "
181 "slice")
183 def __len__(self) -> int:
184 return len(self._data)
186 def _almost_equals(self, other: Any) -> bool:
187 """Compare with another DOSCollection for testing purposes"""
188 if not isinstance(other, type(self)):
189 return False
190 elif not len(self) == len(other):
191 return False
192 else:
193 return all([a._almost_equals(b) for a, b in zip(self, other)])
195 def total(self) -> DOSData:
196 """Sum all the DOSData in this Collection and label it as 'Total'"""
197 data = self.sum_all()
198 data.info.update({'label': 'Total'})
199 return data
201 def sum_all(self) -> DOSData:
202 """Sum all the DOSData contained in this Collection"""
203 if len(self) == 0:
204 raise IndexError("No data to sum")
205 elif len(self) == 1:
206 data = self[0].copy()
207 else:
208 data = reduce(lambda x, y: x + y, self)
209 return data
211 D = TypeVar('D', bound=DOSData)
213 @staticmethod
214 def _select_to_list(dos_collection: Sequence[D], # Bug in flakes
215 info_selection: Dict[str, str], # misses 'D' def
216 negative: bool = False) -> List[D]: # noqa: F821
217 query = set(info_selection.items())
219 if negative:
220 return [data for data in dos_collection
221 if not query.issubset(set(data.info.items()))]
222 else:
223 return [data for data in dos_collection
224 if query.issubset(set(data.info.items()))]
226 def select(self, **info_selection: str) -> 'DOSCollection':
227 """Narrow DOSCollection to items with specified info
229 For example, if ::
231 dc = DOSCollection([DOSData(x1, y1, info={'a': '1', 'b': '1'}),
232 DOSData(x2, y2, info={'a': '2', 'b': '1'})])
234 then ::
236 dc.select(b='1')
238 will return an identical object to dc, while ::
240 dc.select(a='1')
242 will return a DOSCollection with only the first item and ::
244 dc.select(a='2', b='1')
246 will return a DOSCollection with only the second item.
248 """
250 matches = self._select_to_list(self, info_selection)
251 return type(self)(matches)
253 def select_not(self, **info_selection: str) -> 'DOSCollection':
254 """Narrow DOSCollection to items without specified info
256 For example, if ::
258 dc = DOSCollection([DOSData(x1, y1, info={'a': '1', 'b': '1'}),
259 DOSData(x2, y2, info={'a': '2', 'b': '1'})])
261 then ::
263 dc.select_not(b='2')
265 will return an identical object to dc, while ::
267 dc.select_not(a='2')
269 will return a DOSCollection with only the first item and ::
271 dc.select_not(a='1', b='1')
273 will return a DOSCollection with only the second item.
275 """
276 matches = self._select_to_list(self, info_selection, negative=True)
277 return type(self)(matches)
279 def sum_by(self, *info_keys: str) -> 'DOSCollection':
280 """Return a DOSCollection with some data summed by common attributes
282 For example, if ::
284 dc = DOSCollection([DOSData(x1, y1, info={'a': '1', 'b': '1'}),
285 DOSData(x2, y2, info={'a': '2', 'b': '1'}),
286 DOSData(x3, y3, info={'a': '2', 'b': '2'})])
288 then ::
290 dc.sum_by('b')
292 will return a collection equivalent to ::
294 DOSCollection([DOSData(x1, y1, info={'a': '1', 'b': '1'})
295 + DOSData(x2, y2, info={'a': '2', 'b': '1'}),
296 DOSData(x3, y3, info={'a': '2', 'b': '2'})])
298 where the resulting contained DOSData have info attributes of
299 {'b': '1'} and {'b': '2'} respectively.
301 dc.sum_by('a', 'b') on the other hand would return the full three-entry
302 collection, as none of the entries have common 'a' *and* 'b' info.
304 """
306 def _matching_info_tuples(data: DOSData):
307 """Get relevent dict entries in tuple form
309 e.g. if data.info = {'a': 1, 'b': 2, 'c': 3}
310 and info_keys = ('a', 'c')
312 then return (('a', 1), ('c': 3))
313 """
314 matched_keys = set(info_keys) & set(data.info)
315 return tuple(sorted([(key, data.info[key])
316 for key in matched_keys]))
318 # Sorting inside info matching helps set() to remove redundant matches;
319 # combos are then sorted() to ensure consistent output across sessions.
320 all_combos = map(_matching_info_tuples, self)
321 unique_combos = sorted(set(all_combos))
323 # For each key/value combination, perform a select() to obtain all
324 # the matching entries and sum them together.
325 collection_data = [self.select(**dict(combo)).sum_all()
326 for combo in unique_combos]
327 return type(self)(collection_data)
329 def __add__(self, other: Union['DOSCollection', DOSData]
330 ) -> 'DOSCollection':
331 """Join entries between two DOSCollection objects of the same type
333 It is also possible to add a single DOSData object without wrapping it
334 in a new collection: i.e. ::
336 DOSCollection([dosdata1]) + DOSCollection([dosdata2])
338 or ::
340 DOSCollection([dosdata1]) + dosdata2
342 will return ::
344 DOSCollection([dosdata1, dosdata2])
346 """
347 return _add_to_collection(other, self)
350@singledispatch
351def _add_to_collection(other: Union[DOSData, DOSCollection],
352 collection: DOSCollection) -> DOSCollection:
353 if isinstance(other, type(collection)):
354 return type(collection)(list(collection) + list(other))
355 elif isinstance(other, DOSCollection):
356 raise TypeError("Only DOSCollection objects of the same type may "
357 "be joined with '+'.")
358 else:
359 raise TypeError("DOSCollection may only be joined to DOSData or "
360 "DOSCollection objects with '+'.")
363@_add_to_collection.register(DOSData)
364def _add_data(other: DOSData, collection: DOSCollection) -> DOSCollection:
365 """Return a new DOSCollection with an additional DOSData item"""
366 return type(collection)(list(collection) + [other])
369class RawDOSCollection(DOSCollection):
370 def __init__(self, dos_series: Iterable[RawDOSData]) -> None:
371 super().__init__(dos_series)
372 for dos_data in self:
373 if not isinstance(dos_data, RawDOSData):
374 raise TypeError("RawDOSCollection can only store "
375 "RawDOSData objects.")
378class GridDOSCollection(DOSCollection):
379 def __init__(self, dos_series: Iterable[GridDOSData],
380 energies: Optional[Floats] = None) -> None:
381 dos_list = list(dos_series)
382 if energies is None:
383 if len(dos_list) == 0:
384 raise ValueError("Must provide energies to create a "
385 "GridDOSCollection without any DOS data.")
386 self._energies = dos_list[0].get_energies()
387 else:
388 self._energies = np.asarray(energies)
390 self._weights = np.empty((len(dos_list), len(self._energies)), float)
391 self._info = []
393 for i, dos_data in enumerate(dos_list):
394 if not isinstance(dos_data, GridDOSData):
395 raise TypeError("GridDOSCollection can only store "
396 "GridDOSData objects.")
397 if (dos_data.get_energies().shape != self._energies.shape
398 or not np.allclose(dos_data.get_energies(),
399 self._energies)):
400 raise ValueError("All GridDOSData objects in GridDOSCollection"
401 " must have the same energy axis.")
402 self._weights[i, :] = dos_data.get_weights()
403 self._info.append(dos_data.info)
405 def get_energies(self) -> Floats:
406 return self._energies.copy()
408 def get_all_weights(self) -> Union[Sequence[Floats], np.ndarray]:
409 return self._weights.copy()
411 def __len__(self) -> int:
412 return self._weights.shape[0]
414 @overload # noqa F811
415 def __getitem__(self, item: int) -> DOSData:
416 ...
418 @overload # noqa F811
419 def __getitem__(self, item: slice) -> 'GridDOSCollection': # noqa F811
420 ...
422 def __getitem__(self, item): # noqa F811
423 if isinstance(item, int):
424 return GridDOSData(self._energies, self._weights[item, :],
425 info=self._info[item])
426 elif isinstance(item, slice):
427 return type(self)([self[i] for i in range(len(self))[item]])
428 else:
429 raise TypeError("index in DOSCollection must be an integer or "
430 "slice")
432 @classmethod
433 def from_data(cls,
434 energies: Floats,
435 weights: Sequence[Floats],
436 info: Sequence[Info] = None) -> 'GridDOSCollection':
437 """Create a GridDOSCollection from data with a common set of energies
439 This convenience method may also be more efficient as it limits
440 redundant copying/checking of the data.
442 Args:
443 energies: common set of energy values for input data
444 weights: array of DOS weights with rows corresponding to different
445 datasets
446 info: sequence of info dicts corresponding to weights rows.
448 Returns:
449 Collection of DOS data (in RawDOSData format)
450 """
452 weights_array = np.asarray(weights, dtype=float)
453 if len(weights_array.shape) != 2:
454 raise IndexError("Weights must be a 2-D array or nested sequence")
455 if weights_array.shape[0] < 1:
456 raise IndexError("Weights cannot be empty")
457 if weights_array.shape[1] != len(energies):
458 raise IndexError("Length of weights rows must equal size of x")
460 info = cls._check_weights_and_info(weights, info)
462 dos_collection = cls([GridDOSData(energies, weights_array[0])])
463 dos_collection._weights = weights_array
464 dos_collection._info = list(info)
466 return dos_collection
468 def select(self, **info_selection: str) -> 'DOSCollection':
469 """Narrow GridDOSCollection to items with specified info
471 For example, if ::
473 dc = GridDOSCollection([GridDOSData(x, y1,
474 info={'a': '1', 'b': '1'}),
475 GridDOSData(x, y2,
476 info={'a': '2', 'b': '1'})])
478 then ::
480 dc.select(b='1')
482 will return an identical object to dc, while ::
484 dc.select(a='1')
486 will return a DOSCollection with only the first item and ::
488 dc.select(a='2', b='1')
490 will return a DOSCollection with only the second item.
492 """
494 matches = self._select_to_list(self, info_selection)
495 if len(matches) == 0:
496 return type(self)([], energies=self._energies)
497 else:
498 return type(self)(matches)
500 def select_not(self, **info_selection: str) -> 'DOSCollection':
501 """Narrow GridDOSCollection to items without specified info
503 For example, if ::
505 dc = GridDOSCollection([GridDOSData(x, y1,
506 info={'a': '1', 'b': '1'}),
507 GridDOSData(x, y2,
508 info={'a': '2', 'b': '1'})])
510 then ::
512 dc.select_not(b='2')
514 will return an identical object to dc, while ::
516 dc.select_not(a='2')
518 will return a DOSCollection with only the first item and ::
520 dc.select_not(a='1', b='1')
522 will return a DOSCollection with only the second item.
524 """
525 matches = self._select_to_list(self, info_selection, negative=True)
526 if len(matches) == 0:
527 return type(self)([], energies=self._energies)
528 else:
529 return type(self)(matches)
531 def plot(self,
532 npts: int = 0,
533 xmin: float = None,
534 xmax: float = None,
535 width: float = None,
536 smearing: str = 'Gauss',
537 ax: 'matplotlib.axes.Axes' = None,
538 show: bool = False,
539 filename: str = None,
540 mplargs: dict = None) -> 'matplotlib.axes.Axes':
541 """Simple plot of collected DOS data, resampled onto a grid
543 If the special key 'label' is present in self.info, this will be set
544 as the label for the plotted line (unless overruled in mplargs). The
545 label is only seen if a legend is added to the plot (i.e. by calling
546 `ax.legend()`).
548 Args:
549 npts:
550 Number of points in resampled x-axis. If set to zero (default),
551 no resampling is performed and the stored data is plotted
552 directly.
553 xmin, xmax:
554 output data range; this limits the resampling range as well as
555 the plotting output
556 width: Width of broadening kernel, passed to self.sample()
557 smearing: selection of broadening kernel, passed to self.sample()
558 ax: existing Matplotlib axes object. If not provided, a new figure
559 with one set of axes will be created using Pyplot
560 show: show the figure on-screen
561 filename: if a path is given, save the figure to this file
562 mplargs: additional arguments to pass to matplotlib plot command
563 (e.g. {'linewidth': 2} for a thicker line).
565 Returns:
566 Plotting axes. If "ax" was set, this is the same object.
567 """
569 # Apply defaults if necessary
570 npts, width = GridDOSData._interpret_smearing_args(npts, width)
572 if npts:
573 assert isinstance(width, float)
574 dos = self.sample_grid(npts,
575 xmin=xmin, xmax=xmax,
576 width=width, smearing=smearing)
577 else:
578 dos = self
580 energies, all_y = dos._energies, dos._weights
582 all_labels = [DOSData.label_from_info(data.info) for data in self]
584 with SimplePlottingAxes(ax=ax, show=show, filename=filename) as ax:
585 self._plot_broadened(ax, energies, all_y, all_labels, mplargs)
587 return ax
589 @staticmethod
590 def _plot_broadened(ax: 'matplotlib.axes.Axes',
591 energies: Floats,
592 all_y: np.ndarray,
593 all_labels: Sequence[str],
594 mplargs: Union[Dict, None]):
595 """Plot DOS data with labels to axes
597 This is separated into another function so that subclasses can
598 manipulate broadening, labels etc in their plot() method."""
599 if mplargs is None:
600 mplargs = {}
602 all_lines = ax.plot(energies, all_y.T, **mplargs)
603 for line, label in zip(all_lines, all_labels):
604 line.set_label(label)
605 ax.legend()
607 ax.set_xlim(left=min(energies), right=max(energies))
608 ax.set_ylim(bottom=0)