Source code for cebmf_torch.ebnm.generalized_binary

from dataclasses import dataclass

import torch
from torch import Tensor

from cebmf_torch.utils.maths import (
    _LOG_SQRT_2PI,  # log(sqrt(2π))
    logPhi,  # stable log Φ
    my_e2truncnorm,  # E[X^2 | a < X < b] for Normal(mean, sd) via truncation
    my_etruncnorm,  # E[X | a < X < b] for Normal(mean, sd) via truncation
)


@dataclass
class EBNMGBResult:
    """Generalized-binary EBNM result. Scalar fields are 0-d tensors on the
    input device — no per-call host syncs."""

    post_mean: Tensor
    post_mean2: Tensor
    post_sd: Tensor
    pi_slab: Tensor  # slab weight π (= 1 - pi_null)
    mode: Tensor  # learned μ ≥ 0
    scale: Tensor  # fixed ω (σ = ω μ)
    log_lik: Tensor


def _log_normal_pdf(x: Tensor, mean: Tensor, sd: Tensor) -> Tensor:
    sd = sd.clamp_min(1e-12)
    z = (x - mean) / sd
    return -0.5 * z**2 - torch.log(sd) - _LOG_SQRT_2PI


[docs] def ebnm_gb( x: Tensor, s: Tensor, omega: float = 0.2, # fixed ω (σ = ω μ), typically small par_init_mu: float = 1.0, # μ initialization (on the original scale) par_init_pi: float = 0.2, # π initialization max_em: int = 200, tol_em: float = 1e-5, max_lbfgs: int = 200, tol_lbfgs: float = 1e-6, eps: float = 1e-12, ) -> EBNMGBResult: """ EBNM with Generalized Binary prior: θ ~ (1-π) δ0 + π N_+(μ, σ^2), with σ = ω μ, μ≥0, ω fixed. Follows the EM scheme in Supplementary Note (eqs. (16)-(27)). :contentReference[oaicite:1]{index=1} """ device, dtype = x.device, x.dtype x = x.to(dtype) s = torch.clamp(s.to(dtype), min=1e-6) # Initialize hyperparameters mu = torch.tensor(float(max(par_init_mu, 1e-6)), device=device, dtype=dtype) pi = torch.tensor(float(min(max(par_init_pi, 1e-6), 1 - 1e-6)), device=device, dtype=dtype) omega_t = torch.tensor(float(omega), device=device, dtype=dtype) # Precompute spike log-likelihood terms: N(x; 0, s^2) lf = _log_normal_pdf(x, torch.zeros_like(x), s) # Helper to compute ζ (E-step) and optionally pieces for M-step def _E_step(mu_val: Tensor, pi_val: Tensor): # σ = ω μ sigma = omega_t * mu_val.clamp_min(0.0) # slab marginal density from eq. (18): # N(x; μ, σ^2 + s^2) * Φ(μ̃/σ̃) / Φ(μ/σ) var_sum = s * s + sigma * sigma lg0 = _log_normal_pdf(x, mu_val, var_sum.sqrt()) # σ̃_i^2 and μ̃_i (eqs. (19)-(20) with σ=ω μ) denom = 1.0 / (sigma * sigma + eps) + 1.0 / (s * s) sig_tilde2 = 1.0 / denom mu_tilde = ((s * s) * mu_val + (sigma * sigma) * x) / (s * s + sigma * sigma) # log Φ(μ̃/σ̃) − log Φ(μ/σ). Note μ/σ = 1/ω is constant in μ’s optimization. :contentReference[oaicite:2]{index=2} # noqa: E501 log_norm_cdf_ratio = logPhi(mu_tilde / sig_tilde2.sqrt()) - logPhi( torch.tensor(1.0 / float(omega), device=device, dtype=dtype) ) # noqa: E501 lg = lg0 + log_norm_cdf_ratio # log slab marginal per i # ζ_i = posterior slab prob (eq. for E[zi | ...]) # ζ_i = softmax in log-space: lognum = log π + lg, logden = logaddexp(log(1-π)+lf, log π + lg) log_num = torch.log(pi_val.clamp_min(eps)) + lg log_denom = torch.logaddexp(torch.log1p(-pi_val).clamp_min(-50) + lf, log_num) zeta = torch.exp(log_num - log_denom).clamp(0.0, 1.0) return zeta, lg, lf, mu_tilde, sig_tilde2 # M-step for μ: maximize sum_i ζ_i [ log N(x; μ, ω^2 μ^2 + s^2) + log Φ(μ̃/σ̃) ], no closed form. :contentReference[oaicite:3]{index=3} # noqa: E501 # We optimize an unconstrained η with μ = softplus(η). def _optimize_mu(zeta: Tensor, mu_init: Tensor): eta = torch.nn.Parameter(torch.log(torch.expm1(mu_init.clamp_min(1e-8)))) opt = torch.optim.LBFGS( [eta], max_iter=max_lbfgs, tolerance_grad=tol_lbfgs, tolerance_change=tol_lbfgs, line_search_fn="strong_wolfe", history_size=20, ) def closure(): opt.zero_grad(set_to_none=True) mu_pos = torch.nn.functional.softplus(eta) + 0.0 # ensure ≥ 0 sigma = omega_t * mu_pos var_sum = s * s + sigma * sigma lg0 = _log_normal_pdf(x, mu_pos, var_sum.sqrt()) # μ̃_i, σ̃_i as functions of μ (via σ) denom = 1.0 / (sigma * sigma + eps) + 1.0 / (s * s) sig_tilde2 = 1.0 / denom mu_tilde = ((s * s) * mu_pos + (sigma * sigma) * x) / (s * s + sigma * sigma) obj_terms = lg0 + logPhi(mu_tilde / sig_tilde2.sqrt()) # drop constant −logΦ(1/ω) loss = -(zeta * obj_terms).sum() loss = torch.nan_to_num(loss, nan=1e30, posinf=1e30, neginf=1e30) loss.backward() return loss try: opt.step(closure) except RuntimeError: # Fallback: small gradient steps if LBFGS fails adam = torch.optim.Adam([eta], lr=1e-2) for _ in range(200): adam.zero_grad(set_to_none=True) mu_pos = torch.nn.functional.softplus(eta) sigma = omega_t * mu_pos var_sum = s * s + sigma * sigma lg0 = _log_normal_pdf(x, mu_pos, var_sum.sqrt()) denom = 1.0 / (sigma * sigma + eps) + 1.0 / (s * s) sig_tilde2 = 1.0 / denom mu_tilde = ((s * s) * mu_pos + (sigma * sigma) * x) / (s * s + sigma * sigma) obj_terms = lg0 + logPhi(mu_tilde / sig_tilde2.sqrt()) loss = -(zeta * obj_terms).sum() loss.backward() adam.step() with torch.no_grad(): return torch.nn.functional.softplus(eta).clamp_min(1e-8) # ---- EM loop ---- # Convergence is checked every `check_every` iterations to keep most steps # GPU-resident — see optimize_pi_logL for the same pattern. check_every = 5 prev_ll = -float("inf") for it in range(max_em): # E-step zeta, lg, lf, mu_tilde, sig_tilde2 = _E_step(mu, pi) # M-step π (eq. (22)): average ζ pi_new = zeta.mean().clamp(1e-8, 1 - 1e-8) # M-step μ: optimize expected complete log-lik (eq. (23); constant −logΦ(1/ω) dropped). :contentReference[oaicite:4]{index=4} # noqa: E501 mu_new = _optimize_mu(zeta, mu) # Evaluate marginal log-likelihood with updated params (for convergence check; eq. (18)) with torch.no_grad(): sigma = omega_t * mu_new var_sum = s * s + sigma * sigma lg0 = _log_normal_pdf(x, mu_new, var_sum.sqrt()) denom = 1.0 / (sigma * sigma + eps) + 1.0 / (s * s) sig_tilde2 = 1.0 / denom mu_tilde = ((s * s) * mu_new + (sigma * sigma) * x) / (s * s + sigma * sigma) log_norm_cdf_ratio = logPhi(mu_tilde / sig_tilde2.sqrt()) - logPhi( torch.tensor(1.0 / float(omega), device=device, dtype=dtype) ) # noqa: E501 lg_marg = lg0 + log_norm_cdf_ratio # log ∏ [ (1-π)N(x;0,s^2) + π * slab ] ll_t = torch.logaddexp(torch.log1p(-pi_new) + lf, torch.log(pi_new) + lg_marg).sum() pi, mu = pi_new, mu_new # Sync-light convergence check: only force a host comparison every # `check_every` iterations. if (it + 1) % check_every == 0: ll = ll_t.item() if ll - prev_ll < tol_em: break prev_ll = ll # ---- Posterior moments (eqs. (25)-(27)) ---- with torch.no_grad(): # Final E-step for ζ̂ and slab parts zeta, lg, lf, mu_tilde, sig_tilde2 = _E_step(mu, pi) # Posterior over θ: (1-ζ̂) δ0 + ζ̂ N_+(μ̃, σ̃²). # Compute E[θ], E[θ²] from truncated normal moments on [0, ∞). a = torch.full_like(x, 0.0) b = torch.full_like(x, float("inf")) EX = my_etruncnorm(a, b, mean=mu_tilde, sd=sig_tilde2.sqrt()) EX2 = my_e2truncnorm(a, b, mean=mu_tilde, sd=sig_tilde2.sqrt()) post_mean = zeta * EX post_mean2 = zeta * EX2 post_sd = (post_mean2 - post_mean**2).clamp_min(0).sqrt() # Final marginal log-likelihood (kept as a 0-d tensor on-device). sigma = omega_t * mu var_sum = s * s + sigma * sigma lg0 = _log_normal_pdf(x, mu, var_sum.sqrt()) log_norm_cdf_ratio = logPhi(mu_tilde / sig_tilde2.sqrt()) - logPhi( torch.tensor(1.0 / float(omega), device=device, dtype=dtype) ) # noqa: E501 lg_marg = lg0 + log_norm_cdf_ratio log_lik = torch.logaddexp(torch.log1p(-pi) + lf, torch.log(pi) + lg_marg).sum() scale_t = torch.as_tensor(1.0 / (omega + 1e-8), device=device, dtype=dtype) return EBNMGBResult( post_mean=post_mean, post_mean2=post_mean2, post_sd=post_sd, pi_slab=pi.detach(), # local `pi` is the slab weight mode=mu.detach(), scale=scale_t, log_lik=log_lik.detach(), )