From ef265c4f266d203bb00e3c19ca9d9a86d1f0a31e Mon Sep 17 00:00:00 2001 From: GiggleLiu Date: Thu, 15 Jan 2026 12:07:53 +0800 Subject: [PATCH] feat: add Metal backend for Apple GPU support MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add a new `tropical-gemm-metal` crate that provides GPU-accelerated tropical matrix multiplication using Metal on macOS. This addresses issue #1 by enabling GPU acceleration on Apple Silicon. Features: - Support for MaxPlus, MinPlus, and MaxMul semirings (f32) - Argmax tracking for backpropagation - Backward pass kernels for gradient computation - Global context caching for efficient repeated operations - Row-major to column-major layout conversion The implementation follows the same architecture as the CUDA backend: - MetalContext: Shader compilation and kernel management - GpuMatrix: GPU memory management with layout conversion - MetalKernel trait: Type-safe kernel dispatch Closes #1 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- Cargo.lock | 127 +++- Cargo.toml | 2 + crates/tropical-gemm-metal/Cargo.toml | 21 + .../shaders/tropical_gemm.metal | 696 ++++++++++++++++++ crates/tropical-gemm-metal/src/context.rs | 124 ++++ crates/tropical-gemm-metal/src/error.rs | 42 ++ crates/tropical-gemm-metal/src/gpu_mat.rs | 225 ++++++ crates/tropical-gemm-metal/src/kernels.rs | 268 +++++++ crates/tropical-gemm-metal/src/lib.rs | 396 ++++++++++ 9 files changed, 1899 insertions(+), 2 deletions(-) create mode 100644 crates/tropical-gemm-metal/Cargo.toml create mode 100644 crates/tropical-gemm-metal/shaders/tropical_gemm.metal create mode 100644 crates/tropical-gemm-metal/src/context.rs create mode 100644 crates/tropical-gemm-metal/src/error.rs create mode 100644 crates/tropical-gemm-metal/src/gpu_mat.rs create mode 100644 crates/tropical-gemm-metal/src/kernels.rs create mode 100644 crates/tropical-gemm-metal/src/lib.rs diff --git a/Cargo.lock b/Cargo.lock index fbc35a7..d2e042b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -44,12 +44,24 @@ version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5e764a1d40d510daf35e07be9eb06e75770908c27d411ee6c92109c9840eaaf7" +[[package]] +name = "bitflags" +version = "1.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" + [[package]] name = "bitflags" version = "2.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "812e12b5285cc515a9c72a5c1d3b6d46a19dac5acfef5265968c166106e31dd3" +[[package]] +name = "block" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0d8c1fef690941d3e7788d328517591fecc684c084084702d6ff1641e993699a" + [[package]] name = "bumpalo" version = "3.19.1" @@ -126,6 +138,33 @@ version = "0.7.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a1d728cc89cf3aee9ff92b05e62b19ee65a02b5702cff7d5a377e32c6ae29d8d" +[[package]] +name = "core-foundation" +version = "0.9.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91e195e091a93c46f7102ec7818a2aa394e1e1771c3ab4825963fa03e45afb8f" +dependencies = [ + "core-foundation-sys", + "libc", +] + +[[package]] +name = "core-foundation-sys" +version = "0.8.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b" + +[[package]] +name = "core-graphics-types" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "45390e6114f68f718cc7a830514a96f903cccd70d02a8f6d9f643ac4ba45afaf" +dependencies = [ + "bitflags 1.3.2", + "core-foundation", + "libc", +] + [[package]] name = "criterion" version = "0.5.1" @@ -230,6 +269,33 @@ version = "1.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" +[[package]] +name = "foreign-types" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d737d9aa519fb7b749cbc3b962edcf310a8dd1f4b67c91c4f83975dbdd17d965" +dependencies = [ + "foreign-types-macros", + "foreign-types-shared", +] + +[[package]] +name = "foreign-types-macros" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1a5c6c585bc94aaf2c7b51dd4c2ba22680844aba4c687be581871a6f518c5742" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "foreign-types-shared" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aa9a19cbb55df58761df49b23516a86d432839add4af60fc256da840f66ed35b" + [[package]] name = "getrandom" version = "0.3.4" @@ -332,6 +398,21 @@ version = "0.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "df1d3c3b53da64cf5760482273a98e575c651a67eec7f77df96b5b642de8f039" +[[package]] +name = "log" +version = "0.4.29" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5e5032e24019045c762d3c0f28f5b6b8bbf38563a65908389bf7978758920897" + +[[package]] +name = "malloc_buf" +version = "0.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "62bb907fe88d54d8d9ce32a3cceab4218ed2f6b7d35617cafe9adf84e43919cb" +dependencies = [ + "libc", +] + [[package]] name = "matrixmultiply" version = "0.3.10" @@ -357,6 +438,21 @@ dependencies = [ "autocfg", ] +[[package]] +name = "metal" +version = "0.29.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7ecfd3296f8c56b7c1f6fbac3c71cefa9d78ce009850c45000015f206dc7fa21" +dependencies = [ + "bitflags 2.10.0", + "block", + "core-graphics-types", + "foreign-types", + "log", + "objc", + "paste", +] + [[package]] name = "ndarray" version = "0.16.1" @@ -414,6 +510,15 @@ dependencies = [ "rustc-hash", ] +[[package]] +name = "objc" +version = "0.2.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "915b1b472bc21c53464d6c8461c9d3af805ba1ef837e1cac254428f4a77177b1" +dependencies = [ + "malloc_buf", +] + [[package]] name = "once_cell" version = "1.21.3" @@ -426,6 +531,12 @@ version = "11.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d6790f58c7ff633d8771f42965289203411a5e5c68388703c06e14f24770b41e" +[[package]] +name = "paste" +version = "1.0.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a" + [[package]] name = "plotters" version = "0.3.7" @@ -495,7 +606,7 @@ checksum = "bee689443a2bd0a16ab0348b52ee43e3b2d1b1f931c8aa5c9f8de4c86fbe8c40" dependencies = [ "bit-set", "bit-vec", - "bitflags", + "bitflags 2.10.0", "num-traits", "rand", "rand_chacha", @@ -695,7 +806,7 @@ version = "1.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "146c9e247ccc180c1f61615433868c99f3de3ae256a30a43b49f67c2d9171f34" dependencies = [ - "bitflags", + "bitflags 2.10.0", "errno", "libc", "linux-raw-sys", @@ -863,6 +974,18 @@ dependencies = [ "tropical-gemm", ] +[[package]] +name = "tropical-gemm-metal" +version = "0.1.0" +dependencies = [ + "block", + "metal", + "objc", + "once_cell", + "thiserror", + "tropical-gemm", +] + [[package]] name = "tropical-gemm-python" version = "0.1.0" diff --git a/Cargo.toml b/Cargo.toml index 549163e..e7f7609 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,6 +2,7 @@ resolver = "2" members = [ "crates/tropical-gemm-cuda", + "crates/tropical-gemm-metal", "crates/tropical-gemm", "crates/tropical-gemm-python", ] @@ -16,6 +17,7 @@ repository = "https://github.com/TensorBFS/tropical-gemm" [workspace.dependencies] # Internal crates tropical-gemm-cuda = { version = "0.1.0", path = "crates/tropical-gemm-cuda" } +tropical-gemm-metal = { version = "0.1.0", path = "crates/tropical-gemm-metal" } tropical-gemm = { version = "0.1.0", path = "crates/tropical-gemm" } # External dependencies diff --git a/crates/tropical-gemm-metal/Cargo.toml b/crates/tropical-gemm-metal/Cargo.toml new file mode 100644 index 0000000..9d8cac3 --- /dev/null +++ b/crates/tropical-gemm-metal/Cargo.toml @@ -0,0 +1,21 @@ +[package] +name = "tropical-gemm-metal" +version.workspace = true +edition.workspace = true +license.workspace = true +authors.workspace = true +repository.workspace = true +description = "Metal backend for tropical matrix multiplication on Apple Silicon" +keywords = ["tropical", "metal", "gpu", "matrix", "gemm", "apple"] +categories = ["algorithms", "mathematics", "science"] +readme = "../../README.md" + +[dependencies] +tropical-gemm = { workspace = true } +metal = "0.29" +objc = "0.2" +thiserror = { workspace = true } +once_cell = "1.19" +block = "0.1" + +[dev-dependencies] diff --git a/crates/tropical-gemm-metal/shaders/tropical_gemm.metal b/crates/tropical-gemm-metal/shaders/tropical_gemm.metal new file mode 100644 index 0000000..c8499b8 --- /dev/null +++ b/crates/tropical-gemm-metal/shaders/tropical_gemm.metal @@ -0,0 +1,696 @@ +// Tropical GEMM Metal Shaders +// DRY implementation using Metal shader functions +// Adapted from CUDA implementation for Apple Silicon + +#include +using namespace metal; + +// ============================================================================ +// CONSTANTS AND UTILITIES +// ============================================================================ + +// Memory layout helpers (column-major) +#define OFFSET_COL(row, col, ld) ((col) * (ld) + (row)) + +// Integer "infinity" constants (sentinel values for tropical zero) +constant int INF_I32 = 46340; +constant int NEG_INF_I32 = -46340; + +// Saturating addition for MaxPlus (propagates -infinity) +inline int saturating_add_maxplus_i32(int a, int b) { + if (a == NEG_INF_I32 || b == NEG_INF_I32) return NEG_INF_I32; + return a + b; +} + +// Saturating addition for MinPlus (propagates +infinity) +inline int saturating_add_minplus_i32(int a, int b) { + if (a == INF_I32 || b == INF_I32) return INF_I32; + return a + b; +} + +// ============================================================================ +// F32 MAXPLUS GEMM KERNEL +// ============================================================================ +// Block sizes for f32: 64x32x64, Thread sizes: 4x4 + +kernel void tropical_maxplus_f32_nn( + device const float* A [[buffer(0)]], + device const float* B [[buffer(1)]], + device float* C [[buffer(2)]], + constant uint& M [[buffer(3)]], + constant uint& N [[buffer(4)]], + constant uint& K [[buffer(5)]], + uint2 tid [[thread_position_in_threadgroup]], + uint2 gid [[threadgroup_position_in_grid]] +) { + const uint BLOCK_SIZE_M = 64; + const uint BLOCK_SIZE_K = 32; + const uint BLOCK_SIZE_N = 64; + const uint THREAD_SIZE_M = 4; + const uint THREAD_SIZE_N = 4; + + const uint bszm = BLOCK_SIZE_M / THREAD_SIZE_M; + const uint bszn = BLOCK_SIZE_N / THREAD_SIZE_N; + const uint THREAD_NUM_PER_BLOCK = bszm * bszn; + + uint BLOCK_IDX = gid.x; + uint BLOCK_IDY = gid.y; + + const uint thread_id = tid.y * bszm + tid.x; + + threadgroup float As[BLOCK_SIZE_M * BLOCK_SIZE_K]; + threadgroup float Bs[BLOCK_SIZE_K * BLOCK_SIZE_N]; + + float accum[THREAD_SIZE_M * THREAD_SIZE_N]; + float regs_a[THREAD_SIZE_M]; + float regs_b[THREAD_SIZE_N]; + + for (uint i = 0; i < THREAD_SIZE_M * THREAD_SIZE_N; ++i) { + accum[i] = -INFINITY; + } + + const uint A_TILE_COL = thread_id / BLOCK_SIZE_M; + const uint A_TILE_ROW = thread_id % BLOCK_SIZE_M; + const uint B_TILE_COL = thread_id / BLOCK_SIZE_K; + const uint B_TILE_ROW = thread_id % BLOCK_SIZE_K; + const uint A_TILE_COL_STRIDE = THREAD_NUM_PER_BLOCK / BLOCK_SIZE_M; + const uint B_TILE_COL_STRIDE = THREAD_NUM_PER_BLOCK / BLOCK_SIZE_K; + + for (uint tile_idx = 0; tile_idx < K; tile_idx += BLOCK_SIZE_K) { + for (uint i = 0; i < BLOCK_SIZE_K; i += A_TILE_COL_STRIDE) { + uint row = BLOCK_SIZE_M * BLOCK_IDX + A_TILE_ROW; + uint col = A_TILE_COL + i + tile_idx; + float val = -INFINITY; + if (row < M && col < K) { + val = A[OFFSET_COL(row, col, M)]; + } + As[OFFSET_COL(A_TILE_ROW, i + A_TILE_COL, BLOCK_SIZE_M)] = val; + } + + for (uint i = 0; i < BLOCK_SIZE_N; i += B_TILE_COL_STRIDE) { + uint row = tile_idx + B_TILE_ROW; + uint col = BLOCK_SIZE_N * BLOCK_IDY + i + B_TILE_COL; + float val = -INFINITY; + if (row < K && col < N) { + val = B[OFFSET_COL(row, col, K)]; + } + Bs[OFFSET_COL(B_TILE_ROW, i + B_TILE_COL, BLOCK_SIZE_K)] = val; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + for (uint k = 0; k < BLOCK_SIZE_K; ++k) { + for (uint tm = 0; tm < THREAD_SIZE_M; ++tm) { + regs_a[tm] = As[OFFSET_COL(tid.x * THREAD_SIZE_M + tm, k, BLOCK_SIZE_M)]; + } + for (uint tn = 0; tn < THREAD_SIZE_N; ++tn) { + regs_b[tn] = Bs[OFFSET_COL(k, tid.y * THREAD_SIZE_N + tn, BLOCK_SIZE_K)]; + } + for (uint tm = 0; tm < THREAD_SIZE_M; ++tm) { + for (uint tn = 0; tn < THREAD_SIZE_N; ++tn) { + float prod = regs_a[tm] + regs_b[tn]; + uint idx = OFFSET_COL(tm, tn, THREAD_SIZE_M); + accum[idx] = max(accum[idx], prod); + } + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + for (uint tm = 0; tm < THREAD_SIZE_M; ++tm) { + for (uint tn = 0; tn < THREAD_SIZE_N; ++tn) { + uint row = BLOCK_SIZE_M * BLOCK_IDX + THREAD_SIZE_M * tid.x + tm; + uint col = BLOCK_SIZE_N * BLOCK_IDY + THREAD_SIZE_N * tid.y + tn; + if (row < M && col < N) { + C[OFFSET_COL(row, col, M)] = accum[OFFSET_COL(tm, tn, THREAD_SIZE_M)]; + } + } + } +} + +// ============================================================================ +// F32 MINPLUS GEMM KERNEL +// ============================================================================ + +kernel void tropical_minplus_f32_nn( + device const float* A [[buffer(0)]], + device const float* B [[buffer(1)]], + device float* C [[buffer(2)]], + constant uint& M [[buffer(3)]], + constant uint& N [[buffer(4)]], + constant uint& K [[buffer(5)]], + uint2 tid [[thread_position_in_threadgroup]], + uint2 gid [[threadgroup_position_in_grid]] +) { + const uint BLOCK_SIZE_M = 64; + const uint BLOCK_SIZE_K = 32; + const uint BLOCK_SIZE_N = 64; + const uint THREAD_SIZE_M = 4; + const uint THREAD_SIZE_N = 4; + + const uint bszm = BLOCK_SIZE_M / THREAD_SIZE_M; + const uint bszn = BLOCK_SIZE_N / THREAD_SIZE_N; + const uint THREAD_NUM_PER_BLOCK = bszm * bszn; + + uint BLOCK_IDX = gid.x; + uint BLOCK_IDY = gid.y; + + const uint thread_id = tid.y * bszm + tid.x; + + threadgroup float As[BLOCK_SIZE_M * BLOCK_SIZE_K]; + threadgroup float Bs[BLOCK_SIZE_K * BLOCK_SIZE_N]; + + float accum[THREAD_SIZE_M * THREAD_SIZE_N]; + float regs_a[THREAD_SIZE_M]; + float regs_b[THREAD_SIZE_N]; + + for (uint i = 0; i < THREAD_SIZE_M * THREAD_SIZE_N; ++i) { + accum[i] = INFINITY; + } + + const uint A_TILE_COL = thread_id / BLOCK_SIZE_M; + const uint A_TILE_ROW = thread_id % BLOCK_SIZE_M; + const uint B_TILE_COL = thread_id / BLOCK_SIZE_K; + const uint B_TILE_ROW = thread_id % BLOCK_SIZE_K; + const uint A_TILE_COL_STRIDE = THREAD_NUM_PER_BLOCK / BLOCK_SIZE_M; + const uint B_TILE_COL_STRIDE = THREAD_NUM_PER_BLOCK / BLOCK_SIZE_K; + + for (uint tile_idx = 0; tile_idx < K; tile_idx += BLOCK_SIZE_K) { + for (uint i = 0; i < BLOCK_SIZE_K; i += A_TILE_COL_STRIDE) { + uint row = BLOCK_SIZE_M * BLOCK_IDX + A_TILE_ROW; + uint col = A_TILE_COL + i + tile_idx; + float val = INFINITY; + if (row < M && col < K) { + val = A[OFFSET_COL(row, col, M)]; + } + As[OFFSET_COL(A_TILE_ROW, i + A_TILE_COL, BLOCK_SIZE_M)] = val; + } + + for (uint i = 0; i < BLOCK_SIZE_N; i += B_TILE_COL_STRIDE) { + uint row = tile_idx + B_TILE_ROW; + uint col = BLOCK_SIZE_N * BLOCK_IDY + i + B_TILE_COL; + float val = INFINITY; + if (row < K && col < N) { + val = B[OFFSET_COL(row, col, K)]; + } + Bs[OFFSET_COL(B_TILE_ROW, i + B_TILE_COL, BLOCK_SIZE_K)] = val; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + for (uint k = 0; k < BLOCK_SIZE_K; ++k) { + for (uint tm = 0; tm < THREAD_SIZE_M; ++tm) { + regs_a[tm] = As[OFFSET_COL(tid.x * THREAD_SIZE_M + tm, k, BLOCK_SIZE_M)]; + } + for (uint tn = 0; tn < THREAD_SIZE_N; ++tn) { + regs_b[tn] = Bs[OFFSET_COL(k, tid.y * THREAD_SIZE_N + tn, BLOCK_SIZE_K)]; + } + for (uint tm = 0; tm < THREAD_SIZE_M; ++tm) { + for (uint tn = 0; tn < THREAD_SIZE_N; ++tn) { + float prod = regs_a[tm] + regs_b[tn]; + uint idx = OFFSET_COL(tm, tn, THREAD_SIZE_M); + accum[idx] = min(accum[idx], prod); + } + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + for (uint tm = 0; tm < THREAD_SIZE_M; ++tm) { + for (uint tn = 0; tn < THREAD_SIZE_N; ++tn) { + uint row = BLOCK_SIZE_M * BLOCK_IDX + THREAD_SIZE_M * tid.x + tm; + uint col = BLOCK_SIZE_N * BLOCK_IDY + THREAD_SIZE_N * tid.y + tn; + if (row < M && col < N) { + C[OFFSET_COL(row, col, M)] = accum[OFFSET_COL(tm, tn, THREAD_SIZE_M)]; + } + } + } +} + +// ============================================================================ +// F32 MAXMUL GEMM KERNEL +// ============================================================================ + +kernel void tropical_maxmul_f32_nn( + device const float* A [[buffer(0)]], + device const float* B [[buffer(1)]], + device float* C [[buffer(2)]], + constant uint& M [[buffer(3)]], + constant uint& N [[buffer(4)]], + constant uint& K [[buffer(5)]], + uint2 tid [[thread_position_in_threadgroup]], + uint2 gid [[threadgroup_position_in_grid]] +) { + const uint BLOCK_SIZE_M = 64; + const uint BLOCK_SIZE_K = 32; + const uint BLOCK_SIZE_N = 64; + const uint THREAD_SIZE_M = 4; + const uint THREAD_SIZE_N = 4; + + const uint bszm = BLOCK_SIZE_M / THREAD_SIZE_M; + const uint bszn = BLOCK_SIZE_N / THREAD_SIZE_N; + const uint THREAD_NUM_PER_BLOCK = bszm * bszn; + + uint BLOCK_IDX = gid.x; + uint BLOCK_IDY = gid.y; + + const uint thread_id = tid.y * bszm + tid.x; + + threadgroup float As[BLOCK_SIZE_M * BLOCK_SIZE_K]; + threadgroup float Bs[BLOCK_SIZE_K * BLOCK_SIZE_N]; + + float accum[THREAD_SIZE_M * THREAD_SIZE_N]; + float regs_a[THREAD_SIZE_M]; + float regs_b[THREAD_SIZE_N]; + + for (uint i = 0; i < THREAD_SIZE_M * THREAD_SIZE_N; ++i) { + accum[i] = 0.0f; + } + + const uint A_TILE_COL = thread_id / BLOCK_SIZE_M; + const uint A_TILE_ROW = thread_id % BLOCK_SIZE_M; + const uint B_TILE_COL = thread_id / BLOCK_SIZE_K; + const uint B_TILE_ROW = thread_id % BLOCK_SIZE_K; + const uint A_TILE_COL_STRIDE = THREAD_NUM_PER_BLOCK / BLOCK_SIZE_M; + const uint B_TILE_COL_STRIDE = THREAD_NUM_PER_BLOCK / BLOCK_SIZE_K; + + for (uint tile_idx = 0; tile_idx < K; tile_idx += BLOCK_SIZE_K) { + for (uint i = 0; i < BLOCK_SIZE_K; i += A_TILE_COL_STRIDE) { + uint row = BLOCK_SIZE_M * BLOCK_IDX + A_TILE_ROW; + uint col = A_TILE_COL + i + tile_idx; + float val = 0.0f; + if (row < M && col < K) { + val = A[OFFSET_COL(row, col, M)]; + } + As[OFFSET_COL(A_TILE_ROW, i + A_TILE_COL, BLOCK_SIZE_M)] = val; + } + + for (uint i = 0; i < BLOCK_SIZE_N; i += B_TILE_COL_STRIDE) { + uint row = tile_idx + B_TILE_ROW; + uint col = BLOCK_SIZE_N * BLOCK_IDY + i + B_TILE_COL; + float val = 0.0f; + if (row < K && col < N) { + val = B[OFFSET_COL(row, col, K)]; + } + Bs[OFFSET_COL(B_TILE_ROW, i + B_TILE_COL, BLOCK_SIZE_K)] = val; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + for (uint k = 0; k < BLOCK_SIZE_K; ++k) { + for (uint tm = 0; tm < THREAD_SIZE_M; ++tm) { + regs_a[tm] = As[OFFSET_COL(tid.x * THREAD_SIZE_M + tm, k, BLOCK_SIZE_M)]; + } + for (uint tn = 0; tn < THREAD_SIZE_N; ++tn) { + regs_b[tn] = Bs[OFFSET_COL(k, tid.y * THREAD_SIZE_N + tn, BLOCK_SIZE_K)]; + } + for (uint tm = 0; tm < THREAD_SIZE_M; ++tm) { + for (uint tn = 0; tn < THREAD_SIZE_N; ++tn) { + float prod = regs_a[tm] * regs_b[tn]; + uint idx = OFFSET_COL(tm, tn, THREAD_SIZE_M); + accum[idx] = max(accum[idx], prod); + } + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + for (uint tm = 0; tm < THREAD_SIZE_M; ++tm) { + for (uint tn = 0; tn < THREAD_SIZE_N; ++tn) { + uint row = BLOCK_SIZE_M * BLOCK_IDX + THREAD_SIZE_M * tid.x + tm; + uint col = BLOCK_SIZE_N * BLOCK_IDY + THREAD_SIZE_N * tid.y + tn; + if (row < M && col < N) { + C[OFFSET_COL(row, col, M)] = accum[OFFSET_COL(tm, tn, THREAD_SIZE_M)]; + } + } + } +} + +// ============================================================================ +// F32 MAXPLUS GEMM WITH ARGMAX KERNEL +// ============================================================================ + +kernel void tropical_maxplus_f32_nn_with_argmax( + device const float* A [[buffer(0)]], + device const float* B [[buffer(1)]], + device float* C [[buffer(2)]], + device int* argmax_out [[buffer(3)]], + constant uint& M [[buffer(4)]], + constant uint& N [[buffer(5)]], + constant uint& K [[buffer(6)]], + uint2 tid [[thread_position_in_threadgroup]], + uint2 gid [[threadgroup_position_in_grid]] +) { + const uint BLOCK_SIZE_M = 64; + const uint BLOCK_SIZE_K = 32; + const uint BLOCK_SIZE_N = 64; + const uint THREAD_SIZE_M = 4; + const uint THREAD_SIZE_N = 4; + + const uint bszm = BLOCK_SIZE_M / THREAD_SIZE_M; + const uint bszn = BLOCK_SIZE_N / THREAD_SIZE_N; + const uint THREAD_NUM_PER_BLOCK = bszm * bszn; + + uint BLOCK_IDX = gid.x; + uint BLOCK_IDY = gid.y; + + const uint thread_id = tid.y * bszm + tid.x; + + threadgroup float As[BLOCK_SIZE_M * BLOCK_SIZE_K]; + threadgroup float Bs[BLOCK_SIZE_K * BLOCK_SIZE_N]; + + float accum[THREAD_SIZE_M * THREAD_SIZE_N]; + int accum_idx[THREAD_SIZE_M * THREAD_SIZE_N]; + float regs_a[THREAD_SIZE_M]; + float regs_b[THREAD_SIZE_N]; + + for (uint i = 0; i < THREAD_SIZE_M * THREAD_SIZE_N; ++i) { + accum[i] = -INFINITY; + accum_idx[i] = 0; + } + + const uint A_TILE_COL = thread_id / BLOCK_SIZE_M; + const uint A_TILE_ROW = thread_id % BLOCK_SIZE_M; + const uint B_TILE_COL = thread_id / BLOCK_SIZE_K; + const uint B_TILE_ROW = thread_id % BLOCK_SIZE_K; + const uint A_TILE_COL_STRIDE = THREAD_NUM_PER_BLOCK / BLOCK_SIZE_M; + const uint B_TILE_COL_STRIDE = THREAD_NUM_PER_BLOCK / BLOCK_SIZE_K; + + for (uint tile_idx = 0; tile_idx < K; tile_idx += BLOCK_SIZE_K) { + for (uint i = 0; i < BLOCK_SIZE_K; i += A_TILE_COL_STRIDE) { + uint row = BLOCK_SIZE_M * BLOCK_IDX + A_TILE_ROW; + uint col = A_TILE_COL + i + tile_idx; + float val = -INFINITY; + if (row < M && col < K) { + val = A[OFFSET_COL(row, col, M)]; + } + As[OFFSET_COL(A_TILE_ROW, i + A_TILE_COL, BLOCK_SIZE_M)] = val; + } + + for (uint i = 0; i < BLOCK_SIZE_N; i += B_TILE_COL_STRIDE) { + uint row = tile_idx + B_TILE_ROW; + uint col = BLOCK_SIZE_N * BLOCK_IDY + i + B_TILE_COL; + float val = -INFINITY; + if (row < K && col < N) { + val = B[OFFSET_COL(row, col, K)]; + } + Bs[OFFSET_COL(B_TILE_ROW, i + B_TILE_COL, BLOCK_SIZE_K)] = val; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + for (uint k = 0; k < BLOCK_SIZE_K; ++k) { + int global_k = tile_idx + k; + for (uint tm = 0; tm < THREAD_SIZE_M; ++tm) { + regs_a[tm] = As[OFFSET_COL(tid.x * THREAD_SIZE_M + tm, k, BLOCK_SIZE_M)]; + } + for (uint tn = 0; tn < THREAD_SIZE_N; ++tn) { + regs_b[tn] = Bs[OFFSET_COL(k, tid.y * THREAD_SIZE_N + tn, BLOCK_SIZE_K)]; + } + for (uint tm = 0; tm < THREAD_SIZE_M; ++tm) { + for (uint tn = 0; tn < THREAD_SIZE_N; ++tn) { + float prod = regs_a[tm] + regs_b[tn]; + uint idx = OFFSET_COL(tm, tn, THREAD_SIZE_M); + if (prod > accum[idx]) { + accum[idx] = prod; + accum_idx[idx] = global_k; + } + } + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + for (uint tm = 0; tm < THREAD_SIZE_M; ++tm) { + for (uint tn = 0; tn < THREAD_SIZE_N; ++tn) { + uint row = BLOCK_SIZE_M * BLOCK_IDX + THREAD_SIZE_M * tid.x + tm; + uint col = BLOCK_SIZE_N * BLOCK_IDY + THREAD_SIZE_N * tid.y + tn; + if (row < M && col < N) { + uint out_idx = OFFSET_COL(row, col, M); + uint local_idx = OFFSET_COL(tm, tn, THREAD_SIZE_M); + C[out_idx] = accum[local_idx]; + argmax_out[out_idx] = accum_idx[local_idx]; + } + } + } +} + +// ============================================================================ +// F32 MINPLUS GEMM WITH ARGMAX KERNEL +// ============================================================================ + +kernel void tropical_minplus_f32_nn_with_argmax( + device const float* A [[buffer(0)]], + device const float* B [[buffer(1)]], + device float* C [[buffer(2)]], + device int* argmax_out [[buffer(3)]], + constant uint& M [[buffer(4)]], + constant uint& N [[buffer(5)]], + constant uint& K [[buffer(6)]], + uint2 tid [[thread_position_in_threadgroup]], + uint2 gid [[threadgroup_position_in_grid]] +) { + const uint BLOCK_SIZE_M = 64; + const uint BLOCK_SIZE_K = 32; + const uint BLOCK_SIZE_N = 64; + const uint THREAD_SIZE_M = 4; + const uint THREAD_SIZE_N = 4; + + const uint bszm = BLOCK_SIZE_M / THREAD_SIZE_M; + const uint bszn = BLOCK_SIZE_N / THREAD_SIZE_N; + const uint THREAD_NUM_PER_BLOCK = bszm * bszn; + + uint BLOCK_IDX = gid.x; + uint BLOCK_IDY = gid.y; + + const uint thread_id = tid.y * bszm + tid.x; + + threadgroup float As[BLOCK_SIZE_M * BLOCK_SIZE_K]; + threadgroup float Bs[BLOCK_SIZE_K * BLOCK_SIZE_N]; + + float accum[THREAD_SIZE_M * THREAD_SIZE_N]; + int accum_idx[THREAD_SIZE_M * THREAD_SIZE_N]; + float regs_a[THREAD_SIZE_M]; + float regs_b[THREAD_SIZE_N]; + + for (uint i = 0; i < THREAD_SIZE_M * THREAD_SIZE_N; ++i) { + accum[i] = INFINITY; + accum_idx[i] = 0; + } + + const uint A_TILE_COL = thread_id / BLOCK_SIZE_M; + const uint A_TILE_ROW = thread_id % BLOCK_SIZE_M; + const uint B_TILE_COL = thread_id / BLOCK_SIZE_K; + const uint B_TILE_ROW = thread_id % BLOCK_SIZE_K; + const uint A_TILE_COL_STRIDE = THREAD_NUM_PER_BLOCK / BLOCK_SIZE_M; + const uint B_TILE_COL_STRIDE = THREAD_NUM_PER_BLOCK / BLOCK_SIZE_K; + + for (uint tile_idx = 0; tile_idx < K; tile_idx += BLOCK_SIZE_K) { + for (uint i = 0; i < BLOCK_SIZE_K; i += A_TILE_COL_STRIDE) { + uint row = BLOCK_SIZE_M * BLOCK_IDX + A_TILE_ROW; + uint col = A_TILE_COL + i + tile_idx; + float val = INFINITY; + if (row < M && col < K) { + val = A[OFFSET_COL(row, col, M)]; + } + As[OFFSET_COL(A_TILE_ROW, i + A_TILE_COL, BLOCK_SIZE_M)] = val; + } + + for (uint i = 0; i < BLOCK_SIZE_N; i += B_TILE_COL_STRIDE) { + uint row = tile_idx + B_TILE_ROW; + uint col = BLOCK_SIZE_N * BLOCK_IDY + i + B_TILE_COL; + float val = INFINITY; + if (row < K && col < N) { + val = B[OFFSET_COL(row, col, K)]; + } + Bs[OFFSET_COL(B_TILE_ROW, i + B_TILE_COL, BLOCK_SIZE_K)] = val; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + for (uint k = 0; k < BLOCK_SIZE_K; ++k) { + int global_k = tile_idx + k; + for (uint tm = 0; tm < THREAD_SIZE_M; ++tm) { + regs_a[tm] = As[OFFSET_COL(tid.x * THREAD_SIZE_M + tm, k, BLOCK_SIZE_M)]; + } + for (uint tn = 0; tn < THREAD_SIZE_N; ++tn) { + regs_b[tn] = Bs[OFFSET_COL(k, tid.y * THREAD_SIZE_N + tn, BLOCK_SIZE_K)]; + } + for (uint tm = 0; tm < THREAD_SIZE_M; ++tm) { + for (uint tn = 0; tn < THREAD_SIZE_N; ++tn) { + float prod = regs_a[tm] + regs_b[tn]; + uint idx = OFFSET_COL(tm, tn, THREAD_SIZE_M); + if (prod < accum[idx]) { + accum[idx] = prod; + accum_idx[idx] = global_k; + } + } + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + for (uint tm = 0; tm < THREAD_SIZE_M; ++tm) { + for (uint tn = 0; tn < THREAD_SIZE_N; ++tn) { + uint row = BLOCK_SIZE_M * BLOCK_IDX + THREAD_SIZE_M * tid.x + tm; + uint col = BLOCK_SIZE_N * BLOCK_IDY + THREAD_SIZE_N * tid.y + tn; + if (row < M && col < N) { + uint out_idx = OFFSET_COL(row, col, M); + uint local_idx = OFFSET_COL(tm, tn, THREAD_SIZE_M); + C[out_idx] = accum[local_idx]; + argmax_out[out_idx] = accum_idx[local_idx]; + } + } + } +} + +// ============================================================================ +// F32 MAXMUL GEMM WITH ARGMAX KERNEL +// ============================================================================ + +kernel void tropical_maxmul_f32_nn_with_argmax( + device const float* A [[buffer(0)]], + device const float* B [[buffer(1)]], + device float* C [[buffer(2)]], + device int* argmax_out [[buffer(3)]], + constant uint& M [[buffer(4)]], + constant uint& N [[buffer(5)]], + constant uint& K [[buffer(6)]], + uint2 tid [[thread_position_in_threadgroup]], + uint2 gid [[threadgroup_position_in_grid]] +) { + const uint BLOCK_SIZE_M = 64; + const uint BLOCK_SIZE_K = 32; + const uint BLOCK_SIZE_N = 64; + const uint THREAD_SIZE_M = 4; + const uint THREAD_SIZE_N = 4; + + const uint bszm = BLOCK_SIZE_M / THREAD_SIZE_M; + const uint bszn = BLOCK_SIZE_N / THREAD_SIZE_N; + const uint THREAD_NUM_PER_BLOCK = bszm * bszn; + + uint BLOCK_IDX = gid.x; + uint BLOCK_IDY = gid.y; + + const uint thread_id = tid.y * bszm + tid.x; + + threadgroup float As[BLOCK_SIZE_M * BLOCK_SIZE_K]; + threadgroup float Bs[BLOCK_SIZE_K * BLOCK_SIZE_N]; + + float accum[THREAD_SIZE_M * THREAD_SIZE_N]; + int accum_idx[THREAD_SIZE_M * THREAD_SIZE_N]; + float regs_a[THREAD_SIZE_M]; + float regs_b[THREAD_SIZE_N]; + + for (uint i = 0; i < THREAD_SIZE_M * THREAD_SIZE_N; ++i) { + accum[i] = 0.0f; + accum_idx[i] = 0; + } + + const uint A_TILE_COL = thread_id / BLOCK_SIZE_M; + const uint A_TILE_ROW = thread_id % BLOCK_SIZE_M; + const uint B_TILE_COL = thread_id / BLOCK_SIZE_K; + const uint B_TILE_ROW = thread_id % BLOCK_SIZE_K; + const uint A_TILE_COL_STRIDE = THREAD_NUM_PER_BLOCK / BLOCK_SIZE_M; + const uint B_TILE_COL_STRIDE = THREAD_NUM_PER_BLOCK / BLOCK_SIZE_K; + + for (uint tile_idx = 0; tile_idx < K; tile_idx += BLOCK_SIZE_K) { + for (uint i = 0; i < BLOCK_SIZE_K; i += A_TILE_COL_STRIDE) { + uint row = BLOCK_SIZE_M * BLOCK_IDX + A_TILE_ROW; + uint col = A_TILE_COL + i + tile_idx; + float val = 0.0f; + if (row < M && col < K) { + val = A[OFFSET_COL(row, col, M)]; + } + As[OFFSET_COL(A_TILE_ROW, i + A_TILE_COL, BLOCK_SIZE_M)] = val; + } + + for (uint i = 0; i < BLOCK_SIZE_N; i += B_TILE_COL_STRIDE) { + uint row = tile_idx + B_TILE_ROW; + uint col = BLOCK_SIZE_N * BLOCK_IDY + i + B_TILE_COL; + float val = 0.0f; + if (row < K && col < N) { + val = B[OFFSET_COL(row, col, K)]; + } + Bs[OFFSET_COL(B_TILE_ROW, i + B_TILE_COL, BLOCK_SIZE_K)] = val; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + for (uint k = 0; k < BLOCK_SIZE_K; ++k) { + int global_k = tile_idx + k; + for (uint tm = 0; tm < THREAD_SIZE_M; ++tm) { + regs_a[tm] = As[OFFSET_COL(tid.x * THREAD_SIZE_M + tm, k, BLOCK_SIZE_M)]; + } + for (uint tn = 0; tn < THREAD_SIZE_N; ++tn) { + regs_b[tn] = Bs[OFFSET_COL(k, tid.y * THREAD_SIZE_N + tn, BLOCK_SIZE_K)]; + } + for (uint tm = 0; tm < THREAD_SIZE_M; ++tm) { + for (uint tn = 0; tn < THREAD_SIZE_N; ++tn) { + float prod = regs_a[tm] * regs_b[tn]; + uint idx = OFFSET_COL(tm, tn, THREAD_SIZE_M); + if (prod > accum[idx]) { + accum[idx] = prod; + accum_idx[idx] = global_k; + } + } + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + for (uint tm = 0; tm < THREAD_SIZE_M; ++tm) { + for (uint tn = 0; tn < THREAD_SIZE_N; ++tn) { + uint row = BLOCK_SIZE_M * BLOCK_IDX + THREAD_SIZE_M * tid.x + tm; + uint col = BLOCK_SIZE_N * BLOCK_IDY + THREAD_SIZE_N * tid.y + tn; + if (row < M && col < N) { + uint out_idx = OFFSET_COL(row, col, M); + uint local_idx = OFFSET_COL(tm, tn, THREAD_SIZE_M); + C[out_idx] = accum[local_idx]; + argmax_out[out_idx] = accum_idx[local_idx]; + } + } + } +} + +// ============================================================================ +// BACKWARD PASS KERNELS +// ============================================================================ + +kernel void tropical_backward_a_f32( + device const float* grad_c [[buffer(0)]], + device const int* argmax_in [[buffer(1)]], + device atomic_float* grad_a [[buffer(2)]], + constant uint& M [[buffer(3)]], + constant uint& N [[buffer(4)]], + constant uint& K [[buffer(5)]], + uint idx [[thread_position_in_grid]] +) { + uint total = M * N; + if (idx < total) { + uint i = idx % M; + int k = argmax_in[idx]; + if (k >= 0 && (uint)k < K) { + atomic_fetch_add_explicit(&grad_a[i + k * M], grad_c[idx], memory_order_relaxed); + } + } +} + +kernel void tropical_backward_b_f32( + device const float* grad_c [[buffer(0)]], + device const int* argmax_in [[buffer(1)]], + device atomic_float* grad_b [[buffer(2)]], + constant uint& M [[buffer(3)]], + constant uint& N [[buffer(4)]], + constant uint& K [[buffer(5)]], + uint idx [[thread_position_in_grid]] +) { + uint total = M * N; + if (idx < total) { + uint j = idx / M; + int k = argmax_in[idx]; + if (k >= 0 && (uint)k < K) { + atomic_fetch_add_explicit(&grad_b[k + j * K], grad_c[idx], memory_order_relaxed); + } + } +} diff --git a/crates/tropical-gemm-metal/src/context.rs b/crates/tropical-gemm-metal/src/context.rs new file mode 100644 index 0000000..7f3e0ae --- /dev/null +++ b/crates/tropical-gemm-metal/src/context.rs @@ -0,0 +1,124 @@ +//! Metal context and kernel management. + +use crate::error::{MetalError, Result}; +use metal::{CommandQueue, ComputePipelineState, Device, MTLSize}; +use std::collections::HashMap; + +/// Metal kernel source code. +const KERNEL_SOURCE: &str = include_str!("../shaders/tropical_gemm.metal"); + +/// Blocking parameters for f32 kernels. +pub const BLOCK_SIZE_M_F32: u32 = 64; +pub const BLOCK_SIZE_N_F32: u32 = 64; +pub const THREAD_SIZE_M: u32 = 4; +pub const THREAD_SIZE_N: u32 = 4; + +/// Kernel function names. +const KERNEL_NAMES: &[&str] = &[ + // Standard GEMM kernels (f32) + "tropical_maxplus_f32_nn", + "tropical_minplus_f32_nn", + "tropical_maxmul_f32_nn", + // GEMM with argmax kernels (f32) + "tropical_maxplus_f32_nn_with_argmax", + "tropical_minplus_f32_nn_with_argmax", + "tropical_maxmul_f32_nn_with_argmax", + // Backward pass kernels (f32) + "tropical_backward_a_f32", + "tropical_backward_b_f32", +]; + +/// Metal context for tropical GEMM operations. +/// +/// Manages device selection, shader compilation, and caching. +pub struct MetalContext { + device: Device, + command_queue: CommandQueue, + pipelines: HashMap<&'static str, ComputePipelineState>, +} + +impl MetalContext { + /// Create a new Metal context on the default system device. + pub fn new() -> Result { + let device = Device::system_default().ok_or(MetalError::NoDevice)?; + Self::from_device(device) + } + + /// Create a context from an existing device. + pub fn from_device(device: Device) -> Result { + // Create command queue + let command_queue = device.new_command_queue(); + + // Compile shaders from source + let options = metal::CompileOptions::new(); + let library = device + .new_library_with_source(KERNEL_SOURCE, &options) + .map_err(|e| MetalError::ShaderCompile(e.to_string()))?; + + // Create compute pipelines for each kernel + let mut pipelines = HashMap::new(); + for name in KERNEL_NAMES { + let function = library + .get_function(name, None) + .map_err(|e| MetalError::KernelNotFound(format!("{}: {}", name, e)))?; + let pipeline = device + .new_compute_pipeline_state_with_function(&function) + .map_err(|e| MetalError::Library(format!("Pipeline for {}: {}", name, e)))?; + pipelines.insert(*name, pipeline); + } + + Ok(Self { + device, + command_queue, + pipelines, + }) + } + + /// Get the underlying Metal device. + pub fn device(&self) -> &Device { + &self.device + } + + /// Get the command queue. + pub fn command_queue(&self) -> &CommandQueue { + &self.command_queue + } + + /// Get a compute pipeline by kernel name. + pub fn get_pipeline(&self, name: &'static str) -> Result<&ComputePipelineState> { + self.pipelines + .get(name) + .ok_or_else(|| MetalError::KernelNotFound(name.to_string())) + } + + /// Get GPU device name. + pub fn device_name(&self) -> String { + self.device.name().to_string() + } + + /// Calculate grid dimensions for a given matrix size. + pub fn grid_size_f32(m: usize, n: usize) -> MTLSize { + let grid_x = ((m as u64) + BLOCK_SIZE_M_F32 as u64 - 1) / BLOCK_SIZE_M_F32 as u64; + let grid_y = ((n as u64) + BLOCK_SIZE_N_F32 as u64 - 1) / BLOCK_SIZE_N_F32 as u64; + MTLSize::new(grid_x, grid_y, 1) + } + + /// Threadgroup size for f32 kernels. + pub fn threadgroup_size_f32() -> MTLSize { + let bszm = BLOCK_SIZE_M_F32 as u64 / THREAD_SIZE_M as u64; + let bszn = BLOCK_SIZE_N_F32 as u64 / THREAD_SIZE_N as u64; + MTLSize::new(bszm, bszn, 1) + } + + /// Calculate grid dimensions for backward pass kernels. + pub fn grid_size_backward(total_elements: usize) -> MTLSize { + let threads_per_group = 256u64; + let num_groups = ((total_elements as u64) + threads_per_group - 1) / threads_per_group; + MTLSize::new(num_groups * threads_per_group, 1, 1) + } + + /// Threadgroup size for backward kernels. + pub fn threadgroup_size_backward() -> MTLSize { + MTLSize::new(256, 1, 1) + } +} diff --git a/crates/tropical-gemm-metal/src/error.rs b/crates/tropical-gemm-metal/src/error.rs new file mode 100644 index 0000000..4589b3c --- /dev/null +++ b/crates/tropical-gemm-metal/src/error.rs @@ -0,0 +1,42 @@ +//! Error types for Metal operations. + +use thiserror::Error; + +/// Errors that can occur during Metal operations. +#[derive(Debug, Error)] +pub enum MetalError { + /// No Metal device available. + #[error("No Metal device available")] + NoDevice, + + /// Metal shader compilation error. + #[error("Metal shader compilation error: {0}")] + ShaderCompile(String), + + /// Metal library creation error. + #[error("Metal library error: {0}")] + Library(String), + + /// Kernel function not found. + #[error("Kernel not found: {0}")] + KernelNotFound(String), + + /// Buffer creation error. + #[error("Buffer creation error: {0}")] + BufferCreation(String), + + /// Dimension mismatch. + #[error("Dimension mismatch: {0}")] + DimensionMismatch(String), + + /// Command buffer error. + #[error("Command buffer error: {0}")] + CommandBuffer(String), + + /// Execution error. + #[error("Execution error: {0}")] + Execution(String), +} + +/// Result type for Metal operations. +pub type Result = std::result::Result; diff --git a/crates/tropical-gemm-metal/src/gpu_mat.rs b/crates/tropical-gemm-metal/src/gpu_mat.rs new file mode 100644 index 0000000..2fa2158 --- /dev/null +++ b/crates/tropical-gemm-metal/src/gpu_mat.rs @@ -0,0 +1,225 @@ +//! GPU memory management for matrices on Metal. + +use crate::context::MetalContext; +use crate::error::{MetalError, Result}; +use metal::{Buffer, MTLResourceOptions}; +use std::mem; + +/// Type alias for argmax indices (k-index that produced each C[i,j]). +pub type ArgmaxIndex = i32; + +/// A matrix stored in GPU memory. +/// +/// Data is stored in column-major order (Fortran order) for compatibility +/// with BLAS conventions. +pub struct GpuMatrix { + buffer: Buffer, + rows: usize, + cols: usize, + _marker: std::marker::PhantomData, +} + +impl GpuMatrix { + /// Create a GPU matrix from host data. + /// + /// The input data should be in row-major order. It will be transposed + /// to column-major for GPU storage. + pub fn from_host_row_major( + ctx: &MetalContext, + data: &[T], + rows: usize, + cols: usize, + ) -> Result { + if data.len() != rows * cols { + return Err(MetalError::DimensionMismatch(format!( + "Expected {} elements, got {}", + rows * cols, + data.len() + ))); + } + + // Transpose to column-major + let mut col_major = vec![T::default(); rows * cols]; + for i in 0..rows { + for j in 0..cols { + col_major[j * rows + i] = data[i * cols + j]; + } + } + + let byte_len = col_major.len() * mem::size_of::(); + let buffer = ctx.device().new_buffer_with_data( + col_major.as_ptr() as *const _, + byte_len as u64, + MTLResourceOptions::StorageModeShared, + ); + + Ok(Self { + buffer, + rows, + cols, + _marker: std::marker::PhantomData, + }) + } + + /// Create a GPU matrix from column-major host data (no transpose). + pub fn from_host_col_major( + ctx: &MetalContext, + data: &[T], + rows: usize, + cols: usize, + ) -> Result { + if data.len() != rows * cols { + return Err(MetalError::DimensionMismatch(format!( + "Expected {} elements, got {}", + rows * cols, + data.len() + ))); + } + + let byte_len = data.len() * mem::size_of::(); + let buffer = ctx.device().new_buffer_with_data( + data.as_ptr() as *const _, + byte_len as u64, + MTLResourceOptions::StorageModeShared, + ); + + Ok(Self { + buffer, + rows, + cols, + _marker: std::marker::PhantomData, + }) + } + + /// Allocate a zeroed GPU matrix. + pub fn alloc(ctx: &MetalContext, rows: usize, cols: usize) -> Result { + let byte_len = (rows * cols * mem::size_of::()) as u64; + let buffer = ctx.device().new_buffer( + byte_len, + MTLResourceOptions::StorageModeShared, + ); + + // Zero the buffer + unsafe { + std::ptr::write_bytes(buffer.contents() as *mut u8, 0, byte_len as usize); + } + + Ok(Self { + buffer, + rows, + cols, + _marker: std::marker::PhantomData, + }) + } + + /// Copy GPU data back to host in row-major order. + pub fn to_host_row_major(&self, _ctx: &MetalContext) -> Result> { + // Read column-major data from buffer + let col_major = self.read_buffer(); + + // Transpose from column-major to row-major + let mut row_major = vec![T::default(); self.rows * self.cols]; + for i in 0..self.rows { + for j in 0..self.cols { + row_major[i * self.cols + j] = col_major[j * self.rows + i]; + } + } + + Ok(row_major) + } + + /// Copy GPU data back to host in column-major order. + pub fn to_host_col_major(&self, _ctx: &MetalContext) -> Result> { + Ok(self.read_buffer()) + } + + /// Read buffer contents into a Vec. + fn read_buffer(&self) -> Vec { + let len = self.rows * self.cols; + let mut data = vec![T::default(); len]; + unsafe { + std::ptr::copy_nonoverlapping( + self.buffer.contents() as *const T, + data.as_mut_ptr(), + len, + ); + } + data + } + + /// Get the number of rows. + pub fn rows(&self) -> usize { + self.rows + } + + /// Get the number of columns. + pub fn cols(&self) -> usize { + self.cols + } + + /// Get the leading dimension (number of rows for column-major). + pub fn ld(&self) -> usize { + self.rows + } + + /// Get the underlying Metal buffer (for kernel launches). + pub fn as_buffer(&self) -> &Buffer { + &self.buffer + } + + /// Get a mutable reference to the underlying Metal buffer. + pub fn as_buffer_mut(&mut self) -> &mut Buffer { + &mut self.buffer + } +} + +/// A GPU matrix paired with argmax indices (for backward propagation). +/// +/// This stores both the result of a tropical GEMM and the k-indices +/// that produced each optimal value in C[i,j]. Used for gradient computation. +pub struct GpuMatrixWithArgmax { + /// The result matrix C. + pub matrix: GpuMatrix, + /// The argmax indices: argmax[i,j] = k such that C[i,j] = A[i,k] ⊗ B[k,j]. + pub argmax: GpuMatrix, +} + +impl GpuMatrixWithArgmax { + /// Allocate a zeroed GPU matrix with argmax indices. + pub fn alloc(ctx: &MetalContext, rows: usize, cols: usize) -> Result { + let matrix = GpuMatrix::alloc(ctx, rows, cols)?; + let argmax = GpuMatrix::alloc(ctx, rows, cols)?; + + Ok(Self { matrix, argmax }) + } + + /// Get the number of rows. + pub fn rows(&self) -> usize { + self.matrix.rows() + } + + /// Get the number of columns. + pub fn cols(&self) -> usize { + self.matrix.cols() + } + + /// Copy the result matrix back to host in row-major order. + pub fn matrix_to_host_row_major(&self, ctx: &MetalContext) -> Result> { + self.matrix.to_host_row_major(ctx) + } + + /// Copy the argmax indices back to host in row-major order. + pub fn argmax_to_host_row_major(&self, ctx: &MetalContext) -> Result> { + self.argmax.to_host_row_major(ctx) + } + + /// Copy the result matrix back to host in column-major order. + pub fn matrix_to_host_col_major(&self, ctx: &MetalContext) -> Result> { + self.matrix.to_host_col_major(ctx) + } + + /// Copy the argmax indices back to host in column-major order. + pub fn argmax_to_host_col_major(&self, ctx: &MetalContext) -> Result> { + self.argmax.to_host_col_major(ctx) + } +} diff --git a/crates/tropical-gemm-metal/src/kernels.rs b/crates/tropical-gemm-metal/src/kernels.rs new file mode 100644 index 0000000..fdcaef7 --- /dev/null +++ b/crates/tropical-gemm-metal/src/kernels.rs @@ -0,0 +1,268 @@ +//! Metal kernel trait and implementations. + +use crate::context::MetalContext; +use crate::error::Result; +use crate::gpu_mat::{GpuMatrix, GpuMatrixWithArgmax}; +use metal::MTLSize; +use tropical_gemm::types::{TropicalMaxMul, TropicalMaxPlus, TropicalMinPlus, TropicalSemiring}; + +/// Trait for types that can be computed on Metal GPU. +pub trait MetalKernel: TropicalSemiring +where + Self::Scalar: Copy + Default, +{ + /// Kernel function name. + const KERNEL_NAME: &'static str; + + /// Execute the tropical GEMM kernel. + /// + /// Computes C = A ⊗ B where ⊗ is tropical matrix multiplication. + fn launch_gemm( + ctx: &MetalContext, + a: &GpuMatrix, + b: &GpuMatrix, + c: &mut GpuMatrix, + ) -> Result<()>; +} + +/// Helper function to launch a Metal kernel with given grid/threadgroup dimensions. +fn launch_kernel_impl( + ctx: &MetalContext, + kernel_name: &'static str, + a: &GpuMatrix, + b: &GpuMatrix, + c: &mut GpuMatrix, + grid_size: MTLSize, + threadgroup_size: MTLSize, +) -> Result<()> { + let m = a.rows() as u32; + let k = a.cols() as u32; + let n = b.cols() as u32; + + let pipeline = ctx.get_pipeline(kernel_name)?; + let command_buffer = ctx.command_queue().new_command_buffer(); + let encoder = command_buffer.new_compute_command_encoder(); + + encoder.set_compute_pipeline_state(pipeline); + encoder.set_buffer(0, Some(a.as_buffer()), 0); + encoder.set_buffer(1, Some(b.as_buffer()), 0); + encoder.set_buffer(2, Some(c.as_buffer()), 0); + encoder.set_bytes(3, std::mem::size_of::() as u64, &m as *const u32 as *const _); + encoder.set_bytes(4, std::mem::size_of::() as u64, &n as *const u32 as *const _); + encoder.set_bytes(5, std::mem::size_of::() as u64, &k as *const u32 as *const _); + + encoder.dispatch_thread_groups(grid_size, threadgroup_size); + encoder.end_encoding(); + + command_buffer.commit(); + command_buffer.wait_until_completed(); + + Ok(()) +} + +/// Macro to implement MetalKernel for f32 types. +macro_rules! impl_metal_kernel_f32 { + ($($semiring:ty => $kernel_name:literal),* $(,)?) => { + $( + impl MetalKernel for $semiring { + const KERNEL_NAME: &'static str = $kernel_name; + + fn launch_gemm( + ctx: &MetalContext, + a: &GpuMatrix, + b: &GpuMatrix, + c: &mut GpuMatrix, + ) -> Result<()> { + let grid_size = MetalContext::grid_size_f32(a.rows(), b.cols()); + let threadgroup_size = MetalContext::threadgroup_size_f32(); + launch_kernel_impl(ctx, Self::KERNEL_NAME, a, b, c, grid_size, threadgroup_size) + } + } + )* + }; +} + +impl_metal_kernel_f32! { + TropicalMaxPlus => "tropical_maxplus_f32_nn", + TropicalMinPlus => "tropical_minplus_f32_nn", + TropicalMaxMul => "tropical_maxmul_f32_nn", +} + +// ============================================================================ +// MetalKernelWithArgmax - for path reconstruction +// ============================================================================ + +/// Trait for tropical GEMM with argmax tracking (for backward propagation). +/// +/// This computes both C[i,j] and the k-index that produced each C[i,j], +/// which is needed for gradient computation in tropical neural networks. +pub trait MetalKernelWithArgmax: TropicalSemiring +where + Self::Scalar: Copy + Default, +{ + /// Kernel function name for the argmax variant. + const ARGMAX_KERNEL_NAME: &'static str; + + /// Execute the tropical GEMM kernel with argmax tracking. + /// + /// Computes C = A ⊗ B and also records argmax[i,j] = k such that + /// C[i,j] = A[i,k] ⊗ B[k,j] was the winning value. + fn launch_gemm_with_argmax( + ctx: &MetalContext, + a: &GpuMatrix, + b: &GpuMatrix, + c: &mut GpuMatrixWithArgmax, + ) -> Result<()>; +} + +/// Helper function to launch an argmax Metal kernel. +fn launch_kernel_with_argmax_impl( + ctx: &MetalContext, + kernel_name: &'static str, + a: &GpuMatrix, + b: &GpuMatrix, + c: &mut GpuMatrixWithArgmax, + grid_size: MTLSize, + threadgroup_size: MTLSize, +) -> Result<()> { + let m = a.rows() as u32; + let k = a.cols() as u32; + let n = b.cols() as u32; + + let pipeline = ctx.get_pipeline(kernel_name)?; + let command_buffer = ctx.command_queue().new_command_buffer(); + let encoder = command_buffer.new_compute_command_encoder(); + + encoder.set_compute_pipeline_state(pipeline); + encoder.set_buffer(0, Some(a.as_buffer()), 0); + encoder.set_buffer(1, Some(b.as_buffer()), 0); + encoder.set_buffer(2, Some(c.matrix.as_buffer()), 0); + encoder.set_buffer(3, Some(c.argmax.as_buffer()), 0); + encoder.set_bytes(4, std::mem::size_of::() as u64, &m as *const u32 as *const _); + encoder.set_bytes(5, std::mem::size_of::() as u64, &n as *const u32 as *const _); + encoder.set_bytes(6, std::mem::size_of::() as u64, &k as *const u32 as *const _); + + encoder.dispatch_thread_groups(grid_size, threadgroup_size); + encoder.end_encoding(); + + command_buffer.commit(); + command_buffer.wait_until_completed(); + + Ok(()) +} + +/// Macro to implement MetalKernelWithArgmax for f32 types. +macro_rules! impl_metal_kernel_with_argmax_f32 { + ($($semiring:ty => $kernel_name:literal),* $(,)?) => { + $( + impl MetalKernelWithArgmax for $semiring { + const ARGMAX_KERNEL_NAME: &'static str = $kernel_name; + + fn launch_gemm_with_argmax( + ctx: &MetalContext, + a: &GpuMatrix, + b: &GpuMatrix, + c: &mut GpuMatrixWithArgmax, + ) -> Result<()> { + let grid_size = MetalContext::grid_size_f32(a.rows(), b.cols()); + let threadgroup_size = MetalContext::threadgroup_size_f32(); + launch_kernel_with_argmax_impl(ctx, Self::ARGMAX_KERNEL_NAME, a, b, c, grid_size, threadgroup_size) + } + } + )* + }; +} + +impl_metal_kernel_with_argmax_f32! { + TropicalMaxPlus => "tropical_maxplus_f32_nn_with_argmax", + TropicalMinPlus => "tropical_minplus_f32_nn_with_argmax", + TropicalMaxMul => "tropical_maxmul_f32_nn_with_argmax", +} + +// ============================================================================ +// Backward pass kernels +// ============================================================================ + +/// Launch the backward pass kernel for gradient w.r.t. A. +#[allow(dead_code)] +pub fn launch_backward_a( + ctx: &MetalContext, + grad_c: &GpuMatrix, + argmax: &GpuMatrix, + grad_a: &mut GpuMatrix, + m: usize, + n: usize, + k: usize, +) -> Result<()> { + let total = m * n; + let pipeline = ctx.get_pipeline("tropical_backward_a_f32")?; + let command_buffer = ctx.command_queue().new_command_buffer(); + let encoder = command_buffer.new_compute_command_encoder(); + + let m_u32 = m as u32; + let n_u32 = n as u32; + let k_u32 = k as u32; + + encoder.set_compute_pipeline_state(pipeline); + encoder.set_buffer(0, Some(grad_c.as_buffer()), 0); + encoder.set_buffer(1, Some(argmax.as_buffer()), 0); + encoder.set_buffer(2, Some(grad_a.as_buffer()), 0); + encoder.set_bytes(3, std::mem::size_of::() as u64, &m_u32 as *const u32 as *const _); + encoder.set_bytes(4, std::mem::size_of::() as u64, &n_u32 as *const u32 as *const _); + encoder.set_bytes(5, std::mem::size_of::() as u64, &k_u32 as *const u32 as *const _); + + let threads_per_group = 256u64; + let num_groups = ((total as u64) + threads_per_group - 1) / threads_per_group; + let grid_size = MTLSize::new(num_groups, 1, 1); + let threadgroup_size = MTLSize::new(threads_per_group, 1, 1); + + encoder.dispatch_thread_groups(grid_size, threadgroup_size); + encoder.end_encoding(); + + command_buffer.commit(); + command_buffer.wait_until_completed(); + + Ok(()) +} + +/// Launch the backward pass kernel for gradient w.r.t. B. +#[allow(dead_code)] +pub fn launch_backward_b( + ctx: &MetalContext, + grad_c: &GpuMatrix, + argmax: &GpuMatrix, + grad_b: &mut GpuMatrix, + m: usize, + n: usize, + k: usize, +) -> Result<()> { + let total = m * n; + let pipeline = ctx.get_pipeline("tropical_backward_b_f32")?; + let command_buffer = ctx.command_queue().new_command_buffer(); + let encoder = command_buffer.new_compute_command_encoder(); + + let m_u32 = m as u32; + let n_u32 = n as u32; + let k_u32 = k as u32; + + encoder.set_compute_pipeline_state(pipeline); + encoder.set_buffer(0, Some(grad_c.as_buffer()), 0); + encoder.set_buffer(1, Some(argmax.as_buffer()), 0); + encoder.set_buffer(2, Some(grad_b.as_buffer()), 0); + encoder.set_bytes(3, std::mem::size_of::() as u64, &m_u32 as *const u32 as *const _); + encoder.set_bytes(4, std::mem::size_of::() as u64, &n_u32 as *const u32 as *const _); + encoder.set_bytes(5, std::mem::size_of::() as u64, &k_u32 as *const u32 as *const _); + + let threads_per_group = 256u64; + let num_groups = ((total as u64) + threads_per_group - 1) / threads_per_group; + let grid_size = MTLSize::new(num_groups, 1, 1); + let threadgroup_size = MTLSize::new(threads_per_group, 1, 1); + + encoder.dispatch_thread_groups(grid_size, threadgroup_size); + encoder.end_encoding(); + + command_buffer.commit(); + command_buffer.wait_until_completed(); + + Ok(()) +} diff --git a/crates/tropical-gemm-metal/src/lib.rs b/crates/tropical-gemm-metal/src/lib.rs new file mode 100644 index 0000000..d6124b9 --- /dev/null +++ b/crates/tropical-gemm-metal/src/lib.rs @@ -0,0 +1,396 @@ +//! Metal backend for tropical matrix multiplication. +//! +//! This crate provides GPU-accelerated tropical GEMM operations using Metal +//! on Apple Silicon and other macOS GPUs. +//! +//! # Quick Start +//! +//! ```ignore +//! use tropical_gemm_metal::{tropical_matmul_metal, MetalContext}; +//! use tropical_gemm::types::TropicalMaxPlus; +//! +//! // Simple one-shot API (uses cached global context for performance) +//! let a = vec![1.0f32; 1024 * 1024]; +//! let b = vec![1.0f32; 1024 * 1024]; +//! let c = tropical_matmul_metal::>(&a, 1024, 1024, &b, 1024)?; +//! ``` +//! +//! # Persistent Context +//! +//! For explicit context management: +//! +//! ```ignore +//! use tropical_gemm_metal::{MetalContext, GpuMatrix, tropical_gemm_metal}; +//! use tropical_gemm::types::TropicalMaxPlus; +//! +//! let ctx = MetalContext::new()?; +//! +//! let a_gpu = GpuMatrix::from_host_row_major(&ctx, &a, m, k)?; +//! let b_gpu = GpuMatrix::from_host_row_major(&ctx, &b, k, n)?; +//! let mut c_gpu = GpuMatrix::alloc(&ctx, m, n)?; +//! +//! tropical_gemm_metal::>(&ctx, &a_gpu, &b_gpu, &mut c_gpu)?; +//! +//! let c = c_gpu.to_host_row_major(&ctx)?; +//! ``` +//! +//! # Performance +//! +//! The convenience functions (`tropical_matmul_metal`, etc.) use a lazily-initialized +//! global context that persists across calls. This avoids the shader compilation +//! overhead on each call. + +mod context; +mod error; +mod gpu_mat; +mod kernels; + +use once_cell::sync::OnceCell; +use std::sync::Mutex; + +/// Global Metal context for convenience functions. +/// Lazily initialized on first use, persists for process lifetime. +static GLOBAL_CONTEXT: OnceCell = OnceCell::new(); + +/// Mutex to ensure only one thread initializes the context. +static INIT_MUTEX: Mutex<()> = Mutex::new(()); + +/// Get or initialize the global Metal context. +/// +/// This function is thread-safe and will only initialize the context once. +/// Subsequent calls return the cached context. +/// +/// # Errors +/// +/// Returns an error if Metal initialization fails (no device, etc.) +pub fn get_global_context() -> Result<&'static MetalContext> { + // Fast path: already initialized + if let Some(ctx) = GLOBAL_CONTEXT.get() { + return Ok(ctx); + } + + // Slow path: need to initialize + let _lock = INIT_MUTEX.lock().unwrap(); + + // Double-check after acquiring lock + if let Some(ctx) = GLOBAL_CONTEXT.get() { + return Ok(ctx); + } + + // Initialize and store + let ctx = MetalContext::new()?; + let _ = GLOBAL_CONTEXT.set(ctx); + + Ok(GLOBAL_CONTEXT.get().unwrap()) +} + +pub use context::MetalContext; +pub use error::{MetalError, Result}; +pub use gpu_mat::{ArgmaxIndex, GpuMatrix, GpuMatrixWithArgmax}; +pub use kernels::{MetalKernel, MetalKernelWithArgmax}; + +/// One-shot tropical matrix multiplication on Metal GPU. +/// +/// This function handles all GPU memory management automatically. +/// For repeated operations, use `tropical_gemm_metal` with a persistent context. +/// +/// # Arguments +/// +/// * `a` - Matrix A in row-major order, dimensions m×k +/// * `m` - Number of rows in A +/// * `k` - Number of columns in A / rows in B +/// * `b` - Matrix B in row-major order, dimensions k×n +/// * `n` - Number of columns in B +/// +/// # Returns +/// +/// Result matrix C in row-major order, dimensions m×n +/// +/// # Example +/// +/// ```ignore +/// use tropical_gemm_metal::tropical_matmul_metal; +/// use tropical_gemm::types::TropicalMaxPlus; +/// +/// let a = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; // 2x3 +/// let b = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; // 3x2 +/// +/// let c = tropical_matmul_metal::>(&a, 2, 3, &b, 2)?; +/// // c is 2x2, row-major +/// ``` +pub fn tropical_matmul_metal( + a: &[T::Scalar], + m: usize, + k: usize, + b: &[T::Scalar], + n: usize, +) -> Result> +where + T: MetalKernel, + T::Scalar: Copy + Default, +{ + if a.len() != m * k { + return Err(MetalError::DimensionMismatch(format!( + "A: expected {} elements, got {}", + m * k, + a.len() + ))); + } + if b.len() != k * n { + return Err(MetalError::DimensionMismatch(format!( + "B: expected {} elements, got {}", + k * n, + b.len() + ))); + } + + // Use global cached context to avoid shader recompilation + let ctx = get_global_context()?; + + let a_gpu = GpuMatrix::from_host_row_major(ctx, a, m, k)?; + let b_gpu = GpuMatrix::from_host_row_major(ctx, b, k, n)?; + let mut c_gpu = GpuMatrix::alloc(ctx, m, n)?; + + T::launch_gemm(ctx, &a_gpu, &b_gpu, &mut c_gpu)?; + + c_gpu.to_host_row_major(ctx) +} + +/// Tropical matrix multiplication with persistent context. +/// +/// Use this function when performing multiple GPU operations to avoid +/// repeated context initialization and kernel compilation. +/// +/// # Arguments +/// +/// * `ctx` - Metal context +/// * `a` - Matrix A on GPU +/// * `b` - Matrix B on GPU +/// * `c` - Output matrix C on GPU (will be overwritten) +pub fn tropical_gemm_metal( + ctx: &MetalContext, + a: &GpuMatrix, + b: &GpuMatrix, + c: &mut GpuMatrix, +) -> Result<()> +where + T: MetalKernel, + T::Scalar: Copy + Default, +{ + if a.cols() != b.rows() { + return Err(MetalError::DimensionMismatch(format!( + "A.cols ({}) != B.rows ({})", + a.cols(), + b.rows() + ))); + } + if c.rows() != a.rows() || c.cols() != b.cols() { + return Err(MetalError::DimensionMismatch(format!( + "C dimensions ({}, {}) don't match A×B ({}, {})", + c.rows(), + c.cols(), + a.rows(), + b.cols() + ))); + } + + T::launch_gemm(ctx, a, b, c) +} + +/// One-shot tropical matrix multiplication with argmax on Metal GPU. +/// +/// This function handles all GPU memory management automatically. +/// Returns both the result matrix and argmax indices for backpropagation. +/// +/// # Arguments +/// +/// * `a` - Matrix A in row-major order, dimensions m×k +/// * `m` - Number of rows in A +/// * `k` - Number of columns in A / rows in B +/// * `b` - Matrix B in row-major order, dimensions k×n +/// * `n` - Number of columns in B +/// +/// # Returns +/// +/// Tuple of (result matrix C, argmax indices) in row-major order +pub fn tropical_matmul_metal_with_argmax( + a: &[T::Scalar], + m: usize, + k: usize, + b: &[T::Scalar], + n: usize, +) -> Result<(Vec, Vec)> +where + T: MetalKernelWithArgmax, + T::Scalar: Copy + Default, +{ + if a.len() != m * k { + return Err(MetalError::DimensionMismatch(format!( + "A: expected {} elements, got {}", + m * k, + a.len() + ))); + } + if b.len() != k * n { + return Err(MetalError::DimensionMismatch(format!( + "B: expected {} elements, got {}", + k * n, + b.len() + ))); + } + + let ctx = get_global_context()?; + + let a_gpu = GpuMatrix::from_host_row_major(ctx, a, m, k)?; + let b_gpu = GpuMatrix::from_host_row_major(ctx, b, k, n)?; + let mut c_gpu = GpuMatrixWithArgmax::alloc(ctx, m, n)?; + + T::launch_gemm_with_argmax(ctx, &a_gpu, &b_gpu, &mut c_gpu)?; + + let c_data = c_gpu.matrix_to_host_row_major(ctx)?; + let argmax_data = c_gpu.argmax_to_host_row_major(ctx)?; + + Ok((c_data, argmax_data)) +} + +#[cfg(test)] +mod tests { + use super::*; + use tropical_gemm::types::{TropicalMaxPlus, TropicalMinPlus, TropicalMaxMul}; + + #[test] + fn test_maxplus_basic() { + let _ctx = match MetalContext::new() { + Ok(c) => c, + Err(_) => { + println!("Metal not available, skipping test"); + return; + } + }; + + let a = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; // 2x3 + let b = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; // 3x2 + + let c = tropical_matmul_metal::>(&a, 2, 3, &b, 2).unwrap(); + + // C[0,0] = max(1+1, 2+3, 3+5) = 8 + assert!((c[0] - 8.0).abs() < 1e-5); + // C[0,1] = max(1+2, 2+4, 3+6) = 9 + assert!((c[1] - 9.0).abs() < 1e-5); + // C[1,0] = max(4+1, 5+3, 6+5) = 11 + assert!((c[2] - 11.0).abs() < 1e-5); + // C[1,1] = max(4+2, 5+4, 6+6) = 12 + assert!((c[3] - 12.0).abs() < 1e-5); + } + + #[test] + fn test_minplus_basic() { + let _ctx = match MetalContext::new() { + Ok(c) => c, + Err(_) => { + println!("Metal not available, skipping test"); + return; + } + }; + + let a = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; // 2x3 + let b = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; // 3x2 + + let c = tropical_matmul_metal::>(&a, 2, 3, &b, 2).unwrap(); + + // C[0,0] = min(1+1, 2+3, 3+5) = 2 + assert!((c[0] - 2.0).abs() < 1e-5); + // C[0,1] = min(1+2, 2+4, 3+6) = 3 + assert!((c[1] - 3.0).abs() < 1e-5); + // C[1,0] = min(4+1, 5+3, 6+5) = 5 + assert!((c[2] - 5.0).abs() < 1e-5); + // C[1,1] = min(4+2, 5+4, 6+6) = 6 + assert!((c[3] - 6.0).abs() < 1e-5); + } + + #[test] + fn test_maxmul_basic() { + let _ctx = match MetalContext::new() { + Ok(c) => c, + Err(_) => { + println!("Metal not available, skipping test"); + return; + } + }; + + let a = vec![1.0f32, 2.0, 3.0, 4.0]; // 2x2 + let b = vec![2.0f32, 3.0, 4.0, 5.0]; // 2x2 + + let c = tropical_matmul_metal::>(&a, 2, 2, &b, 2).unwrap(); + + // C[0,0] = max(1*2, 2*4) = 8 + assert!((c[0] - 8.0).abs() < 1e-5); + // C[0,1] = max(1*3, 2*5) = 10 + assert!((c[1] - 10.0).abs() < 1e-5); + // C[1,0] = max(3*2, 4*4) = 16 + assert!((c[2] - 16.0).abs() < 1e-5); + // C[1,1] = max(3*3, 4*5) = 20 + assert!((c[3] - 20.0).abs() < 1e-5); + } + + #[test] + fn test_maxplus_with_argmax() { + let _ctx = match MetalContext::new() { + Ok(c) => c, + Err(_) => { + println!("Metal not available, skipping test"); + return; + } + }; + + let a = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; // 2x3 + let b = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; // 3x2 + + let (c, argmax) = tropical_matmul_metal_with_argmax::>(&a, 2, 3, &b, 2).unwrap(); + + // C[0,0] = max(1+1, 2+3, 3+5) = 8, argmax = 2 + assert!((c[0] - 8.0).abs() < 1e-5); + assert_eq!(argmax[0], 2); + + // C[1,1] = max(4+2, 5+4, 6+6) = 12, argmax = 2 + assert!((c[3] - 12.0).abs() < 1e-5); + assert_eq!(argmax[3], 2); + } + + #[test] + fn test_larger_matrix() { + let _ctx = match MetalContext::new() { + Ok(c) => c, + Err(_) => { + println!("Metal not available, skipping test"); + return; + } + }; + + let m = 64; + let k = 128; + let n = 64; + + let a: Vec = (0..m*k).map(|i| i as f32 * 0.01).collect(); + let b: Vec = (0..k*n).map(|i| i as f32 * 0.01).collect(); + + let c = tropical_matmul_metal::>(&a, m, k, &b, n).unwrap(); + + assert_eq!(c.len(), m * n); + } + + #[test] + fn test_device_name() { + let ctx = match MetalContext::new() { + Ok(c) => c, + Err(_) => { + println!("Metal not available, skipping test"); + return; + } + }; + + let name = ctx.device_name(); + println!("Metal device: {}", name); + assert!(!name.is_empty()); + } +}