use std::f64::consts::PI;
use crate::{AvgLLH, Error, Mixable, Parametrizable};
use itertools::izip;
use ndarray::parallel::prelude::*;
use ndarray::prelude::*;
use itertools::Itertools;
use super::utils::{
    get_det_spd, get_shape2, get_shape3, get_weighted_means, get_weighted_sum, invert_spd,
};
#[derive(Default, Debug, Clone)]
pub struct Gaussian {
    pub means: Array2<f64>,
    pub covariances: Array3<f64>,
    pub precisions: Array3<f64>,
    pub summands: Array1<f64>,
    sufficient_statistics: <Gaussian as Parametrizable>::SufficientStatistics,
}
impl Gaussian {
    pub fn new() -> Gaussian {
        Gaussian {
            ..Default::default()
        }
    }
}
impl Parametrizable for Gaussian {
    type SufficientStatistics = (Array1<f64>, Array2<f64>, Array3<f64>);
    type Likelihood = Array2<f64>;
    type DataIn<'a> = ArrayView2<'a, f64>;
    type DataOut = Array2<f64>;
    fn expect(&self, data: &Self::DataIn<'_>) -> Result<(Self::Likelihood, AvgLLH), Error> {
        let [k, _d, _] = get_shape3(&self.covariances)?;
        let adjusted = &data.slice(s![.., NewAxis, ..]) - &self.means.slice(s![NewAxis, .., ..]);
        let [n, _d] = get_shape2(data)?;
        let mut responsibilities = Array2::<f64>::default((n, k));
        (
            adjusted.axis_iter(Axis(1)),
            responsibilities.axis_iter_mut(Axis(1)),
            self.precisions.axis_iter(Axis(0)),
            self.summands.axis_iter(Axis(0)),
        ) .into_par_iter()
            .for_each(|(samples, mut rsp, precision, summand)| {
                izip!(samples.axis_iter(Axis(0)), rsp.axis_iter_mut(Axis(0))) .for_each(|(x, mut r)| {
                        let x = x.slice(s![.., NewAxis]);
                        let x = &x.t().dot(&precision).dot(&x).into_shape(()).unwrap();
                        let x = &summand - x;
                        r.assign(&x);
                    })
            });
        Ok((responsibilities, AvgLLH(f64::NAN)))
    }
    fn compute(
        &self,
        data: &Self::DataIn<'_>,
        responsibilities: &Self::Likelihood,
    ) -> Result<Self::SufficientStatistics, Error> {
        let sum_responsibilities = responsibilities.sum_axis(Axis(0)); let weighted_sum = get_weighted_sum(&data, &responsibilities); let [k, d] = get_shape2(&weighted_sum.view())?;
        let mut covs = Array3::<f64>::zeros((k, d, d)); (
            covs.axis_iter_mut(Axis(0)),
            responsibilities.axis_iter(Axis(1)),
        ) .into_par_iter()
            .for_each(|(mut cov, resp)| {
                data.axis_iter(Axis(0)) .zip(resp.axis_iter(Axis(0)))
                    .for_each(|(x, r)| {
                        let x = x.slice(s![NewAxis, ..]);
                        let x = &r.slice(s![NewAxis, NewAxis]) * &x.t().dot(&x);
                        cov += &x;
                    });
            });
        Ok((sum_responsibilities, weighted_sum, covs))
    }
    fn maximize(
        &mut self,
        sufficient_statistics: &Self::SufficientStatistics,
    ) -> Result<(), Error> {
        self.means = get_weighted_means(&sufficient_statistics.1, &sufficient_statistics.0); let [k, d, _d] = get_shape3(&sufficient_statistics.2)?;
        let mut product = Array3::<f64>::zeros((k, d, d)); (
            product.axis_iter_mut(Axis(0)), self.means.axis_iter(Axis(0)),
        )
            .into_par_iter()
            .for_each(|(mut prod, mean)| {
                prod.assign(
                    &mean
                        .slice(s![.., NewAxis])
                        .dot(&mean.slice(s![NewAxis, ..])),
                )
            });
        self.covariances = &sufficient_statistics.2
            / &sufficient_statistics.0.slice(s![.., NewAxis, NewAxis])
            - &product;
        self.precisions = Array3::<f64>::zeros((k, d, d));
        self.summands = Array1::<f64>::zeros(k);
        (
            self.covariances.axis_iter(Axis(0)),
            self.precisions.axis_iter_mut(Axis(0)),
            self.summands.axis_iter_mut(Axis(0)),
        )
            .into_par_iter()
            .for_each(|(cov, mut prec, mut summand)| {
                prec.assign(&invert_spd(&cov).unwrap());
                summand.assign(&arr0(
                    -(k as f64) / 2.0 * (2.0 * PI).ln() - get_det_spd(&cov).unwrap().ln(),
                ))
            });
        Ok(())
    }
    fn update(
        &mut self,
        sufficient_statistics: &Self::SufficientStatistics,
        weight: f64,
    ) -> Result<(), Error> {
        self.sufficient_statistics.0 =
            &self.sufficient_statistics.0 * (1.0 - weight) + &sufficient_statistics.0 * weight;
        self.sufficient_statistics.1 =
            &self.sufficient_statistics.1 * (1.0 - weight) + &sufficient_statistics.1 * weight;
        self.sufficient_statistics.2 =
            &self.sufficient_statistics.2 * (1.0 - weight) + &sufficient_statistics.2 * weight;
        Ok(())
    }
    fn merge(
        sufficient_statistics: &[&Self::SufficientStatistics],
        weights: &[f64],
    ) -> Result<Self::SufficientStatistics, Error> {
        Ok(sufficient_statistics
            .iter()
            .zip(weights.iter())
            .map(|(s, w)| (&s.0 * *w, &s.1 * *w, &s.2 * *w))
            .reduce(|s1, s2| (&s1.0 + &s2.0, &s1.1 + &s2.1, &s1.2 + &s2.2))
            .unwrap())
    }
    fn predict(&self, _data: &Self::DataIn<'_>) -> Result<Self::DataOut, Error> {
        Err(Error::NotImplemented)
    }
}
impl Mixable<Gaussian> for Gaussian {
    fn predict(
        &self,
        _latent_likelihood: <Gaussian as Parametrizable>::Likelihood,
        _data: &<Gaussian as Parametrizable>::DataIn<'_>,
    ) -> Result<<Gaussian as Parametrizable>::DataOut, Error> {
        Err(Error::NotImplemented)
    }
}
pub fn sort_parameters(gmm: &Gaussian, pmf: &ArrayView1<f64>) -> (Array2<f64>, Array3<f64>) {
    let mut means = gmm.means.clone();
    let mut covariances = gmm.covariances.clone();
    pmf.as_slice()
        .unwrap()
        .into_iter()
        .enumerate()
        .sorted_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
        .enumerate()
        .for_each(|(i, (j, _))| {
            means
                .slice_mut(s![i, ..])
                .assign(&gmm.means.slice(s![j, ..]));
            covariances
                .slice_mut(s![i, .., ..])
                .assign(&gmm.covariances.slice(s![j, .., ..]));
        });
    (means, covariances)
}
#[cfg(test)]
mod tests {
    use super::*;
    use crate::backend::ndarray::utils::{filter_data, generate_samples};
    use tracing::info;
    use tracing_test::traced_test;
    #[traced_test]
    #[test]
    fn check_maximization() {
        let (data, responsibilities, _, covariances) =
            generate_samples(&[100000, 100000, 100000], 2);
        let mut gaussian = Gaussian::new();
        let sufficient_statistics = gaussian.compute(&data.view(), &responsibilities).unwrap();
        gaussian.maximize(&sufficient_statistics).unwrap();
        assert!(covariances.abs_diff_eq(&gaussian.covariances, 1e-1))
    }
    #[traced_test]
    #[test]
    fn check_expectation() {
        let (data, responsibilities, _, covariances) =
            generate_samples(&[200000, 200000, 200000], 2);
        let mut gaussian = Gaussian::new();
        let sufficient_statistics = gaussian.compute(&data.view(), &responsibilities).unwrap();
        gaussian.maximize(&sufficient_statistics).unwrap();
        let (responsibilities, _) = gaussian.expect(&data.view()).unwrap();
        let responsibilities = responsibilities.map(|x| x.exp());
        let responsibilities =
            &responsibilities / &responsibilities.sum_axis(Axis(1)).slice(s![.., NewAxis]);
        let sufficient_statistics = gaussian.compute(&data.view(), &responsibilities).unwrap();
        gaussian.maximize(&sufficient_statistics).unwrap();
        info!(%covariances);
        info!(%gaussian.covariances);
        let diff = &covariances - &gaussian.covariances;
        info!(%diff);
        assert!(covariances.abs_diff_eq(&gaussian.covariances, 2e-1));
    }
    #[traced_test]
    #[test]
    fn check_merge() {
        let (data, responsibilities, _, covariances) = generate_samples(&[10000, 10000, 10000], 2);
        let (data_1, responsibilities_1) =
            filter_data(&data.view(), &responsibilities.view(), |x, _y| x[1] > 0.5).unwrap();
        let (data_2, responsibilities_2) =
            filter_data(&data.view(), &responsibilities.view(), |x, _y| x[1] <= 0.5).unwrap();
        let mut gaussian = Gaussian::new();
        let sufficient_statistics_1 = gaussian
            .compute(&data_1.view(), &responsibilities_1)
            .unwrap();
        let sufficient_statistics_2 = gaussian
            .compute(&data_2.view(), &responsibilities_2)
            .unwrap();
        gaussian.maximize(&sufficient_statistics_1).unwrap();
        info!("{}", covariances.abs_diff_eq(&gaussian.covariances, 1e-3));
        assert!(!covariances.abs_diff_eq(&gaussian.covariances, 1e-3));
        gaussian.maximize(&sufficient_statistics_2).unwrap();
        assert!(!covariances.abs_diff_eq(&gaussian.covariances, 1e-3));
        let sufficient_statistics = Gaussian::merge(
            &[&sufficient_statistics_1, &sufficient_statistics_2],
            &[0.5, 0.5],
        )
        .unwrap();
        gaussian.maximize(&sufficient_statistics).unwrap();
        info!(%gaussian.covariances);
        info!(%covariances);
        assert!(covariances.abs_diff_eq(&gaussian.covariances, 1e-3));
    }
    }