Source code for data.sampler

from typing import Iterator, List, Callable

import numpy as np
from torch.utils.data import Sampler, WeightedRandomSampler

from schnetpack import properties
from schnetpack.data import BaseAtomsData


__all__ = [
    "StratifiedSampler",
    "NumberOfAtomsCriterion",
    "PropertyCriterion",
]


[docs]class NumberOfAtomsCriterion: """ A callable class that returns the number of atoms for each sample in the dataset. """ def __call__(self, dataset): n_atoms = [] for spl_idx in range(len(dataset)): sample = dataset[spl_idx] n_atoms.append(sample[properties.n_atoms].item()) return n_atoms
[docs]class PropertyCriterion: """ A callable class that returns the specified property for each sample in the dataset. Property must be a scalar value. """ def __init__(self, property_key: str = properties.energy): self.property_key = property_key def __call__(self, dataset): property_values = [] for spl_idx in range(len(dataset)): sample = dataset[spl_idx] property_values.append(sample[self.property_key].item()) return property_values
[docs]class StratifiedSampler(WeightedRandomSampler): """ A custom sampler that performs stratified sampling based on a partition criterion. Note: Make sure that num_bins is chosen sufficiently small to avoid too many empty bins. """ def __init__( self, data_source: BaseAtomsData, partition_criterion: Callable[[BaseAtomsData], List], num_samples: int, num_bins: int = 10, replacement: bool = True, verbose: bool = True, ) -> None: """ Args: data_source: The data source to be sampled from. partition_criterion: A callable function that takes a data source and returns a list of values used for partitioning. num_samples: The total number of samples to be drawn from the data source. num_bins: The number of bins to divide the partitioned values into. Defaults to 10. replacement: Whether to sample with replacement or without replacement. Defaults to True. verbose: Whether to print verbose output during sampling. Defaults to True. """ self.data_source = data_source self.num_bins = num_bins self.verbose = verbose weights = self.calculate_weights(partition_criterion) super().__init__(weights=weights, num_samples=num_samples, replacement=replacement) def calculate_weights(self, partition_criterion): """ Calculates the weights for each sample based on the partition criterion. """ feature_values = partition_criterion(self.data_source) bin_counts, bin_edges = np.histogram(feature_values, bins=self.num_bins) bin_edges = bin_edges[1:] bin_edges[-1] += 0.1 bin_indices = np.digitize(feature_values, bin_edges) min_counts = min(bin_counts[bin_counts != 0]) bin_weights = np.where(bin_counts == 0, 0, min_counts / bin_counts) weights = bin_weights[bin_indices] return weights