use crate::{AvgLLH, Error, Parametrizable};
pub trait Latent<T>
where
    T: Parametrizable,
{
    fn expect(
        &self,
        data: &T::DataIn<'_>,
        likelihood: &T::Likelihood,
    ) -> Result<(T::Likelihood, AvgLLH), Error>;
}
pub trait Mixable<T>
where
    T: Parametrizable,
{
    fn predict(
        &self,
        latent_likelihood: T::Likelihood,
        data: &T::DataIn<'_>,
    ) -> Result<T::DataOut, Error>;
}
#[derive(Clone, Debug)]
pub struct Mixture<T, L>
where
    T: Parametrizable<Likelihood = L::Likelihood>,
    L: Parametrizable + Latent<L>,
{
    pub mixables: T,
    pub latent: L,
}
impl<T, L> Mixture<T, L>
where
    T: Parametrizable<Likelihood = L::Likelihood>,
    L: Parametrizable + Latent<L>,
{
    pub fn new(mixables: T, latent: L) -> Self {
        Mixture {
            latent: latent,
            mixables: mixables,
        }
    }
}
impl<T, L> Parametrizable for Mixture<T, L>
where
    T: for<'a> Parametrizable<Likelihood = L::Likelihood, DataIn<'a> = L::DataIn<'a>> + Mixable<T>,
    L: Parametrizable + Latent<L>,
{
    type SufficientStatistics = (L::SufficientStatistics, T::SufficientStatistics);
    type Likelihood = T::Likelihood;
    type DataIn<'a> = T::DataIn<'a>;
    type DataOut = T::DataOut;
    fn expect(&self, data: &Self::DataIn<'_>) -> Result<(Self::Likelihood, AvgLLH), Error> {
        Latent::expect(&self.latent, data, &self.mixables.expect(data)?.0)
        }
    fn compute(
        &self,
        data: &Self::DataIn<'_>,
        responsibilities: &Self::Likelihood,
    ) -> Result<Self::SufficientStatistics, Error> {
        Ok((
            self.latent.compute(&data, responsibilities)?,
            self.mixables.compute(&data, responsibilities)?,
        ))
    }
    fn maximize(
        &mut self,
        sufficient_statistics: &Self::SufficientStatistics,
    ) -> Result<(), Error> {
        self.latent.maximize(&sufficient_statistics.0)?;
        self.mixables.maximize(&sufficient_statistics.1)?;
        Ok(())
    }
    fn predict(&self, data: &Self::DataIn<'_>) -> Result<Self::DataOut, Error> {
        let likelihood = Parametrizable::expect(&self.latent, data)?.0;
        Mixable::predict(&self.mixables, likelihood, data)
    }
    fn update(
        &mut self,
        sufficient_statistics: &Self::SufficientStatistics,
        weight: f64,
    ) -> Result<(), Error> {
        self.latent.update(&sufficient_statistics.0, weight)?;
        self.mixables.update(&sufficient_statistics.1, weight)?;
        Ok(())
    }
    fn merge(
        sufficient_statistics: &[&Self::SufficientStatistics],
        weights: &[f64],
    ) -> Result<Self::SufficientStatistics, Error> {
        let a: Vec<_> = sufficient_statistics.iter().map(|x| &x.0).collect();
        let b: Vec<_> = sufficient_statistics.iter().map(|x| &x.1).collect();
        Ok((L::merge(&a[..], weights)?, T::merge(&b[..], weights)?))
    }
    fn expect_rand(&self, data: &Self::DataIn<'_>, k: usize) -> Result<Self::Likelihood, Error> {
        self.latent.expect_rand(data, k)
    }
}
#[cfg(all(test, feature = "ndarray"))]
mod tests {
    use super::*;
    use crate::backend::ndarray::{
        finite::Finite,
        gaussian::{sort_parameters, Gaussian},
        utils::{generate_random_expections, generate_samples},
    };
    use tracing::info;
    use tracing_test::traced_test;
    #[traced_test]
    #[test]
    fn em_step() {
        let k = 3;
        let (data, _, means, _covariances) = generate_samples(&[5000, 10000, 15000], 2);
        info!(%means);
        let gaussian = Gaussian::new();
        let categorial = Finite::new(None);
        let mut mixture = Mixture {
            mixables: gaussian,
            latent: categorial,
        };
        let mut likelihood: AvgLLH;
        let mut responsibilities = generate_random_expections(&data.view(), k).unwrap();
        for _ in 1..20 {
            let stat = mixture.compute(&data.view(), &responsibilities).unwrap();
            mixture.maximize(&stat).unwrap();
            info!(%mixture.mixables.means);
            info!(%mixture.latent.pmf);
            (responsibilities, likelihood) = mixture.expect(&data.view()).unwrap();
            info!("lieklihood {}", likelihood.0);
        }
        info!(%means);
        let (means_sorted, _) = sort_parameters(&mixture.mixables, &mixture.latent.pmf.view());
        info!(%means_sorted);
        info!("{}", &means_sorted - &means);
        assert!(means.abs_diff_eq(&means_sorted, 1e-2));
        }
}