boltzkit.utils.framework
Functions
|
|
|
Computes per-sample gradients of the first output of value_fn with respect to its first input. |
|
|
|
|
|
|
|
Make a function agnostic to NumPy, JAX, or PyTorch. |
|
|
|
|
|
|
|
Classes
- boltzkit.utils.framework.detect_framework(x: GenericArrayType) Literal['numpy', 'jax', 'pytorch'][source]
- boltzkit.utils.framework.create_pytorch_value_and_grad_fn(value_fn: Callable[[torch.Tensor], torch.Tensor])[source]
Computes per-sample gradients of the first output of value_fn with respect to its first input.
- boltzkit.utils.framework.to_numpy(x: GenericArrayType, source_framework: Literal['numpy', 'jax', 'pytorch']) ndarray[source]
- boltzkit.utils.framework.to_numpy_recursive(x: Any, source_framework: Literal['numpy', 'jax', 'pytorch'])[source]
- boltzkit.utils.framework.from_numpy(x: ndarray, target_framework: Literal['numpy', 'jax', 'pytorch'], style=None)[source]
- boltzkit.utils.framework.from_numpy_recursive(x: GenericArrayType | Any, target_framework: Literal['numpy', 'jax', 'pytorch'], style=None)[source]
- class boltzkit.utils.framework.FrameworkAgnosticFunction[source]
Bases:
object- __init__(impl_framework: Literal['numpy', 'jax', 'pytorch'], value_fn: Callable[[GenericArrayType], GenericArrayType], grad_fn: None | Callable[[GenericArrayType], GenericArrayType] = None, value_and_grad_fn: None | Callable[[GenericArrayType], tuple[GenericArrayType, GenericArrayType]] = None)[source]
- boltzkit.utils.framework.make_agnostic_simple(*, implementation: Literal['numpy', 'jax', 'pytorch'])[source]
- boltzkit.utils.framework.make_agnostic(*, implementation: Literal['numpy', 'jax', 'pytorch'], grad_fn: None | Callable[[GenericArrayType], GenericArrayType] = None, value_and_grad_fn: None | Callable[[GenericArrayType], tuple[GenericArrayType, GenericArrayType]] = None, value_fn: None | Callable[[GenericArrayType], GenericArrayType] = None)[source]
Make a function agnostic to NumPy, JAX, or PyTorch.
This function takes a value_fn already implemented in one of these frameworks and returns a uniform, framework-agnostic wrapper. Optionally, a gradient function (grad_fn) or a combined value-and-gradient function (value_and_grad_fn) can be provided, which is useful when:
Automatic differentiation is not available (e.g., NumPy), or
An analytical gradient exists, which may be more efficient than computing it automatically.
The resulting wrapper can be used either as a decorator (requires value_fn to be None):
- def log_prob_grad(x: np.ndarray):
return -2 * x
@make_agnostic(implementation=”numpy”, grad_fn=log_prob_grad) def log_prob(x: np.ndarray):
return -np.sum(x**2, axis=-1)
or as a factory/wrapper function (requires value_fn to be specified):
- def log_prob_grad(x: np.ndarray):
return -2 * x
- def log_prob_val(x: np.ndarray):
return -np.sum(x**2, axis=-1)
- log_prob = make_agnostic(
implementation=”numpy”, value_fn=log_prob_val, grad_fn=log_prob_grad
)
- param implementation:
The framework in which the function is implemented. Supported values: “numpy”, “jax”, or “pytorch”.
- type implementation:
FrameworkName
- param grad_fn:
A function that computes the gradient of value_fn. Signature should be: grad_fn(x: Array) -> Array.
- type grad_fn:
callable, optional
- param value_and_grad_fn:
A function returning both the value and gradient. Signature should be: value_and_grad_fn(x: Array) -> Tuple[Array, Array]. Providing this may be more efficient than computing value and gradient separately.
- type value_and_grad_fn:
callable, optional
- param value_fn:
The primary function to wrap. Signature should be: value_fn(x: Array) -> Array. - If value_fn is provided directly as an argument, make_agnostic acts as a factory/wrapper and returns the wrapped function immediately. - If value_fn is None (default), make_agnostic acts as a decorator that can be applied to a function later.
- type value_fn:
callable, optional
- returns:
A framework-agnostic wrapper around the provided functions. If value_fn is provided directly, returns the wrapped function immediately; otherwise, returns a decorator that can be applied to a function.
- rtype:
FrameworkAgnosticFunction
Framework-specific caveats
NumPy: Automatic differentiation is not available. The user should provide
grad_fn, value_and_grad_fn, or both if gradients are needed. - JAX: Functions are automatically JIT-compiled and vectorized when possible. They should therefore be provided in a non-vectorized form. - PyTorch: First-order automatic differentiation is supported. Gradients can flow through value_fn, even if implemented in another framework. Requirement is that the input is of type torch.Tensor. This does not translate to Jax input as jax.grad(f) requires f to be a non-batched function, which would result in many individual calls to value_fn when using jax.vmap(jax.grad(f))(batch).