From 4b6217ddb2dd5254f99c7a0b61dae4cfdef327f5 Mon Sep 17 00:00:00 2001 From: Ian Date: Thu, 3 Jul 2025 12:43:16 +0000 Subject: [PATCH 1/5] added new top n genes function --- src/sparse/csc.rs | 36 ++++++++++++++++++++++++++++++++++++ src/sparse/csr.rs | 31 +++++++++++++++++++++++++++++++ src/sparse/mod.rs | 8 ++++++++ 3 files changed, 75 insertions(+) diff --git a/src/sparse/csc.rs b/src/sparse/csc.rs index 2b27a3b..8825f07 100644 --- a/src/sparse/csc.rs +++ b/src/sparse/csc.rs @@ -7,6 +7,7 @@ use std::iter::Sum; use std::ops::Add; use std::ops::AddAssign; +use crate::sparse::MatrixNTop; use crate::utils::Normalize; use super::{ @@ -1027,6 +1028,41 @@ impl BatchMatrixMean for CscMatrix { } } +impl MatrixNTop for CscMatrix { + type Item = M; + + fn sum_row_n_top(&self, n: usize) -> anyhow::Result> + where + T: Float + NumCast + AddAssign + Sum { + let mut result = vec![T::zero(); self.nrows()]; + + let mut row_values: Vec> = vec![Vec::new(); self.nrows()]; + + for col_idx in 0..self.ncols() { + let col_start = self.col_offsets()[col_idx]; + let col_end = self.col_offsets()[col_idx + 1]; + + for idx in col_start..col_end { + let row_idx = self.row_indices()[idx]; + if let Some(val) = T::from(self.values()[idx]) { + row_values[row_idx].push(val); + } + } + } + + for (row_idx, mut values) in row_values.into_iter().enumerate() { + if values.len() <= n { + result[row_idx] = values.into_iter().sum(); + } else { + values.sort_by(|a, b| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal)); + result[row_idx] = values.into_iter().take(n).sum(); + } + } + + Ok(result) + } +} + #[cfg(test)] mod tests { use Direction; diff --git a/src/sparse/csr.rs b/src/sparse/csr.rs index b67b9f0..b605b0c 100644 --- a/src/sparse/csr.rs +++ b/src/sparse/csr.rs @@ -6,6 +6,7 @@ use std::ops::{Add, AddAssign}; use super::{ BatchMatrixMean, BatchMatrixVariance, MatrixMinMax, MatrixNonZero, MatrixSum, MatrixVariance, }; +use crate::sparse::MatrixNTop; use crate::utils::Normalize; use crate::utils::{BatchIdentifier, Log1P}; use anyhow::{anyhow, Ok}; @@ -1033,6 +1034,36 @@ impl BatchMatrixMean for CsrMatrix { } } +impl MatrixNTop for CsrMatrix { + type Item = M; + + fn sum_row_n_top(&self, n: usize) -> anyhow::Result> + where + T: Float + NumCast + AddAssign + Sum { + let mut result = vec![T::zero(); self.nrows()]; + + for row_idx in 0..self.nrows() { + let row_start = self.row_offsets()[row_idx]; + let row_end = self.row_offsets()[row_idx + 1]; + + let mut row_values: Vec = Vec::new(); + for idx in row_start..row_end { + if let Some(val) = T::from(self.values()[idx]) { + row_values.push(val); + } + } + + if row_values.len() <= n { + result[row_idx] = row_values.into_iter().sum(); + } else { + row_values.sort_by(|a, b| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal)); + result[row_idx] = row_values.into_iter().take(n).sum(); + } + } + Ok(result) + } +} + #[cfg(test)] mod tests { use Direction; diff --git a/src/sparse/mod.rs b/src/sparse/mod.rs index df39848..d89925c 100644 --- a/src/sparse/mod.rs +++ b/src/sparse/mod.rs @@ -164,4 +164,12 @@ pub trait BatchMatrixMean { B: BatchIdentifier; } +pub trait MatrixNTop { + type Item: NumCast; + + fn sum_row_n_top(&self, n: usize) -> anyhow::Result> + where + T: Float + NumCast + AddAssign + std::iter::Sum; +} + From bc711b249aee67af0b8f84ea68ef04b23c0b45f9 Mon Sep 17 00:00:00 2001 From: Ian Date: Thu, 3 Jul 2025 12:44:29 +0000 Subject: [PATCH 2/5] version bump --- Cargo.lock | 2 +- Cargo.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index effb019..ac08a1a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1744,7 +1744,7 @@ dependencies = [ [[package]] name = "single_algebra" -version = "0.7.0" +version = "0.8.0" dependencies = [ "anyhow", "approx", diff --git a/Cargo.toml b/Cargo.toml index 7d0085a..7d33b76 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "single_algebra" -version = "0.7.0" +version = "0.8.0" edition = "2021" license-file = "LICENSE.md" description = "A linear algebra convenience library for the single-rust library. Can be used externally as well." From 5055e02396da4349378f739f3211887aeed67251 Mon Sep 17 00:00:00 2001 From: Ian Date: Thu, 3 Jul 2025 16:09:16 +0000 Subject: [PATCH 3/5] uploaded performance of matrix sum and non-zero operations. Version bump --- Cargo.lock | 15 +- Cargo.toml | 4 +- src/sparse/csc.rs | 30 ++- src/sparse/csr.rs | 559 +++++++++++++++++++++++++++++++++++++--------- src/sparse/mod.rs | 79 ++++--- 5 files changed, 521 insertions(+), 166 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index ac08a1a..4f89e76 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1729,7 +1729,7 @@ dependencies = [ "rand 0.9.0", "rand_distr 0.5.1", "rayon", - "single-utilities", + "single-utilities 0.6.0", "thiserror 2.0.9", ] @@ -1742,9 +1742,18 @@ dependencies = [ "num-traits", ] +[[package]] +name = "single-utilities" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bd8b90e8aa6128de9b26903484c2cee7556cf342b190ea3570396e293d1991de" +dependencies = [ + "num-traits", +] + [[package]] name = "single_algebra" -version = "0.8.0" +version = "0.8.1" dependencies = [ "anyhow", "approx", @@ -1761,7 +1770,7 @@ dependencies = [ "rayon", "simba", "single-svdlib", - "single-utilities", + "single-utilities 0.7.0", "smartcore", ] diff --git a/Cargo.toml b/Cargo.toml index 7d33b76..ebed76e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "single_algebra" -version = "0.8.0" +version = "0.8.1" edition = "2021" license-file = "LICENSE.md" description = "A linear algebra convenience library for the single-rust library. Can be used externally as well." @@ -37,7 +37,7 @@ simba = { version = "0.9.0", optional = true } smartcore = { version = "0.4", features = ["ndarray-bindings"], optional = true } single-svdlib = { version = "1.0.4" } rand = "0.9.0" -single-utilities = "0.6.0" +single-utilities = "0.7.0" [dev-dependencies] criterion = "0.5.1" diff --git a/src/sparse/csc.rs b/src/sparse/csc.rs index 8825f07..937f863 100644 --- a/src/sparse/csc.rs +++ b/src/sparse/csc.rs @@ -2,9 +2,7 @@ use nalgebra_sparse::CscMatrix; use num_traits::{Float, NumCast, PrimInt, Unsigned, Zero}; use single_utilities::types::Direction; use std::collections::{HashMap, HashSet}; -use std::hash::Hash; use std::iter::Sum; -use std::ops::Add; use std::ops::AddAssign; use crate::sparse::MatrixNTop; @@ -360,8 +358,8 @@ where fn var_col(&self) -> anyhow::Result> where - I: PrimInt + Unsigned + Zero + AddAssign + Into, - T: Float + NumCast + AddAssign + std::iter::Sum, + I: PrimInt + Unsigned + Zero + AddAssign + Into + Send + Sync, + T: Float + NumCast + AddAssign + std::iter::Sum + Send + Sync, Self::Item: NumCast, { let sum: Vec = self.sum_col()?; @@ -386,8 +384,8 @@ where fn var_row(&self) -> anyhow::Result> where - I: PrimInt + Unsigned + Zero + AddAssign + Into, - T: Float + NumCast + AddAssign + std::iter::Sum, + I: PrimInt + Unsigned + Zero + AddAssign + Into + Send + Sync, + T: Float + NumCast + AddAssign + std::iter::Sum + Send + Sync, Self::Item: NumCast, { let sum: Vec = self.sum_row()?; @@ -411,8 +409,8 @@ where fn var_col_chunk(&self, reference: &mut [T]) -> anyhow::Result<()> where - I: PrimInt + Unsigned + Zero + AddAssign + Into, - T: Float + NumCast + AddAssign + std::iter::Sum, + I: PrimInt + Unsigned + Zero + AddAssign + Into + Send + Sync, + T: Float + NumCast + AddAssign + std::iter::Sum + Send + Sync, Self::Item: NumCast, { // Validate input slice length matches number of columns @@ -450,8 +448,8 @@ where fn var_row_chunk(&self, reference: &mut [T]) -> anyhow::Result<()> where - I: PrimInt + Unsigned + Zero + AddAssign + Into, - T: Float + NumCast + AddAssign + std::iter::Sum, + I: PrimInt + Unsigned + Zero + AddAssign + Into + Send + Sync, + T: Float + NumCast + AddAssign + std::iter::Sum + Send + Sync, Self::Item: NumCast, { // Validate input slice length matches number of rows @@ -489,8 +487,8 @@ where fn var_col_masked(&self, mask: &[bool]) -> anyhow::Result> where - I: PrimInt + Unsigned + Zero + AddAssign + Into, - T: Float + NumCast + AddAssign + Sum, + I: PrimInt + Unsigned + Zero + AddAssign + Into + Send + Sync, + T: Float + NumCast + AddAssign + Sum + Send + Sync, { // Validate mask length if mask.len() < self.nrows() { @@ -538,8 +536,8 @@ where fn var_row_masked(&self, mask: &[bool]) -> anyhow::Result> where - I: PrimInt + Unsigned + Zero + AddAssign + Into, - T: Float + NumCast + AddAssign + Sum, + I: PrimInt + Unsigned + Zero + AddAssign + Into + Send + Sync, + T: Float + NumCast + AddAssign + Sum + Send + Sync { // Validate mask length if mask.len() < self.ncols() { @@ -591,7 +589,7 @@ impl MatrixMinMax for CscMatrix fn min_max_col(&self) -> anyhow::Result<(Vec, Vec)> where - Item: NumCast + Copy + PartialOrd + NumericOps, + Item: NumCast + Copy + PartialOrd + NumericOps + Send + Sync, { let mut min: Vec = vec![Item::max_value(); self.ncols()]; let mut max: Vec = vec![Item::min_value(); self.ncols()]; @@ -602,7 +600,7 @@ impl MatrixMinMax for CscMatrix fn min_max_row(&self) -> anyhow::Result<(Vec, Vec)> where - Item: NumCast + Copy + PartialOrd + NumericOps, + Item: NumCast + Copy + PartialOrd + NumericOps + Send + Sync, { let mut min: Vec = vec![Item::max_value(); self.nrows()]; let mut max: Vec = vec![Item::min_value(); self.nrows()]; diff --git a/src/sparse/csr.rs b/src/sparse/csr.rs index b605b0c..bb426e9 100644 --- a/src/sparse/csr.rs +++ b/src/sparse/csr.rs @@ -1,7 +1,6 @@ use std::collections::HashMap; -use std::hash::Hash; use std::iter::Sum; -use std::ops::{Add, AddAssign}; +use std::ops::AddAssign; use super::{ BatchMatrixMean, BatchMatrixVariance, MatrixMinMax, MatrixNonZero, MatrixSum, MatrixVariance, @@ -11,19 +10,29 @@ use crate::utils::Normalize; use crate::utils::{BatchIdentifier, Log1P}; use anyhow::{anyhow, Ok}; use nalgebra_sparse::CsrMatrix; -use num_traits::{Float, NumCast, One, PrimInt, Unsigned, Zero}; +use num_traits::{Float, NumCast, PrimInt, Unsigned, Zero}; +use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator}; +use rayon::slice::ParallelSliceMut; use single_utilities::traits::{FloatOpsTS, NumericOps}; use single_utilities::types::Direction; +const PARALLEL_THRESHOLD: usize = 200_000; +const CHUNK_SIZE: usize = 512; + impl MatrixNonZero for CsrMatrix { fn nonzero_col(&self) -> anyhow::Result> where T: PrimInt + Unsigned + Zero + AddAssign, { let mut result = vec![T::zero(); self.ncols()]; - for &col_index in self.col_indices() { - result[col_index] += T::one(); + let col_indices = self.col_indices(); + + for chunk in col_indices.chunks(CHUNK_SIZE) { + for &col_index in chunk { + result[col_index] += T::one(); + } } + Ok(result) } @@ -31,20 +40,95 @@ impl MatrixNonZero for CsrMatrix { where T: PrimInt + Unsigned + Zero + AddAssign, { - let data = self - .row_offsets() - .windows(2) - .map(|window| { - let diff = window[1] - .checked_sub(window[0]) - .ok_or_else(|| anyhow!("Subtraction overflow")) - .expect("Subtraction overflow"); - T::from(diff) - .ok_or_else(|| anyhow!("Failed to convert to target type")) - .expect("Failed to convert to target type") - }) - .collect(); - Ok(data) + let row_offsets = self.row_offsets(); + let n_rows = self.nrows(); + + // Early return for empty matrix + if n_rows == 0 { + return Ok(Vec::new()); + } + // Estimate total non-zeros using sampling for larger matrices + let estimated_nonzeros = { + let sample_size = n_rows.min(100); + let sample_sum: usize = (0..sample_size) + .map(|i| row_offsets[i + 1] - row_offsets[i]) + .sum(); + (sample_sum * n_rows) / sample_size + }; + + if estimated_nonzeros >= PARALLEL_THRESHOLD { + // Parallel implementation + let n_cores = rayon::current_num_threads(); + let chunk_size = (n_rows / (n_cores * 4)).max(1024); + + // Pre-allocate result + let mut result = vec![T::zero(); n_rows]; + + // Use par_chunks_mut for correct forward iteration + result.rchunks_mut(chunk_size).enumerate().try_for_each( + |(chunk_idx, chunk)| -> anyhow::Result<()> { + let start_idx = chunk_idx * chunk_size; + + for (i, item) in chunk.iter_mut().enumerate() { + let row_idx = start_idx + i; + let count = row_offsets[row_idx + 1] - row_offsets[row_idx]; + *item = T::from(count).ok_or_else(|| { + anyhow::anyhow!("Count {} exceeds target type capacity", count) + })?; + } + Ok(()) + }, + )?; + + Ok(result) + } else { + // Sequential implementation for medium-sized matrices + let mut result = Vec::with_capacity(n_rows); + + let chunks = n_rows / 8; + let remainder = n_rows % 8; + + unsafe { + let offsets_ptr = row_offsets.as_ptr(); + + for chunk_idx in 0..chunks { + let base = chunk_idx * 8; + + // Prefetch on x86_64 + #[cfg(target_arch = "x86_64")] + { + use std::arch::x86_64::_mm_prefetch; + if base + 16 < row_offsets.len() { + _mm_prefetch(offsets_ptr.add(base + 16) as *const i8, 1); + } + } + + for i in 0..8 { + let idx = base + i; + let start = *offsets_ptr.add(idx); + let end = *offsets_ptr.add(idx + 1); + let count = end - start; + + result.push(T::from(count).ok_or_else(|| { + anyhow::anyhow!("Count {} exceeds target type capacity", count) + })?); + } + } + + for i in 0..remainder { + let idx = chunks * 8 + i; + let start = *offsets_ptr.add(idx); + let end = *offsets_ptr.add(idx + 1); + let count = end - start; + + result.push(T::from(count).ok_or_else(|| { + anyhow::anyhow!("Count {} exceeds target type capacity", count) + })?); + } + } + + Ok(result) + } } fn nonzero_col_chunk(&self, reference: &mut [T]) -> anyhow::Result<()> @@ -115,7 +199,6 @@ impl MatrixNonZero for CsrMatrix { where T: PrimInt + Unsigned + Zero + AddAssign, { - // Validate mask length if mask.len() < self.ncols() { return Err(anyhow::anyhow!( "Mask length ({}) is less than number of columns ({})", @@ -124,31 +207,131 @@ impl MatrixNonZero for CsrMatrix { )); } - let mut result = vec![T::zero(); self.nrows()]; + let n_rows = self.nrows(); + let row_offsets = self.row_offsets(); + let col_indices = self.col_indices(); - // Process each row - for row in 0..self.nrows() { - let row_start = self.row_offsets()[row]; - let row_end = self.row_offsets()[row + 1]; + // Early return for empty matrix + if n_rows == 0 { + return Ok(Vec::new()); + } - // Count non-zero elements in this row that are in masked-in columns - for idx in row_start..row_end { - let col = self.col_indices()[idx]; + // Quick check: if very few columns are masked in, sequential might be better + let masked_in_count = mask.iter().filter(|&&m| m).count(); + if masked_in_count == 0 { + // All columns masked out - return zeros + return Ok(vec![T::zero(); n_rows]); + } + + let total_nnz = self.nnz(); + + if total_nnz < PARALLEL_THRESHOLD || masked_in_count < 10 { + let mut result = vec![T::zero(); n_rows]; + + // If most columns are masked in, use branch-free counting + if masked_in_count > self.ncols() * 3 / 4 { + // Branch-free version for mostly-true masks + unsafe { + let offsets_ptr = row_offsets.as_ptr(); + let indices_ptr = col_indices.as_ptr(); + let mask_ptr = mask.as_ptr(); + + for row in 0..n_rows { + let start = *offsets_ptr.add(row); + let end = *offsets_ptr.add(row + 1); + let mut count = T::zero(); + + for idx in start..end { + let col = *indices_ptr.add(idx); + // Branch-free increment: convert bool to 0 or 1 + let increment = *mask_ptr.add(col) as usize; + if increment > 0 { + count = count + T::one(); + } + } - // Skip this column if masked out - if !mask[col] { - continue; + result[row] = count; + } } + } else { + // Standard version with branching for sparse masks + for row in 0..n_rows { + let row_start = row_offsets[row]; + let row_end = row_offsets[row + 1]; + let mut count = T::zero(); + + // Unroll by 4 for better performance + let chunk_end = row_start + ((row_end - row_start) / 4) * 4; + + for idx in (row_start..chunk_end).step_by(4) { + // Process 4 elements at a time + if mask[col_indices[idx]] { + count = count + T::one(); + } + if mask[col_indices[idx + 1]] { + count = count + T::one(); + } + if mask[col_indices[idx + 2]] { + count = count + T::one(); + } + if mask[col_indices[idx + 3]] { + count = count + T::one(); + } + } + + // Handle remainder + for idx in chunk_end..row_end { + if mask[col_indices[idx]] { + count = count + T::one(); + } + } - result[row] += T::one(); + result[row] = count; + } } - } - Ok(result) + Ok(result) + } else { + let chunk_size = (n_rows / (rayon::current_num_threads() * 4)).max(1024); + + // Pre-allocate result + let mut result = vec![T::zero(); n_rows]; + + // Process rows in parallel chunks + result + .rchunks_mut(chunk_size) + .enumerate() + .for_each(|(chunk_idx, result_chunk)| { + let start_row = chunk_idx * chunk_size; + + for (i, count) in result_chunk.iter_mut().enumerate() { + let row = start_row + i; + if row >= n_rows { + break; + } + + let row_start = row_offsets[row]; + let row_end = row_offsets[row + 1]; + let mut local_count = T::zero(); + + // For better cache locality, process in smaller sub-chunks + for idx in row_start..row_end { + let col = col_indices[idx]; + if mask[col] { + local_count += T::one(); + } + } + + *count = local_count; + } + }); + + Ok(result) + } } } -impl MatrixSum for CsrMatrix { +impl MatrixSum for CsrMatrix { type Item = M; fn sum_col(&self) -> anyhow::Result> @@ -179,56 +362,135 @@ impl MatrixSum for CsrMatrix { fn sum_row(&self) -> anyhow::Result> where - T: Float + NumCast + AddAssign + std::iter::Sum, + T: Float + NumCast + AddAssign + std::iter::Sum + Send + Sync, Self::Item: NumCast, { let nrows = self.nrows(); - let mut result = vec![T::zero(); nrows]; let values = self.values(); let row_offsets = self.row_offsets(); - // Process in chunks of 4 rows when possible - let chunk_size = 4; - let chunks = nrows / chunk_size; - let remainder = nrows % chunk_size; - - // Process chunks - for chunk in 0..chunks { - let base = chunk * chunk_size; - let mut sums = [M::zero(); 4]; - - // Process 4 rows at once to improve instruction-level parallelism - (0..chunk_size).enumerate().for_each(|(i, offset)| { - let row = base + offset; - let start = row_offsets[row]; - let end = row_offsets[row + 1]; - - // Direct sum in original type - for &val in &values[start..end] { - sums[i] += val; + if nrows == 0 { + return Ok(Vec::new()); + } + + let total_nnz = values.len(); + + if total_nnz >= PARALLEL_THRESHOLD { + let sums: Vec = (0..nrows) + .into_par_iter() + .map(|row| { + let start = row_offsets[row]; + let end = row_offsets[row + 1]; + + if end - start > 16 { + let mut sum = M::zero(); + let chunk_size = 8; + let chunks = (end - start) / chunk_size; + let remainder = (end - start) % chunk_size; + + for c in 0..chunks { + let base = start + c * chunk_size; + let mut chunk_sum = M::zero(); + + chunk_sum = chunk_sum + values[base]; + chunk_sum = chunk_sum + values[base + 1]; + chunk_sum = chunk_sum + values[base + 2]; + chunk_sum = chunk_sum + values[base + 3]; + chunk_sum = chunk_sum + values[base + 4]; + chunk_sum = chunk_sum + values[base + 5]; + chunk_sum = chunk_sum + values[base + 6]; + chunk_sum = chunk_sum + values[base + 7]; + + sum = sum + chunk_sum; + } + + for i in 0..remainder { + sum = sum + values[start + chunks * chunk_size + i]; + } + + T::from(sum).unwrap() + } else { + let mut sum = M::zero(); + for i in start..end { + sum = sum + values[i]; + } + T::from(sum).unwrap() + } + }) + .collect(); + Ok(sums) + } else { + let mut result = vec![T::zero(); nrows]; + + let chunk_size = 8; + let chunks = nrows / chunk_size; + + unsafe { + let values_ptr = values.as_ptr(); + let offsets_ptr = row_offsets.as_ptr(); + let result_ptr = result.as_mut_ptr(); + + for chunk in 0..chunks { + let base = chunk * chunk_size; + + #[cfg(target_arch = "x86_64")] + { + use std::arch::x86_64::_mm_prefetch; + if base + chunk_size < nrows { + let next_start = *offsets_ptr.add(base + chunk_size); + if next_start < values.len() { + _mm_prefetch(values_ptr.add(next_start) as *const i8, 1); + } + } + } + + let mut sums = [M::zero(); 8]; + + for i in 0..8 { + let row = base + i; + let start = *offsets_ptr.add(row); + let end = *offsets_ptr.add(row + 1); + + let mut sum = M::zero(); + + let inner_chunks = (end - start) / 4; + let inner_remainder = (end - start) % 4; + + for j in 0..inner_chunks { + let idx = start + j * 4; + sum = sum + *values_ptr.add(idx); + sum = sum + *values_ptr.add(idx + 1); + sum = sum + *values_ptr.add(idx + 2); + sum = sum + *values_ptr.add(idx + 3); + } + + for j in 0..inner_remainder { + sum = sum + *values_ptr.add(start + inner_chunks * 4 + j); + } + + sums[i] = sum; + } + + for i in 0..8 { + *result_ptr.add(base + i) = T::from(sums[i]).unwrap(); + } } - }); - // Convert results for the chunk - sums.iter().enumerate().for_each(|(i, &sum)| { - result[base + i] = T::from(sum).unwrap(); - }); - } + for row in (chunks * chunk_size)..nrows { + let start = *offsets_ptr.add(row); + let end = *offsets_ptr.add(row + 1); + let mut sum = M::zero(); - // Handle remaining rows - let base = chunks * chunk_size; - for row in base..nrows { - let start = row_offsets[row]; - let end = row_offsets[row + 1]; - let mut sum = M::zero(); + for idx in start..end { + sum = sum + *values_ptr.add(idx); + } - for &val in &values[start..end] { - sum += val; + *result_ptr.add(row) = T::from(sum).unwrap(); + } } - result[row] = T::from(sum).unwrap(); - } - Ok(result) + Ok(result) + } } fn sum_col_chunk(&self, reference: &mut [T]) -> anyhow::Result<()> @@ -293,9 +555,8 @@ impl MatrixSum for CsrMatrix { fn sum_row_masked(&self, mask: &[bool]) -> anyhow::Result> where - T: Float + NumCast + AddAssign + Sum, + T: Float + NumCast + AddAssign + Sum + Send + Sync, { - // Validate mask length if mask.len() < self.ncols() { return Err(anyhow::anyhow!( "Mask length ({}) is less than number of columns ({})", @@ -304,28 +565,117 @@ impl MatrixSum for CsrMatrix { )); } - let mut result = vec![T::zero(); self.nrows()]; + let n_rows = self.nrows(); + let row_offsets = self.row_offsets(); + let col_indices = self.col_indices(); + let values = self.values(); - // Process each row - for row in 0..self.nrows() { - let row_start = self.row_offsets()[row]; - let row_end = self.row_offsets()[row + 1]; + if n_rows == 0 { + return Ok(Vec::new()); + } - // Process all non-zero elements in this row - for idx in row_start..row_end { - let col = self.col_indices()[idx]; + let masked_in_count = mask.iter().filter(|&&m| m).count(); + if masked_in_count == 0 { + return Ok(vec![T::zero(); n_rows]); + } - // Skip this column if masked out - if !mask[col] { - continue; + let total_nnz = self.nnz(); + const PARALLEL_THRESHOLD: usize = 10000; + + if total_nnz < PARALLEL_THRESHOLD || masked_in_count < 10 { + let mut result = vec![T::zero(); n_rows]; + + if masked_in_count > self.ncols() * 3 / 4 { + unsafe { + let offsets_ptr = row_offsets.as_ptr(); + let indices_ptr = col_indices.as_ptr(); + let values_ptr = values.as_ptr(); + let mask_ptr = mask.as_ptr(); + + for row in 0..n_rows { + let start = *offsets_ptr.add(row); + let end = *offsets_ptr.add(row + 1); + let mut sum = T::zero(); + + for idx in start..end { + let col = *indices_ptr.add(idx); + let increment = *mask_ptr.add(col) as usize; + if increment > 0 { + sum += T::from(*values_ptr.add(idx)).unwrap(); + } + } + + result[row] = sum; + } } + } else { + for row in 0..n_rows { + let row_start = row_offsets[row]; + let row_end = row_offsets[row + 1]; + let mut sum = T::zero(); - let value = T::from(self.values()[idx]).unwrap(); - result[row] += value; + let chunk_end = row_start + ((row_end - row_start) / 4) * 4; + + for idx in (row_start..chunk_end).step_by(4) { + if mask[col_indices[idx]] { + sum += T::from(values[idx]).unwrap(); + } + if mask[col_indices[idx + 1]] { + sum += T::from(values[idx + 1]).unwrap(); + } + if mask[col_indices[idx + 2]] { + sum += T::from(values[idx + 2]).unwrap(); + } + if mask[col_indices[idx + 3]] { + sum += T::from(values[idx + 3]).unwrap(); + } + } + + for idx in chunk_end..row_end { + if mask[col_indices[idx]] { + sum += T::from(values[idx]).unwrap(); + } + } + + result[row] = sum; + } } - } - Ok(result) + Ok(result) + } else { + let chunk_size = (n_rows / (rayon::current_num_threads() * 4)).max(1024); + + let mut result = vec![T::zero(); n_rows]; + + result + .par_chunks_mut(chunk_size) + .enumerate() + .for_each(|(chunk_idx, result_chunk)| { + let start_row = chunk_idx * chunk_size; + + for (i, sum) in result_chunk.iter_mut().enumerate() { + let row = start_row + i; + if row >= n_rows { + break; + } + + let row_start = row_offsets[row]; + let row_end = row_offsets[row + 1]; + let mut local_sum = T::zero(); + + for idx in row_start..row_end { + let col = col_indices[idx]; + if mask[col] { + local_sum += T::from(values[idx]).unwrap(); + } + } + + *sum = local_sum; + } + }); + + Ok(result) + } } fn sum_col_squared(&self) -> anyhow::Result> @@ -367,7 +717,7 @@ where fn var_col(&self) -> anyhow::Result> where I: PrimInt + Unsigned + Zero + AddAssign + Into, - T: Float + NumCast + AddAssign + Sum, + T: Float + NumCast + AddAssign + Sum + Send + Sync, { let sum: Vec = self.sum_col()?; let squared_sums: Vec = self.sum_col_squared()?; @@ -392,7 +742,7 @@ where fn var_row(&self) -> anyhow::Result> where I: PrimInt + Unsigned + Zero + AddAssign + Into, - T: Float + NumCast + AddAssign + Sum, + T: Float + NumCast + AddAssign + Sum + Send + Sync, Self::Item: NumCast, { let sum: Vec = self.sum_row()?; @@ -417,8 +767,8 @@ where /// Calculate column-wise variance and store results in the provided slice fn var_col_chunk(&self, reference: &mut [T]) -> anyhow::Result<()> where - I: PrimInt + Unsigned + Zero + AddAssign + Into, - T: Float + NumCast + AddAssign + Sum, + I: PrimInt + Unsigned + Zero + AddAssign + Into + Send + Sync, + T: Float + NumCast + AddAssign + Sum + Send + Sync, Self::Item: NumCast, { let ncols = self.ncols(); @@ -456,8 +806,8 @@ where fn var_row_chunk(&self, reference: &mut [T]) -> anyhow::Result<()> where - I: PrimInt + Unsigned + Zero + AddAssign + Into, - T: Float + NumCast + AddAssign + Sum, + I: PrimInt + Unsigned + Zero + AddAssign + Into + Send + Sync, + T: Float + NumCast + AddAssign + Sum + Send + Sync, Self::Item: NumCast, { let nrows = self.nrows(); @@ -504,8 +854,8 @@ where fn var_col_masked(&self, mask: &[bool]) -> anyhow::Result> where - I: PrimInt + Unsigned + Zero + AddAssign + Into, - T: Float + NumCast + AddAssign + Sum, + I: PrimInt + Unsigned + Zero + AddAssign + Into + Send + Sync, + T: Float + NumCast + AddAssign + Sum + Send + Sync, { // Validate mask length if mask.len() < self.nrows() { @@ -553,8 +903,8 @@ where fn var_row_masked(&self, mask: &[bool]) -> anyhow::Result> where - I: PrimInt + Unsigned + Zero + AddAssign + Into, - T: Float + NumCast + AddAssign + Sum, + I: PrimInt + Unsigned + Zero + AddAssign + Into + Send + Sync, + T: Float + NumCast + AddAssign + Sum + Send + Sync, { // Validate mask length if mask.len() < self.ncols() { @@ -606,7 +956,7 @@ impl MatrixMinMax for CsrMatrix fn min_max_col(&self) -> anyhow::Result<(Vec, Vec)> where - Item: NumCast + Copy + PartialOrd + NumericOps, + Item: NumCast + Copy + PartialOrd + NumericOps + Send + Sync, { let mut min: Vec = vec![Item::max_value(); self.ncols()]; let mut max: Vec = vec![Item::min_value(); self.ncols()]; @@ -617,7 +967,7 @@ impl MatrixMinMax for CsrMatrix fn min_max_row(&self) -> anyhow::Result<(Vec, Vec)> where - Item: NumCast + Copy + PartialOrd + NumericOps, + Item: NumCast + Copy + PartialOrd + NumericOps + Send + Sync, { let mut min: Vec = vec![Item::max_value(); self.nrows()]; let mut max: Vec = vec![Item::min_value(); self.nrows()]; @@ -628,7 +978,7 @@ impl MatrixMinMax for CsrMatrix fn min_max_col_chunk(&self, reference: (&mut [Item], &mut [Item])) -> anyhow::Result<()> where - Item: NumCast + Copy + PartialOrd + NumericOps, + Item: NumCast + Copy + PartialOrd + NumericOps + Send + Sync, { let (min_vals, max_vals) = reference; @@ -1039,7 +1389,8 @@ impl MatrixNTop for CsrMatrix { fn sum_row_n_top(&self, n: usize) -> anyhow::Result> where - T: Float + NumCast + AddAssign + Sum { + T: Float + NumCast + AddAssign + Sum, + { let mut result = vec![T::zero(); self.nrows()]; for row_idx in 0..self.nrows() { diff --git a/src/sparse/mod.rs b/src/sparse/mod.rs index d89925c..0dc2e8b 100644 --- a/src/sparse/mod.rs +++ b/src/sparse/mod.rs @@ -1,10 +1,9 @@ use std::collections::HashMap; -use std::hash::Hash; use std::ops::AddAssign; +use crate::utils::BatchIdentifier; use num_traits::{Float, NumCast, PrimInt, Unsigned, Zero}; use single_utilities::traits::NumericOps; -use crate::utils::BatchIdentifier; pub mod csc; pub mod csr; @@ -12,29 +11,29 @@ pub mod csr; pub trait MatrixNonZero { fn nonzero_col(&self) -> anyhow::Result> where - T: PrimInt + Unsigned + Zero + AddAssign; + T: PrimInt + Unsigned + Zero + AddAssign + Send + Sync; fn nonzero_row(&self) -> anyhow::Result> where - T: PrimInt + Unsigned + Zero + AddAssign; + T: PrimInt + Unsigned + Zero + AddAssign + Send + Sync; fn nonzero_col_chunk(&self, reference: &mut [T]) -> anyhow::Result<()> where - T: PrimInt + Unsigned + Zero + AddAssign; + T: PrimInt + Unsigned + Zero + AddAssign + Send + Sync; fn nonzero_row_chunk(&self, reference: &mut [T]) -> anyhow::Result<()> where - T: PrimInt + Unsigned + Zero + AddAssign; + T: PrimInt + Unsigned + Zero + AddAssign + Send + Sync; /// Calculate masked non-zero counts for columns fn nonzero_col_masked(&self, mask: &[bool]) -> anyhow::Result> where - T: PrimInt + Unsigned + Zero + AddAssign; + T: PrimInt + Unsigned + Zero + AddAssign + Send + Sync; /// Calculate masked non-zero counts for rows fn nonzero_row_masked(&self, mask: &[bool]) -> anyhow::Result> where - T: PrimInt + Unsigned + Zero + AddAssign; + T: PrimInt + Unsigned + Zero + AddAssign + Send + Sync; } pub trait MatrixSum { @@ -42,36 +41,36 @@ pub trait MatrixSum { fn sum_col(&self) -> anyhow::Result> where - T: Float + num_traits::NumCast + AddAssign + std::iter::Sum; + T: Float + num_traits::NumCast + AddAssign + std::iter::Sum + Send + Sync; fn sum_row(&self) -> anyhow::Result> where - T: Float + num_traits::NumCast + AddAssign + std::iter::Sum; + T: Float + num_traits::NumCast + AddAssign + std::iter::Sum + Send + Sync; fn sum_col_chunk(&self, reference: &mut [T]) -> anyhow::Result<()> where - T: Float + num_traits::NumCast + AddAssign + std::iter::Sum; + T: Float + num_traits::NumCast + AddAssign + std::iter::Sum + Send + Sync; fn sum_row_chunk(&self, reference: &mut [T]) -> anyhow::Result<()> where - T: Float + num_traits::NumCast + AddAssign + std::iter::Sum; + T: Float + num_traits::NumCast + AddAssign + std::iter::Sum + Send + Sync; fn sum_col_masked(&self, mask: &[bool]) -> anyhow::Result> where - T: Float + NumCast + AddAssign + std::iter::Sum; + T: Float + NumCast + AddAssign + std::iter::Sum + Send + Sync; /// Calculate masked sum for rows fn sum_row_masked(&self, mask: &[bool]) -> anyhow::Result> where - T: Float + NumCast + AddAssign + std::iter::Sum; - - fn sum_col_squared(&self) -> anyhow::Result> + T: Float + NumCast + AddAssign + std::iter::Sum + Send + Sync; + + fn sum_col_squared(&self) -> anyhow::Result> where - T: Float + NumCast + AddAssign + std::iter::Sum; + T: Float + NumCast + AddAssign + std::iter::Sum + Send + Sync; fn sum_row_squared(&self) -> anyhow::Result> where - T: Float + NumCast + AddAssign + std::iter::Sum; + T: Float + NumCast + AddAssign + std::iter::Sum + Send + Sync; } pub trait MatrixVariance { @@ -79,35 +78,35 @@ pub trait MatrixVariance { fn var_col(&self) -> anyhow::Result> where - I: PrimInt + Unsigned + Zero + AddAssign + Into, - T: Float + num_traits::NumCast + AddAssign + std::iter::Sum; + I: PrimInt + Unsigned + Zero + AddAssign + Into + Send + Sync, + T: Float + num_traits::NumCast + AddAssign + std::iter::Sum + Send + Sync; fn var_row(&self) -> anyhow::Result> where - I: PrimInt + Unsigned + Zero + AddAssign + Into, - T: Float + num_traits::NumCast + AddAssign + std::iter::Sum; + I: PrimInt + Unsigned + Zero + AddAssign + Into + Send + Sync, + T: Float + num_traits::NumCast + AddAssign + std::iter::Sum + Send + Sync; fn var_col_chunk(&self, reference: &mut [T]) -> anyhow::Result<()> where - I: PrimInt + Unsigned + Zero + AddAssign + Into, - T: Float + num_traits::NumCast + AddAssign + std::iter::Sum; + I: PrimInt + Unsigned + Zero + AddAssign + Into + Send + Sync, + T: Float + num_traits::NumCast + AddAssign + std::iter::Sum + Send + Sync; fn var_row_chunk(&self, reference: &mut [T]) -> anyhow::Result<()> where - I: PrimInt + Unsigned + Zero + AddAssign + Into, - T: Float + num_traits::NumCast + AddAssign + std::iter::Sum; + I: PrimInt + Unsigned + Zero + AddAssign + Into + Send + Sync, + T: Float + num_traits::NumCast + AddAssign + std::iter::Sum + Send + Sync; /// Calculate masked variance for columns fn var_col_masked(&self, mask: &[bool]) -> anyhow::Result> where - I: PrimInt + Unsigned + Zero + AddAssign + Into, - T: Float + NumCast + AddAssign + std::iter::Sum; + I: PrimInt + Unsigned + Zero + AddAssign + Into + Send + Sync, + T: Float + NumCast + AddAssign + std::iter::Sum + Send + Sync; /// Calculate masked variance for rows fn var_row_masked(&self, mask: &[bool]) -> anyhow::Result> where - I: PrimInt + Unsigned + Zero + AddAssign + Into, - T: Float + NumCast + AddAssign + std::iter::Sum; + I: PrimInt + Unsigned + Zero + AddAssign + Into + Send + Sync, + T: Float + NumCast + AddAssign + std::iter::Sum + Send + Sync; } pub trait MatrixMinMax { @@ -115,19 +114,19 @@ pub trait MatrixMinMax { fn min_max_col(&self) -> anyhow::Result<(Vec, Vec)> where - Item: NumCast + Copy + PartialOrd + NumericOps; + Item: NumCast + Copy + PartialOrd + NumericOps + Send + Sync; fn min_max_row(&self) -> anyhow::Result<(Vec, Vec)> where - Item: NumCast + Copy + PartialOrd + NumericOps; + Item: NumCast + Copy + PartialOrd + NumericOps + Send + Sync; fn min_max_col_chunk(&self, reference: (&mut [Item], &mut [Item])) -> anyhow::Result<()> where - Item: NumCast + Copy + PartialOrd + NumericOps; + Item: NumCast + Copy + PartialOrd + NumericOps + Send + Sync; fn min_max_row_chunk(&self, reference: (&mut [Item], &mut [Item])) -> anyhow::Result<()> where - Item: NumCast + Copy + PartialOrd + NumericOps; + Item: NumCast + Copy + PartialOrd + NumericOps + Send + Sync; } pub trait BatchMatrixVariance { @@ -137,14 +136,14 @@ pub trait BatchMatrixVariance { fn var_batch_row(&self, batches: &[B]) -> anyhow::Result>> where I: PrimInt + Unsigned + Zero + AddAssign + Into, - T: Float + NumCast + AddAssign + std::iter::Sum, + T: Float + NumCast + AddAssign + std::iter::Sum + Send + Sync, B: BatchIdentifier; /// Calculate column-wise variance for each batch fn var_batch_col(&self, batches: &[B]) -> anyhow::Result>> where I: PrimInt + Unsigned + Zero + AddAssign + Into, - T: Float + NumCast + AddAssign + std::iter::Sum, + T: Float + NumCast + AddAssign + std::iter::Sum + Send + Sync, B: BatchIdentifier; } @@ -154,13 +153,13 @@ pub trait BatchMatrixMean { /// Calculate row-wise mean for each batch fn mean_batch_row(&self, batches: &[B]) -> anyhow::Result>> where - T: Float + NumCast + AddAssign + std::iter::Sum, + T: Float + NumCast + AddAssign + std::iter::Sum + Send + Sync, B: BatchIdentifier; /// Calculate column-wise mean for each batch fn mean_batch_col(&self, batches: &[B]) -> anyhow::Result>> where - T: Float + NumCast + AddAssign + std::iter::Sum, + T: Float + NumCast + AddAssign + std::iter::Sum + Send + Sync, B: BatchIdentifier; } @@ -169,7 +168,5 @@ pub trait MatrixNTop { fn sum_row_n_top(&self, n: usize) -> anyhow::Result> where - T: Float + NumCast + AddAssign + std::iter::Sum; + T: Float + NumCast + AddAssign + std::iter::Sum + Send + Sync; } - - From 8914b476e41f18086fad687218457cb2015fec81 Mon Sep 17 00:00:00 2001 From: Ian Date: Thu, 3 Jul 2025 23:08:22 +0000 Subject: [PATCH 4/5] optimized functions, version bump --- Cargo.lock | 2 +- Cargo.toml | 2 +- src/sparse/csr.rs | 452 ++++++++++++++++++++++++++++++++++++++++++---- 3 files changed, 416 insertions(+), 40 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 4f89e76..5e5e7be 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1753,7 +1753,7 @@ dependencies = [ [[package]] name = "single_algebra" -version = "0.8.1" +version = "0.8.2" dependencies = [ "anyhow", "approx", diff --git a/Cargo.toml b/Cargo.toml index ebed76e..d4a3fef 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "single_algebra" -version = "0.8.1" +version = "0.8.2" edition = "2021" license-file = "LICENSE.md" description = "A linear algebra convenience library for the single-rust library. Can be used externally as well." diff --git a/src/sparse/csr.rs b/src/sparse/csr.rs index bb426e9..a66dbc1 100644 --- a/src/sparse/csr.rs +++ b/src/sparse/csr.rs @@ -1,6 +1,7 @@ use std::collections::HashMap; use std::iter::Sum; use std::ops::AddAssign; +use std::sync::Mutex; use super::{ BatchMatrixMean, BatchMatrixVariance, MatrixMinMax, MatrixNonZero, MatrixSum, MatrixVariance, @@ -22,18 +23,132 @@ const CHUNK_SIZE: usize = 512; impl MatrixNonZero for CsrMatrix { fn nonzero_col(&self) -> anyhow::Result> where - T: PrimInt + Unsigned + Zero + AddAssign, + T: PrimInt + Unsigned + Zero + AddAssign + Send + Sync, { - let mut result = vec![T::zero(); self.ncols()]; + let n_cols = self.ncols(); let col_indices = self.col_indices(); + let total_nnz = col_indices.len(); - for chunk in col_indices.chunks(CHUNK_SIZE) { - for &col_index in chunk { - result[col_index] += T::one(); - } + if total_nnz == 0 || n_cols == 0 { + return Ok(vec![T::zero(); n_cols]); } - Ok(result) + if total_nnz < PARALLEL_THRESHOLD { + let mut result = vec![T::zero(); n_cols]; + + for chunk_start in (0..total_nnz).step_by(CHUNK_SIZE) { + let chunk_end = (chunk_start + CHUNK_SIZE).min(total_nnz); + + #[cfg(target_arch = "x86_64")] + { + use std::arch::x86_64::_mm_prefetch; + if chunk_end < total_nnz { + unsafe { + let next_start = chunk_end; + let prefetch_end = (next_start + 64).min(total_nnz); + for i in (next_start..prefetch_end).step_by(8) { + _mm_prefetch(col_indices.as_ptr().add(i) as *const i8, 1); + } + } + } + } + + let unroll_end = chunk_start + ((chunk_end - chunk_start) / 8) * 8; + + unsafe { + let col_ptr = col_indices.as_ptr().add(chunk_start); + let result_ptr = result.as_mut_ptr(); + + for i in (0..(unroll_end - chunk_start)).step_by(8) { + let col0 = *col_ptr.add(i); + let col1 = *col_ptr.add(i + 1); + let col2 = *col_ptr.add(i + 2); + let col3 = *col_ptr.add(i + 3); + let col4 = *col_ptr.add(i + 4); + let col5 = *col_ptr.add(i + 5); + let col6 = *col_ptr.add(i + 6); + let col7 = *col_ptr.add(i + 7); + + *result_ptr.add(col0) = *result_ptr.add(col0) + T::one(); + *result_ptr.add(col1) = *result_ptr.add(col1) + T::one(); + *result_ptr.add(col2) = *result_ptr.add(col2) + T::one(); + *result_ptr.add(col3) = *result_ptr.add(col3) + T::one(); + *result_ptr.add(col4) = *result_ptr.add(col4) + T::one(); + *result_ptr.add(col5) = *result_ptr.add(col5) + T::one(); + *result_ptr.add(col6) = *result_ptr.add(col6) + T::one(); + *result_ptr.add(col7) = *result_ptr.add(col7) + T::one(); + } + } + + for idx in unroll_end..chunk_end { + result[col_indices[idx]] += T::one(); + } + } + + Ok(result) + } else { + let n_threads = rayon::current_num_threads(); + + let thread_results: Vec>> = (0..n_threads) + .map(|_| Mutex::new(vec![T::zero(); n_cols])) + .collect(); + + let chunk_size = (total_nnz / (n_threads * 8)).max(8192); + + (0..total_nnz) + .into_par_iter() + .chunks(chunk_size) + .enumerate() + .for_each(|(thread_id, chunk)| { + let thread_id = thread_id % n_threads; + let mut local_result = vec![T::zero(); n_cols]; + + for idx_group in chunk.chunks(4) { + match idx_group.len() { + 4 => unsafe { + let idx0 = idx_group[0]; + let idx1 = idx_group[1]; + let idx2 = idx_group[2]; + let idx3 = idx_group[3]; + + if idx3 < total_nnz { + let col0 = *col_indices.get_unchecked(idx0); + let col1 = *col_indices.get_unchecked(idx1); + let col2 = *col_indices.get_unchecked(idx2); + let col3 = *col_indices.get_unchecked(idx3); + + *local_result.get_unchecked_mut(col0) += T::one(); + *local_result.get_unchecked_mut(col1) += T::one(); + *local_result.get_unchecked_mut(col2) += T::one(); + *local_result.get_unchecked_mut(col3) += T::one(); + } + }, + _ => { + for &idx in idx_group { + if idx < total_nnz { + local_result[col_indices[idx]] += T::one(); + } + } + } + } + } + + let mut thread_result = thread_results[thread_id].lock().unwrap(); + for col in 0..n_cols { + thread_result[col] += local_result[col]; + } + }); + + let mut final_result = vec![T::zero(); n_cols]; + for thread_result in thread_results { + let result = thread_result.into_inner().unwrap(); + for col in 0..n_cols { + final_result[col] += result[col]; + } + } + + Ok(final_result) + } } fn nonzero_row(&self) -> anyhow::Result> @@ -336,28 +451,152 @@ impl MatrixSum for CsrMatrix { fn sum_col(&self) -> anyhow::Result> where - T: Float + NumCast + AddAssign + std::iter::Sum, + T: Float + NumCast + AddAssign + std::iter::Sum + Send + Sync, Self::Item: NumCast, { - let mut result = vec![T::zero(); self.ncols()]; + let n_cols = self.ncols(); let col_indices = self.col_indices(); let values = self.values(); + let total_nnz = values.len(); - // Process values directly in chunks to better utilize cache - const CHUNK_SIZE: usize = 256; - for chunk_start in (0..values.len()).step_by(CHUNK_SIZE) { - let chunk_end = (chunk_start + CHUNK_SIZE).min(values.len()); + if total_nnz == 0 || n_cols == 0 { + return Ok(vec![T::zero(); n_cols]); + } - // Direct accumulation without temporary storage - for (&col_idx, &value) in col_indices[chunk_start..chunk_end] - .iter() - .zip(&values[chunk_start..chunk_end]) - { - result[col_idx] += T::from(value).unwrap(); + if total_nnz < PARALLEL_THRESHOLD { + let mut result = vec![T::zero(); n_cols]; + + const CHUNK_SIZE: usize = 512; + + for chunk_start in (0..total_nnz).step_by(CHUNK_SIZE) { + let chunk_end = (chunk_start + CHUNK_SIZE).min(total_nnz); + + #[cfg(target_arch = "x86_64")] + { + use std::arch::x86_64::_mm_prefetch; + if chunk_end < total_nnz { + unsafe { + let next_start = chunk_end; + let prefetch_end = (next_start + 64).min(total_nnz); + for i in (next_start..prefetch_end).step_by(8) { + _mm_prefetch(col_indices.as_ptr().add(i) as *const i8, 1); + _mm_prefetch(values.as_ptr().add(i) as *const i8, 1); + } + } + } + } + + let unroll_end = chunk_start + ((chunk_end - chunk_start) / 8) * 8; + + unsafe { + let col_ptr = col_indices.as_ptr().add(chunk_start); + let val_ptr = values.as_ptr().add(chunk_start); + let result_ptr = result.as_mut_ptr(); + + for i in (0..(unroll_end - chunk_start)).step_by(8) { + let col0 = *col_ptr.add(i); + let col1 = *col_ptr.add(i + 1); + let col2 = *col_ptr.add(i + 2); + let col3 = *col_ptr.add(i + 3); + let col4 = *col_ptr.add(i + 4); + let col5 = *col_ptr.add(i + 5); + let col6 = *col_ptr.add(i + 6); + let col7 = *col_ptr.add(i + 7); + + let val0 = T::from(*val_ptr.add(i)).unwrap(); + let val1 = T::from(*val_ptr.add(i + 1)).unwrap(); + let val2 = T::from(*val_ptr.add(i + 2)).unwrap(); + let val3 = T::from(*val_ptr.add(i + 3)).unwrap(); + let val4 = T::from(*val_ptr.add(i + 4)).unwrap(); + let val5 = T::from(*val_ptr.add(i + 5)).unwrap(); + let val6 = T::from(*val_ptr.add(i + 6)).unwrap(); + let val7 = T::from(*val_ptr.add(i + 7)).unwrap(); + + *result_ptr.add(col0) = *result_ptr.add(col0) + val0; + *result_ptr.add(col1) = *result_ptr.add(col1) + val1; + *result_ptr.add(col2) = *result_ptr.add(col2) + val2; + *result_ptr.add(col3) = *result_ptr.add(col3) + val3; + *result_ptr.add(col4) = *result_ptr.add(col4) + val4; + *result_ptr.add(col5) = *result_ptr.add(col5) + val5; + *result_ptr.add(col6) = *result_ptr.add(col6) + val6; + *result_ptr.add(col7) = *result_ptr.add(col7) + val7; + } + } + + for idx in unroll_end..chunk_end { + result[col_indices[idx]] += T::from(values[idx]).unwrap(); + } } - } - Ok(result) + Ok(result) + } else { + let n_threads = rayon::current_num_threads(); + + let thread_results: Vec>> = (0..n_threads) + .map(|_| Mutex::new(vec![T::zero(); n_cols])) + .collect(); + + let chunk_size = (total_nnz / (n_threads * 8)).max(8192); + + (0..total_nnz) + .into_par_iter() + .chunks(chunk_size) + .enumerate() + .for_each(|(thread_id, chunk)| { + let thread_id = thread_id % n_threads; + let mut local_result = vec![T::zero(); n_cols]; + + for idx in chunk { + if idx >= total_nnz { + break; + } + + if idx + 3 < total_nnz { + let remaining = total_nnz - idx; + if remaining >= 4 { + unsafe { + let col0 = *col_indices.get_unchecked(idx); + let col1 = *col_indices.get_unchecked(idx + 1); + let col2 = *col_indices.get_unchecked(idx + 2); + let col3 = *col_indices.get_unchecked(idx + 3); + + let val0 = T::from(*values.get_unchecked(idx)).unwrap(); + let val1 = T::from(*values.get_unchecked(idx + 1)).unwrap(); + let val2 = T::from(*values.get_unchecked(idx + 2)).unwrap(); + let val3 = T::from(*values.get_unchecked(idx + 3)).unwrap(); + + *local_result.get_unchecked_mut(col0) += val0; + *local_result.get_unchecked_mut(col1) += val1; + *local_result.get_unchecked_mut(col2) += val2; + *local_result.get_unchecked_mut(col3) += val3; + } + } else { + for i in 0..remaining { + local_result[col_indices[idx + i]] += + T::from(values[idx + i]).unwrap(); + } + } + } else { + local_result[col_indices[idx]] += T::from(values[idx]).unwrap(); + } + } + + let mut thread_result = thread_results[thread_id].lock().unwrap(); + for (col, value) in local_result.into_iter().enumerate() { + thread_result[col] += value; + } + }); + + let mut final_result = vec![T::zero(); n_cols]; + for thread_result in thread_results { + let result = thread_result.into_inner().unwrap(); + for (col, value) in result.into_iter().enumerate() { + final_result[col] += value; + } + } + + Ok(final_result) + } } fn sum_row(&self) -> anyhow::Result> @@ -519,9 +758,8 @@ impl MatrixSum for CsrMatrix { fn sum_col_masked(&self, mask: &[bool]) -> anyhow::Result> where - T: Float + NumCast + AddAssign + Sum, + T: Float + NumCast + AddAssign + Sum + Send + Sync, { - // Validate mask length if mask.len() < self.nrows() { return Err(anyhow::anyhow!( "Mask length ({}) is less than number of rows ({})", @@ -530,27 +768,165 @@ impl MatrixSum for CsrMatrix { )); } - let mut result = vec![T::zero(); self.ncols()]; + let n_rows = self.nrows(); + let n_cols = self.ncols(); + let row_offsets = self.row_offsets(); + let col_indices = self.col_indices(); + let values = self.values(); - // Process each row - for row in 0..self.nrows() { - // Skip this row if masked out - if !mask[row] { - continue; + if n_rows == 0 || n_cols == 0 { + return Ok(vec![T::zero(); n_cols]); + } + + let masked_in_count = mask.iter().filter(|&&m| m).count(); + if masked_in_count == 0 { + return Ok(vec![T::zero(); n_cols]); + } + + let estimated_work = if masked_in_count > n_rows / 2 { + self.nnz() + } else { + let sample_size = masked_in_count.min(10); + let mut work_estimate = 0; + let mut sampled = 0; + for (row, &is_masked) in mask.iter().enumerate().take(n_rows.min(100)) { + if is_masked && sampled < sample_size { + work_estimate += row_offsets[row + 1] - row_offsets[row]; + sampled += 1; + } } + (work_estimate * masked_in_count) / sample_size.max(1) + }; - let row_start = self.row_offsets()[row]; - let row_end = self.row_offsets()[row + 1]; + const PARALLEL_THRESHOLD: usize = 50000; - // Process all non-zero elements in this row - for idx in row_start..row_end { - let col = self.col_indices()[idx]; - let value = T::from(self.values()[idx]).unwrap(); - result[col] += value; + if estimated_work < PARALLEL_THRESHOLD || masked_in_count < 100 { + let mut result = vec![T::zero(); n_cols]; + + if masked_in_count > n_rows * 3 / 4 { + unsafe { + let offsets_ptr = row_offsets.as_ptr(); + let indices_ptr = col_indices.as_ptr(); + let values_ptr = values.as_ptr(); + let mask_ptr = mask.as_ptr(); + let result_ptr = result.as_mut_ptr(); + + for row in 0..n_rows { + let is_included = *mask_ptr.add(row) as usize; + if is_included > 0 { + let start = *offsets_ptr.add(row); + let end = *offsets_ptr.add(row + 1); + + let chunk_end = start + ((end - start) / 4) * 4; + + for idx in (start..chunk_end).step_by(4) { + let col0 = *indices_ptr.add(idx); + let col1 = *indices_ptr.add(idx + 1); + let col2 = *indices_ptr.add(idx + 2); + let col3 = *indices_ptr.add(idx + 3); + + let val0 = T::from(*values_ptr.add(idx)).unwrap(); + let val1 = T::from(*values_ptr.add(idx + 1)).unwrap(); + let val2 = T::from(*values_ptr.add(idx + 2)).unwrap(); + let val3 = T::from(*values_ptr.add(idx + 3)).unwrap(); + + *result_ptr.add(col0) = *result_ptr.add(col0) + val0; + *result_ptr.add(col1) = *result_ptr.add(col1) + val1; + *result_ptr.add(col2) = *result_ptr.add(col2) + val2; + *result_ptr.add(col3) = *result_ptr.add(col3) + val3; + } + + for idx in chunk_end..end { + let col = *indices_ptr.add(idx); + let val = T::from(*values_ptr.add(idx)).unwrap(); + *result_ptr.add(col) = *result_ptr.add(col) + val; + } + } + } + } + } else { + for (row, &is_included) in mask.iter().enumerate().take(n_rows) { + if !is_included { + continue; + } + + let row_start = row_offsets[row]; + let row_end = row_offsets[row + 1]; + + if row_start == row_end { + continue; + } + + let chunk_end = row_start + ((row_end - row_start) / 4) * 4; + + for idx in (row_start..chunk_end).step_by(4) { + result[col_indices[idx]] += T::from(values[idx]).unwrap(); + result[col_indices[idx + 1]] += T::from(values[idx + 1]).unwrap(); + result[col_indices[idx + 2]] += T::from(values[idx + 2]).unwrap(); + result[col_indices[idx + 3]] += T::from(values[idx + 3]).unwrap(); + } + + for idx in chunk_end..row_end { + result[col_indices[idx]] += T::from(values[idx]).unwrap(); + } + } } - } - Ok(result) + Ok(result) + } else { + let n_threads = rayon::current_num_threads(); + let thread_results: Vec>> = (0..n_threads) + .map(|_| Mutex::new(vec![T::zero(); n_cols])) + .collect(); + + let chunk_size = (n_rows / (n_threads * 4)).max(256); + + (0..n_rows) + .into_par_iter() + .chunks(chunk_size) + .enumerate() + .for_each(|(thread_id, chunk)| { + let thread_id = thread_id % n_threads; + let mut local_result = vec![T::zero(); n_cols]; + + for row in chunk { + if row >= n_rows || !mask[row] { + continue; + } + + let row_start = row_offsets[row]; + let row_end = row_offsets[row + 1]; + + let chunk_end = row_start + ((row_end - row_start) / 4) * 4; + + for idx in (row_start..chunk_end).step_by(4) { + local_result[col_indices[idx]] += T::from(values[idx]).unwrap(); + local_result[col_indices[idx + 1]] += T::from(values[idx + 1]).unwrap(); + local_result[col_indices[idx + 2]] += T::from(values[idx + 2]).unwrap(); + local_result[col_indices[idx + 3]] += T::from(values[idx + 3]).unwrap(); + } + + for idx in chunk_end..row_end { + local_result[col_indices[idx]] += T::from(values[idx]).unwrap(); + } + } + + let mut thread_result = thread_results[thread_id].lock().unwrap(); + for (col, &value) in local_result.iter().enumerate() { + thread_result[col] += value; + } + }); + + let mut final_result = vec![T::zero(); n_cols]; + for thread_result in thread_results { + let result = thread_result.into_inner().unwrap(); + for (col, value) in result.into_iter().enumerate() { + final_result[col] += value; + } + } + + Ok(final_result) + } } fn sum_row_masked(&self, mask: &[bool]) -> anyhow::Result> From e6d2ff5f5cb54cf01d3d2b439ee74fb0045d7cbc Mon Sep 17 00:00:00 2001 From: Ian Date: Fri, 4 Jul 2025 09:38:56 +0000 Subject: [PATCH 5/5] changed some performance improvement bugs, version bump --- Cargo.lock | 2 +- Cargo.toml | 2 +- src/sparse/csr.rs | 922 +++++++++++----------------------------------- 3 files changed, 214 insertions(+), 712 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 5e5e7be..17c9895 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1753,7 +1753,7 @@ dependencies = [ [[package]] name = "single_algebra" -version = "0.8.2" +version = "0.8.3" dependencies = [ "anyhow", "approx", diff --git a/Cargo.toml b/Cargo.toml index d4a3fef..d9d9c20 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "single_algebra" -version = "0.8.2" +version = "0.8.3" edition = "2021" license-file = "LICENSE.md" description = "A linear algebra convenience library for the single-rust library. Can be used externally as well." diff --git a/src/sparse/csr.rs b/src/sparse/csr.rs index a66dbc1..6cb070a 100644 --- a/src/sparse/csr.rs +++ b/src/sparse/csr.rs @@ -13,7 +13,7 @@ use anyhow::{anyhow, Ok}; use nalgebra_sparse::CsrMatrix; use num_traits::{Float, NumCast, PrimInt, Unsigned, Zero}; use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator}; -use rayon::slice::ParallelSliceMut; +use rayon::slice::ParallelSlice; use single_utilities::traits::{FloatOpsTS, NumericOps}; use single_utilities::types::Direction; @@ -29,220 +29,96 @@ impl MatrixNonZero for CsrMatrix { let col_indices = self.col_indices(); let total_nnz = col_indices.len(); + // Early return for empty matrix if total_nnz == 0 || n_cols == 0 { return Ok(vec![T::zero(); n_cols]); } + if let Some(&max_col) = col_indices.iter().max() { + if max_col >= n_cols { + return Err(anyhow::anyhow!( + "Invalid column index {} exceeds matrix column count {}", + max_col, + n_cols + )); + } + } + if total_nnz < PARALLEL_THRESHOLD { + // Sequential implementation let mut result = vec![T::zero(); n_cols]; - for chunk_start in (0..total_nnz).step_by(CHUNK_SIZE) { - let chunk_end = (chunk_start + CHUNK_SIZE).min(total_nnz); - - #[cfg(target_arch = "x86_64")] - { - use std::arch::x86_64::_mm_prefetch; - if chunk_end < total_nnz { - unsafe { - let next_start = chunk_end; - let prefetch_end = (next_start + 64).min(total_nnz); - for i in (next_start..prefetch_end).step_by(8) { - _mm_prefetch(col_indices.as_ptr().add(i) as *const i8, 1); - } - } - } - } - - let unroll_end = chunk_start + ((chunk_end - chunk_start) / 8) * 8; - - unsafe { - let col_ptr = col_indices.as_ptr().add(chunk_start); - let result_ptr = result.as_mut_ptr(); - - for i in (0..(unroll_end - chunk_start)).step_by(8) { - let col0 = *col_ptr.add(i); - let col1 = *col_ptr.add(i + 1); - let col2 = *col_ptr.add(i + 2); - let col3 = *col_ptr.add(i + 3); - let col4 = *col_ptr.add(i + 4); - let col5 = *col_ptr.add(i + 5); - let col6 = *col_ptr.add(i + 6); - let col7 = *col_ptr.add(i + 7); - - *result_ptr.add(col0) = *result_ptr.add(col0) + T::one(); - *result_ptr.add(col1) = *result_ptr.add(col1) + T::one(); - *result_ptr.add(col2) = *result_ptr.add(col2) + T::one(); - *result_ptr.add(col3) = *result_ptr.add(col3) + T::one(); - *result_ptr.add(col4) = *result_ptr.add(col4) + T::one(); - *result_ptr.add(col5) = *result_ptr.add(col5) + T::one(); - *result_ptr.add(col6) = *result_ptr.add(col6) + T::one(); - *result_ptr.add(col7) = *result_ptr.add(col7) + T::one(); - } - } - - for idx in unroll_end..chunk_end { - result[col_indices[idx]] += T::one(); - } + for &col_idx in col_indices { + result[col_idx] += T::one(); } Ok(result) } else { - let n_threads = rayon::current_num_threads(); - - let thread_results: Vec>> = (0..n_threads) - .map(|_| Mutex::new(vec![T::zero(); n_cols])) - .collect(); - - let chunk_size = (total_nnz / (n_threads * 8)).max(8192); - - (0..total_nnz) - .into_par_iter() - .chunks(chunk_size) - .enumerate() - .for_each(|(thread_id, chunk)| { - let thread_id = thread_id % n_threads; - let mut local_result = vec![T::zero(); n_cols]; - - for idx_group in chunk.chunks(4) { - match idx_group.len() { - 4 => unsafe { - let idx0 = idx_group[0]; - let idx1 = idx_group[1]; - let idx2 = idx_group[2]; - let idx3 = idx_group[3]; - - if idx3 < total_nnz { - let col0 = *col_indices.get_unchecked(idx0); - let col1 = *col_indices.get_unchecked(idx1); - let col2 = *col_indices.get_unchecked(idx2); - let col3 = *col_indices.get_unchecked(idx3); - - *local_result.get_unchecked_mut(col0) += T::one(); - *local_result.get_unchecked_mut(col1) += T::one(); - *local_result.get_unchecked_mut(col2) += T::one(); - *local_result.get_unchecked_mut(col3) += T::one(); - } - }, - _ => { - for &idx in idx_group { - if idx < total_nnz { - local_result[col_indices[idx]] += T::one(); - } - } - } - } - } - - let mut thread_result = thread_results[thread_id].lock().unwrap(); - for col in 0..n_cols { - thread_result[col] += local_result[col]; + let result = col_indices + .par_chunks(8192) + .map(|chunk| { + let mut local_counts = vec![T::zero(); n_cols]; + for &col_idx in chunk { + local_counts[col_idx] += T::one(); } - }); - - let mut final_result = vec![T::zero(); n_cols]; - for thread_result in thread_results { - let result = thread_result.into_inner().unwrap(); - for col in 0..n_cols { - final_result[col] += result[col]; - } - } + local_counts + }) + .reduce( + || vec![T::zero(); n_cols], + |mut acc, local| { + for (i, count) in local.into_iter().enumerate() { + acc[i] += count; + } + acc + }, + ); - Ok(final_result) + Ok(result) } } fn nonzero_row(&self) -> anyhow::Result> where - T: PrimInt + Unsigned + Zero + AddAssign, + T: PrimInt + Unsigned + Zero + AddAssign + Send + Sync, { let row_offsets = self.row_offsets(); let n_rows = self.nrows(); - // Early return for empty matrix if n_rows == 0 { return Ok(Vec::new()); } - // Estimate total non-zeros using sampling for larger matrices - let estimated_nonzeros = { - let sample_size = n_rows.min(100); - let sample_sum: usize = (0..sample_size) - .map(|i| row_offsets[i + 1] - row_offsets[i]) - .sum(); - (sample_sum * n_rows) / sample_size - }; - - if estimated_nonzeros >= PARALLEL_THRESHOLD { - // Parallel implementation - let n_cores = rayon::current_num_threads(); - let chunk_size = (n_rows / (n_cores * 4)).max(1024); - - // Pre-allocate result - let mut result = vec![T::zero(); n_rows]; - - // Use par_chunks_mut for correct forward iteration - result.rchunks_mut(chunk_size).enumerate().try_for_each( - |(chunk_idx, chunk)| -> anyhow::Result<()> { - let start_idx = chunk_idx * chunk_size; - - for (i, item) in chunk.iter_mut().enumerate() { - let row_idx = start_idx + i; - let count = row_offsets[row_idx + 1] - row_offsets[row_idx]; - *item = T::from(count).ok_or_else(|| { - anyhow::anyhow!("Count {} exceeds target type capacity", count) - })?; - } - Ok(()) - }, - )?; - - Ok(result) - } else { - // Sequential implementation for medium-sized matrices - let mut result = Vec::with_capacity(n_rows); - - let chunks = n_rows / 8; - let remainder = n_rows % 8; - - unsafe { - let offsets_ptr = row_offsets.as_ptr(); - for chunk_idx in 0..chunks { - let base = chunk_idx * 8; - - // Prefetch on x86_64 - #[cfg(target_arch = "x86_64")] - { - use std::arch::x86_64::_mm_prefetch; - if base + 16 < row_offsets.len() { - _mm_prefetch(offsets_ptr.add(base + 16) as *const i8, 1); - } - } - - for i in 0..8 { - let idx = base + i; - let start = *offsets_ptr.add(idx); - let end = *offsets_ptr.add(idx + 1); - let count = end - start; - - result.push(T::from(count).ok_or_else(|| { - anyhow::anyhow!("Count {} exceeds target type capacity", count) - })?); - } - } + if row_offsets.len() != n_rows + 1 { + return Err(anyhow::anyhow!( + "Invalid row offsets: expected {} elements, got {}", + n_rows + 1, + row_offsets.len() + )); + } - for i in 0..remainder { - let idx = chunks * 8 + i; - let start = *offsets_ptr.add(idx); - let end = *offsets_ptr.add(idx + 1); - let count = end - start; + if n_rows < PARALLEL_THRESHOLD { + let mut result = Vec::with_capacity(n_rows); - result.push(T::from(count).ok_or_else(|| { - anyhow::anyhow!("Count {} exceeds target type capacity", count) - })?); - } + for row in 0..n_rows { + let count = row_offsets[row + 1] - row_offsets[row]; + result.push(T::from(count).ok_or_else(|| { + anyhow::anyhow!("Count {} exceeds target type capacity", count) + })?); } Ok(result) + } else { + let result: Result, anyhow::Error> = (0..n_rows) + .into_par_iter() + .map(|row| { + let count = row_offsets[row + 1] - row_offsets[row]; + T::from(count).ok_or_else(|| { + anyhow::anyhow!("Count {} exceeds target type capacity", count) + }) + }) + .collect(); + + result } } @@ -312,7 +188,7 @@ impl MatrixNonZero for CsrMatrix { fn nonzero_row_masked(&self, mask: &[bool]) -> anyhow::Result> where - T: PrimInt + Unsigned + Zero + AddAssign, + T: PrimInt + Unsigned + Zero + AddAssign + Send + Sync, { if mask.len() < self.ncols() { return Err(anyhow::anyhow!( @@ -326,122 +202,54 @@ impl MatrixNonZero for CsrMatrix { let row_offsets = self.row_offsets(); let col_indices = self.col_indices(); - // Early return for empty matrix if n_rows == 0 { return Ok(Vec::new()); } - // Quick check: if very few columns are masked in, sequential might be better - let masked_in_count = mask.iter().filter(|&&m| m).count(); - if masked_in_count == 0 { - // All columns masked out - return zeros + let masked_count = mask.iter().filter(|&&m| m).count(); + if masked_count == 0 { return Ok(vec![T::zero(); n_rows]); } let total_nnz = self.nnz(); - if total_nnz < PARALLEL_THRESHOLD || masked_in_count < 10 { - let mut result = vec![T::zero(); n_rows]; - - // If most columns are masked in, use branch-free counting - if masked_in_count > self.ncols() * 3 / 4 { - // Branch-free version for mostly-true masks - unsafe { - let offsets_ptr = row_offsets.as_ptr(); - let indices_ptr = col_indices.as_ptr(); - let mask_ptr = mask.as_ptr(); - - for row in 0..n_rows { - let start = *offsets_ptr.add(row); - let end = *offsets_ptr.add(row + 1); - let mut count = T::zero(); - - for idx in start..end { - let col = *indices_ptr.add(idx); - // Branch-free increment: convert bool to 0 or 1 - let increment = *mask_ptr.add(col) as usize; - if increment > 0 { - count = count + T::one(); - } - } - - result[row] = count; - } - } - } else { - // Standard version with branching for sparse masks - for row in 0..n_rows { - let row_start = row_offsets[row]; - let row_end = row_offsets[row + 1]; - let mut count = T::zero(); + if total_nnz < PARALLEL_THRESHOLD { + let mut result = Vec::with_capacity(n_rows); - // Unroll by 4 for better performance - let chunk_end = row_start + ((row_end - row_start) / 4) * 4; + for row in 0..n_rows { + let start = row_offsets[row]; + let end = row_offsets[row + 1]; + let mut count = T::zero(); - for idx in (row_start..chunk_end).step_by(4) { - // Process 4 elements at a time - if mask[col_indices[idx]] { - count = count + T::one(); - } - if mask[col_indices[idx + 1]] { - count = count + T::one(); - } - if mask[col_indices[idx + 2]] { - count = count + T::one(); - } - if mask[col_indices[idx + 3]] { - count = count + T::one(); - } + for idx in start..end { + if mask[col_indices[idx]] { + count += T::one(); } - - // Handle remainder - for idx in chunk_end..row_end { - if mask[col_indices[idx]] { - count = count + T::one(); - } - } - - result[row] = count; } + + result.push(count); } Ok(result) } else { - let chunk_size = (n_rows / (rayon::current_num_threads() * 4)).max(1024); - - // Pre-allocate result - let mut result = vec![T::zero(); n_rows]; - - // Process rows in parallel chunks - result - .rchunks_mut(chunk_size) - .enumerate() - .for_each(|(chunk_idx, result_chunk)| { - let start_row = chunk_idx * chunk_size; - - for (i, count) in result_chunk.iter_mut().enumerate() { - let row = start_row + i; - if row >= n_rows { - break; - } - - let row_start = row_offsets[row]; - let row_end = row_offsets[row + 1]; - let mut local_count = T::zero(); + let counts: Vec = (0..n_rows) + .into_par_iter() + .map(|row| { + let start = row_offsets[row]; + let end = row_offsets[row + 1]; + let mut count = T::zero(); - // For better cache locality, process in smaller sub-chunks - for idx in row_start..row_end { - let col = col_indices[idx]; - if mask[col] { - local_count += T::one(); - } + for idx in start..end { + if mask[col_indices[idx]] { + count += T::one(); } - - *count = local_count; } - }); - Ok(result) + count + }) + .collect(); + + Ok(counts) } } } @@ -466,136 +274,41 @@ impl MatrixSum for CsrMatrix { if total_nnz < PARALLEL_THRESHOLD { let mut result = vec![T::zero(); n_cols]; - const CHUNK_SIZE: usize = 512; - for chunk_start in (0..total_nnz).step_by(CHUNK_SIZE) { let chunk_end = (chunk_start + CHUNK_SIZE).min(total_nnz); - #[cfg(target_arch = "x86_64")] - { - use std::arch::x86_64::_mm_prefetch; - if chunk_end < total_nnz { - unsafe { - let next_start = chunk_end; - let prefetch_end = (next_start + 64).min(total_nnz); - for i in (next_start..prefetch_end).step_by(8) { - _mm_prefetch(col_indices.as_ptr().add(i) as *const i8, 1); - _mm_prefetch(values.as_ptr().add(i) as *const i8, 1); - } - } - } - } - - let unroll_end = chunk_start + ((chunk_end - chunk_start) / 8) * 8; - - unsafe { - let col_ptr = col_indices.as_ptr().add(chunk_start); - let val_ptr = values.as_ptr().add(chunk_start); - let result_ptr = result.as_mut_ptr(); - - for i in (0..(unroll_end - chunk_start)).step_by(8) { - let col0 = *col_ptr.add(i); - let col1 = *col_ptr.add(i + 1); - let col2 = *col_ptr.add(i + 2); - let col3 = *col_ptr.add(i + 3); - let col4 = *col_ptr.add(i + 4); - let col5 = *col_ptr.add(i + 5); - let col6 = *col_ptr.add(i + 6); - let col7 = *col_ptr.add(i + 7); - - let val0 = T::from(*val_ptr.add(i)).unwrap(); - let val1 = T::from(*val_ptr.add(i + 1)).unwrap(); - let val2 = T::from(*val_ptr.add(i + 2)).unwrap(); - let val3 = T::from(*val_ptr.add(i + 3)).unwrap(); - let val4 = T::from(*val_ptr.add(i + 4)).unwrap(); - let val5 = T::from(*val_ptr.add(i + 5)).unwrap(); - let val6 = T::from(*val_ptr.add(i + 6)).unwrap(); - let val7 = T::from(*val_ptr.add(i + 7)).unwrap(); - - *result_ptr.add(col0) = *result_ptr.add(col0) + val0; - *result_ptr.add(col1) = *result_ptr.add(col1) + val1; - *result_ptr.add(col2) = *result_ptr.add(col2) + val2; - *result_ptr.add(col3) = *result_ptr.add(col3) + val3; - *result_ptr.add(col4) = *result_ptr.add(col4) + val4; - *result_ptr.add(col5) = *result_ptr.add(col5) + val5; - *result_ptr.add(col6) = *result_ptr.add(col6) + val6; - *result_ptr.add(col7) = *result_ptr.add(col7) + val7; - } - } - - for idx in unroll_end..chunk_end { + for idx in chunk_start..chunk_end { result[col_indices[idx]] += T::from(values[idx]).unwrap(); } } Ok(result) } else { - let n_threads = rayon::current_num_threads(); - - let thread_results: Vec>> = (0..n_threads) - .map(|_| Mutex::new(vec![T::zero(); n_cols])) - .collect(); - - let chunk_size = (total_nnz / (n_threads * 8)).max(8192); - - (0..total_nnz) + let result = (0..total_nnz) .into_par_iter() - .chunks(chunk_size) - .enumerate() - .for_each(|(thread_id, chunk)| { - let thread_id = thread_id % n_threads; - let mut local_result = vec![T::zero(); n_cols]; - - for idx in chunk { - if idx >= total_nnz { - break; - } + .chunks(8192) + .map(|chunk_indices| { + let mut local_sums = vec![T::zero(); n_cols]; - if idx + 3 < total_nnz { - let remaining = total_nnz - idx; - if remaining >= 4 { - unsafe { - let col0 = *col_indices.get_unchecked(idx); - let col1 = *col_indices.get_unchecked(idx + 1); - let col2 = *col_indices.get_unchecked(idx + 2); - let col3 = *col_indices.get_unchecked(idx + 3); - - let val0 = T::from(*values.get_unchecked(idx)).unwrap(); - let val1 = T::from(*values.get_unchecked(idx + 1)).unwrap(); - let val2 = T::from(*values.get_unchecked(idx + 2)).unwrap(); - let val3 = T::from(*values.get_unchecked(idx + 3)).unwrap(); - - *local_result.get_unchecked_mut(col0) += val0; - *local_result.get_unchecked_mut(col1) += val1; - *local_result.get_unchecked_mut(col2) += val2; - *local_result.get_unchecked_mut(col3) += val3; - } - } else { - for i in 0..remaining { - local_result[col_indices[idx + i]] += - T::from(values[idx + i]).unwrap(); - } - } - } else { - local_result[col_indices[idx]] += T::from(values[idx]).unwrap(); + for idx in chunk_indices { + if idx < total_nnz { + local_sums[col_indices[idx]] += T::from(values[idx]).unwrap(); } } - let mut thread_result = thread_results[thread_id].lock().unwrap(); - for (col, value) in local_result.into_iter().enumerate() { - thread_result[col] += value; - } - }); - - let mut final_result = vec![T::zero(); n_cols]; - for thread_result in thread_results { - let result = thread_result.into_inner().unwrap(); - for (col, value) in result.into_iter().enumerate() { - final_result[col] += value; - } - } + local_sums + }) + .reduce( + || vec![T::zero(); n_cols], + |mut acc, local| { + for (i, val) in local.into_iter().enumerate() { + acc[i] += val; + } + acc + }, + ); - Ok(final_result) + Ok(result) } } @@ -614,121 +327,68 @@ impl MatrixSum for CsrMatrix { let total_nnz = values.len(); - if total_nnz >= PARALLEL_THRESHOLD { - let sums: Vec = (0..nrows) - .into_par_iter() - .map(|row| { - let start = row_offsets[row]; - let end = row_offsets[row + 1]; + if total_nnz < PARALLEL_THRESHOLD { + let mut result = Vec::with_capacity(nrows); - if end - start > 16 { - let mut sum = M::zero(); - let chunk_size = 8; - let chunks = (end - start) / chunk_size; - let remainder = (end - start) % chunk_size; - - for c in 0..chunks { - let base = start + c * chunk_size; - let mut chunk_sum = M::zero(); - - chunk_sum = chunk_sum + values[base]; - chunk_sum = chunk_sum + values[base + 1]; - chunk_sum = chunk_sum + values[base + 2]; - chunk_sum = chunk_sum + values[base + 3]; - chunk_sum = chunk_sum + values[base + 4]; - chunk_sum = chunk_sum + values[base + 5]; - chunk_sum = chunk_sum + values[base + 6]; - chunk_sum = chunk_sum + values[base + 7]; - - sum = sum + chunk_sum; - } + const ROW_CHUNK: usize = 64; - for i in 0..remainder { - sum = sum + values[start + chunks * chunk_size + i]; - } + for row_chunk in (0..nrows).step_by(ROW_CHUNK) { + let chunk_end = (row_chunk + ROW_CHUNK).min(nrows); - T::from(sum).unwrap() + for row in row_chunk..chunk_end { + let start = row_offsets[row]; + let end = row_offsets[row + 1]; + let row_values = &values[start..end]; + let sum = if row_values.len() < 16 { + row_values.iter().map(|&v| T::from(v).unwrap()).sum::() } else { - let mut sum = M::zero(); - for i in start..end { - sum = sum + values[i]; - } - T::from(sum).unwrap() - } - }) - .collect(); - Ok(sums) - } else { - let mut result = vec![T::zero(); nrows]; - - let chunk_size = 8; - let chunks = nrows / chunk_size; - - unsafe { - let values_ptr = values.as_ptr(); - let offsets_ptr = row_offsets.as_ptr(); - let result_ptr = result.as_mut_ptr(); - - for chunk in 0..chunks { - let base = chunk * chunk_size; - - #[cfg(target_arch = "x86_64")] - { - use std::arch::x86_64::_mm_prefetch; - if base + chunk_size < nrows { - let next_start = *offsets_ptr.add(base + chunk_size); - if next_start < values.len() { - _mm_prefetch(values_ptr.add(next_start) as *const i8, 1); - } + let mut sum = T::zero(); + let chunks = row_values.chunks_exact(4); + let remainder = chunks.remainder(); + + for chunk in chunks { + sum += T::from(chunk[0]).unwrap(); + sum += T::from(chunk[1]).unwrap(); + sum += T::from(chunk[2]).unwrap(); + sum += T::from(chunk[3]).unwrap(); } - } - let mut sums = [M::zero(); 8]; + for &val in remainder { + sum += T::from(val).unwrap(); + } - for i in 0..8 { - let row = base + i; - let start = *offsets_ptr.add(row); - let end = *offsets_ptr.add(row + 1); + sum + }; - let mut sum = M::zero(); + result.push(sum); + } + } - let inner_chunks = (end - start) / 4; - let inner_remainder = (end - start) % 4; + Ok(result) + } else { + let sums: Vec = (0..nrows) + .into_par_iter() + .map(|row| { + let start = row_offsets[row]; + let end = row_offsets[row + 1]; + let row_values = &values[start..end]; - for j in 0..inner_chunks { - let idx = start + j * 4; - sum = sum + *values_ptr.add(idx); - sum = sum + *values_ptr.add(idx + 1); - sum = sum + *values_ptr.add(idx + 2); - sum = sum + *values_ptr.add(idx + 3); - } + if row_values.len() < 32 { + row_values.iter().map(|&v| T::from(v).unwrap()).sum::() + } else { + let mut sum = T::zero(); - for j in 0..inner_remainder { - sum = sum + *values_ptr.add(start + inner_chunks * 4 + j); + for chunk in row_values.chunks(8) { + let chunk_sum: T = chunk.iter().map(|&v| T::from(v).unwrap()).sum(); + sum += chunk_sum; } - sums[i] = sum; + sum } + }) + .collect(); - for i in 0..8 { - *result_ptr.add(base + i) = T::from(sums[i]).unwrap(); - } - } - - for row in (chunks * chunk_size)..nrows { - let start = *offsets_ptr.add(row); - let end = *offsets_ptr.add(row + 1); - let mut sum = M::zero(); - - for idx in start..end { - sum = sum + *values_ptr.add(idx); - } - - *result_ptr.add(row) = T::from(sum).unwrap(); - } - } - - Ok(result) + Ok(sums) } } @@ -768,105 +428,26 @@ impl MatrixSum for CsrMatrix { )); } - let n_rows = self.nrows(); let n_cols = self.ncols(); let row_offsets = self.row_offsets(); let col_indices = self.col_indices(); let values = self.values(); - if n_rows == 0 || n_cols == 0 { - return Ok(vec![T::zero(); n_cols]); - } + let masked_count = mask.iter().filter(|&&m| m).count(); - let masked_in_count = mask.iter().filter(|&&m| m).count(); - if masked_in_count == 0 { + if masked_count == 0 { return Ok(vec![T::zero(); n_cols]); } - let estimated_work = if masked_in_count > n_rows / 2 { - self.nnz() - } else { - let sample_size = masked_in_count.min(10); - let mut work_estimate = 0; - let mut sampled = 0; - for (row, &is_masked) in mask.iter().enumerate().take(n_rows.min(100)) { - if is_masked && sampled < sample_size { - work_estimate += row_offsets[row + 1] - row_offsets[row]; - sampled += 1; - } - } - (work_estimate * masked_in_count) / sample_size.max(1) - }; - - const PARALLEL_THRESHOLD: usize = 50000; - - if estimated_work < PARALLEL_THRESHOLD || masked_in_count < 100 { + if masked_count < PARALLEL_THRESHOLD { let mut result = vec![T::zero(); n_cols]; - if masked_in_count > n_rows * 3 / 4 { - unsafe { - let offsets_ptr = row_offsets.as_ptr(); - let indices_ptr = col_indices.as_ptr(); - let values_ptr = values.as_ptr(); - let mask_ptr = mask.as_ptr(); - let result_ptr = result.as_mut_ptr(); - - for row in 0..n_rows { - let is_included = *mask_ptr.add(row) as usize; - if is_included > 0 { - let start = *offsets_ptr.add(row); - let end = *offsets_ptr.add(row + 1); - - let chunk_end = start + ((end - start) / 4) * 4; - - for idx in (start..chunk_end).step_by(4) { - let col0 = *indices_ptr.add(idx); - let col1 = *indices_ptr.add(idx + 1); - let col2 = *indices_ptr.add(idx + 2); - let col3 = *indices_ptr.add(idx + 3); - - let val0 = T::from(*values_ptr.add(idx)).unwrap(); - let val1 = T::from(*values_ptr.add(idx + 1)).unwrap(); - let val2 = T::from(*values_ptr.add(idx + 2)).unwrap(); - let val3 = T::from(*values_ptr.add(idx + 3)).unwrap(); - - *result_ptr.add(col0) = *result_ptr.add(col0) + val0; - *result_ptr.add(col1) = *result_ptr.add(col1) + val1; - *result_ptr.add(col2) = *result_ptr.add(col2) + val2; - *result_ptr.add(col3) = *result_ptr.add(col3) + val3; - } - - for idx in chunk_end..end { - let col = *indices_ptr.add(idx); - let val = T::from(*values_ptr.add(idx)).unwrap(); - *result_ptr.add(col) = *result_ptr.add(col) + val; - } - } - } - } - } else { - for (row, &is_included) in mask.iter().enumerate().take(n_rows) { - if !is_included { - continue; - } - - let row_start = row_offsets[row]; - let row_end = row_offsets[row + 1]; - - if row_start == row_end { - continue; - } - - let chunk_end = row_start + ((row_end - row_start) / 4) * 4; - - for idx in (row_start..chunk_end).step_by(4) { - result[col_indices[idx]] += T::from(values[idx]).unwrap(); - result[col_indices[idx + 1]] += T::from(values[idx + 1]).unwrap(); - result[col_indices[idx + 2]] += T::from(values[idx + 2]).unwrap(); - result[col_indices[idx + 3]] += T::from(values[idx + 3]).unwrap(); - } + for (row, &is_included) in mask.iter().enumerate() { + if is_included { + let start = row_offsets[row]; + let end = row_offsets[row + 1]; - for idx in chunk_end..row_end { + for idx in start..end { result[col_indices[idx]] += T::from(values[idx]).unwrap(); } } @@ -874,58 +455,36 @@ impl MatrixSum for CsrMatrix { Ok(result) } else { - let n_threads = rayon::current_num_threads(); - let thread_results: Vec>> = (0..n_threads) - .map(|_| Mutex::new(vec![T::zero(); n_cols])) - .collect(); - - let chunk_size = (n_rows / (n_threads * 4)).max(256); - - (0..n_rows) + let result = (0..mask.len()) .into_par_iter() - .chunks(chunk_size) - .enumerate() - .for_each(|(thread_id, chunk)| { - let thread_id = thread_id % n_threads; - let mut local_result = vec![T::zero(); n_cols]; - - for row in chunk { - if row >= n_rows || !mask[row] { - continue; - } - - let row_start = row_offsets[row]; - let row_end = row_offsets[row + 1]; + .chunks(256) + .map(|row_chunk| { + let mut local_sums = vec![T::zero(); n_cols]; - let chunk_end = row_start + ((row_end - row_start) / 4) * 4; - - for idx in (row_start..chunk_end).step_by(4) { - local_result[col_indices[idx]] += T::from(values[idx]).unwrap(); - local_result[col_indices[idx + 1]] += T::from(values[idx + 1]).unwrap(); - local_result[col_indices[idx + 2]] += T::from(values[idx + 2]).unwrap(); - local_result[col_indices[idx + 3]] += T::from(values[idx + 3]).unwrap(); - } + for row in row_chunk { + if mask[row] { + let start = row_offsets[row]; + let end = row_offsets[row + 1]; - for idx in chunk_end..row_end { - local_result[col_indices[idx]] += T::from(values[idx]).unwrap(); + for idx in start..end { + local_sums[col_indices[idx]] += T::from(values[idx]).unwrap(); + } } } - let mut thread_result = thread_results[thread_id].lock().unwrap(); - for (col, &value) in local_result.iter().enumerate() { - thread_result[col] += value; - } - }); - - let mut final_result = vec![T::zero(); n_cols]; - for thread_result in thread_results { - let result = thread_result.into_inner().unwrap(); - for (col, value) in result.into_iter().enumerate() { - final_result[col] += value; - } - } + local_sums + }) + .reduce( + || vec![T::zero(); n_cols], + |mut acc, local| { + for (i, val) in local.into_iter().enumerate() { + acc[i] += val; + } + acc + }, + ); - Ok(final_result) + Ok(result) } } @@ -950,107 +509,50 @@ impl MatrixSum for CsrMatrix { return Ok(Vec::new()); } - let masked_in_count = mask.iter().filter(|&&m| m).count(); - if masked_in_count == 0 { + let masked_count = mask.iter().filter(|&&m| m).count(); + if masked_count == 0 { return Ok(vec![T::zero(); n_rows]); } let total_nnz = self.nnz(); - const PARALLEL_THRESHOLD: usize = 10000; - if total_nnz < PARALLEL_THRESHOLD || masked_in_count < 10 { - let mut result = vec![T::zero(); n_rows]; - - if masked_in_count > self.ncols() * 3 / 4 { - unsafe { - let offsets_ptr = row_offsets.as_ptr(); - let indices_ptr = col_indices.as_ptr(); - let values_ptr = values.as_ptr(); - let mask_ptr = mask.as_ptr(); - - for row in 0..n_rows { - let start = *offsets_ptr.add(row); - let end = *offsets_ptr.add(row + 1); - let mut sum = T::zero(); + if total_nnz < PARALLEL_THRESHOLD { + let mut result = Vec::with_capacity(n_rows); - for idx in start..end { - let col = *indices_ptr.add(idx); - let increment = *mask_ptr.add(col) as usize; - if increment > 0 { - sum += T::from(*values_ptr.add(idx)).unwrap(); - } - } + for row in 0..n_rows { + let start = row_offsets[row]; + let end = row_offsets[row + 1]; + let mut sum = T::zero(); - result[row] = sum; + for idx in start..end { + if mask[col_indices[idx]] { + sum += T::from(values[idx]).unwrap(); } } - } else { - for row in 0..n_rows { - let row_start = row_offsets[row]; - let row_end = row_offsets[row + 1]; - let mut sum = T::zero(); - - let chunk_end = row_start + ((row_end - row_start) / 4) * 4; - for idx in (row_start..chunk_end).step_by(4) { - if mask[col_indices[idx]] { - sum += T::from(values[idx]).unwrap(); - } - if mask[col_indices[idx + 1]] { - sum += T::from(values[idx + 1]).unwrap(); - } - if mask[col_indices[idx + 2]] { - sum += T::from(values[idx + 2]).unwrap(); - } - if mask[col_indices[idx + 3]] { - sum += T::from(values[idx + 3]).unwrap(); - } - } - - for idx in chunk_end..row_end { - if mask[col_indices[idx]] { - sum += T::from(values[idx]).unwrap(); - } - } - - result[row] = sum; - } + result.push(sum); } Ok(result) } else { - let chunk_size = (n_rows / (rayon::current_num_threads() * 4)).max(1024); - - let mut result = vec![T::zero(); n_rows]; - - result - .par_chunks_mut(chunk_size) - .enumerate() - .for_each(|(chunk_idx, result_chunk)| { - let start_row = chunk_idx * chunk_size; - - for (i, sum) in result_chunk.iter_mut().enumerate() { - let row = start_row + i; - if row >= n_rows { - break; - } - - let row_start = row_offsets[row]; - let row_end = row_offsets[row + 1]; - let mut local_sum = T::zero(); + let sums: Vec = (0..n_rows) + .into_par_iter() + .map(|row| { + let start = row_offsets[row]; + let end = row_offsets[row + 1]; + let mut sum = T::zero(); - for idx in row_start..row_end { - let col = col_indices[idx]; - if mask[col] { - local_sum += T::from(values[idx]).unwrap(); - } + for idx in start..end { + if mask[col_indices[idx]] { + sum += T::from(values[idx]).unwrap(); } - - *sum = local_sum; } - }); - Ok(result) + sum + }) + .collect(); + + Ok(sums) } }