Perf: Adding torch.compile + static cache, ~3x speed up#88
Perf: Adding torch.compile + static cache, ~3x speed up#88tsdocode wants to merge 12 commits intohuggingface:mainfrom
Conversation
|
Simple inference code to verify output: import time
import torch
from PIL import Image
torch.manual_seed(0)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(0)
from models.vision_language_model import VisionLanguageModel
from data.processors import get_tokenizer, get_image_processor
from torch.utils import benchmark
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.bfloat16
print(f"Using device: {device}")
torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.triton.unique_kernel_names = True
torch._inductor.config.fx_graph_cache = True
# debugging
torch._logging.set_logs(graph_breaks=True, recompiles=True)
torch.manual_seed(666)
if __name__ == "__main__":
model = VisionLanguageModel.from_pretrained("lusxvr/nanoVLM-222M").to(
device, dtype=dtype
)
model.eval()
# model.decoder = torch.compile(model.decoder, mode="reduce-overhead", fullgraph=True)
tokenizer = get_tokenizer(model.cfg.lm_tokenizer)
image_processor = get_image_processor(model.cfg.vit_img_size)
text = "What is this?"
template = f"Question: {text} Answer:"
encoded_batch = tokenizer.batch_encode_plus([template], return_tensors="pt")
tokens = encoded_batch["input_ids"].to(device)
image_path = "assets/image.png"
image = Image.open(image_path)
image = image_processor(image)
image = image.unsqueeze(0).to(device, dtype)
# Print table header
print("\n" + "="*80)
print(f"{'Configuration':<25} {'Time (s)':<12} {'Generated Text'}")
print("="*80)
# Without KV cache
start = time.time()
result = model.generate(
tokens, image, max_new_tokens=128, use_kv_cache=False
)
end = time.time()
generated_text = tokenizer.decode(result[0])
print(f"{'Without KV cache':<25} {end - start:<12.3f} {generated_text[:50]}...")
# Dynamic KV cache
start = time.time()
result = model.generate(
tokens,
image,
max_new_tokens=128,
use_kv_cache=True,
kv_cache_implementation="dynamic"
)
end = time.time()
generated_text = tokenizer.decode(result[0])
print(f"{'Dynamic KV cache':<25} {end - start:<12.3f} {generated_text[:50]}...")
# Static KV cache
start = time.time()
result = model.generate(
tokens,
image,
max_new_tokens=128,
use_kv_cache=True,
kv_cache_implementation="static"
)
end = time.time()
generated_text = tokenizer.decode(result[0])
print(f"{'Static KV cache':<25} {end - start:<12.3f} {generated_text[:50]}...")
model.decoder = torch.compile(
model.decoder, mode="reduce-overhead", fullgraph=True
)
# Static KV cache (compiled) - multiple runs
print("-"*80)
print("Static KV cache (compiled) - Multiple runs:")
print("-"*80)
for i in range(3):
start = time.time()
result = model.generate(
tokens,
image,
max_new_tokens=128,
use_kv_cache=True,
kv_cache_implementation="static"
)
end = time.time()
generated_text = tokenizer.decode(result[0])
print(f"{'Run ' + str(i+1):<25} {end - start:<12.3f} {generated_text[:50]}...")
print("="*80)Output on A100: |
|
This pr is WIP, but above code is testable @andimarafioti @lusxvr If you guy have a moment, please help me give it a shot! |
|
Update: |
|
@andimarafioti can you help me review this PR? |
|
Hi! Sorry, I was busy with other stuff, I'll get to it soon 🙏 |
|
I was looking at it, is this 3x faster that the current kv-cache implementation? How much of that is the torch compile and how mache the kv cache implementation? |
3x faster come from the combination of static kvcache + torch.compile:
|
|
Update new benchmark result on A100 with @andimarafioti please help me advise the next step! |
Did:
Result for 1000 tokens:
A100:
~x2.6 faster for fp32
~x3.2 faster for fp16
H100:
~x1.5 faster for fp32
~x2.7 faster for fp16
TODO: