import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from cebmf_torch.utils.distribution_operation import get_data_loglik_normal_torch
from cebmf_torch.utils.posterior import posterior_mean_norm
from cebmf_torch.utils.standard_scaler import standard_scale
# -------------------------
# Dataset (expects tensors already on correct device/dtype)
# -------------------------
class DensityRegressionDataset(Dataset):
def __init__(self, X: torch.Tensor, betahat: torch.Tensor, sebetahat: torch.Tensor):
self.X = X
self.betahat = betahat
self.sebetahat = sebetahat
def __len__(self):
return self.X.shape[0]
def __getitem__(self, idx):
return self.X[idx], self.betahat[idx], self.sebetahat[idx]
# -------------------------
# Mixture Density Network (spike + slabs)
# pi: (N, K) for [spike, slabs...]
# mu/log_sigma: (N, K-1) for slabs only
# -------------------------
class MDN(nn.Module):
def __init__(self, input_dim, hidden_dim, n_gaussians, n_layers=4):
"""
Initialize a Mixture Density Network (MDN) with spike and slab components.
Parameters
----------
input_dim : int
Number of input features.
hidden_dim : int
Number of hidden units in each layer.
n_gaussians : int
Number of Gaussian components (must be >= 2 for spike + at least one slab).
n_layers : int, optional
Number of hidden layers (default is 4).
"""
super().__init__()
assert n_gaussians >= 2, "Need at least 1 spike + 1 slab."
self.fc_in = nn.Linear(input_dim, hidden_dim)
self.hidden_layers = nn.ModuleList([nn.Linear(hidden_dim, hidden_dim) for _ in range(n_layers)])
self.pi = nn.Linear(hidden_dim, n_gaussians) # includes spike (k=0)
self.mu = nn.Linear(hidden_dim, n_gaussians - 1) # slabs only
self.log_sigma = nn.Linear(hidden_dim, n_gaussians - 1) # slabs only
self.point_mass = 0.0
def forward(self, x):
"""
Forward pass through the MDN.
Parameters
----------
x : torch.Tensor
Input tensor of shape (N, input_dim).
Returns
-------
pi : torch.Tensor
Mixture weights for spike and slabs, shape (N, K).
mu : torch.Tensor
Means for slab components, shape (N, K-1).
log_sigma : torch.Tensor
Log standard deviations for slab components, shape (N, K-1).
"""
x = torch.relu(self.fc_in(x))
for layer in self.hidden_layers:
x = torch.relu(layer(x))
pi = torch.softmax(self.pi(x), dim=1) # (N, K)
mu = self.mu(x) # (N, K-1)
# keep slabs' std positive and stable
log_sigma = torch.log(torch.nn.functional.softplus(self.log_sigma(x)) + 1e-6) # (N, K-1)
return pi, mu, log_sigma
# -------------------------
# Loss: correct spike+slabs mixture + steerable spike penalty
# -------------------------
def mdn_spike_loss_with_varying_noise(
pi,
mu,
log_sigma,
betahat,
sebetahat,
*,
penalty: float = 1.5,
beta_prior: tuple | None = None,
eps: float = 1e-8,
):
"""
Compute the negative log-likelihood loss for a spike-and-slab mixture model with optional
spike penalty and Beta prior.
Parameters
----------
pi : torch.Tensor
Mixture weights for spike and slabs, shape (N, K).
mu : torch.Tensor
Means for slab components, shape (N, K-1).
log_sigma : torch.Tensor
Log standard deviations for slab components, shape (N, K-1).
betahat : torch.Tensor
Observed effect estimates, shape (N,).
sebetahat : torch.Tensor
Standard errors of the effect estimates, shape (N,).
penalty : float, optional
>1 encourages spike; =1 neutral (default is 1.5).
beta_prior : tuple, optional
(alpha0, beta0) for Beta prior on pi_spike.
eps : float, optional
Small value to avoid log(0) (default is 1e-8).
Returns
-------
torch.Tensor
The computed loss (scalar).
"""
# Spike likelihood: mean=0, total var = se^2
var_spike = sebetahat**2 # (N,)
logp_spike = -0.5 * ((betahat**2) / var_spike + torch.log(2 * torch.pi * var_spike)) # (N,)
# Slab likelihoods: mean=mu_j, total sd = sqrt(prior_sd^2 + se^2)
sigma_slab = torch.exp(log_sigma) # (N, K-1) prior sd
total_sigma_slab = torch.sqrt(sigma_slab**2 + sebetahat.unsqueeze(1) ** 2) # (N, K-1)
dist_slab = torch.distributions.Normal(mu, total_sigma_slab)
logp_slabs = dist_slab.log_prob(betahat.unsqueeze(1)) # (N, K-1)
# Mixture log-likelihood = logsumexp over [spike, slabs...]
log_terms_spike = torch.log(pi[:, :1].clamp_min(eps)) + logp_spike.unsqueeze(1) # (N, 1)
log_terms_slabs = torch.log(pi[:, 1:].clamp_min(eps)) + logp_slabs # (N, K-1)
all_log_terms = torch.cat([log_terms_spike, log_terms_slabs], dim=1) # (N, K)
nll = -torch.logsumexp(all_log_terms, dim=1).mean()
# (A) simple steer: penalty>1 encourages spike
reg_simple = 0.0
if penalty != 1.0:
lam = float(penalty) - 1.0 # >0 encourages spike
reg_simple = -(lam) * torch.log(pi[:, 0].clamp_min(eps)).mean()
# (B) optional Beta(alpha0, beta0) prior on pi_spike
reg_beta = 0.0
if beta_prior is not None:
a0, b0 = map(float, beta_prior)
one_minus_pi0 = (1.0 - pi[:, 0]).clamp_min(eps)
reg_beta = -((a0 - 1.0) * torch.log(pi[:, 0].clamp_min(eps)) + (b0 - 1.0) * torch.log(one_minus_pi0)).mean()
return nll + reg_simple + reg_beta
# -------------------------
# Result container
# -------------------------
class EmdnPosteriorMeanNorm:
def __init__(
self,
post_mean,
post_mean2,
post_sd,
location,
pi_np,
scale,
loss=0,
model_param=None,
):
"""
Container for the results of the spiked EMDN posterior mean estimation.
Parameters
----------
post_mean : torch.Tensor
Posterior means for each observation.
post_mean2 : torch.Tensor
Posterior second moments for each observation.
post_sd : torch.Tensor
Posterior standard deviations for each observation.
location : np.ndarray
Mixture component means for each observation.
pi_np : np.ndarray
Mixture weights for each observation.
scale : np.ndarray
Mixture component standard deviations for each observation.
loss : float, optional
Final training loss.
model_param : dict, optional
Trained model parameters (state_dict).
"""
self.post_mean = post_mean
self.post_mean2 = post_mean2
self.post_sd = post_sd
self.location = location
self.pi_np = pi_np
self.scale = scale
self.loss = loss
self.model_param = model_param
# -------------------------
# Main solver (GPU-native)
# -------------------------
[docs]
def spiked_emdn_posterior_means(
X,
betahat,
sebetahat,
n_epochs=50,
n_layers=4,
n_gaussians=5,
hidden_dim=64,
batch_size=512,
lr=1e-3,
model_param=None,
*,
penalty: float = 1.5, # >1 encourages spike; =1 neutral
beta_prior: tuple | None = None, # e.g. (17., 5.) => target pi_spike ~ 0.77
print_every=10,
device: torch.device | None = None,
):
"""
Fit a Mixture Density Network (MDN) with spike and slab components to estimate the prior distribution of effects.
In the EBNM problem, we observe estimates `betahat` with standard errors `sebetahat` and want to estimate
the prior distribution of the true effects. The prior is modeled as a mixture of Gaussians (slabs)
plus a point mass at zero (spike), with mixture parameters predicted by a neural network.
Parameters
----------
X : torch.Tensor or np.ndarray
Covariates for each observation, shape (n_samples, n_features).
betahat : torch.Tensor or np.ndarray
Observed effect estimates, shape (n_samples,).
sebetahat : torch.Tensor or np.ndarray
Standard errors of the effect estimates, shape (n_samples,).
n_epochs : int, optional
Number of training epochs (default=50).
n_layers : int, optional
Number of hidden layers in the neural network (default=4).
n_gaussians : int, optional
Number of Gaussian components in the mixture (default=5).
hidden_dim : int, optional
Number of hidden units in each layer (default=64).
batch_size : int, optional
Batch size for training (default=512).
lr : float, optional
Learning rate for the optimizer (default=1e-3).
model_param : dict, optional
Pre-trained model parameters to initialize the network.
penalty : float, optional
>1 encourages spike; =1 neutral (default=1.5).
beta_prior : tuple, optional
(alpha0, beta0) for Beta prior on pi_spike.
print_every : int, optional
Print training loss every this many epochs (default=10).
device : torch.device, optional
Target device for tensors/models. Defaults to CUDA if available, else CPU.
Returns
-------
EmdnPosteriorMeanNorm
Container with posterior means, standard deviations, and model parameters.
"""
# ---- device: inherit from input tensor if available, else fall back to
# CUDA-or-CPU. Inheriting avoids silent device hops when the caller (e.g.
# cEBMF) is on CPU/MPS but CUDA is also visible on the host.
if device is None:
device = (
betahat.device
if isinstance(betahat, torch.Tensor)
else (torch.device("cuda" if torch.cuda.is_available() else "cpu"))
)
# ---- tensors on device
X = torch.as_tensor(X, dtype=torch.float32, device=device)
if X.ndim == 1:
X = X.reshape(-1, 1)
betahat = torch.as_tensor(betahat, dtype=torch.float32, device=device)
sebetahat = torch.as_tensor(sebetahat, dtype=torch.float32, device=device)
# ---- standardize on device
X_scaled = standard_scale(X)
# ---- data
dataset = DensityRegressionDataset(X_scaled, betahat, sebetahat)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0)
# ---- model
model = MDN(
input_dim=X_scaled.shape[1],
hidden_dim=hidden_dim,
n_gaussians=n_gaussians,
n_layers=n_layers,
).to(device)
if model_param is not None:
model.load_state_dict(model_param)
optimizer = optim.Adam(model.parameters(), lr=lr)
# ---- train
for epoch in range(n_epochs):
model.train()
epoch_loss = 0.0
for inputs, targets, noise_std in dataloader:
optimizer.zero_grad(set_to_none=True)
pi, mu, log_sigma = model(inputs)
loss = mdn_spike_loss_with_varying_noise(
pi,
mu,
log_sigma,
targets,
noise_std,
penalty=penalty,
beta_prior=beta_prior,
)
loss.backward()
optimizer.step()
epoch_loss += loss.item()
if (epoch + 1) % print_every == 0:
print(f"[Spiked-EMDN] Epoch {epoch + 1}/{n_epochs}, Loss: {epoch_loss / max(1, len(dataloader)):.4f}")
# ---- predict (all data)
model.eval()
full_loader = DataLoader(dataset, batch_size=len(dataset), shuffle=False, num_workers=0)
with torch.no_grad():
for X_batch, _, _ in full_loader:
pi_pred, mu_pred, log_sigma_pred = model(X_batch)
# Build full mixture params including the spike at 0 (prior sd=0 for spike)
mu_full = torch.cat([torch.zeros_like(mu_pred[:, :1]), mu_pred], dim=1) # (N, K)
sigma_full = torch.cat([torch.zeros_like(log_sigma_pred[:, :1]), torch.exp(log_sigma_pred)], dim=1) # (N, K)
# ---- Vectorised posterior over all observations.
# Replaces the previous ``for i in range(N)`` loop with a single
# (N, K) tensor pipeline. ``posterior_mean_norm`` and the convolved
# log-likelihood routine accept per-observation (N, K) ``location``
# and ``scale``, so the spike-and-slab structure is preserved exactly:
# ``sigma_full[:, 0] = 0`` keeps column 0 as a point mass at
# ``mu_full[:, 0] = 0``.
eps = 1e-300
data_loglik = get_data_loglik_normal_torch(
betahat=betahat,
sebetahat=sebetahat,
location=mu_full, # (N, K) per-observation
scale=sigma_full, # (N, K) — 0 in column 0 for spike
) # (N, K)
log_pi_full = torch.log(pi_pred.clamp_min(eps)) # (N, K)
result = posterior_mean_norm(
betahat=betahat,
sebetahat=sebetahat,
log_pi=log_pi_full,
data_loglik=data_loglik,
location=mu_full,
scale=sigma_full,
)
post_mean = result.post_mean
post_mean2 = result.post_mean2
post_sd = result.post_sd
# ---- Full marginal log-likelihood (no penalty).
# total sd per obs/component: sqrt(se_i^2 + sigma_{ik}^2); spike has
# sigma=0 ⇒ total sd = se.
total_sigma = torch.sqrt(sigma_full**2 + sebetahat.unsqueeze(1) ** 2) # (N, K)
z = (betahat.unsqueeze(1) - mu_full) / total_sigma # (N, K)
log_sqrt_2pi = 0.5 * torch.log(torch.tensor(2.0 * torch.pi, device=betahat.device, dtype=betahat.dtype))
log_comp = -0.5 * z.pow(2) - torch.log(total_sigma) - log_sqrt_2pi # (N, K)
log_mix = torch.logsumexp(log_pi_full + log_comp, dim=1) # (N,)
full_marginal_ll = float(log_mix.sum().item())
return EmdnPosteriorMeanNorm(
post_mean=post_mean,
post_mean2=post_mean2,
post_sd=post_sd,
location=mu_full,
pi_np=pi_pred,
scale=sigma_full,
loss=-full_marginal_ll,
model_param=model.state_dict(),
)