Source code for cebmf_torch.cebnm.cov_gb_prior

# ============================================================
# Covariate Generalized-Binary Prior Solver (CGB Solver, Torch-only)
# ============================================================

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset

from cebmf_torch.utils.posterior import posterior_point_mass_normal
from cebmf_torch.utils.standard_scaler import standard_scale


# -------------------------
# Dataset
# -------------------------
class DensityRegressionDataset(Dataset):
    def __init__(self, X, betahat, sebetahat):
        self.X = X
        self.betahat = betahat
        self.sebetahat = sebetahat

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.betahat[idx], self.sebetahat[idx]


# -------------------------


# MDN Model: π₂(x) + global μ₂
# -------------------------
class CgbNet(nn.Module):
    def __init__(self, input_dim, hidden_dim=32, n_layers=2):
        """
        Initialize a Covariate Generalized-Binary (CGB) neural network.

        Parameters
        ----------
        input_dim : int
            Number of input features.
        hidden_dim : int, optional
            Number of hidden units in each layer (default: 32).
        n_layers : int, optional
            Number of hidden layers (default: 2).
        """
        super().__init__()
        self.input_layer = nn.Linear(input_dim, hidden_dim)
        self.hidden_layers = nn.ModuleList([nn.Linear(hidden_dim, hidden_dim) for _ in range(n_layers)])
        self.output_layer = nn.Linear(hidden_dim, 1)  # logit for π₂(x)
        self.mu_2 = nn.Parameter(torch.tensor(0.0))  # global mean of slab

        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        """
        Forward pass through the CGB network.

        Parameters
        ----------
        x : torch.Tensor
            Input tensor of shape (N, input_dim).

        Returns
        -------
        pi_1 : torch.Tensor
            Probability of spike component for each observation.
        pi_2 : torch.Tensor
            Probability of slab component for each observation.
        mu_2 : torch.Tensor
            Global mean of the slab component.
        """
        x = self.relu(self.input_layer(x))
        for layer in self.hidden_layers:
            x = self.relu(layer(x))
        pi_2 = self.sigmoid(self.output_layer(x)).squeeze(-1)  # (N,)
        pi_1 = 1.0 - pi_2
        return pi_1, pi_2, self.mu_2


# -------------------------
# Loss (mixture NLL, stable)
# -------------------------
def cgb_loss(pi_1, pi_2, mu_2, sigma2_sq, targets, se, penalty=1.5, eps=1e-8):
    var1 = se**2
    var2 = sigma2_sq + se**2

    logp1 = -0.5 * ((targets - 0.0) ** 2 / var1 + torch.log(2 * torch.pi * var1))
    logp2 = -0.5 * ((targets - mu_2) ** 2 / var2 + torch.log(2 * torch.pi * var2))

    log_mix = torch.logaddexp(torch.log(pi_1.clamp_min(eps)) + logp1, torch.log(pi_2.clamp_min(eps)) + logp2)

    if penalty > 1.0:
        # Penalize per-observation spike probability
        log_pi0 = torch.log(pi_1.clamp_min(eps))
        log_mix = log_mix + (penalty - 1.0) * log_pi0

    return -(log_mix.mean())


# -------------------------
# E-step responsibilities (γ₂)
# -------------------------
def compute_responsibilities(pi_1, pi_2, mu_2, sigma2_sq, targets, se):
    var1 = se**2
    var2 = sigma2_sq + se**2

    logp1 = -0.5 * ((targets - 0.0) ** 2 / var1 + torch.log(2 * torch.pi * var1))
    logp2 = -0.5 * ((targets - mu_2) ** 2 / var2 + torch.log(2 * torch.pi * var2))

    log_num = torch.log(pi_2.clamp_min(1e-12)) + logp2
    log_den = torch.logaddexp(torch.log(pi_1.clamp_min(1e-12)) + logp1, log_num)
    return torch.exp(log_num - log_den)


# -------------------------
# M-step for σ₂²
# -------------------------
def m_step_sigma2(gamma2, mu2, targets, se):
    resid2 = (targets - mu2) ** 2
    sigma0_sq = se**2
    num = torch.sum(gamma2 * (resid2 - sigma0_sq))
    den = torch.sum(gamma2).clamp_min(1e-8)
    return torch.clamp(num / den, min=1e-6)


# -------------------------
# Result container
# -------------------------
class CgbPosteriorResult:
    def __init__(self, post_mean, post_mean2, post_sd, pi, mu_2, sigma_2, loss, model_param):
        """
        Container for the results of the CGB posterior mean estimation.

        Parameters
        ----------
        post_mean : torch.Tensor
            Posterior means for each observation.
        post_mean2 : torch.Tensor
            Posterior second moments for each observation.
        post_sd : torch.Tensor
            Posterior standard deviations for each observation.
        pi : torch.Tensor
            Spike probabilities for each observation.
        mu_2 : float
            Global mean of the slab component.
        sigma_2 : float
            Global standard deviation of the slab component.
        loss : float
            Final training loss or log-likelihood.
        model_param : dict
            Trained model parameters (state_dict).
        """
        self.post_mean = post_mean
        self.post_mean2 = post_mean2
        self.post_sd = post_sd
        self.pi = pi  # π₀(x): spike weight
        self.mu_2 = mu_2
        self.sigma_2 = sigma_2
        self.loss = loss
        self.model_param = model_param


@torch.no_grad()
def compute_marginal_loglik_full(model, X, betahat, se, sigma2_sq, eps=1e-12):
    """
    Exact marginal log-likelihood for current params.
    No penalty, computed on the FULL dataset (not batches).
    """
    model.eval()
    pi1, pi2, mu2 = model(X)

    var1 = se**2
    var2 = se**2 + sigma2_sq

    # log component densities
    logp1 = -0.5 * ((betahat - 0.0) ** 2 / var1 + torch.log(2 * torch.pi * var1))
    logp2 = -0.5 * ((betahat - mu2) ** 2 / var2 + torch.log(2 * torch.pi * var2))

    # stable log mixture
    log_mix = torch.logaddexp((pi1.clamp_min(eps)).log() + logp1, (pi2.clamp_min(eps)).log() + logp2)
    return log_mix.sum()  # scalar


# -------------------------
# Main solver
# -------------------------
[docs] def cgb_posterior_means( X, betahat, sebetahat, n_epochs=50, n_layers=2, hidden_dim=32, batch_size=128, lr=1e-3, penalty: float = 1.5, model_param=None, device: torch.device | None = None, ): """ Fit a Covariate Generalized-Binary (CGB) model to estimate the prior distribution of effects. Parameters ---------- X : torch.Tensor or np.ndarray Covariates for each observation, shape (n_samples, n_features). betahat : torch.Tensor or np.ndarray Observed effect estimates, shape (n_samples,). sebetahat : torch.Tensor or np.ndarray Standard errors of the effect estimates, shape (n_samples,). n_epochs : int, optional Number of training epochs (default=50). n_layers : int, optional Number of hidden layers in the neural network (default=2). hidden_dim : int, optional Number of hidden units in each layer (default=32). batch_size : int, optional Batch size for training (default=128). lr : float, optional Learning rate for the optimizer (default=1e-3). penalty : float, optional Penalty for spike probability (default=1.5). model_param : dict, optional Pre-trained model parameters to initialize the network. Returns ------- CgbPosteriorResult Container with posterior means, standard deviations, and model parameters. """ # Inherit from input tensor when available; avoids silent device hops if # the caller (e.g. cEBMF) is on CPU/MPS but CUDA is also visible. if device is None: device = ( betahat.device if isinstance(betahat, torch.Tensor) else (torch.device("cuda" if torch.cuda.is_available() else "cpu")) ) # ---- to tensor on device X = torch.as_tensor(X, dtype=torch.float32, device=device) betahat = torch.as_tensor(betahat, dtype=torch.float32, device=device) sebetahat = torch.as_tensor(sebetahat, dtype=torch.float32, device=device) if X.ndim == 1: X = X.reshape(-1, 1) # ---- scale on device X_scaled = standard_scale(X) # stays on device # ---- dataset / loader (GPU tensors, keep num_workers=0) dataset = DensityRegressionDataset(X_scaled, betahat, sebetahat) dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0) # ---- model / optimizer on device model = CgbNet(input_dim=X_scaled.shape[1], hidden_dim=hidden_dim, n_layers=n_layers).to(device) if model_param is not None: model.load_state_dict(model_param) optimizer = optim.Adam(model.parameters(), lr=lr) sigma2_sq = torch.tensor(1.0, dtype=torch.float32, device=device) # ---- training for epoch in range(n_epochs): model.eval() with torch.no_grad(): full_pi1, full_pi2, full_mu2 = model(dataset.X) # E-Step: Compute responsibilities across the entire dataset gamma2 = compute_responsibilities( full_pi1, full_pi2, full_mu2, sigma2_sq, dataset.betahat, dataset.sebetahat ) # M-Step: Update global variance sigma2_sq = m_step_sigma2(gamma2, full_mu2, dataset.betahat, dataset.sebetahat) # 2. GRADIENT DESCENT (Neural Net update over batches) model.train() total_loss = 0.0 for xb, xhat, se in dataloader: # already device tensors pi1, pi2, mu2 = model(xb) # Calculate loss using the fixed, global sigma2_sq loss = cgb_loss(pi1, pi2, mu2, sigma2_sq, xhat, se, penalty=penalty) optimizer.zero_grad() loss.backward() optimizer.step() total_loss += loss.item() if (epoch + 1) % 10 == 0: print( f"[CGB] Epoch {epoch + 1}/{n_epochs}, " f"Loss={total_loss / len(dataloader):.4f}, " f"mu2={full_mu2.item():.3f}, sigma2={sigma2_sq.sqrt().item():.3f}" ) # ---- posterior inference model.eval() with torch.no_grad(): pi1, pi2, mu2 = model(dataset.X) post_mean, post_var = posterior_point_mass_normal( betahat=dataset.betahat, sebetahat=dataset.sebetahat, pi=pi1, mu0=0.0, mu1=mu2.item(), sigma_0=sigma2_sq.sqrt().item(), ) post_mean2 = post_var + post_mean**2 post_sd = torch.sqrt(torch.clamp(post_var, min=0.0)) log_marginal = compute_marginal_loglik_full( model, X=dataset.X, # X_scaled (full) betahat=dataset.betahat, # full se=dataset.sebetahat, # full sigma2_sq=sigma2_sq, ) return CgbPosteriorResult( post_mean=post_mean, post_mean2=post_mean2, post_sd=post_sd, pi=pi1, mu_2=mu2.item(), sigma_2=sigma2_sq.sqrt().item(), loss=-float(log_marginal.item()), model_param=model.state_dict(), )