Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions crossfit/op/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@


try:
from crossfit.op.tokenize import Tokenizer
from crossfit.op.tokenize import Tokenizer, TokenCounter

__all__.append("Tokenizer")
__all__.extend(["Tokenizer", "TokenCounter"])
except ImportError:
pass

Expand Down
55 changes: 54 additions & 1 deletion crossfit/op/tokenize.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from typing import Optional

import cudf
from cudf.core.subword_tokenizer import SubwordTokenizer, _cast_to_appropriate_type
Expand Down Expand Up @@ -145,7 +146,7 @@ def from_pretrained(cls, name, cache_dir=None):

# Save vocabulary to disk
# `save_vocabulary()` automatically appends `-vocab.txt` suffix.
vocab_path = tokenizer.save_vocabulary(cache_dir, "{tokenizer_class}")[0]
vocab_path = tokenizer.save_vocabulary(cache_dir, f"{tokenizer_class}")[0]

# Hash the vocabulary and save it
hash_vocab(vocab_path, hashed_vocab_path)
Expand All @@ -166,3 +167,55 @@ def clip_tokens(token_o, max_length, return_type="pt"):
del token_o["metadata"]

return token_o


class TokenCounter(Op):
def __init__(
self,
cols=None,
keep_cols=None,
max_length: Optional[int] = None,
):
super().__init__(cols=cols, keep_cols=keep_cols)
self.max_length = max_length

def call_column(self, data):
if isinstance(data, cudf.DataFrame):
raise ValueError(
"data must be a Series, got DataFrame. Add a pre step to convert to Series"
)
first_zero = data.list.astype(int).list.index(0)
max_length = self.max_length or data.list.len().iloc[0]
num_tokens = first_zero.replace(-1, max_length)
return num_tokens

def call(self, data):
output = cudf.DataFrame()

if self.cols is None or len(self.cols) == 1:
if self.cols:
data = data[self.cols[0]]

if not isinstance(data, cudf.Series):
raise ValueError("data must be a cudf Series")

num_tokens = self.call_column(data)
output["token_count"] = num_tokens

return output

for col in self.cols:
if col not in data.columns:
raise ValueError(f"Column {col} not found in data")

num_tokens = self.call_column(data[col])
output[f"{col}_token_count"] = num_tokens

return output

def meta(self):
if self.cols is not None and len(self.cols) > 1:
dtypes = {f"{col}_token_count": "int32" for col in self.cols}
else:
dtypes = {"token_count": "int32"}
return dtypes
42 changes: 42 additions & 0 deletions tests/op/test_tokenize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import pytest

cudf = pytest.importorskip("cudf")
dask_cudf = pytest.importorskip("dask_cudf")

import crossfit as cf # noqa: E402
from crossfit import op # noqa: E402


@pytest.mark.singlegpu
def test_token_counter(
model_name="sentence-transformers/all-MiniLM-L6-v2",
):
df = cudf.DataFrame(
{
"text": [
"!",
"query: how much protein should a female eat",
"query: summit define",
"passage: As a general guideline, the CDC's average requirement of protein for women ages 19 to 70 is 46 grams per day. But, as you can see from this chart, you'll need to increase that if you're expecting or training for a marathon. Check out the chart below to see how much protein you should be eating each day.", # noqa: E501
"passage: Definition of summit for English Language Learners. : 1 the highest point of a mountain : the top of a mountain. : 2 the highest level. : 3 a meeting or series of meetings between the leaders of two or more governments.", # noqa: E501
]
}
)

ddf = dask_cudf.from_cudf(df, npartitions=2)

model = cf.SentenceTransformerModel(model_name)

pipe = op.Sequential(
op.Tokenizer(model, cols=["text"]),
op.TokenCounter(cols=["input_ids"]),
)

num_tokens = pipe(ddf).compute()
expected = cudf.DataFrame(
{
"token_count": cudf.Series([3, 11, 6, 75, 50], dtype="int32")
}
)

cudf.testing.testing.assert_frame_equal(num_tokens, expected)