Skip to content

Grain workers use GPUs, and there is no way to avoid it #1199

@AakashKumarNain

Description

@AakashKumarNain

I have a dataloader written in grain as shown below:

import grain
import numpy as np
from pathlib import Path
from grain.multiprocessing import SharedMemoryArray


BOS_ID = 50256

class BOSFinder:
    def __init__(self, tokens):
        # Precompute BOS positions once per shard
        self.tokens=tokens
        self.size = len(tokens)
        self.bos_idx = np.where(tokens == BOS_ID)[0]
        self.i = 0
        self.batch_iter = 0

    def next_batch(self, batch_size: int, max_seq_len: int):
        n = len(self.bos_idx)
        starts = []
        ends = []
        
        idx = self.i
        for i in range(batch_size):
            cur_len = 0
            target_len = max_seq_len + 1
            
            while cur_len < target_len:
                if idx >= n:
                    raise StopIteration(f"Insufficient BOS ahead; hit tail of shard.")
                
                cur = self.bos_idx[idx]
                starts.append(cur)
                
                remaining = target_len - cur_len
                next_bos = self.bos_idx[idx + 1] if idx + 1 < n else self.size
                
                # Take either remaining tokens or up to next BOS
                end = min(next_bos, cur + remaining)
                ends.append(end)
                
                cur_len += end - cur
                idx += 1
            
            assert cur_len == target_len
        
        self.i = idx
        self.batch_iter += 1
        return starts, ends



class LoadShardTokens(grain.transforms.Map):
    def map(self, path):
        file = Path(path)

        header = np.fromfile(str(file), count=256, dtype=np.int32)
        assert header[0] == 20240520, "magic number mismatch in the data .bin file"
        assert header[1] == 1, "unsupported version"
        num_tokens = int(header[2])

        with file.open("rb", buffering=0) as f:
            f.seek(256 * 4)
            tokens = SharedMemoryArray((num_tokens,), dtype=np.uint16)
            nbytes = f.readinto(tokens)
            assert nbytes == 2 * num_tokens, "number of tokens read does not match header"

        bos_idx = np.flatnonzero(tokens == BOS_ID)
        return {"path": str(file), "tokens": tokens, "bos_idx": bos_idx, "size": num_tokens}



def make_grain_shard_loader(files, prefetch=2, worker_count=1):

    # files should be a list of pathlib.Path or str
    source = grain.sources.SharedMemoryDataSource([str(p) for p in files])
    sampler = grain.samplers.SequentialSampler(num_records=len(source))
    ops = [LoadShardTokens()]
    return grain.DataLoader(
        data_source=source,
        sampler=sampler,
        operations=ops,
        worker_count=worker_count,          
        worker_buffer_size=prefetch,
    )

If I use this dataloader as is, it does not use any GPU as jax is not imported anywhere. But if I use this in my training script where jax is imported and is using the GPUs, it detects JAX, and all workers start taking around ~550MB on each GPU.

Apparently there is no way to turn this off as everytime JAX is detected, it uses it for some pytree stuff. This IMHO is an extremely bad behavior from efficiency POV

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions