boltzkit.evaluation.eval

Functions

get_histograms(data)

get_pdfs(data)

get_scalar_metrics(data)

make_wandb_compatible(data[, dpi, update_keys])

Convert all elements in the dict into wandb-compatible items (e.g., pdf (in the form of a binary buffer) -> wandb.Image).

run_eval(data, *[, evals, skip_on_missing_data])

update_dict_with_id(target, new_data, idx)

Update target with new_data.

Classes

EnergyHistEval

EvalData

Container for all possible evaluation inputs.

Evaluation

ModelShannonEntropyEval

NllEval

ReverseLogWeightsEval

class boltzkit.evaluation.eval.EvalData[source]

Bases: object

Container for all possible evaluation inputs.

samples_true: ndarray | None = None
samples_pred: ndarray | None = None
true_samples_target_log_prob: ndarray | None = None
pred_samples_target_log_prob: ndarray | None = None
true_samples_model_log_prob: ndarray | None = None
pred_samples_model_log_prob: ndarray | None = None
fits_requirements(requirements: list[Literal['samples_true', 'samples_pred', 'true_samples_target_log_prob', 'pred_samples_target_log_prob', 'true_samples_model_log_prob', 'pred_samples_model_log_prob']]) bool[source]
get_missing_requirements(requirements: list[Literal['samples_true', 'samples_pred', 'true_samples_target_log_prob', 'pred_samples_target_log_prob', 'true_samples_model_log_prob', 'pred_samples_model_log_prob']]) list[Literal['samples_true', 'samples_pred', 'true_samples_target_log_prob', 'pred_samples_target_log_prob', 'true_samples_model_log_prob', 'pred_samples_model_log_prob']][source]
get_required_fields(requirements: list[str])[source]
copy_required(requirements: list[str], eval_cls: type[Evaluation])[source]
__init__(_restricted_access: bool = False, _eval_cls: Evaluation | None = None, samples_true: ndarray | None = None, samples_pred: ndarray | None = None, true_samples_target_log_prob: ndarray | None = None, pred_samples_target_log_prob: ndarray | None = None, true_samples_model_log_prob: ndarray | None = None, pred_samples_model_log_prob: ndarray | None = None) None
class boltzkit.evaluation.eval.Evaluation[source]

Bases: ABC

requirements: list[Literal['samples_true', 'samples_pred', 'true_samples_target_log_prob', 'pred_samples_target_log_prob', 'true_samples_model_log_prob', 'pred_samples_model_log_prob']] = []
__init__()[source]
eval(data: EvalData, skip_on_missing_data: bool = False)[source]
class boltzkit.evaluation.eval.EnergyHistEval[source]

Bases: Evaluation

requirements: list[Literal['samples_true', 'samples_pred', 'true_samples_target_log_prob', 'pred_samples_target_log_prob', 'true_samples_model_log_prob', 'pred_samples_model_log_prob']] = ['true_samples_target_log_prob', 'pred_samples_target_log_prob']
__init__(include_pdf: bool = True, include_pred_histogram: bool = True, include_true_histogram: bool = True)[source]
class boltzkit.evaluation.eval.NllEval[source]

Bases: Evaluation

requirements: list[Literal['samples_true', 'samples_pred', 'true_samples_target_log_prob', 'pred_samples_target_log_prob', 'true_samples_model_log_prob', 'pred_samples_model_log_prob']] = ['true_samples_model_log_prob']
class boltzkit.evaluation.eval.ModelShannonEntropyEval[source]

Bases: Evaluation

requirements: list[Literal['samples_true', 'samples_pred', 'true_samples_target_log_prob', 'pred_samples_target_log_prob', 'true_samples_model_log_prob', 'pred_samples_model_log_prob']] = ['pred_samples_model_log_prob']
class boltzkit.evaluation.eval.ReverseLogWeightsEval[source]

Bases: Evaluation

requirements: list[Literal['samples_true', 'samples_pred', 'true_samples_target_log_prob', 'pred_samples_target_log_prob', 'true_samples_model_log_prob', 'pred_samples_model_log_prob']] = ['pred_samples_target_log_prob', 'pred_samples_model_log_prob']
__init__(include_logZ: bool = True, include_ess: bool = True, include_iw_kl: bool = True)[source]
boltzkit.evaluation.eval.update_dict_with_id(target: dict, new_data: dict, idx: int) dict[source]

Update target with new_data. If a key already exists in target, append unique_id to the key.

boltzkit.evaluation.eval.run_eval(data: EvalData, *, evals: list[Evaluation | tuple[Evaluation]] = [], skip_on_missing_data: bool = True) dict[str, float | int | PdfBuffer | Histogram1D | Histogram2D | Any][source]
boltzkit.evaluation.eval.make_wandb_compatible(data: dict[str, float | int | PdfBuffer | Histogram1D | Histogram2D | Any], dpi: int = 100, update_keys: bool = True)[source]

Convert all elements in the dict into wandb-compatible items (e.g., pdf (in the form of a binary buffer) -> wandb.Image). This function requires the installation of the pip wandb package.

boltzkit.evaluation.eval.get_scalar_metrics(data: dict[str, float | int | PdfBuffer | Histogram1D | Histogram2D | Any])[source]
boltzkit.evaluation.eval.get_histograms(data: dict[str, float | int | PdfBuffer | Histogram1D | Histogram2D | Any]) dict[str, Histogram1D | Histogram2D][source]
boltzkit.evaluation.eval.get_pdfs(data: dict[str, float | int | PdfBuffer | Histogram1D | Histogram2D | Any]) dict[str, PdfBuffer][source]