Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions tests/test_simple_public.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import subprocess
import sys
import requests
import ssl
from pathlib import Path

import tiktoken

Expand Down Expand Up @@ -40,3 +43,21 @@ def test_optional_blobfile_dependency():
assert "blobfile" not in sys.modules
"""
subprocess.check_call([sys.executable, "-c", prog])


def test_custom_http_client():
custom_session = requests.Session()

ca_bundle = ssl.get_default_verify_paths().cafile
if ca_bundle and Path(ca_bundle).exists():
custom_session.verify = ca_bundle
custom_session.headers.update({"User-Agent": "custom-tiktoken-client"})

enc = tiktoken.get_encoding("gpt2", http_client=custom_session)
assert enc.encode("hello world") == [31373, 995]
assert enc.decode([31373, 995]) == "hello world"

enc = tiktoken.encoding_for_model("gpt-4", http_client=custom_session)
assert enc.name == "cl100k_base"
assert enc.encode("hello world") == [15339, 1917]

45 changes: 35 additions & 10 deletions tiktoken/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,26 @@
import base64
import hashlib
import os
from typing import Any, Protocol

import requests

def read_file(blobpath: str) -> bytes:

class HttpClient(Protocol):
def get(self, url: str, **kwargs: Any) -> Any: ...


def read_file(blobpath: str, http_client: HttpClient | None = None) -> bytes:
if "://" not in blobpath:
with open(blobpath, "rb", buffering=0) as f:
return f.read()

if blobpath.startswith(("http://", "https://")):
# avoiding blobfile for public files helps avoid auth issues, like MFA prompts.
import requests
if http_client is None:
http_client = requests.Session()

resp = requests.get(blobpath)
resp = http_client.get(blobpath)
resp.raise_for_status()
return resp.content

Expand All @@ -33,7 +41,11 @@ def check_hash(data: bytes, expected_hash: str) -> bool:
return actual_hash == expected_hash


def read_file_cached(blobpath: str, expected_hash: str | None = None) -> bytes:
def read_file_cached(
blobpath: str,
expected_hash: str | None = None,
http_client: HttpClient | None = None,
) -> bytes:
user_specified_cache = True
if "TIKTOKEN_CACHE_DIR" in os.environ:
cache_dir = os.environ["TIKTOKEN_CACHE_DIR"]
Expand All @@ -47,7 +59,7 @@ def read_file_cached(blobpath: str, expected_hash: str | None = None) -> bytes:

if cache_dir == "":
# disable caching
return read_file(blobpath)
return read_file(blobpath, http_client=http_client)

cache_key = hashlib.sha1(blobpath.encode()).hexdigest()

Expand All @@ -64,7 +76,7 @@ def read_file_cached(blobpath: str, expected_hash: str | None = None) -> bytes:
except OSError:
pass

contents = read_file(blobpath)
contents = read_file(blobpath, http_client=http_client)
if expected_hash and not check_hash(contents, expected_hash):
raise ValueError(
f"Hash mismatch for data downloaded from {blobpath} (expected {expected_hash}). "
Expand Down Expand Up @@ -93,7 +105,10 @@ def data_gym_to_mergeable_bpe_ranks(
vocab_bpe_hash: str | None = None,
encoder_json_hash: str | None = None,
clobber_one_byte_tokens: bool = False,
http_client: HttpClient | None = None,
) -> dict[bytes, int]:
if http_client is None:
http_client = requests.Session()
# NB: do not add caching to this function
rank_to_intbyte = [b for b in range(2**8) if chr(b).isprintable() and chr(b) != " "]

Expand All @@ -107,7 +122,7 @@ def data_gym_to_mergeable_bpe_ranks(
assert len(rank_to_intbyte) == 2**8

# vocab_bpe contains the merges along with associated ranks
vocab_bpe_contents = read_file_cached(vocab_bpe_file, vocab_bpe_hash).decode()
vocab_bpe_contents = read_file_cached(vocab_bpe_file, vocab_bpe_hash, http_client=http_client).decode()
bpe_merges = [tuple(merge_str.split()) for merge_str in vocab_bpe_contents.split("\n")[1:-1]]

def decode_data_gym(value: str) -> bytes:
Expand All @@ -129,7 +144,11 @@ def decode_data_gym(value: str) -> bytes:
# check that the encoder file matches the merges file
# this sanity check is important since tiktoken assumes that ranks are ordered the same
# as merge priority
encoder_json = json.loads(read_file_cached(encoder_json_file, encoder_json_hash))
encoder_json = json.loads(
read_file_cached(
encoder_json_file, encoder_json_hash, http_client=http_client
).decode()
)
encoder_json_loaded = {decode_data_gym(k): v for k, v in encoder_json.items()}
# drop these two special tokens if present, since they're not mergeable bpe tokens
encoder_json_loaded.pop(b"<|endoftext|>", None)
Expand Down Expand Up @@ -157,9 +176,15 @@ def dump_tiktoken_bpe(bpe_ranks: dict[bytes, int], tiktoken_bpe_file: str) -> No
f.write(base64.b64encode(token) + b" " + str(rank).encode() + b"\n")


def load_tiktoken_bpe(tiktoken_bpe_file: str, expected_hash: str | None = None) -> dict[bytes, int]:
def load_tiktoken_bpe(
tiktoken_bpe_file: str,
expected_hash: str | None = None,
http_client: HttpClient | None = None,
) -> dict[bytes, int]:
if http_client is None:
http_client = requests.Session()
# NB: do not add caching to this function
contents = read_file_cached(tiktoken_bpe_file, expected_hash)
contents = read_file_cached(tiktoken_bpe_file, expected_hash, http_client=http_client)
ret = {}
for line in contents.splitlines():
if not line:
Expand Down
9 changes: 7 additions & 2 deletions tiktoken/model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from __future__ import annotations

import requests

from .core import Encoding
from .load import HttpClient
from .registry import get_encoding

# TODO: these will likely be replaced by an API endpoint
Expand Down Expand Up @@ -110,9 +113,11 @@ def encoding_name_for_model(model_name: str) -> str:
return encoding_name


def encoding_for_model(model_name: str) -> Encoding:
def encoding_for_model(
model_name: str, http_client: HttpClient | None = None
) -> Encoding:
"""Returns the encoding used by a model.

Raises a KeyError if the model name is not recognised.
"""
return get_encoding(encoding_name_for_model(model_name))
return get_encoding(encoding_name_for_model(model_name), http_client=http_client)
16 changes: 9 additions & 7 deletions tiktoken/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,16 @@
import threading
from typing import Any, Callable, Sequence

import tiktoken_ext
import requests

import tiktoken
import tiktoken_ext
from tiktoken.core import Encoding
from tiktoken.load import HttpClient

_lock = threading.RLock()
ENCODINGS: dict[str, Encoding] = {}
ENCODING_CONSTRUCTORS: dict[str, Callable[[], dict[str, Any]]] | None = None
ENCODING_CONSTRUCTORS: dict[str, Callable[[HttpClient | None], dict[str, Any]]] | None = None


@functools.lru_cache
Expand All @@ -24,7 +26,9 @@ def _available_plugin_modules() -> Sequence[str]:
# - it's a separate top-level package because namespace subpackages of non-namespace
# packages don't quite do what you want with editable installs
mods = []
plugin_mods = pkgutil.iter_modules(tiktoken_ext.__path__, tiktoken_ext.__name__ + ".")
plugin_mods = pkgutil.iter_modules(
tiktoken_ext.__path__, tiktoken_ext.__name__ + "."
)
for _, mod_name, _ in plugin_mods:
mods.append(mod_name)
return mods
Expand Down Expand Up @@ -58,9 +62,7 @@ def _find_constructors() -> None:
raise




def get_encoding(encoding_name: str) -> Encoding:
def get_encoding(encoding_name: str, http_client: HttpClient | None = None) -> Encoding:
if not isinstance(encoding_name, str):
raise ValueError(f"Expected a string in get_encoding, got {type(encoding_name)}")

Expand All @@ -83,7 +85,7 @@ def get_encoding(encoding_name: str) -> Encoding:
)

constructor = ENCODING_CONSTRUCTORS[encoding_name]
enc = Encoding(**constructor())
enc = Encoding(**constructor(http_client=http_client))
ENCODINGS[encoding_name] = enc
return enc

Expand Down
37 changes: 24 additions & 13 deletions tiktoken_ext/openai_public.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from tiktoken.load import data_gym_to_mergeable_bpe_ranks, load_tiktoken_bpe
import requests

from tiktoken.load import HttpClient, data_gym_to_mergeable_bpe_ranks, load_tiktoken_bpe

ENDOFTEXT = "<|endoftext|>"
FIM_PREFIX = "<|fim_prefix|>"
Expand All @@ -9,17 +11,16 @@
# The pattern in the original GPT-2 release is:
# r"""'s|'t|'re|'ve|'m|'ll|'d| ?[\p{L}]+| ?[\p{N}]+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
# This is equivalent, but executes faster:
r50k_pat_str = (
r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}++| ?\p{N}++| ?[^\s\p{L}\p{N}]++|\s++$|\s+(?!\S)|\s"""
)
r50k_pat_str = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}++| ?\p{N}++| ?[^\s\p{L}\p{N}]++|\s++$|\s+(?!\S)|\s"""


def gpt2():
def gpt2(http_client: HttpClient | None = None):
mergeable_ranks = data_gym_to_mergeable_bpe_ranks(
vocab_bpe_file="https://openaipublic.blob.core.windows.net/gpt-2/encodings/main/vocab.bpe",
encoder_json_file="https://openaipublic.blob.core.windows.net/gpt-2/encodings/main/encoder.json",
vocab_bpe_hash="1ce1664773c50f3e0cc8842619a93edc4624525b728b188a9e0be33b7726adc5",
encoder_json_hash="196139668be63f3b5d6574427317ae82f612a97c5d1cdaf36ed2256dbf636783",
http_client=http_client,
)
return {
"name": "gpt2",
Expand All @@ -30,10 +31,11 @@ def gpt2():
}


def r50k_base():
def r50k_base(http_client: HttpClient | None = None):
mergeable_ranks = load_tiktoken_bpe(
"https://openaipublic.blob.core.windows.net/encodings/r50k_base.tiktoken",
expected_hash="306cd27f03c1a714eca7108e03d66b7dc042abe8c258b44c199a7ed9838dd930",
http_client=http_client,
)
return {
"name": "r50k_base",
Expand All @@ -44,10 +46,11 @@ def r50k_base():
}


def p50k_base():
def p50k_base(http_client: HttpClient | None = None):
mergeable_ranks = load_tiktoken_bpe(
"https://openaipublic.blob.core.windows.net/encodings/p50k_base.tiktoken",
expected_hash="94b5ca7dff4d00767bc256fdd1b27e5b17361d7b8a5f968547f9f23eb70d2069",
http_client=http_client,
)
return {
"name": "p50k_base",
Expand All @@ -58,12 +61,18 @@ def p50k_base():
}


def p50k_edit():
def p50k_edit(http_client: HttpClient | None = None):
mergeable_ranks = load_tiktoken_bpe(
"https://openaipublic.blob.core.windows.net/encodings/p50k_base.tiktoken",
expected_hash="94b5ca7dff4d00767bc256fdd1b27e5b17361d7b8a5f968547f9f23eb70d2069",
http_client=http_client,
)
special_tokens = {ENDOFTEXT: 50256, FIM_PREFIX: 50281, FIM_MIDDLE: 50282, FIM_SUFFIX: 50283}
special_tokens = {
ENDOFTEXT: 50256,
FIM_PREFIX: 50281,
FIM_MIDDLE: 50282,
FIM_SUFFIX: 50283,
}
return {
"name": "p50k_edit",
"pat_str": r50k_pat_str,
Expand All @@ -72,10 +81,11 @@ def p50k_edit():
}


def cl100k_base():
def cl100k_base(http_client: HttpClient | None = None):
mergeable_ranks = load_tiktoken_bpe(
"https://openaipublic.blob.core.windows.net/encodings/cl100k_base.tiktoken",
expected_hash="223921b76ee99bde995b7ff738513eef100fb51d18c93597a113bcffe865b2a7",
http_client=http_client,
)
special_tokens = {
ENDOFTEXT: 100257,
Expand All @@ -92,10 +102,11 @@ def cl100k_base():
}


def o200k_base():
def o200k_base(http_client: HttpClient | None = None):
mergeable_ranks = load_tiktoken_bpe(
"https://openaipublic.blob.core.windows.net/encodings/o200k_base.tiktoken",
expected_hash="446a9538cb6c348e3516120d7c08b09f57c36495e2acfffe59a5bf8b0cfb1a2d",
http_client=http_client,
)
special_tokens = {ENDOFTEXT: 199999, ENDOFPROMPT: 200018}
# This regex could be made more efficient. If I was the one working on this encoding, I would
Expand All @@ -120,8 +131,8 @@ def o200k_base():
}


def o200k_harmony():
base_enc = o200k_base()
def o200k_harmony(http_client: HttpClient | None = None):
base_enc = o200k_base(http_client=http_client)
name = "o200k_harmony"
pat_str = base_enc["pat_str"]
mergeable_ranks = base_enc["mergeable_ranks"]
Expand Down