Batch-wise Structure Relaxation

In this tutorial, we show how to use the ASEBatchwiseLBFGS. It enables relaxation of structures in a batch-wise manner, i.e. it optimizes multiple structures in parallel. This is particularly useful, when many relatively similar structures (–> similar time until convergence) should be relaxed while requiring possibly short simulation time.

[1]:
import os
import shutil

import torch
from ase.io import read

import schnetpack as spk
from schnetpack import properties
from schnetpack.interfaces.ase_interface import AtomsConverter
from schnetpack.interfaces.batchwise_optimization import ASEBatchwiseLBFGS, BatchwiseCalculator
Matplotlib is building the font cache; this may take a moment.

First, we load the force field model that provides the forces for the relaxation process. Furthermore, we define the atoms converter, which is used to convert ase Atoms objects to SchNetPack input. Eventually the calculator is initialized. The latter provides the necessary functionality to load a model and calculates forces and energy for the respective structures. Please note that running batchwise relaxations is significantly faster on a cuda device.

[2]:
model_path = "../../tests/testdata/md_ethanol.model"

# set device
device = torch.device("cpu")

# load model
model = torch.load(model_path, map_location=device)

# define neighbor list
cutoff = model.representation.cutoff.item()
nbh_list=spk.transform.MatScipyNeighborList(cutoff=cutoff)

# build atoms converter
atoms_converter = AtomsConverter(
    neighbor_list=nbh_list,
    device=device,
)

# build calculator
calculator = BatchwiseCalculator(
    model=model_path,
    atoms_converter=atoms_converter,
    device=device,
    energy_unit="kcal/mol",
    position_unit="Ang",
)

Subsequently, we load the batch of initial structures utilizing ASE (supports xyz, db and more).

[3]:
input_structure_file = "../../tests/testdata/md_ethanol.xyz"

# load initial structures
ats = read(input_structure_file, index=":")

For some systems it helps to fix the positions of certain atoms during the relaxation. This can be achieved by providing a mask of boolean entries to ASEBatchwiseLBFGS. The mask is a list of \(n_\text{atoms}\) entries, indicating atoms, which positions are fixed during the relaxation. Here, we do not fix any atoms. Hence, the mask only contains True.

[4]:
# define structure mask for optimization (True for fixed, False for non-fixed)
n_atoms = len(ats[0].get_atomic_numbers())
single_structure_mask = [False for _ in range(n_atoms)]
# expand mask by number of input structures (fixed atoms are equivalent for all input structures)
mask = single_structure_mask * len(ats)

Finally, we run the optimization:

[5]:
results_dir = "./howto_batchwise_relaxations_outputs"
if not os.path.exists(results_dir):
    os.makedirs(results_dir)

# Initialize optimizer
optimizer = ASEBatchwiseLBFGS(
    calculator=calculator,
    atoms=ats,
    trajectory="./howto_batchwise_relaxations_outputs/relax_traj",
)

# run optimization
optimizer.run(fmax=0.0005, steps=1000)
                   Step     Time         fmax
ASEBatchwiseLBFGS:    0 08:45:30       1.4290
ASEBatchwiseLBFGS:   24 08:45:30       0.0004
[5]:
True
[6]:
if os.path.exists(results_dir):
    shutil.rmtree(results_dir)

Optimzed structures (in the form of ASE Atoms) and properties can be obtained with the get_relaxation_results function.

[7]:
# get list of optimized structures and properties
opt_atoms, opt_props = optimizer.get_relaxation_results()

for oatoms in opt_atoms:
    print(oatoms.get_positions())

print(opt_props)
[[-4.92641876  1.53844337 -0.06455453]
 [-3.41313584  1.45339782 -0.13812974]
 [-5.23321471  2.29324893  0.67332542]
 [-5.35420631  0.56978832  0.23023495]
 [-5.34598031  1.81654815 -1.04170372]
 [-2.99841446  1.18484446  0.85213332]
 [-2.99016388  2.43538579 -0.42366396]
 [-3.07344388  0.45518253 -1.11438772]
 [-2.10454178  0.39966062 -1.16255401]]
{'energy': array([-4215.403], dtype=float32), 'forces': array([[-1.65276186e-04,  3.42524727e-05,  4.74809749e-05],
       [-3.66076507e-04,  9.05731285e-05,  8.92290846e-05],
       [ 1.93987056e-04, -9.31423201e-05, -9.11004099e-05],
       [ 1.65586345e-04,  1.69980340e-04, -9.95497976e-05],
       [ 1.58866125e-04, -9.32974071e-05,  1.63952820e-04],
       [-7.66312442e-05,  1.80355331e-04,  2.37999357e-05],
       [-7.63004064e-05,  3.43868742e-05,  1.79512717e-04],
       [ 9.85908700e-05, -2.08471727e-04, -2.03509102e-04],
       [ 6.73159811e-05, -1.14621194e-04, -1.09686996e-04]], dtype=float32)}
[ ]: