Source code for cebmf_torch.ebnm.point_laplace

import math
from dataclasses import dataclass

import torch
from torch import Tensor

from cebmf_torch.utils.maths import (
    _LOG_SQRT_2PI,
    logPhi,
    my_e2truncnorm,
    my_etruncnorm,
    safe_log,
)


def _const_like(x: Tensor, val) -> Tensor:
    return torch.as_tensor(val, device=x.device, dtype=x.dtype)


def logg_laplace_convolved_with_normal(x: Tensor, s: Tensor, a: Tensor) -> Tensor:
    """
    log (Laplace(0, rate=a) ⊗ Normal(0, s^2)) at x.
    = log(a/2) + 0.5*(s a)^2 + log( Φ((x - s^2 a)/s) e^{-a x} + Φ(-(x + s^2 a)/s) e^{a x} )
    """
    # assume x, s, a already on correct device/dtype; s already clamped outside
    z1 = (x - (s * s) * a) / s
    z2 = -(x + (s * s) * a) / s
    lg1 = -a * x + logPhi(z1)
    lg2 = a * x + logPhi(z2)
    lsum = torch.logaddexp(lg1, lg2)
    half = torch.tensor(0.5, device=x.device, dtype=x.dtype)
    two = torch.tensor(2.0, device=x.device, dtype=x.dtype)
    return safe_log(a / two) + half * (s * a) ** 2 + lsum


@dataclass
class EBNMLaplaceResult:
    """All scalar fields are 0-d tensors on the input device (no host syncs)."""

    post_mean: Tensor
    post_mean2: Tensor
    post_sd: Tensor
    pi_slab: Tensor  # mixture weight of the Laplace branch (slab); = 1 - pi_null
    a: Tensor  # Laplace rate (1/scale)
    mu: Tensor
    log_lik: Tensor  # pure marginal log-likelihood (no penalties)


[docs] def ebnm_point_laplace( x: Tensor, s: Tensor, par_init=None, # None by default; choose safely inside fix_par=(False, False, True), # [w_logit, a_logit, mu]; mu fixed at 0 by default max_iter: int = 20, tol: float = 1e-6, a_bounds=(1e-2, 1e2), # bounds for Laplace rate a loga_l2: float = 0.0, # ridge on a's unconstrained logit (optimization only; 0=off) tresh_pi0: float = 1e-3, # legacy name; slab-weight threshold for spike-only shortcut eps: float = 1e-12, pen_pi0: float = 0.0, # legacy name; optional symmetric prior on slab-weight parameter w use_adam_warmstart: bool = False, # default OFF for speed; set True to enable short warm-up adam_steps: int = 8, adam_lr: float = 1e-2, weight_decay: float = 0.0, ) -> EBNMLaplaceResult: """ Efficient direct maximisation of the observed marginal log-likelihood for a point-Laplace EBNM. GPU-resident throughout — no host<->device syncs in either the inner objective or the posterior summary. Optimizer: LBFGS-only (AdamW warm-start optional and short). The four scalar fields of the result (``pi_slab``, ``a``, ``mu``, ``log_lik``) are returned as 0-d tensors on the input device so the cEBMF ELBO accumulator can fold them in without forcing a host sync per factor update. Call ``float(field)`` at your own boundary if you need a Python scalar. """ # ---- setup & hoisted constants (device/dtype safe) ---- device, dtype = x.device, x.dtype x = torch.as_tensor(x, device=device, dtype=dtype) s = torch.as_tensor(s, device=device, dtype=dtype).clamp_min(_const_like(x, 1e-6)) # Precompute once; reused in closure & posterior log_s = torch.log(s) s2 = s * s # scalar constants as tensors on the right device/dtype zero = torch.tensor(0.0, device=device, dtype=dtype) one = torch.tensor(1.0, device=device, dtype=dtype) half = torch.tensor(0.5, device=device, dtype=dtype) two = torch.tensor(2.0, device=device, dtype=dtype) eps_t = torch.tensor(eps, device=device, dtype=dtype) thresh_pi_slab_t = torch.tensor(tresh_pi0, device=device, dtype=dtype) # normalizing constant (reuse tensor if given, else make one) c_norm = _LOG_SQRT_2PI if isinstance(_LOG_SQRT_2PI, torch.Tensor) else _const_like(s, _LOG_SQRT_2PI) # bounds for 'a' via smooth map a = a_lo + (a_hi - a_lo) * sigmoid(v) a_lo, a_hi = a_bounds a_lo_t = torch.tensor(a_lo, device=device, dtype=dtype) a_hi_t = torch.tensor(a_hi, device=device, dtype=dtype) a_span = a_hi_t - a_lo_t # ---- defaults & parameter tensors ---- if par_init is None: par_init = (2.0, 0.0, 0.0) # (logit(w), log(a_init), mu) # map provided log(a_init) into v0 for the bounded-sigmoid parameterization a_init = float(min(max(math.exp(float(par_init[1])), a_lo), a_hi)) r = (a_init - a_lo) / (a_hi - a_lo) r = min(max(r, 1e-8), 1 - 1e-8) v0 = math.log(r) - math.log(1 - r) w_logit = torch.nn.Parameter( torch.as_tensor(par_init[0], dtype=dtype, device=device), requires_grad=not fix_par[0] ) # noqa: E501 a_logit = torch.nn.Parameter(torch.as_tensor(v0, dtype=dtype, device=device), requires_grad=not fix_par[1]) # noqa: E501 mu = torch.nn.Parameter(torch.as_tensor(par_init[2], dtype=dtype, device=device), requires_grad=not fix_par[2]) # noqa: E501 params = [p for p in (w_logit, a_logit, mu) if p.requires_grad] # ---- optional warm-start (kept very short) ---- if use_adam_warmstart and params: opt_adam = torch.optim.AdamW(params, lr=adam_lr, betas=(0.9, 0.999), weight_decay=weight_decay) for _ in range(int(adam_steps)): opt_adam.zero_grad(set_to_none=True) w = torch.sigmoid(w_logit).clamp(eps_t, one - eps_t) a = a_lo_t + a_span * torch.sigmoid(a_logit) xc = x - mu # spike log-lik lf = -(half * (xc / s) ** 2) - log_s - c_norm # slab log-lik (fused helper) z1 = (xc - s2 * a) / s z2 = -(xc + s2 * a) / s lg1 = -a * xc + logPhi(z1) lg2 = a * xc + logPhi(z2) lsum = torch.logaddexp(lg1, lg2) lg = safe_log(a / two) + half * (s * a) ** 2 + lsum llik = torch.logaddexp(torch.log1p(-w) + lf, torch.log(w) + lg).sum() penalty = zero if loga_l2 != 0.0: penalty = penalty + torch.tensor(loga_l2, device=device, dtype=dtype) * (a_logit**2) if pen_pi0 != 0.0: penalty = penalty - torch.tensor(pen_pi0, device=device, dtype=dtype) * ( torch.log(w) + torch.log1p(-w) ) loss = -(llik - penalty) loss.backward() opt_adam.step() # ---- LBFGS polish (closure computes ONLY the scalar loss) ---- if params: opt_lbfgs = torch.optim.LBFGS( params, max_iter=max_iter, tolerance_grad=tol, tolerance_change=tol, line_search_fn="strong_wolfe", history_size=10, ) def closure(): opt_lbfgs.zero_grad(set_to_none=True) w = torch.sigmoid(w_logit).clamp(eps_t, one - eps_t) a = a_lo_t + a_span * torch.sigmoid(a_logit) xc = x - mu # spike log-lik (reuses log_s, c_norm) lf = -(half * (xc / s) ** 2) - log_s - c_norm # slab log-lik (inline for speed) z1 = (xc - s2 * a) / s z2 = -(xc + s2 * a) / s lg1 = -a * xc + logPhi(z1) lg2 = a * xc + logPhi(z2) lsum = torch.logaddexp(lg1, lg2) lg = safe_log(a / two) + half * (s * a) ** 2 + lsum llik = torch.logaddexp(torch.log1p(-w) + lf, torch.log(w) + lg).sum() # penalties as tensors on-device penalty = zero if loga_l2 != 0.0: penalty = penalty + torch.tensor(loga_l2, device=device, dtype=dtype) * (a_logit**2) if pen_pi0 != 0.0: penalty = penalty - torch.tensor(pen_pi0, device=device, dtype=dtype) * ( torch.log(w) + torch.log1p(-w) ) loss = -(llik - penalty) loss = torch.nan_to_num( loss, nan=torch.tensor(1e30, device=device, dtype=dtype), posinf=torch.tensor(1e30, device=device, dtype=dtype), neginf=torch.tensor(1e30, device=device, dtype=dtype), ) loss.backward() return loss try: opt_lbfgs.step(closure) except RuntimeError: # fallback: freeze 'a' if line search fails; keep everything on-device if a_logit.requires_grad: a_logit.requires_grad_(False) params2 = [p for p in (w_logit, mu) if p.requires_grad] if params2: torch.optim.LBFGS( params2, max_iter=max_iter, tolerance_grad=tol, tolerance_change=tol, line_search_fn="strong_wolfe", history_size=10, ).step(closure) # ---- posterior & summaries (no penalties; single no_grad block) ---- with torch.no_grad(): w = torch.sigmoid(w_logit).clamp(eps_t, one - eps_t) a = a_lo_t + a_span * torch.sigmoid(a_logit) xc = x - mu # spike lf = -(half * (xc / s) ** 2) - log_s - c_norm # slab z1 = (xc - s2 * a) / s z2 = -(xc + s2 * a) / s lg1 = -a * xc + logPhi(z1) lg2 = a * xc + logPhi(z2) lsum = torch.logaddexp(lg1, lg2) lg = safe_log(a / two) + half * (s * a) ** 2 + lsum # posterior inclusion prob (slab) log_num = torch.log(w) + lg log_denom = torch.logaddexp(torch.log1p(-w) + lf, log_num) gamma = torch.exp(log_num - log_denom).clamp(zero, one) # sign-mixture inside slab lam = torch.exp(lg1 - lsum) lam = torch.where(torch.isfinite(lsum), lam, torch.full_like(lsum, 0.5)) # truncated-normal moments m_pos = xc - s2 * a m_neg = xc + s2 * a infp = torch.full_like(x, float("inf")) infn = -infp EX_pos = my_etruncnorm(zero, infp, mean=m_pos, sd=s) EX2_pos = my_e2truncnorm(zero, infp, mean=m_pos, sd=s) EX_neg = my_etruncnorm(infn, zero, mean=m_neg, sd=s) EX2_neg = my_e2truncnorm(infn, zero, mean=m_neg, sd=s) EX = lam * EX_pos + (one - lam) * EX_neg EX2 = lam * EX2_pos + (one - lam) * EX2_neg post_mean = gamma * (EX + mu) + (one - gamma) * mu post_mean2 = gamma * (EX2 + two * mu * EX + mu * mu) + (one - gamma) * (mu * mu) post_sd = (post_mean2 - post_mean**2).clamp_min(zero).sqrt() # PURE marginal log-likelihood llik = torch.logaddexp(torch.log1p(-w) + lf, torch.log(w.clamp_min(eps_t)) + lg).sum() # spike-only shortcut if slab weight is tiny — branchless so we never # block on a host-side comparison (that was a per-call CPU sync before). spike_only = w < thresh_pi_slab_t # 0-d bool tensor post_mean_so = torch.zeros_like(x) + mu post_mean2_so = torch.zeros_like(x) + mu * mu + torch.tensor(1e-4, device=device, dtype=dtype) llik_so = lf.sum() post_mean = torch.where(spike_only, post_mean_so, post_mean) post_mean2 = torch.where(spike_only, post_mean2_so, post_mean2) post_sd = (post_mean2 - post_mean**2).clamp_min(zero).sqrt() llik = torch.where(spike_only, llik_so, llik) return EBNMLaplaceResult( post_mean=post_mean, post_mean2=post_mean2, post_sd=post_sd, pi_slab=w.detach(), # mixture weight of the Laplace branch (slab) a=a.detach(), mu=mu.detach(), log_lik=llik.detach(), )