Source code for cebmf_torch.cebnm.cash_solver

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset

from cebmf_torch.utils.distribution_operation import get_data_loglik_normal_torch
from cebmf_torch.utils.mixture import autoselect_scales_mix_norm
from cebmf_torch.utils.posterior import posterior_mean_norm
from cebmf_torch.utils.standard_scaler import standard_scale


# ---- Dataset: assumes tensors already on correct device/dtype
class DensityRegressionDataset(Dataset):
    def __init__(self, X: torch.Tensor, betahat: torch.Tensor, sebetahat: torch.Tensor):
        self.X = X
        self.betahat = betahat
        self.sebetahat = sebetahat

    def __len__(self):
        return self.X.shape[0]

    def __getitem__(self, idx):
        return self.X[idx], self.betahat[idx], self.sebetahat[idx]


class CashNet(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_classes, n_layers):
        """
        Initialize a neural network for CASH (Covariate Adaptive Shrinkage).

        Parameters
        ----------
        input_dim : int
            Number of input features.
        hidden_dim : int
            Number of hidden units in each layer.
        num_classes : int
            Number of mixture components (output classes).
        n_layers : int
            Number of hidden layers.
        """
        super().__init__()
        self.input_layer = nn.Linear(input_dim, hidden_dim)
        self.hidden_layers = nn.ModuleList([nn.Linear(hidden_dim, hidden_dim) for _ in range(n_layers)])
        self.output_layer = nn.Linear(hidden_dim, num_classes)
        self.relu = nn.ReLU()
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        """
        Forward pass through the CASH network.

        Parameters
        ----------
        x : torch.Tensor
            Input tensor of shape (N, input_dim).

        Returns
        -------
        torch.Tensor
            Mixture weights for each observation, shape (N, num_classes).
        """
        x = self.relu(self.input_layer(x))
        for layer in self.hidden_layers:
            x = self.relu(layer(x))
        return self.softmax(self.output_layer(x))


# Custom loss function
def pen_loglik_loss(pred_pi, marginal_log_lik, penalty=1.5, epsilon=1e-10):
    L_batch = torch.exp(marginal_log_lik)  # (B, K)
    inner_sum = torch.sum(pred_pi * L_batch, dim=1)  # (B,)
    inner_sum = torch.clamp(inner_sum, min=epsilon)
    first_sum = torch.sum(torch.log(inner_sum))  # scalar

    if penalty > 1:
        # penalize per-gene spike probability (Dirichlet-like prior on component 0)
        log_pi0 = torch.log(torch.clamp(pred_pi[:, 0], min=epsilon))
        penalized_log_likelihood_value = first_sum + (penalty - 1) * torch.sum(log_pi0)
    else:
        penalized_log_likelihood_value = first_sum

    return -penalized_log_likelihood_value


class cash_PosteriorMeanNorm:
    def __init__(self, post_mean, post_mean2, post_sd, pi_np, scale, loss=0, model_param=None):
        """
        Container for the results of the CASH posterior mean estimation.

        Parameters
        ----------
        post_mean : torch.Tensor
            Posterior means for each observation.
        post_mean2 : torch.Tensor
            Posterior second moments for each observation.
        post_sd : torch.Tensor
            Posterior standard deviations for each observation.
        pi_np : torch.Tensor
            Mixture weights for each observation.
        scale : torch.Tensor
            Mixture component scales.
        loss : float, optional
            Final training loss or log-likelihood.
        model_param : dict, optional
            Trained model parameters (state_dict).
        """
        self.post_mean = post_mean
        self.post_mean2 = post_mean2
        self.post_sd = post_sd
        self.pi_np = pi_np
        self.loss = loss
        self.scale = scale
        self.model_param = model_param


# Class to store the results


[docs] def cash_posterior_means( X, betahat, sebetahat, n_epochs=20, n_layers=4, num_classes=20, hidden_dim=64, batch_size=128, lr=0.001, model_param=None, penalty=1.5, device: torch.device | None = None, ): """ GPU-native CASH training and posterior computation. Fit a CASH (Covariate Adaptive Shrinkage) model and compute posterior means, second moments, and standard deviations. Parameters ---------- X : torch.Tensor or np.ndarray Covariates for each observation, shape (n_samples, n_features). betahat : torch.Tensor or np.ndarray Observed effect estimates, shape (n_samples,). sebetahat : torch.Tensor or np.ndarray Standard errors of the effect estimates, shape (n_samples,). n_epochs : int, optional Number of training epochs (default=20). n_layers : int, optional Number of hidden layers in the neural network (default=4). num_classes : int, optional Number of mixture components (default=20). hidden_dim : int, optional Number of hidden units in each layer (default=64). batch_size : int, optional Batch size for training (default=128). lr : float, optional Learning rate for the optimizer (default=0.001). model_param : dict, optional Pre-trained model parameters to initialize the network. penalty : float, optional Penalty for spike probability (default=1.5). device : torch.device, optional Target device for tensors and the model. If ``None``, inherits from ``betahat`` when it is already a tensor; otherwise falls back to CUDA (if available) or CPU. Inheriting from the input avoids silent cross-device hops when the caller (e.g. ``cEBMF``) is on CPU/MPS but CUDA is also visible on the host. Returns ------- cash_PosteriorMeanNorm Container with posterior means, standard deviations, and model parameters. """ if device is None: device = ( betahat.device if isinstance(betahat, torch.Tensor) else (torch.device("cuda" if torch.cuda.is_available() else "cpu")) ) # ---- to tensors on device X = torch.as_tensor(X, dtype=torch.float32, device=device) if X.ndim == 1: X = X.reshape(-1, 1) betahat = torch.as_tensor(betahat, dtype=torch.float32, device=device) sebetahat = torch.as_tensor(sebetahat, dtype=torch.float32, device=device) # ---- standardize on device X_scaled = standard_scale(X) # your function returns just scaled tensor # ---- mixture scales (ensure tensor on device) scale = autoselect_scales_mix_norm(betahat=betahat, sebetahat=sebetahat, max_class=num_classes) if not isinstance(scale, torch.Tensor): scale = torch.as_tensor(scale, dtype=torch.float32, device=device) else: scale = scale.to(device=device, dtype=torch.float32) # ---- dataset / loader (CUDA tensors => num_workers must be 0) dataset = DensityRegressionDataset(X_scaled, betahat, sebetahat) dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0) # ---- model / optimizer on device input_dim = X_scaled.shape[1] model_cash = CashNet(input_dim=input_dim, hidden_dim=hidden_dim, num_classes=num_classes, n_layers=n_layers).to( device ) if model_param is not None: model_cash.load_state_dict(model_param) optimizer_cash = optim.Adam(model_cash.parameters(), lr=lr) # ---- training for epoch in range(n_epochs): epoch_loss = 0.0 for inputs, targets, noise_std in dataloader: # Compute (log)likelihood for this batch and current global scales batch_loglik = get_data_loglik_normal_torch( betahat=targets, sebetahat=noise_std, location=0 * scale, scale=scale ) optimizer_cash.zero_grad() outputs = model_cash(inputs) cash_loss = pen_loglik_loss(pred_pi=outputs, marginal_log_lik=batch_loglik, penalty=penalty) cash_loss.backward() optimizer_cash.step() epoch_loss += cash_loss.item() if (epoch + 1) % 10 == 0: print(f"[CASH] Epoch {epoch + 1}/{n_epochs} | Loss: {epoch_loss / max(1, len(dataloader)):.4f}") # ---- full-batch inference (no grad) model_cash.eval() with torch.no_grad(): train_loader = DataLoader(dataset, batch_size=len(dataset), shuffle=False, num_workers=0) for X_batch, _, _ in train_loader: all_pi_values = model_cash(X_batch) # (N, K) data_loglik = get_data_loglik_normal_torch( betahat=betahat, sebetahat=sebetahat, location=0 * scale, scale=scale ) # (N, K) # ---- Vectorised posterior over all observations in a single call. # ``posterior_mean_norm`` accepts (J, K) ``log_pi`` directly, so we no # longer iterate over observations. Replaces the previous # ``for i in range(J): ...`` loop, which caused J kernel launches and # made GPU execution dramatically slower than CPU for large J. eps = 1e-300 log_pi_full = torch.log(all_pi_values.clamp_min(eps)) # (N, K) result = posterior_mean_norm( betahat=betahat, sebetahat=sebetahat, log_pi=log_pi_full, data_loglik=data_loglik, location=torch.zeros_like(scale), # shared (K,) zero location scale=scale, ) post_mean = result.post_mean post_mean2 = result.post_mean2 post_sd = result.post_sd # ---- Full marginal log-likelihood (no penalty). Same logsumexp formula # as before; this matches the convention expected by ``cebmf.py``'s # ``self.kl_l[k] = (-resL.loss) - nm_ll_L`` accumulator. log_marginal_per_obs = torch.logsumexp(data_loglik + log_pi_full, dim=1) # (N,) full_marginal_ll = float(log_marginal_per_obs.sum().item()) return cash_PosteriorMeanNorm( post_mean=post_mean, post_mean2=post_mean2, post_sd=post_sd, pi_np=all_pi_values, # (N, K) on device loss=-full_marginal_ll, scale=scale, # (K,) on device model_param=model_cash.state_dict(), )