1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
//! Adds functions that extends 2D float arrays such that they behave as PointSets
use ndarray::{prelude::*, Data};
use num_traits::Float;

/// Computes the L2 norm for all rows of a `PointSet`
///
/// # Examples
///
/// ```
/// use som_rs::ndarray::point_set::row_norm_l2;
///
/// assert_eq!(row_norm_l2(points), );
/// ```
pub fn row_norm_l2<A, S>(points: &ArrayBase<S, Ix2>) -> Array1<A>
where
    S: Data<Elem = A>,
    A: Float,
{
    points.mapv(|e| e.powi(2)).sum_axis(Axis(1)).mapv(A::sqrt)
}

pub trait PointSet<A> {
    /// Computes the difference of each row to a given `point` (1D)
    ///
    /// # Examples
    ///
    /// ```
    /// // Example template not implemented for trait functions
    /// ```
    fn get_differences<S>(&self, point: &ArrayBase<S, Ix1>) -> Array2<A>
    where
        S: Data<Elem = A>,
        A: Float;

    /// Computes the Eucledean distance of each row to a given `point` (1D)
    ///
    /// # Examples
    ///
    /// ```
    /// // Example template not implemented for trait functions
    /// ```
    fn get_distances<S>(&self, point: &ArrayBase<S, Ix1>) -> Array1<A>
    where
        S: Data<Elem = A>,
        A: Float;
}

impl<A, T> PointSet<A> for ArrayBase<T, Ix2>
where
    T: Data<Elem = A>,
    A: Float,
{
    fn get_differences<S>(&self, point: &ArrayBase<S, Ix1>) -> Array2<A>
    where
        S: Data<Elem = A>,
    {
        self - &point.view().insert_axis(Axis(0))
    }

    fn get_distances<S>(&self, point: &ArrayBase<S, Ix1>) -> Array1<A>
    where
        S: Data<Elem = A>,
        A: Float,
    {
        row_norm_l2(&self.get_differences(point))
    }
}

#[cfg(test)]
mod tests {

    #[test]
    fn it_works() {
        let result = 2 + 2;
        assert_eq!(result, 4);
    }
}