-
Notifications
You must be signed in to change notification settings - Fork 63
Open
Description
I have a dataloader written in grain as shown below:
import grain
import numpy as np
from pathlib import Path
from grain.multiprocessing import SharedMemoryArray
BOS_ID = 50256
class BOSFinder:
def __init__(self, tokens):
# Precompute BOS positions once per shard
self.tokens=tokens
self.size = len(tokens)
self.bos_idx = np.where(tokens == BOS_ID)[0]
self.i = 0
self.batch_iter = 0
def next_batch(self, batch_size: int, max_seq_len: int):
n = len(self.bos_idx)
starts = []
ends = []
idx = self.i
for i in range(batch_size):
cur_len = 0
target_len = max_seq_len + 1
while cur_len < target_len:
if idx >= n:
raise StopIteration(f"Insufficient BOS ahead; hit tail of shard.")
cur = self.bos_idx[idx]
starts.append(cur)
remaining = target_len - cur_len
next_bos = self.bos_idx[idx + 1] if idx + 1 < n else self.size
# Take either remaining tokens or up to next BOS
end = min(next_bos, cur + remaining)
ends.append(end)
cur_len += end - cur
idx += 1
assert cur_len == target_len
self.i = idx
self.batch_iter += 1
return starts, ends
class LoadShardTokens(grain.transforms.Map):
def map(self, path):
file = Path(path)
header = np.fromfile(str(file), count=256, dtype=np.int32)
assert header[0] == 20240520, "magic number mismatch in the data .bin file"
assert header[1] == 1, "unsupported version"
num_tokens = int(header[2])
with file.open("rb", buffering=0) as f:
f.seek(256 * 4)
tokens = SharedMemoryArray((num_tokens,), dtype=np.uint16)
nbytes = f.readinto(tokens)
assert nbytes == 2 * num_tokens, "number of tokens read does not match header"
bos_idx = np.flatnonzero(tokens == BOS_ID)
return {"path": str(file), "tokens": tokens, "bos_idx": bos_idx, "size": num_tokens}
def make_grain_shard_loader(files, prefetch=2, worker_count=1):
# files should be a list of pathlib.Path or str
source = grain.sources.SharedMemoryDataSource([str(p) for p in files])
sampler = grain.samplers.SequentialSampler(num_records=len(source))
ops = [LoadShardTokens()]
return grain.DataLoader(
data_source=source,
sampler=sampler,
operations=ops,
worker_count=worker_count,
worker_buffer_size=prefetch,
)If I use this dataloader as is, it does not use any GPU as jax is not imported anywhere. But if I use this in my training script where jax is imported and is using the GPUs, it detects JAX, and all workers start taking around ~550MB on each GPU.
Apparently there is no way to turn this off as everytime JAX is detected, it uses it for some pytree stuff. This IMHO is an extremely bad behavior from efficiency POV
Metadata
Metadata
Assignees
Labels
No labels