diff --git a/tests/test_simple_public.py b/tests/test_simple_public.py index 8458c12..e33591e 100644 --- a/tests/test_simple_public.py +++ b/tests/test_simple_public.py @@ -1,5 +1,8 @@ import subprocess import sys +import requests +import ssl +from pathlib import Path import tiktoken @@ -40,3 +43,21 @@ def test_optional_blobfile_dependency(): assert "blobfile" not in sys.modules """ subprocess.check_call([sys.executable, "-c", prog]) + + +def test_custom_http_client(): + custom_session = requests.Session() + + ca_bundle = ssl.get_default_verify_paths().cafile + if ca_bundle and Path(ca_bundle).exists(): + custom_session.verify = ca_bundle + custom_session.headers.update({"User-Agent": "custom-tiktoken-client"}) + + enc = tiktoken.get_encoding("gpt2", http_client=custom_session) + assert enc.encode("hello world") == [31373, 995] + assert enc.decode([31373, 995]) == "hello world" + + enc = tiktoken.encoding_for_model("gpt-4", http_client=custom_session) + assert enc.name == "cl100k_base" + assert enc.encode("hello world") == [15339, 1917] + diff --git a/tiktoken/load.py b/tiktoken/load.py index dc2eba6..8106b55 100644 --- a/tiktoken/load.py +++ b/tiktoken/load.py @@ -3,18 +3,26 @@ import base64 import hashlib import os +from typing import Any, Protocol +import requests -def read_file(blobpath: str) -> bytes: + +class HttpClient(Protocol): + def get(self, url: str, **kwargs: Any) -> Any: ... + + +def read_file(blobpath: str, http_client: HttpClient | None = None) -> bytes: if "://" not in blobpath: with open(blobpath, "rb", buffering=0) as f: return f.read() if blobpath.startswith(("http://", "https://")): # avoiding blobfile for public files helps avoid auth issues, like MFA prompts. - import requests + if http_client is None: + http_client = requests.Session() - resp = requests.get(blobpath) + resp = http_client.get(blobpath) resp.raise_for_status() return resp.content @@ -33,7 +41,11 @@ def check_hash(data: bytes, expected_hash: str) -> bool: return actual_hash == expected_hash -def read_file_cached(blobpath: str, expected_hash: str | None = None) -> bytes: +def read_file_cached( + blobpath: str, + expected_hash: str | None = None, + http_client: HttpClient | None = None, +) -> bytes: user_specified_cache = True if "TIKTOKEN_CACHE_DIR" in os.environ: cache_dir = os.environ["TIKTOKEN_CACHE_DIR"] @@ -47,7 +59,7 @@ def read_file_cached(blobpath: str, expected_hash: str | None = None) -> bytes: if cache_dir == "": # disable caching - return read_file(blobpath) + return read_file(blobpath, http_client=http_client) cache_key = hashlib.sha1(blobpath.encode()).hexdigest() @@ -64,7 +76,7 @@ def read_file_cached(blobpath: str, expected_hash: str | None = None) -> bytes: except OSError: pass - contents = read_file(blobpath) + contents = read_file(blobpath, http_client=http_client) if expected_hash and not check_hash(contents, expected_hash): raise ValueError( f"Hash mismatch for data downloaded from {blobpath} (expected {expected_hash}). " @@ -93,7 +105,10 @@ def data_gym_to_mergeable_bpe_ranks( vocab_bpe_hash: str | None = None, encoder_json_hash: str | None = None, clobber_one_byte_tokens: bool = False, + http_client: HttpClient | None = None, ) -> dict[bytes, int]: + if http_client is None: + http_client = requests.Session() # NB: do not add caching to this function rank_to_intbyte = [b for b in range(2**8) if chr(b).isprintable() and chr(b) != " "] @@ -107,7 +122,7 @@ def data_gym_to_mergeable_bpe_ranks( assert len(rank_to_intbyte) == 2**8 # vocab_bpe contains the merges along with associated ranks - vocab_bpe_contents = read_file_cached(vocab_bpe_file, vocab_bpe_hash).decode() + vocab_bpe_contents = read_file_cached(vocab_bpe_file, vocab_bpe_hash, http_client=http_client).decode() bpe_merges = [tuple(merge_str.split()) for merge_str in vocab_bpe_contents.split("\n")[1:-1]] def decode_data_gym(value: str) -> bytes: @@ -129,7 +144,11 @@ def decode_data_gym(value: str) -> bytes: # check that the encoder file matches the merges file # this sanity check is important since tiktoken assumes that ranks are ordered the same # as merge priority - encoder_json = json.loads(read_file_cached(encoder_json_file, encoder_json_hash)) + encoder_json = json.loads( + read_file_cached( + encoder_json_file, encoder_json_hash, http_client=http_client + ).decode() + ) encoder_json_loaded = {decode_data_gym(k): v for k, v in encoder_json.items()} # drop these two special tokens if present, since they're not mergeable bpe tokens encoder_json_loaded.pop(b"<|endoftext|>", None) @@ -157,9 +176,15 @@ def dump_tiktoken_bpe(bpe_ranks: dict[bytes, int], tiktoken_bpe_file: str) -> No f.write(base64.b64encode(token) + b" " + str(rank).encode() + b"\n") -def load_tiktoken_bpe(tiktoken_bpe_file: str, expected_hash: str | None = None) -> dict[bytes, int]: +def load_tiktoken_bpe( + tiktoken_bpe_file: str, + expected_hash: str | None = None, + http_client: HttpClient | None = None, +) -> dict[bytes, int]: + if http_client is None: + http_client = requests.Session() # NB: do not add caching to this function - contents = read_file_cached(tiktoken_bpe_file, expected_hash) + contents = read_file_cached(tiktoken_bpe_file, expected_hash, http_client=http_client) ret = {} for line in contents.splitlines(): if not line: diff --git a/tiktoken/model.py b/tiktoken/model.py index 5c669af..a09b7f7 100644 --- a/tiktoken/model.py +++ b/tiktoken/model.py @@ -1,6 +1,9 @@ from __future__ import annotations +import requests + from .core import Encoding +from .load import HttpClient from .registry import get_encoding # TODO: these will likely be replaced by an API endpoint @@ -110,9 +113,11 @@ def encoding_name_for_model(model_name: str) -> str: return encoding_name -def encoding_for_model(model_name: str) -> Encoding: +def encoding_for_model( + model_name: str, http_client: HttpClient | None = None +) -> Encoding: """Returns the encoding used by a model. Raises a KeyError if the model name is not recognised. """ - return get_encoding(encoding_name_for_model(model_name)) + return get_encoding(encoding_name_for_model(model_name), http_client=http_client) diff --git a/tiktoken/registry.py b/tiktoken/registry.py index 17c4574..42fe52b 100644 --- a/tiktoken/registry.py +++ b/tiktoken/registry.py @@ -6,14 +6,16 @@ import threading from typing import Any, Callable, Sequence -import tiktoken_ext +import requests import tiktoken +import tiktoken_ext from tiktoken.core import Encoding +from tiktoken.load import HttpClient _lock = threading.RLock() ENCODINGS: dict[str, Encoding] = {} -ENCODING_CONSTRUCTORS: dict[str, Callable[[], dict[str, Any]]] | None = None +ENCODING_CONSTRUCTORS: dict[str, Callable[[HttpClient | None], dict[str, Any]]] | None = None @functools.lru_cache @@ -24,7 +26,9 @@ def _available_plugin_modules() -> Sequence[str]: # - it's a separate top-level package because namespace subpackages of non-namespace # packages don't quite do what you want with editable installs mods = [] - plugin_mods = pkgutil.iter_modules(tiktoken_ext.__path__, tiktoken_ext.__name__ + ".") + plugin_mods = pkgutil.iter_modules( + tiktoken_ext.__path__, tiktoken_ext.__name__ + "." + ) for _, mod_name, _ in plugin_mods: mods.append(mod_name) return mods @@ -58,9 +62,7 @@ def _find_constructors() -> None: raise - - -def get_encoding(encoding_name: str) -> Encoding: +def get_encoding(encoding_name: str, http_client: HttpClient | None = None) -> Encoding: if not isinstance(encoding_name, str): raise ValueError(f"Expected a string in get_encoding, got {type(encoding_name)}") @@ -83,7 +85,7 @@ def get_encoding(encoding_name: str) -> Encoding: ) constructor = ENCODING_CONSTRUCTORS[encoding_name] - enc = Encoding(**constructor()) + enc = Encoding(**constructor(http_client=http_client)) ENCODINGS[encoding_name] = enc return enc diff --git a/tiktoken_ext/openai_public.py b/tiktoken_ext/openai_public.py index 02c9ee2..0a435d4 100644 --- a/tiktoken_ext/openai_public.py +++ b/tiktoken_ext/openai_public.py @@ -1,4 +1,6 @@ -from tiktoken.load import data_gym_to_mergeable_bpe_ranks, load_tiktoken_bpe +import requests + +from tiktoken.load import HttpClient, data_gym_to_mergeable_bpe_ranks, load_tiktoken_bpe ENDOFTEXT = "<|endoftext|>" FIM_PREFIX = "<|fim_prefix|>" @@ -9,17 +11,16 @@ # The pattern in the original GPT-2 release is: # r"""'s|'t|'re|'ve|'m|'ll|'d| ?[\p{L}]+| ?[\p{N}]+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""" # This is equivalent, but executes faster: -r50k_pat_str = ( - r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}++| ?\p{N}++| ?[^\s\p{L}\p{N}]++|\s++$|\s+(?!\S)|\s""" -) +r50k_pat_str = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}++| ?\p{N}++| ?[^\s\p{L}\p{N}]++|\s++$|\s+(?!\S)|\s""" -def gpt2(): +def gpt2(http_client: HttpClient | None = None): mergeable_ranks = data_gym_to_mergeable_bpe_ranks( vocab_bpe_file="https://openaipublic.blob.core.windows.net/gpt-2/encodings/main/vocab.bpe", encoder_json_file="https://openaipublic.blob.core.windows.net/gpt-2/encodings/main/encoder.json", vocab_bpe_hash="1ce1664773c50f3e0cc8842619a93edc4624525b728b188a9e0be33b7726adc5", encoder_json_hash="196139668be63f3b5d6574427317ae82f612a97c5d1cdaf36ed2256dbf636783", + http_client=http_client, ) return { "name": "gpt2", @@ -30,10 +31,11 @@ def gpt2(): } -def r50k_base(): +def r50k_base(http_client: HttpClient | None = None): mergeable_ranks = load_tiktoken_bpe( "https://openaipublic.blob.core.windows.net/encodings/r50k_base.tiktoken", expected_hash="306cd27f03c1a714eca7108e03d66b7dc042abe8c258b44c199a7ed9838dd930", + http_client=http_client, ) return { "name": "r50k_base", @@ -44,10 +46,11 @@ def r50k_base(): } -def p50k_base(): +def p50k_base(http_client: HttpClient | None = None): mergeable_ranks = load_tiktoken_bpe( "https://openaipublic.blob.core.windows.net/encodings/p50k_base.tiktoken", expected_hash="94b5ca7dff4d00767bc256fdd1b27e5b17361d7b8a5f968547f9f23eb70d2069", + http_client=http_client, ) return { "name": "p50k_base", @@ -58,12 +61,18 @@ def p50k_base(): } -def p50k_edit(): +def p50k_edit(http_client: HttpClient | None = None): mergeable_ranks = load_tiktoken_bpe( "https://openaipublic.blob.core.windows.net/encodings/p50k_base.tiktoken", expected_hash="94b5ca7dff4d00767bc256fdd1b27e5b17361d7b8a5f968547f9f23eb70d2069", + http_client=http_client, ) - special_tokens = {ENDOFTEXT: 50256, FIM_PREFIX: 50281, FIM_MIDDLE: 50282, FIM_SUFFIX: 50283} + special_tokens = { + ENDOFTEXT: 50256, + FIM_PREFIX: 50281, + FIM_MIDDLE: 50282, + FIM_SUFFIX: 50283, + } return { "name": "p50k_edit", "pat_str": r50k_pat_str, @@ -72,10 +81,11 @@ def p50k_edit(): } -def cl100k_base(): +def cl100k_base(http_client: HttpClient | None = None): mergeable_ranks = load_tiktoken_bpe( "https://openaipublic.blob.core.windows.net/encodings/cl100k_base.tiktoken", expected_hash="223921b76ee99bde995b7ff738513eef100fb51d18c93597a113bcffe865b2a7", + http_client=http_client, ) special_tokens = { ENDOFTEXT: 100257, @@ -92,10 +102,11 @@ def cl100k_base(): } -def o200k_base(): +def o200k_base(http_client: HttpClient | None = None): mergeable_ranks = load_tiktoken_bpe( "https://openaipublic.blob.core.windows.net/encodings/o200k_base.tiktoken", expected_hash="446a9538cb6c348e3516120d7c08b09f57c36495e2acfffe59a5bf8b0cfb1a2d", + http_client=http_client, ) special_tokens = {ENDOFTEXT: 199999, ENDOFPROMPT: 200018} # This regex could be made more efficient. If I was the one working on this encoding, I would @@ -120,8 +131,8 @@ def o200k_base(): } -def o200k_harmony(): - base_enc = o200k_base() +def o200k_harmony(http_client: HttpClient | None = None): + base_enc = o200k_base(http_client=http_client) name = "o200k_harmony" pat_str = base_enc["pat_str"] mergeable_ranks = base_enc["mergeable_ranks"]