import torch
from torch import nn
__all__ = ["scatter_add"]
[docs]def scatter_add(
x: torch.Tensor, idx_i: torch.Tensor, dim_size: int, dim: int = 0
) -> torch.Tensor:
"""
Sum over values with the same indices.
Args:
x: input values
idx_i: index of center atom i
dim_size: size of the dimension after reduction
dim: the dimension to reduce
Returns:
reduced input
"""
return _scatter_add(x, idx_i, dim_size, dim)
@torch.jit.script
def _scatter_add(
x: torch.Tensor, idx_i: torch.Tensor, dim_size: int, dim: int = 0
) -> torch.Tensor:
shape = list(x.shape)
shape[dim] = dim_size
tmp = torch.zeros(shape, dtype=x.dtype, device=x.device)
y = tmp.index_add(dim, idx_i, x)
return y