API Reference

This section documents the API for cebmf_torch.

Main Package

class cebmf_torch.cEBMF(data, K=5, prior_L='norm', prior_F='norm', internal_epoch=10, prior_L_kwargs=None, prior_F_kwargs=None, allow_backfitting=True, prune_thresh=0.999, noise_type=NoiseType.CONSTANT, S=None, X_l=None, X_f=None, self_row_cov=False, self_col_cov=False, device=None)[source]

Bases: object

Pure-PyTorch Empirical Bayes Matrix Factorization (EBMF) with NaN handling.

Features

  • Observed-mask weighting in lhat/fhat and their standard errors.

  • Constant or structured noise precision (scalar or per-row/column tau).

  • User-supplied (fixed) standard errors via S — useful for z-scores (pass S=1.0) or pre-computed standard errors of effect-size estimates (pass an (N, P) tensor). When S is provided, the noise variance is treated as known and is not re-estimated during fitting.

  • Mini-batch optimization for mixture weights inside ash().

  • Modular prior and covariate support.

fit(maxit=50)[source]

Fit the cEBMF model for a specified number of iterations.

Parameters:

maxit (int, optional) – Number of iterations to run. Default is 50.

Returns:

Result container with fitted factors, noise, and objective history.

Return type:

CEBMFResult

initialise_factors(method='svd', *, L=None, F=None)[source]

Initialize factor matrices using the specified method, or user-provided initial factors.

Parameters:
  • method (str) – Initialization method (‘svd’, ‘random’, or ‘zero’). Default is ‘svd’. Ignored if L and F are provided.

  • L (Tensor or None, optional) – User-provided initial factor matrix (N, K). Ignored if F not also provided.

  • F (Tensor or None, optional) – User-provided initial factor matrix (P, K). Ignored if L not also provided.

iter_once()[source]

Perform one iteration of the cEBMF update (update all factors and noise).

update_tau()[source]

Update the noise precision parameter(s) according to the noise model.

Behaviour by noise type:

  • CONSTANT -> scalar tau; also provides tau_map (N,P) if you need it

  • ROW_WISE -> tau_row (N,), tau_map broadcast to (N,P)

  • COLUMN_WISE -> tau_col (P,), tau_map broadcast to (N,P)

  • KNOWN -> no-op; tau_map was built once from the user-supplied S.

Parameters:
  • data (Tensor)

  • K (int)

  • prior_L (str)

  • prior_F (str)

  • internal_epoch (int)

  • prior_L_kwargs (dict | None)

  • prior_F_kwargs (dict | None)

  • allow_backfitting (bool)

  • prune_thresh (float)

  • noise_type (NoiseType)

  • S (Tensor | float | int | None)

  • X_l (Tensor | None)

  • X_f (Tensor | None)

  • self_row_cov (bool)

  • self_col_cov (bool)

  • device (device | None)

cebmf_torch.ash(x, s, prior=PriorType.NORM, mult=1.4142135623730951, penalty=10.0, verbose=False, threshold_loglikelihood=-300.0, mode=0.0, *, batch_size=128, shuffle=False, seed=None, optimizer='em')[source]

Adaptive shrinkage with mixture priors (“norm” or “exp”) in pure PyTorch.

Uses EM for π by default (mini-batch capable via batch_size). Set optimizer="lbfgs" to use L-BFGS with softmax reparameterisation, which produces sparse solutions matching R ashr’s convex optimiser.

Parameters:
  • x (torch.Tensor) – Observed data.

  • s (torch.Tensor) – Standard errors of the observed data.

  • prior (PriorType, optional) – Type of prior to use (default: PriorType.NORM).

  • mult (float, optional) – Multiplier for scale grid (default: sqrt(2.0)).

  • penalty (float, optional) – Penalty for mixture weights (default: 10.0).

  • verbose (bool, optional) – Verbosity flag (default: True).

  • threshold_loglikelihood (float, optional) – Minimum log-likelihood threshold (default: -300.0).

  • mode (float, optional) – Mode parameter (for normal prior only, default: 0.0).

  • batch_size (int or None, optional) – Batch size for EM updates (default: 128).

  • shuffle (bool, optional) – Whether to shuffle data in EM (default: False).

  • seed (int or None, optional) – Random seed for reproducibility.

  • optimizer (str, optional) – "em" (default) or "lbfgs". L-BFGS produces sparse solutions matching R ashr’s convex optimiser.

Returns:

Result object containing posterior summaries and model parameters.

Return type:

ASHResult

cebmf_torch.ebnm_point_exp(x, s, par_init=None, fix_par=(False, False, True), max_iter=20, tol=1e-06, a_bounds=(0.01, 100.0), loga_l2=0.001, tresh_pi0=0.001, eps=1e-12)[source]

Direct maximization (no EM) of the observed marginal log-likelihood for a point-Exponential EBNM.

Prior on θ: (1 - pi_slab) δ_μ + pi_slab [μ + Exp(a)], with support θ ≥ μ. Returns pure marginal log-likelihood (no penalties).

Return type:

EBNMPointExp

Parameters:
cebmf_torch.ebnm_point_laplace(x, s, par_init=None, fix_par=(False, False, True), max_iter=20, tol=1e-06, a_bounds=(0.01, 100.0), loga_l2=0.0, tresh_pi0=0.001, eps=1e-12, pen_pi0=0.0, use_adam_warmstart=False, adam_steps=8, adam_lr=0.01, weight_decay=0.0)[source]

Efficient direct maximisation of the observed marginal log-likelihood for a point-Laplace EBNM. GPU-resident throughout — no host<->device syncs in either the inner objective or the posterior summary.

Optimizer: LBFGS-only (AdamW warm-start optional and short). The four scalar fields of the result (pi_slab, a, mu, log_lik) are returned as 0-d tensors on the input device so the cEBMF ELBO accumulator can fold them in without forcing a host sync per factor update. Call float(field) at your own boundary if you need a Python scalar.

Return type:

EBNMLaplaceResult

Parameters:

Core Classes

class cebmf_torch.cEBMF(data, K=5, prior_L='norm', prior_F='norm', internal_epoch=10, prior_L_kwargs=None, prior_F_kwargs=None, allow_backfitting=True, prune_thresh=0.999, noise_type=NoiseType.CONSTANT, S=None, X_l=None, X_f=None, self_row_cov=False, self_col_cov=False, device=None)[source]

Bases: object

Pure-PyTorch Empirical Bayes Matrix Factorization (EBMF) with NaN handling.

Features

  • Observed-mask weighting in lhat/fhat and their standard errors.

  • Constant or structured noise precision (scalar or per-row/column tau).

  • User-supplied (fixed) standard errors via S — useful for z-scores (pass S=1.0) or pre-computed standard errors of effect-size estimates (pass an (N, P) tensor). When S is provided, the noise variance is treated as known and is not re-estimated during fitting.

  • Mini-batch optimization for mixture weights inside ash().

  • Modular prior and covariate support.

fit(maxit=50)[source]

Fit the cEBMF model for a specified number of iterations.

Parameters:

maxit (int, optional) – Number of iterations to run. Default is 50.

Returns:

Result container with fitted factors, noise, and objective history.

Return type:

CEBMFResult

initialise_factors(method='svd', *, L=None, F=None)[source]

Initialize factor matrices using the specified method, or user-provided initial factors.

Parameters:
  • method (str) – Initialization method (‘svd’, ‘random’, or ‘zero’). Default is ‘svd’. Ignored if L and F are provided.

  • L (Tensor or None, optional) – User-provided initial factor matrix (N, K). Ignored if F not also provided.

  • F (Tensor or None, optional) – User-provided initial factor matrix (P, K). Ignored if L not also provided.

iter_once()[source]

Perform one iteration of the cEBMF update (update all factors and noise).

update_tau()[source]

Update the noise precision parameter(s) according to the noise model.

Behaviour by noise type:

  • CONSTANT -> scalar tau; also provides tau_map (N,P) if you need it

  • ROW_WISE -> tau_row (N,), tau_map broadcast to (N,P)

  • COLUMN_WISE -> tau_col (P,), tau_map broadcast to (N,P)

  • KNOWN -> no-op; tau_map was built once from the user-supplied S.

Parameters:
  • data (Tensor)

  • K (int)

  • prior_L (str)

  • prior_F (str)

  • internal_epoch (int)

  • prior_L_kwargs (dict | None)

  • prior_F_kwargs (dict | None)

  • allow_backfitting (bool)

  • prune_thresh (float)

  • noise_type (NoiseType)

  • S (Tensor | float | int | None)

  • X_l (Tensor | None)

  • X_f (Tensor | None)

  • self_row_cov (bool)

  • self_col_cov (bool)

  • device (device | None)

EBNM Solvers

Empirical Bayes Normal Means (EBNM) solvers.

cebmf_torch.ebnm.ash(x, s, prior=PriorType.NORM, mult=1.4142135623730951, penalty=10.0, verbose=False, threshold_loglikelihood=-300.0, mode=0.0, *, batch_size=128, shuffle=False, seed=None, optimizer='em')[source]

Adaptive shrinkage with mixture priors (“norm” or “exp”) in pure PyTorch.

Uses EM for π by default (mini-batch capable via batch_size). Set optimizer="lbfgs" to use L-BFGS with softmax reparameterisation, which produces sparse solutions matching R ashr’s convex optimiser.

Parameters:
  • x (torch.Tensor) – Observed data.

  • s (torch.Tensor) – Standard errors of the observed data.

  • prior (PriorType, optional) – Type of prior to use (default: PriorType.NORM).

  • mult (float, optional) – Multiplier for scale grid (default: sqrt(2.0)).

  • penalty (float, optional) – Penalty for mixture weights (default: 10.0).

  • verbose (bool, optional) – Verbosity flag (default: True).

  • threshold_loglikelihood (float, optional) – Minimum log-likelihood threshold (default: -300.0).

  • mode (float, optional) – Mode parameter (for normal prior only, default: 0.0).

  • batch_size (int or None, optional) – Batch size for EM updates (default: 128).

  • shuffle (bool, optional) – Whether to shuffle data in EM (default: False).

  • seed (int or None, optional) – Random seed for reproducibility.

  • optimizer (str, optional) – "em" (default) or "lbfgs". L-BFGS produces sparse solutions matching R ashr’s convex optimiser.

Returns:

Result object containing posterior summaries and model parameters.

Return type:

ASHResult

cebmf_torch.ebnm.ebnm_point_exp(x, s, par_init=None, fix_par=(False, False, True), max_iter=20, tol=1e-06, a_bounds=(0.01, 100.0), loga_l2=0.001, tresh_pi0=0.001, eps=1e-12)[source]

Direct maximization (no EM) of the observed marginal log-likelihood for a point-Exponential EBNM.

Prior on θ: (1 - pi_slab) δ_μ + pi_slab [μ + Exp(a)], with support θ ≥ μ. Returns pure marginal log-likelihood (no penalties).

Return type:

EBNMPointExp

Parameters:
cebmf_torch.ebnm.ebnm_point_laplace(x, s, par_init=None, fix_par=(False, False, True), max_iter=20, tol=1e-06, a_bounds=(0.01, 100.0), loga_l2=0.0, tresh_pi0=0.001, eps=1e-12, pen_pi0=0.0, use_adam_warmstart=False, adam_steps=8, adam_lr=0.01, weight_decay=0.0)[source]

Efficient direct maximisation of the observed marginal log-likelihood for a point-Laplace EBNM. GPU-resident throughout — no host<->device syncs in either the inner objective or the posterior summary.

Optimizer: LBFGS-only (AdamW warm-start optional and short). The four scalar fields of the result (pi_slab, a, mu, log_lik) are returned as 0-d tensors on the input device so the cEBMF ELBO accumulator can fold them in without forcing a host sync per factor update. Call float(field) at your own boundary if you need a Python scalar.

Return type:

EBNMLaplaceResult

Parameters:
cebmf_torch.ebnm.ebnm_gb(x, s, omega=0.2, par_init_mu=1.0, par_init_pi=0.2, max_em=200, tol_em=1e-05, max_lbfgs=200, tol_lbfgs=1e-06, eps=1e-12)[source]
Return type:

EBNMGBResult

Parameters:
EBNM with Generalized Binary prior:

θ ~ (1-π) δ0 + π N_+(μ, σ^2), with σ = ω μ, μ≥0, ω fixed.

Follows the EM scheme in Supplementary Note (eqs. (16)-(27)). :contentReference[oaicite:1]{index=1}

cebmf_torch.ash(x, s, prior=PriorType.NORM, mult=1.4142135623730951, penalty=10.0, verbose=False, threshold_loglikelihood=-300.0, mode=0.0, *, batch_size=128, shuffle=False, seed=None, optimizer='em')[source]

Adaptive shrinkage with mixture priors (“norm” or “exp”) in pure PyTorch.

Uses EM for π by default (mini-batch capable via batch_size). Set optimizer="lbfgs" to use L-BFGS with softmax reparameterisation, which produces sparse solutions matching R ashr’s convex optimiser.

Parameters:
  • x (torch.Tensor) – Observed data.

  • s (torch.Tensor) – Standard errors of the observed data.

  • prior (PriorType, optional) – Type of prior to use (default: PriorType.NORM).

  • mult (float, optional) – Multiplier for scale grid (default: sqrt(2.0)).

  • penalty (float, optional) – Penalty for mixture weights (default: 10.0).

  • verbose (bool, optional) – Verbosity flag (default: True).

  • threshold_loglikelihood (float, optional) – Minimum log-likelihood threshold (default: -300.0).

  • mode (float, optional) – Mode parameter (for normal prior only, default: 0.0).

  • batch_size (int or None, optional) – Batch size for EM updates (default: 128).

  • shuffle (bool, optional) – Whether to shuffle data in EM (default: False).

  • seed (int or None, optional) – Random seed for reproducibility.

  • optimizer (str, optional) – "em" (default) or "lbfgs". L-BFGS produces sparse solutions matching R ashr’s convex optimiser.

Returns:

Result object containing posterior summaries and model parameters.

Return type:

ASHResult

cebmf_torch.ebnm_point_exp(x, s, par_init=None, fix_par=(False, False, True), max_iter=20, tol=1e-06, a_bounds=(0.01, 100.0), loga_l2=0.001, tresh_pi0=0.001, eps=1e-12)[source]

Direct maximization (no EM) of the observed marginal log-likelihood for a point-Exponential EBNM.

Prior on θ: (1 - pi_slab) δ_μ + pi_slab [μ + Exp(a)], with support θ ≥ μ. Returns pure marginal log-likelihood (no penalties).

Return type:

EBNMPointExp

Parameters:
cebmf_torch.ebnm_point_laplace(x, s, par_init=None, fix_par=(False, False, True), max_iter=20, tol=1e-06, a_bounds=(0.01, 100.0), loga_l2=0.0, tresh_pi0=0.001, eps=1e-12, pen_pi0=0.0, use_adam_warmstart=False, adam_steps=8, adam_lr=0.01, weight_decay=0.0)[source]

Efficient direct maximisation of the observed marginal log-likelihood for a point-Laplace EBNM. GPU-resident throughout — no host<->device syncs in either the inner objective or the posterior summary.

Optimizer: LBFGS-only (AdamW warm-start optional and short). The four scalar fields of the result (pi_slab, a, mu, log_lik) are returned as 0-d tensors on the input device so the cEBMF ELBO accumulator can fold them in without forcing a host sync per factor update. Call float(field) at your own boundary if you need a Python scalar.

Return type:

EBNMLaplaceResult

Parameters:

Covariate-Enhanced EBNM

Covariate-Enhanced Empirical Bayes Normal Means (cEBNM) solvers.

cebmf_torch.cebnm.cash_posterior_means(X, betahat, sebetahat, n_epochs=20, n_layers=4, num_classes=20, hidden_dim=64, batch_size=128, lr=0.001, model_param=None, penalty=1.5, device=None)[source]

GPU-native CASH training and posterior computation.

Fit a CASH (Covariate Adaptive Shrinkage) model and compute posterior means, second moments, and standard deviations.

Parameters:
  • X (torch.Tensor or np.ndarray) – Covariates for each observation, shape (n_samples, n_features).

  • betahat (torch.Tensor or np.ndarray) – Observed effect estimates, shape (n_samples,).

  • sebetahat (torch.Tensor or np.ndarray) – Standard errors of the effect estimates, shape (n_samples,).

  • n_epochs (int, optional) – Number of training epochs (default=20).

  • n_layers (int, optional) – Number of hidden layers in the neural network (default=4).

  • num_classes (int, optional) – Number of mixture components (default=20).

  • hidden_dim (int, optional) – Number of hidden units in each layer (default=64).

  • batch_size (int, optional) – Batch size for training (default=128).

  • lr (float, optional) – Learning rate for the optimizer (default=0.001).

  • model_param (dict, optional) – Pre-trained model parameters to initialize the network.

  • penalty (float, optional) – Penalty for spike probability (default=1.5).

  • device (torch.device, optional) – Target device for tensors and the model. If None, inherits from betahat when it is already a tensor; otherwise falls back to CUDA (if available) or CPU. Inheriting from the input avoids silent cross-device hops when the caller (e.g. cEBMF) is on CPU/MPS but CUDA is also visible on the host.

Returns:

Container with posterior means, standard deviations, and model parameters.

Return type:

cash_PosteriorMeanNorm

cebmf_torch.cebnm.cgb_posterior_means(X, betahat, sebetahat, n_epochs=50, n_layers=2, hidden_dim=32, batch_size=128, lr=0.001, penalty=1.5, model_param=None, device=None)[source]

Fit a Covariate Generalized-Binary (CGB) model to estimate the prior distribution of effects.

Parameters:
  • X (torch.Tensor or np.ndarray) – Covariates for each observation, shape (n_samples, n_features).

  • betahat (torch.Tensor or np.ndarray) – Observed effect estimates, shape (n_samples,).

  • sebetahat (torch.Tensor or np.ndarray) – Standard errors of the effect estimates, shape (n_samples,).

  • n_epochs (int, optional) – Number of training epochs (default=50).

  • n_layers (int, optional) – Number of hidden layers in the neural network (default=2).

  • hidden_dim (int, optional) – Number of hidden units in each layer (default=32).

  • batch_size (int, optional) – Batch size for training (default=128).

  • lr (float, optional) – Learning rate for the optimizer (default=1e-3).

  • penalty (float, optional) – Penalty for spike probability (default=1.5).

  • model_param (dict, optional) – Pre-trained model parameters to initialize the network.

  • device (device | None)

Returns:

Container with posterior means, standard deviations, and model parameters.

Return type:

CgbPosteriorResult

cebmf_torch.cebnm.emdn_posterior_means(X, betahat, sebetahat, n_epochs=50, n_layers=4, n_gaussians=5, hidden_dim=64, batch_size=512, lr=0.001, model_param=None, device=None)[source]

Fit a Mixture Density Network (MDN) to estimate the prior distribution of effects.

In the EBNM problem, we observe estimates betahat with standard errors sebetahat and want to estimate the prior distribution of the true effects. The prior is modeled as a mixture of Gaussians with parameters predicted by a neural network.

Parameters:
  • X (torch.Tensor or np.ndarray) – Covariates for each observation, shape (n_samples, n_features).

  • betahat (torch.Tensor or np.ndarray) – Observed effect estimates, shape (n_samples,).

  • sebetahat (torch.Tensor or np.ndarray) – Standard errors of the effect estimates, shape (n_samples,).

  • n_epochs (int, optional) – Number of training epochs (default=50).

  • n_layers (int, optional) – Number of hidden layers in the neural network (default=4).

  • n_gaussians (int, optional) – Number of Gaussian components in the mixture (default=5).

  • hidden_dim (int, optional) – Number of hidden units in each layer (default=64).

  • batch_size (int, optional) – Batch size for training (default=512).

  • lr (float, optional) – Learning rate for the optimizer (default=1e-3).

  • model_param (dict, optional) – Pre-trained model parameters to initialize the network.

  • device (torch.device, optional) – Target device for tensors/models. Defaults to CUDA if available, else CPU.

Returns:

Container with posterior means, standard deviations, and model parameters.

Return type:

EmdnPosteriorMeanNorm

cebmf_torch.cebnm.lcash_posterior_means(X, betahat, sebetahat, n_epochs=200, batch_size=512, lr=0.001, weight_decay=0.001, penalty=1.5, mult=1.4142135623730951, ash_init=True, ash_threshold=1e-06, model_param=None, device=None, verbose=True, seed=42)[source]

LC-ASH: linear covariate-modulated mixture weights.

Parameters:
  • X (tensor (G, F)) – Feature matrix. Standardised internally with NaN-aware statistics (mean/std computed on non-NaN values, NaN positions zero-filled). Pre-standardisation is not required.

  • betahat (tensor (G,)) – Effect estimates.

  • sebetahat (tensor (G,)) – Standard errors.

  • n_epochs (int or None) – Training epochs. Inside cEBMF, overridden by internal_epoch.

  • batch_size (int) – Mini-batch size for Adam.

  • lr (float) – Learning rate.

  • weight_decay (float) – L2 penalty on feature coefficients only (not bias).

  • penalty (float) – Dirichlet spike penalty (lambda_pen). 1.0 = no penalty.

  • mult (float) – Multiplicative step between mixture grid SDs. Smaller values give a finer grid with more components. Default sqrt(2) matches R ashr and gives ~27 components before pruning.

  • ash_init (bool) – If True (default), run ash internally (L-BFGS optimizer) to prune the grid to active components and initialise the bias from the ash weights, so the model starts at the exchangeable ash solution when all feature coefficients are zero. If False, use the full grid with uniform bias initialisation.

  • ash_threshold (float) – Pruning threshold: components with pi <= threshold are dropped. Only used when ash_init=True.

  • model_param (dict or None) – State dict from a previous call, for warm-starting.

  • device (torch.device or None) – Compute device. Defaults to CUDA if available.

  • verbose (bool) – If True (default), print training progress every 50 epochs.

  • seed (int) – Random seed for weight initialisation and batch ordering.

Returns:

Container with post_mean, post_mean2, post_sd, pi_np (G, K), scale (K,), loss, model_param (state dict for warm-starting).

Return type:

cash_PosteriorMeanNorm

cebmf_torch.cebnm.po_lcash_posterior_means(X, betahat, sebetahat, n_epochs=200, batch_size=512, lr=0.001, weight_decay=0.001, penalty=1.5, mult=1.4142135623730951, ash_init=True, ash_threshold=1e-06, model_param=None, device=None, verbose=True, seed=42)[source]

Proportional odds LC-ASH: ordered logistic covariate-modulated weights.

A shared weight vector maps features to a scalar signal strength s_i = x_i^T w. K-1 ordered cut-points convert s_i to mixture weights via cumulative logistic probabilities. This has F + K - 1 parameters (vs K * F for softmax LC-ASH), making it more parsimonious when K is large relative to F.

When ash_init=True, the grid is pruned to ash’s active components and the cut-points are initialised from the ash weights, so the model starts at the exchangeable ash solution.

Parameters:
  • X (tensor (G, F)) – Feature matrix. Standardised internally with NaN-aware statistics.

  • betahat (Tensor)

  • sebetahat (Tensor)

  • n_epochs (int | None)

  • batch_size (int)

  • lr (float)

  • weight_decay (float)

  • penalty (float)

  • mult (float)

  • ash_init (bool)

  • ash_threshold (float)

  • model_param (dict | None)

  • device (device | None)

  • verbose (bool)

  • seed (int)

Return type:

cash_PosteriorMeanNorm

:param : :type mult: float :param mult: See lcash_posterior_means. :type ash_init: bool :param ash_init: See lcash_posterior_means. :type ash_threshold: float :param ash_threshold: See lcash_posterior_means. :type model_param: dict | None :param model_param: See lcash_posterior_means. :type device: device | None :param device: See lcash_posterior_means. :type verbose: bool :param verbose: See lcash_posterior_means. :type seed: int :param seed: See lcash_posterior_means.

Returns:

Container with post_mean, post_mean2, post_sd, pi_np (G, K), scale (K,), loss, model_param (state dict for warm-starting).

Return type:

cash_PosteriorMeanNorm

Parameters:
  • X (Tensor)

  • betahat (Tensor)

  • sebetahat (Tensor)

  • n_epochs (int | None)

  • batch_size (int)

  • lr (float)

  • weight_decay (float)

  • penalty (float)

  • mult (float)

  • ash_init (bool)

  • ash_threshold (float)

  • model_param (dict | None)

  • device (device | None)

  • verbose (bool)

  • seed (int)

cebmf_torch.cebnm.sharp_cgb_posterior_means(X, betahat, sebetahat, n_epochs=50, n_layers=2, omega=0.02, hidden_dim=32, batch_size=128, lr=0.001, penalty=1.5, model_param=None, eps=1e-08, device=None)[source]

Fit a Covariate Generalized-Binary (CGB) model to estimate the prior distribution of effects.

Parameters:
  • X (torch.Tensor or np.ndarray) – Covariates for each observation, shape (n_samples, n_features).

  • betahat (torch.Tensor or np.ndarray) – Observed effect estimates, shape (n_samples,).

  • sebetahat (torch.Tensor or np.ndarray) – Standard errors of the effect estimates, shape (n_samples,).

  • n_epochs (int, optional) – Number of training epochs (default=50).

  • n_layers (int, optional) – Number of hidden layers in the neural network (default=2).

  • hidden_dim (int, optional) – Number of hidden units in each layer (default=32).

  • batch_size (int, optional) – Batch size for training (default=128).

  • lr (float, optional) – Learning rate for the optimizer (default=1e-3).

  • penalty (float, optional) – Penalty for spike probability (default=1.5).

  • model_param (dict, optional) – Pre-trained model parameters to initialize the network.

  • device (device | None)

Returns:

Container with posterior means, standard deviations, and model parameters.

Return type:

CgbPosteriorResult

cebmf_torch.cebnm.spiked_emdn_posterior_means(X, betahat, sebetahat, n_epochs=50, n_layers=4, n_gaussians=5, hidden_dim=64, batch_size=512, lr=0.001, model_param=None, *, penalty=1.5, beta_prior=None, print_every=10, device=None)[source]

Fit a Mixture Density Network (MDN) with spike and slab components to estimate the prior distribution of effects.

In the EBNM problem, we observe estimates betahat with standard errors sebetahat and want to estimate the prior distribution of the true effects. The prior is modeled as a mixture of Gaussians (slabs) plus a point mass at zero (spike), with mixture parameters predicted by a neural network.

Parameters:
  • X (torch.Tensor or np.ndarray) – Covariates for each observation, shape (n_samples, n_features).

  • betahat (torch.Tensor or np.ndarray) – Observed effect estimates, shape (n_samples,).

  • sebetahat (torch.Tensor or np.ndarray) – Standard errors of the effect estimates, shape (n_samples,).

  • n_epochs (int, optional) – Number of training epochs (default=50).

  • n_layers (int, optional) – Number of hidden layers in the neural network (default=4).

  • n_gaussians (int, optional) – Number of Gaussian components in the mixture (default=5).

  • hidden_dim (int, optional) – Number of hidden units in each layer (default=64).

  • batch_size (int, optional) – Batch size for training (default=512).

  • lr (float, optional) – Learning rate for the optimizer (default=1e-3).

  • model_param (dict, optional) – Pre-trained model parameters to initialize the network.

  • penalty (float, optional) – >1 encourages spike; =1 neutral (default=1.5).

  • beta_prior (tuple, optional) – (alpha0, beta0) for Beta prior on pi_spike.

  • print_every (int, optional) – Print training loss every this many epochs (default=10).

  • device (torch.device, optional) – Target device for tensors/models. Defaults to CUDA if available, else CPU.

Returns:

Container with posterior means, standard deviations, and model parameters.

Return type:

EmdnPosteriorMeanNorm

Utilities

Mathematical utilities and helper functions.

cebmf_torch.utils.get_device(prefer_gpu=True)[source]

Get the best available device.

Priority order: 1. CUDA (NVIDIA GPUs) 2. MPS (Apple Silicon GPUs) 3. CPU (fallback)

Parameters:

prefer_gpu (bool) – Whether to prefer GPU over CPU

Returns:

The selected device

Return type:

torch.device

cebmf_torch.utils.my_e2truncnorm(a, b, mean=0.0, sd=1.0, precision='auto')[source]

Compute E[Z^2 | a < Z < b] for Z ~ N(mean, sd^2), the second moment of a truncated normal.

Parameters:
  • a (float or torch.Tensor) – Lower truncation bound.

  • b (float or torch.Tensor) – Upper truncation bound.

  • mean (float or torch.Tensor, optional) – Mean of the normal distribution. Default is 0.0.

  • sd (float or torch.Tensor, optional) – Standard deviation of the normal distribution. Default is 1.0.

  • precision (str, optional) – Internal compute dtype. See my_etruncnorm() for the contract. "auto" (default) keeps the caller’s dtype (fast on CUDA); "float64" matches the pre-2026 behaviour.

Returns:

Second moment of the truncated normal distribution. Dtype matches the chosen precision.

Return type:

torch.Tensor

cebmf_torch.utils.my_etruncnorm(a, b, mean=0.0, sd=1.0, precision='auto')[source]

Compute E[Z | a < Z < b] for Z ~ N(mean, sd^2), the mean of a truncated normal.

Parameters:
  • a (float or torch.Tensor) – Lower truncation bound.

  • b (float or torch.Tensor) – Upper truncation bound.

  • mean (float or torch.Tensor, optional) – Mean of the normal distribution. Default is 0.0.

  • sd (float or torch.Tensor, optional) – Standard deviation of the normal distribution. Default is 1.0.

  • precision (str, optional) – Internal compute dtype. "auto" (default) follows the input dtype (typically float32), which is fast on CUDA. "float64" forces double precision — slower on CUDA but matches the pre-2026 behaviour.

Returns:

Mean of the truncated normal distribution. Dtype matches the chosen precision.

Return type:

torch.Tensor

cebmf_torch.utils.safe_tensor_to_float(value, null_value=-inf, reduction='min')[source]

Convert tensor, float, or None to float with safe handling.

Parameters:
  • value (torch.Tensor, float, or None) – Value to convert.

  • null_value (float, optional) – Value to return if input is None or empty. Default is -inf.

  • reduction (str, optional) – Reduction to apply if input is a tensor (“min” or “max”). Default is “min”.

Returns:

Converted float value.

Return type:

float

cebmf_torch.utils.autoselect_scales_mix_norm(betahat, sebetahat, max_class=None, mult=2.0)[source]

Automatically select scales for a normal mixture prior.

Parameters:
  • betahat (torch.Tensor) – Observed effect size estimates.

  • sebetahat (torch.Tensor) – Standard errors of the effect size estimates.

  • max_class (int or None, optional) – If provided, number of mixture components (scales) to return.

  • mult (float, optional) – Multiplicative step between scales. Default is 2.0.

Returns:

1D tensor of selected scales.

Return type:

torch.Tensor

cebmf_torch.utils.optimize_pi_logL(logL, penalty, max_iters=100, tol=1e-06, verbose=False, batch_size=None, shuffle=False, seed=None, check_every=10)[source]

EM algorithm for optimizing mixture weights on the simplex given a log-likelihood matrix.

Parameters:
  • logL (torch.Tensor) – (n, K) tensor with entries logL[j, k] = log l_{jk}.

  • penalty (float or torch.Tensor) – Dirichlet pseudo-count alpha_1 on component 0 (or a length-K vector). In the original code, vec_pen[0] = penalty and others = 1.

  • max_iters (int, optional) – Number of EM epochs. Default is 100.

  • tol (float, optional) – L2 tolerance on pi change for convergence. Default is 1e-6.

  • verbose (bool, optional) – Print convergence message if True. Default is False (cEBMF calls this in its inner loop and the convergence prints would spam stdout).

  • batch_size (int or None, optional) – If None, do full-batch; else iterate over mini-batches each epoch. Default is None.

  • shuffle (bool, optional) – Whether to shuffle rows each epoch when using batches. Default is False.

  • seed (int or None, optional) – RNG seed used when shuffle=True. Default is None.

  • check_every (int, optional) – How many EM steps to run between convergence checks. Each check forces a host sync, so checking every step (the previous behaviour) cost up to max_iters syncs per call — bad inside cEBMF’s hot loop. Default is 10.

Returns:

(K,) tensor of optimized mixture weights on the simplex.

Return type:

torch.Tensor

cebmf_torch.utils.optimize_pi_logL_lbfgs(logL, penalty, zero_threshold=1e-06)[source]

Optimize mixture weights via L-BFGS with softmax reparameterisation.

An alternative to optimize_pi_logL() (EM) that produces sparse solutions matching R ashr’s convex optimizer (mixsqp) to 3-5 significant figures.

The simplex constraint is eliminated by optimising unconstrained z where x = softmax(z). L-BFGS builds a positive-definite BFGS Hessian approximation from gradient history, avoiding the near-singular exact Hessian that arises from aliased mixture components.

The Dirichlet penalty is encoded as pseudo-observations following R ashr’s convention: the likelihood matrix is augmented with identity rows weighted by (alpha_k - 1).

Parameters:
  • logL ((n, K) tensor of log-likelihoods.)

  • penalty (float or (K,) tensor, Dirichlet pseudo-count.)

  • zero_threshold (float) – Components with weight below this are set to exactly zero. Default 1e-6: a component at this weight contributes at most one-millionth of the mixture density for any observation, which is negligible for posterior computation. L-BFGS drives inactive components to ~1e-30 via softmax, so any threshold between 1e-10 and 1e-3 gives the same active set.

Returns:

(K,) tensor of optimized mixture weights on the simplex.

Return type:

torch.Tensor

cebmf_torch.utils.posterior_mean_exp(betahat, sebetahat, log_pi, scale)[source]

Vectorised posterior mean and second moment for a spike+exponential mixture prior.

Replaces the previous per-observation Python loop with a single (J, K) tensor pipeline. The math is unchanged: the prior is

theta ~ pi_0 * delta_0 + sum_{k>=1} pi_k * Exp(rate=1/scale[k]),

and the likelihood is x | theta ~ N(theta, s^2).

Parameters:
  • betahat (torch.Tensor) – Observed effect-size estimates, shape (J,).

  • sebetahat (torch.Tensor) – Standard errors of the effect-size estimates, shape (J,).

  • log_pi (torch.Tensor) – Log mixture weights, shape (K,). The mixture is shared across observations.

  • scale (torch.Tensor) – Mixture scales, shape (K,). scale[0] should be 0 (spike) and scale[1:] > 0 are the Exp scales (rate = 1/scale).

Returns:

Container with posterior mean, second moment, and standard deviation (each of shape (J,)).

Return type:

PosteriorMean

cebmf_torch.utils.posterior_mean_norm(betahat, sebetahat, log_pi, data_loglik, scale, location=None)[source]

Compute posterior mean and second moment for a normal mixture prior.

All mixture parameters (log_pi, scale, location) may be either shared across observations (1D, shape (K,)) or per-observation (2D, shape (J, K)). This is what makes the function usable both for classical ASH (one shared prior over a batch) and for covariate-adaptive methods like CASH/EMDN where the neural network emits per-observation mixture parameters.

Parameters:
  • betahat (torch.Tensor) – Observed effect size estimates, shape (J,).

  • sebetahat (torch.Tensor) – Standard errors of the effect size estimates, shape (J,).

  • log_pi (torch.Tensor) – Log mixture weights, shape (K,) or (J, K).

  • data_loglik (torch.Tensor) – Data log-likelihood matrix, shape (J, K).

  • scale (torch.Tensor) – Prior standard deviations, shape (K,) or (J, K). Components with scale == 0 are treated as a point mass at location.

  • location (torch.Tensor or None, optional) – Prior means, shape (K,) or (J, K). If None, uses zeros.

Returns:

Container with posterior mean, second moment, and standard deviation (each of shape (J,)).

Return type:

PosteriorMean

Mathematical Functions

cebmf_torch.utils.maths.log_norm_pdf(x, loc, scale)[source]

Backward-compatible alias for _logpdf_normal() with epsilon-padded scale.

Adds a tiny epsilon to scale before the log/division to tolerate degenerate inputs (e.g., scale==0). Prefer _logpdf_normal() and clamp scale upstream when the caller guarantees positivity.

Return type:

Tensor

Parameters:
  • x (Tensor)

  • loc (Tensor)

  • scale (Tensor)

cebmf_torch.utils.maths.norm_cdf(x)[source]

Compute the standard normal cumulative distribution function (CDF).

Parameters:

x (torch.Tensor) – Input tensor.

Returns:

CDF evaluated at x.

Return type:

torch.Tensor

cebmf_torch.utils.maths.norm_pdf(x)[source]

Compute the standard normal probability density function (PDF).

Parameters:

x (torch.Tensor) – Input tensor.

Returns:

PDF evaluated at x.

Return type:

torch.Tensor

cebmf_torch.utils.maths.logsumexp(x, dim=-1, keepdim=False)[source]

Compute the log of the sum of exponentials of input elements along a given dimension.

Parameters:
  • x (torch.Tensor) – Input tensor.

  • dim (int, optional) – Dimension along which to operate. Default is -1.

  • keepdim (bool, optional) – Whether the output tensor has dim retained or not. Default is False.

Returns:

Result of log-sum-exp operation.

Return type:

torch.Tensor

cebmf_torch.utils.maths.safe_log(x, eps=1e-12)[source]

Compute the logarithm of x with clamping for numerical stability.

Parameters:
  • x (torch.Tensor) – Input tensor.

  • eps (float, optional) – Minimum value to clamp x to. Default is 1e-12.

Returns:

Logarithm of clamped x.

Return type:

torch.Tensor

cebmf_torch.utils.maths.softmax(x, dim=-1)[source]

Compute the softmax of input tensor along the specified dimension.

Parameters:
  • x (torch.Tensor) – Input tensor.

  • dim (int, optional) – Dimension along which softmax will be computed. Default is -1.

Returns:

Softmax of the input tensor.

Return type:

torch.Tensor

cebmf_torch.utils.maths.logphi(z)[source]

Compute the log of the standard normal PDF φ(z) = exp(-z^2/2)/√(2π).

Parameters:

z (torch.Tensor) – Input tensor.

Returns:

Log PDF evaluated at z.

Return type:

torch.Tensor

cebmf_torch.utils.maths.logPhi(z)[source]

Compute the stable log CDF of the standard normal distribution Φ(z).

Parameters:

z (torch.Tensor) – Input tensor.

Returns:

Log CDF evaluated at z.

Return type:

torch.Tensor

cebmf_torch.utils.maths.logscale_sub(logx, logy)[source]

Compute log(exp(logx) - exp(logy)) in a numerically stable way.

Requires logx >= logy.

Parameters:
  • logx (torch.Tensor) – Logarithm of x.

  • logy (torch.Tensor) – Logarithm of y.

Returns:

Logarithm of (exp(logx) - exp(logy)).

Return type:

torch.Tensor

cebmf_torch.utils.maths.logscale_add(logx, logy)[source]

Compute log(exp(logx) + exp(logy)) in a numerically stable way.

Parameters:
  • logx (torch.Tensor) – Logarithm of x.

  • logy (torch.Tensor) – Logarithm of y.

Returns:

Logarithm of (exp(logx) + exp(logy)).

Return type:

torch.Tensor

cebmf_torch.utils.maths.do_truncnorm_argchecks(a, b)[source]

Clamp and sanity check bounds for truncated normal arguments.

Parameters:
  • a (torch.Tensor) – Lower bound(s).

  • b (torch.Tensor) – Upper bound(s).

Returns:

(a, b) after checks.

Return type:

tuple of torch.Tensor

cebmf_torch.utils.maths.safe_tensor_to_float(value, null_value=-inf, reduction='min')[source]

Convert tensor, float, or None to float with safe handling.

Parameters:
  • value (torch.Tensor, float, or None) – Value to convert.

  • null_value (float, optional) – Value to return if input is None or empty. Default is -inf.

  • reduction (str, optional) – Reduction to apply if input is a tensor (“min” or “max”). Default is “min”.

Returns:

Converted float value.

Return type:

float

cebmf_torch.utils.maths.my_etruncnorm(a, b, mean=0.0, sd=1.0, precision='auto')[source]

Compute E[Z | a < Z < b] for Z ~ N(mean, sd^2), the mean of a truncated normal.

Parameters:
  • a (float or torch.Tensor) – Lower truncation bound.

  • b (float or torch.Tensor) – Upper truncation bound.

  • mean (float or torch.Tensor, optional) – Mean of the normal distribution. Default is 0.0.

  • sd (float or torch.Tensor, optional) – Standard deviation of the normal distribution. Default is 1.0.

  • precision (str, optional) – Internal compute dtype. "auto" (default) follows the input dtype (typically float32), which is fast on CUDA. "float64" forces double precision — slower on CUDA but matches the pre-2026 behaviour.

Returns:

Mean of the truncated normal distribution. Dtype matches the chosen precision.

Return type:

torch.Tensor

cebmf_torch.utils.maths.my_e2truncnorm(a, b, mean=0.0, sd=1.0, precision='auto')[source]

Compute E[Z^2 | a < Z < b] for Z ~ N(mean, sd^2), the second moment of a truncated normal.

Parameters:
  • a (float or torch.Tensor) – Lower truncation bound.

  • b (float or torch.Tensor) – Upper truncation bound.

  • mean (float or torch.Tensor, optional) – Mean of the normal distribution. Default is 0.0.

  • sd (float or torch.Tensor, optional) – Standard deviation of the normal distribution. Default is 1.0.

  • precision (str, optional) – Internal compute dtype. See my_etruncnorm() for the contract. "auto" (default) keeps the caller’s dtype (fast on CUDA); "float64" matches the pre-2026 behaviour.

Returns:

Second moment of the truncated normal distribution. Dtype matches the chosen precision.

Return type:

torch.Tensor

Device Management

cebmf_torch.utils.device.get_device(prefer_gpu=True)[source]

Get the best available device.

Priority order: 1. CUDA (NVIDIA GPUs) 2. MPS (Apple Silicon GPUs) 3. CPU (fallback)

Parameters:

prefer_gpu (bool) – Whether to prefer GPU over CPU

Returns:

The selected device

Return type:

torch.device

cebmf_torch.utils.device.to_device(x, device=None)[source]

Move a tensor or module to the specified device.

Parameters:
  • x (object) – Tensor, module, or object supporting the .to() method.

  • device (torch.device or None, optional) – Target device. If None, uses the default device from get_device().

Returns:

The input object moved to the specified device, or unchanged if not supported.

Return type:

object

Prior Registry

Prior registry and prior classes for EBMF.