Source code for cebmf_torch.utils.maths

import math

import torch
from torch import Tensor

_TWOPI = 2.0 * math.pi
_SQRT_2PI = math.sqrt(_TWOPI)
_EPS = 1e-12
_LOG_2PI = math.log(_TWOPI)
_LOG_SQRT_2PI = 0.5 * _LOG_2PI


def _like(x: Tensor, val) -> Tensor:
    """Create a scalar tensor `val` on x's device/dtype."""
    return torch.as_tensor(val, device=x.device, dtype=x.dtype)


def _logpdf_normal(x: Tensor, loc: Tensor, scale: Tensor) -> Tensor:
    """
    Numerically-clean log-density of a normal distribution.

    log N(x | loc, scale^2) = -0.5*((x-loc)/scale)^2 - log(scale) - 0.5*log(2*pi)

    Callers are expected to ensure ``scale > 0`` (typically by clamping standard
    errors). This is the canonical implementation used across the package; the
    module-level aliases in ``utils/posterior.py`` and ``utils/distribution_operation.py``
    re-export this one to avoid drifting copies.
    """
    z = (x - loc) / scale
    return -0.5 * z.pow(2) - torch.log(scale) - _LOG_SQRT_2PI


[docs] def log_norm_pdf(x: Tensor, loc: Tensor, scale: Tensor) -> Tensor: """Backward-compatible alias for :func:`_logpdf_normal` with epsilon-padded scale. Adds a tiny epsilon to ``scale`` before the log/division to tolerate degenerate inputs (e.g., ``scale==0``). Prefer :func:`_logpdf_normal` and clamp ``scale`` upstream when the caller guarantees positivity. """ eps = _like(scale, 1e-32) safe_scale = scale + eps return _logpdf_normal(x, loc, safe_scale)
def _logcdf_normal(z: Tensor) -> Tensor: """Numerically-stable log Φ(z) (standard normal CDF). Wraps torch.special.log_ndtr.""" return torch.special.log_ndtr(z)
[docs] def norm_cdf(x: Tensor) -> Tensor: """ Compute the standard normal cumulative distribution function (CDF). Parameters ---------- x : torch.Tensor Input tensor. Returns ------- torch.Tensor CDF evaluated at x. """ return torch.distributions.Normal(_like(x, 0.0), _like(x, 1.0)).cdf(x)
[docs] def norm_pdf(x: Tensor) -> Tensor: """ Compute the standard normal probability density function (PDF). Parameters ---------- x : torch.Tensor Input tensor. Returns ------- torch.Tensor PDF evaluated at x. """ return torch.exp(-0.5 * x * x) / _like(x, _SQRT_2PI)
[docs] def logsumexp(x: Tensor, dim: int = -1, keepdim: bool = False) -> Tensor: """ Compute the log of the sum of exponentials of input elements along a given dimension. Parameters ---------- x : torch.Tensor Input tensor. dim : int, optional Dimension along which to operate. Default is -1. keepdim : bool, optional Whether the output tensor has dim retained or not. Default is False. Returns ------- torch.Tensor Result of log-sum-exp operation. """ return torch.logsumexp(x, dim=dim, keepdim=keepdim)
[docs] def safe_log(x: Tensor, eps: float = _EPS) -> Tensor: """ Compute the logarithm of x with clamping for numerical stability. Parameters ---------- x : torch.Tensor Input tensor. eps : float, optional Minimum value to clamp x to. Default is 1e-12. Returns ------- torch.Tensor Logarithm of clamped x. """ return torch.log(torch.clamp(x, min=_like(x, eps)))
[docs] def softmax(x: Tensor, dim: int = -1) -> Tensor: """ Compute the softmax of input tensor along the specified dimension. Parameters ---------- x : torch.Tensor Input tensor. dim : int, optional Dimension along which softmax will be computed. Default is -1. Returns ------- torch.Tensor Softmax of the input tensor. """ return torch.softmax(x, dim=dim)
# ------------------------ # helpers # ------------------------
[docs] def logphi(z: torch.Tensor) -> torch.Tensor: """ Compute the log of the standard normal PDF φ(z) = exp(-z^2/2)/√(2π). Parameters ---------- z : torch.Tensor Input tensor. Returns ------- torch.Tensor Log PDF evaluated at z. """ return -0.5 * z.pow(2) - _like(z, _LOG_SQRT_2PI)
[docs] def logPhi(z: torch.Tensor) -> torch.Tensor: """ Compute the stable log CDF of the standard normal distribution Φ(z). Parameters ---------- z : torch.Tensor Input tensor. Returns ------- torch.Tensor Log CDF evaluated at z. """ return torch.special.log_ndtr(z)
[docs] def logscale_sub(logx: torch.Tensor, logy: torch.Tensor) -> torch.Tensor: """ Compute log(exp(logx) - exp(logy)) in a numerically stable way. Requires logx >= logy. Parameters ---------- logx : torch.Tensor Logarithm of x. logy : torch.Tensor Logarithm of y. Returns ------- torch.Tensor Logarithm of (exp(logx) - exp(logy)). """ max_log = torch.maximum(logx, logy) return max_log + torch.log(torch.exp(logx - max_log) - torch.exp(logy - max_log))
[docs] def logscale_add(logx: Tensor, logy: Tensor) -> Tensor: """ Compute log(exp(logx) + exp(logy)) in a numerically stable way. Parameters ---------- logx : torch.Tensor Logarithm of x. logy : torch.Tensor Logarithm of y. Returns ------- torch.Tensor Logarithm of (exp(logx) + exp(logy)). """ return torch.logaddexp(logx, logy)
[docs] def do_truncnorm_argchecks(a: torch.Tensor, b: torch.Tensor): """ Clamp and sanity check bounds for truncated normal arguments. Parameters ---------- a : torch.Tensor Lower bound(s). b : torch.Tensor Upper bound(s). Returns ------- tuple of torch.Tensor (a, b) after checks. """ # Ensure a and b share device/dtype a = torch.as_tensor(a) b = torch.as_tensor(b, device=a.device, dtype=a.dtype) return a, b
[docs] def safe_tensor_to_float( value: torch.Tensor | float | None, null_value: float = float("-inf"), reduction: str = "min" ) -> float: """ Convert tensor, float, or None to float with safe handling. Parameters ---------- value : torch.Tensor, float, or None Value to convert. null_value : float, optional Value to return if input is None or empty. Default is -inf. reduction : str, optional Reduction to apply if input is a tensor ("min" or "max"). Default is "min". Returns ------- float Converted float value. """ if value is None: return null_value if isinstance(value, torch.Tensor): if value.numel() == 0: return null_value if reduction == "min": return float(value.min().item()) elif reduction == "max": return float(value.max().item()) return float(value)
# ------------------------ # E[Z | a<Z<b] and E[Z^2 | a<Z<b] for Z~N(mean, sd^2) # ------------------------ def _resolve_truncnorm_dtype(precision: str, *tensors: torch.Tensor) -> torch.dtype: """Pick the internal compute dtype for the truncnorm helpers. ``precision`` is one of: - ``"auto"`` (default): use the highest-precision floating dtype among the provided tensors, defaulting to ``float32``. This keeps everything on the caller's device/dtype, which on CUDA consumer cards is ~30x faster than float64 with no observable accuracy loss for typical EBNM workloads. - ``"float64"``: force ``torch.float64``. Slower on CUDA but matches the historical behaviour of these helpers exactly. Use this when you have extreme bounds (``|alpha|, |beta| > 10``) that benefit from the extra precision. """ if precision == "float64": return torch.float64 if precision != "auto": raise ValueError(f"precision must be 'auto' or 'float64', got {precision!r}") # Pick the highest-precision floating dtype from inputs, default float32. best = torch.float32 rank = {torch.float16: 0, torch.bfloat16: 0, torch.float32: 1, torch.float64: 2} for t in tensors: if t.dtype.is_floating_point and rank.get(t.dtype, 1) > rank.get(best, 1): best = t.dtype return best
[docs] def my_etruncnorm(a, b, mean=0.0, sd=1.0, precision: str = "auto"): """ Compute E[Z | a < Z < b] for Z ~ N(mean, sd^2), the mean of a truncated normal. Parameters ---------- a : float or torch.Tensor Lower truncation bound. b : float or torch.Tensor Upper truncation bound. mean : float or torch.Tensor, optional Mean of the normal distribution. Default is 0.0. sd : float or torch.Tensor, optional Standard deviation of the normal distribution. Default is 1.0. precision : str, optional Internal compute dtype. ``"auto"`` (default) follows the input dtype (typically ``float32``), which is fast on CUDA. ``"float64"`` forces double precision — slower on CUDA but matches the pre-2026 behaviour. Returns ------- torch.Tensor Mean of the truncated normal distribution. Dtype matches the chosen ``precision``. """ a, b = do_truncnorm_argchecks(torch.as_tensor(a), torch.as_tensor(b)) device = a.device mean_t = torch.as_tensor(mean) sd_t = torch.as_tensor(sd) work_dtype = _resolve_truncnorm_dtype(precision, a, b, mean_t, sd_t) mean = mean_t.to(dtype=work_dtype, device=device) sd = sd_t.to(dtype=work_dtype, device=device) alpha = (a.to(dtype=work_dtype, device=device) - mean) / sd beta = (b.to(dtype=work_dtype, device=device) - mean) / sd flip = ((alpha > 0) & (beta > 0)) | (beta > alpha.abs()) orig_alpha = alpha.clone() alpha = torch.where(flip, -beta, alpha) beta = torch.where(flip, -orig_alpha, beta) dnorm_diff = logscale_sub(logphi(beta), logphi(alpha)) pnorm_diff = logscale_sub(logPhi(beta), logPhi(alpha)) scaled_res = -torch.exp(torch.clamp(dnorm_diff - pnorm_diff, max=700.0)) # endpoints equal endpts_equal = torch.isinf(pnorm_diff) scaled_res = torch.where(endpts_equal, (alpha + beta) / 2, scaled_res) lower_bd = torch.maximum(beta + 1.0 / beta, (alpha + beta) / 2) bad_idx = (~torch.isnan(beta)) & (beta < 0) & ((scaled_res < lower_bd) | (scaled_res > beta)) scaled_res = torch.where(bad_idx, lower_bd, scaled_res) scaled_res = torch.where(flip, -scaled_res, scaled_res) res = mean + sd * scaled_res # Branchless degenerate-sd handling. Previous code wrapped these where() # calls in `if (sd == 0).any():`, which forced a host sync on every call. # The where-chain is a no-op when no entries have sd=0, so we just always # run it (a tiny constant cost on healthy inputs, vs. a real GPU stall). res = _apply_degenerate_sd_first_moment(res, a, b, mean, sd) return res
def _apply_degenerate_sd_first_moment( res: torch.Tensor, a: torch.Tensor, b: torch.Tensor, mean: torch.Tensor, sd: torch.Tensor ) -> torch.Tensor: """Replace entries of `res` corresponding to sd==0 with the limiting first moment.""" a_rep = a.expand_as(res).to(res) b_rep = b.expand_as(res).to(res) mean_rep = mean.expand_as(res) sd_zero = sd == 0 cond1 = sd_zero & (b_rep <= mean_rep) cond2 = sd_zero & (a_rep >= mean_rep) cond3 = sd_zero & (a_rep < mean_rep) & (b_rep > mean_rep) res = torch.where(cond1, b_rep, res) res = torch.where(cond2, a_rep, res) res = torch.where(cond3, mean_rep, res) return res
[docs] def my_e2truncnorm(a, b, mean=0.0, sd=1.0, precision: str = "auto"): """ Compute E[Z^2 | a < Z < b] for Z ~ N(mean, sd^2), the second moment of a truncated normal. Parameters ---------- a : float or torch.Tensor Lower truncation bound. b : float or torch.Tensor Upper truncation bound. mean : float or torch.Tensor, optional Mean of the normal distribution. Default is 0.0. sd : float or torch.Tensor, optional Standard deviation of the normal distribution. Default is 1.0. precision : str, optional Internal compute dtype. See :func:`my_etruncnorm` for the contract. ``"auto"`` (default) keeps the caller's dtype (fast on CUDA); ``"float64"`` matches the pre-2026 behaviour. Returns ------- torch.Tensor Second moment of the truncated normal distribution. Dtype matches the chosen ``precision``. """ a, b = do_truncnorm_argchecks(torch.as_tensor(a), torch.as_tensor(b)) device = a.device mean_t = torch.as_tensor(mean) sd_t = torch.as_tensor(sd) work_dtype = _resolve_truncnorm_dtype(precision, a, b, mean_t, sd_t) mean = mean_t.to(dtype=work_dtype, device=device) sd = sd_t.to(dtype=work_dtype, device=device) alpha = (a.to(dtype=work_dtype, device=device) - mean) / sd beta = (b.to(dtype=work_dtype, device=device) - mean) / sd flip = (alpha > 0) & (beta > 0) orig_alpha = alpha.clone() alpha = torch.where(flip, -beta, alpha) beta = torch.where(flip, -orig_alpha, beta) # Absolute mean handling. `mean.abs()` is a no-op for mean==0, so we don't # need the previous `if not torch.all(mean == 0): mean = mean.abs()` guard # — that guard fired a host sync on every call. mean = mean.abs() pnorm_diff = logscale_sub(logPhi(beta), logPhi(alpha)) alpha_frac = alpha * torch.exp(torch.clamp(logphi(alpha) - pnorm_diff, max=300.0)) beta_frac = beta * torch.exp(torch.clamp(logphi(beta) - pnorm_diff, max=300.0)) # handle nan/inf alpha_frac = torch.where(~torch.isfinite(alpha_frac), torch.zeros_like(alpha_frac), alpha_frac) beta_frac = torch.where(~torch.isfinite(beta_frac), torch.zeros_like(beta_frac), beta_frac) scaled_res = torch.ones_like(alpha) alpha_idx = torch.isfinite(alpha) scaled_res = torch.where(alpha_idx, 1 + alpha_frac, scaled_res) beta_idx = torch.isfinite(beta) scaled_res = torch.where(beta_idx, scaled_res - beta_frac, scaled_res) endpts_equal = torch.isinf(pnorm_diff) scaled_res = torch.where(endpts_equal, ((alpha + beta) ** 2) / 4, scaled_res) upper_bd1 = beta**2 + 2 * (1 + 1 / beta**2) upper_bd2 = (alpha**2 + alpha * beta + beta**2) / 3 upper_bd = torch.minimum(upper_bd1, upper_bd2) bad_idx = (~torch.isnan(beta)) & (beta < 0) & ((scaled_res < beta**2) | (scaled_res > upper_bd)) scaled_res = torch.where(bad_idx, upper_bd, scaled_res) # NOTE: my_etruncnorm expects (a,b,mean,sd). For standardized alpha/beta, use mean=0, sd=1. # Forward the same ``precision`` so the inner call doesn't silently re-upcast back to float64. res = mean**2 + 2 * mean * sd * my_etruncnorm(alpha, beta, 0.0, 1.0, precision=precision) + sd**2 * scaled_res # Branchless degenerate-sd handling — see _apply_degenerate_sd_first_moment # for the rationale (no host sync per call). a_rep = a.expand_as(res).to(res) b_rep = b.expand_as(res).to(res) mean_rep = mean.expand_as(res) sd_zero = sd == 0 cond1 = sd_zero & (b_rep <= mean_rep) cond2 = sd_zero & (a_rep >= mean_rep) cond3 = sd_zero & (a_rep < mean_rep) & (b_rep > mean_rep) res = torch.where(cond1, b_rep**2, res) res = torch.where(cond2, a_rep**2, res) res = torch.where(cond3, mean_rep**2, res) return res