Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions test/test_ucm_connector_save_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,13 @@ def run_once(
load_block_ids=([], []),
dump_block_ids=(dump_hashes, dump_vllm_block_ids),
)
connector.connector.kv_caches = kv_caches

if (
not hasattr(connector.connector, "k_store")
or connector.connector.k_store is None
):
connector.connector.register_kv_caches(kv_caches)

connector.bind_connector_metadata(metadata)

total_bytes = compute_total_bytes(kv_caches, batch_size, is_mla)
Expand All @@ -267,7 +273,7 @@ def run_once(

write_bw = (total_bytes / (1024**3)) / write_time if write_time > 0 else 0.0

lookup = connector.connector.store.lookup(dump_hashes)
lookup = connector.connector.k_store.lookup(dump_hashes)
if not all(lookup):
raise RuntimeError("Found missing cache blocks before load test.")

Expand All @@ -277,7 +283,7 @@ def run_once(
load_block_ids=(dump_hashes, load_vllm_block_ids),
dump_block_ids=([], []),
)
connector.connector.kv_caches = kv_caches

connector.bind_connector_metadata(load_metadata)

forward_context = build_forward_context(kv_caches, is_mla)
Expand Down Expand Up @@ -375,6 +381,8 @@ def broadcast(self, tensor, src):
mla,
)

connector.connector.register_kv_caches(kv_caches)

w_sizes, w_times, w_bws = [], [], []
r_sizes, r_times, r_bws = [], [], []

Expand Down
100 changes: 50 additions & 50 deletions ucm/store/test/e2e/nfsstore_embed_fetch.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@
import torch

from ucm.store.nfsstore.nfsstore_connector import UcmNfsStore
from ucm.store.pcstore.pcstore_connector_v1 import UcmPcStoreV1
from ucm.store.ucmstore import UcmKVStoreBase
from ucm.store.ucmstore_v1 import UcmKVStoreBaseV1


def setup(
Expand All @@ -42,7 +44,7 @@ def setup(
io_size,
transferStreamNumber,
transferIoDirect,
) -> UcmKVStoreBase:
) -> UcmKVStoreBaseV1:
config = {
"storage_backends": storage_backends,
"kv_block_size": block_size,
Expand All @@ -51,8 +53,9 @@ def setup(
"io_size": io_size,
"transferStreamNumber": transferStreamNumber,
"transferIoDirect": transferIoDirect,
"unique_id": secrets.token_hex(8),
}
return UcmNfsStore(config)
return UcmPcStoreV1(config)


def make_aligned_tensor(shape, dtype, device, alignment=4096):
Expand All @@ -79,64 +82,59 @@ def make_aligned_tensor(shape, dtype, device, alignment=4096):
def make_buffers(
block_number, device_id, batch_size, head_dim, block_len, block_layer, num_head, kv
):
hashes = [secrets.token_hex(16) for _ in range(block_number)]
kv_caches = {}
for i in range(block_layer):
kv_caches[i] = make_aligned_tensor(
hashes = [secrets.token_bytes(16) for _ in range(block_number)]
kvcaches = {}
for layer_id in range(block_layer):
kvcaches[layer_id] = make_aligned_tensor(
[kv, block_number, block_len, num_head, head_dim],
dtype=torch.float16,
dtype=torch.bfloat16,
device=f"cuda:{device_id}",
)
return hashes, kv_caches
kvcaches[layer_id].random_()
return hashes, kvcaches


def store_all_hashes(hashes: List[str]):
def store_all_hashes(hashes: List[bytes]):
file_path = os.path.join(os.path.dirname(__file__), "kvcache_block_hashes.txt")
with open(file_path, "w", encoding="utf-8") as f:
for h in hashes:
f.write(h + "\n")
f.write(h.hex() + "\n")


def load_hashes_from_file() -> List[str]:
def load_hashes_from_file() -> List[bytes]:
file_path = os.path.join(os.path.dirname(__file__), "kvcache_block_hashes.txt")
if not os.path.exists(file_path):
return []
with open(file_path, "r", encoding="utf-8") as f:
return [line.strip() for line in f.readlines()]
return [bytes.fromhex(line.strip()) for line in f.readlines()]


def embed(
store: UcmKVStoreBase,
hashes: List[str],
store: UcmKVStoreBaseV1,
hashes: List[bytes],
kvcaches: Dict[int, torch.Tensor],
mla: bool,
):
start_time = time.perf_counter()

total_block_ids, total_offsets, total_tensors = [], [], []
total_tensors = []
total_size = 0

for i, hash_val in enumerate(hashes):
offset = 0
tensors = []
for layer_id, kv_layer in kvcaches.items():
k_tensor = kv_layer[0][i] # kv=1
total_tensors.append(k_tensor)
total_block_ids.append(hash_val)
total_offsets.append(offset)
k_tensor = kv_layer[0][i].contiguous()
tensors.append(k_tensor)
sz = k_tensor.numel() * k_tensor.element_size()
offset += sz
total_size += sz

if not mla:
v_tensor = kv_layer[1][i]
total_tensors.append(v_tensor)
total_block_ids.append(hash_val)
total_offsets.append(offset)
v_tensor = kv_layer[1][i].contiguous()
tensors.append(v_tensor)
sz = v_tensor.numel() * v_tensor.element_size()
offset += sz
total_size += sz

task = store.dump(total_block_ids, total_offsets, total_tensors)
total_tensors.append(tensors)
task = store.dump(hashes, [], total_tensors)
store.wait(task)

elapsed_time = time.perf_counter() - start_time
Expand All @@ -151,8 +149,8 @@ def embed(


def fetch(
store: UcmKVStoreBase,
hashes: List[str],
store: UcmKVStoreBaseV1,
hashes: List[bytes],
kvcaches: Dict[int, torch.Tensor],
mla: bool,
):
Expand All @@ -162,32 +160,33 @@ def fetch(
for f in founds:
assert f, "Cache block miss detected"

block_ids, offsets, tensors = [], [], []
totoal_tensors = []
total_size = 0

for i, hash_val in enumerate(hashes):
offset = 0
tensors = []
for layer_id, kv_layer in kvcaches.items():
k_tensor = kv_layer[0][i] # kv=1
block_ids.append(hash_val)
offsets.append(offset)
k_tensor = kv_layer[0][i].contiguous()
tensors.append(k_tensor)
sz = k_tensor.numel() * k_tensor.element_size()
offset += sz
total_size += sz

if not mla:
v_tensor = kv_layer[1][i]
block_ids.append(hash_val)
offsets.append(offset)
v_tensor = kv_layer[1][i].contiguous()
tensors.append(v_tensor)
sz = v_tensor.numel() * v_tensor.element_size()
offset += sz
total_size += sz
totoal_tensors.append(tensors)

task = store.load(block_ids, offsets, tensors)
ret = store.wait(task)
assert ret == 0, "Load operation failed"
task = store.load(hashes, [], totoal_tensors)
try:
ret = store.wait(task)
if ret is None:
ret = 0
except RuntimeError as e:
print(f"Load operation failed with error: {e}")
raise
assert ret == 0, f"Load operation failed with return code: {ret}"

elapsed_time = time.perf_counter() - start_time
throughput_gbps = (total_size / (1024**3)) / elapsed_time if elapsed_time > 0 else 0
Expand Down Expand Up @@ -226,6 +225,10 @@ def run(
block_dim = head_size * num_head
io_size = block_dim * block_len * block_elem_size
block_size = io_size * block_layer

if not mla:
block_size = block_size * 2

batch_size = int(num_tokens / block_len)
real_blocks = batch_size + 10

Expand Down Expand Up @@ -257,16 +260,13 @@ def run(
kv,
)

results = store.create(hashes[:batch_size])
assert sum(results) == 0, "Create operation failed"

w_size, w_time, w_bw = embed(
store,
hashes[:batch_size],
kvcaches,
mla,
)
store.commit(hashes[:batch_size], True)
time.sleep(1)

if r == 0:
store_all_hashes(hashes[:batch_size])
Expand Down Expand Up @@ -349,10 +349,10 @@ def run(
try:
result = run(
storage_backends=".",
device_id=1,
repeat=1,
device_id=6,
repeat=2,
num_head=1,
block_len=128,
block_len=64,
transferStreamNumber=32,
num_tokens=4096,
block_layer=61,
Expand Down