Tiled clustering model

This example is similar as the example presented in Figure 1 in Covariate-moderated Empirical Bayes Matrix Factorization Denault et al 2025 Matrix Factorization

Model description:

We simulate a matrix \(Z=LF^t\) in which \( L \) (but not \( F \)) depended on the 2-d locations of the data points. Specifically, we generated a periodic tiling of \([0,1] \times [0,1]\), randomly labeling each tile 1, 2 or 3. For each data point \(i\), we set \(\ell_{ik} = f(x,y)\) if \(i\) was in the tile with label \(k\), otherwise \(\ell_{ik} = 0\). The \( F \) matrix, by contrast, was simulated from a simple scale mixture of normals, \(f_{jk} \sim \pi_0 \delta_0 + \sum_{m=1}^M N(0, \sigma_m^2)\). We simulated homoskedastic noise with \(\tau_{ij} = 0.1\).

import torch
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches

# =========================
# Config & Reproducibility
# =========================
SEED = 1
N = 2000          # number of (x,y) points
M = 200           # number of features (columns) in f
BOUNDARIES = (0.33, 0.66)  # vertical/horizontal guide lines
NOISE_STD = 1.0   # std-dev for observation noise in Z

torch.manual_seed(SEED)

# =========================
# 1) Random uniform data
# =========================
x = torch.rand(N)
y = torch.rand(N)
X = torch.stack([x, y], dim=1)  # (N, 2) -- kept for clarity/optionally used later

# Quick scatter of x vs y
plt.figure(figsize=(7, 5))
plt.scatter(x.numpy(), y.numpy(), alpha=0.5, s=12)
plt.xlabel("x")
plt.ylabel("y")
plt.title("Scatter: x vs y")
plt.tight_layout()
plt.show()

# =========================
# 2) Generate f (3 x M)
#    f[0,i] ~ t1_i * N(0,1), f[1,i], f[2,i] ~ t2_i * N(0,1)
# =========================
t1 = torch.randint(0, 2, (M,), dtype=torch.float32)  # {0,1}
t2 = torch.randint(0, 2, (M,), dtype=torch.float32)  # {0,1}

f0 = t1 * torch.randn(M)
f1 = t2 * torch.randn(M)
f2 = t2 * torch.randn(M)
f  = torch.stack([f0, f1, f2], dim=0)  # (3, M)

# =========================
# 3) Build L (N x 3) + factor labels
# =========================
L = torch.zeros(N, 3, dtype=torch.float32)

b1, b2 = BOUNDARIES
mask1 = x < b1
mask2 = (~mask1) & (x < b2) & (y < b1)  # second region
mask3 = ~(mask1 | mask2)                # everything else

# Nonzero loadings by region (note: original logic used sin(x) for all three)
L[mask1, 0] = torch.sin(x[mask1])
L[mask2, 1] = torch.sin(x[mask2])
L[mask3, 2] = torch.sin(x[mask3])

# Integer factor labels in {1,2,3}
factor = torch.zeros(N, dtype=torch.long)
factor[mask1] = 1
factor[mask2] = 2
factor[mask3] = 3

# =========================
# 4) Factor visualization
# =========================
colors = ["#D41159", "#1A85FF", "#40B0A6"]  # red, blue, teal
color_map = {1: colors[0], 2: colors[1], 3: colors[2]}
point_colors = [color_map[int(k)] for k in factor.tolist()]

plt.figure(figsize=(7.5, 6))
plt.scatter(x.numpy(), y.numpy(), c=point_colors, s=18, alpha=0.85)
for v in (b1, b2):
    plt.axhline(v, color="black", linestyle="--", linewidth=1)
    plt.axvline(v, color="black", linestyle="--", linewidth=1)
plt.title("Factor assignment by region")
plt.xlabel("x")
plt.ylabel("y")

# Legend (explicit: red = non-zero L[:,0], etc.)
handles = [
    mpatches.Patch(color=colors[0], label="factor 1 (non-zero L[:,0])"),
    mpatches.Patch(color=colors[1], label="factor 2 (non-zero L[:,1])"),
    mpatches.Patch(color=colors[2], label="factor 3 (non-zero L[:,2])"),
]
plt.legend(handles=handles, title="Groups", loc="best", frameon=True)
plt.tight_layout()
plt.show()

# =========================
# 5) Individual factor heatmaps (L1, L2, L3 as color)
# =========================
for i in range(3):
    plt.figure(figsize=(7.5, 6))
    plt.scatter(
        x.numpy(),
        y.numpy(),
        c=L[:, i].numpy(),
        cmap="coolwarm",  # shows magnitude of L[:,i]
        s=18
    )
    for v in (b1, b2):
        plt.axhline(v, color="black", linestyle="--", linewidth=1)
        plt.axvline(v, color="black", linestyle="--", linewidth=1)
    plt.title(f"Factor {i+1} vaues (L[:,{i}])")
    plt.xlabel("x")
    plt.ylabel("y")
    cbar = plt.colorbar()
    cbar.set_label(f"L{i+1}")
    plt.tight_layout()
    plt.show()

# =========================
# 6) Generate observations Z = L @ f + noise
# =========================
noise = NOISE_STD * torch.randn(N, M)
Z = L @ f + noise  # (N, M)

# Optional sanity checks
assert L.shape == (N, 3)
assert f.shape == (3, M)
assert Z.shape == (N, M)
../_images/10837862fa93c8ed3a1d6122ada2634a3d9db4ed428237a7930f15feb5e42c58.png ../_images/69dac0c3cbe8a5ee8332d3972f394ee31dcd6966a81ce9a0ed51123db3c4f1d3.png ../_images/3615be8f0b03465aaee48426f9f137a8cb42b01fc3cbb13cc2a7a54f05b81f1f.png ../_images/16876330221b0d6490cf0411160d5aa2ec7cd6929477751d5b2da8478bde696d.png ../_images/e037a5dc31d4349e0fddf429949ecc52b2a4e527bd722bf847e278cdf2127c54.png

We will factorize the matrix Z using 4 different models.

  • EBMF using point mass Laplace without using side information

  • cEBMF using a “covariate moderated adaptive shrinkage prior “ (cash) \( g(x,y) = \pi_0(x,y) \delta_0 + \sum_{m=1}^M \pi_m (x,y) N(0, \sigma_m^2) \)

  • cEBMF using a “covariate moderated generalized binary prior” (cbg) \( g(x,y) = \pi_0(x,y) \delta_0 + (1-\pi_0(x,y)) N(\mu, \sigma_m^2) \)

  • cEBMF using a “emprirical mixture of density networks” (emdn) \( g(x,y) = \ \sum_{m=1}^M \pi_m (x,y) N(\mu(x,y), \sigma_m(x,y)^2) \)

Fitting the model without account for the 2d side information

from cebmf_torch import cEBMF
mycebmf=  cEBMF(data=Z ) 
mycebmf.initialise_factors()
mycebmf.fit(10)
CEBMFResult(L=tensor([[-6.7598e+00, -1.4332e-02,  1.8322e+00,  7.3646e-09, -1.6800e-07],
        [-1.2728e-01, -8.6281e-02,  2.0286e+00,  6.0323e-08, -6.9698e-08],
        [-1.6566e+00, -1.5823e-02, -1.1584e-02, -3.5549e-08, -4.5253e-07],
        ...,
        [-6.3495e+00, -9.7502e-01,  1.8474e-02, -1.5775e-08,  5.1910e-07],
        [-1.1016e-01,  6.2751e-02,  2.8443e+00,  7.0515e-08, -1.3485e-08],
        [-9.9256e-02,  1.2566e-01,  1.3082e+00, -2.2851e-08,  6.4656e-08]]), F=tensor([[-1.3574e-01,  1.4970e-01, -8.4925e-02,  1.7957e-20, -9.2348e-20],
        [ 1.6497e-01,  1.4324e-01,  3.9473e-04,  6.2244e-22, -4.8140e-20],
        [-5.1115e-04, -6.7892e-04,  5.7604e-02,  1.2628e-20, -2.6080e-19],
        [-5.5547e-02,  1.2525e-01,  6.6623e-02, -6.9375e-20,  2.1519e-20],
        [ 1.5923e-05, -2.8374e-02,  6.6553e-02, -1.1059e-19, -7.3195e-20],
        [ 1.1018e-03,  1.7563e-03, -4.5966e-04,  6.3001e-21,  1.4011e-20],
        [-1.2125e-05, -9.3930e-03, -5.1896e-04, -1.2304e-19, -1.7500e-20],
        [-1.9283e-01,  1.5602e-01,  2.0437e-03,  4.2131e-21, -5.1623e-20],
        [-4.8056e-05,  1.5003e-04, -9.4235e-02,  4.6035e-20, -6.9588e-20],
        [-8.5753e-02,  2.3534e-01,  1.0137e-02, -4.5239e-20,  1.7128e-19],
        [-1.0425e-01,  6.4188e-02,  1.4812e-01,  2.7006e-21,  7.2587e-20],
        [ 9.7487e-02,  1.2954e-01,  6.4201e-03,  1.2872e-20, -4.2374e-19],
        [-2.5056e-04, -4.1524e-04, -1.2789e-01, -6.0438e-20,  7.8870e-20],
        [ 1.2146e-03,  1.4316e-03,  7.5495e-04, -7.7692e-20, -3.4504e-20],
        [ 8.1914e-05,  2.8076e-04, -9.5339e-04, -3.9487e-20,  8.2315e-20],
        [-2.1534e-03,  6.6904e-04, -1.1697e-04,  2.6439e-20, -1.4832e-20],
        [-3.2322e-04, -2.6461e-03, -1.3059e-03,  2.3863e-20,  8.3820e-21],
        [ 5.9707e-05, -1.2688e-03,  7.0687e-04, -4.9283e-20,  5.2254e-20],
        [ 1.0802e-01, -5.8955e-02,  7.6502e-02,  1.7518e-19, -6.2849e-20],
        [ 1.0104e-03,  1.5032e-04, -1.5265e-03, -2.1690e-20,  1.1288e-19],
        [ 2.4492e-05, -8.0278e-04,  4.7480e-04,  2.7265e-20, -2.7038e-21],
        [ 2.0939e-02, -1.6759e-01,  8.5444e-02,  3.1712e-20,  2.7484e-20],
        [-1.6948e-04, -6.7839e-03,  3.3583e-03,  2.7840e-20,  1.9376e-20],
        [ 6.6897e-04,  3.2620e-02,  6.6062e-05,  4.9348e-20,  6.4941e-20],
        [-8.5586e-02, -1.0361e-03, -1.1501e-02, -7.4976e-20,  3.2870e-20],
        [-5.2643e-02, -1.0521e-01,  9.0646e-03, -5.8664e-20,  1.6953e-19],
        [-3.2699e-04, -1.1781e-03, -1.6563e-02,  4.2264e-20,  1.2057e-19],
        [ 5.6935e-03,  1.6503e-03, -1.7652e-02,  5.5405e-20, -4.8137e-20],
        [-4.7457e-03,  6.3929e-04, -3.7393e-04,  1.9772e-20,  1.3109e-19],
        [-1.1232e-05,  2.2910e-03, -1.4161e-03, -9.1443e-20,  5.3034e-20],
        [ 1.7971e-04, -1.7047e-03,  7.1284e-04,  3.5941e-19, -1.4803e-19],
        [-6.5873e-04, -1.0879e-03,  5.2048e-04, -1.2074e-20,  1.0733e-19],
        [ 2.5404e-05,  8.9711e-04,  6.7480e-03,  6.8623e-20, -7.4276e-20],
        [-1.7032e-04,  3.3755e-03, -1.3659e-01,  2.6300e-20,  7.2140e-20],
        [-2.9819e-03, -2.9035e-03, -4.1955e-03, -1.1254e-21,  4.3051e-20],
        [-8.9145e-06,  8.1518e-06,  1.6263e-03, -2.6513e-20, -2.9795e-20],
        [ 1.6561e-05,  2.3494e-02, -1.1243e-03, -4.4496e-21, -2.2639e-19],
        [-1.4284e-04, -7.7837e-05,  4.1827e-05,  3.2581e-20,  6.6319e-20],
        [-6.4429e-02, -1.4210e-03,  7.8048e-04, -8.5266e-21, -2.2341e-20],
        [ 8.9116e-02,  9.7028e-02, -6.2227e-05,  8.0645e-21, -5.5536e-20],
        [ 2.5038e-04,  1.2387e-03, -1.6571e-04,  1.1721e-19, -8.4799e-20],
        [-2.8748e-02, -8.8919e-03, -5.1223e-02, -7.9312e-20, -6.0307e-20],
        [-6.8683e-02,  4.1812e-02,  1.5548e-01,  9.8535e-20,  1.2217e-19],
        [ 6.2853e-02,  6.6321e-02,  1.1107e-03,  1.8183e-19, -4.3965e-20],
        [-1.7326e-01,  1.6959e-01,  3.0756e-03, -4.1344e-21,  9.5618e-20],
        [-2.8148e-02,  5.2795e-04, -7.4519e-02,  2.2935e-20,  3.8780e-20],
        [ 5.3138e-02, -8.1542e-02,  4.8963e-04, -2.3322e-20, -5.3431e-20],
        [-6.3550e-02,  1.1984e-02, -8.8323e-02, -1.4449e-19, -5.8307e-20],
        [ 1.2018e-04, -3.1058e-03, -8.3748e-04, -1.8781e-20,  5.8181e-20],
        [ 2.9697e-02, -6.7459e-02, -2.7770e-05,  1.0262e-19, -5.7506e-22],
        [-1.2052e-02, -1.1672e-03, -1.5784e-03,  3.6772e-20, -4.4685e-20],
        [-1.5714e-04, -9.4734e-04, -5.7397e-03, -1.8741e-19, -4.7225e-20],
        [ 3.7574e-02,  8.4881e-02, -4.6692e-04,  6.4711e-20, -4.2124e-20],
        [ 2.0667e-04,  3.1218e-04, -7.0962e-04,  4.0085e-20,  9.5358e-20],
        [ 9.4132e-05,  7.9334e-04, -2.9669e-03,  2.4522e-20, -5.2126e-20],
        [ 4.6602e-03, -2.9170e-02, -1.1420e-04,  2.8458e-20,  5.6024e-21],
        [ 2.0284e-01, -1.6733e-02, -4.5719e-03, -2.2902e-20, -5.1409e-20],
        [ 3.3062e-05,  1.4732e-04,  2.8841e-03,  1.6721e-20,  2.1748e-19],
        [-1.1423e-01, -1.0494e-01,  9.0751e-02,  8.6146e-20,  1.7157e-19],
        [ 1.2127e-01,  1.1655e-03, -8.9261e-04, -4.1644e-20, -5.9347e-20],
        [ 5.5725e-04,  6.7578e-04, -1.1144e-01, -1.2188e-20, -2.0762e-20],
        [-2.8548e-05,  8.6150e-04, -1.9644e-03,  9.4290e-20,  7.6416e-20],
        [ 1.0297e-01, -6.6430e-02,  4.1622e-02,  2.7657e-20, -8.4981e-20],
        [ 2.6432e-05,  2.8337e-03, -4.5584e-03, -3.5397e-21,  4.7593e-20],
        [-8.1629e-02,  1.7757e-01,  8.5043e-05, -2.1618e-21, -4.3551e-20],
        [ 1.4527e-01, -7.6396e-02, -4.1422e-03, -2.6837e-20,  1.5157e-19],
        [-1.4587e-02,  3.0583e-03,  7.6059e-02, -1.1696e-20,  2.7135e-21],
        [ 9.9179e-03,  4.4399e-03, -6.7361e-02,  2.1204e-20, -9.1404e-20],
        [-6.3266e-02,  6.8211e-02, -1.4403e-02, -3.6769e-20, -1.1678e-19],
        [ 4.7350e-05,  1.9682e-04,  1.3569e-01, -1.6477e-20, -1.0486e-19],
        [-4.1270e-04,  1.1289e-01,  5.9134e-02,  1.6769e-21, -1.2953e-19],
        [-1.9875e-01, -5.6012e-02, -1.1427e-03, -1.1654e-20,  6.6316e-20],
        [-3.7248e-02, -5.8525e-02,  9.9596e-03,  1.3011e-21, -8.4257e-20],
        [-1.1176e-01,  8.1333e-02, -9.1811e-02,  7.4592e-20, -1.3733e-19],
        [ 6.6000e-06, -2.3799e-04,  1.0924e-01, -2.5927e-19,  4.0836e-20],
        [-1.5488e-01, -4.1751e-02, -4.1433e-02,  4.6822e-20, -4.6347e-20],
        [-7.6759e-02, -2.1057e-03, -1.5066e-02, -1.9332e-20,  1.0531e-19],
        [ 1.1012e-04,  9.4715e-04, -6.4788e-04,  1.2415e-19, -2.8956e-20],
        [ 1.3318e-01, -1.6172e-01, -1.2019e-01,  3.3497e-20,  6.7146e-20],
        [ 5.0234e-03,  8.4604e-04,  5.0094e-05,  1.0853e-20, -1.1801e-19],
        [ 1.4696e-01,  1.2214e-01, -1.0594e-02,  2.5783e-20,  1.7960e-19],
        [ 2.6767e-04, -7.7141e-04, -7.6745e-02,  1.2329e-19, -7.3658e-20],
        [ 4.2664e-02,  1.4677e-01,  1.3615e-03,  4.7083e-20, -9.8512e-20],
        [ 1.5075e-04,  3.6862e-03, -1.9620e-03, -1.7414e-20,  3.4449e-20],
        [-1.3717e-01,  4.6420e-03, -7.8725e-03,  2.3696e-20,  1.0641e-19],
        [-2.1072e-04,  4.0278e-02,  2.4932e-04,  7.3123e-21, -1.0653e-20],
        [ 1.4175e-01, -1.3205e-01,  3.0771e-04,  3.8740e-20, -8.7470e-20],
        [ 1.8058e-01,  3.4278e-02, -6.9091e-02,  7.2686e-21,  9.4165e-21],
        [ 3.2196e-02, -1.5235e-04,  3.6735e-03,  1.5579e-20,  4.4532e-20],
        [-3.0246e-05, -1.9188e-03, -8.3603e-04, -6.6410e-21,  1.7751e-20],
        [ 4.9456e-03, -5.2106e-02,  2.4975e-03, -9.3873e-22,  3.2638e-20],
        [ 4.1365e-02, -8.0753e-02,  6.8802e-02, -1.0998e-19,  6.9506e-20],
        [ 1.3813e-03,  1.2615e-01, -2.1755e-02, -4.6238e-20,  9.8431e-20],
        [ 3.9827e-05,  1.3619e-03,  2.3809e-02,  4.2057e-21,  1.6576e-19],
        [ 4.6455e-05, -5.0966e-04, -8.7614e-05, -5.0321e-20,  1.2135e-20],
        [-1.1070e-03,  2.1406e-04,  7.6964e-03, -2.3209e-20,  5.2304e-20],
        [-2.3022e-01,  4.9325e-04, -6.9926e-04,  2.4548e-21,  3.0765e-20],
        [ 1.6395e-04, -5.1521e-03, -3.5700e-03, -3.9834e-20,  1.1533e-20],
        [ 6.3880e-02, -1.1139e-01,  9.9010e-02,  4.6091e-20,  1.8586e-21],
        [-2.7444e-02,  9.7814e-02,  9.1273e-04,  2.6157e-20,  1.9627e-19],
        [ 1.2399e-01,  1.4725e-01,  2.2021e-03,  9.3068e-20,  2.5964e-20],
        [-1.6294e-04, -5.9606e-03, -1.9510e-03, -1.6733e-20, -6.4012e-20],
        [ 2.4203e-05,  1.6424e-03,  6.5780e-02, -8.4543e-20, -1.8039e-19],
        [ 1.4941e-01,  9.2675e-02, -6.5342e-02, -6.7748e-20, -6.1357e-20],
        [-6.6192e-02,  7.1228e-02,  1.4926e-03,  3.5037e-20,  9.9564e-20],
        [ 1.2750e-03,  6.6668e-03, -2.4562e-02, -5.5742e-20,  2.4914e-20],
        [ 3.7080e-05, -6.9708e-03, -2.8969e-03, -8.1183e-20,  4.6137e-20],
        [ 2.7126e-04, -1.7510e-03, -8.2468e-03, -3.1495e-20,  1.9494e-20],
        [-2.5660e-05,  2.4905e-04,  4.2778e-04, -1.0878e-19, -2.4776e-20],
        [-1.2209e-03,  1.0914e-03, -9.9911e-03, -3.7312e-20,  1.1629e-19],
        [-7.3981e-05, -1.4915e-03, -2.5981e-03,  1.3408e-20, -3.4378e-19],
        [ 1.5387e-04, -5.2371e-04, -1.0244e-01, -8.9938e-20, -1.2913e-19],
        [-1.9930e-04, -2.6572e-03, -3.2436e-03, -5.8330e-20, -2.9037e-19],
        [-6.1172e-02,  5.6656e-04,  5.0731e-03,  1.8865e-20,  8.2624e-22],
        [ 5.9226e-04,  1.9041e-04,  6.4124e-02, -2.6613e-20, -3.6980e-20],
        [ 1.2639e-01,  1.3893e-01,  6.9401e-06, -5.7744e-20, -6.0878e-20],
        [-1.3055e-02,  6.8740e-02,  5.2309e-03,  1.4092e-19,  1.4347e-19],
        [-4.9370e-02, -9.3990e-02, -5.6556e-04,  1.0511e-20, -2.2804e-19],
        [-1.6047e-01, -1.5990e-01, -5.8939e-03,  4.3257e-20, -1.4271e-20],
        [-1.4485e-04,  1.0923e-04, -5.9271e-03, -5.9698e-20,  7.1553e-20],
        [-3.7363e-02,  5.0166e-02,  7.4344e-02,  3.3906e-21,  1.8119e-20],
        [ 1.7148e-01, -7.1842e-03,  5.9119e-04, -2.4416e-20,  2.1797e-20],
        [ 6.9052e-05,  3.0905e-02, -9.2587e-04,  4.4565e-20,  1.3367e-19],
        [ 5.2679e-04,  1.3215e-03, -2.1567e-03,  1.8480e-19,  1.0815e-20],
        [ 2.8240e-04, -1.9167e-05,  6.9585e-02, -1.9480e-19, -4.2075e-20],
        [-2.0245e-04, -2.0796e-03,  2.6284e-03, -7.7238e-21,  1.2098e-20],
        [-6.6652e-02, -4.2250e-02, -3.8903e-05, -9.7684e-20,  1.4876e-19],
        [ 3.2806e-05, -3.3750e-03,  1.6077e-01,  4.0891e-20,  5.7368e-20],
        [ 1.2903e-01,  7.5245e-04,  6.2636e-02,  4.9273e-20,  1.6656e-19],
        [ 7.9751e-05,  7.0579e-03, -8.6073e-05, -4.8833e-20,  1.4288e-19],
        [ 3.5647e-02, -4.5339e-03,  2.1242e-02, -4.0459e-20,  1.4849e-19],
        [ 9.2582e-05,  3.3758e-03,  4.1491e-04,  1.6547e-20,  1.5457e-20],
        [ 4.7011e-04,  3.5311e-05,  3.5238e-03,  3.5965e-20, -1.2555e-20],
        [ 1.5816e-01, -1.1602e-01, -1.2148e-03,  9.8463e-20,  1.0757e-19],
        [-7.5358e-02,  4.6213e-02,  2.1600e-02, -7.5578e-20,  2.1740e-20],
        [-5.0513e-02, -1.0189e-01,  1.2459e-03,  2.3330e-20, -8.1476e-22],
        [-6.1296e-02, -6.8779e-02, -2.6078e-03,  7.3913e-20, -4.9808e-19],
        [-9.1678e-06, -1.7387e-04, -3.6712e-01, -8.3164e-20,  8.9258e-20],
        [-3.3464e-03, -1.2212e-02, -8.0465e-03,  1.9317e-21,  4.4672e-20],
        [ 4.8707e-02, -8.9462e-02,  2.1994e-03, -4.7879e-20,  1.4077e-19],
        [ 2.2906e-04, -5.2223e-05,  1.6396e-05, -2.8768e-20,  3.1513e-20],
        [ 1.2482e-01,  1.5578e-01, -3.9855e-03,  2.8976e-20, -6.5741e-21],
        [ 1.0774e-04,  2.8850e-03, -4.9775e-02,  1.9075e-20, -3.0432e-19],
        [ 1.0257e-04,  1.3225e-03,  9.9998e-02,  2.5279e-19, -5.5564e-20],
        [ 5.7752e-04,  7.2657e-03,  2.4098e-03,  5.0168e-22,  1.5699e-20],
        [-2.3378e-03, -2.7706e-04,  8.5375e-02, -5.7947e-20,  2.2098e-20],
        [ 1.7217e-02, -3.3900e-03,  5.2110e-03,  4.6660e-20,  3.7703e-20],
        [-2.0348e-03, -6.7304e-02, -3.1854e-04, -6.2453e-22,  9.7908e-21],
        [-9.5689e-05,  1.8653e-03,  6.5173e-02, -7.8662e-21, -1.2387e-19],
        [ 3.7915e-05, -4.8745e-03,  4.7713e-03, -5.7752e-20,  2.1748e-19],
        [ 4.9027e-05, -2.9088e-03, -1.6838e-01, -4.4638e-21, -4.7255e-20],
        [ 8.0039e-02,  1.8518e-01, -4.3702e-03,  2.1775e-20, -1.9687e-21],
        [ 3.7804e-06,  1.6377e-03, -4.5465e-03, -1.0555e-20,  2.5027e-20],
        [-4.8711e-04, -1.2619e-04, -9.8955e-02,  3.3229e-20, -8.9859e-20],
        [ 2.5767e-04, -6.2112e-03, -1.3157e-03,  8.3707e-20,  1.0760e-20],
        [ 2.5815e-03,  7.4457e-03, -1.0166e-01,  2.7471e-20,  8.6166e-20],
        [ 1.0316e-01,  7.9118e-02, -2.3165e-01, -4.0759e-20,  5.8373e-22],
        [ 1.1366e-01,  1.3669e-01,  9.4570e-02,  6.0392e-20,  3.4592e-20],
        [-7.1087e-02, -1.5672e-01, -6.7876e-02,  8.9611e-20, -6.1600e-21],
        [ 4.0076e-04,  2.3893e-02,  3.1557e-03, -6.7195e-20, -2.7785e-20],
        [ 8.6295e-02,  8.6670e-02,  6.8344e-02, -5.5726e-20, -6.8579e-21],
        [ 8.5633e-02,  8.1874e-03, -8.6012e-04,  6.7476e-20, -3.7875e-20],
        [ 6.7064e-02,  5.5136e-02, -3.9121e-02,  9.2221e-20, -7.5987e-20],
        [-8.2361e-02, -3.2802e-03,  3.6983e-02,  5.3970e-21, -1.6647e-19],
        [-1.7319e-01,  3.6090e-03,  9.3339e-04,  8.7241e-20, -3.2022e-19],
        [ 2.5873e-04,  2.9838e-03, -1.1698e-01, -8.5598e-21, -6.6861e-20],
        [-1.6293e-04, -8.6426e-04, -4.5295e-03, -6.6619e-20,  1.1342e-19],
        [ 4.7647e-02, -1.2652e-01, -5.1250e-02,  5.3053e-21,  5.0110e-20],
        [ 6.0304e-03, -3.4626e-04,  7.1935e-04, -5.4878e-20, -7.3634e-21],
        [ 3.5268e-02, -1.5980e-01,  1.0941e-01, -7.0961e-20, -1.1220e-19],
        [ 4.6034e-04,  1.5411e-03,  3.1873e-04,  1.7885e-19, -2.0980e-20],
        [-9.5121e-05,  1.8075e-03, -2.7455e-03, -9.4216e-21, -1.4182e-19],
        [-3.5339e-03,  1.1650e-02,  5.6150e-04,  6.1783e-20,  7.7005e-20],
        [ 6.7477e-02,  4.2563e-02,  4.0133e-02, -8.7712e-20, -6.9315e-21],
        [ 6.1806e-05, -5.1524e-03,  5.0531e-03, -6.4487e-20, -5.0113e-20],
        [ 2.1208e-01,  1.4762e-01,  7.0882e-03, -1.0602e-20,  3.4126e-21],
        [-8.0582e-02, -1.3780e-01, -1.3138e-02,  4.6199e-20,  2.2650e-19],
        [ 5.2848e-04,  4.2760e-02,  9.9693e-02,  9.4784e-21, -6.6116e-20],
        [ 4.7488e-02, -9.4620e-02,  4.4557e-03, -6.4812e-20, -6.2005e-20],
        [ 2.6911e-02, -3.1162e-03,  1.5122e-03,  5.2048e-20,  1.5221e-19],
        [-4.6469e-02, -3.0442e-04, -2.2434e-03,  2.4934e-20, -1.5814e-19],
        [-3.1792e-04, -1.6583e-02,  2.1632e-03,  5.8754e-20,  7.4531e-20],
        [-7.7037e-04,  9.3134e-04,  7.8220e-02,  4.2813e-21,  2.9729e-21],
        [-9.2464e-02, -6.8965e-02, -7.6467e-02,  1.6397e-20, -1.2892e-19],
        [ 1.6834e-01, -4.4833e-02, -1.0006e-03,  1.2667e-20, -6.8240e-21],
        [ 9.8550e-02,  3.7965e-02,  2.8003e-03, -4.2913e-20, -1.3931e-19],
        [-1.0635e-01, -3.3404e-02,  7.8223e-02, -4.1015e-20, -5.1745e-20],
        [-8.6311e-02,  2.4094e-04, -4.2853e-04, -8.4809e-21,  1.0048e-19],
        [-2.4338e-04, -3.9936e-03,  9.5450e-02, -4.7913e-20, -1.3756e-20],
        [ 8.6942e-04,  8.4796e-03,  4.8227e-04,  4.0352e-20,  7.6861e-20],
        [-3.7148e-02, -6.4875e-02, -3.2040e-03,  3.2432e-21,  2.7712e-21],
        [-1.4216e-04, -7.3936e-04, -4.9316e-03,  2.3428e-20,  5.9712e-20],
        [-1.8043e-04, -6.2064e-04, -1.0897e-01,  1.0966e-19,  2.0354e-21],
        [-2.9563e-05,  2.4076e-03, -2.2604e-03, -3.7806e-20,  1.2923e-20],
        [-1.3280e-01, -8.1931e-02,  2.7425e-04,  3.6754e-20, -6.7006e-21],
        [-7.2912e-02,  1.0503e-02, -5.1707e-03,  1.0337e-19,  4.4369e-20],
        [-1.6319e-04,  7.5262e-04,  2.2184e-03, -5.1575e-20,  2.2770e-20],
        [ 7.4864e-02,  3.1517e-02, -1.0441e-01, -2.9784e-20, -1.7040e-20],
        [-2.5764e-03,  2.2624e-04,  7.4085e-02, -1.2229e-20, -3.2523e-20],
        [ 7.8175e-06,  6.4782e-04,  1.2907e-01, -5.3282e-21, -1.8154e-19]]), tau=tensor(1.0036), history_obj=[557786.5625, 559040.625, 559512.1875, 559871.9375, 560230.125, 560637.625, 560872.375, 560886.25, 560887.3125, 560887.9375])
mycebmf31=  cEBMF(data=Z, X_l=X,
                 prior_L="cgb_sharp" , allow_backfitting=False) 
mycebmf31.initialise_factors()
mycebmf31.fit(10)

plt.plot(mycebmf31.obj)
[CGB] Epoch   1/50 | Avg NLL=15.180039 | mu=0.9905 | sigma=0.0198 | mean π0=0.7333
[CGB] Epoch  10/50 | Avg NLL=14.846239 | mu=0.9718 | sigma=0.0194 | mean π0=0.9987
[CGB] Epoch  20/50 | Avg NLL=14.845418 | mu=0.9715 | sigma=0.0194 | mean π0=0.9997
[CGB] Epoch  30/50 | Avg NLL=14.845271 | mu=0.9714 | sigma=0.0194 | mean π0=0.9999
[CGB] Epoch  40/50 | Avg NLL=14.845224 | mu=0.9714 | sigma=0.0194 | mean π0=0.9999
[CGB] Epoch  50/50 | Avg NLL=14.845204 | mu=0.9713 | sigma=0.0194 | mean π0=1.0000
[CGB] Epoch   1/50 | Avg NLL=2.783078 | mu=0.9902 | sigma=0.0198 | mean π0=0.7016
[CGB] Epoch  10/50 | Avg NLL=2.670080 | mu=0.9723 | sigma=0.0194 | mean π0=0.9883
[CGB] Epoch  20/50 | Avg NLL=2.668837 | mu=0.9750 | sigma=0.0195 | mean π0=0.9682
[CGB] Epoch  30/50 | Avg NLL=2.668136 | mu=0.9744 | sigma=0.0195 | mean π0=0.9682
[CGB] Epoch  40/50 | Avg NLL=2.667827 | mu=0.9731 | sigma=0.0195 | mean π0=0.9699
[CGB] Epoch  50/50 | Avg NLL=2.667539 | mu=0.9719 | sigma=0.0194 | mean π0=0.9650
[CGB] Epoch   1/50 | Avg NLL=1.956530 | mu=1.0102 | sigma=0.0202 | mean π0=0.4965
[CGB] Epoch  10/50 | Avg NLL=1.792126 | mu=1.1064 | sigma=0.0221 | mean π0=0.6358
[CGB] Epoch  20/50 | Avg NLL=1.755698 | mu=1.2144 | sigma=0.0243 | mean π0=0.6531
[CGB] Epoch  30/50 | Avg NLL=1.731302 | mu=1.3182 | sigma=0.0264 | mean π0=0.6626
[CGB] Epoch  40/50 | Avg NLL=1.713162 | mu=1.4160 | sigma=0.0283 | mean π0=0.6672
[CGB] Epoch  50/50 | Avg NLL=1.699273 | mu=1.5052 | sigma=0.0301 | mean π0=0.6667
[CGB] Epoch   1/50 | Avg NLL=1.789250 | mu=0.9915 | sigma=0.0198 | mean π0=0.6959
[CGB] Epoch  10/50 | Avg NLL=1.752941 | mu=1.0454 | sigma=0.0209 | mean π0=0.8384
[CGB] Epoch  20/50 | Avg NLL=1.748639 | mu=1.1327 | sigma=0.0227 | mean π0=0.8517
[CGB] Epoch  30/50 | Avg NLL=1.745464 | mu=1.2229 | sigma=0.0245 | mean π0=0.8482
[CGB] Epoch  40/50 | Avg NLL=1.742696 | mu=1.3103 | sigma=0.0262 | mean π0=0.8703
[CGB] Epoch  50/50 | Avg NLL=1.740762 | mu=1.3921 | sigma=0.0278 | mean π0=0.8822
[CGB] Epoch   1/50 | Avg NLL=1.775420 | mu=0.9949 | sigma=0.0199 | mean π0=0.7884
[CGB] Epoch  10/50 | Avg NLL=1.752033 | mu=1.0609 | sigma=0.0212 | mean π0=0.8513
[CGB] Epoch  20/50 | Avg NLL=1.748081 | mu=1.1514 | sigma=0.0230 | mean π0=0.8615
[CGB] Epoch  30/50 | Avg NLL=1.744852 | mu=1.2431 | sigma=0.0249 | mean π0=0.8818
[CGB] Epoch  40/50 | Avg NLL=1.742722 | mu=1.3310 | sigma=0.0266 | mean π0=0.8852
[CGB] Epoch  50/50 | Avg NLL=1.739838 | mu=1.4152 | sigma=0.0283 | mean π0=0.8878
[CGB] Epoch   1/50 | Avg NLL=1.408971 | mu=0.9900 | sigma=0.0198 | mean π0=0.6438
[CGB] Epoch  10/50 | Avg NLL=1.315276 | mu=0.9638 | sigma=0.0193 | mean π0=0.9988
[CGB] Epoch  20/50 | Avg NLL=1.315041 | mu=0.9634 | sigma=0.0193 | mean π0=0.9998
[CGB] Epoch  30/50 | Avg NLL=1.315019 | mu=0.9633 | sigma=0.0193 | mean π0=0.9999
[CGB] Epoch  40/50 | Avg NLL=1.315013 | mu=0.9632 | sigma=0.0193 | mean π0=1.0000
[CGB] Epoch  50/50 | Avg NLL=1.315011 | mu=0.9632 | sigma=0.0193 | mean π0=1.0000
[CGB] Epoch   1/50 | Avg NLL=0.875749 | mu=0.9900 | sigma=0.0198 | mean π0=0.5419
[CGB] Epoch  10/50 | Avg NLL=0.288641 | mu=0.9167 | sigma=0.0183 | mean π0=0.4928
[CGB] Epoch  20/50 | Avg NLL=0.251207 | mu=0.8778 | sigma=0.0176 | mean π0=0.4805
[CGB] Epoch  30/50 | Avg NLL=0.240159 | mu=0.8678 | sigma=0.0174 | mean π0=0.4810
[CGB] Epoch  40/50 | Avg NLL=0.232173 | mu=0.8651 | sigma=0.0173 | mean π0=0.4756
[CGB] Epoch  50/50 | Avg NLL=0.227098 | mu=0.8650 | sigma=0.0173 | mean π0=0.4730
[CGB] Epoch   1/50 | Avg NLL=1.549575 | mu=1.0080 | sigma=0.0202 | mean π0=0.5213
[CGB] Epoch  10/50 | Avg NLL=1.386977 | mu=1.1017 | sigma=0.0220 | mean π0=0.6459
[CGB] Epoch  20/50 | Avg NLL=1.362506 | mu=1.1992 | sigma=0.0240 | mean π0=0.6584
[CGB] Epoch  30/50 | Avg NLL=1.352055 | mu=1.2794 | sigma=0.0256 | mean π0=0.6617
[CGB] Epoch  40/50 | Avg NLL=1.346330 | mu=1.3414 | sigma=0.0268 | mean π0=0.6659
[CGB] Epoch  50/50 | Avg NLL=1.344117 | mu=1.3825 | sigma=0.0277 | mean π0=0.6666
[CGB] Epoch   1/50 | Avg NLL=0.869128 | mu=0.9901 | sigma=0.0198 | mean π0=0.6553
[CGB] Epoch  10/50 | Avg NLL=0.693668 | mu=0.9200 | sigma=0.0184 | mean π0=0.7766
[CGB] Epoch  20/50 | Avg NLL=0.666320 | mu=0.8447 | sigma=0.0169 | mean π0=0.7342
[CGB] Epoch  30/50 | Avg NLL=0.648962 | mu=0.7809 | sigma=0.0156 | mean π0=0.6992
[CGB] Epoch  40/50 | Avg NLL=0.637415 | mu=0.7295 | sigma=0.0146 | mean π0=0.6918
[CGB] Epoch  50/50 | Avg NLL=0.629867 | mu=0.6900 | sigma=0.0138 | mean π0=0.6763
[CGB] Epoch   1/50 | Avg NLL=0.834696 | mu=0.9901 | sigma=0.0198 | mean π0=0.7196
[CGB] Epoch  10/50 | Avg NLL=0.685016 | mu=0.9623 | sigma=0.0192 | mean π0=0.9136
[CGB] Epoch  20/50 | Avg NLL=0.671587 | mu=0.9260 | sigma=0.0185 | mean π0=0.9004
[CGB] Epoch  30/50 | Avg NLL=0.663554 | mu=0.8715 | sigma=0.0174 | mean π0=0.8757
[CGB] Epoch  40/50 | Avg NLL=0.657818 | mu=0.8136 | sigma=0.0163 | mean π0=0.8616
[CGB] Epoch  50/50 | Avg NLL=0.653152 | mu=0.7568 | sigma=0.0151 | mean π0=0.8327
[CGB] Epoch   1/50 | Avg NLL=1.671999 | mu=0.9901 | sigma=0.0198 | mean π0=0.6774
[CGB] Epoch  10/50 | Avg NLL=1.622321 | mu=0.9656 | sigma=0.0193 | mean π0=0.9989
[CGB] Epoch  20/50 | Avg NLL=1.622244 | mu=0.9650 | sigma=0.0193 | mean π0=0.9996
[CGB] Epoch  30/50 | Avg NLL=1.622221 | mu=0.9647 | sigma=0.0193 | mean π0=0.9998
[CGB] Epoch  40/50 | Avg NLL=1.622211 | mu=0.9644 | sigma=0.0193 | mean π0=0.9999
[CGB] Epoch  50/50 | Avg NLL=1.622206 | mu=0.9643 | sigma=0.0193 | mean π0=0.9999
[CGB] Epoch   1/50 | Avg NLL=0.350778 | mu=0.9899 | sigma=0.0198 | mean π0=0.4536
[CGB] Epoch  10/50 | Avg NLL=-0.366707 | mu=0.9143 | sigma=0.0183 | mean π0=0.4778
[CGB] Epoch  20/50 | Avg NLL=-0.427352 | mu=0.8692 | sigma=0.0174 | mean π0=0.4712
[CGB] Epoch  30/50 | Avg NLL=-0.439887 | mu=0.8525 | sigma=0.0171 | mean π0=0.4789
[CGB] Epoch  40/50 | Avg NLL=-0.446119 | mu=0.8480 | sigma=0.0170 | mean π0=0.4681
[CGB] Epoch  50/50 | Avg NLL=-0.451640 | mu=0.8475 | sigma=0.0170 | mean π0=0.4682
[CGB] Epoch   1/50 | Avg NLL=1.398015 | mu=1.0053 | sigma=0.0201 | mean π0=0.5402
[CGB] Epoch  10/50 | Avg NLL=1.208552 | mu=1.0932 | sigma=0.0219 | mean π0=0.6408
[CGB] Epoch  20/50 | Avg NLL=1.188575 | mu=1.1735 | sigma=0.0235 | mean π0=0.6610
[CGB] Epoch  30/50 | Avg NLL=1.180704 | mu=1.2301 | sigma=0.0246 | mean π0=0.6622
[CGB] Epoch  40/50 | Avg NLL=1.177048 | mu=1.2642 | sigma=0.0253 | mean π0=0.6627
[CGB] Epoch  50/50 | Avg NLL=1.175546 | mu=1.2824 | sigma=0.0256 | mean π0=0.6668
[CGB] Epoch   1/50 | Avg NLL=0.784880 | mu=0.9901 | sigma=0.0198 | mean π0=0.6053
[CGB] Epoch  10/50 | Avg NLL=0.577059 | mu=0.9110 | sigma=0.0182 | mean π0=0.6724
[CGB] Epoch  20/50 | Avg NLL=0.528655 | mu=0.8428 | sigma=0.0169 | mean π0=0.6481
[CGB] Epoch  30/50 | Avg NLL=0.510267 | mu=0.7963 | sigma=0.0159 | mean π0=0.6219
[CGB] Epoch  40/50 | Avg NLL=0.500302 | mu=0.7672 | sigma=0.0153 | mean π0=0.6131
[CGB] Epoch  50/50 | Avg NLL=0.495283 | mu=0.7502 | sigma=0.0150 | mean π0=0.6186
[CGB] Epoch   1/50 | Avg NLL=1.006676 | mu=0.9903 | sigma=0.0198 | mean π0=0.6549
[CGB] Epoch  10/50 | Avg NLL=0.839356 | mu=0.9533 | sigma=0.0191 | mean π0=0.8824
[CGB] Epoch  20/50 | Avg NLL=0.820964 | mu=0.9083 | sigma=0.0182 | mean π0=0.8583
[CGB] Epoch  30/50 | Avg NLL=0.812704 | mu=0.8593 | sigma=0.0172 | mean π0=0.8476
[CGB] Epoch  40/50 | Avg NLL=0.807991 | mu=0.8136 | sigma=0.0163 | mean π0=0.8216
[CGB] Epoch  50/50 | Avg NLL=0.805180 | mu=0.7737 | sigma=0.0155 | mean π0=0.8037
[CGB] Epoch   1/50 | Avg NLL=2.593431 | mu=0.9901 | sigma=0.0198 | mean π0=0.6873
[CGB] Epoch  10/50 | Avg NLL=2.585954 | mu=0.9663 | sigma=0.0193 | mean π0=0.9989
[CGB] Epoch  20/50 | Avg NLL=2.585942 | mu=0.9656 | sigma=0.0193 | mean π0=0.9996
[CGB] Epoch  30/50 | Avg NLL=2.585939 | mu=0.9653 | sigma=0.0193 | mean π0=0.9998
[CGB] Epoch  40/50 | Avg NLL=2.585937 | mu=0.9651 | sigma=0.0193 | mean π0=0.9999
[CGB] Epoch  50/50 | Avg NLL=2.585936 | mu=0.9649 | sigma=0.0193 | mean π0=0.9999
[CGB] Epoch   1/50 | Avg NLL=0.454545 | mu=0.9899 | sigma=0.0198 | mean π0=0.4988
[CGB] Epoch  10/50 | Avg NLL=-0.267123 | mu=0.9121 | sigma=0.0182 | mean π0=0.4743
[CGB] Epoch  20/50 | Avg NLL=-0.363124 | mu=0.8610 | sigma=0.0172 | mean π0=0.4721
[CGB] Epoch  30/50 | Avg NLL=-0.385322 | mu=0.8385 | sigma=0.0168 | mean π0=0.4631
[CGB] Epoch  40/50 | Avg NLL=-0.393875 | mu=0.8305 | sigma=0.0166 | mean π0=0.4667
[CGB] Epoch  50/50 | Avg NLL=-0.398044 | mu=0.8284 | sigma=0.0166 | mean π0=0.4680
[CGB] Epoch   1/50 | Avg NLL=1.390411 | mu=1.0094 | sigma=0.0202 | mean π0=0.5625
[CGB] Epoch  10/50 | Avg NLL=1.209554 | mu=1.1008 | sigma=0.0220 | mean π0=0.6530
[CGB] Epoch  20/50 | Avg NLL=1.178019 | mu=1.1914 | sigma=0.0238 | mean π0=0.7067
[CGB] Epoch  30/50 | Avg NLL=1.125575 | mu=1.3109 | sigma=0.0262 | mean π0=0.7398
[CGB] Epoch  40/50 | Avg NLL=1.098838 | mu=1.4172 | sigma=0.0283 | mean π0=0.7449
[CGB] Epoch  50/50 | Avg NLL=1.086520 | mu=1.5017 | sigma=0.0300 | mean π0=0.7564
[CGB] Epoch   1/50 | Avg NLL=0.785650 | mu=0.9912 | sigma=0.0198 | mean π0=0.6274
[CGB] Epoch  10/50 | Avg NLL=0.576280 | mu=0.9191 | sigma=0.0184 | mean π0=0.6420
[CGB] Epoch  20/50 | Avg NLL=0.545808 | mu=0.8762 | sigma=0.0175 | mean π0=0.6378
[CGB] Epoch  30/50 | Avg NLL=0.536944 | mu=0.8611 | sigma=0.0172 | mean π0=0.6387
[CGB] Epoch  40/50 | Avg NLL=0.531563 | mu=0.8576 | sigma=0.0172 | mean π0=0.6475
[CGB] Epoch  50/50 | Avg NLL=0.526696 | mu=0.8553 | sigma=0.0171 | mean π0=0.6425
[CGB] Epoch   1/50 | Avg NLL=1.123965 | mu=0.9900 | sigma=0.0198 | mean π0=0.6727
[CGB] Epoch  10/50 | Avg NLL=0.939816 | mu=0.9356 | sigma=0.0187 | mean π0=0.8187
[CGB] Epoch  20/50 | Avg NLL=0.924955 | mu=0.8779 | sigma=0.0176 | mean π0=0.7852
[CGB] Epoch  30/50 | Avg NLL=0.917897 | mu=0.8301 | sigma=0.0166 | mean π0=0.7702
[CGB] Epoch  40/50 | Avg NLL=0.913451 | mu=0.7931 | sigma=0.0159 | mean π0=0.7542
[CGB] Epoch  50/50 | Avg NLL=0.909635 | mu=0.7685 | sigma=0.0154 | mean π0=0.7498
[CGB] Epoch   1/50 | Avg NLL=2.617689 | mu=0.9901 | sigma=0.0198 | mean π0=0.7311
[CGB] Epoch  10/50 | Avg NLL=2.611397 | mu=0.9663 | sigma=0.0193 | mean π0=0.9986
[CGB] Epoch  20/50 | Avg NLL=2.611382 | mu=0.9654 | sigma=0.0193 | mean π0=0.9995
[CGB] Epoch  30/50 | Avg NLL=2.611377 | mu=0.9649 | sigma=0.0193 | mean π0=0.9997
[CGB] Epoch  40/50 | Avg NLL=2.611376 | mu=0.9646 | sigma=0.0193 | mean π0=0.9998
[CGB] Epoch  50/50 | Avg NLL=2.611375 | mu=0.9643 | sigma=0.0193 | mean π0=0.9999
[CGB] Epoch   1/50 | Avg NLL=0.585902 | mu=0.9899 | sigma=0.0198 | mean π0=0.4983
[CGB] Epoch  10/50 | Avg NLL=-0.210082 | mu=0.9108 | sigma=0.0182 | mean π0=0.4724
[CGB] Epoch  20/50 | Avg NLL=-0.346875 | mu=0.8542 | sigma=0.0171 | mean π0=0.4699
[CGB] Epoch  30/50 | Avg NLL=-0.380051 | mu=0.8250 | sigma=0.0165 | mean π0=0.4667
[CGB] Epoch  40/50 | Avg NLL=-0.390548 | mu=0.8124 | sigma=0.0162 | mean π0=0.4682
[CGB] Epoch  50/50 | Avg NLL=-0.391743 | mu=0.8076 | sigma=0.0162 | mean π0=0.4580
[CGB] Epoch   1/50 | Avg NLL=1.287629 | mu=1.0073 | sigma=0.0201 | mean π0=0.5758
[CGB] Epoch  10/50 | Avg NLL=1.104572 | mu=1.0958 | sigma=0.0219 | mean π0=0.7080
[CGB] Epoch  20/50 | Avg NLL=1.032895 | mu=1.2103 | sigma=0.0242 | mean π0=0.7385
[CGB] Epoch  30/50 | Avg NLL=1.002350 | mu=1.3117 | sigma=0.0262 | mean π0=0.7419
[CGB] Epoch  40/50 | Avg NLL=0.988856 | mu=1.3910 | sigma=0.0278 | mean π0=0.7589
[CGB] Epoch  50/50 | Avg NLL=0.981959 | mu=1.4510 | sigma=0.0290 | mean π0=0.7590
[CGB] Epoch   1/50 | Avg NLL=0.831175 | mu=0.9905 | sigma=0.0198 | mean π0=0.6154
[CGB] Epoch  10/50 | Avg NLL=0.606671 | mu=0.9247 | sigma=0.0185 | mean π0=0.6417
[CGB] Epoch  20/50 | Avg NLL=0.576308 | mu=0.9025 | sigma=0.0180 | mean π0=0.6451
[CGB] Epoch  30/50 | Avg NLL=0.559290 | mu=0.9054 | sigma=0.0181 | mean π0=0.6403
[CGB] Epoch  40/50 | Avg NLL=0.550534 | mu=0.9095 | sigma=0.0182 | mean π0=0.6492
[CGB] Epoch  50/50 | Avg NLL=0.545238 | mu=0.9126 | sigma=0.0183 | mean π0=0.6445
[CGB] Epoch   1/50 | Avg NLL=1.167029 | mu=0.9903 | sigma=0.0198 | mean π0=0.6506
[CGB] Epoch  10/50 | Avg NLL=0.934088 | mu=0.9426 | sigma=0.0189 | mean π0=0.7450
[CGB] Epoch  20/50 | Avg NLL=0.918413 | mu=0.9025 | sigma=0.0181 | mean π0=0.7348
[CGB] Epoch  30/50 | Avg NLL=0.910262 | mu=0.8788 | sigma=0.0176 | mean π0=0.7277
[CGB] Epoch  40/50 | Avg NLL=0.904783 | mu=0.8673 | sigma=0.0173 | mean π0=0.7239
[CGB] Epoch  50/50 | Avg NLL=0.899463 | mu=0.8631 | sigma=0.0173 | mean π0=0.7207
[CGB] Epoch   1/50 | Avg NLL=2.746982 | mu=0.9900 | sigma=0.0198 | mean π0=0.6678
[CGB] Epoch  10/50 | Avg NLL=2.741418 | mu=0.9648 | sigma=0.0193 | mean π0=0.9985
[CGB] Epoch  20/50 | Avg NLL=2.741406 | mu=0.9640 | sigma=0.0193 | mean π0=0.9994
[CGB] Epoch  30/50 | Avg NLL=2.741403 | mu=0.9636 | sigma=0.0193 | mean π0=0.9997
[CGB] Epoch  40/50 | Avg NLL=2.741401 | mu=0.9633 | sigma=0.0193 | mean π0=0.9998
[CGB] Epoch  50/50 | Avg NLL=2.741400 | mu=0.9631 | sigma=0.0193 | mean π0=0.9999
[CGB] Epoch   1/50 | Avg NLL=0.647832 | mu=0.9899 | sigma=0.0198 | mean π0=0.4613
[CGB] Epoch  10/50 | Avg NLL=-0.181087 | mu=0.9099 | sigma=0.0182 | mean π0=0.4707
[CGB] Epoch  20/50 | Avg NLL=-0.354839 | mu=0.8497 | sigma=0.0170 | mean π0=0.4657
[CGB] Epoch  30/50 | Avg NLL=-0.405913 | mu=0.8152 | sigma=0.0163 | mean π0=0.4615
[CGB] Epoch  40/50 | Avg NLL=-0.421457 | mu=0.7980 | sigma=0.0160 | mean π0=0.4627
[CGB] Epoch  50/50 | Avg NLL=-0.426067 | mu=0.7904 | sigma=0.0158 | mean π0=0.4596
[CGB] Epoch   1/50 | Avg NLL=1.232012 | mu=1.0059 | sigma=0.0201 | mean π0=0.5853
[CGB] Epoch  10/50 | Avg NLL=1.055557 | mu=1.0959 | sigma=0.0219 | mean π0=0.7111
[CGB] Epoch  20/50 | Avg NLL=0.988356 | mu=1.2039 | sigma=0.0241 | mean π0=0.7507
[CGB] Epoch  30/50 | Avg NLL=0.951031 | mu=1.3022 | sigma=0.0260 | mean π0=0.7498
[CGB] Epoch  40/50 | Avg NLL=0.937855 | mu=1.3749 | sigma=0.0275 | mean π0=0.7659
[CGB] Epoch  50/50 | Avg NLL=0.931758 | mu=1.4252 | sigma=0.0285 | mean π0=0.7710
[CGB] Epoch   1/50 | Avg NLL=0.860352 | mu=0.9912 | sigma=0.0198 | mean π0=0.5730
[CGB] Epoch  10/50 | Avg NLL=0.616255 | mu=0.9376 | sigma=0.0188 | mean π0=0.6308
[CGB] Epoch  20/50 | Avg NLL=0.589981 | mu=0.9225 | sigma=0.0184 | mean π0=0.6285
[CGB] Epoch  30/50 | Avg NLL=0.572921 | mu=0.9249 | sigma=0.0185 | mean π0=0.6338
[CGB] Epoch  40/50 | Avg NLL=0.563654 | mu=0.9317 | sigma=0.0186 | mean π0=0.6251
[CGB] Epoch  50/50 | Avg NLL=0.556297 | mu=0.9355 | sigma=0.0187 | mean π0=0.6392
[CGB] Epoch   1/50 | Avg NLL=1.181196 | mu=0.9902 | sigma=0.0198 | mean π0=0.5353
[CGB] Epoch  10/50 | Avg NLL=0.883311 | mu=0.9809 | sigma=0.0196 | mean π0=0.7183
[CGB] Epoch  20/50 | Avg NLL=0.870489 | mu=0.9905 | sigma=0.0198 | mean π0=0.7118
[CGB] Epoch  30/50 | Avg NLL=0.864816 | mu=0.9922 | sigma=0.0198 | mean π0=0.7079
[CGB] Epoch  40/50 | Avg NLL=0.860032 | mu=0.9901 | sigma=0.0198 | mean π0=0.7118
[CGB] Epoch  50/50 | Avg NLL=0.857510 | mu=0.9908 | sigma=0.0198 | mean π0=0.7121
[CGB] Epoch   1/50 | Avg NLL=2.755591 | mu=0.9901 | sigma=0.0198 | mean π0=0.7085
[CGB] Epoch  10/50 | Avg NLL=2.750350 | mu=0.9675 | sigma=0.0194 | mean π0=0.9992
[CGB] Epoch  20/50 | Avg NLL=2.750344 | mu=0.9670 | sigma=0.0193 | mean π0=0.9996
[CGB] Epoch  30/50 | Avg NLL=2.750342 | mu=0.9667 | sigma=0.0193 | mean π0=0.9998
[CGB] Epoch  40/50 | Avg NLL=2.750341 | mu=0.9665 | sigma=0.0193 | mean π0=0.9999
[CGB] Epoch  50/50 | Avg NLL=2.750340 | mu=0.9664 | sigma=0.0193 | mean π0=0.9999
[CGB] Epoch   1/50 | Avg NLL=0.701133 | mu=0.9900 | sigma=0.0198 | mean π0=0.4833
[CGB] Epoch  10/50 | Avg NLL=-0.161388 | mu=0.9096 | sigma=0.0182 | mean π0=0.4713
[CGB] Epoch  20/50 | Avg NLL=-0.357579 | mu=0.8475 | sigma=0.0170 | mean π0=0.4628
[CGB] Epoch  30/50 | Avg NLL=-0.421565 | mu=0.8103 | sigma=0.0162 | mean π0=0.4578
[CGB] Epoch  40/50 | Avg NLL=-0.440704 | mu=0.7903 | sigma=0.0158 | mean π0=0.4611
[CGB] Epoch  50/50 | Avg NLL=-0.445684 | mu=0.7808 | sigma=0.0156 | mean π0=0.4666
[CGB] Epoch   1/50 | Avg NLL=1.200726 | mu=1.0020 | sigma=0.0200 | mean π0=0.6014
[CGB] Epoch  10/50 | Avg NLL=1.016947 | mu=1.0853 | sigma=0.0217 | mean π0=0.7063
[CGB] Epoch  20/50 | Avg NLL=0.950660 | mu=1.1876 | sigma=0.0238 | mean π0=0.7490
[CGB] Epoch  30/50 | Avg NLL=0.911476 | mu=1.2821 | sigma=0.0256 | mean π0=0.7682
[CGB] Epoch  40/50 | Avg NLL=0.898462 | mu=1.3499 | sigma=0.0270 | mean π0=0.7743
[CGB] Epoch  50/50 | Avg NLL=0.893547 | mu=1.3940 | sigma=0.0279 | mean π0=0.7673
[CGB] Epoch   1/50 | Avg NLL=0.822898 | mu=0.9945 | sigma=0.0199 | mean π0=0.5971
[CGB] Epoch  10/50 | Avg NLL=0.615699 | mu=0.9398 | sigma=0.0188 | mean π0=0.6087
[CGB] Epoch  20/50 | Avg NLL=0.577777 | mu=0.9368 | sigma=0.0187 | mean π0=0.6079
[CGB] Epoch  30/50 | Avg NLL=0.562640 | mu=0.9427 | sigma=0.0189 | mean π0=0.6315
[CGB] Epoch  40/50 | Avg NLL=0.556291 | mu=0.9476 | sigma=0.0190 | mean π0=0.6359
[CGB] Epoch  50/50 | Avg NLL=0.552082 | mu=0.9517 | sigma=0.0190 | mean π0=0.6302
[CGB] Epoch   1/50 | Avg NLL=1.137766 | mu=1.0007 | sigma=0.0200 | mean π0=0.5771
[CGB] Epoch  10/50 | Avg NLL=0.851282 | mu=1.0613 | sigma=0.0212 | mean π0=0.7064
[CGB] Epoch  20/50 | Avg NLL=0.839621 | mu=1.0915 | sigma=0.0218 | mean π0=0.7147
[CGB] Epoch  30/50 | Avg NLL=0.834353 | mu=1.0967 | sigma=0.0219 | mean π0=0.7111
[CGB] Epoch  40/50 | Avg NLL=0.828371 | mu=1.0980 | sigma=0.0220 | mean π0=0.7118
[CGB] Epoch  50/50 | Avg NLL=0.823510 | mu=1.0977 | sigma=0.0220 | mean π0=0.7083
[CGB] Epoch   1/50 | Avg NLL=2.577604 | mu=0.9901 | sigma=0.0198 | mean π0=0.6998
[CGB] Epoch  10/50 | Avg NLL=2.570279 | mu=0.9654 | sigma=0.0193 | mean π0=0.9987
[CGB] Epoch  20/50 | Avg NLL=2.570264 | mu=0.9647 | sigma=0.0193 | mean π0=0.9995
[CGB] Epoch  30/50 | Avg NLL=2.570260 | mu=0.9643 | sigma=0.0193 | mean π0=0.9998
[CGB] Epoch  40/50 | Avg NLL=2.570258 | mu=0.9641 | sigma=0.0193 | mean π0=0.9998
[CGB] Epoch  50/50 | Avg NLL=2.570257 | mu=0.9639 | sigma=0.0193 | mean π0=0.9999
[CGB] Epoch   1/50 | Avg NLL=0.735835 | mu=0.9900 | sigma=0.0198 | mean π0=0.4765
[CGB] Epoch  10/50 | Avg NLL=-0.142963 | mu=0.9094 | sigma=0.0182 | mean π0=0.4636
[CGB] Epoch  20/50 | Avg NLL=-0.350647 | mu=0.8467 | sigma=0.0169 | mean π0=0.4628
[CGB] Epoch  30/50 | Avg NLL=-0.422231 | mu=0.8082 | sigma=0.0162 | mean π0=0.4642
[CGB] Epoch  40/50 | Avg NLL=-0.445951 | mu=0.7870 | sigma=0.0157 | mean π0=0.4557
[CGB] Epoch  50/50 | Avg NLL=-0.451789 | mu=0.7767 | sigma=0.0155 | mean π0=0.4621
[CGB] Epoch   1/50 | Avg NLL=1.173112 | mu=1.0003 | sigma=0.0200 | mean π0=0.6235
[CGB] Epoch  10/50 | Avg NLL=0.992972 | mu=1.0830 | sigma=0.0217 | mean π0=0.7202
[CGB] Epoch  20/50 | Avg NLL=0.943293 | mu=1.1759 | sigma=0.0235 | mean π0=0.7473
[CGB] Epoch  30/50 | Avg NLL=0.893108 | mu=1.2671 | sigma=0.0253 | mean π0=0.7589
[CGB] Epoch  40/50 | Avg NLL=0.877948 | mu=1.3340 | sigma=0.0267 | mean π0=0.7683
[CGB] Epoch  50/50 | Avg NLL=0.871828 | mu=1.3747 | sigma=0.0275 | mean π0=0.7748
[CGB] Epoch   1/50 | Avg NLL=0.870577 | mu=0.9902 | sigma=0.0198 | mean π0=0.5274
[CGB] Epoch  10/50 | Avg NLL=0.609930 | mu=0.9467 | sigma=0.0189 | mean π0=0.5906
[CGB] Epoch  20/50 | Avg NLL=0.578576 | mu=0.9488 | sigma=0.0190 | mean π0=0.6087
[CGB] Epoch  30/50 | Avg NLL=0.558950 | mu=0.9605 | sigma=0.0192 | mean π0=0.6193
[CGB] Epoch  40/50 | Avg NLL=0.550004 | mu=0.9654 | sigma=0.0193 | mean π0=0.6079
[CGB] Epoch  50/50 | Avg NLL=0.544479 | mu=0.9672 | sigma=0.0193 | mean π0=0.6232
[CGB] Epoch   1/50 | Avg NLL=1.158275 | mu=1.0026 | sigma=0.0201 | mean π0=0.5252
[CGB] Epoch  10/50 | Avg NLL=0.836175 | mu=1.0940 | sigma=0.0219 | mean π0=0.6960
[CGB] Epoch  20/50 | Avg NLL=0.818744 | mu=1.1565 | sigma=0.0231 | mean π0=0.6948
[CGB] Epoch  30/50 | Avg NLL=0.813472 | mu=1.1753 | sigma=0.0235 | mean π0=0.7111
[CGB] Epoch  40/50 | Avg NLL=0.809224 | mu=1.1800 | sigma=0.0236 | mean π0=0.7120
[CGB] Epoch  50/50 | Avg NLL=0.804999 | mu=1.1809 | sigma=0.0236 | mean π0=0.7123
[CGB] Epoch   1/50 | Avg NLL=2.663827 | mu=0.9901 | sigma=0.0198 | mean π0=0.7126
[CGB] Epoch  10/50 | Avg NLL=2.657800 | mu=0.9652 | sigma=0.0193 | mean π0=0.9985
[CGB] Epoch  20/50 | Avg NLL=2.657785 | mu=0.9643 | sigma=0.0193 | mean π0=0.9994
[CGB] Epoch  30/50 | Avg NLL=2.657781 | mu=0.9638 | sigma=0.0193 | mean π0=0.9997
[CGB] Epoch  40/50 | Avg NLL=2.657779 | mu=0.9635 | sigma=0.0193 | mean π0=0.9998
[CGB] Epoch  50/50 | Avg NLL=2.657777 | mu=0.9633 | sigma=0.0193 | mean π0=0.9999
[CGB] Epoch   1/50 | Avg NLL=0.694985 | mu=0.9899 | sigma=0.0198 | mean π0=0.5258
[CGB] Epoch  10/50 | Avg NLL=-0.145711 | mu=0.9091 | sigma=0.0182 | mean π0=0.4702
[CGB] Epoch  20/50 | Avg NLL=-0.351548 | mu=0.8461 | sigma=0.0169 | mean π0=0.4551
[CGB] Epoch  30/50 | Avg NLL=-0.419886 | mu=0.8077 | sigma=0.0162 | mean π0=0.4589
[CGB] Epoch  40/50 | Avg NLL=-0.441069 | mu=0.7867 | sigma=0.0157 | mean π0=0.4667
[CGB] Epoch  50/50 | Avg NLL=-0.444165 | mu=0.7764 | sigma=0.0155 | mean π0=0.4598
[CGB] Epoch   1/50 | Avg NLL=1.156677 | mu=1.0019 | sigma=0.0200 | mean π0=0.6075
[CGB] Epoch  10/50 | Avg NLL=0.956058 | mu=1.0881 | sigma=0.0218 | mean π0=0.7252
[CGB] Epoch  20/50 | Avg NLL=0.886029 | mu=1.1901 | sigma=0.0238 | mean π0=0.7627
[CGB] Epoch  30/50 | Avg NLL=0.858461 | mu=1.2704 | sigma=0.0254 | mean π0=0.7694
[CGB] Epoch  40/50 | Avg NLL=0.849524 | mu=1.3209 | sigma=0.0264 | mean π0=0.7719
[CGB] Epoch  50/50 | Avg NLL=0.845747 | mu=1.3510 | sigma=0.0270 | mean π0=0.7720
[CGB] Epoch   1/50 | Avg NLL=0.871928 | mu=0.9911 | sigma=0.0198 | mean π0=0.5179
[CGB] Epoch  10/50 | Avg NLL=0.593765 | mu=0.9621 | sigma=0.0192 | mean π0=0.6046
[CGB] Epoch  20/50 | Avg NLL=0.563499 | mu=0.9679 | sigma=0.0194 | mean π0=0.6107
[CGB] Epoch  30/50 | Avg NLL=0.548581 | mu=0.9742 | sigma=0.0195 | mean π0=0.5992
[CGB] Epoch  40/50 | Avg NLL=0.541365 | mu=0.9776 | sigma=0.0196 | mean π0=0.6140
[CGB] Epoch  50/50 | Avg NLL=0.536368 | mu=0.9801 | sigma=0.0196 | mean π0=0.6198
[CGB] Epoch   1/50 | Avg NLL=1.166712 | mu=1.0091 | sigma=0.0202 | mean π0=0.5564
[CGB] Epoch  10/50 | Avg NLL=0.826018 | mu=1.1018 | sigma=0.0220 | mean π0=0.7015
[CGB] Epoch  20/50 | Avg NLL=0.798490 | mu=1.1764 | sigma=0.0235 | mean π0=0.7002
[CGB] Epoch  30/50 | Avg NLL=0.789356 | mu=1.2196 | sigma=0.0244 | mean π0=0.7068
[CGB] Epoch  40/50 | Avg NLL=0.782785 | mu=1.2377 | sigma=0.0248 | mean π0=0.7042
[CGB] Epoch  50/50 | Avg NLL=0.778260 | mu=1.2433 | sigma=0.0249 | mean π0=0.7128
[CGB] Epoch   1/50 | Avg NLL=2.555595 | mu=0.9900 | sigma=0.0198 | mean π0=0.6521
[CGB] Epoch  10/50 | Avg NLL=2.547132 | mu=0.9626 | sigma=0.0193 | mean π0=0.9981
[CGB] Epoch  20/50 | Avg NLL=2.547100 | mu=0.9618 | sigma=0.0192 | mean π0=0.9996
[CGB] Epoch  30/50 | Avg NLL=2.547096 | mu=0.9616 | sigma=0.0192 | mean π0=0.9999
[CGB] Epoch  40/50 | Avg NLL=2.547094 | mu=0.9615 | sigma=0.0192 | mean π0=0.9999
[CGB] Epoch  50/50 | Avg NLL=2.547094 | mu=0.9614 | sigma=0.0192 | mean π0=1.0000
[CGB] Epoch   1/50 | Avg NLL=0.721028 | mu=0.9899 | sigma=0.0198 | mean π0=0.4338
[CGB] Epoch  10/50 | Avg NLL=-0.139705 | mu=0.9091 | sigma=0.0182 | mean π0=0.4648
[CGB] Epoch  20/50 | Avg NLL=-0.338156 | mu=0.8467 | sigma=0.0169 | mean π0=0.4593
[CGB] Epoch  30/50 | Avg NLL=-0.404405 | mu=0.8085 | sigma=0.0162 | mean π0=0.4569
[CGB] Epoch  40/50 | Avg NLL=-0.422484 | mu=0.7876 | sigma=0.0158 | mean π0=0.4617
[CGB] Epoch  50/50 | Avg NLL=-0.431881 | mu=0.7774 | sigma=0.0155 | mean π0=0.4624
[CGB] Epoch   1/50 | Avg NLL=1.134098 | mu=1.0062 | sigma=0.0201 | mean π0=0.6365
[CGB] Epoch  10/50 | Avg NLL=0.948594 | mu=1.0924 | sigma=0.0218 | mean π0=0.7280
[CGB] Epoch  20/50 | Avg NLL=0.889866 | mu=1.1861 | sigma=0.0237 | mean π0=0.7617
[CGB] Epoch  30/50 | Avg NLL=0.850099 | mu=1.2676 | sigma=0.0254 | mean π0=0.7634
[CGB] Epoch  40/50 | Avg NLL=0.837479 | mu=1.3182 | sigma=0.0264 | mean π0=0.7674
[CGB] Epoch  50/50 | Avg NLL=0.833480 | mu=1.3468 | sigma=0.0269 | mean π0=0.7803
[CGB] Epoch   1/50 | Avg NLL=0.841442 | mu=0.9975 | sigma=0.0199 | mean π0=0.5296
[CGB] Epoch  10/50 | Avg NLL=0.579206 | mu=0.9780 | sigma=0.0196 | mean π0=0.5883
[CGB] Epoch  20/50 | Avg NLL=0.542866 | mu=0.9875 | sigma=0.0197 | mean π0=0.5969
[CGB] Epoch  30/50 | Avg NLL=0.529416 | mu=0.9934 | sigma=0.0199 | mean π0=0.5936
[CGB] Epoch  40/50 | Avg NLL=0.522961 | mu=0.9952 | sigma=0.0199 | mean π0=0.5981
[CGB] Epoch  50/50 | Avg NLL=0.519772 | mu=0.9971 | sigma=0.0199 | mean π0=0.6036
[CGB] Epoch   1/50 | Avg NLL=1.176625 | mu=1.0094 | sigma=0.0202 | mean π0=0.5611
[CGB] Epoch  10/50 | Avg NLL=0.820941 | mu=1.1022 | sigma=0.0220 | mean π0=0.6919
[CGB] Epoch  20/50 | Avg NLL=0.787833 | mu=1.1854 | sigma=0.0237 | mean π0=0.6976
[CGB] Epoch  30/50 | Avg NLL=0.776344 | mu=1.2412 | sigma=0.0248 | mean π0=0.6904
[CGB] Epoch  40/50 | Avg NLL=0.770323 | mu=1.2716 | sigma=0.0254 | mean π0=0.7093
[CGB] Epoch  50/50 | Avg NLL=0.764893 | mu=1.2856 | sigma=0.0257 | mean π0=0.7013
[<matplotlib.lines.Line2D at 0x1b7d61e9310>]
../_images/f45698148d49b8745abc408103931e93cc57697c1fd0466179ddcbbe7ef23cbc.png
plt.plot(mycebmf31.obj)
[<matplotlib.lines.Line2D at 0x1b7d6222110>]
../_images/f45698148d49b8745abc408103931e93cc57697c1fd0466179ddcbbe7ef23cbc.png

Fitting the cEBMF model

Here we store the side information for the rows \(L\) in the matrix \(X\) that we pass in argument X_l (covariate for L) Using prior a “covariate moderated adaptive shrinkage prior “ (cash) \( g(x,y) = \pi_0(x,y) \delta_0 + \sum_{m=1}^M \pi_m (x,y) N(0, \sigma_m^2) \)

mycebmf11=  cEBMF(data=Z, X_l=X,
                 prior_L="cash" , allow_backfitting=False) 
mycebmf11.initialise_factors()
mycebmf11.fit(10)
[CASH] Epoch 10/20 | Loss: 318.0065
[CASH] Epoch 20/20 | Loss: 314.4136
[CASH] Epoch 10/20 | Loss: 209.1351
[CASH] Epoch 20/20 | Loss: 203.4585
[CASH] Epoch 10/20 | Loss: 221.2009
[CASH] Epoch 20/20 | Loss: 214.8821
[CASH] Epoch 10/20 | Loss: 209.0129
[CASH] Epoch 20/20 | Loss: 209.0653
[CASH] Epoch 10/20 | Loss: 208.9766
[CASH] Epoch 20/20 | Loss: 209.0096
[CASH] Epoch 10/20 | Loss: 314.7596
[CASH] Epoch 20/20 | Loss: 311.3474
[CASH] Epoch 10/20 | Loss: 202.0447
[CASH] Epoch 20/20 | Loss: 196.4391
[CASH] Epoch 10/20 | Loss: 213.8132
[CASH] Epoch 20/20 | Loss: 210.5392
[CASH] Epoch 10/20 | Loss: 205.9917
[CASH] Epoch 20/20 | Loss: 206.0365
[CASH] Epoch 10/20 | Loss: 206.5518
[CASH] Epoch 20/20 | Loss: 206.2175
[CASH] Epoch 10/20 | Loss: 315.2998
[CASH] Epoch 20/20 | Loss: 312.0461
[CASH] Epoch 10/20 | Loss: 201.6886
[CASH] Epoch 20/20 | Loss: 195.6338
[CASH] Epoch 10/20 | Loss: 214.9021
[CASH] Epoch 20/20 | Loss: 207.1794
[CASH] Epoch 10/20 | Loss: 210.5850
[CASH] Epoch 20/20 | Loss: 210.3396
[CASH] Epoch 10/20 | Loss: 206.0460
[CASH] Epoch 20/20 | Loss: 205.9410
[CASH] Epoch 10/20 | Loss: 316.4176
[CASH] Epoch 20/20 | Loss: 310.6652
[CASH] Epoch 10/20 | Loss: 202.5271
[CASH] Epoch 20/20 | Loss: 195.8371
[CASH] Epoch 10/20 | Loss: 209.3030
[CASH] Epoch 20/20 | Loss: 205.2607
[CASH] Epoch 10/20 | Loss: 208.8357
[CASH] Epoch 20/20 | Loss: 208.7954
[CASH] Epoch 10/20 | Loss: 204.1910
[CASH] Epoch 20/20 | Loss: 203.8816
[CASH] Epoch 10/20 | Loss: 314.9395
[CASH] Epoch 20/20 | Loss: 310.7118
[CASH] Epoch 10/20 | Loss: 201.8327
[CASH] Epoch 20/20 | Loss: 195.2398
[CASH] Epoch 10/20 | Loss: 212.2016
[CASH] Epoch 20/20 | Loss: 206.3348
[CASH] Epoch 10/20 | Loss: 212.3192
[CASH] Epoch 20/20 | Loss: 212.2892
[CASH] Epoch 10/20 | Loss: 206.8687
[CASH] Epoch 20/20 | Loss: 206.5055
[CASH] Epoch 10/20 | Loss: 314.9159
[CASH] Epoch 20/20 | Loss: 311.2849
[CASH] Epoch 10/20 | Loss: 201.6779
[CASH] Epoch 20/20 | Loss: 195.7501
[CASH] Epoch 10/20 | Loss: 211.8555
[CASH] Epoch 20/20 | Loss: 205.5993
[CASH] Epoch 10/20 | Loss: 210.5579
[CASH] Epoch 20/20 | Loss: 210.4857
[CASH] Epoch 10/20 | Loss: 208.8424
[CASH] Epoch 20/20 | Loss: 208.7929
[CASH] Epoch 10/20 | Loss: 312.5203
[CASH] Epoch 20/20 | Loss: 310.0354
[CASH] Epoch 10/20 | Loss: 202.4267
[CASH] Epoch 20/20 | Loss: 195.5463
[CASH] Epoch 10/20 | Loss: 210.6670
[CASH] Epoch 20/20 | Loss: 205.2538
[CASH] Epoch 10/20 | Loss: 212.9779
[CASH] Epoch 20/20 | Loss: 212.9712
[CASH] Epoch 10/20 | Loss: 206.7036
[CASH] Epoch 20/20 | Loss: 206.4169
[CASH] Epoch 10/20 | Loss: 315.9410
[CASH] Epoch 20/20 | Loss: 311.3899
[CASH] Epoch 10/20 | Loss: 201.9404
[CASH] Epoch 20/20 | Loss: 195.6765
[CASH] Epoch 10/20 | Loss: 210.7105
[CASH] Epoch 20/20 | Loss: 207.3624
[CASH] Epoch 10/20 | Loss: 219.5773
[CASH] Epoch 20/20 | Loss: 219.4080
[CASH] Epoch 10/20 | Loss: 206.5024
[CASH] Epoch 20/20 | Loss: 206.2402
[CASH] Epoch 10/20 | Loss: 315.9301
[CASH] Epoch 20/20 | Loss: 311.1770
[CASH] Epoch 10/20 | Loss: 200.9448
[CASH] Epoch 20/20 | Loss: 194.6749
[CASH] Epoch 10/20 | Loss: 211.5682
[CASH] Epoch 20/20 | Loss: 208.3621
[CASH] Epoch 10/20 | Loss: 220.4680
[CASH] Epoch 20/20 | Loss: 220.4219
[CASH] Epoch 10/20 | Loss: 211.9366
[CASH] Epoch 20/20 | Loss: 211.7063
[CASH] Epoch 10/20 | Loss: 314.9283
[CASH] Epoch 20/20 | Loss: 311.3136
[CASH] Epoch 10/20 | Loss: 200.2954
[CASH] Epoch 20/20 | Loss: 194.2835
[CASH] Epoch 10/20 | Loss: 215.9843
[CASH] Epoch 20/20 | Loss: 212.7378
[CASH] Epoch 10/20 | Loss: 220.4476
[CASH] Epoch 20/20 | Loss: 220.3254
[CASH] Epoch 10/20 | Loss: 210.9125
[CASH] Epoch 20/20 | Loss: 210.6961
CEBMFResult(L=tensor([[-6.8917e+00, -1.2949e-09,  2.0583e-12,  8.1348e-02, -4.0405e-01],
        [-1.0151e-03, -1.5498e-11,  2.3386e+00,  3.6892e-01, -1.0618e-01],
        [-1.9792e+00, -3.0538e-19, -5.8073e-06, -1.6054e-01, -1.9785e-01],
        ...,
        [-6.5323e+00, -1.3508e-12,  1.2190e-12, -1.5553e-01,  8.1315e-01],
        [-5.2091e-04,  1.0466e-07,  2.8986e+00,  1.6640e-01,  1.6414e-01],
        [-7.3548e-05,  2.6724e-07,  1.7624e+00, -4.1566e-03,  1.1402e-01]]), F=tensor([[-1.3455e-01,  1.5683e-01, -1.0078e-01,  5.4707e-04, -1.8289e-03],
        [ 1.6455e-01,  1.5537e-01,  1.1040e-03,  1.2531e-04, -2.5863e-04],
        [-4.2631e-04, -1.0676e-03,  5.0055e-02,  9.1029e-04, -2.2025e-01],
        [-5.5217e-02,  1.2823e-01,  9.2370e-02, -2.7585e-03, -7.5741e-04],
        [ 1.0919e-05, -3.8014e-02,  7.7560e-02, -3.0726e-03, -2.8143e-03],
        [ 9.6189e-04,  2.7726e-03,  3.5946e-04, -2.2664e-04,  3.2895e-04],
        [-1.2602e-05, -1.2349e-02, -7.2410e-04, -1.3481e-01,  1.8135e-04],
        [-1.9133e-01,  1.5694e-01,  1.2352e-02,  2.3390e-05,  2.2564e-04],
        [-2.8899e-05, -1.1632e-04, -9.3146e-02,  1.4921e-03, -2.3305e-04],
        [-8.4444e-02,  2.5098e-01,  1.3247e-02, -1.8078e-03,  2.7631e-03],
        [-1.0314e-01,  6.1477e-02,  1.4891e-01,  1.1334e-03,  2.5769e-03],
        [ 9.7915e-02,  1.3537e-01,  5.4219e-03,  1.4024e-04, -1.9312e-01],
        [-2.2469e-04, -5.3401e-04, -1.3197e-01, -3.7513e-04,  1.0633e-03],
        [ 9.5868e-04,  1.0237e-03,  1.1618e-03, -1.4823e-03,  3.9404e-05],
        [ 9.7674e-05,  1.1560e-03, -2.0406e-03, -3.2058e-04,  1.2611e-03],
        [-2.2106e-03,  1.1932e-03, -5.3719e-04,  2.8196e-04,  5.3465e-04],
        [-3.1299e-04, -3.1580e-03, -1.2570e-03,  7.1330e-04, -1.0620e-03],
        [ 8.1609e-05, -1.6827e-03,  5.3312e-04, -6.6481e-04,  9.5792e-04],
        [ 1.0698e-01, -6.3028e-02,  7.2250e-02,  2.0501e-01, -4.1481e-05],
        [ 1.2694e-03,  7.8259e-04, -4.5950e-03, -2.2859e-04,  2.5605e-03],
        [ 2.6283e-05, -1.1668e-03,  2.3181e-03,  6.7269e-04,  2.4109e-04],
        [ 2.0595e-02, -1.7727e-01,  9.0996e-02,  1.8682e-04, -9.6341e-04],
        [-2.2741e-04, -5.1742e-03,  3.6805e-04,  8.9467e-04, -1.6966e-03],
        [ 7.5874e-04,  2.7115e-02,  6.3791e-04,  1.2070e-03,  1.5316e-01],
        [-8.5458e-02, -4.5917e-04, -4.3846e-03, -7.7824e-03,  3.6539e-04],
        [-5.2755e-02, -1.1163e-01,  8.7752e-03, -5.6469e-04,  2.0626e-03],
        [-3.1357e-04, -1.7750e-03, -1.9801e-02,  1.2068e-03,  5.2312e-04],
        [ 5.0777e-03,  1.6010e-03, -1.1433e-02,  9.3595e-04, -2.3114e-03],
        [-3.4797e-03,  2.0201e-03, -1.0582e-03,  4.1396e-04,  2.3957e-03],
        [-1.7724e-05,  1.5895e-03, -2.2057e-03, -1.1838e-01, -3.8310e-04],
        [ 1.8924e-04, -1.5656e-03,  1.2910e-03,  2.6164e-01, -1.2380e-03],
        [-6.2749e-04, -6.7965e-04,  2.6867e-03, -4.3594e-04,  3.8910e-04],
        [ 4.7491e-05,  7.8230e-04,  8.3347e-03,  7.0714e-02,  2.4099e-05],
        [-1.7133e-04,  5.3821e-03, -1.6361e-01,  1.3963e-03,  8.8814e-05],
        [-2.1128e-03, -2.7215e-03, -4.6410e-03, -2.9911e-04, -1.3386e-04],
        [-5.9496e-06, -2.2756e-05,  8.1893e-04, -3.1060e-04, -1.4751e-03],
        [ 2.5378e-05,  1.7330e-02, -1.4603e-03, -2.4206e-04, -1.5709e-01],
        [-1.9114e-04, -4.5442e-04,  3.2933e-05,  3.2857e-04,  1.4909e-04],
        [-6.4521e-02, -1.6726e-03,  5.7655e-04,  2.1816e-04,  5.3708e-04],
        [ 8.9600e-02,  1.0050e-01, -2.1618e-04,  3.8761e-04, -4.5478e-04],
        [ 2.6958e-04,  1.1643e-03, -4.6450e-04,  9.7289e-02, -4.5296e-03],
        [-2.8335e-02, -1.0818e-02, -4.3812e-02, -1.3265e-01, -8.8151e-04],
        [-6.7964e-02,  4.8164e-02,  1.7357e-01,  1.4196e-02,  1.1990e-01],
        [ 6.2218e-02,  7.1431e-02, -9.7483e-04,  1.9775e-01, -6.1427e-04],
        [-1.7277e-01,  1.7073e-01,  5.1682e-03, -9.3832e-04,  4.5902e-04],
        [-2.8129e-02,  5.5377e-04, -5.8386e-02,  8.1212e-05,  8.4953e-04],
        [ 5.2774e-02, -8.2989e-02,  1.0650e-03,  1.3261e-04, -1.0330e-03],
        [-6.2797e-02,  7.7497e-03, -9.4238e-02, -1.3086e-01, -1.2811e-01],
        [ 8.4397e-05, -6.5308e-03, -1.0294e-03, -7.5887e-04,  2.3308e-03],
        [ 2.9668e-02, -7.5195e-02, -9.1662e-04,  1.3419e-01,  1.7802e-03],
        [-1.1663e-02, -2.5245e-03, -8.1086e-04, -2.3202e-04, -4.8036e-04],
        [-1.8378e-04, -7.1523e-04, -6.8924e-03, -1.4746e-01,  7.7498e-04],
        [ 3.7786e-02,  9.2512e-02, -8.6447e-04,  1.7484e-03, -3.2181e-03],
        [ 1.5820e-04,  2.1356e-04, -3.8773e-03,  9.8688e-04,  2.9399e-03],
        [ 6.9984e-05,  1.0921e-03, -3.6499e-03, -1.5638e-04, -5.7012e-04],
        [ 3.1151e-03, -2.8473e-02, -3.2025e-03,  1.7703e-03,  1.5925e-03],
        [ 2.0189e-01, -1.3361e-02, -1.2101e-03, -1.0242e-03, -9.4304e-04],
        [ 2.0052e-05,  2.6429e-04,  2.4441e-03,  2.1542e-04,  1.2909e-01],
        [-1.1470e-01, -1.1325e-01,  9.6648e-02,  1.3994e-01,  1.3006e-01],
        [ 1.2153e-01,  2.0061e-03,  4.9259e-04, -9.8681e-04, -3.9660e-04],
        [ 5.4428e-04,  1.5735e-03, -1.2716e-01, -5.9337e-04,  3.1609e-04],
        [-1.4091e-05,  4.0975e-04, -4.3997e-03,  4.5066e-03,  2.2787e-03],
        [ 1.0259e-01, -6.8006e-02,  3.6758e-02, -2.9653e-04, -3.4700e-03],
        [ 3.2372e-05,  2.3057e-03, -1.1596e-02,  9.6374e-05,  3.9002e-04],
        [-8.0644e-02,  1.8564e-01,  5.8162e-04, -1.3583e-04, -1.2022e-03],
        [ 1.4435e-01, -7.2302e-02, -1.1383e-02, -6.4907e-04,  4.5844e-03],
        [-1.4503e-02,  2.4223e-03,  7.1172e-02, -2.1214e-04, -5.1861e-04],
        [ 8.6452e-03,  4.4537e-03, -7.6112e-02,  5.3205e-04, -9.3121e-04],
        [-6.2801e-02,  6.6597e-02, -1.0121e-02, -5.1010e-04, -4.2749e-04],
        [ 5.3491e-05,  2.1861e-04,  1.4793e-01, -7.1362e-04, -2.4912e-03],
        [-2.8221e-04,  1.1670e-01,  7.4481e-02, -9.4804e-05, -2.1149e-03],
        [-1.9818e-01, -6.4765e-02, -1.8048e-03,  7.6312e-04, -8.8231e-04],
        [-3.6840e-02, -6.4817e-02,  1.0309e-02, -4.9142e-04, -1.7635e-04],
        [-1.1082e-01,  8.7785e-02, -1.0208e-01,  2.7409e-03, -9.9287e-05],
        [ 1.1738e-05, -5.9704e-04,  9.9560e-02, -2.1012e-01,  1.5970e-03],
        [-1.5437e-01, -4.7047e-02, -2.6951e-02,  3.9756e-04, -6.8408e-04],
        [-7.6359e-02, -2.3779e-03, -2.6652e-02, -1.0465e-03,  5.4379e-04],
        [ 1.0007e-04,  2.7346e-04, -1.5346e-03,  8.6951e-02, -4.2033e-04],
        [ 1.3220e-01, -1.6441e-01, -1.3671e-01,  7.3870e-04,  8.5092e-04],
        [ 4.4896e-03,  7.1619e-04,  2.4366e-04, -2.2825e-04, -9.7926e-04],
        [ 1.4599e-01,  1.3814e-01, -1.8286e-02,  5.7566e-04,  3.3133e-04],
        [ 2.3161e-04, -1.1811e-03, -7.0447e-02,  5.2150e-03, -5.1384e-04],
        [ 4.3406e-02,  1.5624e-01,  1.5565e-03,  1.0569e-03, -3.1350e-04],
        [ 1.3108e-04,  5.2846e-03, -2.6057e-04, -4.4524e-04,  6.7929e-04],
        [-1.3634e-01,  3.4975e-03, -1.2151e-02, -6.3867e-04,  2.9116e-03],
        [-1.9540e-04,  3.7710e-02, -3.1813e-04,  6.0541e-04,  4.9092e-04],
        [ 1.4052e-01, -1.3724e-01, -4.2685e-04,  8.9577e-04, -2.1151e-03],
        [ 1.8066e-01,  5.0627e-02, -7.9241e-02,  6.1328e-04,  1.9895e-03],
        [ 3.2494e-02, -3.9606e-04,  3.9668e-03,  2.5259e-05,  9.3067e-04],
        [-3.4908e-05, -1.9959e-03, -2.9495e-03,  1.2923e-04,  7.7471e-04],
        [ 4.9360e-03, -5.1881e-02,  5.9042e-04, -1.9789e-04, -1.2475e-03],
        [ 4.0886e-02, -8.3671e-02,  8.1537e-02, -1.4425e-01, -2.9201e-04],
        [ 1.8816e-03,  1.3223e-01, -3.4830e-02, -1.3371e-03,  1.0668e-03],
        [ 3.6680e-05,  8.6677e-04,  1.7802e-02,  6.9025e-04,  7.7455e-02],
        [ 6.0846e-05, -2.7904e-04,  3.6675e-04, -1.1555e-03, -2.8486e-04],
        [-8.5211e-04,  2.9585e-04,  6.2442e-03,  1.4889e-04, -7.4618e-04],
        [-2.3001e-01, -8.1732e-04, -3.4379e-06,  2.1536e-04,  2.3743e-03],
        [ 1.2152e-04, -4.3873e-03, -1.7621e-02, -3.4300e-04, -2.5384e-05],
        [ 6.3431e-02, -1.2262e-01,  9.8919e-02,  1.2738e-03,  1.6842e-04],
        [-2.7086e-02,  1.0267e-01,  1.9000e-04,  1.1020e-04,  9.4390e-03],
        [ 1.2369e-01,  1.5744e-01,  2.7554e-03,  1.3673e-01, -8.9183e-04],
        [-1.8383e-04, -8.7770e-03, -3.9286e-03, -6.2124e-04, -7.2588e-04],
        [ 3.0982e-05,  1.1691e-03,  7.7194e-02, -1.3209e-01, -2.0949e-01],
        [ 1.4870e-01,  1.0619e-01, -7.8572e-02, -1.2752e-01, -1.7887e-03],
        [-6.5460e-02,  7.4858e-02,  3.8665e-03,  1.7599e-03,  6.9144e-04],
        [ 1.5956e-03,  8.1535e-03, -3.4698e-02, -7.0614e-04, -5.3961e-04],
        [ 4.6204e-05, -5.4209e-03, -5.4349e-03, -1.0193e-01,  8.7209e-04],
        [ 2.8597e-04, -1.4572e-03, -1.1840e-02, -3.9239e-04,  1.2352e-03],
        [-3.5382e-05,  3.8468e-04,  1.0726e-03, -1.0514e-01, -1.1274e-03],
        [-1.5672e-03,  8.5694e-04, -1.8346e-02, -1.7560e-03,  4.0627e-04],
        [-6.9290e-05, -1.7972e-03, -6.0449e-03, -6.5456e-04, -1.5342e-01],
        [ 1.7140e-04, -1.9600e-04, -1.1217e-01, -3.3885e-03, -8.1709e-03],
        [-1.6117e-04, -3.3309e-03, -4.9670e-03, -1.4666e-03, -1.6987e-01],
        [-6.1426e-02,  9.1234e-04,  6.8927e-03, -3.9051e-05,  2.6842e-05],
        [ 6.6322e-04, -1.5524e-04,  7.5674e-02, -1.3549e-03, -3.4012e-02],
        [ 1.2675e-01,  1.4860e-01, -1.1655e-03, -1.8698e-03, -3.5485e-04],
        [-1.2185e-02,  7.0259e-02,  5.8445e-03,  1.4595e-01,  1.4485e-03],
        [-4.9818e-02, -9.9803e-02, -2.4084e-03,  1.1264e-03, -5.2007e-03],
        [-1.6100e-01, -1.7623e-01, -3.6347e-03,  1.0666e-03, -6.1424e-04],
        [-1.5170e-04,  1.1399e-04, -5.7292e-03, -1.2829e-03,  8.5774e-04],
        [-3.6313e-02,  4.1535e-02,  8.6916e-02, -3.2395e-04,  9.5720e-05],
        [ 1.7126e-01, -4.9345e-03,  2.1864e-03, -1.4620e-03,  9.3274e-04],
        [ 8.8838e-05,  2.3974e-02, -8.4300e-04,  3.2591e-04,  2.4475e-03],
        [ 6.0291e-04,  1.3217e-03, -3.2275e-03,  1.4952e-01,  1.8532e-04],
        [ 3.0631e-04,  9.1334e-06,  8.1736e-02, -1.1009e-01, -1.9790e-02],
        [-1.9964e-04, -3.0306e-03,  1.7115e-03,  1.3300e-04,  9.2849e-04],
        [-6.6863e-02, -3.9844e-02, -9.6963e-04, -1.1469e-01,  1.4987e-01],
        [ 5.6801e-05, -6.9600e-03,  1.8396e-01,  2.3097e-03,  6.7973e-04],
        [ 1.2814e-01,  1.2370e-03,  6.0272e-02,  9.9899e-04,  2.2954e-03],
        [ 7.1562e-05,  6.2057e-03, -4.1528e-04, -1.4963e-03,  5.1993e-03],
        [ 3.5230e-02, -5.2817e-03,  2.8734e-02, -1.5012e-03,  1.7679e-01],
        [ 8.5526e-05,  4.9701e-03,  5.1950e-04,  2.8819e-04,  2.1791e-04],
        [ 5.1837e-04, -1.3653e-04,  2.3516e-03,  5.4210e-04, -1.1702e-01],
        [ 1.5719e-01, -1.1699e-01, -1.5743e-03,  3.5568e-03,  1.0882e-01],
        [-7.5326e-02,  4.8934e-02,  5.8113e-02, -1.0770e-01, -3.8962e-04],
        [-5.0871e-02, -1.0907e-01,  2.1128e-03, -1.7239e-04,  7.9414e-04],
        [-6.1390e-02, -7.9150e-02, -2.4045e-03,  1.0825e-01, -2.3275e-01],
        [ 1.0141e-05,  6.9531e-04, -3.8105e-01, -8.3822e-03,  8.6811e-04],
        [-3.3427e-03, -1.1199e-02, -1.6178e-02,  3.3264e-04,  1.2736e-04],
        [ 4.8314e-02, -8.7687e-02,  2.3178e-03, -1.1490e-03,  3.3050e-03],
        [ 2.2546e-04,  9.2306e-04, -3.1452e-05, -3.2265e-04,  1.6003e-03],
        [ 1.2507e-01,  1.6876e-01, -4.7744e-03,  8.0320e-04, -1.0736e-03],
        [ 1.2607e-04,  2.1120e-03, -3.9956e-02,  1.3473e-03, -1.3179e-01],
        [ 7.9236e-05,  2.0863e-03,  1.1361e-01,  1.8635e-01, -6.8411e-04],
        [ 5.0096e-04,  9.2486e-03,  5.1662e-03,  6.1923e-04,  1.0202e-03],
        [-2.4618e-03, -2.4375e-04,  9.9052e-02, -6.2521e-04, -2.3511e-04],
        [ 1.7592e-02, -2.4502e-03,  1.1835e-02,  6.3680e-04,  1.3003e-03],
        [-1.9014e-03, -7.2463e-02, -2.3835e-04, -4.2418e-04, -1.3672e-04],
        [-5.7524e-05,  1.5555e-03,  8.1403e-02,  1.4109e-05, -3.2037e-04],
        [ 3.3211e-05, -4.7257e-03,  5.3184e-03, -1.4583e-03,  1.2136e-01],
        [ 9.4139e-05, -6.1465e-03, -1.9728e-01, -6.4948e-04, -9.6800e-04],
        [ 8.0279e-02,  2.0114e-01, -2.3149e-02,  5.3314e-04, -9.5663e-04],
        [-9.4057e-06,  9.3017e-04, -6.3619e-03, -8.9478e-04,  6.8945e-04],
        [-4.0757e-04, -2.6289e-04, -1.0739e-01, -6.4301e-06, -6.4756e-04],
        [ 1.9800e-04, -7.5303e-03, -1.4169e-03,  7.0478e-02,  9.7779e-04],
        [ 2.5162e-03,  1.1748e-02, -1.1759e-01,  8.5675e-05,  3.8758e-02],
        [ 1.0333e-01,  8.3930e-02, -2.7763e-01, -1.8450e-03, -9.4792e-04],
        [ 1.1369e-01,  1.4739e-01,  9.3549e-02,  3.0639e-03, -3.6555e-04],
        [-7.1042e-02, -1.6552e-01, -6.5801e-02,  1.8529e-03, -3.4879e-04],
        [ 3.6944e-04,  2.1824e-02,  1.4615e-03, -2.6478e-03,  2.8615e-03],
        [ 8.6565e-02,  9.3828e-02,  7.3508e-02, -1.4559e-03,  1.4645e-04],
        [ 8.4967e-02,  1.8918e-02, -2.6207e-03,  1.0162e-03,  4.5810e-05],
        [ 6.6856e-02,  6.1064e-02, -2.8467e-02,  2.1627e-03, -3.5404e-04],
        [-8.1901e-02, -5.4220e-03,  3.4173e-02,  1.6867e-03, -1.3530e-03],
        [-1.7310e-01,  1.9890e-03,  2.4033e-03,  1.4348e-01, -1.9746e-01],
        [ 2.6985e-04,  4.0074e-03, -1.2633e-01, -1.1019e-03, -1.3738e-04],
        [-1.6513e-04, -3.2432e-04, -5.5851e-03, -1.2390e-03,  1.5313e-01],
        [ 4.6928e-02, -1.2585e-01, -6.1553e-02,  6.0248e-04,  1.3525e-01],
        [ 7.2692e-03, -9.6169e-04,  3.6073e-04, -5.6002e-04, -4.9393e-04],
        [ 3.3977e-02, -1.6764e-01,  1.2505e-01, -3.8051e-03, -6.7468e-04],
        [ 5.5181e-04,  2.3308e-03, -1.2515e-04,  1.3825e-01,  1.0268e-04],
        [-9.1427e-05,  1.7960e-03, -7.2494e-04,  7.3418e-05, -1.3692e-03],
        [-3.3211e-03,  1.7818e-02,  5.0075e-04,  1.9982e-03,  7.3755e-04],
        [ 6.6795e-02,  4.0416e-02,  4.2386e-02, -1.6749e-03, -1.3686e-03],
        [ 6.3355e-05, -3.9757e-03,  4.8202e-03, -7.4913e-04,  2.4455e-04],
        [ 2.1193e-01,  1.6540e-01,  8.0607e-03, -5.9517e-05, -7.0154e-04],
        [-8.1117e-02, -1.4417e-01, -1.6001e-02,  9.1730e-04,  1.6705e-01],
        [ 3.7545e-04,  4.1846e-02,  1.0624e-01,  9.3050e-05, -6.1875e-04],
        [ 4.6913e-02, -9.5452e-02,  6.0052e-03, -1.4269e-03, -6.2730e-04],
        [ 2.6614e-02, -2.9062e-03,  1.2468e-03,  9.3307e-04,  1.4782e-01],
        [-4.6565e-02, -1.2932e-03, -6.8698e-03,  3.1506e-04, -1.7370e-01],
        [-4.4658e-04, -1.1241e-02,  2.3371e-03,  1.5563e-03,  1.3638e-03],
        [-8.7541e-04,  3.1728e-04,  8.2786e-02,  2.0641e-04,  9.9914e-04],
        [-9.1960e-02, -7.5928e-02, -7.0665e-02, -4.9700e-04, -1.2949e-02],
        [ 1.6783e-01, -4.8480e-02, -3.5405e-03, -1.9673e-04, -1.3228e-03],
        [ 9.8401e-02,  4.8366e-02,  3.9594e-03, -7.7564e-04, -4.9369e-03],
        [-1.0665e-01, -4.5527e-02,  9.8746e-02, -1.4136e-03,  1.7882e-04],
        [-8.6737e-02,  2.0106e-04,  1.1155e-03, -2.3710e-04,  6.2419e-04],
        [-2.2532e-04, -4.8149e-03,  9.3750e-02, -1.0814e-03, -8.9537e-05],
        [ 7.4749e-04,  6.5410e-03,  2.0409e-03,  4.3415e-04,  1.0154e-01],
        [-3.6311e-02, -7.2775e-02, -5.8592e-03, -2.1231e-04,  1.1161e-03],
        [-1.4456e-04, -1.1161e-03, -7.9437e-03,  1.8548e-03,  8.5985e-04],
        [-1.7488e-04, -6.2023e-04, -1.1652e-01,  1.1431e-01,  7.1053e-04],
        [-2.7854e-05,  1.9083e-03, -6.1105e-03,  4.7669e-05,  2.5440e-04],
        [-1.3315e-01, -8.8718e-02,  1.9735e-04,  9.9303e-04, -5.9006e-04],
        [-7.2631e-02,  1.0339e-02, -8.4829e-03,  7.3527e-02,  4.3565e-04],
        [-1.3954e-04,  1.2043e-04,  1.2219e-03, -8.6147e-04, -5.1474e-04],
        [ 7.4800e-02,  4.0765e-02, -1.1845e-01, -9.5920e-04, -5.6222e-04],
        [-2.9311e-03, -2.1366e-04,  7.3735e-02,  2.3452e-04,  1.7151e-04],
        [ 1.8567e-05,  1.5184e-03,  1.4320e-01, -8.0148e-04, -5.3764e-04]]), tau=tensor(1.0063), history_obj=[560282.875, 561110.8125, 561661.0, 561621.1875, 561736.5, 561817.8125, 561573.5625, 561566.0625, 561606.625, 561645.875])

The ELBO is not purely monotonic for cEBMF using because the prior learning uses SGD

plt.plot(mycebmf11.obj)
[<matplotlib.lines.Line2D at 0x1b7d3a30950>]
../_images/9d41eccfb9121b555f3d50069b1179a4db240698f5c0a63de24fba6905671409.png

Fitting the cEBMF model using a “covariate moderated generalized binary prior” (cbg) \( g(x,y) = \pi_0(x,y) \delta_0 + (1-\pi_0) (x,y) N(\mu, \sigma_m^2) \)

mycebmf12=  cEBMF(data=Z, X_l=X,
                 prior_L="cgb" , allow_backfitting=False) 
mycebmf12.initialise_factors()
mycebmf12.fit(10)

plt.plot(mycebmf12.obj)
[CGB] Epoch 10/50, Loss=2.9914, mu2=-0.154, sigma2=6.250
[CGB] Epoch 20/50, Loss=2.9477, mu2=-0.311, sigma2=6.293
[CGB] Epoch 30/50, Loss=2.9301, mu2=-0.472, sigma2=5.891
[CGB] Epoch 40/50, Loss=2.9060, mu2=-0.635, sigma2=5.893
[CGB] Epoch 50/50, Loss=2.8905, mu2=-0.800, sigma2=6.193
[CGB] Epoch 10/50, Loss=1.7650, mu2=-0.108, sigma2=4.740
[CGB] Epoch 20/50, Loss=1.7063, mu2=-0.213, sigma2=4.881
[CGB] Epoch 30/50, Loss=1.6880, mu2=-0.334, sigma2=3.477
[CGB] Epoch 40/50, Loss=1.6836, mu2=-0.465, sigma2=3.696
[CGB] Epoch 50/50, Loss=1.6777, mu2=-0.605, sigma2=3.878
[CGB] Epoch 10/50, Loss=1.9227, mu2=0.126, sigma2=2.389
[CGB] Epoch 20/50, Loss=1.8660, mu2=0.271, sigma2=2.315
[CGB] Epoch 30/50, Loss=1.8181, mu2=0.423, sigma2=2.216
[CGB] Epoch 40/50, Loss=1.8040, mu2=0.588, sigma2=2.432
[CGB] Epoch 50/50, Loss=1.7860, mu2=0.760, sigma2=2.230
[CGB] Epoch 10/50, Loss=1.7700, mu2=-0.012, sigma2=2.474
[CGB] Epoch 20/50, Loss=1.7631, mu2=-0.029, sigma2=2.319
[CGB] Epoch 30/50, Loss=1.7678, mu2=-0.046, sigma2=2.199
[CGB] Epoch 40/50, Loss=1.7630, mu2=-0.068, sigma2=1.503
[CGB] Epoch 50/50, Loss=1.7678, mu2=-0.092, sigma2=2.219
[CGB] Epoch 10/50, Loss=1.7674, mu2=-0.006, sigma2=1.876
[CGB] Epoch 20/50, Loss=1.7633, mu2=-0.006, sigma2=1.883
[CGB] Epoch 30/50, Loss=1.7553, mu2=-0.008, sigma2=2.136
[CGB] Epoch 40/50, Loss=1.7565, mu2=-0.011, sigma2=1.289
[CGB] Epoch 50/50, Loss=1.7564, mu2=-0.019, sigma2=2.286
[CGB] Epoch 10/50, Loss=2.9744, mu2=-0.155, sigma2=6.679
[CGB] Epoch 20/50, Loss=2.9478, mu2=-0.313, sigma2=6.946
[CGB] Epoch 30/50, Loss=2.9191, mu2=-0.475, sigma2=6.265
[CGB] Epoch 40/50, Loss=2.8866, mu2=-0.639, sigma2=6.347
[CGB] Epoch 50/50, Loss=2.8648, mu2=-0.804, sigma2=5.922
[CGB] Epoch 10/50, Loss=1.7079, mu2=-0.111, sigma2=3.723
[CGB] Epoch 20/50, Loss=1.6719, mu2=-0.229, sigma2=4.322
[CGB] Epoch 30/50, Loss=1.6449, mu2=-0.361, sigma2=4.328
[CGB] Epoch 40/50, Loss=1.6207, mu2=-0.505, sigma2=4.474
[CGB] Epoch 50/50, Loss=1.6118, mu2=-0.657, sigma2=4.835
[CGB] Epoch 10/50, Loss=1.8643, mu2=0.119, sigma2=2.623
[CGB] Epoch 20/50, Loss=1.7850, mu2=0.252, sigma2=2.425
[CGB] Epoch 30/50, Loss=1.7549, mu2=0.402, sigma2=2.753
[CGB] Epoch 40/50, Loss=1.7415, mu2=0.567, sigma2=2.096
[CGB] Epoch 50/50, Loss=1.7218, mu2=0.741, sigma2=2.655
[CGB] Epoch 10/50, Loss=1.6386, mu2=-0.017, sigma2=0.576
[CGB] Epoch 20/50, Loss=1.6378, mu2=-0.021, sigma2=0.001
[CGB] Epoch 30/50, Loss=1.6375, mu2=-0.024, sigma2=1.427
[CGB] Epoch 40/50, Loss=1.6374, mu2=-0.028, sigma2=0.616
[CGB] Epoch 50/50, Loss=1.6354, mu2=-0.033, sigma2=2.327
[CGB] Epoch 10/50, Loss=1.6494, mu2=-0.006, sigma2=3.014
[CGB] Epoch 20/50, Loss=1.6527, mu2=-0.005, sigma2=1.913
[CGB] Epoch 30/50, Loss=1.6465, mu2=-0.002, sigma2=0.743
[CGB] Epoch 40/50, Loss=1.6473, mu2=0.004, sigma2=0.628
[CGB] Epoch 50/50, Loss=1.6449, mu2=0.010, sigma2=1.455
[CGB] Epoch 10/50, Loss=2.9547, mu2=-0.153, sigma2=6.957
[CGB] Epoch 20/50, Loss=2.9246, mu2=-0.310, sigma2=6.490
[CGB] Epoch 30/50, Loss=2.9025, mu2=-0.471, sigma2=5.973
[CGB] Epoch 40/50, Loss=2.8780, mu2=-0.634, sigma2=5.917
[CGB] Epoch 50/50, Loss=2.8670, mu2=-0.799, sigma2=6.366
[CGB] Epoch 10/50, Loss=1.7060, mu2=-0.108, sigma2=4.446
[CGB] Epoch 20/50, Loss=1.6580, mu2=-0.214, sigma2=4.303
[CGB] Epoch 30/50, Loss=1.6334, mu2=-0.339, sigma2=5.012
[CGB] Epoch 40/50, Loss=1.6258, mu2=-0.474, sigma2=3.879
[CGB] Epoch 50/50, Loss=1.6238, mu2=-0.617, sigma2=3.395
[CGB] Epoch 10/50, Loss=1.8857, mu2=0.114, sigma2=2.898
[CGB] Epoch 20/50, Loss=1.8047, mu2=0.245, sigma2=2.877
[CGB] Epoch 30/50, Loss=1.7730, mu2=0.397, sigma2=2.346
[CGB] Epoch 40/50, Loss=1.7528, mu2=0.561, sigma2=2.010
[CGB] Epoch 50/50, Loss=1.7377, mu2=0.734, sigma2=2.211
[CGB] Epoch 10/50, Loss=5.0259, mu2=0.000, sigma2=0.001
[CGB] Epoch 20/50, Loss=5.0251, mu2=-0.000, sigma2=0.001
[CGB] Epoch 30/50, Loss=5.0250, mu2=-0.000, sigma2=0.001
[CGB] Epoch 40/50, Loss=5.0249, mu2=-0.000, sigma2=0.001
[CGB] Epoch 50/50, Loss=5.0249, mu2=-0.000, sigma2=0.001
[CGB] Epoch 10/50, Loss=5.2795, mu2=-0.002, sigma2=0.001
[CGB] Epoch 20/50, Loss=5.2786, mu2=-0.002, sigma2=0.001
[CGB] Epoch 30/50, Loss=5.2784, mu2=-0.003, sigma2=0.001
[CGB] Epoch 40/50, Loss=5.2783, mu2=-0.003, sigma2=0.001
[CGB] Epoch 50/50, Loss=5.2783, mu2=-0.003, sigma2=0.001
[CGB] Epoch 10/50, Loss=2.9623, mu2=-0.153, sigma2=6.433
[CGB] Epoch 20/50, Loss=2.9247, mu2=-0.310, sigma2=6.484
[CGB] Epoch 30/50, Loss=2.8960, mu2=-0.471, sigma2=6.200
[CGB] Epoch 40/50, Loss=2.8829, mu2=-0.635, sigma2=5.890
[CGB] Epoch 50/50, Loss=2.8676, mu2=-0.800, sigma2=6.037
[CGB] Epoch 10/50, Loss=1.6928, mu2=-0.111, sigma2=4.313
[CGB] Epoch 20/50, Loss=1.6200, mu2=-0.221, sigma2=4.693
[CGB] Epoch 30/50, Loss=1.6111, mu2=-0.347, sigma2=4.313
[CGB] Epoch 40/50, Loss=1.6008, mu2=-0.485, sigma2=4.103
[CGB] Epoch 50/50, Loss=1.5979, mu2=-0.630, sigma2=3.929
[CGB] Epoch 10/50, Loss=1.8478, mu2=0.117, sigma2=2.382
[CGB] Epoch 20/50, Loss=1.7637, mu2=0.250, sigma2=2.617
[CGB] Epoch 30/50, Loss=1.7412, mu2=0.400, sigma2=2.169
[CGB] Epoch 40/50, Loss=1.7271, mu2=0.563, sigma2=2.399
[CGB] Epoch 50/50, Loss=1.7142, mu2=0.735, sigma2=2.092
[CGB] Epoch 10/50, Loss=-5.2608, mu2=0.000, sigma2=0.001
[CGB] Epoch 20/50, Loss=-5.2614, mu2=0.000, sigma2=0.001
[CGB] Epoch 30/50, Loss=-5.2615, mu2=0.000, sigma2=0.001
[CGB] Epoch 40/50, Loss=-5.2615, mu2=0.000, sigma2=0.001
[CGB] Epoch 50/50, Loss=-5.2615, mu2=0.000, sigma2=0.001
[CGB] Epoch 10/50, Loss=-3.7984, mu2=-0.000, sigma2=0.001
[CGB] Epoch 20/50, Loss=-3.7993, mu2=-0.000, sigma2=0.001
[CGB] Epoch 30/50, Loss=-3.7994, mu2=-0.000, sigma2=0.001
[CGB] Epoch 40/50, Loss=-3.7995, mu2=-0.000, sigma2=0.001
[CGB] Epoch 50/50, Loss=-3.7995, mu2=-0.000, sigma2=0.001
[CGB] Epoch 10/50, Loss=2.9592, mu2=-0.155, sigma2=6.462
[CGB] Epoch 20/50, Loss=2.9205, mu2=-0.312, sigma2=6.522
[CGB] Epoch 30/50, Loss=2.8928, mu2=-0.473, sigma2=6.365
[CGB] Epoch 40/50, Loss=2.8785, mu2=-0.636, sigma2=6.244
[CGB] Epoch 50/50, Loss=2.8600, mu2=-0.801, sigma2=6.118
[CGB] Epoch 10/50, Loss=1.6758, mu2=-0.108, sigma2=4.043
[CGB] Epoch 20/50, Loss=1.6148, mu2=-0.215, sigma2=3.591
[CGB] Epoch 30/50, Loss=1.6008, mu2=-0.339, sigma2=4.009
[CGB] Epoch 40/50, Loss=1.5972, mu2=-0.475, sigma2=4.091
[CGB] Epoch 50/50, Loss=1.5858, mu2=-0.618, sigma2=3.520
[CGB] Epoch 10/50, Loss=1.8157, mu2=0.119, sigma2=2.590
[CGB] Epoch 20/50, Loss=1.7436, mu2=0.256, sigma2=2.264
[CGB] Epoch 30/50, Loss=1.7139, mu2=0.412, sigma2=2.515
[CGB] Epoch 40/50, Loss=1.6919, mu2=0.581, sigma2=1.550
[CGB] Epoch 50/50, Loss=1.6763, mu2=0.761, sigma2=1.178
[CGB] Epoch 10/50, Loss=-5.4230, mu2=0.000, sigma2=0.001
[CGB] Epoch 20/50, Loss=-5.4238, mu2=0.000, sigma2=0.001
[CGB] Epoch 30/50, Loss=-5.4240, mu2=0.000, sigma2=0.001
[CGB] Epoch 40/50, Loss=-5.4240, mu2=0.000, sigma2=0.001
[CGB] Epoch 50/50, Loss=-5.4241, mu2=0.000, sigma2=0.001
[CGB] Epoch 10/50, Loss=-4.3343, mu2=0.000, sigma2=0.001
[CGB] Epoch 20/50, Loss=-4.3348, mu2=0.000, sigma2=0.001
[CGB] Epoch 30/50, Loss=-4.3348, mu2=0.000, sigma2=0.001
[CGB] Epoch 40/50, Loss=-4.3348, mu2=0.000, sigma2=0.001
[CGB] Epoch 50/50, Loss=-4.3348, mu2=0.000, sigma2=0.001
[CGB] Epoch 10/50, Loss=2.9614, mu2=-0.154, sigma2=6.120
[CGB] Epoch 20/50, Loss=2.9307, mu2=-0.312, sigma2=6.071
[CGB] Epoch 30/50, Loss=2.9031, mu2=-0.473, sigma2=6.124
[CGB] Epoch 40/50, Loss=2.8870, mu2=-0.636, sigma2=6.207
[CGB] Epoch 50/50, Loss=2.8630, mu2=-0.801, sigma2=6.039
[CGB] Epoch 10/50, Loss=1.6578, mu2=-0.106, sigma2=4.356
[CGB] Epoch 20/50, Loss=1.5944, mu2=-0.208, sigma2=4.416
[CGB] Epoch 30/50, Loss=1.5807, mu2=-0.328, sigma2=3.472
[CGB] Epoch 40/50, Loss=1.5771, mu2=-0.459, sigma2=4.192
[CGB] Epoch 50/50, Loss=1.5752, mu2=-0.600, sigma2=3.517
[CGB] Epoch 10/50, Loss=1.5944, mu2=0.121, sigma2=1.971
[CGB] Epoch 20/50, Loss=1.5318, mu2=0.259, sigma2=2.084
[CGB] Epoch 30/50, Loss=1.4881, mu2=0.415, sigma2=1.818
[CGB] Epoch 40/50, Loss=1.4740, mu2=0.588, sigma2=1.642
[CGB] Epoch 50/50, Loss=1.4472, mu2=0.773, sigma2=1.127
[CGB] Epoch 10/50, Loss=-4.6559, mu2=0.000, sigma2=0.001
[CGB] Epoch 20/50, Loss=-4.6567, mu2=-0.000, sigma2=0.001
[CGB] Epoch 30/50, Loss=-4.6568, mu2=-0.000, sigma2=0.001
[CGB] Epoch 40/50, Loss=-4.6568, mu2=-0.000, sigma2=0.001
[CGB] Epoch 50/50, Loss=-4.6569, mu2=-0.000, sigma2=0.001
[CGB] Epoch 10/50, Loss=-5.0545, mu2=0.000, sigma2=0.001
[CGB] Epoch 20/50, Loss=-5.0557, mu2=-0.000, sigma2=0.001
[CGB] Epoch 30/50, Loss=-5.0559, mu2=-0.000, sigma2=0.001
[CGB] Epoch 40/50, Loss=-5.0559, mu2=-0.000, sigma2=0.001
[CGB] Epoch 50/50, Loss=-5.0559, mu2=-0.000, sigma2=0.001
[CGB] Epoch 10/50, Loss=2.9738, mu2=-0.153, sigma2=6.485
[CGB] Epoch 20/50, Loss=2.9150, mu2=-0.309, sigma2=6.509
[CGB] Epoch 30/50, Loss=2.8949, mu2=-0.469, sigma2=5.719
[CGB] Epoch 40/50, Loss=2.8811, mu2=-0.631, sigma2=6.220
[CGB] Epoch 50/50, Loss=2.8548, mu2=-0.796, sigma2=5.999
[CGB] Epoch 10/50, Loss=1.7062, mu2=-0.110, sigma2=4.248
[CGB] Epoch 20/50, Loss=1.5955, mu2=-0.212, sigma2=3.033
[CGB] Epoch 30/50, Loss=1.5736, mu2=-0.329, sigma2=3.805
[CGB] Epoch 40/50, Loss=1.5649, mu2=-0.458, sigma2=4.764
[CGB] Epoch 50/50, Loss=1.5579, mu2=-0.597, sigma2=3.823
[CGB] Epoch 10/50, Loss=1.4787, mu2=0.127, sigma2=1.766
[CGB] Epoch 20/50, Loss=1.3956, mu2=0.273, sigma2=1.842
[CGB] Epoch 30/50, Loss=1.3610, mu2=0.439, sigma2=1.561
[CGB] Epoch 40/50, Loss=1.3366, mu2=0.618, sigma2=1.356
[CGB] Epoch 50/50, Loss=1.3104, mu2=0.808, sigma2=1.107
[CGB] Epoch 10/50, Loss=-4.5778, mu2=-0.000, sigma2=0.001
[CGB] Epoch 20/50, Loss=-4.5785, mu2=-0.000, sigma2=0.001
[CGB] Epoch 30/50, Loss=-4.5786, mu2=-0.000, sigma2=0.001
[CGB] Epoch 40/50, Loss=-4.5786, mu2=-0.000, sigma2=0.001
[CGB] Epoch 50/50, Loss=-4.5787, mu2=-0.000, sigma2=0.001
[CGB] Epoch 10/50, Loss=-4.5474, mu2=-0.000, sigma2=0.001
[CGB] Epoch 20/50, Loss=-4.5482, mu2=-0.000, sigma2=0.001
[CGB] Epoch 30/50, Loss=-4.5483, mu2=-0.000, sigma2=0.001
[CGB] Epoch 40/50, Loss=-4.5483, mu2=-0.000, sigma2=0.001
[CGB] Epoch 50/50, Loss=-4.5483, mu2=-0.000, sigma2=0.001
[CGB] Epoch 10/50, Loss=2.9687, mu2=-0.154, sigma2=6.300
[CGB] Epoch 20/50, Loss=2.9227, mu2=-0.311, sigma2=5.944
[CGB] Epoch 30/50, Loss=2.8937, mu2=-0.471, sigma2=5.873
[CGB] Epoch 40/50, Loss=2.8786, mu2=-0.634, sigma2=6.266
[CGB] Epoch 50/50, Loss=2.8480, mu2=-0.799, sigma2=5.879
[CGB] Epoch 10/50, Loss=1.6371, mu2=-0.108, sigma2=4.051
[CGB] Epoch 20/50, Loss=1.5804, mu2=-0.214, sigma2=4.103
[CGB] Epoch 30/50, Loss=1.5704, mu2=-0.338, sigma2=3.702
[CGB] Epoch 40/50, Loss=1.5591, mu2=-0.473, sigma2=4.064
[CGB] Epoch 50/50, Loss=1.5502, mu2=-0.617, sigma2=4.128
[CGB] Epoch 10/50, Loss=1.4261, mu2=0.115, sigma2=1.632
[CGB] Epoch 20/50, Loss=1.3385, mu2=0.251, sigma2=1.312
[CGB] Epoch 30/50, Loss=1.2960, mu2=0.407, sigma2=1.501
[CGB] Epoch 40/50, Loss=1.2767, mu2=0.580, sigma2=1.217
[CGB] Epoch 50/50, Loss=1.2487, mu2=0.765, sigma2=0.822
[CGB] Epoch 10/50, Loss=-4.6675, mu2=-0.000, sigma2=0.001
[CGB] Epoch 20/50, Loss=-4.6685, mu2=0.000, sigma2=0.001
[CGB] Epoch 30/50, Loss=-4.6687, mu2=0.000, sigma2=0.001
[CGB] Epoch 40/50, Loss=-4.6687, mu2=0.000, sigma2=0.001
[CGB] Epoch 50/50, Loss=-4.6687, mu2=0.000, sigma2=0.001
[CGB] Epoch 10/50, Loss=-6.3708, mu2=0.000, sigma2=0.001
[CGB] Epoch 20/50, Loss=-6.3721, mu2=0.000, sigma2=0.001
[CGB] Epoch 30/50, Loss=-6.3723, mu2=0.000, sigma2=0.001
[CGB] Epoch 40/50, Loss=-6.3724, mu2=0.000, sigma2=0.001
[CGB] Epoch 50/50, Loss=-6.3724, mu2=0.000, sigma2=0.001
[CGB] Epoch 10/50, Loss=2.9718, mu2=-0.153, sigma2=7.058
[CGB] Epoch 20/50, Loss=2.9328, mu2=-0.310, sigma2=6.743
[CGB] Epoch 30/50, Loss=2.9059, mu2=-0.471, sigma2=6.190
[CGB] Epoch 40/50, Loss=2.8756, mu2=-0.634, sigma2=5.544
[CGB] Epoch 50/50, Loss=2.8530, mu2=-0.799, sigma2=5.660
[CGB] Epoch 10/50, Loss=1.6464, mu2=-0.111, sigma2=3.986
[CGB] Epoch 20/50, Loss=1.5881, mu2=-0.221, sigma2=4.012
[CGB] Epoch 30/50, Loss=1.5708, mu2=-0.348, sigma2=3.813
[CGB] Epoch 40/50, Loss=1.5631, mu2=-0.486, sigma2=3.857
[CGB] Epoch 50/50, Loss=1.5556, mu2=-0.631, sigma2=3.586
[CGB] Epoch 10/50, Loss=1.2801, mu2=0.120, sigma2=1.324
[CGB] Epoch 20/50, Loss=1.2095, mu2=0.265, sigma2=1.170
[CGB] Epoch 30/50, Loss=1.1603, mu2=0.427, sigma2=1.237
[CGB] Epoch 40/50, Loss=1.1214, mu2=0.608, sigma2=0.868
[CGB] Epoch 50/50, Loss=1.0847, mu2=0.800, sigma2=0.512
[CGB] Epoch 10/50, Loss=-4.4266, mu2=0.000, sigma2=0.001
[CGB] Epoch 20/50, Loss=-4.4279, mu2=0.000, sigma2=0.001
[CGB] Epoch 30/50, Loss=-4.4280, mu2=0.000, sigma2=0.001
[CGB] Epoch 40/50, Loss=-4.4280, mu2=0.000, sigma2=0.001
[CGB] Epoch 50/50, Loss=-4.4280, mu2=0.000, sigma2=0.001
[CGB] Epoch 10/50, Loss=-5.3345, mu2=-0.000, sigma2=0.001
[CGB] Epoch 20/50, Loss=-5.3353, mu2=0.000, sigma2=0.001
[CGB] Epoch 30/50, Loss=-5.3354, mu2=0.000, sigma2=0.001
[CGB] Epoch 40/50, Loss=-5.3355, mu2=0.000, sigma2=0.001
[CGB] Epoch 50/50, Loss=-5.3355, mu2=0.000, sigma2=0.001
[CGB] Epoch 10/50, Loss=2.9414, mu2=-0.155, sigma2=6.429
[CGB] Epoch 20/50, Loss=2.9119, mu2=-0.312, sigma2=6.303
[CGB] Epoch 30/50, Loss=2.8803, mu2=-0.474, sigma2=5.968
[CGB] Epoch 40/50, Loss=2.8595, mu2=-0.637, sigma2=6.197
[CGB] Epoch 50/50, Loss=2.8419, mu2=-0.802, sigma2=5.705
[CGB] Epoch 10/50, Loss=1.6190, mu2=-0.112, sigma2=4.571
[CGB] Epoch 20/50, Loss=1.5712, mu2=-0.230, sigma2=3.611
[CGB] Epoch 30/50, Loss=1.5563, mu2=-0.363, sigma2=5.201
[CGB] Epoch 40/50, Loss=1.5511, mu2=-0.506, sigma2=3.617
[CGB] Epoch 50/50, Loss=1.5429, mu2=-0.656, sigma2=3.724
[CGB] Epoch 10/50, Loss=1.0602, mu2=0.124, sigma2=1.138
[CGB] Epoch 20/50, Loss=0.9781, mu2=0.273, sigma2=1.095
[CGB] Epoch 30/50, Loss=0.9209, mu2=0.448, sigma2=0.617
[CGB] Epoch 40/50, Loss=0.8796, mu2=0.642, sigma2=0.553
[CGB] Epoch 50/50, Loss=0.8333, mu2=0.828, sigma2=0.401
[CGB] Epoch 10/50, Loss=-5.1318, mu2=0.000, sigma2=0.001
[CGB] Epoch 20/50, Loss=-5.1329, mu2=-0.000, sigma2=0.001
[CGB] Epoch 30/50, Loss=-5.1331, mu2=-0.000, sigma2=0.001
[CGB] Epoch 40/50, Loss=-5.1331, mu2=-0.000, sigma2=0.001
[CGB] Epoch 50/50, Loss=-5.1331, mu2=-0.000, sigma2=0.001
[CGB] Epoch 10/50, Loss=-4.7668, mu2=-0.000, sigma2=0.001
[CGB] Epoch 20/50, Loss=-4.7682, mu2=-0.000, sigma2=0.001
[CGB] Epoch 30/50, Loss=-4.7683, mu2=-0.000, sigma2=0.001
[CGB] Epoch 40/50, Loss=-4.7684, mu2=-0.000, sigma2=0.001
[CGB] Epoch 50/50, Loss=-4.7684, mu2=-0.000, sigma2=0.001
[<matplotlib.lines.Line2D at 0x1b7d61e3810>]
../_images/b9d36728b89603200be43be1348915cfa61dcef547bb898530efc47cd0cbe2c2.png

Fitting the cEBMF model using a “emprirical mixture of density networks” (emdn) \( g(x,y) = \ \sum_{m=1}^M \pi_m (x,y) N(\mu(x,y), \sigma_m(x,y)^2) \)

mycebmf13=  cEBMF(data=Z, X_l=X,
                 prior_L="emdn", allow_backfitting=False) 
mycebmf13.initialise_factors()
mycebmf13.fit(10)
[EMDN] Epoch 10/50, Loss: 2.9986
[EMDN] Epoch 20/50, Loss: 1.8607
[EMDN] Epoch 30/50, Loss: 1.5856
[EMDN] Epoch 40/50, Loss: 1.5462
[EMDN] Epoch 50/50, Loss: 1.5173
[EMDN] Epoch 10/50, Loss: 1.7203
[EMDN] Epoch 20/50, Loss: 1.5484
[EMDN] Epoch 30/50, Loss: 1.5139
[EMDN] Epoch 40/50, Loss: 1.4961
[EMDN] Epoch 50/50, Loss: 1.4789
[EMDN] Epoch 10/50, Loss: 1.6969
[EMDN] Epoch 20/50, Loss: 1.6024
[EMDN] Epoch 30/50, Loss: 1.5047
[EMDN] Epoch 40/50, Loss: 1.4901
[EMDN] Epoch 50/50, Loss: 1.4911
[EMDN] Epoch 10/50, Loss: 1.6820
[EMDN] Epoch 20/50, Loss: 1.6806
[EMDN] Epoch 30/50, Loss: 1.6784
[EMDN] Epoch 40/50, Loss: 1.6768
[EMDN] Epoch 50/50, Loss: 1.6774
[EMDN] Epoch 10/50, Loss: 1.6802
[EMDN] Epoch 20/50, Loss: 1.6765
[EMDN] Epoch 30/50, Loss: 1.6739
[EMDN] Epoch 40/50, Loss: 1.6725
[EMDN] Epoch 50/50, Loss: 1.6703
[EMDN] Epoch 10/50, Loss: 2.8551
[EMDN] Epoch 20/50, Loss: 2.1658
[EMDN] Epoch 30/50, Loss: 1.6062
[EMDN] Epoch 40/50, Loss: 1.5610
[EMDN] Epoch 50/50, Loss: 1.5293
[EMDN] Epoch 10/50, Loss: 1.6767
[EMDN] Epoch 20/50, Loss: 1.5076
[EMDN] Epoch 30/50, Loss: 1.4652
[EMDN] Epoch 40/50, Loss: 1.4466
[EMDN] Epoch 50/50, Loss: 1.4415
[EMDN] Epoch 10/50, Loss: 1.7402
[EMDN] Epoch 20/50, Loss: 1.6256
[EMDN] Epoch 30/50, Loss: 1.5049
[EMDN] Epoch 40/50, Loss: 1.4958
[EMDN] Epoch 50/50, Loss: 1.4893
[EMDN] Epoch 10/50, Loss: 1.6841
[EMDN] Epoch 20/50, Loss: 1.6820
[EMDN] Epoch 30/50, Loss: 1.6798
[EMDN] Epoch 40/50, Loss: 1.6786
[EMDN] Epoch 50/50, Loss: 1.6787
[EMDN] Epoch 10/50, Loss: 1.6784
[EMDN] Epoch 20/50, Loss: 1.6735
[EMDN] Epoch 30/50, Loss: 1.6694
[EMDN] Epoch 40/50, Loss: 1.6672
[EMDN] Epoch 50/50, Loss: 1.6627
[EMDN] Epoch 10/50, Loss: 2.7908
[EMDN] Epoch 20/50, Loss: 2.0262
[EMDN] Epoch 30/50, Loss: 1.5761
[EMDN] Epoch 40/50, Loss: 1.5321
[EMDN] Epoch 50/50, Loss: 1.4976
[EMDN] Epoch 10/50, Loss: 1.7055
[EMDN] Epoch 20/50, Loss: 1.5182
[EMDN] Epoch 30/50, Loss: 1.4781
[EMDN] Epoch 40/50, Loss: 1.4596
[EMDN] Epoch 50/50, Loss: 1.4470
[EMDN] Epoch 10/50, Loss: 1.7470
[EMDN] Epoch 20/50, Loss: 1.6596
[EMDN] Epoch 30/50, Loss: 1.5373
[EMDN] Epoch 40/50, Loss: 1.5019
[EMDN] Epoch 50/50, Loss: 1.4936
[EMDN] Epoch 10/50, Loss: 1.6760
[EMDN] Epoch 20/50, Loss: 1.6742
[EMDN] Epoch 30/50, Loss: 1.6729
[EMDN] Epoch 40/50, Loss: 1.6712
[EMDN] Epoch 50/50, Loss: 1.6692
[EMDN] Epoch 10/50, Loss: 1.6539
[EMDN] Epoch 20/50, Loss: 1.6511
[EMDN] Epoch 30/50, Loss: 1.6465
[EMDN] Epoch 40/50, Loss: 1.6428
[EMDN] Epoch 50/50, Loss: 1.6392
[EMDN] Epoch 10/50, Loss: 2.8626
[EMDN] Epoch 20/50, Loss: 2.2547
[EMDN] Epoch 30/50, Loss: 1.5728
[EMDN] Epoch 40/50, Loss: 1.5169
[EMDN] Epoch 50/50, Loss: 1.4798
[EMDN] Epoch 10/50, Loss: 1.6830
[EMDN] Epoch 20/50, Loss: 1.5067
[EMDN] Epoch 30/50, Loss: 1.4867
[EMDN] Epoch 40/50, Loss: 1.4822
[EMDN] Epoch 50/50, Loss: 1.4648
[EMDN] Epoch 10/50, Loss: 1.8107
[EMDN] Epoch 20/50, Loss: 1.7149
[EMDN] Epoch 30/50, Loss: 1.5854
[EMDN] Epoch 40/50, Loss: 1.5687
[EMDN] Epoch 50/50, Loss: 1.5604
[EMDN] Epoch 10/50, Loss: 1.6773
[EMDN] Epoch 20/50, Loss: 1.6754
[EMDN] Epoch 30/50, Loss: 1.6723
[EMDN] Epoch 40/50, Loss: 1.6710
[EMDN] Epoch 50/50, Loss: 1.6693
[EMDN] Epoch 10/50, Loss: 1.6416
[EMDN] Epoch 20/50, Loss: 1.6395
[EMDN] Epoch 30/50, Loss: 1.6356
[EMDN] Epoch 40/50, Loss: 1.6289
[EMDN] Epoch 50/50, Loss: 1.6257
[EMDN] Epoch 10/50, Loss: 2.8940
[EMDN] Epoch 20/50, Loss: 2.0248
[EMDN] Epoch 30/50, Loss: 1.5652
[EMDN] Epoch 40/50, Loss: 1.5365
[EMDN] Epoch 50/50, Loss: 1.5012
[EMDN] Epoch 10/50, Loss: 1.7102
[EMDN] Epoch 20/50, Loss: 1.5159
[EMDN] Epoch 30/50, Loss: 1.4958
[EMDN] Epoch 40/50, Loss: 1.4815
[EMDN] Epoch 50/50, Loss: 1.4746
[EMDN] Epoch 10/50, Loss: 1.7975
[EMDN] Epoch 20/50, Loss: 1.7271
[EMDN] Epoch 30/50, Loss: 1.6045
[EMDN] Epoch 40/50, Loss: 1.5800
[EMDN] Epoch 50/50, Loss: 1.5737
[EMDN] Epoch 10/50, Loss: 1.6770
[EMDN] Epoch 20/50, Loss: 1.6731
[EMDN] Epoch 30/50, Loss: 1.6714
[EMDN] Epoch 40/50, Loss: 1.6696
[EMDN] Epoch 50/50, Loss: 1.6692
[EMDN] Epoch 10/50, Loss: 1.6377
[EMDN] Epoch 20/50, Loss: 1.6333
[EMDN] Epoch 30/50, Loss: 1.6284
[EMDN] Epoch 40/50, Loss: 1.6255
[EMDN] Epoch 50/50, Loss: 1.6200
[EMDN] Epoch 10/50, Loss: 2.8780
[EMDN] Epoch 20/50, Loss: 1.8526
[EMDN] Epoch 30/50, Loss: 1.5655
[EMDN] Epoch 40/50, Loss: 1.5289
[EMDN] Epoch 50/50, Loss: 1.4830
[EMDN] Epoch 10/50, Loss: 1.6534
[EMDN] Epoch 20/50, Loss: 1.5179
[EMDN] Epoch 30/50, Loss: 1.4912
[EMDN] Epoch 40/50, Loss: 1.4775
[EMDN] Epoch 50/50, Loss: 1.4673
[EMDN] Epoch 10/50, Loss: 1.7795
[EMDN] Epoch 20/50, Loss: 1.6799
[EMDN] Epoch 30/50, Loss: 1.5620
[EMDN] Epoch 40/50, Loss: 1.5506
[EMDN] Epoch 50/50, Loss: 1.5471
[EMDN] Epoch 10/50, Loss: 1.6759
[EMDN] Epoch 20/50, Loss: 1.6732
[EMDN] Epoch 30/50, Loss: 1.6711
[EMDN] Epoch 40/50, Loss: 1.6691
[EMDN] Epoch 50/50, Loss: 1.6687
[EMDN] Epoch 10/50, Loss: 1.6128
[EMDN] Epoch 20/50, Loss: 1.6074
[EMDN] Epoch 30/50, Loss: 1.6024
[EMDN] Epoch 40/50, Loss: 1.5992
[EMDN] Epoch 50/50, Loss: 1.5969
[EMDN] Epoch 10/50, Loss: 2.8109
[EMDN] Epoch 20/50, Loss: 1.7667
[EMDN] Epoch 30/50, Loss: 1.5862
[EMDN] Epoch 40/50, Loss: 1.5500
[EMDN] Epoch 50/50, Loss: 1.5119
[EMDN] Epoch 10/50, Loss: 1.7226
[EMDN] Epoch 20/50, Loss: 1.5295
[EMDN] Epoch 30/50, Loss: 1.5031
[EMDN] Epoch 40/50, Loss: 1.4906
[EMDN] Epoch 50/50, Loss: 1.4811
[EMDN] Epoch 10/50, Loss: 1.8252
[EMDN] Epoch 20/50, Loss: 1.7161
[EMDN] Epoch 30/50, Loss: 1.6055
[EMDN] Epoch 40/50, Loss: 1.5872
[EMDN] Epoch 50/50, Loss: 1.5833
[EMDN] Epoch 10/50, Loss: 1.6769
[EMDN] Epoch 20/50, Loss: 1.6748
[EMDN] Epoch 30/50, Loss: 1.6732
[EMDN] Epoch 40/50, Loss: 1.6711
[EMDN] Epoch 50/50, Loss: 1.6695
[EMDN] Epoch 10/50, Loss: 1.5937
[EMDN] Epoch 20/50, Loss: 1.5900
[EMDN] Epoch 30/50, Loss: 1.5848
[EMDN] Epoch 40/50, Loss: 1.5780
[EMDN] Epoch 50/50, Loss: 1.5758
[EMDN] Epoch 10/50, Loss: 2.8036
[EMDN] Epoch 20/50, Loss: 2.4109
[EMDN] Epoch 30/50, Loss: 1.6157
[EMDN] Epoch 40/50, Loss: 1.5495
[EMDN] Epoch 50/50, Loss: 1.5238
[EMDN] Epoch 10/50, Loss: 1.7260
[EMDN] Epoch 20/50, Loss: 1.5298
[EMDN] Epoch 30/50, Loss: 1.5047
[EMDN] Epoch 40/50, Loss: 1.4902
[EMDN] Epoch 50/50, Loss: 1.4780
[EMDN] Epoch 10/50, Loss: 1.7769
[EMDN] Epoch 20/50, Loss: 1.6921
[EMDN] Epoch 30/50, Loss: 1.5571
[EMDN] Epoch 40/50, Loss: 1.5407
[EMDN] Epoch 50/50, Loss: 1.5421
[EMDN] Epoch 10/50, Loss: 1.6671
[EMDN] Epoch 20/50, Loss: 1.6625
[EMDN] Epoch 30/50, Loss: 1.6590
[EMDN] Epoch 40/50, Loss: 1.6589
[EMDN] Epoch 50/50, Loss: 1.6581
[EMDN] Epoch 10/50, Loss: 1.5781
[EMDN] Epoch 20/50, Loss: 1.5723
[EMDN] Epoch 30/50, Loss: 1.5672
[EMDN] Epoch 40/50, Loss: 1.5618
[EMDN] Epoch 50/50, Loss: 1.5567
[EMDN] Epoch 10/50, Loss: 2.9899
[EMDN] Epoch 20/50, Loss: 2.1702
[EMDN] Epoch 30/50, Loss: 1.5850
[EMDN] Epoch 40/50, Loss: 1.5329
[EMDN] Epoch 50/50, Loss: 1.5059
[EMDN] Epoch 10/50, Loss: 1.6684
[EMDN] Epoch 20/50, Loss: 1.5415
[EMDN] Epoch 30/50, Loss: 1.5086
[EMDN] Epoch 40/50, Loss: 1.5033
[EMDN] Epoch 50/50, Loss: 1.4886
[EMDN] Epoch 10/50, Loss: 1.8199
[EMDN] Epoch 20/50, Loss: 1.7027
[EMDN] Epoch 30/50, Loss: 1.6191
[EMDN] Epoch 40/50, Loss: 1.5988
[EMDN] Epoch 50/50, Loss: 1.5973
[EMDN] Epoch 10/50, Loss: 1.6603
[EMDN] Epoch 20/50, Loss: 1.6596
[EMDN] Epoch 30/50, Loss: 1.6566
[EMDN] Epoch 40/50, Loss: 1.6554
[EMDN] Epoch 50/50, Loss: 1.6536
[EMDN] Epoch 10/50, Loss: 1.5598
[EMDN] Epoch 20/50, Loss: 1.5558
[EMDN] Epoch 30/50, Loss: 1.5478
[EMDN] Epoch 40/50, Loss: 1.5443
[EMDN] Epoch 50/50, Loss: 1.5412
[EMDN] Epoch 10/50, Loss: 2.9222
[EMDN] Epoch 20/50, Loss: 2.3725
[EMDN] Epoch 30/50, Loss: 1.6483
[EMDN] Epoch 40/50, Loss: 1.5886
[EMDN] Epoch 50/50, Loss: 1.5636
[EMDN] Epoch 10/50, Loss: 1.7672
[EMDN] Epoch 20/50, Loss: 1.5496
[EMDN] Epoch 30/50, Loss: 1.5170
[EMDN] Epoch 40/50, Loss: 1.5058
[EMDN] Epoch 50/50, Loss: 1.4955
[EMDN] Epoch 10/50, Loss: 1.8075
[EMDN] Epoch 20/50, Loss: 1.7381
[EMDN] Epoch 30/50, Loss: 1.6439
[EMDN] Epoch 40/50, Loss: 1.5699
[EMDN] Epoch 50/50, Loss: 1.5661
[EMDN] Epoch 10/50, Loss: 1.6571
[EMDN] Epoch 20/50, Loss: 1.6543
[EMDN] Epoch 30/50, Loss: 1.6522
[EMDN] Epoch 40/50, Loss: 1.6510
[EMDN] Epoch 50/50, Loss: 1.6486
[EMDN] Epoch 10/50, Loss: 1.5500
[EMDN] Epoch 20/50, Loss: 1.5448
[EMDN] Epoch 30/50, Loss: 1.5436
[EMDN] Epoch 40/50, Loss: 1.5349
[EMDN] Epoch 50/50, Loss: 1.5319
CEBMFResult(L=tensor([[-7.0260e+00, -1.4371e-01,  1.0476e-01, -1.5912e-01, -1.1556e-01],
        [-2.5117e-01, -7.2868e-02,  3.4810e+00,  2.6240e-01, -1.6113e-01],
        [-2.9664e+00, -7.4841e-02,  1.7804e-03, -2.4686e-01, -1.9230e-01],
        ...,
        [-6.5060e+00, -1.1645e-01,  9.7392e-02, -2.8837e-01,  9.8725e-02],
        [-1.2500e-01, -2.3029e-02,  3.6450e+00,  1.0461e-01,  4.5507e-01],
        [-1.1618e-01, -1.1522e-02,  3.1444e+00, -1.1501e-01,  1.5473e-01]]), F=tensor([[-1.3705e-01,  1.5808e-01, -9.6476e-02,  9.6620e-04, -1.0204e-03],
        [ 1.6217e-01,  1.3871e-01,  1.5000e-02, -4.8870e-04, -3.3395e-03],
        [-5.0391e-04, -1.1573e-03,  5.3035e-02,  4.3421e-04, -2.5842e-01],
        [-5.6655e-02,  1.3307e-01,  8.2241e-02, -8.0455e-04, -4.0792e-03],
        [ 8.7007e-05, -3.4626e-02,  7.6109e-02, -6.3931e-04, -1.0573e-03],
        [ 1.1407e-03,  3.8601e-03,  8.9573e-04,  7.0435e-04,  3.1298e-04],
        [ 3.4677e-05, -1.3276e-02, -2.7620e-03, -1.5574e-01,  2.9780e-03],
        [-1.9293e-01,  1.6401e-01,  2.7110e-02,  4.7656e-04,  1.1465e-03],
        [-1.5419e-04, -4.3064e-05, -8.9515e-02,  9.5348e-04,  1.1948e-03],
        [-8.7520e-02,  2.4486e-01,  1.2626e-02, -1.7567e-03,  3.1252e-03],
        [-1.0249e-01,  6.9055e-02,  1.3277e-01,  1.0145e-03,  9.6483e-03],
        [ 9.5366e-02,  1.1659e-01,  1.3955e-02, -3.4557e-04, -2.6936e-01],
        [-4.3487e-04, -6.8412e-04, -1.3041e-01, -4.5255e-04,  3.1407e-03],
        [ 6.3867e-04,  7.3969e-04, -7.9459e-04, -1.9713e-03, -6.4756e-04],
        [ 8.7077e-05,  1.9922e-03, -4.3536e-03, -5.3377e-04,  6.9672e-03],
        [-2.0813e-03,  1.0097e-03, -5.4393e-04, -3.6138e-05,  2.5691e-03],
        [-2.9092e-04, -2.2299e-03,  3.3444e-04,  2.8868e-04,  1.0806e-04],
        [ 5.8534e-05, -2.5119e-03,  7.6592e-04, -4.8842e-04,  1.8544e-03],
        [ 1.0734e-01, -6.1556e-02,  8.1142e-02,  2.4080e-01, -2.8891e-03],
        [ 1.4874e-03,  7.1745e-04, -9.5134e-03, -1.0482e-04,  1.5797e-03],
        [ 2.3299e-06, -5.1971e-04,  2.9701e-03,  5.4532e-04,  2.7483e-03],
        [ 2.3372e-02, -1.6988e-01,  8.9488e-02,  8.1160e-04, -3.3494e-03],
        [-2.3562e-04, -2.6700e-03,  5.3118e-04,  7.1274e-04, -4.2417e-03],
        [ 7.9458e-04,  2.9934e-02,  6.6020e-04,  7.2202e-04,  1.8951e-01],
        [-8.4293e-02,  1.0424e-03, -2.1502e-03, -2.8660e-02,  3.2554e-03],
        [-5.0073e-02, -1.0497e-01,  3.5575e-03, -2.1195e-04,  8.1002e-02],
        [-3.8466e-04, -2.1696e-03, -1.2981e-02,  4.0335e-04,  2.7871e-03],
        [ 3.8522e-03,  1.5906e-03, -8.1443e-03,  3.9795e-04, -7.4683e-03],
        [-2.7504e-03,  3.4240e-03, -2.4047e-03, -1.1414e-03,  1.1838e-01],
        [-3.0288e-05,  1.6219e-03, -8.9141e-03, -9.7578e-02, -3.7692e-04],
        [ 2.0739e-05, -2.1580e-03,  1.3404e-03,  2.8110e-01, -1.3584e-03],
        [-4.9053e-04, -4.3310e-04,  1.6581e-03, -6.4742e-04,  2.0914e-04],
        [ 7.3391e-05,  8.8280e-04,  4.7076e-02,  2.6500e-03, -2.8230e-03],
        [-2.3954e-04,  3.7137e-03, -1.6540e-01,  2.4150e-04,  9.3647e-04],
        [-2.5151e-03, -1.8402e-03, -1.1540e-02, -9.3109e-05, -4.7172e-04],
        [ 2.4433e-06, -3.4044e-05, -7.9456e-04, -2.1411e-03, -3.7009e-04],
        [ 3.3675e-05,  1.3379e-02, -3.3243e-03, -2.3951e-04, -3.8608e-02],
        [-2.3595e-04, -8.9144e-05, -3.4521e-03,  2.0994e-04,  1.8630e-03],
        [-6.3911e-02, -7.9076e-04,  5.3869e-05,  6.3467e-04,  7.7776e-04],
        [ 8.8773e-02,  8.5261e-02,  8.3653e-04,  4.1366e-04,  2.7946e-04],
        [ 2.6761e-04,  1.1202e-03, -8.1126e-05,  1.6349e-02, -9.1255e-03],
        [-2.8080e-02, -1.3595e-02, -4.1978e-02, -1.4053e-01, -5.6636e-03],
        [-6.5849e-02,  5.6611e-02,  1.8100e-01,  1.9327e-03,  1.7478e-01],
        [ 6.1030e-02,  6.3409e-02,  2.7864e-04,  2.1031e-01,  1.3903e-04],
        [-1.7485e-01,  1.7573e-01,  1.8064e-03, -1.3580e-03,  3.0777e-03],
        [-2.8429e-02,  9.3042e-04, -5.4016e-02,  4.0523e-04,  3.0890e-04],
        [ 5.3415e-02, -7.8520e-02,  1.2758e-03,  1.0662e-04, -3.4664e-03],
        [-6.3786e-02,  1.1911e-02, -1.0383e-01, -4.1280e-02, -1.8879e-01],
        [ 4.4884e-05, -7.4126e-03, -1.7329e-03, -1.6137e-03,  5.1869e-04],
        [ 2.9950e-02, -7.8738e-02,  1.2871e-03,  1.1233e-01,  2.9529e-03],
        [-1.2249e-02, -2.9224e-03, -1.0605e-03,  3.6787e-04, -6.8346e-05],
        [-5.0360e-05, -7.0478e-04, -9.2020e-03, -1.6245e-01,  6.1653e-03],
        [ 3.5276e-02,  8.6031e-02, -1.5234e-04, -4.5905e-04, -5.8321e-03],
        [ 2.0931e-04,  1.3313e-04, -8.5212e-04,  5.2876e-04,  9.9520e-04],
        [ 6.6807e-05,  1.3028e-03, -1.0456e-03, -2.6515e-04, -9.8260e-02],
        [ 4.7789e-03, -2.6204e-02, -4.3018e-03,  4.1458e-04,  1.4712e-04],
        [ 2.0054e-01, -4.6217e-02,  1.7757e-03, -1.0936e-03, -4.6498e-03],
        [ 1.4249e-05, -5.4715e-04,  2.5293e-03, -2.2778e-04,  1.0495e-02],
        [-1.1257e-01, -1.0398e-01,  8.2578e-02,  1.6310e-01,  1.4866e-02],
        [ 1.2164e-01, -2.8072e-04,  9.0082e-04, -3.3663e-04,  4.9650e-04],
        [ 3.6086e-04,  2.4370e-03, -1.3273e-01, -1.0340e-04, -1.3272e-04],
        [-8.3673e-05, -1.8485e-04, -5.0948e-03,  1.6455e-01,  4.4941e-03],
        [ 1.0365e-01, -7.1510e-02,  2.2002e-02,  2.8613e-04, -1.3553e-03],
        [ 4.0249e-05,  5.4363e-03, -1.2248e-02, -2.7702e-04, -8.9786e-05],
        [-8.3212e-02,  1.8288e-01, -2.6684e-03,  4.2552e-04, -1.4679e-03],
        [ 1.4414e-01, -8.2341e-02, -1.2236e-02, -4.4333e-04,  1.7248e-03],
        [-1.4593e-02,  1.5030e-03,  7.6478e-02, -9.4612e-04, -6.7436e-04],
        [ 8.0151e-03,  4.2162e-03, -8.2550e-02,  1.7215e-03, -1.8912e-03],
        [-6.4839e-02,  6.6163e-02, -2.1626e-02,  1.4121e-04, -7.1275e-05],
        [ 1.9104e-04, -1.8103e-04,  1.5424e-01, -2.6647e-05, -1.2646e-01],
        [-3.7003e-04,  1.1727e-01,  7.9931e-02, -6.4797e-04, -1.6701e-03],
        [-1.9766e-01, -4.6560e-02, -1.0987e-02,  6.0774e-04,  5.4420e-04],
        [-3.5649e-02, -5.8467e-02,  6.5424e-03, -7.3521e-04, -2.5401e-03],
        [-1.1292e-01,  8.8566e-02, -9.5930e-02,  2.0026e-03,  9.7871e-04],
        [ 1.3952e-04, -6.5877e-04,  1.0065e-01, -2.2888e-01,  1.4245e-03],
        [-1.5501e-01, -2.6148e-02, -2.9767e-02,  4.5823e-04,  6.6349e-04],
        [-7.7235e-02, -1.1034e-03, -3.3175e-02, -6.3763e-03,  5.0366e-03],
        [ 7.8313e-05,  1.4045e-03, -1.4810e-03,  1.5977e-01,  2.3330e-04],
        [ 1.3243e-01, -1.7022e-01, -1.3627e-01,  8.0558e-04, -2.1138e-03],
        [ 5.9182e-03,  5.8875e-05, -5.3900e-04, -3.9268e-04,  3.9066e-04],
        [ 1.4321e-01,  1.2902e-01, -1.9839e-02,  6.0513e-04,  1.9092e-03],
        [ 1.3804e-04, -1.3419e-03, -7.6774e-02,  2.7962e-03, -3.0685e-03],
        [ 4.1131e-02,  1.5265e-01,  1.5600e-03,  7.5686e-03, -1.8557e-03],
        [ 1.7708e-04,  6.5116e-03,  9.7636e-04, -2.1751e-04,  2.5654e-03],
        [-1.3541e-01,  1.2738e-02, -1.6516e-02, -4.2248e-04,  2.1895e-03],
        [-1.4222e-04,  2.6713e-02,  5.9579e-04,  5.7631e-04,  2.1144e-03],
        [ 1.4169e-01, -1.4103e-01, -5.2205e-05,  7.0594e-04, -4.1160e-03],
        [ 1.7779e-01,  2.9425e-02, -6.9323e-02, -5.5131e-04,  1.1577e-03],
        [ 3.2314e-02, -8.3596e-04,  5.5490e-03, -1.8967e-04,  8.6611e-04],
        [-3.7997e-05, -2.6730e-03, -5.5713e-03,  3.0256e-04,  5.6823e-03],
        [ 5.6370e-03, -4.9762e-02, -8.2914e-04, -8.1641e-04, -3.1418e-02],
        [ 4.4649e-02, -8.7330e-02,  7.4682e-02, -2.2952e-01, -8.5109e-04],
        [ 8.5480e-04,  1.2758e-01, -3.5459e-02, -1.3738e-03,  8.8914e-04],
        [ 7.0690e-05,  3.2763e-04,  2.1355e-02,  5.1597e-04,  3.8252e-03],
        [ 6.1356e-05,  3.8577e-04,  1.2446e-03, -1.6628e-03,  1.7118e-03],
        [-7.4260e-04,  4.2620e-04,  5.4651e-03,  3.8997e-04, -5.6254e-03],
        [-2.2946e-01,  1.8672e-03, -3.3561e-03,  1.0084e-03,  2.1848e-03],
        [ 1.6984e-04, -4.5857e-03, -3.9533e-02, -1.1249e-03,  6.0446e-04],
        [ 6.5846e-02, -1.2533e-01,  9.6385e-02,  5.2668e-04,  3.9333e-05],
        [-2.8360e-02,  1.0220e-01,  2.0228e-03,  2.0818e-04,  1.3877e-01],
        [ 1.2126e-01,  1.4335e-01,  4.2381e-03,  3.6036e-02, -1.2455e-03],
        [-2.5371e-04, -1.0653e-02, -1.0160e-02, -3.2928e-04, -7.3653e-04],
        [ 1.6921e-04,  6.1956e-04,  7.4704e-02, -1.5231e-01, -2.3185e-01],
        [ 1.4753e-01,  9.6646e-02, -7.9980e-02, -1.5438e-01, -1.3378e-01],
        [-6.5778e-02,  7.5245e-02,  1.4589e-02,  7.1759e-05,  3.4577e-05],
        [ 1.0575e-03,  9.0590e-03, -3.4179e-02, -1.6866e-03, -2.0811e-03],
        [ 4.4900e-05, -4.1905e-03, -1.1028e-02, -3.5402e-03,  2.1415e-03],
        [ 3.9800e-04, -6.8560e-04, -6.6686e-03, -6.6402e-05, -8.1571e-04],
        [-1.9030e-05,  2.6896e-04,  4.1929e-04, -2.6097e-03, -4.8344e-02],
        [-1.9377e-03,  1.1413e-03, -5.2130e-03, -8.0894e-04, -1.4528e-04],
        [-1.0023e-04, -1.8603e-03, -1.6008e-02, -5.2036e-04, -2.6085e-01],
        [ 5.5255e-05, -6.6471e-04, -1.0836e-01, -1.0596e-02, -5.0444e-03],
        [-3.5115e-04, -3.6099e-03, -5.7961e-03, -2.5202e-03, -1.5840e-01],
        [-6.0586e-02,  3.5739e-03,  1.5331e-02,  8.8149e-06,  2.1134e-03],
        [ 4.4819e-04, -5.4469e-04,  7.3456e-02, -5.5196e-04, -2.9054e-02],
        [ 1.2417e-01,  1.3854e-01, -5.5463e-04, -3.0681e-04, -1.6620e-03],
        [-1.3927e-02,  7.0628e-02,  1.5371e-02,  1.7926e-01,  4.1797e-03],
        [-4.8387e-02, -9.5973e-02, -8.0225e-03, -4.3047e-05, -1.3066e-02],
        [-1.5881e-01, -1.5801e-01, -2.5991e-03,  2.8289e-03, -8.0360e-04],
        [-1.3375e-04,  1.9465e-04, -9.0101e-03, -1.1000e-03, -9.8720e-04],
        [-3.5817e-02,  4.3058e-02,  9.2912e-02, -2.5383e-04,  2.0604e-02],
        [ 1.7154e-01, -1.6374e-02,  6.0955e-03, -1.0169e-03,  3.2510e-04],
        [ 5.2460e-05,  2.7241e-02, -2.1431e-03,  9.2925e-04,  4.7197e-03],
        [ 2.5991e-04,  2.5985e-03, -1.0018e-02,  1.8053e-01,  8.2958e-04],
        [ 3.8946e-04,  1.7484e-04,  8.5901e-02, -1.3667e-01, -6.2018e-03],
        [-2.0645e-04, -5.9891e-03, -1.6684e-04, -2.8283e-04,  5.7210e-04],
        [-6.5935e-02, -2.0001e-02, -2.7638e-03, -1.1301e-01,  1.8593e-01],
        [ 1.6186e-04, -3.6743e-03,  1.8196e-01,  4.7377e-04,  2.9106e-03],
        [ 1.2833e-01,  1.5893e-04,  6.6296e-02,  1.3331e-03,  4.1108e-03],
        [ 4.9187e-05,  8.2871e-03, -1.6631e-04, -4.9732e-04,  4.8027e-02],
        [ 3.5863e-02, -4.8656e-03,  3.0453e-02, -1.8167e-03,  1.7167e-01],
        [ 4.3732e-05,  4.4371e-03,  1.9819e-03,  1.9910e-04,  3.2761e-05],
        [ 3.9739e-04, -2.1498e-04,  1.9583e-03,  4.6614e-04, -5.1916e-02],
        [ 1.5859e-01, -1.1989e-01, -2.5898e-05,  9.4304e-04,  1.1708e-01],
        [-7.6412e-02,  5.5497e-02,  2.8286e-02, -1.0079e-03, -1.3689e-03],
        [-4.8604e-02, -1.0594e-01,  4.8979e-03, -9.4401e-04,  3.5228e-03],
        [-6.1636e-02, -7.9144e-02, -2.6402e-03,  1.6424e-01, -2.7006e-01],
        [-1.9614e-04,  2.2529e-04, -3.9905e-01, -9.2042e-04,  9.5076e-04],
        [-5.2462e-03, -1.0133e-02, -4.5018e-02,  2.4914e-04,  5.9312e-06],
        [ 4.9395e-02, -8.2566e-02,  7.4409e-03, -6.1470e-04,  1.3498e-01],
        [ 2.1092e-04,  6.4171e-04, -1.8598e-03,  1.3597e-04,  2.4093e-03],
        [ 1.2171e-01,  1.5813e-01, -3.5217e-03,  6.2300e-04,  1.6517e-05],
        [ 8.7339e-06,  2.2399e-03, -3.5966e-02,  4.4159e-04, -1.0105e-02],
        [ 1.6279e-04,  4.3018e-03,  1.2808e-01,  1.8253e-01, -8.1481e-04],
        [ 5.6720e-04,  1.1463e-02,  6.0243e-03,  3.3243e-03,  1.6404e-03],
        [-1.3170e-03, -6.2730e-04,  9.8800e-02, -9.6375e-04, -2.2709e-03],
        [ 1.8043e-02, -3.5957e-03,  4.1260e-02,  8.8588e-04,  1.6308e-03],
        [-1.0825e-03, -6.9838e-02,  2.0741e-03, -8.2246e-04, -1.4125e-04],
        [-4.3856e-05,  1.6560e-03,  8.6171e-02, -1.3395e-04, -2.0818e-03],
        [ 5.7330e-05, -2.4430e-03,  2.0231e-03, -1.3592e-03,  1.3166e-01],
        [ 2.2092e-05, -9.0947e-03, -1.8522e-01, -9.2153e-04,  5.6016e-05],
        [ 7.7272e-02,  1.8834e-01, -6.5637e-03,  1.7473e-03, -1.4129e-03],
        [-2.1545e-05,  5.5098e-04, -8.4751e-03, -3.9871e-04,  5.4278e-05],
        [-4.2073e-04, -6.6997e-04, -1.0905e-01,  1.3851e-04, -3.2177e-03],
        [ 1.3572e-04, -7.7767e-03, -3.6110e-03,  6.6387e-03,  6.5498e-03],
        [ 1.6706e-03,  1.8472e-02, -1.2629e-01,  2.7953e-04,  1.1496e-01],
        [ 9.9152e-02,  7.5185e-02, -2.8204e-01, -5.4432e-04,  1.3610e-03],
        [ 1.1264e-01,  1.3753e-01,  1.0441e-01,  1.8169e-03,  6.2239e-04],
        [-7.0975e-02, -1.5386e-01, -8.7617e-02,  2.2533e-03, -1.4573e-03],
        [ 2.6992e-04,  1.7281e-02,  4.6906e-04, -2.0860e-03,  1.1132e-01],
        [ 8.5579e-02,  8.3005e-02,  7.7967e-02, -8.6651e-04, -1.7808e-03],
        [ 8.5174e-02,  1.4323e-02, -1.0755e-03,  3.0686e-03, -2.7051e-04],
        [ 6.6548e-02,  5.2503e-02, -9.4088e-03,  1.6115e-03, -2.7302e-03],
        [-8.0979e-02, -5.7979e-03,  4.2853e-02,  2.0923e-03, -5.0688e-02],
        [-1.7340e-01,  6.3816e-03, -3.6276e-04,  1.6401e-01, -1.5187e-01],
        [ 1.1887e-04,  3.8780e-03, -1.3135e-01, -7.6554e-04, -5.9816e-05],
        [-1.4067e-04, -7.1354e-05, -6.3970e-03, -5.9128e-04,  1.7076e-01],
        [ 4.8107e-02, -1.2531e-01, -5.7435e-02,  9.3506e-04,  1.7396e-01],
        [ 8.4685e-03, -1.2586e-03, -3.9774e-04, -1.1737e-04, -1.1513e-03],
        [ 3.6491e-02, -1.6276e-01,  1.2794e-01, -3.6620e-03, -1.9677e-04],
        [ 3.4594e-04,  2.0345e-03, -4.5161e-04,  1.9727e-01, -2.3857e-03],
        [-3.3328e-05,  2.1760e-03, -5.2975e-04,  5.8824e-04, -1.3855e-03],
        [-3.2982e-03,  1.5229e-02,  1.4794e-03,  1.5787e-03,  1.0752e-03],
        [ 6.5676e-02,  3.6469e-02,  4.4772e-02, -3.7206e-03, -3.9598e-04],
        [ 5.1944e-05, -2.6897e-03,  5.1148e-03, -4.0004e-04,  1.5052e-03],
        [ 2.0946e-01,  1.4705e-01,  3.6125e-02, -5.6984e-04, -7.8634e-03],
        [-7.8954e-02, -1.3464e-01, -3.8022e-02,  8.1325e-04,  2.1898e-01],
        [ 4.3976e-04,  4.4679e-02,  1.0681e-01, -5.3495e-04, -2.5120e-03],
        [ 4.7950e-02, -1.0008e-01,  8.7266e-03, -7.2865e-04, -2.8699e-03],
        [ 2.6896e-02, -2.6083e-03,  1.9745e-03,  7.7265e-04,  4.2420e-02],
        [-4.6769e-02, -3.5376e-04, -8.4169e-03, -1.6733e-04, -1.2833e-01],
        [-4.9823e-04, -9.7589e-03,  1.7998e-03,  1.9380e-03, -2.3757e-05],
        [-3.9719e-04,  7.6873e-04,  7.8181e-02, -9.0880e-04,  7.6812e-04],
        [-9.1259e-02, -6.7199e-02, -7.7797e-02,  1.9825e-04, -1.6353e-01],
        [ 1.6824e-01, -6.1151e-02, -4.5834e-03, -2.5037e-04,  1.9517e-04],
        [ 9.7248e-02,  3.3808e-02,  5.2512e-03, -8.1043e-04, -1.4518e-01],
        [-1.0556e-01, -4.1623e-02,  9.1604e-02, -1.1604e-01,  1.4324e-03],
        [-8.5998e-02,  1.4190e-03,  7.0659e-04, -4.1477e-04,  5.7058e-04],
        [-1.6077e-04, -6.4752e-03,  9.8442e-02, -7.7824e-04,  2.0550e-03],
        [ 7.1222e-04,  8.1199e-03,  4.7921e-03, -4.5638e-04,  1.3073e-01],
        [-3.6186e-02, -6.7596e-02, -5.6218e-03,  1.2786e-03,  6.1535e-03],
        [-1.6230e-04, -1.3655e-03, -4.0056e-03,  4.0315e-04,  1.0117e-03],
        [-2.5753e-04,  8.2016e-05, -1.0630e-01,  5.7151e-02,  7.4079e-04],
        [-6.1165e-06,  1.7430e-03, -7.0623e-03, -5.6294e-04,  4.8096e-04],
        [-1.3195e-01, -8.2270e-02, -8.7772e-04,  4.8869e-04, -3.5061e-03],
        [-7.1076e-02,  2.1807e-02, -1.8480e-02,  1.3873e-03,  1.9878e-03],
        [-1.6684e-04, -5.6737e-05,  1.5694e-03, -1.9120e-03, -2.5743e-03],
        [ 7.2425e-02,  2.5261e-02, -1.2781e-01, -2.3544e-03, -1.1822e-03],
        [-1.3402e-03, -2.4944e-04,  7.5778e-02,  5.4978e-04,  2.8440e-04],
        [ 1.1237e-04,  1.2946e-03,  1.3590e-01, -1.0746e-03, -3.4625e-03]]), tau=tensor(1.0062), history_obj=[562380.4375, 563685.8125, 563900.375, 564120.4375, 564208.0625, 564242.1875, 564203.5, 564215.625, 564260.875, 564259.4375])

we can visualized the corresponding values o the ELBO (higher the better)

import numpy as np
import matplotlib.pyplot as plt
import torch

def to1d(a):
    if isinstance(a, torch.Tensor):
        return a.detach().cpu().flatten().numpy()
    return np.asarray(a).ravel()

fig, ax = plt.subplots(figsize=(8,5))

ax.plot(to1d(mycebmf11.obj), label="cash")
ax.plot(to1d(mycebmf12.obj), label="cgb")
ax.plot(to1d(mycebmf13.obj), label="emdn prior")
ax.plot(to1d(mycebmf31.obj), label="shap cgb prior")
ax.plot(to1d(mycebmf.obj),  label="point Laplace")

ax.set_title("Objective vs Iteration")
ax.set_xlabel("Iteration")
ax.set_ylabel("Objective")
ax.legend(title="Method", loc="best")
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
../_images/8a6b0cab689d4d8bbf16f5c06d129897dc1bc9f7efdab4b9bf8e29e9921da1f8.png

The results is not surprising as in this case only the emdn prior can recover the data generating process thus leading to a better fit.

The worst is acheived by the EBMF model with point Laplace which is the least flexible model

Noe

mycebmf._update_fitted_value()
mycebmf11._update_fitted_value()
mycebmf12._update_fitted_value()
mycebmf13._update_fitted_value()
mycebmf31._update_fitted_value()

plt.scatter(mycebmf.Y_fit,L@f)
plt.scatter(mycebmf13.Y_fit,L@f)
print("RMSE Point Laplace",torch.mean((mycebmf.Y_fit-L@f).pow(2)))

print("RMSE cash",torch.mean((mycebmf11.Y_fit-L@f).pow(2)))
print("RMSE cgb",torch.mean((mycebmf12.Y_fit-L@f).pow(2)))
print("RMSE emdn",torch.mean((mycebmf13.Y_fit-L@f).pow(2)))
print("RMSE emdn",torch.mean((mycebmf31.Y_fit-L@f).pow(2)))
RMSE Point Laplace tensor(0.0087)
RMSE cash tensor(0.0066)
RMSE cgb tensor(0.0051)
RMSE emdn tensor(0.0031)
RMSE emdn tensor(0.0072)
../_images/8a79842f45723f36d23cf0ab6518a281bc6e5d8a491142a292f97b19c2ebb890.png

Fitted factors (emdn prior)

for i in range(3):
    plt.figure(figsize=(8,6))
    plt.scatter(
        x.numpy(),
        y.numpy(),
        c=mycebmf13.L[:, i].numpy(),
        cmap="coolwarm",
        s=50
    )
    for v in [0.33, 0.66]:
        plt.axhline(v, color="black", linestyle="--")
        plt.axvline(v, color="black", linestyle="--")
    plt.title(f"Factor {i+1}")
    plt.xlabel("x"); plt.ylabel("y")
    plt.colorbar(label=f"L{i+1}")
    plt.show()
../_images/9bcf5f67c16c382633983c16503190a207a782b6fbfdb25d50204c73bdd9a1d6.png ../_images/a2488224f09c9c39ca0f35d1952e2d442dcdce7030cbe8a7f6ba365c07675c2d.png ../_images/5055729a663afc96c15fd465e2b1a320bbb117579caf54251ff7a97f1760574f.png

Fitted factors (point Laplace)

for i in range(5):
    plt.figure(figsize=(8,6))
    plt.scatter(
        x.numpy(),
        y.numpy(),
        c=mycebmf.L[:, i].numpy(),
        cmap="coolwarm",
        s=50
    )
    for v in [0.33, 0.66]:
        plt.axhline(v, color="black", linestyle="--")
        plt.axvline(v, color="black", linestyle="--")
    plt.title(f"Factor {i+1}")
    plt.xlabel("x"); plt.ylabel("y")
    plt.colorbar(label=f"L{i+1}")
    plt.show()
../_images/6077cfeeb7b7cb5f3808e23ac8f693711ac79603ff275c6dc5802a10a14c2f85.png ../_images/f009b7ab50dde29147455f63f630de93eb39d35d1a931c1e53472e37ada000f7.png ../_images/da23ed5ca49b1a53373b0c5e2094a5ef5c21203b96d96b9cc71d720bcd7ec367.png ../_images/77492c1051c02e8403335bb9ea3be6166570e90b234734bf7a4e168af1593f84.png ../_images/01759e6d922be0cd3ba76fe679b3a751023aa63e3afb62edae4f24d36e482d55.png

We can visulaized the fitted factor for the different prior

k = 2

import matplotlib.pyplot as plt

fig, ax = plt.subplots(figsize=(7,5))
ax.scatter(x.detach().cpu().numpy(), mycebmf.L[:,k].detach().cpu().numpy(), s=10, label="point Laplace")
ax.scatter(x.detach().cpu().numpy(), mycebmf11.L[:,k].detach().cpu().numpy(), s=10, label="cash")
ax.scatter(x.detach().cpu().numpy(), mycebmf12.L[:,k].detach().cpu().numpy(), s=10, label="cgb")
ax.scatter(x.detach().cpu().numpy(), mycebmf13.L[:,k].detach().cpu().numpy(), s=10, label="emdn")

ax.set_title(f"Fitted Factor {k}")
ax.set_xlabel("x"); ax.set_ylabel(f"L[:, {k}]")
ax.legend(title="Method", loc="best")
plt.tight_layout()
plt.show()
../_images/c0dc7d6f1a8f45f39b0b8a618c2669a5d03f9967dc97273d7246e73eff138557.png