Coverage for /builds/ase/ase/ase/calculators/kim/kimpy_wrappers.py : 74.53%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2Wrappers that provide a minimal interface to kimpy methods and objects
4Daniel S. Karls
5University of Minnesota
6"""
8from abc import ABC
9import functools
11import numpy as np
12import kimpy
14from .exceptions import (
15 KIMModelNotFound,
16 KIMModelInitializationError,
17 KimpyError,
18 KIMModelParameterError,
19)
21# Function used for casting parameter/extent indices to C-compatible ints
22c_int = np.intc
24# Function used for casting floating point parameter values to C-compatible
25# doubles
26c_double = np.double
29def c_int_args(func):
30 """
31 Decorator for instance methods that will cast all of the args passed,
32 excluding the first (which corresponds to 'self'), to C-compatible
33 integers.
34 """
36 @functools.wraps(func)
37 def myfunc(*args, **kwargs):
38 args_cast = [args[0]]
39 args_cast += map(c_int, args[1:])
40 return func(*args, **kwargs)
42 return myfunc
45def check_call(f, *args, **kwargs):
46 """Call a kimpy function using its arguments and, if a RuntimeError is
47 raised, catch it and raise a KimpyError with the exception's
48 message.
50 (Starting with kimpy 2.0.0, a RuntimeError is the only exception
51 type raised when something goes wrong.)"""
53 try:
54 return f(*args, **kwargs)
55 except RuntimeError as e:
56 raise KimpyError(
57 f'Calling kimpy function "{f.__name__}" failed:\n {str(e)}')
60def check_call_wrapper(func):
61 @functools.wraps(func)
62 def myfunc(*args, **kwargs):
63 return check_call(func, *args, **kwargs)
65 return myfunc
68# kimpy methods
69collections_create = functools.partial(check_call, kimpy.collections.create)
70model_create = functools.partial(check_call, kimpy.model.create)
71simulator_model_create = functools.partial(
72 check_call, kimpy.simulator_model.create)
73get_species_name = functools.partial(
74 check_call, kimpy.species_name.get_species_name)
75get_number_of_species_names = functools.partial(
76 check_call, kimpy.species_name.get_number_of_species_names
77)
79# kimpy attributes (here to avoid importing kimpy in higher-level modules)
80collection_item_type_portableModel = kimpy.collection_item_type.portableModel
83class ModelCollections:
84 """
85 KIM Portable Models and Simulator Models are installed/managed into
86 different "collections". In order to search through the different
87 KIM API model collections on the system, a corresponding object must
88 be instantiated. For more on model collections, see the KIM API's
89 install file:
90 https://github.com/openkim/kim-api/blob/master/INSTALL
91 """
93 def __init__(self):
94 self.collection = collections_create()
96 def __enter__(self):
97 return self
99 def __exit__(self, exc_type, value, traceback):
100 pass
102 def get_item_type(self, model_name):
103 try:
104 model_type = check_call(self.collection.get_item_type, model_name)
105 except KimpyError:
106 msg = (
107 "Could not find model {} installed in any of the KIM API "
108 "model collections on this system. See "
109 "https://openkim.org/doc/usage/obtaining-models/ for "
110 "instructions on installing models.".format(model_name)
111 )
112 raise KIMModelNotFound(msg)
114 return model_type
116 @property
117 def initialized(self):
118 return hasattr(self, "collection")
121class PortableModel:
122 """Creates a KIM API Portable Model object and provides a minimal
123 interface to it"""
125 def __init__(self, model_name, debug):
126 self.model_name = model_name
127 self.debug = debug
129 # Create KIM API Model object
130 units_accepted, self.kim_model = model_create(
131 kimpy.numbering.zeroBased,
132 kimpy.length_unit.A,
133 kimpy.energy_unit.eV,
134 kimpy.charge_unit.e,
135 kimpy.temperature_unit.K,
136 kimpy.time_unit.ps,
137 self.model_name,
138 )
140 if not units_accepted:
141 raise KIMModelInitializationError(
142 "Requested units not accepted in kimpy.model.create"
143 )
145 if self.debug:
146 l_unit, e_unit, c_unit, te_unit, ti_unit = check_call(
147 self.kim_model.get_units
148 )
149 print("Length unit is: {}".format(l_unit))
150 print("Energy unit is: {}".format(e_unit))
151 print("Charge unit is: {}".format(c_unit))
152 print("Temperature unit is: {}".format(te_unit))
153 print("Time unit is: {}".format(ti_unit))
154 print()
156 self._create_parameters()
158 def __enter__(self):
159 return self
161 def __exit__(self, exc_type, value, traceback):
162 pass
164 @check_call_wrapper
165 def _get_number_of_parameters(self):
166 return self.kim_model.get_number_of_parameters()
168 def _create_parameters(self):
169 def _kim_model_parameter(**kwargs):
170 dtype = kwargs["dtype"]
172 if dtype == "Integer":
173 return KIMModelParameterInteger(**kwargs)
174 elif dtype == "Double":
175 return KIMModelParameterDouble(**kwargs)
176 else:
177 raise KIMModelParameterError(
178 f"Invalid model parameter type {dtype}. Supported types "
179 "'Integer' and 'Double'."
180 )
182 self._parameters = {}
183 num_params = self._get_number_of_parameters()
184 for index_param in range(num_params):
185 parameter_metadata = self._get_one_parameter_metadata(index_param)
186 name = parameter_metadata["name"]
188 self._parameters[name] = _kim_model_parameter(
189 kim_model=self.kim_model,
190 dtype=parameter_metadata["dtype"],
191 extent=parameter_metadata["extent"],
192 name=name,
193 description=parameter_metadata["description"],
194 parameter_index=index_param,
195 )
197 def get_model_supported_species_and_codes(self):
198 """Get all of the supported species for this model and their
199 corresponding integer codes that are defined in the KIM API
201 Returns
202 -------
203 species : list of str
204 Abbreviated chemical symbols of all species the mmodel
205 supports (e.g. ["Mo", "S"])
207 codes : list of int
208 Integer codes used by the model for each species (order
209 corresponds to the order of ``species``)
210 """
211 species = []
212 codes = []
213 num_kim_species = get_number_of_species_names()
215 for i in range(num_kim_species):
216 species_name = get_species_name(i)
218 species_is_supported, code = self.get_species_support_and_code(
219 species_name)
221 if species_is_supported:
222 species.append(str(species_name))
223 codes.append(code)
225 return species, codes
227 @check_call_wrapper
228 def clear_then_refresh(self):
229 self.kim_model.clear_then_refresh()
231 @c_int_args
232 def _get_parameter_metadata(self, index_parameter):
233 try:
234 dtype, extent, name, description = check_call(
235 self.kim_model.get_parameter_metadata, index_parameter
236 )
237 except KimpyError as e:
238 raise KIMModelParameterError(
239 "Failed to retrieve metadata for "
240 f"parameter at index {index_parameter}"
241 ) from e
243 return dtype, extent, name, description
245 def parameters_metadata(self):
246 """Metadata associated with all model parameters.
248 Returns
249 -------
250 dict
251 Metadata associated with all model parameters.
252 """
253 return {
254 param_name: param.metadata
255 for param_name, param in self._parameters.items()
256 }
258 def parameter_names(self):
259 """Names of model parameters registered in the KIM API.
261 Returns
262 -------
263 tuple
264 Names of model parameters registered in the KIM API
265 """
266 return tuple(self._parameters.keys())
268 def get_parameters(self, **kwargs):
269 """
270 Get the values of one or more model parameter arrays.
272 Given the names of one or more model parameters and a set of indices
273 for each of them, retrieve the corresponding elements of the relevant
274 model parameter arrays.
276 Parameters
277 ----------
278 **kwargs
279 Names of the model parameters and the indices whose values should
280 be retrieved.
282 Returns
283 -------
284 dict
285 The requested indices and the values of the model's parameters.
287 Note
288 ----
289 The output of this method can be used as input of
290 ``set_parameters``.
292 Example
293 -------
294 To get `epsilons` and `sigmas` in the LJ universal model for Mo-Mo
295 (index 4879), Mo-S (index 2006) and S-S (index 1980) interactions::
297 >>> LJ = 'LJ_ElliottAkerson_2015_Universal__MO_959249795837_003'
298 >>> calc = KIM(LJ)
299 >>> calc.get_parameters(epsilons=[4879, 2006, 1980],
300 ... sigmas=[4879, 2006, 1980])
301 {'epsilons': [[4879, 2006, 1980],
302 [4.47499, 4.421814057295943, 4.36927]],
303 'sigmas': [[4879, 2006, 1980],
304 [2.74397, 2.30743, 1.87089]]}
305 """
306 parameters = {}
307 for parameter_name, index_range in kwargs.items():
308 parameters.update(
309 self._get_one_parameter(
310 parameter_name,
311 index_range))
312 return parameters
314 def set_parameters(self, **kwargs):
315 """
316 Set the values of one or more model parameter arrays.
318 Given the names of one or more model parameters and a set of indices
319 and corresponding values for each of them, mutate the corresponding
320 elements of the relevant model parameter arrays.
322 Parameters
323 ----------
324 **kwargs
325 Names of the model parameters to mutate and the corresponding
326 indices and values to set.
328 Returns
329 -------
330 dict
331 The requested indices and the values of the model's parameters
332 that were set.
334 Example
335 -------
336 To set `epsilons` in the LJ universal model for Mo-Mo (index 4879),
337 Mo-S (index 2006) and S-S (index 1980) interactions to 5.0, 4.5, and
338 4.0, respectively::
340 >>> LJ = 'LJ_ElliottAkerson_2015_Universal__MO_959249795837_003'
341 >>> calc = KIM(LJ)
342 >>> calc.set_parameters(epsilons=[[4879, 2006, 1980],
343 ... [5.0, 4.5, 4.0]])
344 {'epsilons': [[4879, 2006, 1980],
345 [5.0, 4.5, 4.0]]}
346 """
347 parameters = {}
348 for parameter_name, parameter_data in kwargs.items():
349 index_range, values = parameter_data
350 self._set_one_parameter(parameter_name, index_range, values)
351 parameters[parameter_name] = parameter_data
353 return parameters
355 def _get_one_parameter(self, parameter_name, index_range):
356 """
357 Retrieve value of one or more components of a model parameter array.
359 Parameters
360 ----------
361 parameter_name : str
362 Name of model parameter registered in the KIM API.
363 index_range : int or list
364 Zero-based index (int) or indices (list of int) specifying the
365 component(s) of the corresponding model parameter array that are
366 to be retrieved.
368 Returns
369 -------
370 dict
371 The requested indices and the corresponding values of the model
372 parameter array.
373 """
374 if parameter_name not in self._parameters:
375 raise KIMModelParameterError(
376 f"Parameter '{parameter_name}' is not "
377 "supported by this model. "
378 "Please check that the parameter name is spelled correctly."
379 )
381 return self._parameters[parameter_name].get_values(index_range)
383 def _set_one_parameter(self, parameter_name, index_range, values):
384 """
385 Set the value of one or more components of a model parameter array.
387 Parameters
388 ----------
389 parameter_name : str
390 Name of model parameter registered in the KIM API.
391 index_range : int or list
392 Zero-based index (int) or indices (list of int) specifying the
393 component(s) of the corresponding model parameter array that are
394 to be mutated.
395 values : int/float or list
396 Value(s) to assign to the component(s) of the model parameter
397 array specified by ``index_range``.
398 """
399 if parameter_name not in self._parameters:
400 raise KIMModelParameterError(
401 f"Parameter '{parameter_name}' is not "
402 "supported by this model. "
403 "Please check that the parameter name is spelled correctly."
404 )
406 self._parameters[parameter_name].set_values(index_range, values)
408 def _get_one_parameter_metadata(self, index_parameter):
409 """
410 Get metadata associated with a single model parameter.
412 Parameters
413 ----------
414 index_parameter : int
415 Zero-based index used by the KIM API to refer to this model
416 parameter.
418 Returns
419 -------
420 dict
421 Metadata associated with the requested model parameter.
422 """
423 dtype, extent, name, description = self._get_parameter_metadata(
424 index_parameter)
425 parameter_metadata = {
426 "name": name,
427 "dtype": repr(dtype),
428 "extent": extent,
429 "description": description,
430 }
431 return parameter_metadata
433 @check_call_wrapper
434 def compute(self, compute_args_wrapped, release_GIL):
435 return self.kim_model.compute(
436 compute_args_wrapped.compute_args, release_GIL)
438 @check_call_wrapper
439 def get_species_support_and_code(self, species_name):
440 return self.kim_model.get_species_support_and_code(species_name)
442 @check_call_wrapper
443 def get_influence_distance(self):
444 return self.kim_model.get_influence_distance()
446 @check_call_wrapper
447 def get_neighbor_list_cutoffs_and_hints(self):
448 return self.kim_model.get_neighbor_list_cutoffs_and_hints()
450 def compute_arguments_create(self):
451 return ComputeArguments(self, self.debug)
453 @property
454 def initialized(self):
455 return hasattr(self, "kim_model")
458class KIMModelParameter(ABC):
459 def __init__(self, kim_model, dtype, extent,
460 name, description, parameter_index):
461 self._kim_model = kim_model
462 self._dtype = dtype
463 self._extent = extent
464 self._name = name
465 self._description = description
467 # Ensure that parameter_index is cast to a C-compatible integer. This
468 # is necessary because this is passed to kimpy.
469 self._parameter_index = c_int(parameter_index)
471 @property
472 def metadata(self):
473 return {
474 "dtype": self._dtype,
475 "extent": self._extent,
476 "name": self._name,
477 "description": self._description,
478 }
480 @c_int_args
481 def _get_one_value(self, index_extent):
482 get_parameter = getattr(self._kim_model, self._dtype_accessor)
483 try:
484 return check_call(
485 get_parameter, self._parameter_index, index_extent)
486 except KimpyError as exception:
487 raise KIMModelParameterError(
488 f"Failed to access component {index_extent} of model "
489 f"parameter of type '{self._dtype}' at parameter index "
490 f"{self._parameter_index}"
491 ) from exception
493 def _set_one_value(self, index_extent, value):
494 value_typecast = self._dtype_c(value)
496 try:
497 check_call(
498 self._kim_model.set_parameter,
499 self._parameter_index,
500 c_int(index_extent),
501 value_typecast,
502 )
503 except KimpyError:
504 raise KIMModelParameterError(
505 f"Failed to set component {index_extent} at parameter index "
506 f"{self._parameter_index} to {self._dtype} value "
507 f"{value_typecast}"
508 )
510 def get_values(self, index_range):
511 index_range_dim = np.ndim(index_range)
512 if index_range_dim == 0:
513 values = self._get_one_value(index_range)
514 elif index_range_dim == 1:
515 values = []
516 for idx in index_range:
517 values.append(self._get_one_value(idx))
518 else:
519 raise KIMModelParameterError(
520 "Index range must be an integer or a list of integers"
521 )
522 return {self._name: [index_range, values]}
524 def set_values(self, index_range, values):
525 index_range_dim = np.ndim(index_range)
526 values_dim = np.ndim(values)
528 # Check the shape of index_range and values
529 msg = "index_range and values must have the same shape"
530 assert index_range_dim == values_dim, msg
532 if index_range_dim == 0:
533 self._set_one_value(index_range, values)
534 elif index_range_dim == 1:
535 assert len(index_range) == len(values), msg
536 for idx, value in zip(index_range, values):
537 self._set_one_value(idx, value)
538 else:
539 raise KIMModelParameterError(
540 "Index range must be an integer or a list containing a "
541 "single integer"
542 )
545class KIMModelParameterInteger(KIMModelParameter):
546 _dtype_c = c_int
547 _dtype_accessor = "get_parameter_int"
550class KIMModelParameterDouble(KIMModelParameter):
551 _dtype_c = c_double
552 _dtype_accessor = "get_parameter_double"
555class ComputeArguments:
556 """Creates a KIM API ComputeArguments object from a KIM Portable
557 Model object and configures it for ASE. A ComputeArguments object
558 is associated with a KIM Portable Model and is used to inform the
559 KIM API of what the model can compute. It is also used to
560 register the data arrays that allow the KIM API to pass the atomic
561 coordinates to the model and retrieve the corresponding energy and
562 forces, etc."""
564 def __init__(self, kim_model_wrapped, debug):
565 self.kim_model_wrapped = kim_model_wrapped
566 self.debug = debug
568 # Create KIM API ComputeArguments object
569 self.compute_args = check_call(
570 self.kim_model_wrapped.kim_model.compute_arguments_create
571 )
573 # Check compute arguments
574 kimpy_arg_name = kimpy.compute_argument_name
575 num_arguments = kimpy_arg_name.get_number_of_compute_argument_names()
576 if self.debug:
577 print("Number of compute_args: {}".format(num_arguments))
579 for i in range(num_arguments):
580 name = check_call(kimpy_arg_name.get_compute_argument_name, i)
581 dtype = check_call(
582 kimpy_arg_name.get_compute_argument_data_type, name)
584 arg_support = self.get_argument_support_status(name)
586 if self.debug:
587 print(
588 "Compute Argument name {:21} is of type {:7} "
589 "and has support "
590 "status {}".format(*[str(x)
591 for x in [name, dtype, arg_support]])
592 )
594 # See if the model demands that we ask it for anything
595 # other than energy and forces. If so, raise an
596 # exception.
597 if arg_support == kimpy.support_status.required:
598 if (
599 name != kimpy.compute_argument_name.partialEnergy
600 and name != kimpy.compute_argument_name.partialForces
601 ):
602 raise KIMModelInitializationError(
603 "Unsupported required ComputeArgument {}".format(name)
604 )
606 # Check compute callbacks
607 callback_name = kimpy.compute_callback_name
608 num_callbacks = callback_name.get_number_of_compute_callback_names()
609 if self.debug:
610 print()
611 print("Number of callbacks: {}".format(num_callbacks))
613 for i in range(num_callbacks):
614 name = check_call(callback_name.get_compute_callback_name, i)
616 support_status = self.get_callback_support_status(name)
618 if self.debug:
619 print(
620 "Compute callback {:17} has support status {}".format(
621 str(name), support_status
622 )
623 )
625 # Cannot handle any "required" callbacks
626 if support_status == kimpy.support_status.required:
627 raise KIMModelInitializationError(
628 "Unsupported required ComputeCallback: {}".format(name)
629 )
631 @check_call_wrapper
632 def set_argument_pointer(self, compute_arg_name, data_object):
633 return self.compute_args.set_argument_pointer(
634 compute_arg_name, data_object)
636 @check_call_wrapper
637 def get_argument_support_status(self, name):
638 return self.compute_args.get_argument_support_status(name)
640 @check_call_wrapper
641 def get_callback_support_status(self, name):
642 return self.compute_args.get_callback_support_status(name)
644 @check_call_wrapper
645 def set_callback(self, compute_callback_name,
646 callback_function, data_object):
647 return self.compute_args.set_callback(
648 compute_callback_name, callback_function, data_object
649 )
651 @check_call_wrapper
652 def set_callback_pointer(
653 self, compute_callback_name, callback, data_object):
654 return self.compute_args.set_callback_pointer(
655 compute_callback_name, callback, data_object
656 )
658 def update(
659 self, num_particles, species_code, particle_contributing,
660 coords, energy, forces
661 ):
662 """Register model input and output in the kim_model object."""
663 compute_arg_name = kimpy.compute_argument_name
664 set_argument_pointer = self.set_argument_pointer
666 set_argument_pointer(compute_arg_name.numberOfParticles, num_particles)
667 set_argument_pointer(
668 compute_arg_name.particleSpeciesCodes,
669 species_code)
670 set_argument_pointer(
671 compute_arg_name.particleContributing, particle_contributing
672 )
673 set_argument_pointer(compute_arg_name.coordinates, coords)
674 set_argument_pointer(compute_arg_name.partialEnergy, energy)
675 set_argument_pointer(compute_arg_name.partialForces, forces)
677 if self.debug:
678 print("Debug: called update_kim")
679 print()
682class SimulatorModel:
683 """Creates a KIM API Simulator Model object and provides a minimal
684 interface to it. This is only necessary in this package in order to
685 extract any information about a given simulator model because it is
686 generally embedded in a shared object.
687 """
689 def __init__(self, model_name):
690 # Create a KIM API Simulator Model object for this model
691 self.model_name = model_name
692 self.simulator_model = simulator_model_create(self.model_name)
694 # Need to close template map in order to access simulator
695 # model metadata
696 self.simulator_model.close_template_map()
698 def __enter__(self):
699 return self
701 def __exit__(self, exc_type, value, traceback):
702 pass
704 @property
705 def simulator_name(self):
706 simulator_name, _ = self.simulator_model.\
707 get_simulator_name_and_version()
708 return simulator_name
710 @property
711 def num_supported_species(self):
712 num_supported_species = self.simulator_model.\
713 get_number_of_supported_species()
714 if num_supported_species == 0:
715 raise KIMModelInitializationError(
716 "Unable to determine supported species of "
717 "simulator model {}.".format(self.model_name)
718 )
719 return num_supported_species
721 @property
722 def supported_species(self):
723 supported_species = []
724 for spec_code in range(self.num_supported_species):
725 species = check_call(
726 self.simulator_model.get_supported_species, spec_code)
727 supported_species.append(species)
729 return tuple(supported_species)
731 @property
732 def num_metadata_fields(self):
733 return self.simulator_model.get_number_of_simulator_fields()
735 @property
736 def metadata(self):
737 sm_metadata_fields = {}
738 for field in range(self.num_metadata_fields):
739 extent, field_name = check_call(
740 self.simulator_model.get_simulator_field_metadata, field
741 )
742 sm_metadata_fields[field_name] = []
743 for ln in range(extent):
744 field_line = check_call(
745 self.simulator_model.get_simulator_field_line, field, ln
746 )
747 sm_metadata_fields[field_name].append(field_line)
749 return sm_metadata_fields
751 @property
752 def supported_units(self):
753 try:
754 supported_units = self.metadata["units"][0]
755 except (KeyError, IndexError):
756 raise KIMModelInitializationError(
757 "Unable to determine supported units of "
758 "simulator model {}.".format(self.model_name)
759 )
761 return supported_units
763 @property
764 def atom_style(self):
765 """
766 See if a 'model-init' field exists in the SM metadata and, if
767 so, whether it contains any entries including an "atom_style"
768 command. This is specific to LAMMPS SMs and is only required
769 for using the LAMMPSrun calculator because it uses
770 lammps.inputwriter to create a data file. All other content in
771 'model-init', if it exists, is ignored.
772 """
773 atom_style = None
774 for ln in self.metadata.get("model-init", []):
775 if ln.find("atom_style") != -1:
776 atom_style = ln.split()[1]
778 return atom_style
780 @property
781 def model_defn(self):
782 return self.metadata["model-defn"]
784 @property
785 def initialized(self):
786 return hasattr(self, "simulator_model")