Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 34 additions & 25 deletions contrib/models/AFM-4.5B-Base/README.md
Original file line number Diff line number Diff line change
@@ -1,37 +1,43 @@
# Contrib Model: AFM 4.5B Base
# Contrib Model: AFM 4.5B Base (Arcee)

NeuronX Distributed Inference implementation of AFM 4.5B Base.
NeuronX Distributed Inference implementation of AFM 4.5B Base (Arcee architecture).

## Model Information

- **HuggingFace ID:** `AFM-4.5B-Base`
- **HuggingFace ID:** `arcee-ai/AFM-4.5B-Base`
- **Model Type:** Decoder-only transformer
- **Architecture:** Arcee (similar to LLaMA with YaRN RoPE and ReLU² activation)
- **License:** Check HuggingFace model card

## Architecture Details

- **Hidden Size:** 2560
- **Num Layers:** 36
- **Attention Heads:** 20 (Q), 4 (KV) - Grouped Query Attention
- **Head Dim:** 128
- **Intermediate Size:** 18432
- **Vocab Size:** 128004
- **Max Position Embeddings:** 65536
- **RoPE Scaling:** YaRN (factor=20, original_max_pos=4096)
- **Activation:** ReLU² (relu(x).pow(2))

## Validation Results

**Validated:** 2026-01-29
**Configuration:** TP=32, batch_size=None, seq_len=None, None
**Validated:** 2026-02-06
**Configuration:** TP=1, batch_size=1, seq_len=128, bfloat16

### Test Results

| Test | Status | Result |
|------|--------|--------|
| Smoke Test | ✅ PASS | Model loads successfully |
| Token Matching | ⚠️ LOW | **41.0% match** |
| Throughput | ⚠️ SLOW | 8.10 tok/s (threshold: 10 tok/s) |
| Token Matching | ✅ PASS | **100% match** |

### Performance Metrics
### Key Implementation Notes

| Metric | Value |
|--------|-------|
| Throughput | 8.10 tokens/s |


**Status:** ⚠️ VALIDATED
1. **YaRN RoPE Scaling:** Implements the YaRN (Yet another RoPE extensioN) mechanism for extended context support (65k tokens)
2. **ReLU² Activation:** Uses `relu(x).pow(2)` instead of SwiGLU - only `up_proj` and `down_proj` (no `gate_proj`)
3. **State Dict Conversion:** QKV projections are separate in HF, combined into `qkv_proj` for Neuron

## Usage

Expand All @@ -41,32 +47,35 @@ from neuronx_distributed_inference.models.config import NeuronConfig
from neuronx_distributed_inference.utils.hf_adapter import load_pretrained_config

# Import model classes from src
from src.modeling_afm_4_5b_base import NeuronAFM45BBaseForCausalLM, AFM45BBaseInferenceConfig
from src.modeling_afm import NeuronAFMForCausalLM, AFMInferenceConfig

model_path = "/path/to/AFM-4.5B-Base/"
compiled_model_path = "/path/to/compiled/"

# Configure
neuron_config = NeuronConfig(
tp_degree=32,
batch_size=None,
seq_len=512,
torch_dtype=torch.None,
tp_degree=1,
batch_size=1,
seq_len=128,
torch_dtype=torch.bfloat16,
)

config = AFM45BBaseInferenceConfig(
neuron_config,
config = AFMInferenceConfig(

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you move the test to a different file to follow the structure in the PR Conversation?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure I follow - the file is under test/integration/test_model.py

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I see, you are saying there already is a file. Approved.

neuron_config=neuron_config,
load_config=load_pretrained_config(model_path),
)

# Compile and load
model = NeuronAFM45BBaseForCausalLM(model_path, config)
model = NeuronAFMForCausalLM(model_path, config)
model.compile(compiled_model_path)
model.load(compiled_model_path)

# Generate
tokenizer = AutoTokenizer.from_pretrained(model_path)
# ... (see integration test for full example)
inputs = tokenizer("1+1=", return_tensors="pt")
outputs = model.generate(inputs.input_ids, max_new_tokens=20)
print(tokenizer.decode(outputs[0]))
# Output: 1+1=2, 2+2=4, 3+3=6, 4+4=8, ...
```

## Compatibility Matrix
Expand All @@ -93,10 +102,10 @@ python3 test/integration/test_model.py

## Example Checkpoints

* AFM-4.5B-Base
* arcee-ai/AFM-4.5B-Base

## Maintainer

Neuroboros Team - Annapurna Labs

**Last Updated:** 2026-01-29
**Last Updated:** 2026-02-06