API Reference
This section documents the API for cebmf_torch.
Main Package
- class cebmf_torch.cEBMF(data, K=5, prior_L='norm', prior_F='norm', internal_epoch=10, prior_L_kwargs=None, prior_F_kwargs=None, allow_backfitting=True, prune_thresh=0.999, noise_type=NoiseType.CONSTANT, S=None, X_l=None, X_f=None, self_row_cov=False, self_col_cov=False, device=None)[source]
Bases:
objectPure-PyTorch Empirical Bayes Matrix Factorization (EBMF) with NaN handling.
Features
Observed-mask weighting in lhat/fhat and their standard errors.
Constant or structured noise precision (scalar or per-row/column tau).
User-supplied (fixed) standard errors via
S— useful for z-scores (passS=1.0) or pre-computed standard errors of effect-size estimates (pass an(N, P)tensor). WhenSis provided, the noise variance is treated as known and is not re-estimated during fitting.Mini-batch optimization for mixture weights inside ash().
Modular prior and covariate support.
- fit(maxit=50)[source]
Fit the cEBMF model for a specified number of iterations.
- Parameters:
maxit (int, optional) – Number of iterations to run. Default is 50.
- Returns:
Result container with fitted factors, noise, and objective history.
- Return type:
CEBMFResult
- initialise_factors(method='svd', *, L=None, F=None)[source]
Initialize factor matrices using the specified method, or user-provided initial factors.
- Parameters:
method (str) – Initialization method (‘svd’, ‘random’, or ‘zero’). Default is ‘svd’. Ignored if L and F are provided.
L (Tensor or None, optional) – User-provided initial factor matrix (N, K). Ignored if F not also provided.
F (Tensor or None, optional) – User-provided initial factor matrix (P, K). Ignored if L not also provided.
- iter_once()[source]
Perform one iteration of the cEBMF update (update all factors and noise).
- update_tau()[source]
Update the noise precision parameter(s) according to the noise model.
Behaviour by noise type:
CONSTANT-> scalar tau; also provides tau_map (N,P) if you need itROW_WISE-> tau_row (N,), tau_map broadcast to (N,P)COLUMN_WISE-> tau_col (P,), tau_map broadcast to (N,P)KNOWN-> no-op; tau_map was built once from the user-suppliedS.
- Parameters:
data (Tensor)
K (int)
prior_L (str)
prior_F (str)
internal_epoch (int)
prior_L_kwargs (dict | None)
prior_F_kwargs (dict | None)
allow_backfitting (bool)
prune_thresh (float)
noise_type (NoiseType)
X_l (Tensor | None)
X_f (Tensor | None)
self_row_cov (bool)
self_col_cov (bool)
device (device | None)
- cebmf_torch.ash(x, s, prior=PriorType.NORM, mult=1.4142135623730951, penalty=10.0, verbose=False, threshold_loglikelihood=-300.0, mode=0.0, *, batch_size=128, shuffle=False, seed=None, optimizer='em')[source]
Adaptive shrinkage with mixture priors (“norm” or “exp”) in pure PyTorch.
Uses EM for π by default (mini-batch capable via batch_size). Set
optimizer="lbfgs"to use L-BFGS with softmax reparameterisation, which produces sparse solutions matching R ashr’s convex optimiser.- Parameters:
x (torch.Tensor) – Observed data.
s (torch.Tensor) – Standard errors of the observed data.
prior (PriorType, optional) – Type of prior to use (default: PriorType.NORM).
mult (float, optional) – Multiplier for scale grid (default: sqrt(2.0)).
penalty (float, optional) – Penalty for mixture weights (default: 10.0).
verbose (bool, optional) – Verbosity flag (default: True).
threshold_loglikelihood (float, optional) – Minimum log-likelihood threshold (default: -300.0).
mode (float, optional) – Mode parameter (for normal prior only, default: 0.0).
batch_size (int or None, optional) – Batch size for EM updates (default: 128).
shuffle (bool, optional) – Whether to shuffle data in EM (default: False).
seed (int or None, optional) – Random seed for reproducibility.
optimizer (str, optional) –
"em"(default) or"lbfgs". L-BFGS produces sparse solutions matching R ashr’s convex optimiser.
- Returns:
Result object containing posterior summaries and model parameters.
- Return type:
ASHResult
- cebmf_torch.ebnm_point_exp(x, s, par_init=None, fix_par=(False, False, True), max_iter=20, tol=1e-06, a_bounds=(0.01, 100.0), loga_l2=0.001, tresh_pi0=0.001, eps=1e-12)[source]
Direct maximization (no EM) of the observed marginal log-likelihood for a point-Exponential EBNM.
Prior on θ: (1 - pi_slab) δ_μ + pi_slab [μ + Exp(a)], with support θ ≥ μ. Returns pure marginal log-likelihood (no penalties).
- cebmf_torch.ebnm_point_laplace(x, s, par_init=None, fix_par=(False, False, True), max_iter=20, tol=1e-06, a_bounds=(0.01, 100.0), loga_l2=0.0, tresh_pi0=0.001, eps=1e-12, pen_pi0=0.0, use_adam_warmstart=False, adam_steps=8, adam_lr=0.01, weight_decay=0.0)[source]
Efficient direct maximisation of the observed marginal log-likelihood for a point-Laplace EBNM. GPU-resident throughout — no host<->device syncs in either the inner objective or the posterior summary.
Optimizer: LBFGS-only (AdamW warm-start optional and short). The four scalar fields of the result (
pi_slab,a,mu,log_lik) are returned as 0-d tensors on the input device so the cEBMF ELBO accumulator can fold them in without forcing a host sync per factor update. Callfloat(field)at your own boundary if you need a Python scalar.
Core Classes
- class cebmf_torch.cEBMF(data, K=5, prior_L='norm', prior_F='norm', internal_epoch=10, prior_L_kwargs=None, prior_F_kwargs=None, allow_backfitting=True, prune_thresh=0.999, noise_type=NoiseType.CONSTANT, S=None, X_l=None, X_f=None, self_row_cov=False, self_col_cov=False, device=None)[source]
Bases:
objectPure-PyTorch Empirical Bayes Matrix Factorization (EBMF) with NaN handling.
Features
Observed-mask weighting in lhat/fhat and their standard errors.
Constant or structured noise precision (scalar or per-row/column tau).
User-supplied (fixed) standard errors via
S— useful for z-scores (passS=1.0) or pre-computed standard errors of effect-size estimates (pass an(N, P)tensor). WhenSis provided, the noise variance is treated as known and is not re-estimated during fitting.Mini-batch optimization for mixture weights inside ash().
Modular prior and covariate support.
- fit(maxit=50)[source]
Fit the cEBMF model for a specified number of iterations.
- Parameters:
maxit (int, optional) – Number of iterations to run. Default is 50.
- Returns:
Result container with fitted factors, noise, and objective history.
- Return type:
CEBMFResult
- initialise_factors(method='svd', *, L=None, F=None)[source]
Initialize factor matrices using the specified method, or user-provided initial factors.
- Parameters:
method (str) – Initialization method (‘svd’, ‘random’, or ‘zero’). Default is ‘svd’. Ignored if L and F are provided.
L (Tensor or None, optional) – User-provided initial factor matrix (N, K). Ignored if F not also provided.
F (Tensor or None, optional) – User-provided initial factor matrix (P, K). Ignored if L not also provided.
- update_tau()[source]
Update the noise precision parameter(s) according to the noise model.
Behaviour by noise type:
CONSTANT-> scalar tau; also provides tau_map (N,P) if you need itROW_WISE-> tau_row (N,), tau_map broadcast to (N,P)COLUMN_WISE-> tau_col (P,), tau_map broadcast to (N,P)KNOWN-> no-op; tau_map was built once from the user-suppliedS.
- Parameters:
data (Tensor)
K (int)
prior_L (str)
prior_F (str)
internal_epoch (int)
prior_L_kwargs (dict | None)
prior_F_kwargs (dict | None)
allow_backfitting (bool)
prune_thresh (float)
noise_type (NoiseType)
X_l (Tensor | None)
X_f (Tensor | None)
self_row_cov (bool)
self_col_cov (bool)
device (device | None)
EBNM Solvers
Empirical Bayes Normal Means (EBNM) solvers.
- cebmf_torch.ebnm.ash(x, s, prior=PriorType.NORM, mult=1.4142135623730951, penalty=10.0, verbose=False, threshold_loglikelihood=-300.0, mode=0.0, *, batch_size=128, shuffle=False, seed=None, optimizer='em')[source]
Adaptive shrinkage with mixture priors (“norm” or “exp”) in pure PyTorch.
Uses EM for π by default (mini-batch capable via batch_size). Set
optimizer="lbfgs"to use L-BFGS with softmax reparameterisation, which produces sparse solutions matching R ashr’s convex optimiser.- Parameters:
x (torch.Tensor) – Observed data.
s (torch.Tensor) – Standard errors of the observed data.
prior (PriorType, optional) – Type of prior to use (default: PriorType.NORM).
mult (float, optional) – Multiplier for scale grid (default: sqrt(2.0)).
penalty (float, optional) – Penalty for mixture weights (default: 10.0).
verbose (bool, optional) – Verbosity flag (default: True).
threshold_loglikelihood (float, optional) – Minimum log-likelihood threshold (default: -300.0).
mode (float, optional) – Mode parameter (for normal prior only, default: 0.0).
batch_size (int or None, optional) – Batch size for EM updates (default: 128).
shuffle (bool, optional) – Whether to shuffle data in EM (default: False).
seed (int or None, optional) – Random seed for reproducibility.
optimizer (str, optional) –
"em"(default) or"lbfgs". L-BFGS produces sparse solutions matching R ashr’s convex optimiser.
- Returns:
Result object containing posterior summaries and model parameters.
- Return type:
ASHResult
- cebmf_torch.ebnm.ebnm_point_exp(x, s, par_init=None, fix_par=(False, False, True), max_iter=20, tol=1e-06, a_bounds=(0.01, 100.0), loga_l2=0.001, tresh_pi0=0.001, eps=1e-12)[source]
Direct maximization (no EM) of the observed marginal log-likelihood for a point-Exponential EBNM.
Prior on θ: (1 - pi_slab) δ_μ + pi_slab [μ + Exp(a)], with support θ ≥ μ. Returns pure marginal log-likelihood (no penalties).
- cebmf_torch.ebnm.ebnm_point_laplace(x, s, par_init=None, fix_par=(False, False, True), max_iter=20, tol=1e-06, a_bounds=(0.01, 100.0), loga_l2=0.0, tresh_pi0=0.001, eps=1e-12, pen_pi0=0.0, use_adam_warmstart=False, adam_steps=8, adam_lr=0.01, weight_decay=0.0)[source]
Efficient direct maximisation of the observed marginal log-likelihood for a point-Laplace EBNM. GPU-resident throughout — no host<->device syncs in either the inner objective or the posterior summary.
Optimizer: LBFGS-only (AdamW warm-start optional and short). The four scalar fields of the result (
pi_slab,a,mu,log_lik) are returned as 0-d tensors on the input device so the cEBMF ELBO accumulator can fold them in without forcing a host sync per factor update. Callfloat(field)at your own boundary if you need a Python scalar.
- cebmf_torch.ebnm.ebnm_gb(x, s, omega=0.2, par_init_mu=1.0, par_init_pi=0.2, max_em=200, tol_em=1e-05, max_lbfgs=200, tol_lbfgs=1e-06, eps=1e-12)[source]
- Return type:
EBNMGBResult- Parameters:
- EBNM with Generalized Binary prior:
θ ~ (1-π) δ0 + π N_+(μ, σ^2), with σ = ω μ, μ≥0, ω fixed.
Follows the EM scheme in Supplementary Note (eqs. (16)-(27)). :contentReference[oaicite:1]{index=1}
- cebmf_torch.ash(x, s, prior=PriorType.NORM, mult=1.4142135623730951, penalty=10.0, verbose=False, threshold_loglikelihood=-300.0, mode=0.0, *, batch_size=128, shuffle=False, seed=None, optimizer='em')[source]
Adaptive shrinkage with mixture priors (“norm” or “exp”) in pure PyTorch.
Uses EM for π by default (mini-batch capable via batch_size). Set
optimizer="lbfgs"to use L-BFGS with softmax reparameterisation, which produces sparse solutions matching R ashr’s convex optimiser.- Parameters:
x (torch.Tensor) – Observed data.
s (torch.Tensor) – Standard errors of the observed data.
prior (PriorType, optional) – Type of prior to use (default: PriorType.NORM).
mult (float, optional) – Multiplier for scale grid (default: sqrt(2.0)).
penalty (float, optional) – Penalty for mixture weights (default: 10.0).
verbose (bool, optional) – Verbosity flag (default: True).
threshold_loglikelihood (float, optional) – Minimum log-likelihood threshold (default: -300.0).
mode (float, optional) – Mode parameter (for normal prior only, default: 0.0).
batch_size (int or None, optional) – Batch size for EM updates (default: 128).
shuffle (bool, optional) – Whether to shuffle data in EM (default: False).
seed (int or None, optional) – Random seed for reproducibility.
optimizer (str, optional) –
"em"(default) or"lbfgs". L-BFGS produces sparse solutions matching R ashr’s convex optimiser.
- Returns:
Result object containing posterior summaries and model parameters.
- Return type:
ASHResult
- cebmf_torch.ebnm_point_exp(x, s, par_init=None, fix_par=(False, False, True), max_iter=20, tol=1e-06, a_bounds=(0.01, 100.0), loga_l2=0.001, tresh_pi0=0.001, eps=1e-12)[source]
Direct maximization (no EM) of the observed marginal log-likelihood for a point-Exponential EBNM.
Prior on θ: (1 - pi_slab) δ_μ + pi_slab [μ + Exp(a)], with support θ ≥ μ. Returns pure marginal log-likelihood (no penalties).
- cebmf_torch.ebnm_point_laplace(x, s, par_init=None, fix_par=(False, False, True), max_iter=20, tol=1e-06, a_bounds=(0.01, 100.0), loga_l2=0.0, tresh_pi0=0.001, eps=1e-12, pen_pi0=0.0, use_adam_warmstart=False, adam_steps=8, adam_lr=0.01, weight_decay=0.0)[source]
Efficient direct maximisation of the observed marginal log-likelihood for a point-Laplace EBNM. GPU-resident throughout — no host<->device syncs in either the inner objective or the posterior summary.
Optimizer: LBFGS-only (AdamW warm-start optional and short). The four scalar fields of the result (
pi_slab,a,mu,log_lik) are returned as 0-d tensors on the input device so the cEBMF ELBO accumulator can fold them in without forcing a host sync per factor update. Callfloat(field)at your own boundary if you need a Python scalar.
Covariate-Enhanced EBNM
Covariate-Enhanced Empirical Bayes Normal Means (cEBNM) solvers.
- cebmf_torch.cebnm.cash_posterior_means(X, betahat, sebetahat, n_epochs=20, n_layers=4, num_classes=20, hidden_dim=64, batch_size=128, lr=0.001, model_param=None, penalty=1.5, device=None)[source]
GPU-native CASH training and posterior computation.
Fit a CASH (Covariate Adaptive Shrinkage) model and compute posterior means, second moments, and standard deviations.
- Parameters:
X (torch.Tensor or np.ndarray) – Covariates for each observation, shape (n_samples, n_features).
betahat (torch.Tensor or np.ndarray) – Observed effect estimates, shape (n_samples,).
sebetahat (torch.Tensor or np.ndarray) – Standard errors of the effect estimates, shape (n_samples,).
n_epochs (int, optional) – Number of training epochs (default=20).
n_layers (int, optional) – Number of hidden layers in the neural network (default=4).
num_classes (int, optional) – Number of mixture components (default=20).
hidden_dim (int, optional) – Number of hidden units in each layer (default=64).
batch_size (int, optional) – Batch size for training (default=128).
lr (float, optional) – Learning rate for the optimizer (default=0.001).
model_param (dict, optional) – Pre-trained model parameters to initialize the network.
penalty (float, optional) – Penalty for spike probability (default=1.5).
device (torch.device, optional) – Target device for tensors and the model. If
None, inherits frombetahatwhen it is already a tensor; otherwise falls back to CUDA (if available) or CPU. Inheriting from the input avoids silent cross-device hops when the caller (e.g.cEBMF) is on CPU/MPS but CUDA is also visible on the host.
- Returns:
Container with posterior means, standard deviations, and model parameters.
- Return type:
cash_PosteriorMeanNorm
- cebmf_torch.cebnm.cgb_posterior_means(X, betahat, sebetahat, n_epochs=50, n_layers=2, hidden_dim=32, batch_size=128, lr=0.001, penalty=1.5, model_param=None, device=None)[source]
Fit a Covariate Generalized-Binary (CGB) model to estimate the prior distribution of effects.
- Parameters:
X (torch.Tensor or np.ndarray) – Covariates for each observation, shape (n_samples, n_features).
betahat (torch.Tensor or np.ndarray) – Observed effect estimates, shape (n_samples,).
sebetahat (torch.Tensor or np.ndarray) – Standard errors of the effect estimates, shape (n_samples,).
n_epochs (int, optional) – Number of training epochs (default=50).
n_layers (int, optional) – Number of hidden layers in the neural network (default=2).
hidden_dim (int, optional) – Number of hidden units in each layer (default=32).
batch_size (int, optional) – Batch size for training (default=128).
lr (float, optional) – Learning rate for the optimizer (default=1e-3).
penalty (float, optional) – Penalty for spike probability (default=1.5).
model_param (dict, optional) – Pre-trained model parameters to initialize the network.
device (device | None)
- Returns:
Container with posterior means, standard deviations, and model parameters.
- Return type:
CgbPosteriorResult
- cebmf_torch.cebnm.emdn_posterior_means(X, betahat, sebetahat, n_epochs=50, n_layers=4, n_gaussians=5, hidden_dim=64, batch_size=512, lr=0.001, model_param=None, device=None)[source]
Fit a Mixture Density Network (MDN) to estimate the prior distribution of effects.
In the EBNM problem, we observe estimates betahat with standard errors sebetahat and want to estimate the prior distribution of the true effects. The prior is modeled as a mixture of Gaussians with parameters predicted by a neural network.
- Parameters:
X (torch.Tensor or np.ndarray) – Covariates for each observation, shape (n_samples, n_features).
betahat (torch.Tensor or np.ndarray) – Observed effect estimates, shape (n_samples,).
sebetahat (torch.Tensor or np.ndarray) – Standard errors of the effect estimates, shape (n_samples,).
n_epochs (int, optional) – Number of training epochs (default=50).
n_layers (int, optional) – Number of hidden layers in the neural network (default=4).
n_gaussians (int, optional) – Number of Gaussian components in the mixture (default=5).
hidden_dim (int, optional) – Number of hidden units in each layer (default=64).
batch_size (int, optional) – Batch size for training (default=512).
lr (float, optional) – Learning rate for the optimizer (default=1e-3).
model_param (dict, optional) – Pre-trained model parameters to initialize the network.
device (torch.device, optional) – Target device for tensors/models. Defaults to CUDA if available, else CPU.
- Returns:
Container with posterior means, standard deviations, and model parameters.
- Return type:
EmdnPosteriorMeanNorm
- cebmf_torch.cebnm.lcash_posterior_means(X, betahat, sebetahat, n_epochs=200, batch_size=512, lr=0.001, weight_decay=0.001, penalty=1.5, mult=1.4142135623730951, ash_init=True, ash_threshold=1e-06, model_param=None, device=None, verbose=True, seed=42)[source]
LC-ASH: linear covariate-modulated mixture weights.
- Parameters:
X (tensor (G, F)) – Feature matrix. Standardised internally with NaN-aware statistics (mean/std computed on non-NaN values, NaN positions zero-filled). Pre-standardisation is not required.
betahat (tensor (G,)) – Effect estimates.
sebetahat (tensor (G,)) – Standard errors.
n_epochs (int or None) – Training epochs. Inside cEBMF, overridden by internal_epoch.
batch_size (int) – Mini-batch size for Adam.
lr (float) – Learning rate.
weight_decay (float) – L2 penalty on feature coefficients only (not bias).
penalty (float) – Dirichlet spike penalty (lambda_pen). 1.0 = no penalty.
mult (float) – Multiplicative step between mixture grid SDs. Smaller values give a finer grid with more components. Default sqrt(2) matches R ashr and gives ~27 components before pruning.
ash_init (bool) – If True (default), run ash internally (L-BFGS optimizer) to prune the grid to active components and initialise the bias from the ash weights, so the model starts at the exchangeable ash solution when all feature coefficients are zero. If False, use the full grid with uniform bias initialisation.
ash_threshold (float) – Pruning threshold: components with pi <= threshold are dropped. Only used when
ash_init=True.model_param (dict or None) – State dict from a previous call, for warm-starting.
device (torch.device or None) – Compute device. Defaults to CUDA if available.
verbose (bool) – If True (default), print training progress every 50 epochs.
seed (int) – Random seed for weight initialisation and batch ordering.
- Returns:
Container with post_mean, post_mean2, post_sd, pi_np (G, K), scale (K,), loss, model_param (state dict for warm-starting).
- Return type:
cash_PosteriorMeanNorm
- cebmf_torch.cebnm.po_lcash_posterior_means(X, betahat, sebetahat, n_epochs=200, batch_size=512, lr=0.001, weight_decay=0.001, penalty=1.5, mult=1.4142135623730951, ash_init=True, ash_threshold=1e-06, model_param=None, device=None, verbose=True, seed=42)[source]
Proportional odds LC-ASH: ordered logistic covariate-modulated weights.
A shared weight vector maps features to a scalar signal strength s_i = x_i^T w. K-1 ordered cut-points convert s_i to mixture weights via cumulative logistic probabilities. This has F + K - 1 parameters (vs K * F for softmax LC-ASH), making it more parsimonious when K is large relative to F.
When
ash_init=True, the grid is pruned to ash’s active components and the cut-points are initialised from the ash weights, so the model starts at the exchangeable ash solution.- Parameters:
X (tensor (G, F)) – Feature matrix. Standardised internally with NaN-aware statistics.
betahat (
Tensor)sebetahat (
Tensor)batch_size (
int)lr (
float)weight_decay (
float)penalty (
float)mult (float)
ash_init (bool)
ash_threshold (float)
model_param (dict | None)
device (device | None)
verbose (bool)
seed (int)
- Return type:
cash_PosteriorMeanNorm
:param : :type mult:
float:param mult: Seelcash_posterior_means. :type ash_init:bool:param ash_init: Seelcash_posterior_means. :type ash_threshold:float:param ash_threshold: Seelcash_posterior_means. :type model_param:dict|None:param model_param: Seelcash_posterior_means. :type device:device|None:param device: Seelcash_posterior_means. :type verbose:bool:param verbose: Seelcash_posterior_means. :type seed:int:param seed: Seelcash_posterior_means.- Returns:
Container with post_mean, post_mean2, post_sd, pi_np (G, K), scale (K,), loss, model_param (state dict for warm-starting).
- Return type:
cash_PosteriorMeanNorm
- Parameters:
- cebmf_torch.cebnm.sharp_cgb_posterior_means(X, betahat, sebetahat, n_epochs=50, n_layers=2, omega=0.02, hidden_dim=32, batch_size=128, lr=0.001, penalty=1.5, model_param=None, eps=1e-08, device=None)[source]
Fit a Covariate Generalized-Binary (CGB) model to estimate the prior distribution of effects.
- Parameters:
X (torch.Tensor or np.ndarray) – Covariates for each observation, shape (n_samples, n_features).
betahat (torch.Tensor or np.ndarray) – Observed effect estimates, shape (n_samples,).
sebetahat (torch.Tensor or np.ndarray) – Standard errors of the effect estimates, shape (n_samples,).
n_epochs (int, optional) – Number of training epochs (default=50).
n_layers (int, optional) – Number of hidden layers in the neural network (default=2).
hidden_dim (int, optional) – Number of hidden units in each layer (default=32).
batch_size (int, optional) – Batch size for training (default=128).
lr (float, optional) – Learning rate for the optimizer (default=1e-3).
penalty (float, optional) – Penalty for spike probability (default=1.5).
model_param (dict, optional) – Pre-trained model parameters to initialize the network.
device (device | None)
- Returns:
Container with posterior means, standard deviations, and model parameters.
- Return type:
CgbPosteriorResult
- cebmf_torch.cebnm.spiked_emdn_posterior_means(X, betahat, sebetahat, n_epochs=50, n_layers=4, n_gaussians=5, hidden_dim=64, batch_size=512, lr=0.001, model_param=None, *, penalty=1.5, beta_prior=None, print_every=10, device=None)[source]
Fit a Mixture Density Network (MDN) with spike and slab components to estimate the prior distribution of effects.
In the EBNM problem, we observe estimates betahat with standard errors sebetahat and want to estimate the prior distribution of the true effects. The prior is modeled as a mixture of Gaussians (slabs) plus a point mass at zero (spike), with mixture parameters predicted by a neural network.
- Parameters:
X (torch.Tensor or np.ndarray) – Covariates for each observation, shape (n_samples, n_features).
betahat (torch.Tensor or np.ndarray) – Observed effect estimates, shape (n_samples,).
sebetahat (torch.Tensor or np.ndarray) – Standard errors of the effect estimates, shape (n_samples,).
n_epochs (int, optional) – Number of training epochs (default=50).
n_layers (int, optional) – Number of hidden layers in the neural network (default=4).
n_gaussians (int, optional) – Number of Gaussian components in the mixture (default=5).
hidden_dim (int, optional) – Number of hidden units in each layer (default=64).
batch_size (int, optional) – Batch size for training (default=512).
lr (float, optional) – Learning rate for the optimizer (default=1e-3).
model_param (dict, optional) – Pre-trained model parameters to initialize the network.
penalty (float, optional) – >1 encourages spike; =1 neutral (default=1.5).
beta_prior (tuple, optional) – (alpha0, beta0) for Beta prior on pi_spike.
print_every (int, optional) – Print training loss every this many epochs (default=10).
device (torch.device, optional) – Target device for tensors/models. Defaults to CUDA if available, else CPU.
- Returns:
Container with posterior means, standard deviations, and model parameters.
- Return type:
EmdnPosteriorMeanNorm
Utilities
Mathematical utilities and helper functions.
- cebmf_torch.utils.get_device(prefer_gpu=True)[source]
Get the best available device.
Priority order: 1. CUDA (NVIDIA GPUs) 2. MPS (Apple Silicon GPUs) 3. CPU (fallback)
- Parameters:
prefer_gpu (
bool) – Whether to prefer GPU over CPU- Returns:
The selected device
- Return type:
torch.device
- cebmf_torch.utils.my_e2truncnorm(a, b, mean=0.0, sd=1.0, precision='auto')[source]
Compute E[Z^2 | a < Z < b] for Z ~ N(mean, sd^2), the second moment of a truncated normal.
- Parameters:
a (float or torch.Tensor) – Lower truncation bound.
b (float or torch.Tensor) – Upper truncation bound.
mean (float or torch.Tensor, optional) – Mean of the normal distribution. Default is 0.0.
sd (float or torch.Tensor, optional) – Standard deviation of the normal distribution. Default is 1.0.
precision (str, optional) – Internal compute dtype. See
my_etruncnorm()for the contract."auto"(default) keeps the caller’s dtype (fast on CUDA);"float64"matches the pre-2026 behaviour.
- Returns:
Second moment of the truncated normal distribution. Dtype matches the chosen
precision.- Return type:
torch.Tensor
- cebmf_torch.utils.my_etruncnorm(a, b, mean=0.0, sd=1.0, precision='auto')[source]
Compute E[Z | a < Z < b] for Z ~ N(mean, sd^2), the mean of a truncated normal.
- Parameters:
a (float or torch.Tensor) – Lower truncation bound.
b (float or torch.Tensor) – Upper truncation bound.
mean (float or torch.Tensor, optional) – Mean of the normal distribution. Default is 0.0.
sd (float or torch.Tensor, optional) – Standard deviation of the normal distribution. Default is 1.0.
precision (str, optional) – Internal compute dtype.
"auto"(default) follows the input dtype (typicallyfloat32), which is fast on CUDA."float64"forces double precision — slower on CUDA but matches the pre-2026 behaviour.
- Returns:
Mean of the truncated normal distribution. Dtype matches the chosen
precision.- Return type:
torch.Tensor
- cebmf_torch.utils.safe_tensor_to_float(value, null_value=-inf, reduction='min')[source]
Convert tensor, float, or None to float with safe handling.
- Parameters:
- Returns:
Converted float value.
- Return type:
- cebmf_torch.utils.autoselect_scales_mix_norm(betahat, sebetahat, max_class=None, mult=2.0)[source]
Automatically select scales for a normal mixture prior.
- Parameters:
betahat (torch.Tensor) – Observed effect size estimates.
sebetahat (torch.Tensor) – Standard errors of the effect size estimates.
max_class (int or None, optional) – If provided, number of mixture components (scales) to return.
mult (float, optional) – Multiplicative step between scales. Default is 2.0.
- Returns:
1D tensor of selected scales.
- Return type:
torch.Tensor
- cebmf_torch.utils.optimize_pi_logL(logL, penalty, max_iters=100, tol=1e-06, verbose=False, batch_size=None, shuffle=False, seed=None, check_every=10)[source]
EM algorithm for optimizing mixture weights on the simplex given a log-likelihood matrix.
- Parameters:
logL (torch.Tensor) – (n, K) tensor with entries logL[j, k] = log l_{jk}.
penalty (float or torch.Tensor) – Dirichlet pseudo-count alpha_1 on component 0 (or a length-K vector). In the original code, vec_pen[0] = penalty and others = 1.
max_iters (int, optional) – Number of EM epochs. Default is 100.
tol (float, optional) – L2 tolerance on pi change for convergence. Default is 1e-6.
verbose (bool, optional) – Print convergence message if True. Default is False (cEBMF calls this in its inner loop and the convergence prints would spam stdout).
batch_size (int or None, optional) – If None, do full-batch; else iterate over mini-batches each epoch. Default is None.
shuffle (bool, optional) – Whether to shuffle rows each epoch when using batches. Default is False.
seed (int or None, optional) – RNG seed used when shuffle=True. Default is None.
check_every (int, optional) – How many EM steps to run between convergence checks. Each check forces a host sync, so checking every step (the previous behaviour) cost up to
max_iterssyncs per call — bad inside cEBMF’s hot loop. Default is 10.
- Returns:
(K,) tensor of optimized mixture weights on the simplex.
- Return type:
torch.Tensor
- cebmf_torch.utils.optimize_pi_logL_lbfgs(logL, penalty, zero_threshold=1e-06)[source]
Optimize mixture weights via L-BFGS with softmax reparameterisation.
An alternative to
optimize_pi_logL()(EM) that produces sparse solutions matching R ashr’s convex optimizer (mixsqp) to 3-5 significant figures.The simplex constraint is eliminated by optimising unconstrained
zwherex = softmax(z). L-BFGS builds a positive-definite BFGS Hessian approximation from gradient history, avoiding the near-singular exact Hessian that arises from aliased mixture components.The Dirichlet penalty is encoded as pseudo-observations following R ashr’s convention: the likelihood matrix is augmented with identity rows weighted by
(alpha_k - 1).- Parameters:
logL ((n, K) tensor of log-likelihoods.)
penalty (float or (K,) tensor, Dirichlet pseudo-count.)
zero_threshold (float) – Components with weight below this are set to exactly zero. Default 1e-6: a component at this weight contributes at most one-millionth of the mixture density for any observation, which is negligible for posterior computation. L-BFGS drives inactive components to ~1e-30 via softmax, so any threshold between 1e-10 and 1e-3 gives the same active set.
- Returns:
(K,) tensor of optimized mixture weights on the simplex.
- Return type:
torch.Tensor
- cebmf_torch.utils.posterior_mean_exp(betahat, sebetahat, log_pi, scale)[source]
Vectorised posterior mean and second moment for a spike+exponential mixture prior.
Replaces the previous per-observation Python loop with a single (J, K) tensor pipeline. The math is unchanged: the prior is
theta ~ pi_0 * delta_0 + sum_{k>=1} pi_k * Exp(rate=1/scale[k]),
and the likelihood is
x | theta ~ N(theta, s^2).- Parameters:
betahat (torch.Tensor) – Observed effect-size estimates, shape
(J,).sebetahat (torch.Tensor) – Standard errors of the effect-size estimates, shape
(J,).log_pi (torch.Tensor) – Log mixture weights, shape
(K,). The mixture is shared across observations.scale (torch.Tensor) – Mixture scales, shape
(K,).scale[0]should be0(spike) andscale[1:] > 0are the Exp scales (rate= 1/scale).
- Returns:
Container with posterior mean, second moment, and standard deviation (each of shape
(J,)).- Return type:
PosteriorMean
- cebmf_torch.utils.posterior_mean_norm(betahat, sebetahat, log_pi, data_loglik, scale, location=None)[source]
Compute posterior mean and second moment for a normal mixture prior.
All mixture parameters (
log_pi,scale,location) may be either shared across observations (1D, shape(K,)) or per-observation (2D, shape(J, K)). This is what makes the function usable both for classical ASH (one shared prior over a batch) and for covariate-adaptive methods like CASH/EMDN where the neural network emits per-observation mixture parameters.- Parameters:
betahat (torch.Tensor) – Observed effect size estimates, shape
(J,).sebetahat (torch.Tensor) – Standard errors of the effect size estimates, shape
(J,).log_pi (torch.Tensor) – Log mixture weights, shape
(K,)or(J, K).data_loglik (torch.Tensor) – Data log-likelihood matrix, shape
(J, K).scale (torch.Tensor) – Prior standard deviations, shape
(K,)or(J, K). Components withscale == 0are treated as a point mass atlocation.location (torch.Tensor or None, optional) – Prior means, shape
(K,)or(J, K). IfNone, uses zeros.
- Returns:
Container with posterior mean, second moment, and standard deviation (each of shape
(J,)).- Return type:
PosteriorMean
Mathematical Functions
- cebmf_torch.utils.maths.log_norm_pdf(x, loc, scale)[source]
Backward-compatible alias for
_logpdf_normal()with epsilon-padded scale.Adds a tiny epsilon to
scalebefore the log/division to tolerate degenerate inputs (e.g.,scale==0). Prefer_logpdf_normal()and clampscaleupstream when the caller guarantees positivity.- Return type:
Tensor- Parameters:
x (Tensor)
loc (Tensor)
scale (Tensor)
- cebmf_torch.utils.maths.norm_cdf(x)[source]
Compute the standard normal cumulative distribution function (CDF).
- Parameters:
x (torch.Tensor) – Input tensor.
- Returns:
CDF evaluated at x.
- Return type:
torch.Tensor
- cebmf_torch.utils.maths.norm_pdf(x)[source]
Compute the standard normal probability density function (PDF).
- Parameters:
x (torch.Tensor) – Input tensor.
- Returns:
PDF evaluated at x.
- Return type:
torch.Tensor
- cebmf_torch.utils.maths.logsumexp(x, dim=-1, keepdim=False)[source]
Compute the log of the sum of exponentials of input elements along a given dimension.
- cebmf_torch.utils.maths.safe_log(x, eps=1e-12)[source]
Compute the logarithm of x with clamping for numerical stability.
- Parameters:
x (torch.Tensor) – Input tensor.
eps (float, optional) – Minimum value to clamp x to. Default is 1e-12.
- Returns:
Logarithm of clamped x.
- Return type:
torch.Tensor
- cebmf_torch.utils.maths.softmax(x, dim=-1)[source]
Compute the softmax of input tensor along the specified dimension.
- Parameters:
x (torch.Tensor) – Input tensor.
dim (int, optional) – Dimension along which softmax will be computed. Default is -1.
- Returns:
Softmax of the input tensor.
- Return type:
torch.Tensor
- cebmf_torch.utils.maths.logphi(z)[source]
Compute the log of the standard normal PDF φ(z) = exp(-z^2/2)/√(2π).
- Parameters:
z (torch.Tensor) – Input tensor.
- Returns:
Log PDF evaluated at z.
- Return type:
torch.Tensor
- cebmf_torch.utils.maths.logPhi(z)[source]
Compute the stable log CDF of the standard normal distribution Φ(z).
- Parameters:
z (torch.Tensor) – Input tensor.
- Returns:
Log CDF evaluated at z.
- Return type:
torch.Tensor
- cebmf_torch.utils.maths.logscale_sub(logx, logy)[source]
Compute log(exp(logx) - exp(logy)) in a numerically stable way.
Requires logx >= logy.
- Parameters:
logx (torch.Tensor) – Logarithm of x.
logy (torch.Tensor) – Logarithm of y.
- Returns:
Logarithm of (exp(logx) - exp(logy)).
- Return type:
torch.Tensor
- cebmf_torch.utils.maths.logscale_add(logx, logy)[source]
Compute log(exp(logx) + exp(logy)) in a numerically stable way.
- Parameters:
logx (torch.Tensor) – Logarithm of x.
logy (torch.Tensor) – Logarithm of y.
- Returns:
Logarithm of (exp(logx) + exp(logy)).
- Return type:
torch.Tensor
- cebmf_torch.utils.maths.do_truncnorm_argchecks(a, b)[source]
Clamp and sanity check bounds for truncated normal arguments.
- Parameters:
a (torch.Tensor) – Lower bound(s).
b (torch.Tensor) – Upper bound(s).
- Returns:
(a, b) after checks.
- Return type:
tuple of torch.Tensor
- cebmf_torch.utils.maths.safe_tensor_to_float(value, null_value=-inf, reduction='min')[source]
Convert tensor, float, or None to float with safe handling.
- Parameters:
- Returns:
Converted float value.
- Return type:
- cebmf_torch.utils.maths.my_etruncnorm(a, b, mean=0.0, sd=1.0, precision='auto')[source]
Compute E[Z | a < Z < b] for Z ~ N(mean, sd^2), the mean of a truncated normal.
- Parameters:
a (float or torch.Tensor) – Lower truncation bound.
b (float or torch.Tensor) – Upper truncation bound.
mean (float or torch.Tensor, optional) – Mean of the normal distribution. Default is 0.0.
sd (float or torch.Tensor, optional) – Standard deviation of the normal distribution. Default is 1.0.
precision (str, optional) – Internal compute dtype.
"auto"(default) follows the input dtype (typicallyfloat32), which is fast on CUDA."float64"forces double precision — slower on CUDA but matches the pre-2026 behaviour.
- Returns:
Mean of the truncated normal distribution. Dtype matches the chosen
precision.- Return type:
torch.Tensor
- cebmf_torch.utils.maths.my_e2truncnorm(a, b, mean=0.0, sd=1.0, precision='auto')[source]
Compute E[Z^2 | a < Z < b] for Z ~ N(mean, sd^2), the second moment of a truncated normal.
- Parameters:
a (float or torch.Tensor) – Lower truncation bound.
b (float or torch.Tensor) – Upper truncation bound.
mean (float or torch.Tensor, optional) – Mean of the normal distribution. Default is 0.0.
sd (float or torch.Tensor, optional) – Standard deviation of the normal distribution. Default is 1.0.
precision (str, optional) – Internal compute dtype. See
my_etruncnorm()for the contract."auto"(default) keeps the caller’s dtype (fast on CUDA);"float64"matches the pre-2026 behaviour.
- Returns:
Second moment of the truncated normal distribution. Dtype matches the chosen
precision.- Return type:
torch.Tensor
Device Management
- cebmf_torch.utils.device.get_device(prefer_gpu=True)[source]
Get the best available device.
Priority order: 1. CUDA (NVIDIA GPUs) 2. MPS (Apple Silicon GPUs) 3. CPU (fallback)
- Parameters:
prefer_gpu (
bool) – Whether to prefer GPU over CPU- Returns:
The selected device
- Return type:
torch.device
Prior Registry
Prior registry and prior classes for EBMF.