boltzkit.utils.molecular.augmentation
Functions
|
Create a rotation and center-of-mass translation augmentation function. |
- boltzkit.utils.molecular.augmentation.create_symmetry_augmentation(sigma: float | None = None, rotation_augmentation: bool = True, COM_augmentation: bool = True) Callable[[GenericArrayType, bool | None, bool | None, bool | None], GenericArrayType][source]
Create a rotation and center-of-mass translation augmentation function.
This factory returns a callable that applies stochastic rigid-body transformations to molecular coordinate samples. The augmentation consists of:
Removing the center-of-mass (COM) from each sample.
Optionally applying a random 3D rotation.
Optionally applying a random COM translation sampled from a Gaussian.
The function supports both NumPy arrays and PyTorch tensors via a framework dispatch mechanism.
- Parameters:
sigma (float or None, optional) – Standard deviation of the Gaussian COM translation. If
None, the default value1 / sqrt(n_atoms)is used for each sample, wheren_atomsis inferred from the input dimensionality.rotation_augmentation (bool, default=True) – Whether to apply random 3D rotations to each sample.
COM_augmentation (bool, default=True) – Whether to apply a random center-of-mass translation after removing the original COM. If
False, samples remain centered at the origin.
- Returns:
A function
augment(samples, sigma_override=None, rotation_override=None, COM_override=None)that applies the augmentation. The function accepts both NumPy arrays and PyTorch tensors and returns an array of the same type.- Return type:
Callable
Notes
The input samples are assumed to represent flattened 3D coordinates of atoms in the form:
(batch_size, n_atoms * 3)
Internally the coordinates are reshaped to:
(batch_size, n_atoms, 3)
before applying the transformations.
This augmentation strategy is inspired by rigid-body data augmentation used in molecular generative modeling, e.g. in Scalable Equilibrium Sampling with Sequential Boltzmann Generators.