boltzkit.utils.shape_utils

Functions

get_balanced_grid(n)

Given n subplots, returns a (rows, cols) tuple for a balanced layout.

squeeze_last_dim(x)

Squeeze the last dimension of an array if it has size 1.

boltzkit.utils.shape_utils.squeeze_last_dim(x: ndarray) ndarray[source]

Squeeze the last dimension of an array if it has size 1.

Converts arrays of shape (batch, 1) to (batch,), leaves (batch,) unchanged.

Parameters:

x (np.ndarray) – Input array of shape (batch,) or (batch, 1).

Returns:

Flattened array of shape (batch,).

Return type:

np.ndarray

boltzkit.utils.shape_utils.get_balanced_grid(n: int) tuple[int, int][source]

Given n subplots, returns a (rows, cols) tuple for a balanced layout.

  • Tries to make it roughly square.

  • If not square, prefers one row less than columns.