boltzkit.utils.langevin
Functions
|
Batched Langevin integrator. |
|
Batched Langevin integrator using the "middle" (Strang splitting) scheme. |
- boltzkit.utils.langevin.integrate_langevin(score_fn: Callable[[ndarray], ndarray], x0: ndarray, stepsize: float, n_steps: int | None = None, callback: Callable[[ndarray, int], None] | None = None, callback_every: int = 1) ndarray[source]
Batched Langevin integrator.
- Parameters:
log_prob_grad – Function mapping (batch, dim) -> (batch, dim), returning gradient of log-probability (score).
x0 – Initial samples of shape (batch, dim).
stepsize – Integration stepsize.
n_steps – Number of steps to run (if None, runs indefinitely).
callback – Optional function called as callback(x, step).
callback_every – Call callback every k steps.
- Returns:
Final samples of shape (batch, dim).
- boltzkit.utils.langevin.integrate_langevin_middle(score_fn: Callable[[ndarray], ndarray], x0: ndarray, stepsize: float, n_steps: int | None = None, callback: Callable[[ndarray, int], None] | None = None, callback_every: int = 1) ndarray[source]
Batched Langevin integrator using the “middle” (Strang splitting) scheme.
- Parameters:
score_fn – Function mapping (batch, dim) -> (batch, dim), returning gradient of log-probability (score).
x0 – Initial samples of shape (batch, dim).
stepsize – Integration stepsize.
n_steps – Number of steps to run (if None, runs indefinitely).
callback – Optional function called as callback(x, step).
callback_every – Call callback every k steps.
- Returns:
Final samples of shape (batch, dim).