boltzkit.utils.dataloader

Functions

cache_load_sample_derived_data(samples, ...)

Load or compute derived data for a set of samples with optional caching.

load_from_file(path, data_type, ], ...)

Load data from a file and validate/reshape it according to its type.

load_tica_model(path)

load_topology(path)

Classes

CacheLoadingArgs

Configuration options controlling dataset cache loading and automatic generation of sample-related quantities (log-probs/energies, scores/forces).

CachedRepoDatasetLoader

DatasetLoader

DomainScaledDatasetLoader

boltzkit.utils.dataloader.load_from_file(path: str | ~pathlib.Path, data_type: ~typing.Literal['log_probs', 'samples'], n_samples: int | None = None, dtype: type | ~numpy.dtype = <class 'numpy.float32'>) ndarray[source]

Load data from a file and validate/reshape it according to its type.

This function supports PyTorch (.pt, .pth) and NumPy (.npy, .npz) files. The loaded data is converted to a NumPy array and reshaped based on data_type.

Parameters:
  • path (str) – Path to the file to load. Must exist and have a supported extension.

  • data_type ({"log_probs", "samples"}) –

    Specifies the type of data being loaded, which determines shape validation: - “log_probs”: expects data of shape (batch,) or (batch, 1) and flattens to (batch,) - “samples”: expects data of shape (batch,), (batch, dim), or (batch, n_nodes, 3)

    3D molecular data is flattened to (batch, n_nodes*3)

  • dtype (np.dtype, optional) – Desired floating-point type for the loaded data. The data will be converted to this type after loading. If not specified, the library’s default floating-point type (np.float64) is used.

Returns:

Loaded data as a NumPy array with appropriate shape for the given data_type.

Return type:

np.ndarray

Raises:
  • FileNotFoundError – If path does not exist.

  • ImportError – If the file format requires PyTorch or NumPy and the library is not installed.

  • RuntimeError – If the file could not be loaded.

  • TypeError – If the loaded object is not of the expected type (torch.Tensor for PyTorch, np.ndarray for NumPy).

  • ValueError – If the file extension is unsupported or if the loaded data has an invalid shape for the specified data_type.

Examples

>>> from pathlib import Path
>>> data = _load_from_file(Path("predictions.npy"), data_type="log_probs")
>>> print(data.shape)
(1000,)
>>> data = _load_from_file(Path("samples.pt"), data_type="samples")
>>> print(data.shape)
(1000, 198)  # if original shape was (1000, 66, 3) for molecular coordinates
boltzkit.utils.dataloader.load_tica_model(path: str | Path)[source]
boltzkit.utils.dataloader.load_topology(path: str | Path)[source]
boltzkit.utils.dataloader.cache_load_sample_derived_data(samples: ndarray, data_fpath: Path | None, data_cache_fpath: Path | None = None, data_eval_fn: Callable[[ndarray], ndarray] | None = None, allow_autogen: bool = False, cache_data: bool = False) ndarray[source]

Load or compute derived data for a set of samples with optional caching.

The function attempts, in order, to load data from a primary file, fall back to a cache file, or generate missing data using a provided evaluation function. Generated data can optionally be cached.

Logic priority: 1. Load from primary data_fpath if it exists. 2. Load from data_cache_fpath if it exists (requires cache_data to be True). 3. If allow_autogen is True, compute missing data using data_eval_fn . 4. If cache_data is True, save computed results to data_cache_fpath.

Parameters:
  • samples (numpy.ndarray) – Input samples of shape (n_samples, …).

  • data_fpath (pathlib.Path or None) – Path to the primary data file to load.

  • data_cache_fpath (pathlib.Path or None) – Path to the cache file for loading/saving data.

  • data_eval_fn (Callable[[numpy.ndarray], numpy.ndarray] or None) – Function to compute derived data from samples (e.g., log_probs or scores).

  • allow_autogen (bool) – If True, compute missing data when not available.

  • cache_data (bool) – If True, enable loading from and saving to cache.

Returns:

Array of derived data aligned with samples.

Return type:

numpy.ndarray

Raises:
  • ValueError – If autogeneration is enabled but no evaluation function is provided.

  • RuntimeError – If data cannot be loaded or generated.

class boltzkit.utils.dataloader.DatasetLoader[source]

Bases: ABC

load_dataset(type: Literal['train', 'val', 'test'], length: int, *, include_samples: bool = True, include_log_probs: bool = False, include_scores: bool = False, **kwargs) Dataset[source]

Load the dataset of the specified split.

This method retrieves samples and optionally associated log_probs/energies and scores/forces.

Parameters:
  • type (Literal["train", "val", "test"]) – Dataset split to load.

  • length (int, optional) – Maximum number of samples to load. If -1, all available samples are used.

  • T (float | int | None) – Temperature (in Kelvin) identifying the dataset. Integers are cast to float. If None, the target’s temperature is used.

  • include_samples (bool, default=True) – Whether to return samples.

  • include_log_probs (bool, default=False) – Whether to include energy values for each sample. Fails if no energies are available and allow_autogen is False.

  • include_scores (bool, default=False) – Whether to include force values for each sample. Fails if no forces are available and allow_autogen is False.

Return type:

Dataset

Raises:

ValueError | NotImplementedError | Exception – If dataset configuration is missing or cannot be computed/retrieved

try_load_dataset(*args, **kwargs) Dataset | str[source]

Same input as load_dataset but instead of failing on a missing dataset, the error message is returned.

boltzkit.utils.dataloader._get_cache_path(samples_fpath: Path, cache_data_type: Literal['log_probs', 'scores'] | str) Path[source]

Creates a cache path next to the samples file, e.g., ‘samples.npy’ -> ‘samples.npy_log_probs.npy’

class boltzkit.utils.dataloader.CacheLoadingArgs[source]

Bases: object

Configuration options controlling dataset cache loading and automatic generation of sample-related quantities (log-probs/energies, scores/forces).

Parameters:
  • allow_autogen (bool, optional, default=True) – If True, missing quantities (e.g., log-probs/energies, scores/forces) may be computed automatically online if possible.

  • cache_log_probs (bool, optional, default=True) – Whether log-probs/energies can be cached after online-computation (allow_autogen=True) or loaded from cache files if available.

  • cache_scores (bool, optional, default=False) – Whether scores/forces can be cached after online-computation (allow_autogen=True) or loaded from cache files if available.

allow_autogen: bool = True
cache_log_probs: bool = True
cache_scores: bool = False
__init__(allow_autogen: bool = True, cache_log_probs: bool = True, cache_scores: bool = False) None
class boltzkit.utils.dataloader.CachedRepoDatasetLoader[source]

Bases: DatasetLoader

__init__(kB_T: float, cached_repo: CachedRepo, T: float, log_prob_fn: Callable[[ndarray], ndarray] | None, score_fn: Callable[[ndarray], ndarray] | None, caching_args: CacheLoadingArgs | dict | None = None)[source]
load_dataset(type, length, *, include_samples=True, include_log_probs=False, include_scores=False, **kwargs)[source]

Load from cached repo assuming a specific layout

class boltzkit.utils.dataloader.DomainScaledDatasetLoader[source]

Bases: DatasetLoader

__init__(dataset_loader: DatasetLoader, length_scale: float)[source]
load_dataset(type, length, *, include_samples=True, include_log_probs=False, include_scores=False, **kwargs)[source]

Load the dataset of the specified split.

This method retrieves samples and optionally associated log_probs/energies and scores/forces.

Parameters:
  • type (Literal["train", "val", "test"]) – Dataset split to load.

  • length (int, optional) – Maximum number of samples to load. If -1, all available samples are used.

  • T (float | int | None) – Temperature (in Kelvin) identifying the dataset. Integers are cast to float. If None, the target’s temperature is used.

  • include_samples (bool, default=True) – Whether to return samples.

  • include_log_probs (bool, default=False) – Whether to include energy values for each sample. Fails if no energies are available and allow_autogen is False.

  • include_scores (bool, default=False) – Whether to include force values for each sample. Fails if no forces are available and allow_autogen is False.

Return type:

Dataset

Raises:

ValueError | NotImplementedError | Exception – If dataset configuration is missing or cannot be computed/retrieved