Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 28 additions & 2 deletions F2LLM/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ In this repo we provide a streamlined and efficient script for training embeddin

- Setup environment following `requirements.txt`. We note that transformers>=4.51.0 is required for training Qwen3 models.
- Download data and backbone models from Hugging Face (we use Qwen3 models).
- Run `tokenize_data_qwen.py` to tokenize the downloaded data
- Modify model path, data path, and other arguments in `configs/config.json`.
- Run `python tokenize_data_general.py --model_path <path_to_model> [--arch encoder|decoder|auto] [--no_append_eos_decoder]` to tokenize the downloaded data for both decoder and encoder models. On mac/CPU, no flash-attn is needed; the model will fall back to eager attention.
- Modify model path, data path, and other arguments in `configs/config.json` (decoder) or `configs/config_bert.json` (encoder). You can also set `model_arch` to force behavior if auto-detect is undesirable.
- Start training with `accelerate launch --config_file configs/accelerate_config.yaml run.py --config configs/config.json`.

Note: we recommend setting `num_processes` to 1 in `configs/accelerate_config.yaml` and launch the training code once to generate cache for training data before starting the actual training.
Expand All @@ -42,6 +42,32 @@ where N_NODE is the number of machines; N_PROCESSES is N_NODE\*8; MASTER_IP is t

On worker nodes, also run the above commmand but modify `machine_rank` accordingly.

### Support for Encoder-Only Models

- Decoder-only models: last non-padded token pooling (unchanged); uses flash-attn when available, otherwise falls back to eager.
- Encoder-only models: auto-detected (`BertModel`, `RobertaModel`, `DebertaModel`, `ElectraModel`, `AlbertModel`, `DistilBertModel`) or forced via `model_arch`/`--arch`.
- Pooling options for encoders: `cls` (default), `mean`, `cls_mean` hybrid.
- Tokenization: `tokenize_data_general.py` handles both; you can force `--arch encoder|decoder|auto` and skip EOS appending with `--no_append_eos_decoder`.

Quick start (encoder):

```
python tokenize_data_general.py \
--model_path bert-base-uncased \
--data_dir training_data \
--output_dir training_data/data_tokenized_bert \
--max_seq_length 512 \
--num_processes 8 \
--arch encoder

accelerate launch --config_file configs/accelerate_config.yaml run.py --config configs/config_bert.json
```

Notes and tips
- Typical encoder max length: 512; LR 2e-5 to 5e-5.
- For gated/private HF models, run `huggingface-cli login`.
- On mac/CPU, flash-attn is not required; the code will use eager attention automatically.

### Citation

If you use the F2LLM models, data, or code, please cite the following technical report.
Expand Down
4 changes: 4 additions & 0 deletions F2LLM/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ class Args:
min_lr: float = 1e-6
weight_decay: float = 1e-2
warmup_steps: int = 100
# model architecture: 'decoder' (default) or 'encoder'
model_arch: str = "decoder"
# pooling strategy for embedding: 'last_token' (decoder), 'cls' or 'mean' (encoder)
pooling: str = "last_token"
# embedding-related settings
num_hard_neg: int = 7
# train steps take precedence over epochs, set to -1 to disable
Expand Down
19 changes: 19 additions & 0 deletions F2LLM/configs/config_bert.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
{
"model_path": "bert-base-uncased",
"experiment_id": "bert-base-uncased+lr.2e-5+bs.16x32+context.512+2epochs",
"train_data_path": "training_data/data_tokenized",
"output_dir": "output",
"tb_dir": "output/tb",
"cache_dir": "cache",
"train_batch_size": 16,
"checkpointing_steps": 5000,
"validation_steps": 5000,
"max_seq_length": 512,
"learning_rate": 2e-5,
"min_lr": 1e-7,
"weight_decay": 0.01,
"warmup_steps": 500,
"train_epochs": 2,
"log_interval": 100,
"num_hard_neg": 7
}
67 changes: 57 additions & 10 deletions F2LLM/model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import torch
from transformers import AutoModel, AutoTokenizer
from transformers import AutoModel, AutoTokenizer, AutoConfig


class F2LLM:
Expand All @@ -12,8 +12,33 @@ def __init__(self,
self.args = args
self.dtype = torch.bfloat16
self.device = None # set after accelerator.prepare
self.lm = AutoModel.from_pretrained(model_path, trust_remote_code=True, torch_dtype=self.dtype, attn_implementation='flash_attention_2')
self.lm.config.use_cache = False
config = AutoConfig.from_pretrained(model_path)
encoder_archs = ['BertModel', 'RobertaModel', 'DebertaModel', 'ElectraModel', 'AlbertModel', 'DistilBertModel']

# Allow explicit override via args.model_arch; otherwise infer from config
if self.args and getattr(self.args, 'model_arch', None):
arch_flag = self.args.model_arch.lower()
self.is_encoder_only = arch_flag == 'encoder'
else:
self.is_encoder_only = any(arch in getattr(config, 'architectures', []) for arch in encoder_archs)

# Choose attention impl: prefer flash_attention_2 when available on CUDA for decoders; otherwise fallback to eager
if not self.is_encoder_only and torch.cuda.is_available():
try:
import flash_attn # noqa: F401
attn_impl = 'flash_attention_2'
except Exception:
attn_impl = 'eager'
else:
attn_impl = 'eager'
self.lm = AutoModel.from_pretrained(
model_path,
trust_remote_code=True,
torch_dtype=self.dtype,
attn_implementation=attn_impl
)
if not self.is_encoder_only:
self.lm.config.use_cache = False
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
self.max_seq_length = max_seq_length

Expand All @@ -24,14 +49,36 @@ def forward(self, batch):
bs = batch['bs']
num_hard_neg = int((len(batch['input_ids']) - 2*bs) / bs)

outputs = self.lm(batch['input_ids'],
batch['attention_mask'],
)
outputs = self.lm(
batch['input_ids'],
batch['attention_mask'],
)

hidden = outputs.last_hidden_state # [total_bs, seq_len, dim]

# Pooling per-architecture
if self.is_encoder_only:
pooling = getattr(self.args, 'pooling', 'cls') if self.args else 'cls'
if pooling == 'mean':
mask = batch['attention_mask'].unsqueeze(-1) # [B, L, 1]
summed = (hidden * mask).sum(dim=1, keepdim=True)
lengths = mask.sum(dim=1, keepdim=True).clamp_min(1)
pooled = summed / lengths
elif pooling == 'cls_mean':
mask = batch['attention_mask'].unsqueeze(-1)
summed = (hidden * mask).sum(dim=1, keepdim=True)
lengths = mask.sum(dim=1, keepdim=True).clamp_min(1)
mean_pooled = summed / lengths
pooled = 0.5 * (hidden[:, 0:1, :] + mean_pooled)
else: # default CLS
pooled = hidden[:, 0:1, :]
else:
# decoder-style: last non-pad token representation
pooled = torch.stack([hidden[i, [batch['seq_lens'][i]-1]] for i in range(len(batch['seq_lens']))])

passage_features_all_tokens = outputs.last_hidden_state
return {
'query_passage_features': torch.stack([passage_features_all_tokens[i, [batch['seq_lens'][i]-1]] for i in range(bs)]),
'passage_passage_features': torch.stack([passage_features_all_tokens[i, [batch['seq_lens'][i]-1]] for i in range(bs, 2*bs)]),
'negative_passage_features': None if num_hard_neg == 0 else torch.stack([passage_features_all_tokens[i, [batch['seq_lens'][i]-1]] for i in range(2*bs, len(batch['seq_lens']))]).view(bs, num_hard_neg, -1)
'query_passage_features': pooled[:bs],
'passage_passage_features': pooled[bs:2*bs],
'negative_passage_features': None if num_hard_neg == 0 else pooled[2*bs:].view(bs, num_hard_neg, -1)
}

7 changes: 6 additions & 1 deletion F2LLM/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
accelerate
datasets
deepspeed
flash-attn
# flash-attn is GPU-only; skip on mac/arm. Pip marker limits install to Linux x86_64.
flash-attn; platform_system == "Linux" and platform_machine == "x86_64"
torch
transformers
tensorboard
scikit-learn
numpy
pandas
pytest
22 changes: 21 additions & 1 deletion F2LLM/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
from transformers import (
AutoTokenizer,
set_seed,
get_scheduler
get_scheduler,
AutoConfig
)
import os, json, random
from datasets import load_dataset
Expand All @@ -22,6 +23,17 @@
args.num_processes = accelerator.num_processes
accelerator.print(args)

# Detect architecture and normalize tokenizer padding
config = AutoConfig.from_pretrained(args.model_path)
encoder_archs = ['BertModel', 'RobertaModel', 'DebertaModel', 'ElectraModel', 'AlbertModel', 'DistilBertModel']
detected_encoder = any(arch in getattr(config, 'architectures', []) for arch in encoder_archs)
if args.model_arch:
is_encoder_only = args.model_arch.lower() == "encoder"
else:
is_encoder_only = detected_encoder
args.model_arch = "encoder" if is_encoder_only else "decoder"
accelerator.print(f"Model architecture: {'encoder' if is_encoder_only else 'decoder'} | Pooling: {args.pooling}")

def _stack(input_ids, max_len):
data = [ids[:max_len] for ids in input_ids] # input_ids: list of lists
lens = [len(x) for x in data]
Expand Down Expand Up @@ -70,6 +82,14 @@ def collate_fn(batch_raw):
valid_datasets.append((dataset_name, dataset['test']))

tokenizer = AutoTokenizer.from_pretrained(args.model_path)
if tokenizer.pad_token_id is None:
if tokenizer.eos_token_id is not None:
tokenizer.pad_token = tokenizer.eos_token
elif getattr(tokenizer, 'unk_token', None):
tokenizer.pad_token = tokenizer.unk_token
else:
tokenizer.add_special_tokens({'pad_token': '[PAD]'} )
tokenizer.padding_side = 'right'

train_loaders = {
name: DataLoader(ds, shuffle=True, batch_size=args.train_batch_size, collate_fn=collate_fn)
Expand Down
135 changes: 135 additions & 0 deletions F2LLM/smoke_encoder_decoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
"""
Lightweight smoke checks for encoder/decoder pooling and tokenizer behaviors.
Run: python smoke_encoder_decoder.py
"""
import torch
from tokenize_data_general import process_sent
from model import F2LLM


class MockTokenizer:
def __init__(self, eos_token_id=2, pad_token_id=0):
self.eos_token_id = eos_token_id
self.pad_token_id = pad_token_id

def __call__(self, sentence, max_length, truncation=True, add_special_tokens=False):
# deterministic token ids based on length
base = list(range(1, min(max_length, len(sentence.split())) + 1))
if add_special_tokens:
ids = [101] + base
if len(ids) < max_length:
ids.append(102)
else:
ids = base
ids = ids[:max_length]

class Output:
def __init__(self, ids):
self.input_ids = ids
return Output(ids)


def test_process_sent_encoder_special_tokens():
tok = MockTokenizer()
arr = process_sent("hello world", tok, max_seq_length=5, is_encoder_only=True, append_eos_decoder=True)
assert arr[0] == 101, "CLS should be first"
assert arr[-1] == 102, "SEP should be last when room remains"


def test_process_sent_decoder_eos_appended():
tok = MockTokenizer(eos_token_id=9)
arr = process_sent("a b c", tok, max_seq_length=6, is_encoder_only=False, append_eos_decoder=True)
assert arr[-1] == 9, "EOS should be appended for decoder when enabled"


def test_process_sent_decoder_skip_eos():
tok = MockTokenizer(eos_token_id=9)
arr = process_sent("a b c", tok, max_seq_length=6, is_encoder_only=False, append_eos_decoder=False)
assert arr[-1] != 9, "EOS should not be appended when disabled"


def test_encoder_pooling_variants():
class Args:
pooling = "cls"
model_arch = "encoder"
args = Args()
model = F2LLM.__new__(F2LLM)
model.args = args
model.is_encoder_only = True
bs = 2
num_hard_neg = 1
seq_lens = torch.tensor([5, 6, 7, 8, 5, 6])
hidden = torch.randn(bs * (2 + num_hard_neg), 10, 4)
attn_mask = torch.ones(bs * (2 + num_hard_neg), 10, dtype=torch.long)
batch = {
'input_ids': torch.zeros_like(attn_mask),
'attention_mask': attn_mask,
'seq_lens': seq_lens,
'bs': bs
}
class MockLM:
def __call__(self, input_ids, attention_mask):
class Output:
last_hidden_state = hidden
return Output()
model.lm = MockLM()
model.lm.device = hidden.device
model.forward = F2LLM.forward.__get__(model, F2LLM)

out_cls = model.forward(batch)
assert out_cls['query_passage_features'].shape == (bs, 1, hidden.size(-1))

model.args.pooling = "mean"
out_mean = model.forward(batch)
assert out_mean['query_passage_features'].shape == (bs, 1, hidden.size(-1))

model.args.pooling = "cls_mean"
out_cls_mean = model.forward(batch)
assert out_cls_mean['query_passage_features'].shape == (bs, 1, hidden.size(-1))


def test_decoder_pooling_last_token():
model = F2LLM.__new__(F2LLM)
model.args = None
model.is_encoder_only = False
bs = 2
num_hard_neg = 1
seq_lens = torch.tensor([2, 3, 4, 5, 6, 7])
hidden = torch.randn(bs * (2 + num_hard_neg), 8, 4)
attn_mask = torch.ones(bs * (2 + num_hard_neg), 8, dtype=torch.long)
batch = {
'input_ids': torch.zeros_like(attn_mask),
'attention_mask': attn_mask,
'seq_lens': seq_lens,
'bs': bs
}
class MockLM:
def __call__(self, input_ids, attention_mask):
class Output:
last_hidden_state = hidden
return Output()
model.lm = MockLM()
model.lm.device = hidden.device
model.forward = F2LLM.forward.__get__(model, F2LLM)

out = model.forward(batch)
assert out['query_passage_features'].shape == (bs, 1, hidden.size(-1))
assert out['negative_passage_features'].shape == (bs, num_hard_neg, hidden.size(-1))


def main():
tests = [
test_process_sent_encoder_special_tokens,
test_process_sent_decoder_eos_appended,
test_process_sent_decoder_skip_eos,
test_encoder_pooling_variants,
test_decoder_pooling_last_token,
]
for t in tests:
t()
print(f"{t.__name__}: ok")
print("All smoke tests passed.")


if __name__ == "__main__":
main()
Loading