Source code for cebmf_torch.cebnm.lcash

"""LC-ASH: Linear Covariate Adaptive Shrinkage.

Two parameterisations:
  - Softmax (multinomial logistic): K independent logit vectors, K*F params.
  - Proportional odds (ordered logistic): shared weight vector, F+K-1 params.

Both map gene features to mixture weights.  A linear alternative to the
MLP-based CASH solver, with ash-based bias/cut-point initialisation and
grid pruning.
"""

import warnings

import torch
import torch.nn as nn

from cebmf_torch.cebnm.cash_solver import (
    cash_PosteriorMeanNorm,
    pen_loglik_loss,
)
from cebmf_torch.ebnm.ash import PriorType, ash
from cebmf_torch.utils.distribution_operation import get_data_loglik_normal_torch
from cebmf_torch.utils.mixture import autoselect_scales_mix_norm

# ============================================================
# Model classes
# ============================================================


class LcashNet(nn.Module):
    """Multinomial logistic regression: features -> mixture weights.

    A single nn.Linear(F, K) followed by softmax. Equivalent to
    multinomial logistic regression with K classes and F features.

    Parameters
    ----------
    input_dim : int
        Number of input features.
    num_classes : int
        Number of mixture components (output classes).
    log_pi_init : torch.Tensor or None
        If provided, (K,) tensor of centred log-weights from a global ash
        fit. Used to initialise the bias so that softmax(bias) approximates
        the global ash pi when all feature coefficients are zero.
    """

    def __init__(
        self,
        input_dim: int,
        num_classes: int,
        log_pi_init: torch.Tensor | None = None,
        generator: torch.Generator | None = None,
    ):
        super().__init__()
        self.linear = nn.Linear(input_dim, num_classes)
        # Small random perturbation breaks symmetry across features.
        # Starting from exact zeros leads Adam to different local
        # optima on high-dimensional feature sets (F > 100).
        nn.init.normal_(self.linear.weight, mean=0.0, std=0.01, generator=generator)
        if log_pi_init is not None:
            with torch.no_grad():
                self.linear.bias.copy_(log_pi_init)
        else:
            nn.init.zeros_(self.linear.bias)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return torch.softmax(self.linear(x), dim=1)


class PropOddsLcashNet(nn.Module):
    """Proportional odds (ordered logistic) mapping: features -> mixture weights.

    A single shared weight vector maps features to a scalar signal
    strength s_i = x_i^T w.  K-1 ordered cut-points convert s_i to
    mixture weights via cumulative logistic probabilities.

    Parameters
    ----------
    input_dim : int
        Number of input features.
    num_classes : int
        Number of mixture components (K).
    log_pi_init : torch.Tensor or None
        If provided, (K,) tensor of centred log-weights from a global ash
        fit.  Used to initialise ordered cut-points so that the model
        recovers the global ash pi when all feature coefficients are zero.
    """

    def __init__(
        self,
        input_dim: int,
        num_classes: int,
        log_pi_init: torch.Tensor | None = None,
        generator: torch.Generator | None = None,
    ):
        super().__init__()
        K = num_classes

        # Shared feature weights: initialised near zero so the model
        # starts close to the exchangeable prior.
        self.w = nn.Parameter(torch.empty(input_dim))
        nn.init.normal_(self.w, mean=0.0, std=0.01, generator=generator)

        # Cut-point parameterisation: delta_1 (free), delta_2..K-1 (gaps)
        if log_pi_init is not None and K > 1:
            init_cuts = self._init_cutpoints_from_pi(log_pi_init, K)
        else:
            init_cuts = torch.linspace(-2.0, 2.0, K - 1)

        self.delta_1 = nn.Parameter(init_cuts[0:1])  # (1,)
        if K > 2:
            gaps = torch.log(torch.clamp(init_cuts[1:] - init_cuts[:-1], min=1e-6))
            self.delta_gaps = nn.Parameter(gaps)  # (K-2,)
        else:
            self.delta_gaps = None

        self._K = K

    @staticmethod
    def _init_cutpoints_from_pi(log_pi_init: torch.Tensor, K: int) -> torch.Tensor:
        """Initialise cut-points so that sigma(theta_k) approx cumprob_k.

        At initialisation w ~ 0, so s_i ~ 0 for all genes.  Then
        pi_k = sigma(theta_{k+1}) - sigma(theta_k), so we need
        sigma(theta_k) = sum_{j<k} pi_j, i.e. theta_k = logit(cumprob_k).
        """
        pi = torch.exp(log_pi_init - log_pi_init.max())
        pi = pi / pi.sum()
        cumprob = torch.cumsum(pi, dim=0)[:-1]  # K-1 values
        cumprob = torch.clamp(cumprob, 1e-6, 1 - 1e-6)
        cuts = torch.log(cumprob / (1 - cumprob))
        return cuts

    def _get_cutpoints(self) -> torch.Tensor:
        """Reconstruct ordered cut-points from unconstrained parameters."""
        # K=1: degenerate case, all weight on the single component.
        if self._K == 1:
            return torch.empty(0, device=self.delta_1.device)
        if self.delta_gaps is not None:
            gaps = torch.exp(self.delta_gaps)
            return torch.cat([self.delta_1, self.delta_1 + torch.cumsum(gaps, dim=0)])
        return self.delta_1  # K = 2: single cut-point

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Compute mixture weights pi_k for each gene.

        Parameters
        ----------
        x : tensor (G, F)
            Feature matrix.

        Returns
        -------
        tensor (G, K)
            Per-gene mixture weights.
        """
        s = x @ self.w  # (G,)
        theta = self._get_cutpoints()  # (K-1,)

        # Cumulative probabilities: P(category <= k) = sigma(theta_k - s)
        cum_probs = torch.sigmoid(theta.unsqueeze(0) - s.unsqueeze(1))  # (G, K-1)

        # Convert cumulative to category probabilities
        ones = torch.ones(s.shape[0], 1, device=x.device)
        zeros = torch.zeros(s.shape[0], 1, device=x.device)
        cum_ext = torch.cat([zeros, cum_probs, ones], dim=1)  # (G, K+1)
        pi = cum_ext[:, 1:] - cum_ext[:, :-1]  # (G, K)

        # Numerical safety: clamp small negatives from floating-point
        pi = torch.clamp(pi, min=1e-10)
        pi = pi / pi.sum(dim=1, keepdim=True)

        return pi


# ============================================================
# Shared helpers
# ============================================================


def _prepare_inputs(
    X: torch.Tensor,
    betahat: torch.Tensor,
    sebetahat: torch.Tensor,
    device: torch.device,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """Convert inputs to float32 tensors on device and standardise X.

    Uses NaN-aware standardisation: mean and std are computed on
    non-NaN values only, then NaN positions are zero-filled.  This
    ensures that missing features contribute nothing to the logits
    (falling back to the intercept/global prior) and that the
    statistics are not biased by the zero-fill.
    """
    X = torch.as_tensor(X, dtype=torch.float32, device=device)
    if X.ndim == 1:
        X = X.reshape(-1, 1)
    betahat = torch.as_tensor(betahat, dtype=torch.float32, device=device)
    sebetahat = torch.as_tensor(sebetahat, dtype=torch.float32, device=device)
    X_scaled = _nanstandardise(X)
    return X_scaled, betahat, sebetahat


def _nanstandardise(X: torch.Tensor) -> torch.Tensor:
    """Standardise columns using non-NaN values, then zero-fill NaN.

    Vectorised implementation. For each column, compute mean and
    population std on observed (non-NaN) entries, standardise observed
    values, and set NaN positions to 0. Columns with zero std
    (constant or all-NaN) are set to 0.
    """
    mask = ~torch.isnan(X)
    counts = mask.sum(dim=0)  # (F,)

    # Replace NaN with 0 for safe summation
    X_filled = torch.where(mask, X, torch.zeros_like(X))

    # Mean on observed values
    safe_counts = counts.clamp(min=1)
    mu = X_filled.sum(dim=0) / safe_counts  # (F,)

    # Population std on observed values
    diff = torch.where(mask, X - mu, torch.zeros_like(X))
    var = (diff**2).sum(dim=0) / safe_counts  # (F,)
    sd = var.sqrt()

    # Standardise observed, zero-fill missing
    safe_sd = torch.where((sd > 0) & (counts > 1), sd, torch.ones_like(sd))
    X_out = torch.where(
        mask & (sd > 0).unsqueeze(0) & (counts > 1).unsqueeze(0),
        diff / safe_sd,
        torch.zeros_like(X),
    )
    return X_out


def _select_grid(
    betahat: torch.Tensor,
    sebetahat: torch.Tensor,
    mult: float,
    ash_init: bool,
    ash_threshold: float,
    device: torch.device,
) -> tuple[torch.Tensor, torch.Tensor | None]:
    """Select mixture grid and (optionally) initialise from ash.

    Always builds the grid via ``autoselect_scales_mix_norm(mult=mult)``.
    When ``ash_init=True``, additionally runs a full ash fit (L-BFGS
    optimizer) to determine which components are active and initialise
    the bias/cut-points from the ash mixture weights.

    Parameters
    ----------
    mult : float
        Multiplicative step between grid SDs.  Smaller values give a
        finer grid with more components (sqrt(2) ≈ 27 components,
        2.0 ≈ 15 components for typical data).
    ash_init : bool
        If True, run ash internally with ``optimizer="lbfgs"`` to
        prune the grid to active components and initialise bias from
        the ash weights.
    ash_threshold : float
        Pruning threshold: components with ``pi <= ash_threshold``
        are dropped.  Only used when ``ash_init=True``.

    Returns
    -------
    scale : tensor (K,)
        Mixture component standard deviations.
    log_pi_init : tensor (K,) or None
        Centred log-weights for bias/cut-point initialisation, or
        None when ``ash_init=False``.
    """
    if ash_init:
        ash_result = ash(betahat, sebetahat, prior=PriorType.NORM, verbose=False, optimizer="lbfgs", mult=mult)
        pi_full = ash_result.pi
        active = pi_full > ash_threshold
        # Fallback: ensure at least K=2 (spike + one slab)
        if active.sum() < 2:
            active = torch.zeros_like(pi_full, dtype=torch.bool)
            active[0] = True
            non_spike = pi_full.clone()
            non_spike[0] = -1.0
            active[non_spike.argmax()] = True
        scale = ash_result.scale[active].to(device=device, dtype=torch.float32)
        pi_active = pi_full[active]
        log_pi_init = torch.log(pi_active.clamp(min=1e-30))
        log_pi_init = log_pi_init - log_pi_init.mean()
        log_pi_init = log_pi_init.to(device=device, dtype=torch.float32)
        return scale, log_pi_init

    scale = autoselect_scales_mix_norm(betahat=betahat, sebetahat=sebetahat, mult=mult)
    if not isinstance(scale, torch.Tensor):
        scale = torch.as_tensor(scale, dtype=torch.float32, device=device)
    else:
        scale = scale.to(device=device, dtype=torch.float32)
    return scale, None


def _train_model(
    model: nn.Module,
    optimizer: torch.optim.Optimizer,
    X_scaled: torch.Tensor,
    betahat: torch.Tensor,
    sebetahat: torch.Tensor,
    scale: torch.Tensor,
    n_epochs: int,
    batch_size: int,
    penalty: float,
    verbose: bool,
    label: str,
    seed: int = 42,
) -> float:
    """Run the training loop. Returns the final-epoch total loss.

    Pre-computes the (G, K) log-likelihood matrix once rather than
    recomputing per mini-batch (logL is constant during training).
    Batch ordering is seeded for reproducibility.
    """
    model.train()
    device = X_scaled.device

    # Pre-compute log-likelihood matrix (constant during training).
    loc = torch.zeros_like(scale)
    with torch.no_grad():
        logL_all = get_data_loglik_normal_torch(
            betahat=betahat,
            sebetahat=sebetahat,
            location=loc,
            scale=scale,
        )

    # Seeded manual batching (5x faster than DataLoader due to
    # avoiding per-sample __getitem__ and collation overhead).
    # Generator and permutation are created on the same device as the
    # data to avoid CPU/GPU device mismatches.
    g = torch.Generator(device=device)
    g.manual_seed(seed)
    n = X_scaled.shape[0]
    n_batches = max(1, (n + batch_size - 1) // batch_size)

    final_epoch_loss = 0.0
    for epoch in range(n_epochs):
        epoch_loss = 0.0
        perm = torch.randperm(n, generator=g, device=device)
        for start in range(0, n, batch_size):
            idx = perm[start : start + batch_size]
            pi_pred = model(X_scaled[idx])
            loss = pen_loglik_loss(pi_pred, logL_all[idx], penalty=penalty)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
        final_epoch_loss = epoch_loss
        if verbose and (epoch + 1) % 50 == 0:
            print(f"[{label}] Epoch {epoch + 1}/{n_epochs} | Loss: {epoch_loss / n_batches:.4f}")

    return final_epoch_loss


def _compute_posteriors(
    model: nn.Module,
    X_scaled: torch.Tensor,
    betahat: torch.Tensor,
    sebetahat: torch.Tensor,
    scale: torch.Tensor,
    device: torch.device,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, float]:
    """Vectorised posterior computation with per-observation pi.

    Assumes location = 0 for all mixture components (spike mean is 0).
    This matches the zero-centred normal mixture prior used by LC-ASH.

    Returns
    -------
    post_mean, post_mean2, post_sd, all_pi_values, marginal_loglik
        ``marginal_loglik`` is the full-data marginal log-likelihood
        ``sum_g logsumexp_k (log pi_g,k + log p(beta_g | 0, sqrt(se_g^2 + scale_k^2)))``,
        i.e. ``log p(y | fitted prior)`` without any spike Dirichlet penalty.
        It is what the cebmf consumer at ``cebmf.py:299``
        (``self.kl_l[k] = (-resL.loss) - nm_ll_L``) requires of the loss
        field on this object.
    """
    model.eval()
    loc = torch.zeros_like(scale)
    with torch.no_grad():
        all_pi_values = model(X_scaled)  # (G, K)

        data_loglik = get_data_loglik_normal_torch(
            betahat=betahat, sebetahat=sebetahat, location=loc, scale=scale
        )  # (G, K)

        # Use dtype-appropriate eps to avoid log(0) in float32.
        eps = torch.finfo(all_pi_values.dtype).tiny
        log_pi_all = torch.log(torch.clamp(all_pi_values, min=eps))  # (G, K)
        combined = data_loglik + log_pi_all  # (G, K)
        log_norm = torch.logsumexp(combined, dim=1, keepdim=True)  # (G, 1)
        # log_norm[g] is the per-gene marginal log-likelihood of the fitted
        # mixture; summing gives the full-data marginal log-lik (no penalty).
        marginal_loglik = float(log_norm.sum().item())
        resp = torch.exp(combined - log_norm)  # (G, K) responsibilities

        s2 = sebetahat.pow(2).unsqueeze(1)  # (G, 1)
        t2 = scale.pow(2).unsqueeze(0)  # (1, K)

        denom = (1.0 / s2) + torch.where(t2 > 0, 1.0 / t2, torch.zeros_like(t2))
        post_var_comp = torch.where(t2 > 0, 1.0 / denom, torch.zeros_like(denom))  # (G, K)

        m_comp = torch.where(
            t2 > 0,
            post_var_comp * (betahat.unsqueeze(1) / s2),
            torch.zeros(1, device=device),
        )  # (G, K)

        post_mean = torch.sum(resp * m_comp, dim=1)
        post_mean2 = torch.sum(resp * (post_var_comp + m_comp.pow(2)), dim=1)
        post_sd = torch.sqrt(torch.clamp(post_mean2 - post_mean.pow(2), min=0.0))

    return post_mean, post_mean2, post_sd, all_pi_values, marginal_loglik


def _warm_start(
    model: nn.Module,
    model_param: dict | None,
    label: str,
) -> None:
    """Load state dict with a guard against architecture mismatch."""
    if model_param is not None:
        try:
            model.load_state_dict(model_param)
        except RuntimeError:
            warnings.warn(
                f"{label} warm-start skipped: grid size changed between iterations",
                stacklevel=3,
            )


def _fit_lcash(
    X: torch.Tensor,
    betahat: torch.Tensor,
    sebetahat: torch.Tensor,
    model_class: type,
    label: str,
    n_epochs: int = 200,
    batch_size: int = 512,
    lr: float = 1e-3,
    weight_decay: float = 1e-3,
    penalty: float = 1.5,
    mult: float = 1.4142135623730951,
    ash_init: bool = True,
    ash_threshold: float = 1e-6,
    model_param: dict | None = None,
    device: torch.device | None = None,
    verbose: bool = True,
    seed: int = 42,
) -> cash_PosteriorMeanNorm:
    """Shared implementation for both softmax and proportional odds LC-ASH.

    Parameters
    ----------
    model_class : type
        Either ``LcashNet`` or ``PropOddsLcashNet``.
    label : str
        Label for verbose logging (e.g. "LC-ASH" or "PO-LC-ASH").

    See ``lcash_posterior_means`` for other parameter descriptions.
    """
    # Inherit from input tensor when available; avoids silent device hops if
    # the caller (e.g. cEBMF) is on CPU/MPS but CUDA is also visible.
    if device is None:
        device = (
            betahat.device
            if isinstance(betahat, torch.Tensor)
            else (torch.device("cuda" if torch.cuda.is_available() else "cpu"))
        )
    if n_epochs is None:
        n_epochs = 200

    X_scaled, betahat, sebetahat = _prepare_inputs(X, betahat, sebetahat, device)
    scale, log_pi_init = _select_grid(betahat, sebetahat, mult, ash_init, ash_threshold, device)

    # Local RNG for reproducible weight init and batch ordering.
    # Does not mutate global torch RNG state.
    rng = torch.Generator(device=device)
    rng.manual_seed(seed)

    K = scale.shape[0]
    model = model_class(X_scaled.shape[1], K, log_pi_init=log_pi_init, generator=rng).to(device)
    _warm_start(model, model_param, label)

    # Build optimizer: weight_decay on feature weights only.
    if model_class is LcashNet:
        param_groups = [
            {"params": [model.linear.weight], "weight_decay": weight_decay},
            {"params": [model.linear.bias], "weight_decay": 0.0},
        ]
    else:  # PropOddsLcashNet
        cutpoint_params = [model.delta_1]
        if model.delta_gaps is not None:
            cutpoint_params.append(model.delta_gaps)
        param_groups = [
            {"params": [model.w], "weight_decay": weight_decay},
            {"params": cutpoint_params, "weight_decay": 0.0},
        ]
    optimizer = torch.optim.Adam(param_groups, lr=lr)

    _train_model(
        model,
        optimizer,
        X_scaled,
        betahat,
        sebetahat,
        scale,
        n_epochs,
        batch_size,
        penalty,
        verbose,
        label,
        seed=seed,
    )

    post_mean, post_mean2, post_sd, all_pi_values, marginal_loglik = _compute_posteriors(
        model,
        X_scaled,
        betahat,
        sebetahat,
        scale,
        device,
    )

    # `loss` is the negative full-data marginal log-likelihood under the
    # fitted prior, *without* the spike Dirichlet penalty. This matches
    # the convention used by `cebnm/emdn.py` and is the meaning required
    # by `cebmf.py`'s per-factor `kl_l[k] = (-loss) - nm_ll_L` formula.
    # The previous training-loss-on-final-epoch return value was an
    # unfinished refactor (cf. the `# compute proper full negative
    # marginal log-likelihood (no penalty)` TODO comments that used to
    # live in `cash_solver.py`).
    return cash_PosteriorMeanNorm(
        post_mean=post_mean,
        post_mean2=post_mean2,
        post_sd=post_sd,
        pi_np=all_pi_values,
        loss=-marginal_loglik,
        scale=scale,
        model_param=model.state_dict(),
    )


# ============================================================
# Public entry points
# ============================================================


[docs] def lcash_posterior_means( X: torch.Tensor, betahat: torch.Tensor, sebetahat: torch.Tensor, n_epochs: int | None = 200, batch_size: int = 512, lr: float = 1e-3, weight_decay: float = 1e-3, penalty: float = 1.5, mult: float = 1.4142135623730951, ash_init: bool = True, ash_threshold: float = 1e-6, model_param: dict | None = None, device: torch.device | None = None, verbose: bool = True, seed: int = 42, ) -> cash_PosteriorMeanNorm: """LC-ASH: linear covariate-modulated mixture weights. Parameters ---------- X : tensor (G, F) Feature matrix. Standardised internally with NaN-aware statistics (mean/std computed on non-NaN values, NaN positions zero-filled). Pre-standardisation is not required. betahat : tensor (G,) Effect estimates. sebetahat : tensor (G,) Standard errors. n_epochs : int or None Training epochs. Inside cEBMF, overridden by internal_epoch. batch_size : int Mini-batch size for Adam. lr : float Learning rate. weight_decay : float L2 penalty on feature coefficients only (not bias). penalty : float Dirichlet spike penalty (lambda_pen). 1.0 = no penalty. mult : float Multiplicative step between mixture grid SDs. Smaller values give a finer grid with more components. Default sqrt(2) matches R ashr and gives ~27 components before pruning. ash_init : bool If True (default), run ash internally (L-BFGS optimizer) to prune the grid to active components and initialise the bias from the ash weights, so the model starts at the exchangeable ash solution when all feature coefficients are zero. If False, use the full grid with uniform bias initialisation. ash_threshold : float Pruning threshold: components with pi <= threshold are dropped. Only used when ``ash_init=True``. model_param : dict or None State dict from a previous call, for warm-starting. device : torch.device or None Compute device. Defaults to CUDA if available. verbose : bool If True (default), print training progress every 50 epochs. seed : int Random seed for weight initialisation and batch ordering. Returns ------- cash_PosteriorMeanNorm Container with post_mean, post_mean2, post_sd, pi_np (G, K), scale (K,), loss, model_param (state dict for warm-starting). """ return _fit_lcash( X, betahat, sebetahat, model_class=LcashNet, label="LC-ASH", n_epochs=n_epochs, batch_size=batch_size, lr=lr, weight_decay=weight_decay, penalty=penalty, mult=mult, ash_init=ash_init, ash_threshold=ash_threshold, model_param=model_param, device=device, verbose=verbose, seed=seed, )
[docs] def po_lcash_posterior_means( X: torch.Tensor, betahat: torch.Tensor, sebetahat: torch.Tensor, n_epochs: int | None = 200, batch_size: int = 512, lr: float = 1e-3, weight_decay: float = 1e-3, penalty: float = 1.5, mult: float = 1.4142135623730951, ash_init: bool = True, ash_threshold: float = 1e-6, model_param: dict | None = None, device: torch.device | None = None, verbose: bool = True, seed: int = 42, ) -> cash_PosteriorMeanNorm: """Proportional odds LC-ASH: ordered logistic covariate-modulated weights. A shared weight vector maps features to a scalar signal strength s_i = x_i^T w. K-1 ordered cut-points convert s_i to mixture weights via cumulative logistic probabilities. This has F + K - 1 parameters (vs K * F for softmax LC-ASH), making it more parsimonious when K is large relative to F. When ``ash_init=True``, the grid is pruned to ash's active components and the cut-points are initialised from the ash weights, so the model starts at the exchangeable ash solution. Parameters ---------- X : tensor (G, F) Feature matrix. Standardised internally with NaN-aware statistics. betahat, sebetahat, n_epochs, batch_size, lr, weight_decay, penalty, mult, ash_init, ash_threshold, model_param, device, verbose, seed : See ``lcash_posterior_means``. Returns ------- cash_PosteriorMeanNorm Container with post_mean, post_mean2, post_sd, pi_np (G, K), scale (K,), loss, model_param (state dict for warm-starting). """ return _fit_lcash( X, betahat, sebetahat, model_class=PropOddsLcashNet, label="PO-LC-ASH", n_epochs=n_epochs, batch_size=batch_size, lr=lr, weight_decay=weight_decay, penalty=penalty, mult=mult, ash_init=ash_init, ash_threshold=ash_threshold, model_param=model_param, device=device, verbose=verbose, seed=seed, )