Gaussian mixture models can be trained with EM
High-dimensional clustering (Warning: $\mathcal O(n^3)$!)
$ # Python packaging and dependency management
$ curl -LsSf https://astral.sh/uv/install.sh | sh
$ # Get the code
$ git clone https://github.com/StefanUlbrich/PyCon2024.git
$ cd PyCon2024 && git checkout skeleton
PyCon2024$ # Create virtual environment and install dependencies
PyCon2024$ uv env
PyCon2024$ uv pip sync requirements.txt
PyCon2024$ . ./venv/bin/activate
@dataclass
class GMM:
means: NDArray[np.float64] # k x d
covs: NDArray[np.float64] # k x d x d
weights: NDArray[np.float64] # k
def expect(
gmm: GaussianMixtureModel,
data: NDArray[np.float64] # n x d
) -> NDArray[np.float64]: # n x k
...
def maximize(
gmm: GaussianMixtureModel,
responsibilities: NDArray[np.float64], # n x k
data: NDArray[np.float64] # n x d
) -> None:
...
knd, kn, knd -> kdd
einstein_sum_notation('knd, kn, knd -> kdd', data, responsibilities, data)
einsum('knd, kn, knd -> kdd', data, responsibilities, data)
np.einsum('knd, kn, knd -> kdd', data, responsibilities, data)
def maximize_(gmm: GaussianMixtureModel, responsibilities: Likelihood, data: NDArray[np.float64]) -> None:
"""Maximization step"""
assert responsibilities.shape[0] == data.shape[0]
sum_responsibilities = responsibilities.sum(axis=0)
gmm.means = (
np.sum(data[:, np.newaxis, :] * responsibilities[:, :, np.newaxis], axis=0)
/ sum_responsibilities[:, np.newaxis]
)
data = data[:, np.newaxis, :] - gmm.means[np.newaxis, :, :] # n x k x d
gmm.covs = (
np.einsum("nkd, nk, nke -> kde", data, responsibilities, data)
/ sum_responsibilities[:, np.newaxis, np.newaxis]
)
gmm.weights = sum_responsibilities / sum_responsibilities.sum()
Rust is on its seventh year as the most loved language with 87% of developers saying they want to continue using it. Rust also ties with Python as the most wanted technology with TypeScript running a close second. (stackoverflow survey 2022)
pub struct NewsArticle {
pub headline: String,
pub location: String,
pub author: String,
pub content: String,
}
pub trait Summary {
fn summarize(&self) -> String;
}
impl Summary for NewsArticle {
fn summarize(&self) -> String {
format!("{}, by {} ({})", self.headline, self.author, self.location)
}
}
let mut s = String::from("hello");
{
let r1 = &mut s;
} // r1 goes out of scope here, so we can make a new reference with no problems.
let r2 = &mut s;
enum Message {
Quit,
Move { x: i32, y: i32 },
Write(String),
ChangeColor(i32, i32, i32),
}
$ # Installation
$ curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh
$ rustup update # Optional: Update the tool chain
$ cd PyCon2024 && git checkout rust-examples
PyCon2024$ # git checkout rust_test # or rust_skeleton
fn main() {
let data: Array2 = read_npy("data/data.npy").unwrap();
println!("{}", data);
let responsibilities: Array2 = read_npy("data/responsibilities.npy").unwrap();
println!("{}", responsibilities);
let means: Array2 = read_npy("data/means.npy").unwrap();
println!("{}", means);
}
use ndarray::prelude::*;
pub fn foo(data: Array2) -> Array2 { Array2::::zeros((0,0)) }
use ndarray::prelude::*;
pub fn foo(data: &Array2) -> Array2 { Array2::::zeros((0,0)) }
use ndarray::prelude::*;
pub fn foo(mut data: &Array2, other: ArrayView2:: ) {
temp.assign(&data);
}
let sum_responsibilities = responsibilities.sum_axis(Axis(0));
sum_responsibilities = responsibilities.sum(axis=1)
let x = (&responsibilities.slice(s![.., .., NewAxis]) * &data.slice(s![.., NewAxis, ..]))
x = np.sum(data[np.newaxis, :, :] * responsibilities[:, :, np.newaxis], axis=1)
let cov = &x.t().dot(&y)
covs = x.T @ y
ndarray
has an interface
that reminds of numpy
einsum
in Rust criterion.rs
PyCon2024$ git checkout bindings
PyCon2024$ # git checkout benchmarks # spoiler alert!
PyCon2024$ maturin develop -r --strip # Builds the extensions and adds it to the venv
PyCon2024$ maturin build -r --strip # Creates a binary wheel
data, _ = gmm.make_blobs(n_samples=10000, centers=20, n_features=2, random_state=7)
model = gmm.initialize(data, 20)
print(",",model.means)
r = gmm.expect(model, data)
einsum
13 ms ± 369 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
7.37 ms ± 194 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
3.49 ms ± 23.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)