Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 82 additions & 10 deletions guides/ipynb/quantization_overview.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -122,14 +122,23 @@
"x_train = rng.random((100, 10)).astype(\"float32\")\n",
"y_train = rng.random((100, 1)).astype(\"float32\")\n",
"\n",
"\n",
"# Build the model.\n",
"def get_model():\n",
" \"\"\"\n",
" Helper to build a simple sequential model.\n",
" \"\"\"\n",
" return keras.Sequential(\n",
" [\n",
" keras.Input(shape=(10,)),\n",
" keras.layers.Dense(32, activation=\"relu\", name=\"dense_1\"),\n",
" keras.layers.Dense(1, name=\"output_head\"),\n",
" ]\n",
" )\n",
"\n",
"\n",
"# Build the model.\n",
"model = keras.Sequential(\n",
" [\n",
" keras.Input(shape=(10,)),\n",
" keras.layers.Dense(32, activation=\"relu\"),\n",
" keras.layers.Dense(1),\n",
" ]\n",
")\n",
"model = get_model()\n",
"\n",
"# Compile and fit the model.\n",
"model.compile(optimizer=\"adam\", loss=\"mean_squared_error\")\n",
Expand All @@ -147,7 +156,70 @@
"source": [
"**What this does:** Quantizes the weights of the supported layers, and re-wires their forward paths to be compatible with the quantized kernels and quantization scales.\n",
"\n",
"**Note**: Throughput gains depend on backend/hardware kernels; in cases where kernels fall back to dequantized matmul, you still get memory savings but smaller speedups.\n",
"**Note**: Throughput gains depend on backend/hardware kernels; in cases where kernels fall back to dequantized matmul, you still get memory savings but smaller speedups."
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text"
},
"source": [
"### Selective Quantization\n",
"\n",
"You can quantize only a subset of the model's layers by passing a `filters` argument to `quantize()`. This argument can be a single regex string, a list of regex strings, or a callable that takes a layer instance and returns a boolean.\n",
"\n",
"**Using Regex Filters:**"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab_type": "code"
},
"outputs": [],
"source": [
"\n",
"# Quantize only layers with \"dense\" in the name, but skip \"output\"\n",
"model = get_model()\n",
"model.quantize(\"int8\", filters=[\"dense\", \"^((?!output).)*$\"])"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text"
},
"source": [
"**Using Callable Filters:**"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab_type": "code"
},
"outputs": [],
"source": [
"\n",
"def my_filter(layer):\n",
" # Only quantize Dense layers that aren't the output\n",
" return isinstance(layer, keras.layers.Dense) and layer.name != \"output_head\"\n",
"\n",
"\n",
"model = get_model()\n",
"model.quantize(\"int8\", filters=my_filter)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text"
},
"source": [
"This is particularly useful when you want to avoid quantizing sensitive layers (like the first or last layers of a network) to preserve accuracy.\n",
"\n",
"### Layer-wise Quantization\n",
"\n",
Expand Down Expand Up @@ -201,7 +273,7 @@
"\n",
"Any composite layers that are built from the above (for example, `MultiHeadAttention`, `GroupedQueryAttention`, feed-forward blocks in Transformers) inherit quantization support by construction. This covers the majority of modern encoder-only and decoder-only Transformer architectures.\n",
"\n",
"Since all KerasHub models subclass `keras.Model`, they automatically support the `model.quantize(...)` API. In practice, this means you can take a popular LLM preset, call a single function to obtain an int8/int4/GPTQ-quantized variant, and then save or serve it—without changing your training code.\n",
"Since all KerasHub models subclass `keras.Model`, they automatically support the `model.quantize(...)` API. In practice, this means you can take a popular LLM preset, call a single function to obtain an int8/int4/GPTQ-quantized variant, and then save or serve it\u2014without changing your training code.\n",
"\n",
"## Practical guidance\n",
"\n",
Expand Down Expand Up @@ -239,4 +311,4 @@
},
"nbformat": 4,
"nbformat_minor": 0
}
}
53 changes: 46 additions & 7 deletions guides/md/quantization_overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -94,14 +94,23 @@ rng = np.random.default_rng()
x_train = rng.random((100, 10)).astype("float32")
y_train = rng.random((100, 1)).astype("float32")


# Build the model.
def get_model():
"""
Helper to build a simple sequential model.
"""
return keras.Sequential(
[
keras.Input(shape=(10,)),
keras.layers.Dense(32, activation="relu", name="dense_1"),
keras.layers.Dense(1, name="output_head"),
]
)


# Build the model.
model = keras.Sequential(
[
keras.Input(shape=(10,)),
keras.layers.Dense(32, activation="relu"),
keras.layers.Dense(1),
]
)
model = get_model()

# Compile and fit the model.
model.compile(optimizer="adam", loss="mean_squared_error")
Expand All @@ -115,6 +124,36 @@ model.quantize("int8")

**Note**: Throughput gains depend on backend/hardware kernels; in cases where kernels fall back to dequantized matmul, you still get memory savings but smaller speedups.

### Selective Quantization

You can quantize only a subset of the model's layers by passing a `filters` argument to `quantize()`. This argument can be a single regex string, a list of regex strings, or a callable that takes a layer instance and returns a boolean.

**Using Regex Filters:**


```python

# Quantize only layers with "dense" in the name, but skip "output"
model = get_model()
model.quantize("int8", filters=["dense", "^((?!output).)*$"])
```

**Using Callable Filters:**


```python

def my_filter(layer):
# Only quantize Dense layers that aren't the output
return isinstance(layer, keras.layers.Dense) and layer.name != "output_head"


model = get_model()
model.quantize("int8", filters=my_filter)
```

This is particularly useful when you want to avoid quantizing sensitive layers (like the first or last layers of a network) to preserve accuracy.

### Layer-wise Quantization

The Keras quantization framework allows you to quantize each layer separately, without having to quantize the entire model using the same unified API.
Expand Down
54 changes: 47 additions & 7 deletions guides/quantization_overview.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,14 +92,23 @@
x_train = rng.random((100, 10)).astype("float32")
y_train = rng.random((100, 1)).astype("float32")


# Build the model.
def get_model():
"""
Helper to build a simple sequential model.
"""
return keras.Sequential(
[
keras.Input(shape=(10,)),
keras.layers.Dense(32, activation="relu", name="dense_1"),
keras.layers.Dense(1, name="output_head"),
]
)


# Build the model.
model = keras.Sequential(
[
keras.Input(shape=(10,)),
keras.layers.Dense(32, activation="relu"),
keras.layers.Dense(1),
]
)
model = get_model()

# Compile and fit the model.
model.compile(optimizer="adam", loss="mean_squared_error")
Expand All @@ -113,6 +122,37 @@
**Note**: Throughput gains depend on backend/hardware kernels; in cases where kernels fall back to dequantized matmul, you still get memory savings but smaller speedups.
"""

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This empty line between two consecutive docstring blocks will be interpreted as an empty code cell by the tutobooks.py script, resulting in an unnecessary empty cell in the generated Jupyter notebook. Please remove this empty line to avoid this.

"""
### Selective Quantization
You can quantize only a subset of the model's layers by passing a `filters` argument to `quantize()`. This argument can be a single regex string, a list of regex strings, or a callable that takes a layer instance and returns a boolean.
**Using Regex Filters:**
"""


# Quantize only layers with "dense" in the name, but skip "output"
model = get_model()
model.quantize("int8", filters=["dense", "^((?!output).)*$"])

"""
**Using Callable Filters:**
"""


def my_filter(layer):
# Only quantize Dense layers that aren't the output
return isinstance(layer, keras.layers.Dense) and layer.name != "output_head"


model = get_model()
model.quantize("int8", filters=my_filter)

"""
This is particularly useful when you want to avoid quantizing sensitive layers (like the first or last layers of a network) to preserve accuracy.
### Layer-wise Quantization
The Keras quantization framework allows you to quantize each layer separately, without having to quantize the entire model using the same unified API.
Expand Down
4 changes: 2 additions & 2 deletions scripts/tutobooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ def nb_to_md(nb_path, md_path, img_dir, working_dir=None):
del_working_dir = False
if working_dir is None:
del_working_dir = True
working_dir = "tmp_" + str(random.randint(1e6, 1e7))
working_dir = "tmp_" + str(random.randint(10**6, 10**7))
if not os.path.exists(working_dir):
os.makedirs(working_dir)
print("Using working_dir:", working_dir)
Expand Down Expand Up @@ -335,7 +335,7 @@ def validate(py):
# Validate style with black

tmp = tempfile.gettempdir()
fpath = os.path.join(tmp, str(random.randint(1e6, 1e7)) + ".py")
fpath = os.path.join(tmp, str(random.randint(10**6, 10**7)) + ".py")
f = open(fpath, "w")
pre_formatting = "\n".join(lines)
f.write(pre_formatting)
Expand Down