use crate::{AvgLLH, Error, Parametrizable};
pub trait Latent<T>
where
T: Parametrizable,
{
fn expect(
&self,
data: &T::DataIn<'_>,
likelihood: &T::Likelihood,
) -> Result<(T::Likelihood, AvgLLH), Error>;
}
pub trait Mixable<T>
where
T: Parametrizable,
{
fn predict(
&self,
latent_likelihood: T::Likelihood,
data: &T::DataIn<'_>,
) -> Result<T::DataOut, Error>;
}
#[derive(Clone, Debug)]
pub struct Mixture<T, L>
where
T: Parametrizable<Likelihood = L::Likelihood>,
L: Parametrizable + Latent<L>,
{
pub mixables: T,
pub latent: L,
}
impl<T, L> Mixture<T, L>
where
T: Parametrizable<Likelihood = L::Likelihood>,
L: Parametrizable + Latent<L>,
{
pub fn new(mixables: T, latent: L) -> Self {
Mixture {
latent: latent,
mixables: mixables,
}
}
}
impl<T, L> Parametrizable for Mixture<T, L>
where
T: for<'a> Parametrizable<Likelihood = L::Likelihood, DataIn<'a> = L::DataIn<'a>> + Mixable<T>,
L: Parametrizable + Latent<L>,
{
type SufficientStatistics = (L::SufficientStatistics, T::SufficientStatistics);
type Likelihood = T::Likelihood;
type DataIn<'a> = T::DataIn<'a>;
type DataOut = T::DataOut;
fn expect(&self, data: &Self::DataIn<'_>) -> Result<(Self::Likelihood, AvgLLH), Error> {
Latent::expect(&self.latent, data, &self.mixables.expect(data)?.0)
}
fn compute(
&self,
data: &Self::DataIn<'_>,
responsibilities: &Self::Likelihood,
) -> Result<Self::SufficientStatistics, Error> {
Ok((
self.latent.compute(&data, responsibilities)?,
self.mixables.compute(&data, responsibilities)?,
))
}
fn maximize(
&mut self,
sufficient_statistics: &Self::SufficientStatistics,
) -> Result<(), Error> {
self.latent.maximize(&sufficient_statistics.0)?;
self.mixables.maximize(&sufficient_statistics.1)?;
Ok(())
}
fn predict(&self, data: &Self::DataIn<'_>) -> Result<Self::DataOut, Error> {
let likelihood = Parametrizable::expect(&self.latent, data)?.0;
Mixable::predict(&self.mixables, likelihood, data)
}
fn update(
&mut self,
sufficient_statistics: &Self::SufficientStatistics,
weight: f64,
) -> Result<(), Error> {
self.latent.update(&sufficient_statistics.0, weight)?;
self.mixables.update(&sufficient_statistics.1, weight)?;
Ok(())
}
fn merge(
sufficient_statistics: &[&Self::SufficientStatistics],
weights: &[f64],
) -> Result<Self::SufficientStatistics, Error> {
let a: Vec<_> = sufficient_statistics.iter().map(|x| &x.0).collect();
let b: Vec<_> = sufficient_statistics.iter().map(|x| &x.1).collect();
Ok((L::merge(&a[..], weights)?, T::merge(&b[..], weights)?))
}
fn expect_rand(&self, data: &Self::DataIn<'_>, k: usize) -> Result<Self::Likelihood, Error> {
self.latent.expect_rand(data, k)
}
}
#[cfg(all(test, feature = "ndarray"))]
mod tests {
use super::*;
use crate::backend::ndarray::{
finite::Finite,
gaussian::{sort_parameters, Gaussian},
utils::{generate_random_expections, generate_samples},
};
use tracing::info;
use tracing_test::traced_test;
#[traced_test]
#[test]
fn em_step() {
let k = 3;
let (data, _, means, _covariances) = generate_samples(&[5000, 10000, 15000], 2);
info!(%means);
let gaussian = Gaussian::new();
let categorial = Finite::new(None);
let mut mixture = Mixture {
mixables: gaussian,
latent: categorial,
};
let mut likelihood: AvgLLH;
let mut responsibilities = generate_random_expections(&data.view(), k).unwrap();
for _ in 1..20 {
let stat = mixture.compute(&data.view(), &responsibilities).unwrap();
mixture.maximize(&stat).unwrap();
info!(%mixture.mixables.means);
info!(%mixture.latent.pmf);
(responsibilities, likelihood) = mixture.expect(&data.view()).unwrap();
info!("lieklihood {}", likelihood.0);
}
info!(%means);
let (means_sorted, _) = sort_parameters(&mixture.mixables, &mixture.latent.pmf.view());
info!(%means_sorted);
info!("{}", &means_sorted - &means);
assert!(means.abs_diff_eq(&means_sorted, 1e-2));
}
}