Source code for cebmf_torch.cebnm.spiked_emdn

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset

from cebmf_torch.utils.distribution_operation import get_data_loglik_normal_torch
from cebmf_torch.utils.posterior import posterior_mean_norm
from cebmf_torch.utils.standard_scaler import standard_scale


# -------------------------
# Dataset (expects tensors already on correct device/dtype)
# -------------------------
class DensityRegressionDataset(Dataset):
    def __init__(self, X: torch.Tensor, betahat: torch.Tensor, sebetahat: torch.Tensor):
        self.X = X
        self.betahat = betahat
        self.sebetahat = sebetahat

    def __len__(self):
        return self.X.shape[0]

    def __getitem__(self, idx):
        return self.X[idx], self.betahat[idx], self.sebetahat[idx]


# -------------------------
# Mixture Density Network (spike + slabs)
#   pi: (N, K) for [spike, slabs...]
#   mu/log_sigma: (N, K-1) for slabs only
# -------------------------
class MDN(nn.Module):
    def __init__(self, input_dim, hidden_dim, n_gaussians, n_layers=4):
        """
        Initialize a Mixture Density Network (MDN) with spike and slab components.

        Parameters
        ----------
        input_dim : int
            Number of input features.
        hidden_dim : int
            Number of hidden units in each layer.
        n_gaussians : int
            Number of Gaussian components (must be >= 2 for spike + at least one slab).
        n_layers : int, optional
            Number of hidden layers (default is 4).
        """
        super().__init__()
        assert n_gaussians >= 2, "Need at least 1 spike + 1 slab."
        self.fc_in = nn.Linear(input_dim, hidden_dim)
        self.hidden_layers = nn.ModuleList([nn.Linear(hidden_dim, hidden_dim) for _ in range(n_layers)])
        self.pi = nn.Linear(hidden_dim, n_gaussians)  # includes spike (k=0)
        self.mu = nn.Linear(hidden_dim, n_gaussians - 1)  # slabs only
        self.log_sigma = nn.Linear(hidden_dim, n_gaussians - 1)  # slabs only
        self.point_mass = 0.0

    def forward(self, x):
        """
        Forward pass through the MDN.

        Parameters
        ----------
        x : torch.Tensor
            Input tensor of shape (N, input_dim).

        Returns
        -------
        pi : torch.Tensor
            Mixture weights for spike and slabs, shape (N, K).
        mu : torch.Tensor
            Means for slab components, shape (N, K-1).
        log_sigma : torch.Tensor
            Log standard deviations for slab components, shape (N, K-1).
        """
        x = torch.relu(self.fc_in(x))
        for layer in self.hidden_layers:
            x = torch.relu(layer(x))
        pi = torch.softmax(self.pi(x), dim=1)  # (N, K)
        mu = self.mu(x)  # (N, K-1)
        # keep slabs' std positive and stable
        log_sigma = torch.log(torch.nn.functional.softplus(self.log_sigma(x)) + 1e-6)  # (N, K-1)
        return pi, mu, log_sigma


# -------------------------
# Loss: correct spike+slabs mixture + steerable spike penalty
# -------------------------
def mdn_spike_loss_with_varying_noise(
    pi,
    mu,
    log_sigma,
    betahat,
    sebetahat,
    *,
    penalty: float = 1.5,
    beta_prior: tuple | None = None,
    eps: float = 1e-8,
):
    """
    Compute the negative log-likelihood loss for a spike-and-slab mixture model with optional
    spike penalty and Beta prior.

    Parameters
    ----------
    pi : torch.Tensor
        Mixture weights for spike and slabs, shape (N, K).
    mu : torch.Tensor
        Means for slab components, shape (N, K-1).
    log_sigma : torch.Tensor
        Log standard deviations for slab components, shape (N, K-1).
    betahat : torch.Tensor
        Observed effect estimates, shape (N,).
    sebetahat : torch.Tensor
        Standard errors of the effect estimates, shape (N,).
    penalty : float, optional
        >1 encourages spike; =1 neutral (default is 1.5).
    beta_prior : tuple, optional
        (alpha0, beta0) for Beta prior on pi_spike.
    eps : float, optional
        Small value to avoid log(0) (default is 1e-8).

    Returns
    -------
    torch.Tensor
        The computed loss (scalar).
    """
    # Spike likelihood: mean=0, total var = se^2
    var_spike = sebetahat**2  # (N,)
    logp_spike = -0.5 * ((betahat**2) / var_spike + torch.log(2 * torch.pi * var_spike))  # (N,)

    # Slab likelihoods: mean=mu_j, total sd = sqrt(prior_sd^2 + se^2)
    sigma_slab = torch.exp(log_sigma)  # (N, K-1) prior sd
    total_sigma_slab = torch.sqrt(sigma_slab**2 + sebetahat.unsqueeze(1) ** 2)  # (N, K-1)
    dist_slab = torch.distributions.Normal(mu, total_sigma_slab)
    logp_slabs = dist_slab.log_prob(betahat.unsqueeze(1))  # (N, K-1)

    # Mixture log-likelihood = logsumexp over [spike, slabs...]
    log_terms_spike = torch.log(pi[:, :1].clamp_min(eps)) + logp_spike.unsqueeze(1)  # (N, 1)
    log_terms_slabs = torch.log(pi[:, 1:].clamp_min(eps)) + logp_slabs  # (N, K-1)
    all_log_terms = torch.cat([log_terms_spike, log_terms_slabs], dim=1)  # (N, K)
    nll = -torch.logsumexp(all_log_terms, dim=1).mean()

    # (A) simple steer: penalty>1 encourages spike
    reg_simple = 0.0
    if penalty != 1.0:
        lam = float(penalty) - 1.0  # >0 encourages spike
        reg_simple = -(lam) * torch.log(pi[:, 0].clamp_min(eps)).mean()

    # (B) optional Beta(alpha0, beta0) prior on pi_spike
    reg_beta = 0.0
    if beta_prior is not None:
        a0, b0 = map(float, beta_prior)
        one_minus_pi0 = (1.0 - pi[:, 0]).clamp_min(eps)
        reg_beta = -((a0 - 1.0) * torch.log(pi[:, 0].clamp_min(eps)) + (b0 - 1.0) * torch.log(one_minus_pi0)).mean()

    return nll + reg_simple + reg_beta


# -------------------------
# Result container
# -------------------------
class EmdnPosteriorMeanNorm:
    def __init__(
        self,
        post_mean,
        post_mean2,
        post_sd,
        location,
        pi_np,
        scale,
        loss=0,
        model_param=None,
    ):
        """
        Container for the results of the spiked EMDN posterior mean estimation.

        Parameters
        ----------
        post_mean : torch.Tensor
            Posterior means for each observation.
        post_mean2 : torch.Tensor
            Posterior second moments for each observation.
        post_sd : torch.Tensor
            Posterior standard deviations for each observation.
        location : np.ndarray
            Mixture component means for each observation.
        pi_np : np.ndarray
            Mixture weights for each observation.
        scale : np.ndarray
            Mixture component standard deviations for each observation.
        loss : float, optional
            Final training loss.
        model_param : dict, optional
            Trained model parameters (state_dict).
        """
        self.post_mean = post_mean
        self.post_mean2 = post_mean2
        self.post_sd = post_sd
        self.location = location
        self.pi_np = pi_np
        self.scale = scale
        self.loss = loss
        self.model_param = model_param


# -------------------------
# Main solver (GPU-native)
# -------------------------
[docs] def spiked_emdn_posterior_means( X, betahat, sebetahat, n_epochs=50, n_layers=4, n_gaussians=5, hidden_dim=64, batch_size=512, lr=1e-3, model_param=None, *, penalty: float = 1.5, # >1 encourages spike; =1 neutral beta_prior: tuple | None = None, # e.g. (17., 5.) => target pi_spike ~ 0.77 print_every=10, device: torch.device | None = None, ): """ Fit a Mixture Density Network (MDN) with spike and slab components to estimate the prior distribution of effects. In the EBNM problem, we observe estimates `betahat` with standard errors `sebetahat` and want to estimate the prior distribution of the true effects. The prior is modeled as a mixture of Gaussians (slabs) plus a point mass at zero (spike), with mixture parameters predicted by a neural network. Parameters ---------- X : torch.Tensor or np.ndarray Covariates for each observation, shape (n_samples, n_features). betahat : torch.Tensor or np.ndarray Observed effect estimates, shape (n_samples,). sebetahat : torch.Tensor or np.ndarray Standard errors of the effect estimates, shape (n_samples,). n_epochs : int, optional Number of training epochs (default=50). n_layers : int, optional Number of hidden layers in the neural network (default=4). n_gaussians : int, optional Number of Gaussian components in the mixture (default=5). hidden_dim : int, optional Number of hidden units in each layer (default=64). batch_size : int, optional Batch size for training (default=512). lr : float, optional Learning rate for the optimizer (default=1e-3). model_param : dict, optional Pre-trained model parameters to initialize the network. penalty : float, optional >1 encourages spike; =1 neutral (default=1.5). beta_prior : tuple, optional (alpha0, beta0) for Beta prior on pi_spike. print_every : int, optional Print training loss every this many epochs (default=10). device : torch.device, optional Target device for tensors/models. Defaults to CUDA if available, else CPU. Returns ------- EmdnPosteriorMeanNorm Container with posterior means, standard deviations, and model parameters. """ # ---- device: inherit from input tensor if available, else fall back to # CUDA-or-CPU. Inheriting avoids silent device hops when the caller (e.g. # cEBMF) is on CPU/MPS but CUDA is also visible on the host. if device is None: device = ( betahat.device if isinstance(betahat, torch.Tensor) else (torch.device("cuda" if torch.cuda.is_available() else "cpu")) ) # ---- tensors on device X = torch.as_tensor(X, dtype=torch.float32, device=device) if X.ndim == 1: X = X.reshape(-1, 1) betahat = torch.as_tensor(betahat, dtype=torch.float32, device=device) sebetahat = torch.as_tensor(sebetahat, dtype=torch.float32, device=device) # ---- standardize on device X_scaled = standard_scale(X) # ---- data dataset = DensityRegressionDataset(X_scaled, betahat, sebetahat) dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0) # ---- model model = MDN( input_dim=X_scaled.shape[1], hidden_dim=hidden_dim, n_gaussians=n_gaussians, n_layers=n_layers, ).to(device) if model_param is not None: model.load_state_dict(model_param) optimizer = optim.Adam(model.parameters(), lr=lr) # ---- train for epoch in range(n_epochs): model.train() epoch_loss = 0.0 for inputs, targets, noise_std in dataloader: optimizer.zero_grad(set_to_none=True) pi, mu, log_sigma = model(inputs) loss = mdn_spike_loss_with_varying_noise( pi, mu, log_sigma, targets, noise_std, penalty=penalty, beta_prior=beta_prior, ) loss.backward() optimizer.step() epoch_loss += loss.item() if (epoch + 1) % print_every == 0: print(f"[Spiked-EMDN] Epoch {epoch + 1}/{n_epochs}, Loss: {epoch_loss / max(1, len(dataloader)):.4f}") # ---- predict (all data) model.eval() full_loader = DataLoader(dataset, batch_size=len(dataset), shuffle=False, num_workers=0) with torch.no_grad(): for X_batch, _, _ in full_loader: pi_pred, mu_pred, log_sigma_pred = model(X_batch) # Build full mixture params including the spike at 0 (prior sd=0 for spike) mu_full = torch.cat([torch.zeros_like(mu_pred[:, :1]), mu_pred], dim=1) # (N, K) sigma_full = torch.cat([torch.zeros_like(log_sigma_pred[:, :1]), torch.exp(log_sigma_pred)], dim=1) # (N, K) # ---- Vectorised posterior over all observations. # Replaces the previous ``for i in range(N)`` loop with a single # (N, K) tensor pipeline. ``posterior_mean_norm`` and the convolved # log-likelihood routine accept per-observation (N, K) ``location`` # and ``scale``, so the spike-and-slab structure is preserved exactly: # ``sigma_full[:, 0] = 0`` keeps column 0 as a point mass at # ``mu_full[:, 0] = 0``. eps = 1e-300 data_loglik = get_data_loglik_normal_torch( betahat=betahat, sebetahat=sebetahat, location=mu_full, # (N, K) per-observation scale=sigma_full, # (N, K) — 0 in column 0 for spike ) # (N, K) log_pi_full = torch.log(pi_pred.clamp_min(eps)) # (N, K) result = posterior_mean_norm( betahat=betahat, sebetahat=sebetahat, log_pi=log_pi_full, data_loglik=data_loglik, location=mu_full, scale=sigma_full, ) post_mean = result.post_mean post_mean2 = result.post_mean2 post_sd = result.post_sd # ---- Full marginal log-likelihood (no penalty). # total sd per obs/component: sqrt(se_i^2 + sigma_{ik}^2); spike has # sigma=0 ⇒ total sd = se. total_sigma = torch.sqrt(sigma_full**2 + sebetahat.unsqueeze(1) ** 2) # (N, K) z = (betahat.unsqueeze(1) - mu_full) / total_sigma # (N, K) log_sqrt_2pi = 0.5 * torch.log(torch.tensor(2.0 * torch.pi, device=betahat.device, dtype=betahat.dtype)) log_comp = -0.5 * z.pow(2) - torch.log(total_sigma) - log_sqrt_2pi # (N, K) log_mix = torch.logsumexp(log_pi_full + log_comp, dim=1) # (N,) full_marginal_ll = float(log_mix.sum().item()) return EmdnPosteriorMeanNorm( post_mean=post_mean, post_mean2=post_mean2, post_sd=post_sd, location=mu_full, pi_np=pi_pred, scale=sigma_full, loss=-full_marginal_ll, model_param=model.state_dict(), )