Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ packages = [
{ include = "parallax", from = "src" },
{ include = "scheduling", from = "src" },
{ include = "parallax_utils", from = "src" },
{ include = "parallax_extensions", from = "src" },
]

dependencies = [
Expand Down
85 changes: 85 additions & 0 deletions src/parallax_extensions/CMakelists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
cmake_minimum_required(VERSION 3.27)

project(_ext LANGUAGES CXX)

# ----------------------------- Setup -----------------------------
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_POSITION_INDEPENDENT_CODE ON)

option(BUILD_SHARED_LIBS "Build extensions as a shared library" ON)

# ----------------------------- Dependencies -----------------------------
find_package(
Python 3.10
COMPONENTS Interpreter Development.Module
REQUIRED)
execute_process(
COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir
OUTPUT_STRIP_TRAILING_WHITESPACE
OUTPUT_VARIABLE nanobind_ROOT)
find_package(nanobind CONFIG REQUIRED)

execute_process(
COMMAND "${Python_EXECUTABLE}" -m mlx --cmake-dir
OUTPUT_STRIP_TRAILING_WHITESPACE
OUTPUT_VARIABLE MLX_ROOT)
find_package(MLX CONFIG REQUIRED)

# ----------------------------- Extensions -----------------------------

# Add library
add_library(parallax_ext)

# Add sources
target_sources(
parallax_ext
PUBLIC ${CMAKE_CURRENT_LIST_DIR}/kernels/paged_attention.cpp
${CMAKE_CURRENT_LIST_DIR}/kernels/reshape_and_cache.cpp
${CMAKE_CURRENT_LIST_DIR}/kernels/utils.cpp)

# Add include headers
target_include_directories(parallax_ext PUBLIC ${CMAKE_CURRENT_LIST_DIR})

# Link to mlx
target_link_libraries(parallax_ext PUBLIC mlx)

# ----------------------------- Metal -----------------------------

# Build metallib
if(MLX_BUILD_METAL)
mlx_build_metallib(
TARGET
parallax_ext_metallib
TITLE
parallax_ext
SOURCES
${CMAKE_CURRENT_LIST_DIR}/kernels/utils.metal
${CMAKE_CURRENT_LIST_DIR}/kernels/float8.metal
${CMAKE_CURRENT_LIST_DIR}/kernels/paged_attention.metal
${CMAKE_CURRENT_LIST_DIR}/kernels/reshape_and_cache.metal
INCLUDE_DIRS
${PROJECT_SOURCE_DIR}
${MLX_INCLUDE_DIRS}
OUTPUT_DIRECTORY
${CMAKE_LIBRARY_OUTPUT_DIRECTORY})

add_dependencies(parallax_ext parallax_ext_metallib)

endif()

# ----------------------------- Python Bindings -----------------------------
nanobind_add_module(
_ext
NB_STATIC
STABLE_ABI
LTO
NOMINSIZE
NB_DOMAIN
mlx
${CMAKE_CURRENT_LIST_DIR}/bindings.cpp)
target_link_libraries(_ext PRIVATE parallax_ext)

if(BUILD_SHARED_LIBS)
target_link_options(_ext PRIVATE -Wl,-rpath,@loader_path)
endif()
38 changes: 38 additions & 0 deletions src/parallax_extensions/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
## Parallax MLX Kernel Extentions
Extended kernels built for MLX backend.
MLX official instructions for custom extensions: https://ml-explore.github.io/mlx/build/html/dev/extensions.html

### Directory Structure
```bash
.
β”œβ”€β”€ __init__.py
β”œβ”€β”€ bindings.cpp # Nanobind
β”œβ”€β”€ CMakelists.txt
β”œβ”€β”€ lib
β”‚Β Β  β”œβ”€β”€ _ext.cpython-311-darwin.so # Python Binding
β”‚Β Β  β”œβ”€β”€ libparallax_ext.dylib # C++ extension library
β”‚Β Β  └── parallax_ext.metallib # Metal library
β”œβ”€β”€ paged_attention_v1 # Kernel Source Code Directories
β”‚Β Β  β”œβ”€β”€ float8.metal
β”‚Β Β  β”œβ”€β”€ paged_attention.cpp
β”‚Β Β  β”œβ”€β”€ paged_attention.h
β”‚Β Β  β”œβ”€β”€ paged_attention.metal
β”‚Β Β  β”œβ”€β”€ reshape_and_cache.metal
β”‚Β Β  └── utils.metal
β”œβ”€β”€ README.md
└── setup.py # Setup Tools Script
```

### Package Build and Install
Build inplace for development using:
```sh
python setup.py build_ext -j8 --inplace
```
Then you can try to install in the directory using the command ```python -m pip install .```.
The pre-built package should be already installed in the parallax project.

### Usage Example
```python
import mlx.core as mx
from parallax_extensions import paged_attention_v1
```
4 changes: 4 additions & 0 deletions src/parallax_extensions/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# autoflake: skip_file
import mlx.core as mx

from .lib._ext import paged_attention_v1, reshape_and_cache
68 changes: 68 additions & 0 deletions src/parallax_extensions/bindings.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
#include <nanobind/nanobind.h>
#include <nanobind/stl/variant.h>

#include "kernels/paged_attention.h"
#include "kernels/reshape_and_cache.h"

namespace nb = nanobind;
using namespace nb::literals;

NB_MODULE(_ext, m) {
m.doc() = "Parallax extensions";

m.def(
"paged_attention_v1",
&parallax_ext::paged_attention_v1,
"query"_a,
"key_cache"_a,
"value_cache"_a,
"block_tables"_a,
"seq_lens"_a,
"num_kv_heads"_a,
"block_size"_a,
"max_seq_len"_a,
"scale"_a,
nb::kw_only(),
"stream"_a = nb::none(),
R"(
vLLM PagedAttentionV1 operation

Args:
query (array): Input array [num_seqs, num_heads, head_size].
key_cache (array): Input array [num_blocks, num_heads, head_size/x, block_size, x].
value_cache (array): Input array [num_blocks, num_heads, head_size, block_size].
block_tables (array): Input array [num_seqs, max_num_blocks_per_seq].
seq_lens (array): Input array [num_seqs].
num_kv_heads (int): Input parameter.
block_size (int): Input parameter.
max_seq_len (int): Input parameter.
scale (float): Input parameter.

Returns:
array: ``Paged attention result``
)");

m.def(
"reshape_and_cache",
&parallax_ext::reshape_and_cache,
"key"_a,
"value"_a,
"key_cache"_a,
"value_cache"_a,
"slot_mapping"_a,
nb::kw_only(),
"stream"_a = nb::none(),
R"(
vLLM ReshapeAndCache operation

Args:
key (array): Input array [num_tokens, num_heads, head_size].
value (array): Input array [num_tokens, num_heads, head_size].
key_cache (array): Input array [num_blocks, num_heads, head_size/x, block_size, x].
value_cache (array): Input array [num_blocks, num_heads, head_size, block_size].
slot_mapping (array): Input array [num_tokens].

Returns:
array: ``Dummy output``
)");
}
122 changes: 122 additions & 0 deletions src/parallax_extensions/kernels/float8.metal
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
#include <metal_stdlib>
using namespace metal;

// Helpers ------------------------------------------------------------
static inline uint as_bits(float x) { return as_type<uint>(x); }
static inline float from_bits(uint b) { return as_type<float>(b); }

// -------------------------------------------------------------------
// FP8 E4M3 (bias = 7)
// -------------------------------------------------------------------
inline float fp8_e4m3_to_float(uchar v) {
const uint s = v >> 7;
const uint exp = (v >> 3) & 0xF;
const uint man = v & 0x7;

if (exp == 0) { // zero / sub-normal
if (man == 0)
return s ? -0.f : 0.f;
const float m = float(man) / 8.f; // already scaled by 2^-3
float val = ldexp(m, 1 - 7); // 2^(1-bias) = 2^-6
return s ? -val : val;
}

if (exp == 0xF) { // Inf / NaN (E4M3FN keeps only NaN)
if (man != 0)
return NAN;
return s ? -INFINITY : INFINITY;
}

const float m = 1.f + float(man) / 8.f;
float val = ldexp(m, int(exp) - 7);
return s ? -val : val;
}

// -------------------------------------------------------------------
// FP8 E5M2 (bias = 15)
// -------------------------------------------------------------------
inline float fp8_e5m2_to_float(uchar v) {
const uint s = v >> 7;
const uint exp = (v >> 2) & 0x1F;
const uint man = v & 0x3;

if (exp == 0) {
if (man == 0)
return s ? -0.f : 0.f;
const float m = float(man) / 4.f;
float val = ldexp(m, 1 - 15); // 2^(1-bias) = 2^-14
return s ? -val : val;
}

if (exp == 0x1F) {
if (man != 0)
return NAN;
return s ? -INFINITY : INFINITY;
}

const float m = 1.f + float(man) / 4.f;
float val = ldexp(m, int(exp) - 15);
return s ? -val : val;
}

// -------------------------------------------------------------------
// Encoding helpers (round-to-nearest-even, gradual under-flow, sat-to-∞)
// -------------------------------------------------------------------
namespace detail {
template <int EXP_BITS, int MAN_BITS, int BIAS>
inline uchar fp32_to_fp8(float f) {
const uint bits = as_bits(f);
const uint s = bits >> 31;
const uint abs = bits & 0x7FFFFFFF;

// NaN propagates, Inf saturates
if (abs >= 0x7F800000u) {
return uchar((s << 7) | (((1u << EXP_BITS) - 1u) << MAN_BITS) |
(abs != 0x7F800000u));
}

int e = int((abs >> 23) & 0xFF) - 127; // unbiased exponent
uint m = abs & 0x7FFFFFu; // 23-bit mantissa
const int EXP_MAX = (1 << EXP_BITS) - 2; // last finite exponent

// ---------- Normal path -------------------------------------------------
int e_fp8 = e + BIAS;
if (e_fp8 >= 1 && e_fp8 <= EXP_MAX) {
// round-to-nearest-even
const int shift = 23 - MAN_BITS;
uint mant = m >> shift;
const uint lsb = mant & 1u;
const uint round = (m >> (shift - 1)) & 1u;
const uint sticky = (m & ((1u << (shift - 1)) - 1u)) != 0u;
mant += (round & (sticky | lsb));
if (mant >> MAN_BITS) { // mantissa overflow
mant = 0;
++e_fp8;
if (e_fp8 > EXP_MAX)
return uchar((s << 7) | (((1u << EXP_BITS) - 1u) << MAN_BITS)); // ∞
}
return uchar((s << 7) | (uint(e_fp8) << MAN_BITS) |
(mant & ((1u << MAN_BITS) - 1u)));
}

// ---------- Sub-normal / under-flow ------------------------------------
if (e_fp8 < 1 - MAN_BITS) // too small -> Β±0
return uchar(s << 7);

// shift so that exponent becomes 1
int rshift = (1 - e_fp8) + (23 - MAN_BITS);
uint mant = (0x800000u | m); // implicit 1
uint rounded = (mant + (1u << (rshift - 1))) >> rshift;
if (rounded == 0)
return uchar(s << 7); // rounds to zero

return uchar((s << 7) | (rounded & ((1u << MAN_BITS) - 1u)));
}
} // namespace detail

inline uchar float_to_fp8_e4m3(float f) {
return detail::fp32_to_fp8<4, 3, 7>(f);
}
inline uchar float_to_fp8_e5m2(float f) {
return detail::fp32_to_fp8<5, 2, 15>(f);
}
Loading