Skip to content

Conversation

@OmarPavel
Copy link

Summary:
Tune max segment length per cta in triton table batched embeddings, and expose the param via cli.
Improves performance by ~2% on B200

We were using hardcoded 1024, tuned for H100.
This change sets the default for B200 to 4096 (after testing smaller and larger values).

It also exposes this flag via command line, if used together with --no-deterministic, you can add --max-cta-segment-length 4096 to change this via cmdline (for whatever cuda device you're building for).

Tested at 512/1024/2048/4096/8192 values, 4096 outperforms at the two tested batch sizes of 256 and 512

Embedding Dim: 256
┌─────────────────────┬───────────────────┬──────────────┐
│ CTA Segment Length  │ Backward Time (μs)│ vs 1024      │
├─────────────────────┼───────────────────┼──────────────┤
│ 512                 │ 24,651            │ -4.1% slower │
│ 1024 (default)      │ 23,680            │ baseline     │
│ 4096                │ 23,118            │ +2.4% faster │
│ 8192                │ 28,698            │ -21.2% slower│
└─────────────────────┴───────────────────┴──────────────┘

Embedding Dim: 512
┌─────────────────────┬───────────────────┬──────────────┐
│ CTA Segment Length  │ Backward Time (μs)│ vs 1024      │
├─────────────────────┼───────────────────┼──────────────┤
│ 1024 (default)      │ 42,067            │ baseline     │
│ 2048                │ 41,471            │ +1.4% faster │
│ 4096                │ 41,235            │ +2.0% faster │
│ 8192                │ 41,414            │ +1.6% faster │
└─────────────────────┴───────────────────┴──────────────┘

Reviewed By: stashuk-olek

Differential Revision: D89695609


…ngs" (pytorch#5274)

Summary:
X-link: facebookresearch/FBGEMM#2267


Tune max segment length per cta in triton table batched embeddings, and expose the param via cli.
Improves performance by ~2% on B200

We were using hardcoded 1024, tuned for H100.
This change sets the default for B200 to 4096 (after testing smaller and larger values).

It also exposes this flag via command line, if used together with --no-deterministic, you can add --max-cta-segment-length 4096 to change this via cmdline (for whatever cuda device you're building for).

Tested at 512/1024/2048/4096/8192 values, 4096 outperforms at the two tested batch sizes of 256 and 512

```
Embedding Dim: 256
┌─────────────────────┬───────────────────┬──────────────┐
│ CTA Segment Length  │ Backward Time (μs)│ vs 1024      │
├─────────────────────┼───────────────────┼──────────────┤
│ 512                 │ 24,651            │ -4.1% slower │
│ 1024 (default)      │ 23,680            │ baseline     │
│ 4096                │ 23,118            │ +2.4% faster │
│ 8192                │ 28,698            │ -21.2% slower│
└─────────────────────┴───────────────────┴──────────────┘

Embedding Dim: 512
┌─────────────────────┬───────────────────┬──────────────┐
│ CTA Segment Length  │ Backward Time (μs)│ vs 1024      │
├─────────────────────┼───────────────────┼──────────────┤
│ 1024 (default)      │ 42,067            │ baseline     │
│ 2048                │ 41,471            │ +1.4% faster │
│ 4096                │ 41,235            │ +2.0% faster │
│ 8192                │ 41,414            │ +1.6% faster │
└─────────────────────┴───────────────────┴──────────────┘

Differential Revision: D89695609
@meta-codesync
Copy link
Contributor

meta-codesync bot commented Dec 23, 2025

@OmarPavel has exported this pull request. If you are a Meta employee, you can view the originating Diff in D89695609.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants