boltzkit.utils.framework

Functions

create_dispatch([impl_np, impl_torch, ...])

create_pytorch_value_and_grad_fn(value_fn)

Computes per-sample gradients of the first output of value_fn with respect to its first input.

detect_framework(x)

from_numpy(x, target_framework[, style])

from_numpy_recursive(x, target_framework[, ...])

is_jax_array(x)

is_torch_tensor(x)

make_agnostic(*, implementation[, grad_fn, ...])

Make a function agnostic to NumPy, JAX, or PyTorch.

make_agnostic_simple(*, implementation)

to_numpy(x, source_framework)

to_numpy_recursive(x, source_framework)

try_jit_jax(f)

Classes

FrameworkAgnosticFunction

boltzkit.utils.framework.is_torch_tensor(x: GenericArrayType) bool[source]
boltzkit.utils.framework.is_jax_array(x: GenericArrayType) bool[source]
boltzkit.utils.framework.detect_framework(x: GenericArrayType) Literal['numpy', 'jax', 'pytorch'][source]
boltzkit.utils.framework.create_pytorch_value_and_grad_fn(value_fn: Callable[[torch.Tensor], torch.Tensor])[source]

Computes per-sample gradients of the first output of value_fn with respect to its first input.

boltzkit.utils.framework.to_numpy(x: GenericArrayType, source_framework: Literal['numpy', 'jax', 'pytorch']) ndarray[source]
boltzkit.utils.framework.to_numpy_recursive(x: Any, source_framework: Literal['numpy', 'jax', 'pytorch'])[source]
boltzkit.utils.framework.from_numpy(x: ndarray, target_framework: Literal['numpy', 'jax', 'pytorch'], style=None)[source]
boltzkit.utils.framework.from_numpy_recursive(x: GenericArrayType | Any, target_framework: Literal['numpy', 'jax', 'pytorch'], style=None)[source]
class boltzkit.utils.framework.FrameworkAgnosticFunction[source]

Bases: object

__init__(impl_framework: Literal['numpy', 'jax', 'pytorch'], value_fn: Callable[[GenericArrayType], GenericArrayType], grad_fn: None | Callable[[GenericArrayType], GenericArrayType] = None, value_and_grad_fn: None | Callable[[GenericArrayType], tuple[GenericArrayType, GenericArrayType]] = None)[source]
get_value(x: GenericArrayType) GenericArrayType[source]
get_grad(x: GenericArrayType) GenericArrayType[source]
get_value_and_grad(x: GenericArrayType) tuple[GenericArrayType, GenericArrayType][source]
boltzkit.utils.framework.make_agnostic_simple(*, implementation: Literal['numpy', 'jax', 'pytorch'])[source]
boltzkit.utils.framework.try_jit_jax(f: T) T[source]
boltzkit.utils.framework.make_agnostic(*, implementation: Literal['numpy', 'jax', 'pytorch'], grad_fn: None | Callable[[GenericArrayType], GenericArrayType] = None, value_and_grad_fn: None | Callable[[GenericArrayType], tuple[GenericArrayType, GenericArrayType]] = None, value_fn: None | Callable[[GenericArrayType], GenericArrayType] = None)[source]

Make a function agnostic to NumPy, JAX, or PyTorch.

This function takes a value_fn already implemented in one of these frameworks and returns a uniform, framework-agnostic wrapper. Optionally, a gradient function (grad_fn) or a combined value-and-gradient function (value_and_grad_fn) can be provided, which is useful when:

  • Automatic differentiation is not available (e.g., NumPy), or

  • An analytical gradient exists, which may be more efficient than computing it automatically.

The resulting wrapper can be used either as a decorator (requires value_fn to be None):

```python import numpy as np

def log_prob_grad(x: np.ndarray):

return -2 * x

@make_agnostic(implementation=”numpy”, grad_fn=log_prob_grad) def log_prob(x: np.ndarray):

return -np.sum(x**2, axis=-1)

```

or as a factory/wrapper function (requires value_fn to be specified):

```python import numpy as np

def log_prob_grad(x: np.ndarray):

return -2 * x

def log_prob_val(x: np.ndarray):

return -np.sum(x**2, axis=-1)

log_prob = make_agnostic(

implementation=”numpy”, value_fn=log_prob_val, grad_fn=log_prob_grad

)

param implementation:

The framework in which the function is implemented. Supported values: “numpy”, “jax”, or “pytorch”.

type implementation:

FrameworkName

param grad_fn:

A function that computes the gradient of value_fn. Signature should be: grad_fn(x: Array) -> Array.

type grad_fn:

callable, optional

param value_and_grad_fn:

A function returning both the value and gradient. Signature should be: value_and_grad_fn(x: Array) -> Tuple[Array, Array]. Providing this may be more efficient than computing value and gradient separately.

type value_and_grad_fn:

callable, optional

param value_fn:

The primary function to wrap. Signature should be: value_fn(x: Array) -> Array. - If value_fn is provided directly as an argument, make_agnostic acts as a factory/wrapper and returns the wrapped function immediately. - If value_fn is None (default), make_agnostic acts as a decorator that can be applied to a function later.

type value_fn:

callable, optional

returns:

A framework-agnostic wrapper around the provided functions. If value_fn is provided directly, returns the wrapped function immediately; otherwise, returns a decorator that can be applied to a function.

rtype:

FrameworkAgnosticFunction

Framework-specific caveats

  • NumPy: Automatic differentiation is not available. The user should provide

grad_fn, value_and_grad_fn, or both if gradients are needed. - JAX: Functions are automatically JIT-compiled and vectorized when possible. They should therefore be provided in a non-vectorized form. - PyTorch: First-order automatic differentiation is supported. Gradients can flow through value_fn, even if implemented in another framework. Requirement is that the input is of type torch.Tensor. This does not translate to Jax input as jax.grad(f) requires f to be a non-batched function, which would result in many individual calls to value_fn when using jax.vmap(jax.grad(f))(batch).

boltzkit.utils.framework.create_dispatch(impl_np: Callable | None = None, impl_torch: Callable | None = None, impl_jax: Callable | None = None, vmap_jax: bool = False, use_jit: bool = False)[source]