Source code for cebmf_torch.cebnm.emdn

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset

from cebmf_torch.utils.distribution_operation import get_data_loglik_normal_torch
from cebmf_torch.utils.posterior import posterior_mean_norm
from cebmf_torch.utils.standard_scaler import standard_scale


# -------------------------
# Dataset (expects tensors already on correct device/dtype)
# -------------------------
class DensityRegressionDataset(Dataset):
    def __init__(self, X: torch.Tensor, betahat: torch.Tensor, sebetahat: torch.Tensor):
        self.X = X
        self.betahat = betahat
        self.sebetahat = sebetahat

    def __len__(self):
        return self.X.shape[0]

    def __getitem__(self, idx):
        return self.X[idx], self.betahat[idx], self.sebetahat[idx]


# -------------------------
# Mixture Density Network
# -------------------------
class MDN(nn.Module):
    def __init__(self, input_dim, hidden_dim, n_gaussians, n_layers=4):
        """
        Initialize a Mixture Density Network (MDN).

        Parameters
        ----------
        input_dim : int
            Number of input features.
        hidden_dim : int
            Number of hidden units in each layer.
        n_gaussians : int
            Number of Gaussian components in the mixture.
        n_layers : int, optional
            Number of hidden layers (default is 4).
        """
        super().__init__()
        self.fc_in = nn.Linear(input_dim, hidden_dim)
        self.hidden_layers = nn.ModuleList([nn.Linear(hidden_dim, hidden_dim) for _ in range(n_layers)])
        self.pi = nn.Linear(hidden_dim, n_gaussians)
        self.mu = nn.Linear(hidden_dim, n_gaussians)
        self.log_sigma = nn.Linear(hidden_dim, n_gaussians)

    def forward(self, x):
        """
        Forward pass through the MDN.

        Parameters
        ----------
        x : torch.Tensor
            Input tensor of shape (N, input_dim).

        Returns
        -------
        pi : torch.Tensor
            Mixture weights, shape (N, K).
        mu : torch.Tensor
            Means for each component, shape (N, K).
        log_sigma : torch.Tensor
            Log standard deviations for each component, shape (N, K).
        """
        x = torch.relu(self.fc_in(x))
        for layer in self.hidden_layers:
            x = torch.relu(layer(x))
        pi = torch.softmax(self.pi(x), dim=1)
        mu = self.mu(x)
        log_sigma = self.log_sigma(x)
        return pi, mu, log_sigma


# -------------------------
# Loss function
# -------------------------
def mdn_loss_with_varying_noise(pi, mu, log_sigma, betahat, sebetahat):
    """
    Compute the negative log-likelihood loss for a mixture density network with varying noise.

    Parameters
    ----------
    pi : torch.Tensor
        Mixture weights, shape (N, K).
    mu : torch.Tensor
        Means for each component, shape (N, K).
    log_sigma : torch.Tensor
        Log standard deviations for each component, shape (N, K).
    betahat : torch.Tensor
        Observed effect estimates, shape (N,).
    sebetahat : torch.Tensor
        Standard errors of the effect estimates, shape (N,).

    Returns
    -------
    torch.Tensor
        The computed loss (scalar).
    """
    sigma = torch.exp(log_sigma)
    total_sigma = torch.sqrt(sigma**2 + sebetahat.unsqueeze(1) ** 2)
    dist = torch.distributions.Normal(mu, total_sigma)
    log_probs = dist.log_prob(betahat.unsqueeze(1)) + torch.log(pi)
    return -torch.logsumexp(log_probs, dim=1).mean()


# -------------------------
# Result container
# -------------------------
class EmdnPosteriorMeanNorm:
    def __init__(
        self,
        post_mean,
        post_mean2,
        post_sd,
        location,
        pi_np,
        scale,
        loss=0,
        model_param=None,
    ):
        """
        Container for the results of the EMDN posterior mean estimation.

        Parameters
        ----------
        post_mean : torch.Tensor
            Posterior means for each observation.
        post_mean2 : torch.Tensor
            Posterior second moments for each observation.
        post_sd : torch.Tensor
            Posterior standard deviations for each observation.
        location : np.ndarray or torch.Tensor
            Mixture component means for each observation.
        pi_np : np.ndarray or torch.Tensor
            Mixture weights for each observation.
        scale : np.ndarray or torch.Tensor
            Mixture component standard deviations for each observation.
        loss : float, optional
            Final training loss.
        model_param : dict, optional
            Trained model parameters (state_dict).
        """
        self.post_mean = post_mean
        self.post_mean2 = post_mean2
        self.post_sd = post_sd
        self.location = location
        self.pi_np = pi_np
        self.scale = scale
        self.loss = loss
        self.model_param = model_param


# -------------------------
# Main solver (GPU-native)
# -------------------------
[docs] def emdn_posterior_means( X, betahat, sebetahat, n_epochs=50, n_layers=4, n_gaussians=5, hidden_dim=64, batch_size=512, lr=1e-3, model_param=None, device: torch.device | None = None, ): """ Fit a Mixture Density Network (MDN) to estimate the prior distribution of effects. In the EBNM problem, we observe estimates `betahat` with standard errors `sebetahat` and want to estimate the prior distribution of the true effects. The prior is modeled as a mixture of Gaussians with parameters predicted by a neural network. Parameters ---------- X : torch.Tensor or np.ndarray Covariates for each observation, shape (n_samples, n_features). betahat : torch.Tensor or np.ndarray Observed effect estimates, shape (n_samples,). sebetahat : torch.Tensor or np.ndarray Standard errors of the effect estimates, shape (n_samples,). n_epochs : int, optional Number of training epochs (default=50). n_layers : int, optional Number of hidden layers in the neural network (default=4). n_gaussians : int, optional Number of Gaussian components in the mixture (default=5). hidden_dim : int, optional Number of hidden units in each layer (default=64). batch_size : int, optional Batch size for training (default=512). lr : float, optional Learning rate for the optimizer (default=1e-3). model_param : dict, optional Pre-trained model parameters to initialize the network. device : torch.device, optional Target device for tensors/models. Defaults to CUDA if available, else CPU. Returns ------- EmdnPosteriorMeanNorm Container with posterior means, standard deviations, and model parameters. """ # ---- device: inherit from the input tensor if available, else fall back # to CUDA-or-CPU. Inheriting avoids silent device hops when the caller # (e.g. cEBMF) is on CPU/MPS but CUDA is also visible on the host. if device is None: device = ( betahat.device if isinstance(betahat, torch.Tensor) else (torch.device("cuda" if torch.cuda.is_available() else "cpu")) ) # ---- tensors on device X = torch.as_tensor(X, dtype=torch.float32, device=device) if X.ndim == 1: X = X.reshape(-1, 1) betahat = torch.as_tensor(betahat, dtype=torch.float32, device=device) sebetahat = torch.as_tensor(sebetahat, dtype=torch.float32, device=device) # ---- standardize on device X_scaled = standard_scale(X) # ---- dataset / loader (CUDA tensors => num_workers=0) dataset = DensityRegressionDataset(X_scaled, betahat, sebetahat) dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0) # ---- model / optimizer on device model = MDN( input_dim=X_scaled.shape[1], hidden_dim=hidden_dim, n_gaussians=n_gaussians, n_layers=n_layers, ).to(device) if model_param is not None: model.load_state_dict(model_param) optimizer = optim.Adam(model.parameters(), lr=lr) # ---- training for epoch in range(n_epochs): model.train() epoch_loss = 0.0 for inputs, targets, noise_std in dataloader: optimizer.zero_grad(set_to_none=True) pi, mu, log_sigma = model(inputs) loss = mdn_loss_with_varying_noise(pi, mu, log_sigma, targets, noise_std) loss.backward() optimizer.step() epoch_loss += loss.item() if (epoch + 1) % 10 == 0: print(f"[EMDN] Epoch {epoch + 1}/{n_epochs}, Loss: {epoch_loss / max(1, len(dataloader)):.4f}") # ---- prediction for all data (no grad) model.eval() full_loader = DataLoader(dataset, batch_size=len(dataset), shuffle=False, num_workers=0) with torch.no_grad(): for X_batch, _, _ in full_loader: pi, mu, log_sigma = model(X_batch) sigma = torch.exp(log_sigma) # (N, K) # ---- Vectorised posterior over all observations. # The MDN emits per-observation mixture parameters (pi, mu, sigma) of # shape (N, K). The vectorised posterior_mean_norm + per-obs (J, K) # data_loglik path replaces the previous ``for i in range(J)`` loop, # which created N kernel launches and dominated runtime on GPU. eps = 1e-12 data_loglik = get_data_loglik_normal_torch( betahat=betahat, sebetahat=sebetahat, location=mu, # (N, K) per-observation scale=sigma, # (N, K) per-observation ) # (N, K) log_pi_full = torch.log(pi.clamp_min(eps)) # (N, K) result = posterior_mean_norm( betahat=betahat, sebetahat=sebetahat, log_pi=log_pi_full, data_loglik=data_loglik, location=mu, scale=sigma, ) post_mean = result.post_mean post_mean2 = result.post_mean2 post_sd = result.post_sd # ---- Full marginal log-likelihood (no penalty). # log N(b_i ; mu_{ik}, sqrt(se_i^2 + sigma_{ik}^2)) in log-domain. total_sigma = torch.sqrt(sigma**2 + sebetahat.unsqueeze(1) ** 2) # (N, K) z = (betahat.unsqueeze(1) - mu) / total_sigma # (N, K) log_sqrt_2pi = 0.5 * torch.log(torch.tensor(2.0 * torch.pi, device=betahat.device, dtype=betahat.dtype)) log_comp = -0.5 * z.pow(2) - torch.log(total_sigma) - log_sqrt_2pi # (N, K) log_mix = torch.logsumexp(log_pi_full + log_comp, dim=1) # (N,) full_marginal_ll = float(log_mix.sum().item()) return EmdnPosteriorMeanNorm( post_mean=post_mean, post_mean2=post_mean2, post_sd=post_sd, location=mu, pi_np=pi, scale=sigma, loss=-full_marginal_ll, model_param=model.state_dict(), )