1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77
//! Adds functions that extends 2D float arrays such that they behave as PointSets
use ndarray::{prelude::*, Data};
use num_traits::Float;
/// Computes the L2 norm for all rows of a `PointSet`
///
/// # Examples
///
/// ```
/// use som_rs::ndarray::point_set::row_norm_l2;
///
/// assert_eq!(row_norm_l2(points), );
/// ```
pub fn row_norm_l2<A, S>(points: &ArrayBase<S, Ix2>) -> Array1<A>
where
S: Data<Elem = A>,
A: Float,
{
points.mapv(|e| e.powi(2)).sum_axis(Axis(1)).mapv(A::sqrt)
}
pub trait PointSet<A> {
/// Computes the difference of each row to a given `point` (1D)
///
/// # Examples
///
/// ```
/// // Example template not implemented for trait functions
/// ```
fn get_differences<S>(&self, point: &ArrayBase<S, Ix1>) -> Array2<A>
where
S: Data<Elem = A>,
A: Float;
/// Computes the Eucledean distance of each row to a given `point` (1D)
///
/// # Examples
///
/// ```
/// // Example template not implemented for trait functions
/// ```
fn get_distances<S>(&self, point: &ArrayBase<S, Ix1>) -> Array1<A>
where
S: Data<Elem = A>,
A: Float;
}
impl<A, T> PointSet<A> for ArrayBase<T, Ix2>
where
T: Data<Elem = A>,
A: Float,
{
fn get_differences<S>(&self, point: &ArrayBase<S, Ix1>) -> Array2<A>
where
S: Data<Elem = A>,
{
self - &point.view().insert_axis(Axis(0))
}
fn get_distances<S>(&self, point: &ArrayBase<S, Ix1>) -> Array1<A>
where
S: Data<Elem = A>,
A: Float,
{
row_norm_l2(&self.get_differences(point))
}
}
#[cfg(test)]
mod tests {
#[test]
fn it_works() {
let result = 2 + 2;
assert_eq!(result, 4);
}
}