Source code for cebmf_torch.cebmf.cebmf

import math
from dataclasses import dataclass
from enum import StrEnum, auto
from warnings import warn

import torch
from torch import Tensor

from cebmf_torch.cebmf._initialisation import INIT_STRATEGIES, user_provided_factors
from cebmf_torch.priors import PRIOR_REGISTRY
from cebmf_torch.utils.device import get_device
from cebmf_torch.utils.maths import safe_tensor_to_float

# Add at top of file after imports:
NUMERICAL_EPS = 1e-12
DEFAULT_PRUNE_THRESH = 1 - 1e-3


@dataclass
class CEBMFResult:
    """
    Container for cEBMF results.

    Attributes
    ----------
    L : Tensor
        Left factor matrix.
    F : Tensor
        Right factor matrix.
    tau : Tensor
        Noise precision(s).
    history_obj : list
        History of objective values.
    """

    L: Tensor
    F: Tensor
    tau: Tensor
    history_obj: list


class NoiseType(StrEnum):
    CONSTANT = auto()
    ROW_WISE = auto()
    COLUMN_WISE = auto()
    KNOWN = auto()  # user-supplied standard errors; variance is fixed (not learned)


@dataclass
class ModelParams:
    K: int = 5
    prior_L: str = "norm"
    prior_F: str = "norm"
    allow_backfitting: bool = True
    prune_thresh: float = DEFAULT_PRUNE_THRESH


@dataclass
class NoiseParams:
    type: NoiseType = NoiseType.CONSTANT


@dataclass
class CovariateParams:
    X_l: Tensor | None = None
    X_f: Tensor | None = None
    self_row_cov: bool = False
    self_col_cov: bool = False


[docs] class cEBMF: """ Pure-PyTorch Empirical Bayes Matrix Factorization (EBMF) with NaN handling. Features -------- - Observed-mask weighting in lhat/fhat and their standard errors. - Constant or structured noise precision (scalar or per-row/column tau). - User-supplied (fixed) standard errors via ``S`` — useful for z-scores (pass ``S=1.0``) or pre-computed standard errors of effect-size estimates (pass an ``(N, P)`` tensor). When ``S`` is provided, the noise variance is treated as known and is *not* re-estimated during fitting. - Mini-batch optimization for mixture weights inside ash(). - Modular prior and covariate support. """ def __init__( self, data: torch.Tensor, K: int = 5, prior_L: str = "norm", prior_F: str = "norm", internal_epoch: int = 10, prior_L_kwargs: dict | None = None, prior_F_kwargs: dict | None = None, allow_backfitting: bool = True, prune_thresh: float = DEFAULT_PRUNE_THRESH, noise_type: NoiseType = NoiseType.CONSTANT, S: torch.Tensor | float | int | None = None, X_l: torch.Tensor | None = None, X_f: torch.Tensor | None = None, self_row_cov: bool = False, self_col_cov: bool = False, device: torch.device | None = None, ): """ Parameters ---------- data : torch.Tensor Observed data matrix of shape (N, P). NaN entries are treated as missing. K : int, optional Initial number of factors. Default 5. prior_L, prior_F : str, optional Prior names to use for the row/column factors. internal_epoch : int, optional Number of inner epochs for the prior fitting routine. prior_L_kwargs, prior_F_kwargs : dict or None, optional Extra keyword arguments forwarded to the prior builders. allow_backfitting : bool, optional If True, allow factor pruning between iterations. prune_thresh : float, optional π0 threshold above which a factor is pruned. noise_type : NoiseType, optional Structure of the (learned) noise variance. Ignored if ``S`` is provided. S : torch.Tensor, float, int, or None, optional User-supplied standard errors. When provided, the noise variance is taken as fixed (not estimated) and ``noise_type`` is forced to :attr:`NoiseType.KNOWN`. ``S`` may be: * a scalar — typical for z-scores (``S=1.0``) — applied to every entry; * a tensor broadcastable to ``data.shape`` — for example, ``(N, P)`` if every observation has its own standard error, or ``(P,)`` / ``(N, 1)`` for column- or row-wise known SEs. Entries of ``S`` that are NaN, infinite, or non-positive are folded into the missing-data mask (those observations are dropped from the likelihood). A warning is raised if any such entries align with observed values in ``data``. X_l, X_f : torch.Tensor or None, optional External covariates for the row/column factors. self_row_cov, self_col_cov : bool, optional Whether to use other factors as self-covariates. device : torch.device or None, optional Target device. Defaults to the result of :func:`get_device`. """ self.data = data self.device = device or get_device() # If the user supplied S, the noise variance is fixed and any explicit # ``noise_type`` other than the default CONSTANT is silently overridden. if S is not None: if noise_type != NoiseType.CONSTANT: warn( f"`S` was provided so the noise variance is treated as known/fixed; " f"the requested noise_type={noise_type!r} will be ignored.", stacklevel=2, ) noise_type = NoiseType.KNOWN # Build config objects internally self.model = ModelParams( K=K, prior_L=prior_L, prior_F=prior_F, allow_backfitting=allow_backfitting, prune_thresh=prune_thresh ) self.noise = NoiseParams(type=noise_type) # Move covariates to device (if provided) to avoid later CPU↔GPU hops self.covariate = CovariateParams( X_l=(X_l.to(self.device) if X_l is not None else None), X_f=(X_f.to(self.device) if X_f is not None else None), self_row_cov=self_row_cov, self_col_cov=self_col_cov, ) if prior_L_kwargs is None: prior_L_kwargs = {} if prior_F_kwargs is None: prior_F_kwargs = {} # Stash raw S input; normalised to an (N, P) tensor inside _initialise_tensors self._S_input = S self._validate_inputs() self.Y = self.data.to(self.device).float() self.N, self.P = self.Y.shape self._initialise_priors(prior_L_kwargs=prior_L_kwargs, prior_F_kwargs=prior_F_kwargs) self._initialise_tensors() self.internal_epoch = internal_epoch self._factors_initialised = False
[docs] @torch.no_grad() def fit(self, maxit: int = 50): """ Fit the cEBMF model for a specified number of iterations. Parameters ---------- maxit : int, optional Number of iterations to run. Default is 50. Returns ------- CEBMFResult Result container with fitted factors, noise, and objective history. """ if not self._factors_initialised: warn("Factors not initialized; using SVD initialization.", stacklevel=2) self.initialise_factors() for _ in range(maxit): self.iter_once() return CEBMFResult(self.L, self.F, self.tau, self.obj)
[docs] @torch.no_grad() def initialise_factors(self, method: str = "svd", *, L: Tensor | None = None, F: Tensor | None = None): """ Initialize factor matrices using the specified method, or user-provided initial factors. Parameters ---------- method : str Initialization method ('svd', 'random', or 'zero'). Default is 'svd'. Ignored if L and F are provided. L : Tensor or None, optional User-provided initial factor matrix (N, K). Ignored if F not also provided. F : Tensor or None, optional User-provided initial factor matrix (P, K). Ignored if L not also provided. """ def _use_strategy(method: str): initialise_fn = INIT_STRATEGIES[method] self.L, self.F = initialise_fn(self.Y, self.N, self.P, self.model.K, self.device) if L is None and F is not None: warn("Provided F without L; ignoring F and using svd for initialization.", stacklevel=2) _use_strategy("svd") elif L is not None and F is None: warn("Provided L without F; ignoring L and using svd for initialization.", stacklevel=2) _use_strategy("svd") elif L is not None and F is not None: self.L, self.F = user_provided_factors(L, F, self.N, self.P, self.model.K, self.device) elif method not in INIT_STRATEGIES: raise ValueError(f"Unknown initialization method '{method}'. Available: {list(INIT_STRATEGIES.keys())}") else: _use_strategy(method) self.L2 = self.L * self.L self.F2 = self.F * self.F self.R = self.Y0 - self.L @ self.F.T self.R.mul_(self.mask) self.R.nan_to_num_(nan=0.0) self.update_tau() self._factors_initialised = True
[docs] @torch.no_grad() def iter_once(self): """ Perform one iteration of the cEBMF update (update all factors and noise). """ tau_map = None if self.noise.type == NoiseType.CONSTANT else self.tau_map for k in range(self.model.K): self._update_factors(k, tau_map=tau_map, eps=NUMERICAL_EPS) self.update_tau() self._backfit() self._cal_obj()
[docs] @torch.no_grad() def update_tau(self): """ Update the noise precision parameter(s) according to the noise model. Behaviour by noise type: - ``CONSTANT`` -> scalar tau; also provides tau_map (N,P) if you need it - ``ROW_WISE`` -> tau_row (N,), tau_map broadcast to (N,P) - ``COLUMN_WISE`` -> tau_col (P,), tau_map broadcast to (N,P) - ``KNOWN`` -> no-op; tau_map was built once from the user-supplied ``S``. """ if self.noise.type == NoiseType.KNOWN: # Variance is fixed by the user; nothing to update. return R2 = self._expected_residuals_squared() # (N,P), zeros at missing match self.noise.type: case NoiseType.CONSTANT: dim = None case NoiseType.COLUMN_WISE: dim = 0 case NoiseType.ROW_WISE: dim = 1 case _: raise ValueError("type_noise must be 'constant', 'row_wise', 'column_wise', or 'known'") self._update_tau(R2, dim=dim)
# ========================================================================= # Private Methods - Internal Implementation Details # ========================================================================= @torch.no_grad() def _update_factors(self, k: int, tau_map: Tensor | None = None, eps: float = NUMERICAL_EPS) -> None: """Update L[:,k] then F[:,k] using stable, non-inplace residualization.""" # Full residual with current factors (exclude all k implicitly) self._recompute_residual() # --- L update uses Rk = R + L_k F_k^T (add back k's current contrib) Rk = self.R + torch.outer(self.L[:, k], self.F[:, k]) Rk.mul_(self.mask) self._update_L_factor(k, Rk, tau_map, eps) # After updating L, refresh residual with new L_k self._recompute_residual() # --- F update with updated L: again build Rk Rk = self.R + torch.outer(self.L[:, k], self.F[:, k]) Rk.mul_(self.mask) self._update_F_factor(k, Rk, tau_map, eps) # Final residual for next factor self._recompute_residual() @torch.no_grad() def _update_L_factor(self, k: int, Rk: Tensor, tau_map: Tensor | None, eps: float) -> None: """Update L[:,k] and L2[:,k] using provided Rk (residual with k added back).""" mask_f = self.mask if self.mask.dtype.is_floating_point else self.mask.to(self.L.dtype) Fk = self.F[:, k] Fk2 = self.F2[:, k] if tau_map is None: denom_l = (mask_f @ Fk2).clamp_min(eps) # (N,) num_l = Rk @ Fk # (N,) se_l = torch.sqrt(1.0 / (self.tau * denom_l)) else: W = tau_map * mask_f denom_l = (W @ Fk2).clamp_min(eps) # (N,) num_l = (W * Rk) @ Fk # (N,) se_l = torch.sqrt(1.0 / denom_l) lhat = num_l / denom_l # fit prior for L X_model = self._build_covariate_matrix( external_cov=self.covariate.X_l, self_cov_enabled=self.covariate.self_row_cov, factors=self.L, k=k, dim_size=self.N, ) with torch.enable_grad(): resL = self.prior_L_fn.fit( X=X_model, betahat=lhat, internal_epoch=self.internal_epoch, sebetahat=se_l, model_param=self.model_state_L[k], device=self.device, ) # write back self.model_state_L[k] = resL.model_param self.L[:, k] = resL.post_mean self.L2[:, k] = resL.post_mean2 nm_ll_L = normal_means_loglik(x=lhat, s=se_l, Et=resL.post_mean, Et2=resL.post_mean2) self.kl_l[k] = torch.as_tensor((-resL.loss) - nm_ll_L, device=self.device, dtype=self.L.dtype) self.pi0_L[k] = resL.pi0_null @torch.no_grad() def _update_F_factor(self, k: int, Rk: Tensor, tau_map: Tensor | None, eps: float) -> None: """Update F[:,k] and F2[:,k] using provided Rk (residual with k added back).""" mask_f = self.mask if self.mask.dtype.is_floating_point else self.mask.to(self.L.dtype) Lk = self.L[:, k] Lk2 = self.L2[:, k] if tau_map is None: denom_f = (mask_f.T @ Lk2).clamp_min(eps) # (P,) num_f = Rk.T @ Lk # (P,) se_f = torch.sqrt(1.0 / (self.tau * denom_f)) else: W = tau_map * mask_f denom_f = (W.T @ Lk2).clamp_min(eps) num_f = (W * Rk).transpose(0, 1) @ Lk se_f = torch.sqrt(1.0 / denom_f) fhat = num_f / denom_f # fit prior for F X_model = self._build_covariate_matrix( external_cov=self.covariate.X_f, self_cov_enabled=self.covariate.self_col_cov, factors=self.F, k=k, dim_size=self.P, ) with torch.enable_grad(): resF = self.prior_F_fn.fit( X=X_model, betahat=fhat, sebetahat=se_f, internal_epoch=self.internal_epoch, model_param=self.model_state_F[k], device=self.device, ) # write back self.model_state_F[k] = resF.model_param self.F[:, k] = resF.post_mean self.F2[:, k] = resF.post_mean2 nm_ll_F = normal_means_loglik(x=fhat, s=se_f, Et=resF.post_mean, Et2=resF.post_mean2) self.kl_f[k] = torch.as_tensor((-resF.loss) - nm_ll_F, device=self.device, dtype=self.F.dtype) self.pi0_F[k] = resF.pi0_null @torch.no_grad() def _cal_obj(self): # Data term ER2 = self._expected_residuals_squared() if self.noise.type == NoiseType.CONSTANT: ll = self._compute_constant_loglik(ER2) else: ll = self._compute_elementwise_loglik(ER2) KL = self.kl_l.sum() + self.kl_f.sum() loss = (-ll + KL).item() # minimize this (negative ELBO) self.obj.append(loss) @torch.no_grad() def _compute_constant_loglik(self, ER2: Tensor) -> Tensor: m = self.mask.sum().clamp_min(1.0) c2pi = ER2.new_tensor(math.log(2.0 * math.pi)) # tau is scalar-precision in this branch return -0.5 * (m * (c2pi - torch.log(self.tau)) + self.tau * ER2.sum()) @torch.no_grad() def _compute_elementwise_loglik(self, ER2: Tensor) -> Tensor: obs = self.mask.bool() c2pi = ER2.new_tensor(math.log(2.0 * math.pi)) return -0.5 * (c2pi * obs.sum() - torch.log(self.tau_map[obs]).sum() + (self.tau_map * ER2)[obs].sum()) @torch.no_grad() def _backfit(self): if not (self.model.allow_backfitting and self.model.K > 1): return to_drop = [k for k in range(self.model.K) if self._should_prune_factor(k)] if len(to_drop) >= self.model.K: keep_one = min(to_drop) to_drop = [k for k in range(self.model.K) if k != keep_one] # drop highest indices first to avoid reindex churn to_drop_sorted = sorted(to_drop, reverse=True) self._prune_indices(to_drop_sorted) @torch.no_grad() def _update_fitted_value(self): self.Y_fit = self.L @ self.F.T @torch.no_grad() def _expected_residuals_squared(self): """ E[(Y - sum_k L_k F_k)^2] on observed entries. Uses: (Y - E[Y])^2 - sum_k (E[L]^2)(E[F]^2)^T + sum_k E[L^2] E[F^2]^T """ Yfit = self.L @ self.F.T # (N,P) resid_mean_sq = (self.Y0 - Yfit).pow(2) # (N,P) first_moment_sq = (self.L.pow(2)) @ (self.F.pow(2)).T # Σ_k (E[L]^2)(E[F]^2)^T second_moment = self.L2 @ self.F2.T # Σ_k E[L^2] E[F^2]^T # R2 = resid_mean_sq - first_moment_sq + second_moment # R2 = (R2 * self.mask).clamp_min(0.0) # zero where missing; no negatives variance_term = second_moment - first_moment_sq R2 = resid_mean_sq + variance_term # NOT minus! R2 = (R2 * self.mask).clamp_min(0.0) return R2 @torch.no_grad() def _validate_inputs(self) -> None: if self.model.K < 1: raise ValueError(f"K must be >= 1, got {self.model.K}") if torch.isnan(self.data).all(): raise ValueError("Data cannot be all NaN") # More validation... @torch.no_grad() def _initialise_priors(self, prior_L_kwargs: dict, prior_F_kwargs: dict) -> None: self.prior_L_fn = PRIOR_REGISTRY.get_builder(self.model.prior_L) self.prior_L_fn.set_kwargs(**prior_L_kwargs) self.prior_F_fn = PRIOR_REGISTRY.get_builder(self.model.prior_F) self.prior_F_fn.set_kwargs(**prior_F_kwargs) self.model_state_L = [None] * self.model.K self.model_state_F = [None] * self.model.K @torch.no_grad() def _initialise_tensors(self): self.mask = (~torch.isnan(self.Y)).float() # 1 where observed, 0 where NaN self.Y0 = torch.nan_to_num(self.Y, nan=0.0) # zeros where missing self.L = torch.zeros(self.N, self.model.K, device=self.device) self.L2 = torch.zeros(self.N, self.model.K, device=self.device) self.F = torch.zeros(self.P, self.model.K, device=self.device) self.F2 = torch.zeros(self.P, self.model.K, device=self.device) if self.noise.type == NoiseType.KNOWN: # Build tau_map from user-supplied S; may further reduce self.mask / self.Y0. self._setup_known_variance() else: # Initial precision guess for learned-noise modes; refined by update_tau(). self.S = None self.tau = torch.tensor(1.0, device=self.device) self.kl_l = torch.zeros(self.model.K, device=self.device) self.kl_f = torch.zeros(self.model.K, device=self.device) self.pi0_L: list[Tensor | float | None] = [ None ] * self.model.K # store latest pi0 for L[:,k]; scalar or Tensor or None self.pi0_F: list[Tensor | float | None] = [None] * self.model.K self.obj = [] @torch.no_grad() def _setup_known_variance(self) -> None: """ Materialise the user-supplied standard-error matrix into ``self.S`` and the corresponding precision matrix ``self.tau_map = 1 / S**2``. Accepts a scalar (broadcast to ``(N, P)``) or any tensor broadcastable to ``(N, P)``. Entries of S that are NaN, infinite, or non-positive are folded into ``self.mask`` (treated as missing) and replaced with 1 internally so downstream multiplications by ``mask=0`` cleanly zero out their contribution. """ S_in = self._S_input target_shape = (self.N, self.P) target_dtype = self.Y.dtype # --- Normalise S to an (N, P) tensor on the right device/dtype. if isinstance(S_in, (int, float)): S_val = float(S_in) if not math.isfinite(S_val) or S_val <= 0.0: raise ValueError(f"Scalar S must be a positive, finite number, got {S_in!r}") S = torch.full(target_shape, S_val, device=self.device, dtype=target_dtype) else: if not isinstance(S_in, torch.Tensor): S = torch.as_tensor(S_in, dtype=target_dtype, device=self.device) else: S = S_in.to(device=self.device, dtype=target_dtype) if S.shape != target_shape: try: S = S.expand(target_shape).contiguous() except RuntimeError as e: raise ValueError( f"S has shape {tuple(S.shape)}, which is not broadcastable to data shape {target_shape}" ) from e else: S = S.contiguous() # --- Fold non-finite / non-positive entries into the mask. bad = ~torch.isfinite(S) | (S <= 0) if bool(bad.any()): # Only warn about bad-S entries that align with currently-observed data: # those are the ones that actually change the likelihood. observed_and_bad = bad & self.mask.bool() n_observed_bad = int(observed_and_bad.sum().item()) if n_observed_bad > 0: warn( f"S has {n_observed_bad} non-finite or non-positive entries " f"at observed positions in `data`; these entries will be treated " f"as missing.", stacklevel=3, ) # Drop them from the mask and zero the corresponding Y0 entries. good_f = (~bad).to(self.mask.dtype) self.mask = self.mask * good_f self.Y0 = self.Y0 * self.mask # Replace bad S entries with 1.0 internally (any finite positive value works # because mask=0 there annihilates their contribution). S = torch.where(bad, torch.ones_like(S), S) self.S = S self.tau_map = 1.0 / (S * S) # Downstream code (e.g. _update_factors elementwise branch, _cal_obj # elementwise branch) reads self.tau_map; set self.tau to the same tensor # for consistency with the structured-noise convention. self.tau = self.tau_map @torch.no_grad() def _update_tau(self, R2: Tensor, dim: None | int) -> None: if dim not in (None, 0, 1): raise ValueError("dim must be None, 0, or 1") m = self.mask.sum(dim=dim).clamp_min(1.0) mean_R2 = R2.sum(dim=dim) / m tau = 1.0 / (mean_R2.clamp_min(NUMERICAL_EPS)) if dim is None: # Scalar precision. Keep `self.tau` as a 0-d tensor (the loglik branch # for CONSTANT noise uses it directly) and expose a broadcast-only # tau_map for convenience. `expand` shares storage and avoids both # the per-iteration host sync from `tau.item()` and the (N, P) # materialisation that the previous `torch.full` triggered. self.tau = tau # 0-d tensor self.tau_map = tau.view(1, 1).expand(self.N, self.P) # (N, P) view, no copy return view = (-1, 1) if dim == 1 else (1, -1) self.tau_map = tau.view(*view).expand(self.N, self.P) # (N,P) self.tau = self.tau_map # downstream uses elementwise in structured noise @torch.no_grad() def _partial_residual_masked(self, k: int) -> Tensor: # Rk for observed entries only recon = self.L @ self.F.T k_contrib = torch.outer(self.L[:, k], self.F[:, k]) Rk = (self.Y0 - (recon - k_contrib)) * self.mask return Rk @torch.no_grad() def _should_prune_factor(self, k: int) -> bool: """ Remove factor k if we have π₀ info and the smallest π₀ across coordinates (for any side that provided it) is ≥ thresh. (Your spec: use the *lowest* π₀.) """ pi0_min_L = safe_tensor_to_float(self.pi0_L[k]) pi0_min_F = safe_tensor_to_float(self.pi0_F[k]) # if neither side provided π0, don't prune if pi0_min_L == float("-inf") and pi0_min_F == float("-inf"): return False # If either side indicates "all near spike", prune. thresh = self.model.prune_thresh return (pi0_min_L >= thresh) or (pi0_min_F >= thresh) @torch.no_grad() def _prune_indices(self, idxs: list[int]) -> None: """In-place prune of K and all factor-aligned structures.""" if not idxs: return keep = [i for i in range(self.model.K) if i not in idxs] self.L = self.L[:, keep] self.L2 = self.L2[:, keep] self.F = self.F[:, keep] self.F2 = self.F2[:, keep] self.kl_l = self.kl_l[keep] self.kl_f = self.kl_f[keep] self.model_state_L = [self.model_state_L[i] for i in keep] self.model_state_F = [self.model_state_F[i] for i in keep] self.pi0_L = [self.pi0_L[i] for i in keep] self.pi0_F = [self.pi0_F[i] for i in keep] self.model.K = len(keep) self.obj = [] @torch.no_grad() def _build_covariate_matrix( self, external_cov: Tensor | None, self_cov_enabled: bool, factors: Tensor, k: int, dim_size: int ) -> Tensor | None: """Build covariate matrix combining external and self-covariates.""" if external_cov is not None and external_cov.device != self.device: external_cov = external_cov.to(self.device) if not self_cov_enabled: return external_cov # Get other factors (excluding k) if self.model.K > 1: others = factors[:, torch.arange(self.model.K, device=self.device) != k] if external_cov is None: return others return torch.hstack((external_cov, others)) # K=1 case: return external covariates or intercept return external_cov if external_cov is not None else factors.new_ones(dim_size, 1) @torch.no_grad() def _recompute_residual(self) -> None: """R := (Y0 - L F^T) ⊙ mask, NaNs -> 0.""" self.R = self.Y0 - self.L @ self.F.T self.R.mul_(self.mask) self.R.nan_to_num_(nan=0.0)
def normal_means_loglik( x: torch.Tensor, s: torch.Tensor, Et: torch.Tensor, Et2: torch.Tensor, mask: torch.Tensor | None = None, reduce: str = "sum", eps: float = NUMERICAL_EPS, ) -> torch.Tensor: """ Compute the expected normal-means log-likelihood. E_q[ log N(x | theta, s^2) ] with q giving Et, Et2. Parameters ---------- x : torch.Tensor Observed data (broadcastable with s, Et, Et2). s : torch.Tensor Standard errors (broadcastable). Et : torch.Tensor Posterior means (broadcastable). Et2 : torch.Tensor Posterior second moments (broadcastable). mask : torch.Tensor or None, optional Optional bool mask; True = include entry. reduce : str, optional 'sum' (default), 'mean', or 'none' (per-element with NaNs for excluded). eps : float, optional Numerical floor for variance. Returns ------- torch.Tensor Scalar if reduce in {'sum','mean'}, else elementwise tensor. """ # Ensure common dtype/device via broadcasting x, s, Et, Et2 = torch.broadcast_tensors(x, s, Et, Et2) # Validity mask: finite & s > 0 valid = torch.isfinite(x) & torch.isfinite(s) & torch.isfinite(Et) & torch.isfinite(Et2) & (s > 0) if mask is not None: valid = valid & mask.to(dtype=torch.bool, device=x.device) # Stable variance and constant term. Use a safe `s` so log/quad never see 0 # at masked entries; the masking happens after via `torch.where`. safe_s = torch.where(valid, s, torch.ones_like(s)) var = (safe_s * safe_s).clamp_min(eps) c2pi = x.new_tensor(math.log(2.0 * math.pi)) # stays on same device/dtype # E[(x - theta)^2] = Et2 - 2*x*Et + x^2 (uses safe_x to keep finite) safe_x = torch.where(valid, x, torch.zeros_like(x)) safe_Et = torch.where(valid, Et, torch.zeros_like(Et)) safe_Et2 = torch.where(valid, Et2, torch.zeros_like(Et2)) quad = safe_Et2 - 2.0 * safe_x * safe_Et + safe_x * safe_x ll_el_raw = -0.5 * (c2pi + torch.log(var) + quad / var) # Branchless: invalid entries contribute 0 to sum/mean (was previously a # host-sync `if not valid.any(): return nan`). When literally every entry # is invalid, sum=0 and mean returns NaN naturally via 0/0. valid_f = valid.to(ll_el_raw.dtype) ll_el_masked = ll_el_raw * valid_f if reduce == "sum": return ll_el_masked.sum() elif reduce == "mean": return ll_el_masked.sum() / valid_f.sum() elif reduce == "none": out = torch.full_like(x, float("nan")) out[valid] = ll_el_raw[valid] return out else: raise ValueError("reduce must be 'sum', 'mean', or 'none'")