task.ModelOutput

class task.ModelOutput(*args: Any, **kwargs: Any)[source]

Defines an output of a model, including mappings to a loss function and weight for training and metrics to be logged.

Parameters:
  • name – name of output in results dict

  • target_property – Name of target in training batch. Only required for supervised training. If not given, the output name is assumed to also be the target name.

  • loss_fn – function to compute the loss

  • loss_weight – loss weight in the composite loss: $l = w_1 l_1 + dots + w_n l_n$

  • metrics – dictionary of metrics with names as keys

  • constraints – constraint class for specifying the usage of model output in the loss function and logged metrics, while not changing the model output itself. Essentially, constraints represent postprocessing transforms that do not affect the model output but only change the loss value. For example, constraints can be used to neglect or weight some atomic forces in the loss function. This may be useful when training on systems, where only some forces are crucial for its dynamics.