use ndarray::{prelude::*, Shape};
pub struct NdIndexIterator<D: Dimension> {
shape: Shape<D>,
counter: usize,
}
impl<D> Iterator for NdIndexIterator<D>
where
D: Dimension,
{
type Item = Array1<usize>;
fn next(&mut self) -> Option<Self::Item> {
let max = self.shape.raw_dim().size();
if self.counter < max {
let dimensions = self.shape.raw_dim().slice();
let n_dimensions = dimensions.len();
let mut result = Array::<usize, Ix1>::zeros(n_dimensions);
let mut counter = self.counter;
for (i, d) in dimensions.iter().rev().enumerate() {
result[n_dimensions - i - 1] = counter % d;
counter /= d;
}
self.counter += 1;
Some(result)
} else {
None
}
}
}
pub fn ndindex<Sh>(shape: Sh) -> NdIndexIterator<Sh::Dim>
where
Sh: ShapeBuilder,
{
let shape = shape.into_shape(); NdIndexIterator {
shape: shape,
counter: 0,
}
}
pub fn get_ndindex_array<D>(shape: &Shape<D>) -> Array2<f64>
where
D: Dimension,
{
let dim = shape.raw_dim();
let (m, n) = (dim.size(), dim.slice().len());
let mut result = Array2::<f64>::zeros((m, n));
let index_iterator = NdIndexIterator {
shape: shape.clone(),
counter: 0,
};
for (mut r, i) in result.outer_iter_mut().zip(index_iterator) {
r.assign(&i.mapv(|e| e as f64));
}
result
}
#[cfg(test)]
mod tests {
#[test]
fn it_works() {
let result = 2 + 2;
assert_eq!(result, 4);
}
}