diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 9d1f864..f245dbf 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -22,7 +22,7 @@ jobs: - uses: actions/checkout@v4 - uses: dtolnay/rust-toolchain@stable - uses: Swatinem/rust-cache@v2 - - run: cargo test --workspace --exclude tropical-gemm-cuda + - run: cargo test --workspace --exclude tropical-gemm-cuda --exclude tropical-gemm-metal clippy: name: Clippy @@ -33,5 +33,5 @@ jobs: with: components: clippy - uses: Swatinem/rust-cache@v2 - - run: cargo clippy --workspace --exclude tropical-gemm-cuda --all-targets -- -W clippy::all + - run: cargo clippy --workspace --exclude tropical-gemm-cuda --exclude tropical-gemm-metal --all-targets -- -W clippy::all diff --git a/Cargo.lock b/Cargo.lock index 68bf154..b0a7ce1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -44,12 +44,24 @@ version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5e764a1d40d510daf35e07be9eb06e75770908c27d411ee6c92109c9840eaaf7" +[[package]] +name = "bitflags" +version = "1.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" + [[package]] name = "bitflags" version = "2.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "812e12b5285cc515a9c72a5c1d3b6d46a19dac5acfef5265968c166106e31dd3" +[[package]] +name = "block" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0d8c1fef690941d3e7788d328517591fecc684c084084702d6ff1641e993699a" + [[package]] name = "bumpalo" version = "3.19.1" @@ -126,6 +138,33 @@ version = "0.7.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a1d728cc89cf3aee9ff92b05e62b19ee65a02b5702cff7d5a377e32c6ae29d8d" +[[package]] +name = "core-foundation" +version = "0.9.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91e195e091a93c46f7102ec7818a2aa394e1e1771c3ab4825963fa03e45afb8f" +dependencies = [ + "core-foundation-sys", + "libc", +] + +[[package]] +name = "core-foundation-sys" +version = "0.8.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b" + +[[package]] +name = "core-graphics-types" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "45390e6114f68f718cc7a830514a96f903cccd70d02a8f6d9f643ac4ba45afaf" +dependencies = [ + "bitflags 1.3.2", + "core-foundation", + "libc", +] + [[package]] name = "criterion" version = "0.5.1" @@ -230,6 +269,33 @@ version = "1.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" +[[package]] +name = "foreign-types" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d737d9aa519fb7b749cbc3b962edcf310a8dd1f4b67c91c4f83975dbdd17d965" +dependencies = [ + "foreign-types-macros", + "foreign-types-shared", +] + +[[package]] +name = "foreign-types-macros" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1a5c6c585bc94aaf2c7b51dd4c2ba22680844aba4c687be581871a6f518c5742" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "foreign-types-shared" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aa9a19cbb55df58761df49b23516a86d432839add4af60fc256da840f66ed35b" + [[package]] name = "getrandom" version = "0.3.4" @@ -317,12 +383,42 @@ version = "0.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "df1d3c3b53da64cf5760482273a98e575c651a67eec7f77df96b5b642de8f039" +[[package]] +name = "log" +version = "0.4.29" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5e5032e24019045c762d3c0f28f5b6b8bbf38563a65908389bf7978758920897" + +[[package]] +name = "malloc_buf" +version = "0.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "62bb907fe88d54d8d9ce32a3cceab4218ed2f6b7d35617cafe9adf84e43919cb" +dependencies = [ + "libc", +] + [[package]] name = "memchr" version = "2.7.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f52b00d39961fc5b2736ea853c9cc86238e165017a493d1d5c8eac6bdc4cc273" +[[package]] +name = "metal" +version = "0.29.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7ecfd3296f8c56b7c1f6fbac3c71cefa9d78ce009850c45000015f206dc7fa21" +dependencies = [ + "bitflags 2.10.0", + "block", + "core-graphics-types", + "foreign-types", + "log", + "objc", + "paste", +] + [[package]] name = "num-traits" version = "0.2.19" @@ -332,6 +428,15 @@ dependencies = [ "autocfg", ] +[[package]] +name = "objc" +version = "0.2.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "915b1b472bc21c53464d6c8461c9d3af805ba1ef837e1cac254428f4a77177b1" +dependencies = [ + "malloc_buf", +] + [[package]] name = "once_cell" version = "1.21.3" @@ -344,6 +449,12 @@ version = "11.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d6790f58c7ff633d8771f42965289203411a5e5c68388703c06e14f24770b41e" +[[package]] +name = "paste" +version = "1.0.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a" + [[package]] name = "plotters" version = "0.3.7" @@ -398,7 +509,7 @@ checksum = "bee689443a2bd0a16ab0348b52ee43e3b2d1b1f931c8aa5c9f8de4c86fbe8c40" dependencies = [ "bit-set", "bit-vec", - "bitflags", + "bitflags 2.10.0", "num-traits", "rand", "rand_chacha", @@ -523,7 +634,7 @@ version = "1.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "146c9e247ccc180c1f61615433868c99f3de3ae256a30a43b49f67c2d9171f34" dependencies = [ - "bitflags", + "bitflags 2.10.0", "errno", "libc", "linux-raw-sys", @@ -696,6 +807,18 @@ dependencies = [ "tropical-types", ] +[[package]] +name = "tropical-gemm-metal" +version = "0.1.0" +dependencies = [ + "block", + "metal", + "objc", + "proptest", + "thiserror", + "tropical-types", +] + [[package]] name = "tropical-gemm-simd" version = "0.1.0" diff --git a/Cargo.toml b/Cargo.toml index 9c77181..b29366e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -5,6 +5,7 @@ members = [ "crates/tropical-gemm-core", "crates/tropical-gemm-simd", "crates/tropical-gemm-cuda", + "crates/tropical-gemm-metal", "crates/tropical-gemm", ] @@ -21,6 +22,7 @@ tropical-types = { path = "crates/tropical-types" } tropical-gemm-core = { path = "crates/tropical-gemm-core" } tropical-gemm-simd = { path = "crates/tropical-gemm-simd" } tropical-gemm-cuda = { path = "crates/tropical-gemm-cuda" } +tropical-gemm-metal = { path = "crates/tropical-gemm-metal" } tropical-gemm = { path = "crates/tropical-gemm" } # External dependencies diff --git a/benchmarks/bench_metal.sh b/benchmarks/bench_metal.sh new file mode 100755 index 0000000..69eb5cb --- /dev/null +++ b/benchmarks/bench_metal.sh @@ -0,0 +1,41 @@ +#!/bin/bash +# Metal GPU Benchmark Script for macOS +# +# Usage: ./bench_metal.sh [options] +# --quick Run quick benchmarks (smaller sizes) + +set -e + +SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" +PROJECT_DIR="$(dirname "$SCRIPT_DIR")" + +echo "========================================================================" +echo "Tropical GEMM Metal Benchmark (macOS)" +echo "========================================================================" +echo "" +echo "Project: $PROJECT_DIR" +echo "Date: $(date)" +echo "" + +# Create results directory +RESULTS_DIR="$SCRIPT_DIR/results" +mkdir -p "$RESULTS_DIR" +TIMESTAMP=$(date +%Y%m%d_%H%M%S) + +cd "$PROJECT_DIR" + +# Build release +echo "Building release binary..." +cargo build --release --example bench_comparison -p tropical-gemm-metal + +# Run benchmark +echo "" +echo "Running Metal GPU benchmark..." +OUTPUT="$RESULTS_DIR/metal_${TIMESTAMP}.txt" +cargo run --release --example bench_comparison -p tropical-gemm-metal 2>&1 | tee "$OUTPUT" + +echo "" +echo "========================================================================" +echo "Benchmark Complete" +echo "========================================================================" +echo "Results saved to: $OUTPUT" diff --git a/crates/tropical-gemm-metal/Cargo.toml b/crates/tropical-gemm-metal/Cargo.toml new file mode 100644 index 0000000..89822db --- /dev/null +++ b/crates/tropical-gemm-metal/Cargo.toml @@ -0,0 +1,20 @@ +[package] +name = "tropical-gemm-metal" +version.workspace = true +edition.workspace = true +license.workspace = true +authors.workspace = true +repository.workspace = true +description = "Metal backend for tropical matrix multiplication on Apple GPUs" +keywords = ["tropical", "matrix", "gpu", "metal", "apple"] +categories = ["science", "mathematics"] + +[dependencies] +tropical-types = { workspace = true } +thiserror = { workspace = true } +metal = "0.29" +objc = "0.2" +block = "0.1" + +[dev-dependencies] +proptest = { workspace = true } diff --git a/crates/tropical-gemm-metal/examples/bench_comparison.rs b/crates/tropical-gemm-metal/examples/bench_comparison.rs new file mode 100644 index 0000000..994f789 --- /dev/null +++ b/crates/tropical-gemm-metal/examples/bench_comparison.rs @@ -0,0 +1,110 @@ +//! Comprehensive benchmark for Metal GPU vs CPU comparison. + +use std::time::Instant; +use tropical_gemm_metal::{tropical_matmul_gpu, MetalContext}; +use tropical_types::{TropicalMaxPlus, TropicalMinPlus, TropicalMaxMul}; + +fn bench_cpu(a: &[f32], m: usize, k: usize, b: &[f32], n: usize) -> Vec { + // Simple CPU implementation for comparison + let mut c = vec![f32::NEG_INFINITY; m * n]; + for i in 0..m { + for j in 0..n { + for kk in 0..k { + let val = a[i * k + kk] + b[kk * n + j]; + if val > c[i * n + j] { + c[i * n + j] = val; + } + } + } + } + c +} + +fn main() { + println!("Tropical GEMM: Metal vs CPU Benchmark"); + println!("======================================\n"); + + let ctx = match MetalContext::new() { + Ok(c) => c, + Err(e) => { + eprintln!("Metal not available: {:?}", e); + return; + } + }; + println!("GPU: {}", ctx.device_name()); + println!(); + + // GPU vs CPU comparison + println!("### GPU vs CPU Performance\n"); + println!("| Size | CPU (ms) | Metal GPU (ms) | Speedup |"); + println!("|------|----------|----------------|---------|"); + + let sizes = [256, 512, 1024, 2048]; + + for &n in &sizes { + let a: Vec = (0..n * n).map(|i| (i % 100) as f32).collect(); + let b: Vec = (0..n * n).map(|i| ((i + 50) % 100) as f32).collect(); + + // Warmup GPU + let _ = tropical_matmul_gpu::>(&a, n, n, &b, n); + + // Benchmark GPU + let gpu_iters = if n <= 512 { 10 } else { 3 }; + let start = Instant::now(); + for _ in 0..gpu_iters { + let _ = tropical_matmul_gpu::>(&a, n, n, &b, n); + } + let gpu_ms = start.elapsed().as_secs_f64() * 1000.0 / gpu_iters as f64; + + // Benchmark CPU (fewer iterations for large sizes) + let cpu_iters = if n <= 512 { 3 } else { 1 }; + let start = Instant::now(); + for _ in 0..cpu_iters { + let _ = bench_cpu(&a, n, n, &b, n); + } + let cpu_ms = start.elapsed().as_secs_f64() * 1000.0 / cpu_iters as f64; + + let speedup = cpu_ms / gpu_ms; + println!("| {:>4} | {:>8.1} | {:>14.3} | **{:.0}x** |", + n, cpu_ms, gpu_ms, speedup); + } + + // All semirings benchmark + println!("\n### All Semirings (Metal GPU Kernel Time)\n"); + println!("| Size | MaxPlus (ms) | MinPlus (ms) | MaxMul (ms) |"); + println!("|------|--------------|--------------|-------------|"); + + for &n in &sizes { + let a: Vec = (0..n * n).map(|i| (i % 100) as f32 + 1.0).collect(); + let b: Vec = (0..n * n).map(|i| ((i + 50) % 100) as f32 + 1.0).collect(); + + let iters = if n <= 512 { 10 } else { 3 }; + + // MaxPlus + let _ = tropical_matmul_gpu::>(&a, n, n, &b, n); + let start = Instant::now(); + for _ in 0..iters { + let _ = tropical_matmul_gpu::>(&a, n, n, &b, n); + } + let maxplus_ms = start.elapsed().as_secs_f64() * 1000.0 / iters as f64; + + // MinPlus + let _ = tropical_matmul_gpu::>(&a, n, n, &b, n); + let start = Instant::now(); + for _ in 0..iters { + let _ = tropical_matmul_gpu::>(&a, n, n, &b, n); + } + let minplus_ms = start.elapsed().as_secs_f64() * 1000.0 / iters as f64; + + // MaxMul + let _ = tropical_matmul_gpu::>(&a, n, n, &b, n); + let start = Instant::now(); + for _ in 0..iters { + let _ = tropical_matmul_gpu::>(&a, n, n, &b, n); + } + let maxmul_ms = start.elapsed().as_secs_f64() * 1000.0 / iters as f64; + + println!("| {:>4} | {:>12.3} | {:>12.3} | {:>11.3} |", + n, maxplus_ms, minplus_ms, maxmul_ms); + } +} diff --git a/crates/tropical-gemm-metal/examples/bench_metal.rs b/crates/tropical-gemm-metal/examples/bench_metal.rs new file mode 100644 index 0000000..fe48a88 --- /dev/null +++ b/crates/tropical-gemm-metal/examples/bench_metal.rs @@ -0,0 +1,48 @@ +//! Benchmark comparing Metal GPU vs CPU performance. + +use std::time::Instant; +use tropical_gemm_metal::{tropical_matmul_gpu, MetalContext}; +use tropical_types::TropicalMaxPlus; + +fn main() { + println!("Tropical GEMM Metal Benchmark"); + println!("==============================\n"); + + // Check Metal availability + let ctx = match MetalContext::new() { + Ok(c) => c, + Err(e) => { + eprintln!("Metal not available: {:?}", e); + return; + } + }; + println!("Device: {}\n", ctx.device_name()); + + let sizes = [256, 512, 1024, 2048]; + + println!("{:>6} {:>12} {:>12}", "Size", "GPU (ms)", "GFLOPS"); + println!("{:-<6} {:-<12} {:-<12}", "", "", ""); + + for &n in &sizes { + let a: Vec = (0..n * n).map(|i| (i % 100) as f32).collect(); + let b: Vec = (0..n * n).map(|i| ((i + 50) % 100) as f32).collect(); + + // Warmup + let _ = tropical_matmul_gpu::>(&a, n, n, &b, n); + + // Benchmark + let iterations = if n <= 512 { 10 } else { 3 }; + let start = Instant::now(); + for _ in 0..iterations { + let _ = tropical_matmul_gpu::>(&a, n, n, &b, n); + } + let elapsed = start.elapsed(); + let avg_ms = elapsed.as_secs_f64() * 1000.0 / iterations as f64; + + // Calculate GFLOPS (2 ops per element: add + max) + let ops = 2.0 * (n as f64).powi(3); + let gflops = ops / (avg_ms / 1000.0) / 1e9; + + println!("{:>6} {:>12.3} {:>12.1}", n, avg_ms, gflops); + } +} diff --git a/crates/tropical-gemm-metal/shaders/tropical_gemm.metal b/crates/tropical-gemm-metal/shaders/tropical_gemm.metal new file mode 100644 index 0000000..e84df3e --- /dev/null +++ b/crates/tropical-gemm-metal/shaders/tropical_gemm.metal @@ -0,0 +1,307 @@ +// Tropical GEMM Metal Shaders +// High-performance tropical matrix multiplication for Apple GPUs + +#include +using namespace metal; + +// Blocking parameters +constant int BLOCK_SIZE_M = 32; +constant int BLOCK_SIZE_N = 32; +constant int BLOCK_SIZE_K = 32; +constant int THREAD_SIZE_M = 4; +constant int THREAD_SIZE_N = 4; + +// Helper macro for column-major indexing +#define OFFSET_COL(row, col, ld) ((col) * (ld) + (row)) + +// ============================================================================ +// TropicalMaxPlus: C[i,j] = max_k(A[i,k] + B[k,j]) +// ============================================================================ + +kernel void tropical_maxplus_f32( + device const float* A [[buffer(0)]], + device const float* B [[buffer(1)]], + device float* C [[buffer(2)]], + constant int& M [[buffer(3)]], + constant int& N [[buffer(4)]], + constant int& K [[buffer(5)]], + uint2 gid [[threadgroup_position_in_grid]], + uint2 tid [[thread_position_in_threadgroup]] +) { + const int threads_per_group_m = BLOCK_SIZE_M / THREAD_SIZE_M; + const int threads_per_group_n = BLOCK_SIZE_N / THREAD_SIZE_N; + + // Shared memory for tiles + threadgroup float As[BLOCK_SIZE_M * BLOCK_SIZE_K]; + threadgroup float Bs[BLOCK_SIZE_K * BLOCK_SIZE_N]; + + // Accumulators in registers + float accum[THREAD_SIZE_M * THREAD_SIZE_N]; + float regs_a[THREAD_SIZE_M]; + float regs_b[THREAD_SIZE_N]; + + // Initialize with -infinity (tropical zero for MaxPlus) + for (int i = 0; i < THREAD_SIZE_M * THREAD_SIZE_N; ++i) { + accum[i] = -INFINITY; + } + + const int linear_tid = tid.y * threads_per_group_m + tid.x; + const int THREAD_NUM = threads_per_group_m * threads_per_group_n; + + const int A_TILE_ROW = linear_tid % BLOCK_SIZE_M; + const int A_TILE_COL = linear_tid / BLOCK_SIZE_M; + const int B_TILE_ROW = linear_tid % BLOCK_SIZE_K; + const int B_TILE_COL = linear_tid / BLOCK_SIZE_K; + + const int A_STRIDE = THREAD_NUM / BLOCK_SIZE_M; + const int B_STRIDE = THREAD_NUM / BLOCK_SIZE_K; + + // Loop over K dimension tiles + for (int tile_idx = 0; tile_idx < K; tile_idx += BLOCK_SIZE_K) { + // Load A tile cooperatively + for (int i = 0; i < BLOCK_SIZE_K; i += A_STRIDE) { + int row = BLOCK_SIZE_M * gid.x + A_TILE_ROW; + int col = A_TILE_COL + i + tile_idx; + float val = -INFINITY; + if (row < M && col < K) { + val = A[OFFSET_COL(row, col, M)]; + } + As[OFFSET_COL(A_TILE_ROW, i + A_TILE_COL, BLOCK_SIZE_M)] = val; + } + + // Load B tile cooperatively + for (int i = 0; i < BLOCK_SIZE_N; i += B_STRIDE) { + int row = tile_idx + B_TILE_ROW; + int col = BLOCK_SIZE_N * gid.y + i + B_TILE_COL; + float val = -INFINITY; + if (row < K && col < N) { + val = B[OFFSET_COL(row, col, K)]; + } + Bs[OFFSET_COL(B_TILE_ROW, i + B_TILE_COL, BLOCK_SIZE_K)] = val; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Compute tile contribution + for (int k = 0; k < BLOCK_SIZE_K; ++k) { + // Load A values into registers + for (int tm = 0; tm < THREAD_SIZE_M; ++tm) { + regs_a[tm] = As[OFFSET_COL(tid.x * THREAD_SIZE_M + tm, k, BLOCK_SIZE_M)]; + } + + // Load B values into registers + for (int tn = 0; tn < THREAD_SIZE_N; ++tn) { + regs_b[tn] = Bs[OFFSET_COL(k, tid.y * THREAD_SIZE_N + tn, BLOCK_SIZE_K)]; + } + + // Compute outer product with tropical operations + for (int tm = 0; tm < THREAD_SIZE_M; ++tm) { + for (int tn = 0; tn < THREAD_SIZE_N; ++tn) { + float prod = regs_a[tm] + regs_b[tn]; + int idx = OFFSET_COL(tm, tn, THREAD_SIZE_M); + accum[idx] = max(accum[idx], prod); + } + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + // Store results + for (int tm = 0; tm < THREAD_SIZE_M; ++tm) { + for (int tn = 0; tn < THREAD_SIZE_N; ++tn) { + int row = BLOCK_SIZE_M * gid.x + THREAD_SIZE_M * tid.x + tm; + int col = BLOCK_SIZE_N * gid.y + THREAD_SIZE_N * tid.y + tn; + if (row < M && col < N) { + C[OFFSET_COL(row, col, M)] = accum[OFFSET_COL(tm, tn, THREAD_SIZE_M)]; + } + } + } +} + +// ============================================================================ +// TropicalMinPlus: C[i,j] = min_k(A[i,k] + B[k,j]) +// ============================================================================ + +kernel void tropical_minplus_f32( + device const float* A [[buffer(0)]], + device const float* B [[buffer(1)]], + device float* C [[buffer(2)]], + constant int& M [[buffer(3)]], + constant int& N [[buffer(4)]], + constant int& K [[buffer(5)]], + uint2 gid [[threadgroup_position_in_grid]], + uint2 tid [[thread_position_in_threadgroup]] +) { + const int threads_per_group_m = BLOCK_SIZE_M / THREAD_SIZE_M; + const int threads_per_group_n = BLOCK_SIZE_N / THREAD_SIZE_N; + + threadgroup float As[BLOCK_SIZE_M * BLOCK_SIZE_K]; + threadgroup float Bs[BLOCK_SIZE_K * BLOCK_SIZE_N]; + + float accum[THREAD_SIZE_M * THREAD_SIZE_N]; + float regs_a[THREAD_SIZE_M]; + float regs_b[THREAD_SIZE_N]; + + // Initialize with +infinity (tropical zero for MinPlus) + for (int i = 0; i < THREAD_SIZE_M * THREAD_SIZE_N; ++i) { + accum[i] = INFINITY; + } + + const int linear_tid = tid.y * threads_per_group_m + tid.x; + const int THREAD_NUM = threads_per_group_m * threads_per_group_n; + + const int A_TILE_ROW = linear_tid % BLOCK_SIZE_M; + const int A_TILE_COL = linear_tid / BLOCK_SIZE_M; + const int B_TILE_ROW = linear_tid % BLOCK_SIZE_K; + const int B_TILE_COL = linear_tid / BLOCK_SIZE_K; + + const int A_STRIDE = THREAD_NUM / BLOCK_SIZE_M; + const int B_STRIDE = THREAD_NUM / BLOCK_SIZE_K; + + for (int tile_idx = 0; tile_idx < K; tile_idx += BLOCK_SIZE_K) { + for (int i = 0; i < BLOCK_SIZE_K; i += A_STRIDE) { + int row = BLOCK_SIZE_M * gid.x + A_TILE_ROW; + int col = A_TILE_COL + i + tile_idx; + float val = INFINITY; + if (row < M && col < K) { + val = A[OFFSET_COL(row, col, M)]; + } + As[OFFSET_COL(A_TILE_ROW, i + A_TILE_COL, BLOCK_SIZE_M)] = val; + } + + for (int i = 0; i < BLOCK_SIZE_N; i += B_STRIDE) { + int row = tile_idx + B_TILE_ROW; + int col = BLOCK_SIZE_N * gid.y + i + B_TILE_COL; + float val = INFINITY; + if (row < K && col < N) { + val = B[OFFSET_COL(row, col, K)]; + } + Bs[OFFSET_COL(B_TILE_ROW, i + B_TILE_COL, BLOCK_SIZE_K)] = val; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + for (int k = 0; k < BLOCK_SIZE_K; ++k) { + for (int tm = 0; tm < THREAD_SIZE_M; ++tm) { + regs_a[tm] = As[OFFSET_COL(tid.x * THREAD_SIZE_M + tm, k, BLOCK_SIZE_M)]; + } + for (int tn = 0; tn < THREAD_SIZE_N; ++tn) { + regs_b[tn] = Bs[OFFSET_COL(k, tid.y * THREAD_SIZE_N + tn, BLOCK_SIZE_K)]; + } + for (int tm = 0; tm < THREAD_SIZE_M; ++tm) { + for (int tn = 0; tn < THREAD_SIZE_N; ++tn) { + float prod = regs_a[tm] + regs_b[tn]; + int idx = OFFSET_COL(tm, tn, THREAD_SIZE_M); + accum[idx] = min(accum[idx], prod); + } + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + for (int tm = 0; tm < THREAD_SIZE_M; ++tm) { + for (int tn = 0; tn < THREAD_SIZE_N; ++tn) { + int row = BLOCK_SIZE_M * gid.x + THREAD_SIZE_M * tid.x + tm; + int col = BLOCK_SIZE_N * gid.y + THREAD_SIZE_N * tid.y + tn; + if (row < M && col < N) { + C[OFFSET_COL(row, col, M)] = accum[OFFSET_COL(tm, tn, THREAD_SIZE_M)]; + } + } + } +} + +// ============================================================================ +// TropicalMaxMul: C[i,j] = max_k(A[i,k] * B[k,j]) +// ============================================================================ + +kernel void tropical_maxmul_f32( + device const float* A [[buffer(0)]], + device const float* B [[buffer(1)]], + device float* C [[buffer(2)]], + constant int& M [[buffer(3)]], + constant int& N [[buffer(4)]], + constant int& K [[buffer(5)]], + uint2 gid [[threadgroup_position_in_grid]], + uint2 tid [[thread_position_in_threadgroup]] +) { + const int threads_per_group_m = BLOCK_SIZE_M / THREAD_SIZE_M; + const int threads_per_group_n = BLOCK_SIZE_N / THREAD_SIZE_N; + + threadgroup float As[BLOCK_SIZE_M * BLOCK_SIZE_K]; + threadgroup float Bs[BLOCK_SIZE_K * BLOCK_SIZE_N]; + + float accum[THREAD_SIZE_M * THREAD_SIZE_N]; + float regs_a[THREAD_SIZE_M]; + float regs_b[THREAD_SIZE_N]; + + // Initialize with 0 (tropical zero for MaxMul) + for (int i = 0; i < THREAD_SIZE_M * THREAD_SIZE_N; ++i) { + accum[i] = 0.0f; + } + + const int linear_tid = tid.y * threads_per_group_m + tid.x; + const int THREAD_NUM = threads_per_group_m * threads_per_group_n; + + const int A_TILE_ROW = linear_tid % BLOCK_SIZE_M; + const int A_TILE_COL = linear_tid / BLOCK_SIZE_M; + const int B_TILE_ROW = linear_tid % BLOCK_SIZE_K; + const int B_TILE_COL = linear_tid / BLOCK_SIZE_K; + + const int A_STRIDE = THREAD_NUM / BLOCK_SIZE_M; + const int B_STRIDE = THREAD_NUM / BLOCK_SIZE_K; + + for (int tile_idx = 0; tile_idx < K; tile_idx += BLOCK_SIZE_K) { + for (int i = 0; i < BLOCK_SIZE_K; i += A_STRIDE) { + int row = BLOCK_SIZE_M * gid.x + A_TILE_ROW; + int col = A_TILE_COL + i + tile_idx; + float val = 0.0f; + if (row < M && col < K) { + val = A[OFFSET_COL(row, col, M)]; + } + As[OFFSET_COL(A_TILE_ROW, i + A_TILE_COL, BLOCK_SIZE_M)] = val; + } + + for (int i = 0; i < BLOCK_SIZE_N; i += B_STRIDE) { + int row = tile_idx + B_TILE_ROW; + int col = BLOCK_SIZE_N * gid.y + i + B_TILE_COL; + float val = 0.0f; + if (row < K && col < N) { + val = B[OFFSET_COL(row, col, K)]; + } + Bs[OFFSET_COL(B_TILE_ROW, i + B_TILE_COL, BLOCK_SIZE_K)] = val; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + for (int k = 0; k < BLOCK_SIZE_K; ++k) { + for (int tm = 0; tm < THREAD_SIZE_M; ++tm) { + regs_a[tm] = As[OFFSET_COL(tid.x * THREAD_SIZE_M + tm, k, BLOCK_SIZE_M)]; + } + for (int tn = 0; tn < THREAD_SIZE_N; ++tn) { + regs_b[tn] = Bs[OFFSET_COL(k, tid.y * THREAD_SIZE_N + tn, BLOCK_SIZE_K)]; + } + for (int tm = 0; tm < THREAD_SIZE_M; ++tm) { + for (int tn = 0; tn < THREAD_SIZE_N; ++tn) { + // MaxMul: tropical_mul = *, tropical_add = max + float prod = regs_a[tm] * regs_b[tn]; + int idx = OFFSET_COL(tm, tn, THREAD_SIZE_M); + accum[idx] = max(accum[idx], prod); + } + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + for (int tm = 0; tm < THREAD_SIZE_M; ++tm) { + for (int tn = 0; tn < THREAD_SIZE_N; ++tn) { + int row = BLOCK_SIZE_M * gid.x + THREAD_SIZE_M * tid.x + tm; + int col = BLOCK_SIZE_N * gid.y + THREAD_SIZE_N * tid.y + tn; + if (row < M && col < N) { + C[OFFSET_COL(row, col, M)] = accum[OFFSET_COL(tm, tn, THREAD_SIZE_M)]; + } + } + } +} diff --git a/crates/tropical-gemm-metal/src/context.rs b/crates/tropical-gemm-metal/src/context.rs new file mode 100644 index 0000000..901473c --- /dev/null +++ b/crates/tropical-gemm-metal/src/context.rs @@ -0,0 +1,110 @@ +//! Metal context and device management. + +use crate::error::{MetalError, Result}; +use metal::{Device, ComputePipelineState, CommandQueue}; +use std::collections::HashMap; + +/// Metal shader source code. +const SHADER_SOURCE: &str = include_str!("../shaders/tropical_gemm.metal"); + +/// Blocking parameters for f32 kernels. +pub const BLOCK_SIZE_M_F32: u32 = 32; +pub const BLOCK_SIZE_N_F32: u32 = 32; +pub const THREAD_SIZE_M: u32 = 4; +pub const THREAD_SIZE_N: u32 = 4; + +/// Kernel function names. +const KERNEL_NAMES: &[&str] = &[ + "tropical_maxplus_f32", + "tropical_minplus_f32", + "tropical_maxmul_f32", +]; + +/// Metal context for tropical GEMM operations. +/// +/// Manages device selection, shader compilation, and pipeline caching. +pub struct MetalContext { + device: Device, + command_queue: CommandQueue, + pipelines: HashMap<&'static str, ComputePipelineState>, +} + +impl MetalContext { + /// Create a new Metal context on the default device. + pub fn new() -> Result { + let device = Device::system_default().ok_or(MetalError::NoDevice)?; + Self::from_device(device) + } + + /// Create a context from an existing device. + pub fn from_device(device: Device) -> Result { + let command_queue = device + .new_command_queue(); + + // Compile shaders + let options = metal::CompileOptions::new(); + let library = device + .new_library_with_source(SHADER_SOURCE, &options) + .map_err(|e| MetalError::ShaderCompilation(e.to_string()))?; + + // Create compute pipelines for each kernel + let mut pipelines = HashMap::new(); + for name in KERNEL_NAMES { + let function = library + .get_function(name, None) + .map_err(|_| MetalError::KernelNotFound(name.to_string()))?; + + let pipeline = device + .new_compute_pipeline_state_with_function(&function) + .map_err(|e| MetalError::PipelineCreation(e.to_string()))?; + + pipelines.insert(*name, pipeline); + } + + Ok(Self { + device, + command_queue, + pipelines, + }) + } + + /// Get the underlying Metal device. + pub fn device(&self) -> &Device { + &self.device + } + + /// Get the command queue. + pub fn command_queue(&self) -> &CommandQueue { + &self.command_queue + } + + /// Get a compute pipeline by kernel name. + pub fn get_pipeline(&self, name: &'static str) -> Result<&ComputePipelineState> { + self.pipelines + .get(name) + .ok_or_else(|| MetalError::KernelNotFound(name.to_string())) + } + + /// Get GPU device name. + pub fn device_name(&self) -> String { + self.device.name().to_string() + } + + /// Calculate threadgroup size for a kernel. + pub fn threadgroup_size(&self) -> metal::MTLSize { + let threads_per_group_m = BLOCK_SIZE_M_F32 / THREAD_SIZE_M; + let threads_per_group_n = BLOCK_SIZE_N_F32 / THREAD_SIZE_N; + metal::MTLSize::new( + threads_per_group_m as u64, + threads_per_group_n as u64, + 1, + ) + } + + /// Calculate grid size for given matrix dimensions. + pub fn grid_size(&self, m: usize, n: usize) -> metal::MTLSize { + let grid_x = (m as u64 + BLOCK_SIZE_M_F32 as u64 - 1) / BLOCK_SIZE_M_F32 as u64; + let grid_y = (n as u64 + BLOCK_SIZE_N_F32 as u64 - 1) / BLOCK_SIZE_N_F32 as u64; + metal::MTLSize::new(grid_x, grid_y, 1) + } +} diff --git a/crates/tropical-gemm-metal/src/error.rs b/crates/tropical-gemm-metal/src/error.rs new file mode 100644 index 0000000..a10b55e --- /dev/null +++ b/crates/tropical-gemm-metal/src/error.rs @@ -0,0 +1,42 @@ +//! Error types for Metal operations. + +use thiserror::Error; + +/// Errors that can occur during Metal operations. +#[derive(Debug, Error)] +pub enum MetalError { + /// No Metal device available. + #[error("No Metal device available")] + NoDevice, + + /// Failed to create command queue. + #[error("Failed to create command queue")] + CommandQueueCreation, + + /// Failed to create compute pipeline. + #[error("Failed to create compute pipeline: {0}")] + PipelineCreation(String), + + /// Failed to create buffer. + #[error("Failed to create buffer")] + BufferCreation, + + /// Dimension mismatch. + #[error("Dimension mismatch: {0}")] + DimensionMismatch(String), + + /// Kernel not found. + #[error("Kernel not found: {0}")] + KernelNotFound(String), + + /// Shader compilation error. + #[error("Shader compilation error: {0}")] + ShaderCompilation(String), + + /// Command buffer error. + #[error("Command buffer error: {0}")] + CommandBuffer(String), +} + +/// Result type for Metal operations. +pub type Result = std::result::Result; diff --git a/crates/tropical-gemm-metal/src/kernels.rs b/crates/tropical-gemm-metal/src/kernels.rs new file mode 100644 index 0000000..870245f --- /dev/null +++ b/crates/tropical-gemm-metal/src/kernels.rs @@ -0,0 +1,98 @@ +//! Metal kernel trait and implementations. + +use crate::context::MetalContext; +use crate::error::Result; +use crate::memory::GpuMatrix; +use tropical_types::{TropicalMaxMul, TropicalMaxPlus, TropicalMinPlus, TropicalSemiring}; + +/// Trait for types that can be computed on Metal GPU. +pub trait MetalKernel: TropicalSemiring +where + Self::Scalar: Clone + Default + Copy, +{ + /// Kernel function name. + const KERNEL_NAME: &'static str; + + /// Execute the tropical GEMM kernel. + fn launch_gemm( + ctx: &MetalContext, + a: &GpuMatrix, + b: &GpuMatrix, + c: &mut GpuMatrix, + ) -> Result<()>; +} + +/// Helper function to launch a Metal compute kernel. +fn launch_kernel_impl( + ctx: &MetalContext, + kernel_name: &'static str, + a: &GpuMatrix, + b: &GpuMatrix, + c: &mut GpuMatrix, +) -> Result<()> { + let m = a.rows() as i32; + let k = a.cols() as i32; + let n = b.cols() as i32; + + let pipeline = ctx.get_pipeline(kernel_name)?; + let command_buffer = ctx.command_queue().new_command_buffer(); + let encoder = command_buffer.new_compute_command_encoder(); + + encoder.set_compute_pipeline_state(pipeline); + encoder.set_buffer(0, Some(a.buffer()), 0); + encoder.set_buffer(1, Some(b.buffer()), 0); + encoder.set_buffer(2, Some(c.buffer()), 0); + encoder.set_bytes(3, std::mem::size_of::() as u64, &m as *const i32 as *const _); + encoder.set_bytes(4, std::mem::size_of::() as u64, &n as *const i32 as *const _); + encoder.set_bytes(5, std::mem::size_of::() as u64, &k as *const i32 as *const _); + + let grid_size = ctx.grid_size(a.rows(), b.cols()); + let threadgroup_size = ctx.threadgroup_size(); + + encoder.dispatch_thread_groups(grid_size, threadgroup_size); + encoder.end_encoding(); + + command_buffer.commit(); + command_buffer.wait_until_completed(); + + Ok(()) +} + +impl MetalKernel for TropicalMaxPlus { + const KERNEL_NAME: &'static str = "tropical_maxplus_f32"; + + fn launch_gemm( + ctx: &MetalContext, + a: &GpuMatrix, + b: &GpuMatrix, + c: &mut GpuMatrix, + ) -> Result<()> { + launch_kernel_impl(ctx, Self::KERNEL_NAME, a, b, c) + } +} + +impl MetalKernel for TropicalMinPlus { + const KERNEL_NAME: &'static str = "tropical_minplus_f32"; + + fn launch_gemm( + ctx: &MetalContext, + a: &GpuMatrix, + b: &GpuMatrix, + c: &mut GpuMatrix, + ) -> Result<()> { + launch_kernel_impl(ctx, Self::KERNEL_NAME, a, b, c) + } +} + +impl MetalKernel for TropicalMaxMul { + const KERNEL_NAME: &'static str = "tropical_maxmul_f32"; + + fn launch_gemm( + ctx: &MetalContext, + a: &GpuMatrix, + b: &GpuMatrix, + c: &mut GpuMatrix, + ) -> Result<()> { + launch_kernel_impl(ctx, Self::KERNEL_NAME, a, b, c) + } +} diff --git a/crates/tropical-gemm-metal/src/lib.rs b/crates/tropical-gemm-metal/src/lib.rs new file mode 100644 index 0000000..eb21c7b --- /dev/null +++ b/crates/tropical-gemm-metal/src/lib.rs @@ -0,0 +1,210 @@ +//! Metal backend for tropical matrix multiplication. +//! +//! This crate provides GPU-accelerated tropical GEMM operations using Metal on Apple GPUs. +//! +//! # Quick Start +//! +//! ```ignore +//! use tropical_gemm_metal::{tropical_matmul_gpu, MetalContext}; +//! use tropical_types::TropicalMaxPlus; +//! +//! // Simple one-shot API +//! let a = vec![1.0f32; 1024 * 1024]; +//! let b = vec![1.0f32; 1024 * 1024]; +//! let c = tropical_matmul_gpu::>(&a, 1024, 1024, &b, 1024)?; +//! ``` +//! +//! # Persistent Context +//! +//! For multiple operations, reuse the Metal context: +//! +//! ```ignore +//! use tropical_gemm_metal::{MetalContext, GpuMatrix, tropical_gemm_gpu}; +//! use tropical_types::TropicalMaxPlus; +//! +//! let ctx = MetalContext::new()?; +//! +//! let a_gpu = GpuMatrix::from_host_row_major(&ctx, &a, m, k)?; +//! let b_gpu = GpuMatrix::from_host_row_major(&ctx, &b, k, n)?; +//! let mut c_gpu = GpuMatrix::alloc(&ctx, m, n)?; +//! +//! tropical_gemm_gpu::>(&ctx, &a_gpu, &b_gpu, &mut c_gpu)?; +//! +//! let c = c_gpu.to_host_row_major(); +//! ``` + +mod context; +mod error; +mod kernels; +mod memory; + +pub use context::MetalContext; +pub use error::{MetalError, Result}; +pub use kernels::MetalKernel; +pub use memory::GpuMatrix; + +/// One-shot tropical matrix multiplication on GPU. +/// +/// This function handles all GPU memory management automatically. +/// For repeated operations, use `tropical_gemm_gpu` with a persistent context. +/// +/// # Arguments +/// +/// * `a` - Matrix A in row-major order, dimensions m×k +/// * `m` - Number of rows in A +/// * `k` - Number of columns in A / rows in B +/// * `b` - Matrix B in row-major order, dimensions k×n +/// * `n` - Number of columns in B +/// +/// # Returns +/// +/// Result matrix C in row-major order, dimensions m×n +/// +/// # Example +/// +/// ```ignore +/// use tropical_gemm_metal::tropical_matmul_gpu; +/// use tropical_types::TropicalMaxPlus; +/// +/// let a = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; // 2x3 +/// let b = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; // 3x2 +/// +/// let c = tropical_matmul_gpu::>(&a, 2, 3, &b, 2)?; +/// // c is 2x2, row-major +/// ``` +pub fn tropical_matmul_gpu( + a: &[T::Scalar], + m: usize, + k: usize, + b: &[T::Scalar], + n: usize, +) -> Result> +where + T: MetalKernel, + T::Scalar: Clone + Default + Copy, +{ + if a.len() != m * k { + return Err(MetalError::DimensionMismatch(format!( + "A: expected {} elements, got {}", + m * k, + a.len() + ))); + } + if b.len() != k * n { + return Err(MetalError::DimensionMismatch(format!( + "B: expected {} elements, got {}", + k * n, + b.len() + ))); + } + + let ctx = MetalContext::new()?; + + let a_gpu = GpuMatrix::from_host_row_major(&ctx, a, m, k)?; + let b_gpu = GpuMatrix::from_host_row_major(&ctx, b, k, n)?; + let mut c_gpu = GpuMatrix::alloc(&ctx, m, n)?; + + T::launch_gemm(&ctx, &a_gpu, &b_gpu, &mut c_gpu)?; + + Ok(c_gpu.to_host_row_major()) +} + +/// Tropical matrix multiplication with persistent context. +/// +/// Use this function when performing multiple GPU operations to avoid +/// repeated context initialization and kernel compilation. +/// +/// # Arguments +/// +/// * `ctx` - Metal context +/// * `a` - Matrix A on GPU +/// * `b` - Matrix B on GPU +/// * `c` - Output matrix C on GPU (will be overwritten) +pub fn tropical_gemm_gpu( + ctx: &MetalContext, + a: &GpuMatrix, + b: &GpuMatrix, + c: &mut GpuMatrix, +) -> Result<()> +where + T: MetalKernel, + T::Scalar: Clone + Default + Copy, +{ + if a.cols() != b.rows() { + return Err(MetalError::DimensionMismatch(format!( + "A.cols ({}) != B.rows ({})", + a.cols(), + b.rows() + ))); + } + if c.rows() != a.rows() || c.cols() != b.cols() { + return Err(MetalError::DimensionMismatch(format!( + "C dimensions ({}, {}) don't match A×B ({}, {})", + c.rows(), + c.cols(), + a.rows(), + b.cols() + ))); + } + + T::launch_gemm(ctx, a, b, c) +} + +/// Tropical matrix multiplication with context, returning a new GPU matrix. +/// +/// Allocates the output matrix automatically. +pub fn tropical_matmul_gpu_with_ctx( + ctx: &MetalContext, + a: &GpuMatrix, + b: &GpuMatrix, +) -> Result> +where + T: MetalKernel, + T::Scalar: Clone + Default + Copy, +{ + if a.cols() != b.rows() { + return Err(MetalError::DimensionMismatch(format!( + "A.cols ({}) != B.rows ({})", + a.cols(), + b.rows() + ))); + } + + let mut c = GpuMatrix::alloc(ctx, a.rows(), b.cols())?; + T::launch_gemm(ctx, a, b, &mut c)?; + Ok(c) +} + +#[cfg(test)] +mod tests { + use super::*; + use tropical_types::TropicalMaxPlus; + + #[test] + fn test_tropical_matmul_gpu_small() { + // Skip if Metal not available + let _ctx = match MetalContext::new() { + Ok(c) => c, + Err(_) => { + println!("Metal not available, skipping test"); + return; + } + }; + + // 2x3 matrix A + let a = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; + // 3x2 matrix B + let b = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; + + let c = tropical_matmul_gpu::>(&a, 2, 3, &b, 2).unwrap(); + + // C[0,0] = max(1+1, 2+3, 3+5) = 8 + assert!((c[0] - 8.0).abs() < 1e-5, "C[0,0] = {}, expected 8", c[0]); + // C[0,1] = max(1+2, 2+4, 3+6) = 9 + assert!((c[1] - 9.0).abs() < 1e-5, "C[0,1] = {}, expected 9", c[1]); + // C[1,0] = max(4+1, 5+3, 6+5) = 11 + assert!((c[2] - 11.0).abs() < 1e-5, "C[1,0] = {}, expected 11", c[2]); + // C[1,1] = max(4+2, 5+4, 6+6) = 12 + assert!((c[3] - 12.0).abs() < 1e-5, "C[1,1] = {}, expected 12", c[3]); + } +} diff --git a/crates/tropical-gemm-metal/src/memory.rs b/crates/tropical-gemm-metal/src/memory.rs new file mode 100644 index 0000000..63ddcf7 --- /dev/null +++ b/crates/tropical-gemm-metal/src/memory.rs @@ -0,0 +1,145 @@ +//! GPU memory management for matrices. + +use crate::context::MetalContext; +use crate::error::{MetalError, Result}; +use metal::{Buffer, MTLResourceOptions}; +use std::marker::PhantomData; + +/// A matrix stored in GPU memory. +/// +/// Data is stored in column-major order for compatibility with BLAS conventions. +pub struct GpuMatrix { + buffer: Buffer, + rows: usize, + cols: usize, + _marker: PhantomData, +} + +impl GpuMatrix { + /// Create a GPU matrix from host data in row-major order. + /// + /// The input data will be transposed to column-major for GPU storage. + pub fn from_host_row_major( + ctx: &MetalContext, + data: &[T], + rows: usize, + cols: usize, + ) -> Result { + if data.len() != rows * cols { + return Err(MetalError::DimensionMismatch(format!( + "Expected {} elements, got {}", + rows * cols, + data.len() + ))); + } + + // Transpose to column-major + let mut col_major = vec![T::default(); rows * cols]; + for i in 0..rows { + for j in 0..cols { + col_major[j * rows + i] = data[i * cols + j]; + } + } + + let byte_size = (rows * cols * std::mem::size_of::()) as u64; + let buffer = ctx.device().new_buffer_with_data( + col_major.as_ptr() as *const _, + byte_size, + MTLResourceOptions::StorageModeShared, + ); + + Ok(Self { + buffer, + rows, + cols, + _marker: PhantomData, + }) + } + + /// Create a GPU matrix from column-major host data (no transpose). + pub fn from_host_col_major( + ctx: &MetalContext, + data: &[T], + rows: usize, + cols: usize, + ) -> Result { + if data.len() != rows * cols { + return Err(MetalError::DimensionMismatch(format!( + "Expected {} elements, got {}", + rows * cols, + data.len() + ))); + } + + let byte_size = (rows * cols * std::mem::size_of::()) as u64; + let buffer = ctx.device().new_buffer_with_data( + data.as_ptr() as *const _, + byte_size, + MTLResourceOptions::StorageModeShared, + ); + + Ok(Self { + buffer, + rows, + cols, + _marker: PhantomData, + }) + } + + /// Allocate a zeroed GPU matrix. + pub fn alloc(ctx: &MetalContext, rows: usize, cols: usize) -> Result { + let byte_size = (rows * cols * std::mem::size_of::()) as u64; + let buffer = ctx.device().new_buffer( + byte_size, + MTLResourceOptions::StorageModeShared, + ); + + Ok(Self { + buffer, + rows, + cols, + _marker: PhantomData, + }) + } + + /// Copy GPU data back to host in row-major order. + pub fn to_host_row_major(&self) -> Vec { + let ptr = self.buffer.contents() as *const T; + let col_major: Vec = unsafe { + std::slice::from_raw_parts(ptr, self.rows * self.cols).to_vec() + }; + + // Transpose from column-major to row-major + let mut row_major = vec![T::default(); self.rows * self.cols]; + for i in 0..self.rows { + for j in 0..self.cols { + row_major[i * self.cols + j] = col_major[j * self.rows + i]; + } + } + + row_major + } + + /// Copy GPU data back to host in column-major order. + pub fn to_host_col_major(&self) -> Vec { + let ptr = self.buffer.contents() as *const T; + unsafe { + std::slice::from_raw_parts(ptr, self.rows * self.cols).to_vec() + } + } + + /// Get the number of rows. + pub fn rows(&self) -> usize { + self.rows + } + + /// Get the number of columns. + pub fn cols(&self) -> usize { + self.cols + } + + /// Get the underlying Metal buffer. + pub fn buffer(&self) -> &Buffer { + &self.buffer + } +} diff --git a/tarpaulin.toml b/tarpaulin.toml index c36cb2d..2d745f2 100644 --- a/tarpaulin.toml +++ b/tarpaulin.toml @@ -11,8 +11,8 @@ engine = "Llvm" # Run for whole workspace workspace = true -# Exclude CUDA crate from testing (requires GPU) -exclude = ["tropical-gemm-cuda"] +# Exclude GPU crates from testing (requires specific hardware) +exclude = ["tropical-gemm-cuda", "tropical-gemm-metal"] # Note: exclude_files doesn't work reliably with LLVM engine in config file. # Use CLI flag --exclude-files "crates/tropical-gemm-cuda/*" for accurate totals.