from copy import copy
from typing import Dict
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.callbacks import ModelCheckpoint as BaseModelCheckpoint
from torch_ema import ExponentialMovingAverage as EMA
import torch
import os
from pytorch_lightning.callbacks import BasePredictionWriter
from typing import List, Any
from schnetpack.task import AtomisticTask
from schnetpack import properties
from collections import defaultdict
__all__ = ["ModelCheckpoint", "PredictionWriter", "ExponentialMovingAverage"]
[docs]class PredictionWriter(BasePredictionWriter):
"""
Callback to store prediction results using ``torch.save``.
"""
def __init__(
self,
output_dir: str,
write_interval: str,
write_idx: bool = False,
):
"""
Args:
output_dir: output directory for prediction files
write_interval: can be one of ["batch", "epoch", "batch_and_epoch"]
write_idx: Write molecular ids for all atoms. This is needed for
atomic properties like forces.
"""
super().__init__(write_interval)
self.output_dir = output_dir
self.write_idx = write_idx
os.makedirs(output_dir, exist_ok=True)
def write_on_batch_end(
self,
trainer,
pl_module: AtomisticTask,
prediction: Any,
batch_indices: List[int],
batch: Any,
batch_idx: int,
dataloader_idx: int,
):
bdir = os.path.join(self.output_dir, str(dataloader_idx))
os.makedirs(bdir, exist_ok=True)
torch.save(prediction, os.path.join(bdir, f"{batch_idx}.pt"))
def write_on_epoch_end(
self,
trainer,
pl_module: AtomisticTask,
predictions: List[Any],
batch_indices: List[Any],
):
# collect batches of predictions and restructure
concatenated_predictions = defaultdict(list)
for batch_prediction in predictions[0]:
for property_name, data in batch_prediction.items():
if not self.write_idx and property_name == properties.idx_m:
continue
concatenated_predictions[property_name].append(data)
concatenated_predictions = {
property_name: torch.concat(data)
for property_name, data in concatenated_predictions.items()
}
# save concatenated predictions
torch.save(
concatenated_predictions,
os.path.join(self.output_dir, "predictions.pt"),
)
[docs]class ModelCheckpoint(BaseModelCheckpoint):
"""
Like the PyTorch Lightning ModelCheckpoint callback,
but also saves the best inference model with activated post-processing
"""
def __init__(self, model_path: str, do_postprocessing=True, *args, **kwargs):
super().__init__(*args, **kwargs)
self.model_path = model_path
self.do_postprocessing = do_postprocessing
def on_validation_end(self, trainer, pl_module: AtomisticTask) -> None:
self.trainer = trainer
self.task = pl_module
super().on_validation_end(trainer, pl_module)
def _update_best_and_save(
self, current: torch.Tensor, trainer, monitor_candidates: Dict[str, Any]
):
# save model checkpoint
super()._update_best_and_save(current, trainer, monitor_candidates)
# save best inference model
if isinstance(current, torch.Tensor) and torch.isnan(current):
current = torch.tensor(float("inf" if self.mode == "min" else "-inf"))
if current == self.best_model_score:
if self.trainer.strategy.local_rank == 0:
# remove references to trainer and data loaders to avoid pickle error in ddp
self.task.save_model(self.model_path, do_postprocessing=True)
class ExponentialMovingAverage(Callback):
def __init__(self, decay, *args, **kwargs):
self.decay = decay
self.ema = None
self._to_load = None
def on_fit_start(self, trainer, pl_module: AtomisticTask):
if self.ema is None:
self.ema = EMA(pl_module.model.parameters(), decay=self.decay)
if self._to_load is not None:
self.ema.load_state_dict(self._to_load)
self._to_load = None
# load average parameters, to have same starting point as after validation
self.ema.store()
self.ema.copy_to()
def on_train_epoch_start(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"
) -> None:
self.ema.restore()
def on_train_batch_end(self, trainer, pl_module: AtomisticTask, *args, **kwargs):
self.ema.update()
def on_validation_epoch_start(
self, trainer: "pl.Trainer", pl_module: AtomisticTask, *args, **kwargs
):
self.ema.store()
self.ema.copy_to()
def load_state_dict(self, state_dict):
if "ema" in state_dict:
if self.ema is None:
self._to_load = state_dict["ema"]
else:
self.ema.load_state_dict(state_dict["ema"])
def state_dict(self):
return {"ema": self.ema.state_dict()}