diff --git a/.gitignore b/.gitignore index ddc5107..250de2a 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ tmp/ +models/ checkpoints/ *.bin diff --git a/experiments/mnist.py b/experiments/01_mnist.py similarity index 100% rename from experiments/mnist.py rename to experiments/01_mnist.py diff --git a/experiments/mini_transformer.py b/experiments/02_mini_transformer.py similarity index 96% rename from experiments/mini_transformer.py rename to experiments/02_mini_transformer.py index c135d8d..4c99d7d 100644 --- a/experiments/mini_transformer.py +++ b/experiments/02_mini_transformer.py @@ -10,24 +10,26 @@ # Reproduce https://sdbuchanan.com/blog/jax-2/ from functools import partial + import grain import jax import jax.numpy as jnp import numpy as np import optax from jax.nn.initializers import truncated_normal + from julax.core import Learner, Trainer from julax.einops import Rearrange from julax.experiment import Experiment from julax.layers import ( Chain, - Linear, + Embedding, LayerNorm, + Linear, Parallel, - Repeated, + Repeat, + Residual, RotaryEmbedding, - SkipConnection, - Embedding, Unembedding, ) from julax.observers import default_observer @@ -75,11 +77,11 @@ def main( out_dim=dim, w_init=truncated_normal(stddev=param_std), ), - blocks=Repeated( + blocks=Repeat( n=num_layers, layer=Chain( - attn=SkipConnection( - layer=Chain( + attn=Residual( + processor=Chain( norm_attn=LayerNorm(dim=dim), attn=Chain( # qkv projection @@ -130,8 +132,8 @@ def main( ), ) ), - mlp=SkipConnection( - layer=Chain( + mlp=Residual( + processor=Chain( norm_mlp=LayerNorm(dim=dim), mlp=Chain( up=Linear( diff --git a/experiments/03_Llama_3.2_1B.py b/experiments/03_Llama_3.2_1B.py new file mode 100644 index 0000000..25b7e1f --- /dev/null +++ b/experiments/03_Llama_3.2_1B.py @@ -0,0 +1,440 @@ +# /// script +# dependencies = [ +# "julax", +# "pyarrow", +# ] +# +# [tool.uv.sources] +# julax = { path = "../", editable = true } +# /// + +import os +import pickle +from safetensors import safe_open + +import grain +import jax +import jax.numpy as jnp +import numpy as np +from grain._src.core.sharding import ShardByJaxProcess, even_split +from grain.experimental import FlatMapIterDataset, FlatMapTransform, ParquetIterDataset +from jax import Array + +from julax.base import Dtype +from julax.core import LayerBase, Param, State +from julax.einops import Rearrange +from julax.layers import ( + Branch, + Chain, + Embedding, + Linear, + Parallel, + Repeat, + Residual, + RMSNorm, + Select, +) +from julax.utils import identity + + +# Adapted from: +# https://github.com/AI-Hypercomputer/maxtext/blob/9204d6bbbf8bb19a05ebed72a55cfec687e0e044/src/MaxText/layers/embeddings.py#L486-L622 +# TODO: The real and imaginary part are interleaved. benchmark with the HF +# transformer style (first half as real, second half as imaginary). +def apply_rotary_emb( + inputs: jax.Array, + timescale: jax.Array, + position: None | jax.Array = None, + fprop_dtype: Dtype | None = jnp.bfloat16, +) -> jax.Array: + """Applies LLaMA variant of rotary position embedding. + + Args: + inputs: The input sequence on which to apply the Rotary position + embedding. It is assumed of shape [B, S, N, H]. + position: Optional position array [B, S]. Only needed when the sequence + is packed. + + Returns: + A jax.Array of shape [B, S, N, H] with rotary position embeddings applied. + """ + # Ensure input is 4D + if len(inputs.shape) != 4: + raise ValueError( + "Input is assumed to be a rank 4 tensor of shape [B, S, N, H]." + ) + # Determine positions if not provided + if position is None: + seq_length = inputs.shape[1] + position = jnp.arange(seq_length, dtype=jnp.float32)[jnp.newaxis, :] + + # Calculate sinusoidal input + position = position[:, :, jnp.newaxis, jnp.newaxis] + sinusoid_inp = position / timescale + + sin = jnp.sin(sinusoid_inp) + cos = jnp.cos(sinusoid_inp) + + r, i = jnp.split(inputs, 2, axis=-1) + pos_r = cos * r - sin * i + pos_i = sin * r + cos * i + outputs = jnp.concatenate([pos_r, pos_i], axis=-1) + + if fprop_dtype: + outputs = outputs.astype(fprop_dtype) + + return outputs + + +class LLaMARotaryEmbedding(LayerBase): + embedding_dims: int + min_timescale: int = 1 + max_timescale: int = 10_000 + cast_as_fprop_dtype: bool = True + fprop_dtype: Dtype = jnp.bfloat16 + + scaling_factor: float = 8.0 + low_freq_factor: float = 1.0 + high_freq_factor: float = 4.0 + original_max_position_embeddings: int = 8192 + + def _apply_scaling_factor(self, freq): + """apply scaling factor to rotary position embedding.""" + low_freq_wavelen = self.original_max_position_embeddings / self.low_freq_factor + high_freq_wavelen = ( + self.original_max_position_embeddings / self.high_freq_factor + ) + wavelen = 2 * jnp.pi / freq + + def lower_wavelen(freq): + return freq + + def bigger_or_equal_wavelen(freq): + def bigger_wavelen(freq): + return freq / self.scaling_factor + + def equal_wavelen(freq): + smooth = ( + self.original_max_position_embeddings / wavelen + - self.low_freq_factor + ) / (self.high_freq_factor - self.low_freq_factor) + return (1 - smooth) * freq / self.scaling_factor + smooth * freq + + bigger_wavelen_cond = wavelen > low_freq_wavelen + return jax.lax.cond( + bigger_wavelen_cond, bigger_wavelen, equal_wavelen, freq + ) + + lower_wavelen_cond = wavelen < high_freq_wavelen + return jax.lax.cond( + lower_wavelen_cond, lower_wavelen, bigger_or_equal_wavelen, freq + ) + + def state(self, rng) -> State: + half_embedding_dim = self.embedding_dims // 2 + fraction = 2 * jnp.arange(0, half_embedding_dim) / self.embedding_dims + timescale = ( + self.min_timescale * (self.max_timescale / self.min_timescale) ** fraction + ) + timescale = 1.0 / jax.vmap(self._apply_scaling_factor)(1.0 / timescale) + return State(timescale=timescale) + + def forward(self, x: Array, p: Param, s: State) -> tuple[Array, State]: + return apply_rotary_emb( + x, + s["timescale"], + position=None, + fprop_dtype=self.fprop_dtype if self.cast_as_fprop_dtype else None, + ), s + + +class Tokenize(FlatMapTransform): + def __init__(self, tokenizer_path: str) -> None: + super().__init__() + with open(tokenizer_path, "rb") as f: + self.tokenizer = pickle.load(f) + self.bos_token_id = self.tokenizer.encode_single_token("<|bos|>") + + def encode(self, text: str) -> list[int]: + return [self.bos_token_id] + self.tokenizer.encode_ordinary(text) + + def flat_map(self, element): + return self.encode(element) + + def get_segment_ids(self, tokens: np.ndarray): + assert tokens.ndim == 2 + bos_mask = tokens == self.bos_token_id + bos_mask[:, 0] = False + segment_ids = np.cumsum(bos_mask, axis=1) + + return segment_ids + + +def create_dataset( + batch_size: int, + seq_len: int, + data_dir: str, + tokenizer_path: str, + split: str = "train", + seed: int = 2025, +) -> grain.IterDataset: + files = sorted([os.path.join(data_dir, f) for f in os.listdir(data_dir)]) + if split == "train": + files = files[:-1] + else: + # TODO: + raise ValueError("Unsupported yet") + + tokenize = Tokenize(tokenizer_path) + + # TODO: window shuffle? + # TODO: prefetch to device? + ds = grain.experimental.InterleaveIterDataset( + grain.MapDataset.source(files) + .shuffle(seed=seed) + .slice(slice(*even_split(len(files), ShardByJaxProcess(drop_remainder=True)))) + .map( + lambda file_path: FlatMapIterDataset( + ParquetIterDataset(file_path).map(lambda x: x["text"]), tokenize + ) + .batch(batch_size * seq_len + 1) + .map(np.array) + .map( + lambda x: { + "inputs": { + "token_ids": x[:-1].reshape(batch_size, seq_len), + "segment_ids": tokenize.get_segment_ids( + x[:-1].reshape(batch_size, seq_len) + ), + }, + "target_labels": x[1:].reshape(batch_size, seq_len), + } + ) + ), # pyright: ignore[reportArgumentType] + cycle_length=4, + ) + + return ds.mp_prefetch( + grain.MultiprocessingOptions(num_workers=4, per_worker_buffer_size=1) + ) + + +def attention(qkv, timescale): + q, k, v = qkv.values() + q = apply_rotary_emb(q, timescale) + k = apply_rotary_emb(k, timescale) + o = jax.nn.dot_product_attention(q, k, v, is_causal=True) + return o + + +class Transformer(LayerBase): + emb: Embedding + rope: LLaMARotaryEmbedding + blocks: Repeat + out_norm: RMSNorm + + def forward(self, x: dict, p: Param, s: State) -> tuple[Array, State]: + S = State(rope=s["rope"]) + + h = x["token_ids"] + h, S["emb"] = self.emb(h, p["emb"], s["emb"]) + h, S["blocks"] = self.blocks( + {"hidden": h, "timescale": s["rope"]["timescale"]}, p["blocks"], s["blocks"] + ) + h, S["out_norm"] = self.out_norm(h["hidden"], p["out_norm"], s["out_norm"]) + + o = self.emb.attend(h, p["emb"]) + + return o, S + + +def create_transformer( + batch_size=1, + seq_len=10, + dim=2048, + num_q_heads=32, + num_kv_heads=8, + head_dim=64, + ffn_hidden_dim=8192, + vocab_size=128256, +) -> Transformer: + return Transformer( + emb=Embedding(in_dim=vocab_size, out_dim=dim, param_dtype=jnp.bfloat16), + rope=LLaMARotaryEmbedding( + embedding_dims=head_dim, + min_timescale=1, + max_timescale=500_000, + ), + out_norm=RMSNorm(dim=dim, eps=1e-05, scale_dtype=jnp.bfloat16), + blocks=Repeat( + n=16, + layer=Branch( # {hidden(in), timescale} => {hidden(out), timescale} + hidden=Chain( + attn=Residual( + Chain( + Parallel( + qkv=Chain( + norm=RMSNorm(dim=dim, eps=1e-05), + qkv_proj=Branch( + q=Chain( + Linear( + in_dim=dim, + out_dim=num_q_heads * head_dim, + param_dtype=jnp.bfloat16, + ), + Rearrange( + "B T (N H) -> B T N H", + B=batch_size, + T=seq_len, + N=num_q_heads, + H=head_dim, + ), + ), + k=Chain( + Linear( + in_dim=dim, + out_dim=num_kv_heads * head_dim, + param_dtype=jnp.bfloat16, + ), + Rearrange( + "B S (K H) -> B S K H", + B=batch_size, + S=seq_len, + K=num_kv_heads, + H=head_dim, + ), + ), + v=Chain( + Linear( + in_dim=dim, + out_dim=num_kv_heads * head_dim, + param_dtype=jnp.bfloat16, + ), + Rearrange( + "B S (K H) -> B S K H", + B=batch_size, + S=seq_len, + K=num_kv_heads, + H=head_dim, + ), + ), + ), + ), + timescale=identity, + reduce=attention, + ), + Rearrange( + "B T N H -> B T (N H)", + B=batch_size, + T=seq_len, + N=num_q_heads, + H=head_dim, + ), + Linear( + in_dim=dim, + out_dim=dim, + param_dtype=jnp.bfloat16, + ), + ), + skip_through=Select(key="hidden"), + ), + ffn=Residual( + Chain( + norm=RMSNorm(dim=dim, eps=1e-05), + up=Branch( + # up_proj + Linear( + in_dim=dim, + out_dim=ffn_hidden_dim, + param_dtype=jnp.bfloat16, + ), + # gate_proj + Chain( + proj=Linear( + in_dim=dim, + out_dim=ffn_hidden_dim, + param_dtype=jnp.bfloat16, + ), + activation=jax.nn.silu, + ), + reduce=jnp.multiply, + ), + down=Linear( + in_dim=ffn_hidden_dim, + out_dim=dim, + param_dtype=jnp.bfloat16, + ), + ) + ), + ), + timescale=Select(key="timescale"), + ), + ), + ) + + +def from_hf(): + tensors = {} + with safe_open( + "models/Llama-3.2-1B-Instruct/model.safetensors", framework="flax", device="cpu" + ) as f: + for key in f.keys(): + tensors[key] = f.get_tensor(key) + return tensors + + +def verify(): + m = create_transformer() + p, s = m.init() + + tensors = from_hf() + input_ids = jnp.array([[128000, 791, 6367, 311, 28915, 264, 1695, 19692, 374, 220]]) + p, s = m.init() + + w_ln1 = [] + w_q = [] + w_k = [] + w_v = [] + w_o = [] + w_ln2 = [] + w_up = [] + w_gate = [] + w_down = [] + + for i in range(16): + w_ln1.append(tensors[f"model.layers.{i}.input_layernorm.weight"]) + w_q.append(tensors[f"model.layers.{i}.self_attn.q_proj.weight"].T) + w_k.append(tensors[f"model.layers.{i}.self_attn.k_proj.weight"].T) + w_v.append(tensors[f"model.layers.{i}.self_attn.v_proj.weight"].T) + w_o.append(tensors[f"model.layers.{i}.self_attn.o_proj.weight"].T) + + w_ln2.append(tensors[f"model.layers.{i}.post_attention_layernorm.weight"]) + w_up.append(tensors[f"model.layers.{i}.mlp.up_proj.weight"].T) + w_gate.append(tensors[f"model.layers.{i}.mlp.gate_proj.weight"].T) + w_down.append(tensors[f"model.layers.{i}.mlp.down_proj.weight"].T) + + p["blocks"]["hidden"]["attn"]["#0"]["#0"]["qkv"]["norm"]["scale"] = jnp.stack( + w_ln1, axis=0 + ) + p["blocks"]["hidden"]["attn"]["#0"]["#0"]["qkv"]["qkv_proj"]["q"]["#0"]["w"] = ( + jnp.stack(w_q, axis=0) + ) + p["blocks"]["hidden"]["attn"]["#0"]["#0"]["qkv"]["qkv_proj"]["k"]["#0"]["w"] = ( + jnp.stack(w_k, axis=0) + ) + p["blocks"]["hidden"]["attn"]["#0"]["#0"]["qkv"]["qkv_proj"]["v"]["#0"]["w"] = ( + jnp.stack(w_v, axis=0) + ) + p["blocks"]["hidden"]["attn"]["#0"]["#2"]["w"] = jnp.stack(w_o, axis=0) + + p["blocks"]["hidden"]["ffn"]["#0"]["norm"]["scale"] = jnp.stack(w_ln2, axis=0) + p["blocks"]["hidden"]["ffn"]["#0"]["up"]["#0"]["w"] = jnp.stack(w_up, axis=0) + p["blocks"]["hidden"]["ffn"]["#0"]["up"]["#1"]["proj"]["w"] = jnp.stack( + w_gate, axis=0 + ) + p["blocks"]["hidden"]["ffn"]["#0"]["down"]["w"] = jnp.stack(w_down, axis=0) + + p["emb"]["w"] = tensors["model.embed_tokens.weight"] + p["out_norm"]["scale"] = tensors["model.norm.weight"] + + return m({"token_ids": input_ids}, p, s) diff --git a/pyproject.toml b/pyproject.toml index 1ce96e7..640585f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "julax" -version = "0.0.3" +version = "0.0.4-dev" description = "Just Layers over JAX" readme = "README.md" authors = [ @@ -10,6 +10,7 @@ requires-python = ">=3.12" dependencies = [ "einops>=0.8.1", "grain>=0.2.12", + "humanize>=4.13.0", "jax>=0.7.2", "optax>=0.2.6", "orbax-checkpoint>=0.11.25", @@ -25,10 +26,6 @@ requires = ["hatchling"] build-backend = "hatchling.build" [project.optional-dependencies] -dev = [ - "pytest>=8.3.2", - "pytest-cov>=5.0.0", -] tpu = [ "jax[tpu]>=0.7.2", ] @@ -39,3 +36,12 @@ ignore = ["E741", "F811"] [[tool.uv.index]] url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple" default = true + +[dependency-groups] +dev = [ + "pyarrow>=22.0.0", + "pytest>=8.4.2", + "safetensors>=0.7.0", + "tensorflow-datasets>=4.9.9", + "tiktoken>=0.12.0", +] diff --git a/src/julax/base.py b/src/julax/base.py index 39d1923..817f582 100644 --- a/src/julax/base.py +++ b/src/julax/base.py @@ -32,6 +32,9 @@ def values(self): def items(self): return self.root.items() + def __len__(self): + return len(self.root) + def __hash__(self): return hash(frozenset(self.root.items())) diff --git a/src/julax/core.py b/src/julax/core.py index 234b658..1633f58 100644 --- a/src/julax/core.py +++ b/src/julax/core.py @@ -15,7 +15,7 @@ ##### -from julax.base import PRNG, Dtype, OutShardingType, PyTree, dispatch +from julax.base import PRNG, PyTree, dispatch # TODO: use RootModel[dict] for better customization # Or maybe SimpleNamespace? @@ -26,10 +26,6 @@ class LayerBase(BaseModel, ABC): - param_dtype: Dtype | None = None - param_sharding: OutShardingType = None - out_sharding: OutShardingType = None - model_config = ConfigDict( arbitrary_types_allowed=True, frozen=True, @@ -40,13 +36,6 @@ class LayerBase(BaseModel, ABC): ), ) - @classmethod - def __pydantic_init_subclass__(cls, **kwargs): - # TODO: respect `FieldInfo` - jax.tree_util.register_dataclass( - cls, data_fields=list(cls.model_fields.keys()), meta_fields=[] - ) - def sublayers(self) -> dict: attrs_flatten, treedef = jax.tree.flatten( dict(self), is_leaf=lambda x: isinstance(x, LayerBase) @@ -66,12 +55,37 @@ def sublayers(self) -> dict: res[k] = v return res + def __getitem__(self, key: str) -> "LayerBase": + return self.sublayers()[key] + + def _ipython_display_(self): + from julax.pprint import pprint + + pprint(self) + def param(self, rng: PRNG) -> Param: return Param() + def param_length(self) -> int: + return 0 + def state(self, rng: PRNG) -> State: return State() + def state_length(self) -> int: + return 0 + + def numel(self) -> tuple[int, int]: + num_params = self.param_length() + num_states = self.state_length() + + for sublayer in self.sublayers().values(): + p, s = sublayer.numel() + num_params += p + num_states += s + + return num_params, num_states + @dispatch def init(self, seed: int = 0) -> tuple[Param, State]: return self.init(jax.random.key(seed)) @@ -109,7 +123,10 @@ def init( assert len(layer_params.keys() & sublayer_params.keys()) == 0 assert len(layer_states.keys() & sublayer_states.keys()) == 0 - return sublayer_params | layer_params, sublayer_states | layer_states + return ( + sublayer_params | layer_params, + sublayer_states | layer_states, + ) @abstractmethod def forward(self, x: PyTree, p: Param, s: State) -> tuple[PyTree, State]: ... @@ -165,7 +182,10 @@ def init( self, layer_params, layer_states, sublayer_params, sublayer_states ) -> tuple[Param, State]: layer_states["optimizer"] = self.optimizer.init(sublayer_params["learner"]) - return sublayer_params | layer_params, sublayer_states | layer_states + return ( + sublayer_params | layer_params, + sublayer_states | layer_states, + ) def forward(self, x: PyTree, p: Param, s: State) -> tuple[PyTree, State]: loss, state = self.learner(x, p["learner"], s["learner"]) diff --git a/src/julax/experiment.py b/src/julax/experiment.py index d80ef10..6769388 100644 --- a/src/julax/experiment.py +++ b/src/julax/experiment.py @@ -1,5 +1,6 @@ +from functools import cached_property import jax -from jax.sharding import PartitionSpec +from jax.sharding import PartitionSpec, Mesh from julax.utils import create_mesh @@ -10,7 +11,7 @@ import logging -from pydantic import Field +from pydantic import Field, computed_field from .observers import default_observer, ObserverBase @@ -30,6 +31,11 @@ class Experiment(LayerBase): checkpoint_manager: ocp.CheckpointManager | None = None observer: ObserverBase = Field(default_factory=default_observer) + @computed_field + @cached_property + def mesh(self) -> Mesh: + return create_mesh(self.mesh_shape) + def state(self, rng: PRNG) -> State: return State(input=iter(self.dataset), step=0) @@ -86,7 +92,7 @@ def close(self): self.checkpoint_manager.close() def run(self) -> tuple[Param, State]: - with create_mesh(self.mesh_shape) as mesh: + with self.mesh as mesh: p, s = self.restore() self.observer(self, p, s) diff --git a/src/julax/inputs.py b/src/julax/inputs.py deleted file mode 100644 index f47dd07..0000000 --- a/src/julax/inputs.py +++ /dev/null @@ -1,6 +0,0 @@ -from julax.base import PyTree - - -def create_global_input( - process_local_data: PyTree, -): ... diff --git a/src/julax/layers.py b/src/julax/layers.py index cc78f55..cc57fd3 100644 --- a/src/julax/layers.py +++ b/src/julax/layers.py @@ -1,20 +1,19 @@ from typing import Callable import jax +import jax.numpy as jnp from jax import Array +from jax.nn.initializers import Initializer, lecun_normal, ones, variance_scaling, zeros from jax.sharding import PartitionSpec as P -import jax.numpy as jnp -from jax.nn.initializers import ( - Initializer, - lecun_normal, - ones, - zeros, - variance_scaling, -) -from julax.base import Dtype +from julax.base import Dtype, OutShardingType +from julax.utils import identity + -from .core import PRNG, LayerBase, LayerLike, PyTree, Param, State, dispatch +from typing import Annotated + +from pydantic import Field +from .core import PRNG, LayerBase, LayerLike, Param, PyTree, State, dispatch class F(LayerBase): @@ -29,39 +28,54 @@ def to_layer(x: Callable): return F(f=x) -class SkipConnection(LayerBase): - layer: LayerLike - connection: Callable = jnp.add +# TODO: generalize to select subtree +class Select(LayerBase): + key: int | str def forward(self, x: PyTree, p: Param, s: State) -> tuple[PyTree, State]: - S = State() - o, S["layer"] = self.layer(x, p["layer"], s["layer"]) - return self.connection(o, x), S - - -class Repeated(LayerBase): + match self.key: + case int(k): + return x[k], s + case str(k) if k.startswith("."): + return getattr(x, k), s + case str(k): + return x[k], s + case _: + raise ValueError(f"Unsupported key type: {type(self.key)}") + + +class Repeat(LayerBase): n: int layer: LayerLike def sublayers(self) -> dict: - return {f"layer_{i}": self.layer for i in range(self.n)} + return {f"#{i}": self.layer for i in range(self.n)} + + @dispatch + def init(self, rng: PRNG) -> tuple[Param, State]: + def scan_init(carry, rng): + p, s = self.layer.init(rng) + return carry, (p, s) + + rngs = jax.random.split(rng, self.n) + _, (P, S) = jax.lax.scan(scan_init, None, rngs) + return P, S + + def __getitem__(self, key) -> LayerBase: + return self.layer def forward(self, x: PyTree, p: Param, s: State) -> tuple[PyTree, State]: - S = State() - o = x - for i in range(self.n): - o, S[f"layer_{i}"] = self.layer(o, p[f"layer_{i}"], s[f"layer_{i}"]) - return o, S + def scan_forward(x, ps): + return self.layer(x, *ps) + o, s = jax.lax.scan(scan_forward, x, (p, s)) -class NamedLayers(LayerBase): - names: tuple[str, ...] - layers: tuple[LayerLike, ...] + return o, s - def __init__(self, *args, **kwargs): - names = tuple(f"layer_{i}" for i in range(len(args))) + tuple(kwargs.keys()) - layers = tuple(args) + tuple(kwargs.values()) - super().__init__(names=names, layers=layers) + +class NamedLayers(LayerBase): + names: Annotated[tuple[str, ...], Field(repr=False)] + layers: Annotated[tuple[LayerLike, ...], Field(repr=False)] def sublayers(self) -> dict: return {k: v for k, v in zip(self.names, self.layers)} @@ -69,7 +83,15 @@ def sublayers(self) -> dict: class Chain(NamedLayers): def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) + names = tuple(f"#{i}" for i in range(len(args))) + tuple(kwargs.keys()) + layers = tuple(args) + tuple(kwargs.values()) + super().__init__(names=names, layers=layers) + + def __getitem__(self, key: str | int) -> LayerBase: + if isinstance(key, int): + return self.layers[key] + else: + return self.sublayers()[key] def forward(self, x: PyTree, p: Param, s: State) -> tuple[PyTree, State]: h = x @@ -82,32 +104,49 @@ def forward(self, x: PyTree, p: Param, s: State) -> tuple[PyTree, State]: class Branch(NamedLayers): """1 -> N""" - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) + reduce: Callable | None = None + + def __init__(self, *args, reduce: Callable | None = None, **kwargs): + names = tuple(f"#{i}" for i in range(len(args))) + tuple(kwargs.keys()) + layers = tuple(args) + tuple(kwargs.values()) + super().__init__(names=names, layers=layers, reduce=reduce) def forward(self, x: PyTree, p: Param, s: State) -> tuple[PyTree, State]: O = {} S = State() for name, layer in zip(self.names, self.layers): O[name], S[name] = layer(x, p[name], s[name]) - # ??? return dict? - return tuple(O.values()), S + if self.reduce is not None: + args = (v for k, v in O.items() if k.startswith("#")) + kwargs = {k: v for k, v in O.items() if not k.startswith("#")} + O = self.reduce(*args, **kwargs) + return O, S -class Parallel(NamedLayers): +class Residual(Branch): + def __init__(self, processor, *, skip_through=identity, reduce: Callable = jnp.add): + super().__init__(processor, skip_through, reduce=reduce) + + +class Parallel(Branch): """N -> N""" + # place holder to bypass link check def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def forward(self, x: PyTree, p: Param, s: State) -> tuple[PyTree, State]: + assert isinstance(x, dict) assert len(x) == len(self.layers) O = {} S = State() - for name, layer, xᵢ in zip(self.names, self.layers, x): + for name, layer, xᵢ in zip(self.names, self.layers, x.values()): O[name], S[name] = layer(xᵢ, p[name], s[name]) - # ??? return dict? - return tuple(O.values()), S + if self.reduce is not None: + args = (v for k, v in O.items() if k.startswith("#")) + kwargs = {k: v for k, v in O.items() if not k.startswith("#")} + O = self.reduce(*args, **kwargs) + return O, S ##### @@ -117,36 +156,38 @@ class Linear(LayerBase): in_dim: int out_dim: int w_init: Initializer = lecun_normal() - b_init: None | Initializer = zeros + b_init: Initializer | None = None + + param_dtype: Dtype | None = None + param_sharding: OutShardingType = None + out_sharding: OutShardingType = None def param(self, rng: PRNG) -> Param: + p = Param() rng_w, rng_b = jax.random.split(rng) - return Param( - w=self.w_init( - rng_w, - (self.in_dim, self.out_dim), - dtype=self.param_dtype, - out_sharding=self.param_sharding, - ), - b=( - self.b_init( - rng_b, - (self.out_dim,), - dtype=self.param_dtype, - out_sharding=( - None - if self.param_sharding is None - else P(self.param_sharding[-1]) - ), - ) - if self.b_init - else None - ), + p["w"] = self.w_init( + rng_w, + (self.in_dim, self.out_dim), + dtype=self.param_dtype, + out_sharding=self.param_sharding, ) + if self.b_init: + p["b"] = self.b_init( + rng_b, + (self.out_dim,), + dtype=self.param_dtype, + out_sharding=( + None if self.param_sharding is None else P(self.param_sharding[-1]) + ), + ) + return p + + def param_length(self) -> int: + return self.in_dim * self.out_dim + (self.out_dim if self.b_init else 0) def forward(self, x: Array, p: Param, s: State) -> tuple[Array, State]: o = jnp.einsum("...d,dh->...h", x, p["w"], out_sharding=self.out_sharding) - if p["b"] is not None: + if self.b_init is not None: o += p["b"] return o, s @@ -157,6 +198,9 @@ class Dropout(LayerBase): def state(self, rng: PRNG) -> State: return State(rng=rng, is_training=True) + def state_length(self) -> int: + return 4 # typically 32 bits? + def forward(self, x: Array, p: Param, s: State) -> tuple[Array, State]: rng, s["rng"] = jax.random.split(s["rng"]) if s["is_training"] and self.rate > 0: @@ -190,7 +234,11 @@ def test_mode(s: State): class Embedding(LayerBase): in_dim: int out_dim: int - w_init: Initializer = variance_scaling(1.0, "fan_in", "normal", out_axis=0) + w_init: Initializer = variance_scaling(1.0, "fan_out", "normal") + + param_dtype: Dtype | None = None + param_sharding: OutShardingType = None + out_sharding: OutShardingType = None def param(self, rng: PRNG) -> Param: return Param( @@ -202,9 +250,15 @@ def param(self, rng: PRNG) -> Param: ) ) + def param_length(self) -> int: + return self.in_dim * self.out_dim + def forward(self, x: Array, p: Param, s: State) -> tuple[Array, State]: return p["w"].at[x].get(out_sharding=self.out_sharding), s + def attend(self, x: Array, p: Param) -> Array: + return jnp.einsum("...ld,nd->...ln", x, p["w"], out_sharding=self.out_sharding) + class RotaryEmbedding(LayerBase): """Rotary Position Embedding.""" @@ -227,6 +281,9 @@ def state(self, rng: PRNG) -> State: timescale = timescale * self.rope_linear_scaling_factor return State(timescale=timescale) + def state_length(self) -> int: + return self.embedding_dims // 2 + def forward(self, x: Array, p: Param, s: State) -> tuple[Array, State]: seq_length = x.shape[1] position = jnp.arange(seq_length, dtype=jnp.float32)[ @@ -245,11 +302,6 @@ def forward(self, x: Array, p: Param, s: State) -> tuple[Array, State]: return x_out, s -class Unembedding(Embedding): - def forward(self, x: Array, p: Param, s: State) -> tuple[Array, State]: - return jnp.einsum("bld,dn->bln", x, p["w"], out_sharding=self.out_sharding), s - - class LayerNorm(LayerBase): dim: int epsilon: float = 1e-5 @@ -257,6 +309,10 @@ class LayerNorm(LayerBase): b_init: Initializer = zeros compute_dtype: Dtype | None = None + param_dtype: Dtype | None = None + param_sharding: OutShardingType = None + out_sharding: OutShardingType = None + def param(self, rng: PRNG) -> Param: w_rng, b_rng = jax.random.split(rng) return Param( @@ -264,7 +320,7 @@ def param(self, rng: PRNG) -> Param: w_rng, (self.dim,), dtype=self.param_dtype, - out_sharding=self.out_sharding, + out_sharding=self.param_sharding, ), b=self.b_init( b_rng, @@ -276,8 +332,70 @@ def param(self, rng: PRNG) -> Param: ), ) + def param_length(self) -> int: + return 2 * self.dim + def forward(self, x: Array, p: Param, s: State) -> tuple[Array, State]: x_std = jax.nn.standardize( x.astype(self.compute_dtype), epsilon=self.epsilon ).astype(self.param_dtype) - return x_std * p["w"] + p["b"], s + o = x_std * p["w"] + p["b"] + if self.out_sharding is not None: + o = jax.lax.with_sharding_constraint(o, self.out_sharding) + return o, s + + +class RMSNorm(LayerBase): + dim: int + eps: float = 1e-8 + zero_center: bool = False + scale_init: Initializer | None = ones + scale_dtype: Dtype | None = None + scale_sharding: OutShardingType = None + + dtype: Dtype = jnp.float32 + param_sharding: OutShardingType = None + out_sharding: OutShardingType = None + + def param(self, rng: PRNG) -> Param: + if self.scale_init is None: + return Param() + else: + return Param( + scale=self.scale_init( + rng, + (self.dim,), + dtype=self.scale_dtype, + out_sharding=( + None + if self.param_sharding is None + else P(self.param_sharding[-1]) + ), + ) + ) + + def param_length(self) -> int: + if self.scale_init is None: + return 0 + else: + return self.dim + + def forward(self, x: Array, p: Param, s: State) -> tuple[Array, State]: + x_dtype = x.dtype + + x = x.astype(self.dtype) + rms = jax.lax.rsqrt(jnp.mean(jnp.square(x), axis=-1, keepdims=True) + self.eps) + + if self.zero_center: + x = x - x.mean(axis=-1, keepdims=True) + + o = x * rms + + if self.scale_init is not None: + o = o * p["scale"] + + o = o.astype(x_dtype) + + if self.out_sharding is not None: + o = jax.lax.with_sharding_constraint(o, self.out_sharding) + return o, s diff --git a/src/julax/pprint.py b/src/julax/pprint.py new file mode 100644 index 0000000..55db446 --- /dev/null +++ b/src/julax/pprint.py @@ -0,0 +1,112 @@ +import jax +from julax.base import dispatch +from julax.core import LayerBase + +from rich.console import Console +from rich.tree import Tree +from rich.markup import escape + +import humanize +import zlib + +from julax.layers import Repeat + +console = Console() + + +def string_color(s: str) -> str: + return f"color({(zlib.adler32(s.encode()) + 123) % 256})" + + +def colorof(layer: LayerBase) -> str: + return string_color(layer.__class__.__name__) + + +def nameof(layer: LayerBase) -> str: + s = layer.__class__.__name__ + return f"[{colorof(layer)}]{s}[/]" + + +def argsof(layer: LayerBase) -> str: + fields = layer.model_dump( + exclude_defaults=True, exclude_unset=True, exclude_none=True + ) + for k, v in type(layer).model_fields.items(): + if k in fields and v.repr is False: + del fields[k] + if isinstance(getattr(layer, k), LayerBase): + del fields[k] + + if fields: + args = [f"[bright_blue]{k}[/]={escape(repr(v))}" for k, v in fields.items()] + return "(" + ", ".join(args) + ")" + else: + return "" + + +def param_info(layer: LayerBase) -> str: + num_params, num_states = layer.numel() + total = num_params + num_states + if total == 0: + return "" + s = f" [dim]# Total Params: [bright_green]{humanize.intcomma(total)}[/][/]" + if num_states == 0: + return s + else: + return ( + s + + f" [trainable=[bright_green]{humanize.intcomma(num_params)}[/], non_trainable=[bright_green]{humanize.intcomma(num_states)}[/]]" + ) + + +def summary(layer: LayerBase) -> str: + return f"[bold]{nameof(layer)}[/][dim]{argsof(layer)}[/]" + + +@dispatch +def to_rich(layer: Repeat) -> Tree: + root = Tree(summary(layer) + param_info(layer), guide_style=colorof(layer)) + child = to_rich(layer.layer) + child.label = f"[{colorof(layer.layer)}]0..{layer.n - 1}[/] [bright_yellow]=>[/] {child.label}" + root.children.append(child) + return root + + +@dispatch +def to_rich(layer: LayerBase) -> Tree: + root = Tree(summary(layer) + param_info(layer), guide_style=colorof(layer)) + for name, sublayer in layer.sublayers().items(): + child = to_rich(sublayer) + child.label = ( + f"[{colorof(sublayer)}]{name}[/] [bright_yellow]=>[/] {summary(sublayer)}" + ) + child.guide_style = colorof(sublayer) + root.children.append(child) + return root + + +@dispatch +def to_rich(x: jax.Array) -> str: + return str(jax.typeof(x)) + + +@dispatch +def to_rich(t: dict) -> Tree: + root = Tree("") + for k, v in t.items(): + child = to_rich(v) + if isinstance(child, Tree): + if child.label: + child.label = f"[bright_blue]{k}[/] [bright_yellow]=>[/] {child.label}" + else: + child.label = f"[bright_blue]{k}[/]" + root.children.append(child) + elif isinstance(child, str): + root.add(f"[bright_blue]{k}[/] [bright_yellow]=>[/] {child}") + else: + raise NotImplementedError() + return root + + +def pprint(x): + console.print(to_rich(x)) diff --git a/tests/test_inputs.py b/tests/test_inputs.py new file mode 100644 index 0000000..3a2ee1b --- /dev/null +++ b/tests/test_inputs.py @@ -0,0 +1,107 @@ +import numpy as np +import pytest +from julax.inputs import preprocess_text_inputs + +# The beginning-of-sequence token ID is consistently 0 for all test cases. +BOS_TOKEN_ID = 0 + +# List of test cases for parameterization. Each case is defined with pytest.param +# for better readability and separate tracking in test results. +# Each parameter tuple contains: (tokens, expected_mask, expected_position_ids) +TEST_CASES = [ + pytest.param( + # Input: A single sequence starting with BOS. + np.array([[BOS_TOKEN_ID, 1, 2, 3]]), + # Expected mask: A standard 4x4 causal mask. + np.tril(np.ones((4, 4), dtype=bool))[None, :, :], + # Expected positions: Standard increasing positions. + np.array([[0, 1, 2, 3]]), + id="single_sequence_with_bos", + ), + pytest.param( + # Input: A single sequence with no BOS tokens. + np.array([[1, 2, 3, 4]]), + # Expected mask: A standard 4x4 causal mask. + np.tril(np.ones((4, 4), dtype=bool))[None, :, :], + # Expected positions: Standard increasing positions. + np.array([[0, 1, 2, 3]]), + id="no_bos_tokens", + ), + pytest.param( + # Input: Two packed sequences: [BOS, 1] and [BOS, 2, 3]. + np.array([[BOS_TOKEN_ID, 1, BOS_TOKEN_ID, 2, 3]]), + # Expected mask: A block-causal mask for segments (0,1) and (2,3,4). + np.array( + [ + [ + [1, 0, 0, 0, 0], + [1, 1, 0, 0, 0], + [0, 0, 1, 0, 0], + [0, 0, 1, 1, 0], + [0, 0, 1, 1, 1], + ] + ], + dtype=bool, + ), + # Expected positions: Positions reset at the second BOS token. + np.array([[0, 1, 0, 1, 2]]), + id="packed_sequences", + ), + pytest.param( + # Input: Two packed sequences where the first doesn't start with BOS: [1, 2] and [BOS, 3]. + np.array([[1, 2, BOS_TOKEN_ID, 3]]), + # Expected mask: A block-causal mask for segments (0,1) and (2,3). + np.array( + [[[1, 0, 0, 0], [1, 1, 0, 0], [0, 0, 1, 0], [0, 0, 1, 1]]], dtype=bool + ), + # Expected positions: Positions reset at the BOS token. + np.array([[0, 1, 0, 1]]), + id="sequence_without_initial_bos", + ), + pytest.param( + # Input: A batch of 2 sequences with mixed packing. + np.array( + [ + [ + BOS_TOKEN_ID, + 1, + BOS_TOKEN_ID, + 3, + 4, + ], # Packed: [BOS, 1] and [BOS, 3, 4] + [1, 2, 3, 4, 5], # Not packed + ] + ), + # Expected mask: A batch of masks, one block-causal and one standard-causal. + np.array( + [ + [ + [1, 0, 0, 0, 0], + [1, 1, 0, 0, 0], + [0, 0, 1, 0, 0], + [0, 0, 1, 1, 0], + [0, 0, 1, 1, 1], + ], # Block-causal + np.tril(np.ones((5, 5), dtype=bool)), # Standard causal + ], + dtype=bool, + ), + # Expected positions: A batch of positions, one resetting and one standard. + np.array([[0, 1, 0, 1, 2], [0, 1, 2, 3, 4]]), + id="batch_of_sequences", + ), +] + + +@pytest.mark.parametrize("tokens, expected_mask, expected_position_ids", TEST_CASES) +def test_preprocess_text_inputs(tokens, expected_mask, expected_position_ids): + """ + Tests the preprocess_text_inputs function with various parameterized scenarios + covering single sequences, packed sequences, and batched inputs. + """ + # Run the function being tested. + result = preprocess_text_inputs(tokens, BOS_TOKEN_ID) + + # Assert that the generated mask and position_ids match the expected outputs. + np.testing.assert_array_equal(result["mask"], expected_mask) + np.testing.assert_array_equal(result["position_ids"], expected_position_ids) diff --git a/uv.lock b/uv.lock index 5fbac64..b307da6 100644 --- a/uv.lock +++ b/uv.lock @@ -68,6 +68,72 @@ wheels = [ { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/96/2a/a4773109619010192e72f48e95165b14790413a51f513c879c8d63f67e17/beartype-0.22.2-py3-none-any.whl", hash = "sha256:12077afe3528eba5c5b801f816712f7ff06f6da5509994c79561e29b48bcedb8", size = 1317280, upload-time = "2025-10-04T06:37:53.99Z" }, ] +[[package]] +name = "certifi" +version = "2025.11.12" +source = { registry = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple" } +sdist = { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/a2/8c/58f469717fa48465e4a50c014a0400602d3c437d7c0c468e17ada824da3a/certifi-2025.11.12.tar.gz", hash = "sha256:d8ab5478f2ecd78af242878415affce761ca6bc54a22a27e026d7c25357c3316", size = 160538, upload-time = "2025-11-12T02:54:51.517Z" } +wheels = [ + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/70/7d/9bc192684cea499815ff478dfcdc13835ddf401365057044fb721ec6bddb/certifi-2025.11.12-py3-none-any.whl", hash = "sha256:97de8790030bbd5c2d96b7ec782fc2f7820ef8dba6db909ccf95449f2d062d4b", size = 159438, upload-time = "2025-11-12T02:54:49.735Z" }, +] + +[[package]] +name = "charset-normalizer" +version = "3.4.4" +source = { registry = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple" } +sdist = { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/13/69/33ddede1939fdd074bce5434295f38fae7136463422fe4fd3e0e89b98062/charset_normalizer-3.4.4.tar.gz", hash = "sha256:94537985111c35f28720e43603b8e7b43a6ecfb2ce1d3058bbe955b73404e21a", size = 129418, upload-time = "2025-10-14T04:42:32.879Z" } +wheels = [ + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/f3/85/1637cd4af66fa687396e757dec650f28025f2a2f5a5531a3208dc0ec43f2/charset_normalizer-3.4.4-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:0a98e6759f854bd25a58a73fa88833fba3b7c491169f86ce1180c948ab3fd394", size = 208425, upload-time = "2025-10-14T04:40:53.353Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/9d/6a/04130023fef2a0d9c62d0bae2649b69f7b7d8d24ea5536feef50551029df/charset_normalizer-3.4.4-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b5b290ccc2a263e8d185130284f8501e3e36c5e02750fc6b6bdeb2e9e96f1e25", size = 148162, upload-time = "2025-10-14T04:40:54.558Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/78/29/62328d79aa60da22c9e0b9a66539feae06ca0f5a4171ac4f7dc285b83688/charset_normalizer-3.4.4-cp312-cp312-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:74bb723680f9f7a6234dcf67aea57e708ec1fbdf5699fb91dfd6f511b0a320ef", size = 144558, upload-time = "2025-10-14T04:40:55.677Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/86/bb/b32194a4bf15b88403537c2e120b817c61cd4ecffa9b6876e941c3ee38fe/charset_normalizer-3.4.4-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:f1e34719c6ed0b92f418c7c780480b26b5d9c50349e9a9af7d76bf757530350d", size = 161497, upload-time = "2025-10-14T04:40:57.217Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/19/89/a54c82b253d5b9b111dc74aca196ba5ccfcca8242d0fb64146d4d3183ff1/charset_normalizer-3.4.4-cp312-cp312-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:2437418e20515acec67d86e12bf70056a33abdacb5cb1655042f6538d6b085a8", size = 159240, upload-time = "2025-10-14T04:40:58.358Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/c0/10/d20b513afe03acc89ec33948320a5544d31f21b05368436d580dec4e234d/charset_normalizer-3.4.4-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:11d694519d7f29d6cd09f6ac70028dba10f92f6cdd059096db198c283794ac86", size = 153471, upload-time = "2025-10-14T04:40:59.468Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/61/fa/fbf177b55bdd727010f9c0a3c49eefa1d10f960e5f09d1d887bf93c2e698/charset_normalizer-3.4.4-cp312-cp312-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:ac1c4a689edcc530fc9d9aa11f5774b9e2f33f9a0c6a57864e90908f5208d30a", size = 150864, upload-time = "2025-10-14T04:41:00.623Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/05/12/9fbc6a4d39c0198adeebbde20b619790e9236557ca59fc40e0e3cebe6f40/charset_normalizer-3.4.4-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:21d142cc6c0ec30d2efee5068ca36c128a30b0f2c53c1c07bd78cb6bc1d3be5f", size = 150647, upload-time = "2025-10-14T04:41:01.754Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/ad/1f/6a9a593d52e3e8c5d2b167daf8c6b968808efb57ef4c210acb907c365bc4/charset_normalizer-3.4.4-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:5dbe56a36425d26d6cfb40ce79c314a2e4dd6211d51d6d2191c00bed34f354cc", size = 145110, upload-time = "2025-10-14T04:41:03.231Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/30/42/9a52c609e72471b0fc54386dc63c3781a387bb4fe61c20231a4ebcd58bdd/charset_normalizer-3.4.4-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:5bfbb1b9acf3334612667b61bd3002196fe2a1eb4dd74d247e0f2a4d50ec9bbf", size = 162839, upload-time = "2025-10-14T04:41:04.715Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/c4/5b/c0682bbf9f11597073052628ddd38344a3d673fda35a36773f7d19344b23/charset_normalizer-3.4.4-cp312-cp312-musllinux_1_2_riscv64.whl", hash = "sha256:d055ec1e26e441f6187acf818b73564e6e6282709e9bcb5b63f5b23068356a15", size = 150667, upload-time = "2025-10-14T04:41:05.827Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/e4/24/a41afeab6f990cf2daf6cb8c67419b63b48cf518e4f56022230840c9bfb2/charset_normalizer-3.4.4-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:af2d8c67d8e573d6de5bc30cdb27e9b95e49115cd9baad5ddbd1a6207aaa82a9", size = 160535, upload-time = "2025-10-14T04:41:06.938Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/2a/e5/6a4ce77ed243c4a50a1fecca6aaaab419628c818a49434be428fe24c9957/charset_normalizer-3.4.4-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:780236ac706e66881f3b7f2f32dfe90507a09e67d1d454c762cf642e6e1586e0", size = 154816, upload-time = "2025-10-14T04:41:08.101Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/a8/ef/89297262b8092b312d29cdb2517cb1237e51db8ecef2e9af5edbe7b683b1/charset_normalizer-3.4.4-cp312-cp312-win32.whl", hash = "sha256:5833d2c39d8896e4e19b689ffc198f08ea58116bee26dea51e362ecc7cd3ed26", size = 99694, upload-time = "2025-10-14T04:41:09.23Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/3d/2d/1e5ed9dd3b3803994c155cd9aacb60c82c331bad84daf75bcb9c91b3295e/charset_normalizer-3.4.4-cp312-cp312-win_amd64.whl", hash = "sha256:a79cfe37875f822425b89a82333404539ae63dbdddf97f84dcbc3d339aae9525", size = 107131, upload-time = "2025-10-14T04:41:10.467Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/d0/d9/0ed4c7098a861482a7b6a95603edce4c0d9db2311af23da1fb2b75ec26fc/charset_normalizer-3.4.4-cp312-cp312-win_arm64.whl", hash = "sha256:376bec83a63b8021bb5c8ea75e21c4ccb86e7e45ca4eb81146091b56599b80c3", size = 100390, upload-time = "2025-10-14T04:41:11.915Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/97/45/4b3a1239bbacd321068ea6e7ac28875b03ab8bc0aa0966452db17cd36714/charset_normalizer-3.4.4-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:e1f185f86a6f3403aa2420e815904c67b2f9ebc443f045edd0de921108345794", size = 208091, upload-time = "2025-10-14T04:41:13.346Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/7d/62/73a6d7450829655a35bb88a88fca7d736f9882a27eacdca2c6d505b57e2e/charset_normalizer-3.4.4-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:6b39f987ae8ccdf0d2642338faf2abb1862340facc796048b604ef14919e55ed", size = 147936, upload-time = "2025-10-14T04:41:14.461Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/89/c5/adb8c8b3d6625bef6d88b251bbb0d95f8205831b987631ab0c8bb5d937c2/charset_normalizer-3.4.4-cp313-cp313-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:3162d5d8ce1bb98dd51af660f2121c55d0fa541b46dff7bb9b9f86ea1d87de72", size = 144180, upload-time = "2025-10-14T04:41:15.588Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/91/ed/9706e4070682d1cc219050b6048bfd293ccf67b3d4f5a4f39207453d4b99/charset_normalizer-3.4.4-cp313-cp313-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:81d5eb2a312700f4ecaa977a8235b634ce853200e828fbadf3a9c50bab278328", size = 161346, upload-time = "2025-10-14T04:41:16.738Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/d5/0d/031f0d95e4972901a2f6f09ef055751805ff541511dc1252ba3ca1f80cf5/charset_normalizer-3.4.4-cp313-cp313-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:5bd2293095d766545ec1a8f612559f6b40abc0eb18bb2f5d1171872d34036ede", size = 158874, upload-time = "2025-10-14T04:41:17.923Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/f5/83/6ab5883f57c9c801ce5e5677242328aa45592be8a00644310a008d04f922/charset_normalizer-3.4.4-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a8a8b89589086a25749f471e6a900d3f662d1d3b6e2e59dcecf787b1cc3a1894", size = 153076, upload-time = "2025-10-14T04:41:19.106Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/75/1e/5ff781ddf5260e387d6419959ee89ef13878229732732ee73cdae01800f2/charset_normalizer-3.4.4-cp313-cp313-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:bc7637e2f80d8530ee4a78e878bce464f70087ce73cf7c1caf142416923b98f1", size = 150601, upload-time = "2025-10-14T04:41:20.245Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/d7/57/71be810965493d3510a6ca79b90c19e48696fb1ff964da319334b12677f0/charset_normalizer-3.4.4-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:f8bf04158c6b607d747e93949aa60618b61312fe647a6369f88ce2ff16043490", size = 150376, upload-time = "2025-10-14T04:41:21.398Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/e5/d5/c3d057a78c181d007014feb7e9f2e65905a6c4ef182c0ddf0de2924edd65/charset_normalizer-3.4.4-cp313-cp313-musllinux_1_2_armv7l.whl", hash = "sha256:554af85e960429cf30784dd47447d5125aaa3b99a6f0683589dbd27e2f45da44", size = 144825, upload-time = "2025-10-14T04:41:22.583Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/e6/8c/d0406294828d4976f275ffbe66f00266c4b3136b7506941d87c00cab5272/charset_normalizer-3.4.4-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:74018750915ee7ad843a774364e13a3db91682f26142baddf775342c3f5b1133", size = 162583, upload-time = "2025-10-14T04:41:23.754Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/d7/24/e2aa1f18c8f15c4c0e932d9287b8609dd30ad56dbe41d926bd846e22fb8d/charset_normalizer-3.4.4-cp313-cp313-musllinux_1_2_riscv64.whl", hash = "sha256:c0463276121fdee9c49b98908b3a89c39be45d86d1dbaa22957e38f6321d4ce3", size = 150366, upload-time = "2025-10-14T04:41:25.27Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/e4/5b/1e6160c7739aad1e2df054300cc618b06bf784a7a164b0f238360721ab86/charset_normalizer-3.4.4-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:362d61fd13843997c1c446760ef36f240cf81d3ebf74ac62652aebaf7838561e", size = 160300, upload-time = "2025-10-14T04:41:26.725Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/7a/10/f882167cd207fbdd743e55534d5d9620e095089d176d55cb22d5322f2afd/charset_normalizer-3.4.4-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:9a26f18905b8dd5d685d6d07b0cdf98a79f3c7a918906af7cc143ea2e164c8bc", size = 154465, upload-time = "2025-10-14T04:41:28.322Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/89/66/c7a9e1b7429be72123441bfdbaf2bc13faab3f90b933f664db506dea5915/charset_normalizer-3.4.4-cp313-cp313-win32.whl", hash = "sha256:9b35f4c90079ff2e2edc5b26c0c77925e5d2d255c42c74fdb70fb49b172726ac", size = 99404, upload-time = "2025-10-14T04:41:29.95Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/c4/26/b9924fa27db384bdcd97ab83b4f0a8058d96ad9626ead570674d5e737d90/charset_normalizer-3.4.4-cp313-cp313-win_amd64.whl", hash = "sha256:b435cba5f4f750aa6c0a0d92c541fb79f69a387c91e61f1795227e4ed9cece14", size = 107092, upload-time = "2025-10-14T04:41:31.188Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/af/8f/3ed4bfa0c0c72a7ca17f0380cd9e4dd842b09f664e780c13cff1dcf2ef1b/charset_normalizer-3.4.4-cp313-cp313-win_arm64.whl", hash = "sha256:542d2cee80be6f80247095cc36c418f7bddd14f4a6de45af91dfad36d817bba2", size = 100408, upload-time = "2025-10-14T04:41:32.624Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/2a/35/7051599bd493e62411d6ede36fd5af83a38f37c4767b92884df7301db25d/charset_normalizer-3.4.4-cp314-cp314-macosx_10_13_universal2.whl", hash = "sha256:da3326d9e65ef63a817ecbcc0df6e94463713b754fe293eaa03da99befb9a5bd", size = 207746, upload-time = "2025-10-14T04:41:33.773Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/10/9a/97c8d48ef10d6cd4fcead2415523221624bf58bcf68a802721a6bc807c8f/charset_normalizer-3.4.4-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:8af65f14dc14a79b924524b1e7fffe304517b2bff5a58bf64f30b98bbc5079eb", size = 147889, upload-time = "2025-10-14T04:41:34.897Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/10/bf/979224a919a1b606c82bd2c5fa49b5c6d5727aa47b4312bb27b1734f53cd/charset_normalizer-3.4.4-cp314-cp314-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:74664978bb272435107de04e36db5a9735e78232b85b77d45cfb38f758efd33e", size = 143641, upload-time = "2025-10-14T04:41:36.116Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/ba/33/0ad65587441fc730dc7bd90e9716b30b4702dc7b617e6ba4997dc8651495/charset_normalizer-3.4.4-cp314-cp314-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:752944c7ffbfdd10c074dc58ec2d5a8a4cd9493b314d367c14d24c17684ddd14", size = 160779, upload-time = "2025-10-14T04:41:37.229Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/67/ed/331d6b249259ee71ddea93f6f2f0a56cfebd46938bde6fcc6f7b9a3d0e09/charset_normalizer-3.4.4-cp314-cp314-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:d1f13550535ad8cff21b8d757a3257963e951d96e20ec82ab44bc64aeb62a191", size = 159035, upload-time = "2025-10-14T04:41:38.368Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/67/ff/f6b948ca32e4f2a4576aa129d8bed61f2e0543bf9f5f2b7fc3758ed005c9/charset_normalizer-3.4.4-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:ecaae4149d99b1c9e7b88bb03e3221956f68fd6d50be2ef061b2381b61d20838", size = 152542, upload-time = "2025-10-14T04:41:39.862Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/16/85/276033dcbcc369eb176594de22728541a925b2632f9716428c851b149e83/charset_normalizer-3.4.4-cp314-cp314-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:cb6254dc36b47a990e59e1068afacdcd02958bdcce30bb50cc1700a8b9d624a6", size = 149524, upload-time = "2025-10-14T04:41:41.319Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/9e/f2/6a2a1f722b6aba37050e626530a46a68f74e63683947a8acff92569f979a/charset_normalizer-3.4.4-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:c8ae8a0f02f57a6e61203a31428fa1d677cbe50c93622b4149d5c0f319c1d19e", size = 150395, upload-time = "2025-10-14T04:41:42.539Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/60/bb/2186cb2f2bbaea6338cad15ce23a67f9b0672929744381e28b0592676824/charset_normalizer-3.4.4-cp314-cp314-musllinux_1_2_armv7l.whl", hash = "sha256:47cc91b2f4dd2833fddaedd2893006b0106129d4b94fdb6af1f4ce5a9965577c", size = 143680, upload-time = "2025-10-14T04:41:43.661Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/7d/a5/bf6f13b772fbb2a90360eb620d52ed8f796f3c5caee8398c3b2eb7b1c60d/charset_normalizer-3.4.4-cp314-cp314-musllinux_1_2_ppc64le.whl", hash = "sha256:82004af6c302b5d3ab2cfc4cc5f29db16123b1a8417f2e25f9066f91d4411090", size = 162045, upload-time = "2025-10-14T04:41:44.821Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/df/c5/d1be898bf0dc3ef9030c3825e5d3b83f2c528d207d246cbabe245966808d/charset_normalizer-3.4.4-cp314-cp314-musllinux_1_2_riscv64.whl", hash = "sha256:2b7d8f6c26245217bd2ad053761201e9f9680f8ce52f0fcd8d0755aeae5b2152", size = 149687, upload-time = "2025-10-14T04:41:46.442Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/a5/42/90c1f7b9341eef50c8a1cb3f098ac43b0508413f33affd762855f67a410e/charset_normalizer-3.4.4-cp314-cp314-musllinux_1_2_s390x.whl", hash = "sha256:799a7a5e4fb2d5898c60b640fd4981d6a25f1c11790935a44ce38c54e985f828", size = 160014, upload-time = "2025-10-14T04:41:47.631Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/76/be/4d3ee471e8145d12795ab655ece37baed0929462a86e72372fd25859047c/charset_normalizer-3.4.4-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:99ae2cffebb06e6c22bdc25801d7b30f503cc87dbd283479e7b606f70aff57ec", size = 154044, upload-time = "2025-10-14T04:41:48.81Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/b0/6f/8f7af07237c34a1defe7defc565a9bc1807762f672c0fde711a4b22bf9c0/charset_normalizer-3.4.4-cp314-cp314-win32.whl", hash = "sha256:f9d332f8c2a2fcbffe1378594431458ddbef721c1769d78e2cbc06280d8155f9", size = 99940, upload-time = "2025-10-14T04:41:49.946Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/4b/51/8ade005e5ca5b0d80fb4aff72a3775b325bdc3d27408c8113811a7cbe640/charset_normalizer-3.4.4-cp314-cp314-win_amd64.whl", hash = "sha256:8a6562c3700cce886c5be75ade4a5db4214fda19fede41d9792d100288d8f94c", size = 107104, upload-time = "2025-10-14T04:41:51.051Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/da/5f/6b8f83a55bb8278772c5ae54a577f3099025f9ade59d0136ac24a0df4bde/charset_normalizer-3.4.4-cp314-cp314-win_arm64.whl", hash = "sha256:de00632ca48df9daf77a2c65a484531649261ec9f25489917f09e455cb09ddb2", size = 100743, upload-time = "2025-10-14T04:41:52.122Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/0a/4c/925909008ed5a988ccbb72dcc897407e5d6d3bd72410d69e051fc0c14647/charset_normalizer-3.4.4-py3-none-any.whl", hash = "sha256:7a32c560861a02ff789ad905a2fe94e3f840803362c84fecf1851cb4cf3dc37f", size = 53402, upload-time = "2025-10-14T04:42:31.76Z" }, +] + [[package]] name = "chex" version = "0.1.91" @@ -103,80 +169,6 @@ wheels = [ { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/d1/d6/3965ed04c63042e047cb6a3e6ed1a63a35087b6a609aa3a15ed8ac56c221/colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6", size = 25335, upload-time = "2022-10-25T02:36:20.889Z" }, ] -[[package]] -name = "coverage" -version = "7.10.7" -source = { registry = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple" } -sdist = { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/51/26/d22c300112504f5f9a9fd2297ce33c35f3d353e4aeb987c8419453b2a7c2/coverage-7.10.7.tar.gz", hash = "sha256:f4ab143ab113be368a3e9b795f9cd7906c5ef407d6173fe9675a902e1fffc239", size = 827704, upload-time = "2025-09-21T20:03:56.815Z" } -wheels = [ - { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/13/e4/eb12450f71b542a53972d19117ea5a5cea1cab3ac9e31b0b5d498df1bd5a/coverage-7.10.7-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:7bb3b9ddb87ef7725056572368040c32775036472d5a033679d1fa6c8dc08417", size = 218290, upload-time = "2025-09-21T20:01:36.455Z" }, - { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/37/66/593f9be12fc19fb36711f19a5371af79a718537204d16ea1d36f16bd78d2/coverage-7.10.7-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:18afb24843cbc175687225cab1138c95d262337f5473512010e46831aa0c2973", size = 218515, upload-time = "2025-09-21T20:01:37.982Z" }, - { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/66/80/4c49f7ae09cafdacc73fbc30949ffe77359635c168f4e9ff33c9ebb07838/coverage-7.10.7-cp312-cp312-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:399a0b6347bcd3822be369392932884b8216d0944049ae22925631a9b3d4ba4c", size = 250020, upload-time = "2025-09-21T20:01:39.617Z" }, - { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/a6/90/a64aaacab3b37a17aaedd83e8000142561a29eb262cede42d94a67f7556b/coverage-7.10.7-cp312-cp312-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:314f2c326ded3f4b09be11bc282eb2fc861184bc95748ae67b360ac962770be7", size = 252769, upload-time = "2025-09-21T20:01:41.341Z" }, - { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/98/2e/2dda59afd6103b342e096f246ebc5f87a3363b5412609946c120f4e7750d/coverage-7.10.7-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:c41e71c9cfb854789dee6fc51e46743a6d138b1803fab6cb860af43265b42ea6", size = 253901, upload-time = "2025-09-21T20:01:43.042Z" }, - { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/53/dc/8d8119c9051d50f3119bb4a75f29f1e4a6ab9415cd1fa8bf22fcc3fb3b5f/coverage-7.10.7-cp312-cp312-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:bc01f57ca26269c2c706e838f6422e2a8788e41b3e3c65e2f41148212e57cd59", size = 250413, upload-time = "2025-09-21T20:01:44.469Z" }, - { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/98/b3/edaff9c5d79ee4d4b6d3fe046f2b1d799850425695b789d491a64225d493/coverage-7.10.7-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:a6442c59a8ac8b85812ce33bc4d05bde3fb22321fa8294e2a5b487c3505f611b", size = 251820, upload-time = "2025-09-21T20:01:45.915Z" }, - { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/11/25/9a0728564bb05863f7e513e5a594fe5ffef091b325437f5430e8cfb0d530/coverage-7.10.7-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:78a384e49f46b80fb4c901d52d92abe098e78768ed829c673fbb53c498bef73a", size = 249941, upload-time = "2025-09-21T20:01:47.296Z" }, - { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/e0/fd/ca2650443bfbef5b0e74373aac4df67b08180d2f184b482c41499668e258/coverage-7.10.7-cp312-cp312-musllinux_1_2_riscv64.whl", hash = "sha256:5e1e9802121405ede4b0133aa4340ad8186a1d2526de5b7c3eca519db7bb89fb", size = 249519, upload-time = "2025-09-21T20:01:48.73Z" }, - { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/24/79/f692f125fb4299b6f963b0745124998ebb8e73ecdfce4ceceb06a8c6bec5/coverage-7.10.7-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:d41213ea25a86f69efd1575073d34ea11aabe075604ddf3d148ecfec9e1e96a1", size = 251375, upload-time = "2025-09-21T20:01:50.529Z" }, - { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/5e/75/61b9bbd6c7d24d896bfeec57acba78e0f8deac68e6baf2d4804f7aae1f88/coverage-7.10.7-cp312-cp312-win32.whl", hash = "sha256:77eb4c747061a6af8d0f7bdb31f1e108d172762ef579166ec84542f711d90256", size = 220699, upload-time = "2025-09-21T20:01:51.941Z" }, - { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/ca/f3/3bf7905288b45b075918d372498f1cf845b5b579b723c8fd17168018d5f5/coverage-7.10.7-cp312-cp312-win_amd64.whl", hash = "sha256:f51328ffe987aecf6d09f3cd9d979face89a617eacdaea43e7b3080777f647ba", size = 221512, upload-time = "2025-09-21T20:01:53.481Z" }, - { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/5c/44/3e32dbe933979d05cf2dac5e697c8599cfe038aaf51223ab901e208d5a62/coverage-7.10.7-cp312-cp312-win_arm64.whl", hash = "sha256:bda5e34f8a75721c96085903c6f2197dc398c20ffd98df33f866a9c8fd95f4bf", size = 220147, upload-time = "2025-09-21T20:01:55.2Z" }, - { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/9a/94/b765c1abcb613d103b64fcf10395f54d69b0ef8be6a0dd9c524384892cc7/coverage-7.10.7-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:981a651f543f2854abd3b5fcb3263aac581b18209be49863ba575de6edf4c14d", size = 218320, upload-time = "2025-09-21T20:01:56.629Z" }, - { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/72/4f/732fff31c119bb73b35236dd333030f32c4bfe909f445b423e6c7594f9a2/coverage-7.10.7-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:73ab1601f84dc804f7812dc297e93cd99381162da39c47040a827d4e8dafe63b", size = 218575, upload-time = "2025-09-21T20:01:58.203Z" }, - { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/87/02/ae7e0af4b674be47566707777db1aa375474f02a1d64b9323e5813a6cdd5/coverage-7.10.7-cp313-cp313-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:a8b6f03672aa6734e700bbcd65ff050fd19cddfec4b031cc8cf1c6967de5a68e", size = 249568, upload-time = "2025-09-21T20:01:59.748Z" }, - { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/a2/77/8c6d22bf61921a59bce5471c2f1f7ac30cd4ac50aadde72b8c48d5727902/coverage-7.10.7-cp313-cp313-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:10b6ba00ab1132a0ce4428ff68cf50a25efd6840a42cdf4239c9b99aad83be8b", size = 252174, upload-time = "2025-09-21T20:02:01.192Z" }, - { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/b1/20/b6ea4f69bbb52dac0aebd62157ba6a9dddbfe664f5af8122dac296c3ee15/coverage-7.10.7-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:c79124f70465a150e89340de5963f936ee97097d2ef76c869708c4248c63ca49", size = 253447, upload-time = "2025-09-21T20:02:02.701Z" }, - { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/f9/28/4831523ba483a7f90f7b259d2018fef02cb4d5b90bc7c1505d6e5a84883c/coverage-7.10.7-cp313-cp313-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:69212fbccdbd5b0e39eac4067e20a4a5256609e209547d86f740d68ad4f04911", size = 249779, upload-time = "2025-09-21T20:02:04.185Z" }, - { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/a7/9f/4331142bc98c10ca6436d2d620c3e165f31e6c58d43479985afce6f3191c/coverage-7.10.7-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:7ea7c6c9d0d286d04ed3541747e6597cbe4971f22648b68248f7ddcd329207f0", size = 251604, upload-time = "2025-09-21T20:02:06.034Z" }, - { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/ce/60/bda83b96602036b77ecf34e6393a3836365481b69f7ed7079ab85048202b/coverage-7.10.7-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:b9be91986841a75042b3e3243d0b3cb0b2434252b977baaf0cd56e960fe1e46f", size = 249497, upload-time = "2025-09-21T20:02:07.619Z" }, - { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/5f/af/152633ff35b2af63977edd835d8e6430f0caef27d171edf2fc76c270ef31/coverage-7.10.7-cp313-cp313-musllinux_1_2_riscv64.whl", hash = "sha256:b281d5eca50189325cfe1f365fafade89b14b4a78d9b40b05ddd1fc7d2a10a9c", size = 249350, upload-time = "2025-09-21T20:02:10.34Z" }, - { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/9d/71/d92105d122bd21cebba877228990e1646d862e34a98bb3374d3fece5a794/coverage-7.10.7-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:99e4aa63097ab1118e75a848a28e40d68b08a5e19ce587891ab7fd04475e780f", size = 251111, upload-time = "2025-09-21T20:02:12.122Z" }, - { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/a2/9e/9fdb08f4bf476c912f0c3ca292e019aab6712c93c9344a1653986c3fd305/coverage-7.10.7-cp313-cp313-win32.whl", hash = "sha256:dc7c389dce432500273eaf48f410b37886be9208b2dd5710aaf7c57fd442c698", size = 220746, upload-time = "2025-09-21T20:02:13.919Z" }, - { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/b1/b1/a75fd25df44eab52d1931e89980d1ada46824c7a3210be0d3c88a44aaa99/coverage-7.10.7-cp313-cp313-win_amd64.whl", hash = "sha256:cac0fdca17b036af3881a9d2729a850b76553f3f716ccb0360ad4dbc06b3b843", size = 221541, upload-time = "2025-09-21T20:02:15.57Z" }, - { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/14/3a/d720d7c989562a6e9a14b2c9f5f2876bdb38e9367126d118495b89c99c37/coverage-7.10.7-cp313-cp313-win_arm64.whl", hash = "sha256:4b6f236edf6e2f9ae8fcd1332da4e791c1b6ba0dc16a2dc94590ceccb482e546", size = 220170, upload-time = "2025-09-21T20:02:17.395Z" }, - { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/bb/22/e04514bf2a735d8b0add31d2b4ab636fc02370730787c576bb995390d2d5/coverage-7.10.7-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:a0ec07fd264d0745ee396b666d47cef20875f4ff2375d7c4f58235886cc1ef0c", size = 219029, upload-time = "2025-09-21T20:02:18.936Z" }, - { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/11/0b/91128e099035ece15da3445d9015e4b4153a6059403452d324cbb0a575fa/coverage-7.10.7-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:dd5e856ebb7bfb7672b0086846db5afb4567a7b9714b8a0ebafd211ec7ce6a15", size = 219259, upload-time = "2025-09-21T20:02:20.44Z" }, - { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/8b/51/66420081e72801536a091a0c8f8c1f88a5c4bf7b9b1bdc6222c7afe6dc9b/coverage-7.10.7-cp313-cp313t-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:f57b2a3c8353d3e04acf75b3fed57ba41f5c0646bbf1d10c7c282291c97936b4", size = 260592, upload-time = "2025-09-21T20:02:22.313Z" }, - { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/5d/22/9b8d458c2881b22df3db5bb3e7369e63d527d986decb6c11a591ba2364f7/coverage-7.10.7-cp313-cp313t-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:1ef2319dd15a0b009667301a3f84452a4dc6fddfd06b0c5c53ea472d3989fbf0", size = 262768, upload-time = "2025-09-21T20:02:24.287Z" }, - { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/f7/08/16bee2c433e60913c610ea200b276e8eeef084b0d200bdcff69920bd5828/coverage-7.10.7-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:83082a57783239717ceb0ad584de3c69cf581b2a95ed6bf81ea66034f00401c0", size = 264995, upload-time = "2025-09-21T20:02:26.133Z" }, - { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/20/9d/e53eb9771d154859b084b90201e5221bca7674ba449a17c101a5031d4054/coverage-7.10.7-cp313-cp313t-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:50aa94fb1fb9a397eaa19c0d5ec15a5edd03a47bf1a3a6111a16b36e190cff65", size = 259546, upload-time = "2025-09-21T20:02:27.716Z" }, - { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/ad/b0/69bc7050f8d4e56a89fb550a1577d5d0d1db2278106f6f626464067b3817/coverage-7.10.7-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:2120043f147bebb41c85b97ac45dd173595ff14f2a584f2963891cbcc3091541", size = 262544, upload-time = "2025-09-21T20:02:29.216Z" }, - { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/ef/4b/2514b060dbd1bc0aaf23b852c14bb5818f244c664cb16517feff6bb3a5ab/coverage-7.10.7-cp313-cp313t-musllinux_1_2_i686.whl", hash = "sha256:2fafd773231dd0378fdba66d339f84904a8e57a262f583530f4f156ab83863e6", size = 260308, upload-time = "2025-09-21T20:02:31.226Z" }, - { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/54/78/7ba2175007c246d75e496f64c06e94122bdb914790a1285d627a918bd271/coverage-7.10.7-cp313-cp313t-musllinux_1_2_riscv64.whl", hash = "sha256:0b944ee8459f515f28b851728ad224fa2d068f1513ef6b7ff1efafeb2185f999", size = 258920, upload-time = "2025-09-21T20:02:32.823Z" }, - { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/c0/b3/fac9f7abbc841409b9a410309d73bfa6cfb2e51c3fada738cb607ce174f8/coverage-7.10.7-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:4b583b97ab2e3efe1b3e75248a9b333bd3f8b0b1b8e5b45578e05e5850dfb2c2", size = 261434, upload-time = "2025-09-21T20:02:34.86Z" }, - { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/ee/51/a03bec00d37faaa891b3ff7387192cef20f01604e5283a5fabc95346befa/coverage-7.10.7-cp313-cp313t-win32.whl", hash = "sha256:2a78cd46550081a7909b3329e2266204d584866e8d97b898cd7fb5ac8d888b1a", size = 221403, upload-time = "2025-09-21T20:02:37.034Z" }, - { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/53/22/3cf25d614e64bf6d8e59c7c669b20d6d940bb337bdee5900b9ca41c820bb/coverage-7.10.7-cp313-cp313t-win_amd64.whl", hash = "sha256:33a5e6396ab684cb43dc7befa386258acb2d7fae7f67330ebb85ba4ea27938eb", size = 222469, upload-time = "2025-09-21T20:02:39.011Z" }, - { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/49/a1/00164f6d30d8a01c3c9c48418a7a5be394de5349b421b9ee019f380df2a0/coverage-7.10.7-cp313-cp313t-win_arm64.whl", hash = "sha256:86b0e7308289ddde73d863b7683f596d8d21c7d8664ce1dee061d0bcf3fbb4bb", size = 220731, upload-time = "2025-09-21T20:02:40.939Z" }, - { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/23/9c/5844ab4ca6a4dd97a1850e030a15ec7d292b5c5cb93082979225126e35dd/coverage-7.10.7-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:b06f260b16ead11643a5a9f955bd4b5fd76c1a4c6796aeade8520095b75de520", size = 218302, upload-time = "2025-09-21T20:02:42.527Z" }, - { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/f0/89/673f6514b0961d1f0e20ddc242e9342f6da21eaba3489901b565c0689f34/coverage-7.10.7-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:212f8f2e0612778f09c55dd4872cb1f64a1f2b074393d139278ce902064d5b32", size = 218578, upload-time = "2025-09-21T20:02:44.468Z" }, - { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/05/e8/261cae479e85232828fb17ad536765c88dd818c8470aca690b0ac6feeaa3/coverage-7.10.7-cp314-cp314-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:3445258bcded7d4aa630ab8296dea4d3f15a255588dd535f980c193ab6b95f3f", size = 249629, upload-time = "2025-09-21T20:02:46.503Z" }, - { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/82/62/14ed6546d0207e6eda876434e3e8475a3e9adbe32110ce896c9e0c06bb9a/coverage-7.10.7-cp314-cp314-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:bb45474711ba385c46a0bfe696c695a929ae69ac636cda8f532be9e8c93d720a", size = 252162, upload-time = "2025-09-21T20:02:48.689Z" }, - { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/ff/49/07f00db9ac6478e4358165a08fb41b469a1b053212e8a00cb02f0d27a05f/coverage-7.10.7-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:813922f35bd800dca9994c5971883cbc0d291128a5de6b167c7aa697fcf59360", size = 253517, upload-time = "2025-09-21T20:02:50.31Z" }, - { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/a2/59/c5201c62dbf165dfbc91460f6dbbaa85a8b82cfa6131ac45d6c1bfb52deb/coverage-7.10.7-cp314-cp314-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:93c1b03552081b2a4423091d6fb3787265b8f86af404cff98d1b5342713bdd69", size = 249632, upload-time = "2025-09-21T20:02:51.971Z" }, - { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/07/ae/5920097195291a51fb00b3a70b9bbd2edbfe3c84876a1762bd1ef1565ebc/coverage-7.10.7-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:cc87dd1b6eaf0b848eebb1c86469b9f72a1891cb42ac7adcfbce75eadb13dd14", size = 251520, upload-time = "2025-09-21T20:02:53.858Z" }, - { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/b9/3c/a815dde77a2981f5743a60b63df31cb322c944843e57dbd579326625a413/coverage-7.10.7-cp314-cp314-musllinux_1_2_i686.whl", hash = "sha256:39508ffda4f343c35f3236fe8d1a6634a51f4581226a1262769d7f970e73bffe", size = 249455, upload-time = "2025-09-21T20:02:55.807Z" }, - { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/aa/99/f5cdd8421ea656abefb6c0ce92556709db2265c41e8f9fc6c8ae0f7824c9/coverage-7.10.7-cp314-cp314-musllinux_1_2_riscv64.whl", hash = "sha256:925a1edf3d810537c5a3abe78ec5530160c5f9a26b1f4270b40e62cc79304a1e", size = 249287, upload-time = "2025-09-21T20:02:57.784Z" }, - { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/c3/7a/e9a2da6a1fc5d007dd51fca083a663ab930a8c4d149c087732a5dbaa0029/coverage-7.10.7-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:2c8b9a0636f94c43cd3576811e05b89aa9bc2d0a85137affc544ae5cb0e4bfbd", size = 250946, upload-time = "2025-09-21T20:02:59.431Z" }, - { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/ef/5b/0b5799aa30380a949005a353715095d6d1da81927d6dbed5def2200a4e25/coverage-7.10.7-cp314-cp314-win32.whl", hash = "sha256:b7b8288eb7cdd268b0304632da8cb0bb93fadcfec2fe5712f7b9cc8f4d487be2", size = 221009, upload-time = "2025-09-21T20:03:01.324Z" }, - { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/da/b0/e802fbb6eb746de006490abc9bb554b708918b6774b722bb3a0e6aa1b7de/coverage-7.10.7-cp314-cp314-win_amd64.whl", hash = "sha256:1ca6db7c8807fb9e755d0379ccc39017ce0a84dcd26d14b5a03b78563776f681", size = 221804, upload-time = "2025-09-21T20:03:03.4Z" }, - { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/9e/e8/71d0c8e374e31f39e3389bb0bd19e527d46f00ea8571ec7ec8fd261d8b44/coverage-7.10.7-cp314-cp314-win_arm64.whl", hash = "sha256:097c1591f5af4496226d5783d036bf6fd6cd0cbc132e071b33861de756efb880", size = 220384, upload-time = "2025-09-21T20:03:05.111Z" }, - { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/62/09/9a5608d319fa3eba7a2019addeacb8c746fb50872b57a724c9f79f146969/coverage-7.10.7-cp314-cp314t-macosx_10_13_x86_64.whl", hash = "sha256:a62c6ef0d50e6de320c270ff91d9dd0a05e7250cac2a800b7784bae474506e63", size = 219047, upload-time = "2025-09-21T20:03:06.795Z" }, - { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/f5/6f/f58d46f33db9f2e3647b2d0764704548c184e6f5e014bef528b7f979ef84/coverage-7.10.7-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:9fa6e4dd51fe15d8738708a973470f67a855ca50002294852e9571cdbd9433f2", size = 219266, upload-time = "2025-09-21T20:03:08.495Z" }, - { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/74/5c/183ffc817ba68e0b443b8c934c8795553eb0c14573813415bd59941ee165/coverage-7.10.7-cp314-cp314t-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:8fb190658865565c549b6b4706856d6a7b09302c797eb2cf8e7fe9dabb043f0d", size = 260767, upload-time = "2025-09-21T20:03:10.172Z" }, - { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/0f/48/71a8abe9c1ad7e97548835e3cc1adbf361e743e9d60310c5f75c9e7bf847/coverage-7.10.7-cp314-cp314t-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:affef7c76a9ef259187ef31599a9260330e0335a3011732c4b9effa01e1cd6e0", size = 262931, upload-time = "2025-09-21T20:03:11.861Z" }, - { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/84/fd/193a8fb132acfc0a901f72020e54be5e48021e1575bb327d8ee1097a28fd/coverage-7.10.7-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:6e16e07d85ca0cf8bafe5f5d23a0b850064e8e945d5677492b06bbe6f09cc699", size = 265186, upload-time = "2025-09-21T20:03:13.539Z" }, - { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/b1/8f/74ecc30607dd95ad50e3034221113ccb1c6d4e8085cc761134782995daae/coverage-7.10.7-cp314-cp314t-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:03ffc58aacdf65d2a82bbeb1ffe4d01ead4017a21bfd0454983b88ca73af94b9", size = 259470, upload-time = "2025-09-21T20:03:15.584Z" }, - { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/0f/55/79ff53a769f20d71b07023ea115c9167c0bb56f281320520cf64c5298a96/coverage-7.10.7-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:1b4fd784344d4e52647fd7857b2af5b3fbe6c239b0b5fa63e94eb67320770e0f", size = 262626, upload-time = "2025-09-21T20:03:17.673Z" }, - { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/88/e2/dac66c140009b61ac3fc13af673a574b00c16efdf04f9b5c740703e953c0/coverage-7.10.7-cp314-cp314t-musllinux_1_2_i686.whl", hash = "sha256:0ebbaddb2c19b71912c6f2518e791aa8b9f054985a0769bdb3a53ebbc765c6a1", size = 260386, upload-time = "2025-09-21T20:03:19.36Z" }, - { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/a2/f1/f48f645e3f33bb9ca8a496bc4a9671b52f2f353146233ebd7c1df6160440/coverage-7.10.7-cp314-cp314t-musllinux_1_2_riscv64.whl", hash = "sha256:a2d9a3b260cc1d1dbdb1c582e63ddcf5363426a1a68faa0f5da28d8ee3c722a0", size = 258852, upload-time = "2025-09-21T20:03:21.007Z" }, - { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/bb/3b/8442618972c51a7affeead957995cfa8323c0c9bcf8fa5a027421f720ff4/coverage-7.10.7-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:a3cc8638b2480865eaa3926d192e64ce6c51e3d29c849e09d5b4ad95efae5399", size = 261534, upload-time = "2025-09-21T20:03:23.12Z" }, - { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/b2/dc/101f3fa3a45146db0cb03f5b4376e24c0aac818309da23e2de0c75295a91/coverage-7.10.7-cp314-cp314t-win32.whl", hash = "sha256:67f8c5cbcd3deb7a60b3345dffc89a961a484ed0af1f6f73de91705cc6e31235", size = 221784, upload-time = "2025-09-21T20:03:24.769Z" }, - { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/4c/a1/74c51803fc70a8a40d7346660379e144be772bab4ac7bb6e6b905152345c/coverage-7.10.7-cp314-cp314t-win_amd64.whl", hash = "sha256:e1ed71194ef6dea7ed2d5cb5f7243d4bcd334bfb63e59878519be558078f848d", size = 222905, upload-time = "2025-09-21T20:03:26.93Z" }, - { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/12/65/f116a6d2127df30bcafbceef0302d8a64ba87488bf6f73a6d8eebf060873/coverage-7.10.7-cp314-cp314t-win_arm64.whl", hash = "sha256:7fe650342addd8524ca63d77b2362b02345e5f1a093266787d210c70a50b471a", size = 220922, upload-time = "2025-09-21T20:03:28.672Z" }, - { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/ec/16/114df1c291c22cac3b0c127a73e0af5c12ed7bbb6558d310429a0ae24023/coverage-7.10.7-py3-none-any.whl", hash = "sha256:f7941f6f2fe6dd6807a1208737b8a0cbcf1cc6d7b07d24998ad2d63590868260", size = 209952, upload-time = "2025-09-21T20:03:53.918Z" }, -] - [[package]] name = "dm-tree" version = "0.1.9" @@ -202,6 +194,15 @@ wheels = [ { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/c5/37/15603079854394f16e3833a7b50696c1f3cbf30a2243a119f64f18a16f36/dm_tree-0.1.9-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e1f5d1e96b3a7de22b25b13a5eb30f41f8cf9c02dd4479a24920de99e780903c", size = 153052, upload-time = "2025-01-30T20:45:35.907Z" }, ] +[[package]] +name = "docstring-parser" +version = "0.17.0" +source = { registry = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple" } +sdist = { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/b2/9d/c3b43da9515bd270df0f80548d9944e389870713cc1fe2b8fb35fe2bcefd/docstring_parser-0.17.0.tar.gz", hash = "sha256:583de4a309722b3315439bb31d64ba3eebada841f2e2cee23b99df001434c912", size = 27442, upload-time = "2025-07-21T07:35:01.868Z" } +wheels = [ + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/55/e2/2537ebcff11c1ee1ff17d8d0b6f4db75873e3b0fb32c2d4a2ee31ecb310a/docstring_parser-0.17.0-py3-none-any.whl", hash = "sha256:cf2569abd23dce8099b300f9b4fa8191e9582dda731fd533daf54c4551658708", size = 36896, upload-time = "2025-07-21T07:35:00.684Z" }, +] + [[package]] name = "einops" version = "0.8.1" @@ -221,6 +222,14 @@ wheels = [ ] [package.optional-dependencies] +edc = [ + { name = "typing-extensions" }, +] +enp = [ + { name = "einops" }, + { name = "numpy" }, + { name = "typing-extensions" }, +] epath = [ { name = "fsspec" }, { name = "importlib-resources" }, @@ -230,6 +239,13 @@ epath = [ epy = [ { name = "typing-extensions" }, ] +etree = [ + { name = "absl-py" }, + { name = "einops" }, + { name = "numpy" }, + { name = "tqdm" }, + { name = "typing-extensions" }, +] [[package]] name = "fsspec" @@ -240,6 +256,18 @@ wheels = [ { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/47/71/70db47e4f6ce3e5c37a607355f80da8860a33226be640226ac52cb05ef2e/fsspec-2025.9.0-py3-none-any.whl", hash = "sha256:530dc2a2af60a414a832059574df4a6e10cce927f6f4a78209390fe38955cfb7", size = 199289, upload-time = "2025-09-02T19:10:47.708Z" }, ] +[[package]] +name = "googleapis-common-protos" +version = "1.72.0" +source = { registry = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple" } +dependencies = [ + { name = "protobuf" }, +] +sdist = { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/e5/7b/adfd75544c415c487b33061fe7ae526165241c1ea133f9a9125a56b39fd8/googleapis_common_protos-1.72.0.tar.gz", hash = "sha256:e55a601c1b32b52d7a3e65f43563e2aa61bcd737998ee672ac9b951cd49319f5", size = 147433, upload-time = "2025-11-06T18:29:24.087Z" } +wheels = [ + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/c4/ab/09169d5a4612a5f92490806649ac8d41e3ec9129c636754575b3553f4ea4/googleapis_common_protos-1.72.0-py3-none-any.whl", hash = "sha256:4299c5a82d5ae1a9702ada957347726b167f9f8d1fc352477702a1e851ff4038", size = 297515, upload-time = "2025-11-06T18:29:13.14Z" }, +] + [[package]] name = "grain" version = "0.2.12" @@ -274,6 +302,24 @@ wheels = [ { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/1e/c7/316e7ca04d26695ef0635dc81683d628350810eb8e9b2299fc08ba49f366/humanize-4.13.0-py3-none-any.whl", hash = "sha256:b810820b31891813b1673e8fec7f1ed3312061eab2f26e3fa192c393d11ed25f", size = 128869, upload-time = "2025-08-25T09:39:18.54Z" }, ] +[[package]] +name = "idna" +version = "3.11" +source = { registry = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple" } +sdist = { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/6f/6d/0703ccc57f3a7233505399edb88de3cbd678da106337b9fcde432b65ed60/idna-3.11.tar.gz", hash = "sha256:795dafcc9c04ed0c1fb032c2aa73654d8e8c5023a7df64a53f39190ada629902", size = 194582, upload-time = "2025-10-12T14:55:20.501Z" } +wheels = [ + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/0e/61/66938bbb5fc52dbdf84594873d5b51fb1f7c7794e9c0f5bd885f30bc507b/idna-3.11-py3-none-any.whl", hash = "sha256:771a87f49d9defaf64091e6e6fe9c18d4833f140bd19464795bc32d966ca37ea", size = 71008, upload-time = "2025-10-12T14:55:18.883Z" }, +] + +[[package]] +name = "immutabledict" +version = "4.2.2" +source = { registry = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple" } +sdist = { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/ce/12/1da8e1a9050d0603ba65fb1796ed8860a705b906701c96e77f85cc7490be/immutabledict-4.2.2.tar.gz", hash = "sha256:cb6ed3090df593148f94cb407d218ca526fd2639694afdb553dc4f50ce6feeca", size = 6099, upload-time = "2025-10-12T13:32:59.755Z" } +wheels = [ + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/63/7b/04ab6afa1ff7eb9ccb09049918c0407b205f5009092c0416147d163e4e2b/immutabledict-4.2.2-py3-none-any.whl", hash = "sha256:97c31d098a2c850e93a958badeef765e4736ed7942ec73e439facd764a3a7217", size = 4736, upload-time = "2025-10-12T13:32:58.326Z" }, +] + [[package]] name = "importlib-resources" version = "6.5.2" @@ -308,6 +354,13 @@ wheels = [ { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/d5/e6/5fd0f6fff79eb47469ff9c4fa27125b661517a2fbf8884689b02e9fdfaa8/jax-0.7.2-py3-none-any.whl", hash = "sha256:e7e32f9be51ae5cc6854225958c57de8cca2187d279844338465b15e8a1fe7f2", size = 2835570, upload-time = "2025-09-16T16:48:51.33Z" }, ] +[package.optional-dependencies] +tpu = [ + { name = "jaxlib" }, + { name = "libtpu" }, + { name = "requests" }, +] + [[package]] name = "jaxlib" version = "0.7.2" @@ -340,11 +393,12 @@ wheels = [ [[package]] name = "julax" -version = "0.0.3.dev0" +version = "0.0.4.dev0" source = { editable = "." } dependencies = [ { name = "einops" }, { name = "grain" }, + { name = "humanize" }, { name = "jax" }, { name = "optax" }, { name = "orbax-checkpoint" }, @@ -353,24 +407,53 @@ dependencies = [ ] [package.optional-dependencies] +tpu = [ + { name = "jax", extra = ["tpu"] }, +] + +[package.dev-dependencies] dev = [ + { name = "pyarrow" }, { name = "pytest" }, - { name = "pytest-cov" }, + { name = "safetensors" }, + { name = "tensorflow-datasets" }, + { name = "tiktoken" }, ] [package.metadata] requires-dist = [ { name = "einops", specifier = ">=0.8.1" }, { name = "grain", specifier = ">=0.2.12" }, + { name = "humanize", specifier = ">=4.13.0" }, { name = "jax", specifier = ">=0.7.2" }, + { name = "jax", extras = ["tpu"], marker = "extra == 'tpu'", specifier = ">=0.7.2" }, { name = "optax", specifier = ">=0.2.6" }, { name = "orbax-checkpoint", specifier = ">=0.11.25" }, { name = "plum-dispatch", specifier = ">=2.5.8" }, { name = "pydantic", specifier = ">=2.12.0" }, - { name = "pytest", marker = "extra == 'dev'", specifier = ">=8.3.2" }, - { name = "pytest-cov", marker = "extra == 'dev'", specifier = ">=5.0.0" }, ] -provides-extras = ["dev"] +provides-extras = ["tpu"] + +[package.metadata.requires-dev] +dev = [ + { name = "pyarrow", specifier = ">=22.0.0" }, + { name = "pytest", specifier = ">=8.4.2" }, + { name = "safetensors", specifier = ">=0.7.0" }, + { name = "tensorflow-datasets", specifier = ">=4.9.9" }, + { name = "tiktoken", specifier = ">=0.12.0" }, +] + +[[package]] +name = "libtpu" +version = "0.0.23" +source = { registry = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple" } +wheels = [ + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/41/69/e5c2a6fc622eec8aed70cd86441bb18e60fd4946dcbec2cf12923cf6a290/libtpu-0.0.23-cp312-cp312-manylinux_2_31_x86_64.whl", hash = "sha256:e633423474d8dfec61ee6b282a89f4470172bf6d317cfe6e54ef9da221125074", size = 155127197, upload-time = "2025-09-12T21:19:11.424Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/81/de/a1abb348faa4cab7fd07c0762328c5c7aaa4e4f98413c6fe3724d2217fa5/libtpu-0.0.23-cp313-cp313-manylinux_2_31_x86_64.whl", hash = "sha256:5c8b2a7c98afdfe88d712356023bbe594253b468b31aa2be8a88848a1f560a04", size = 155127356, upload-time = "2025-09-12T21:18:33.791Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/de/10/3b52120060de97384055ca42915ab452494bd04440d58e1b0512c11fa068/libtpu-0.0.23-cp313-cp313t-manylinux_2_31_x86_64.whl", hash = "sha256:d6e875ba793907169de2cbdb8c0832508bb4e7e210c0dcd07b9f3a8484657c1b", size = 155127879, upload-time = "2025-09-12T21:18:45.123Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/31/b1/27322ddf7bfc6854d6703b60369d81916d8f568791e4512e7f7e227e441e/libtpu-0.0.23-cp314-cp314-manylinux_2_31_x86_64.whl", hash = "sha256:7e327e5c02c677ece76374ae680567282f7d5e5aa7fe513c49a3bacc9099af60", size = 155127561, upload-time = "2025-09-12T21:18:53.779Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/d1/c7/1ea05fd6b72dfdd60718470474a36fb325b28eb740fd93c37d7c38b92e0b/libtpu-0.0.23-cp314-cp314t-manylinux_2_31_x86_64.whl", hash = "sha256:f1206abe1805690cdeeb27e827a93fd09f803cf05cbd8108e36ffa506935e815", size = 155127695, upload-time = "2025-09-12T21:19:19.669Z" }, +] [[package]] name = "markdown-it-py" @@ -631,6 +714,15 @@ wheels = [ { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/d1/c1/8ccc8ba81154fb9c29c62032a1aa5e2f56045d1446a4605a249daf433974/plum_dispatch-2.5.8-py3-none-any.whl", hash = "sha256:02c6561718e83b5599c863d8c2bb4a64d8e852ac84ec09e49043145c3f48313a", size = 42061, upload-time = "2025-10-07T17:54:22.953Z" }, ] +[[package]] +name = "promise" +version = "2.3" +source = { registry = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple" } +dependencies = [ + { name = "six" }, +] +sdist = { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/cf/9c/fb5d48abfe5d791cd496e4242ebcf87a4bb2e0c3dcd6e0ae68c11426a528/promise-2.3.tar.gz", hash = "sha256:dfd18337c523ba4b6a58801c164c1904a9d4d1b1747c7d5dbf45b693a49d93d0", size = 19534, upload-time = "2019-12-18T07:31:43.07Z" } + [[package]] name = "protobuf" version = "6.32.1" @@ -645,6 +737,75 @@ wheels = [ { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/97/b7/15cc7d93443d6c6a84626ae3258a91f4c6ac8c0edd5df35ea7658f71b79c/protobuf-6.32.1-py3-none-any.whl", hash = "sha256:2601b779fc7d32a866c6b4404f9d42a3f67c5b9f3f15b4db3cccabe06b95c346", size = 169289, upload-time = "2025-09-11T21:38:41.234Z" }, ] +[[package]] +name = "psutil" +version = "7.1.3" +source = { registry = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple" } +sdist = { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/e1/88/bdd0a41e5857d5d703287598cbf08dad90aed56774ea52ae071bae9071b6/psutil-7.1.3.tar.gz", hash = "sha256:6c86281738d77335af7aec228328e944b30930899ea760ecf33a4dba66be5e74", size = 489059, upload-time = "2025-11-02T12:25:54.619Z" } +wheels = [ + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/bd/93/0c49e776b8734fef56ec9c5c57f923922f2cf0497d62e0f419465f28f3d0/psutil-7.1.3-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:0005da714eee687b4b8decd3d6cc7c6db36215c9e74e5ad2264b90c3df7d92dc", size = 239751, upload-time = "2025-11-02T12:25:58.161Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/6f/8d/b31e39c769e70780f007969815195a55c81a63efebdd4dbe9e7a113adb2f/psutil-7.1.3-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:19644c85dcb987e35eeeaefdc3915d059dac7bd1167cdcdbf27e0ce2df0c08c0", size = 240368, upload-time = "2025-11-02T12:26:00.491Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/62/61/23fd4acc3c9eebbf6b6c78bcd89e5d020cfde4acf0a9233e9d4e3fa698b4/psutil-7.1.3-cp313-cp313t-manylinux2010_x86_64.manylinux_2_12_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:95ef04cf2e5ba0ab9eaafc4a11eaae91b44f4ef5541acd2ee91d9108d00d59a7", size = 287134, upload-time = "2025-11-02T12:26:02.613Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/30/1c/f921a009ea9ceb51aa355cb0cc118f68d354db36eae18174bab63affb3e6/psutil-7.1.3-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:1068c303be3a72f8e18e412c5b2a8f6d31750fb152f9cb106b54090296c9d251", size = 289904, upload-time = "2025-11-02T12:26:05.207Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/a6/82/62d68066e13e46a5116df187d319d1724b3f437ddd0f958756fc052677f4/psutil-7.1.3-cp313-cp313t-win_amd64.whl", hash = "sha256:18349c5c24b06ac5612c0428ec2a0331c26443d259e2a0144a9b24b4395b58fa", size = 249642, upload-time = "2025-11-02T12:26:07.447Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/df/ad/c1cd5fe965c14a0392112f68362cfceb5230819dbb5b1888950d18a11d9f/psutil-7.1.3-cp313-cp313t-win_arm64.whl", hash = "sha256:c525ffa774fe4496282fb0b1187725793de3e7c6b29e41562733cae9ada151ee", size = 245518, upload-time = "2025-11-02T12:26:09.719Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/2e/bb/6670bded3e3236eb4287c7bcdc167e9fae6e1e9286e437f7111caed2f909/psutil-7.1.3-cp314-cp314t-macosx_10_15_x86_64.whl", hash = "sha256:b403da1df4d6d43973dc004d19cee3b848e998ae3154cc8097d139b77156c353", size = 239843, upload-time = "2025-11-02T12:26:11.968Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/b8/66/853d50e75a38c9a7370ddbeefabdd3d3116b9c31ef94dc92c6729bc36bec/psutil-7.1.3-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:ad81425efc5e75da3f39b3e636293360ad8d0b49bed7df824c79764fb4ba9b8b", size = 240369, upload-time = "2025-11-02T12:26:14.358Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/41/bd/313aba97cb5bfb26916dc29cf0646cbe4dd6a89ca69e8c6edce654876d39/psutil-7.1.3-cp314-cp314t-manylinux2010_x86_64.manylinux_2_12_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:8f33a3702e167783a9213db10ad29650ebf383946e91bc77f28a5eb083496bc9", size = 288210, upload-time = "2025-11-02T12:26:16.699Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/c2/fa/76e3c06e760927a0cfb5705eb38164254de34e9bd86db656d4dbaa228b04/psutil-7.1.3-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:fac9cd332c67f4422504297889da5ab7e05fd11e3c4392140f7370f4208ded1f", size = 291182, upload-time = "2025-11-02T12:26:18.848Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/0f/1d/5774a91607035ee5078b8fd747686ebec28a962f178712de100d00b78a32/psutil-7.1.3-cp314-cp314t-win_amd64.whl", hash = "sha256:3792983e23b69843aea49c8f5b8f115572c5ab64c153bada5270086a2123c7e7", size = 250466, upload-time = "2025-11-02T12:26:21.183Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/00/ca/e426584bacb43a5cb1ac91fae1937f478cd8fbe5e4ff96574e698a2c77cd/psutil-7.1.3-cp314-cp314t-win_arm64.whl", hash = "sha256:31d77fcedb7529f27bb3a0472bea9334349f9a04160e8e6e5020f22c59893264", size = 245756, upload-time = "2025-11-02T12:26:23.148Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/ef/94/46b9154a800253e7ecff5aaacdf8ebf43db99de4a2dfa18575b02548654e/psutil-7.1.3-cp36-abi3-macosx_10_9_x86_64.whl", hash = "sha256:2bdbcd0e58ca14996a42adf3621a6244f1bb2e2e528886959c72cf1e326677ab", size = 238359, upload-time = "2025-11-02T12:26:25.284Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/68/3a/9f93cff5c025029a36d9a92fef47220ab4692ee7f2be0fba9f92813d0cb8/psutil-7.1.3-cp36-abi3-macosx_11_0_arm64.whl", hash = "sha256:bc31fa00f1fbc3c3802141eede66f3a2d51d89716a194bf2cd6fc68310a19880", size = 239171, upload-time = "2025-11-02T12:26:27.23Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/ce/b1/5f49af514f76431ba4eea935b8ad3725cdeb397e9245ab919dbc1d1dc20f/psutil-7.1.3-cp36-abi3-manylinux2010_x86_64.manylinux_2_12_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:3bb428f9f05c1225a558f53e30ccbad9930b11c3fc206836242de1091d3e7dd3", size = 263261, upload-time = "2025-11-02T12:26:29.48Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/e0/95/992c8816a74016eb095e73585d747e0a8ea21a061ed3689474fabb29a395/psutil-7.1.3-cp36-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:56d974e02ca2c8eb4812c3f76c30e28836fffc311d55d979f1465c1feeb2b68b", size = 264635, upload-time = "2025-11-02T12:26:31.74Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/55/4c/c3ed1a622b6ae2fd3c945a366e64eb35247a31e4db16cf5095e269e8eb3c/psutil-7.1.3-cp37-abi3-win_amd64.whl", hash = "sha256:f39c2c19fe824b47484b96f9692932248a54c43799a84282cfe58d05a6449efd", size = 247633, upload-time = "2025-11-02T12:26:33.887Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/c9/ad/33b2ccec09bf96c2b2ef3f9a6f66baac8253d7565d8839e024a6b905d45d/psutil-7.1.3-cp37-abi3-win_arm64.whl", hash = "sha256:bd0d69cee829226a761e92f28140bec9a5ee9d5b4fb4b0cc589068dbfff559b1", size = 244608, upload-time = "2025-11-02T12:26:36.136Z" }, +] + +[[package]] +name = "pyarrow" +version = "22.0.0" +source = { registry = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple" } +sdist = { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/30/53/04a7fdc63e6056116c9ddc8b43bc28c12cdd181b85cbeadb79278475f3ae/pyarrow-22.0.0.tar.gz", hash = "sha256:3d600dc583260d845c7d8a6db540339dd883081925da2bd1c5cb808f720b3cd9", size = 1151151, upload-time = "2025-10-24T12:30:00.762Z" } +wheels = [ + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/af/63/ba23862d69652f85b615ca14ad14f3bcfc5bf1b99ef3f0cd04ff93fdad5a/pyarrow-22.0.0-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:bea79263d55c24a32b0d79c00a1c58bb2ee5f0757ed95656b01c0fb310c5af3d", size = 34211578, upload-time = "2025-10-24T10:05:21.583Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/b1/d0/f9ad86fe809efd2bcc8be32032fa72e8b0d112b01ae56a053006376c5930/pyarrow-22.0.0-cp312-cp312-macosx_12_0_x86_64.whl", hash = "sha256:12fe549c9b10ac98c91cf791d2945e878875d95508e1a5d14091a7aaa66d9cf8", size = 35989906, upload-time = "2025-10-24T10:05:29.485Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/b4/a8/f910afcb14630e64d673f15904ec27dd31f1e009b77033c365c84e8c1e1d/pyarrow-22.0.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:334f900ff08ce0423407af97e6c26ad5d4e3b0763645559ece6fbf3747d6a8f5", size = 45021677, upload-time = "2025-10-24T10:05:38.274Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/13/95/aec81f781c75cd10554dc17a25849c720d54feafb6f7847690478dcf5ef8/pyarrow-22.0.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:c6c791b09c57ed76a18b03f2631753a4960eefbbca80f846da8baefc6491fcfe", size = 47726315, upload-time = "2025-10-24T10:05:47.314Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/bb/d4/74ac9f7a54cfde12ee42734ea25d5a3c9a45db78f9def949307a92720d37/pyarrow-22.0.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:c3200cb41cdbc65156e5f8c908d739b0dfed57e890329413da2748d1a2cd1a4e", size = 47990906, upload-time = "2025-10-24T10:05:58.254Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/2e/71/fedf2499bf7a95062eafc989ace56572f3343432570e1c54e6599d5b88da/pyarrow-22.0.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:ac93252226cf288753d8b46280f4edf3433bf9508b6977f8dd8526b521a1bbb9", size = 50306783, upload-time = "2025-10-24T10:06:08.08Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/68/ed/b202abd5a5b78f519722f3d29063dda03c114711093c1995a33b8e2e0f4b/pyarrow-22.0.0-cp312-cp312-win_amd64.whl", hash = "sha256:44729980b6c50a5f2bfcc2668d36c569ce17f8b17bccaf470c4313dcbbf13c9d", size = 27972883, upload-time = "2025-10-24T10:06:14.204Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/a6/d6/d0fac16a2963002fc22c8fa75180a838737203d558f0ed3b564c4a54eef5/pyarrow-22.0.0-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:e6e95176209257803a8b3d0394f21604e796dadb643d2f7ca21b66c9c0b30c9a", size = 34204629, upload-time = "2025-10-24T10:06:20.274Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/c6/9c/1d6357347fbae062ad3f17082f9ebc29cc733321e892c0d2085f42a2212b/pyarrow-22.0.0-cp313-cp313-macosx_12_0_x86_64.whl", hash = "sha256:001ea83a58024818826a9e3f89bf9310a114f7e26dfe404a4c32686f97bd7901", size = 35985783, upload-time = "2025-10-24T10:06:27.301Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/ff/c0/782344c2ce58afbea010150df07e3a2f5fdad299cd631697ae7bd3bac6e3/pyarrow-22.0.0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:ce20fe000754f477c8a9125543f1936ea5b8867c5406757c224d745ed033e691", size = 45020999, upload-time = "2025-10-24T10:06:35.387Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/1b/8b/5362443737a5307a7b67c1017c42cd104213189b4970bf607e05faf9c525/pyarrow-22.0.0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:e0a15757fccb38c410947df156f9749ae4a3c89b2393741a50521f39a8cf202a", size = 47724601, upload-time = "2025-10-24T10:06:43.551Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/69/4d/76e567a4fc2e190ee6072967cb4672b7d9249ac59ae65af2d7e3047afa3b/pyarrow-22.0.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:cedb9dd9358e4ea1d9bce3665ce0797f6adf97ff142c8e25b46ba9cdd508e9b6", size = 48001050, upload-time = "2025-10-24T10:06:52.284Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/01/5e/5653f0535d2a1aef8223cee9d92944cb6bccfee5cf1cd3f462d7cb022790/pyarrow-22.0.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:252be4a05f9d9185bb8c18e83764ebcfea7185076c07a7a662253af3a8c07941", size = 50307877, upload-time = "2025-10-24T10:07:02.405Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/2d/f8/1d0bd75bf9328a3b826e24a16e5517cd7f9fbf8d34a3184a4566ef5a7f29/pyarrow-22.0.0-cp313-cp313-win_amd64.whl", hash = "sha256:a4893d31e5ef780b6edcaf63122df0f8d321088bb0dee4c8c06eccb1ca28d145", size = 27977099, upload-time = "2025-10-24T10:08:07.259Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/90/81/db56870c997805bf2b0f6eeeb2d68458bf4654652dccdcf1bf7a42d80903/pyarrow-22.0.0-cp313-cp313t-macosx_12_0_arm64.whl", hash = "sha256:f7fe3dbe871294ba70d789be16b6e7e52b418311e166e0e3cba9522f0f437fb1", size = 34336685, upload-time = "2025-10-24T10:07:11.47Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/1c/98/0727947f199aba8a120f47dfc229eeb05df15bcd7a6f1b669e9f882afc58/pyarrow-22.0.0-cp313-cp313t-macosx_12_0_x86_64.whl", hash = "sha256:ba95112d15fd4f1105fb2402c4eab9068f0554435e9b7085924bcfaac2cc306f", size = 36032158, upload-time = "2025-10-24T10:07:18.626Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/96/b4/9babdef9c01720a0785945c7cf550e4acd0ebcd7bdd2e6f0aa7981fa85e2/pyarrow-22.0.0-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:c064e28361c05d72eed8e744c9605cbd6d2bb7481a511c74071fd9b24bc65d7d", size = 44892060, upload-time = "2025-10-24T10:07:26.002Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/f8/ca/2f8804edd6279f78a37062d813de3f16f29183874447ef6d1aadbb4efa0f/pyarrow-22.0.0-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:6f9762274496c244d951c819348afbcf212714902742225f649cf02823a6a10f", size = 47504395, upload-time = "2025-10-24T10:07:34.09Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/b9/f0/77aa5198fd3943682b2e4faaf179a674f0edea0d55d326d83cb2277d9363/pyarrow-22.0.0-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:a9d9ffdc2ab696f6b15b4d1f7cec6658e1d788124418cb30030afbae31c64746", size = 48066216, upload-time = "2025-10-24T10:07:43.528Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/79/87/a1937b6e78b2aff18b706d738c9e46ade5bfcf11b294e39c87706a0089ac/pyarrow-22.0.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:ec1a15968a9d80da01e1d30349b2b0d7cc91e96588ee324ce1b5228175043e95", size = 50288552, upload-time = "2025-10-24T10:07:53.519Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/60/ae/b5a5811e11f25788ccfdaa8f26b6791c9807119dffcf80514505527c384c/pyarrow-22.0.0-cp313-cp313t-win_amd64.whl", hash = "sha256:bba208d9c7decf9961998edf5c65e3ea4355d5818dd6cd0f6809bec1afb951cc", size = 28262504, upload-time = "2025-10-24T10:08:00.932Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/bd/b0/0fa4d28a8edb42b0a7144edd20befd04173ac79819547216f8a9f36f9e50/pyarrow-22.0.0-cp314-cp314-macosx_12_0_arm64.whl", hash = "sha256:9bddc2cade6561f6820d4cd73f99a0243532ad506bc510a75a5a65a522b2d74d", size = 34224062, upload-time = "2025-10-24T10:08:14.101Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/0f/a8/7a719076b3c1be0acef56a07220c586f25cd24de0e3f3102b438d18ae5df/pyarrow-22.0.0-cp314-cp314-macosx_12_0_x86_64.whl", hash = "sha256:e70ff90c64419709d38c8932ea9fe1cc98415c4f87ea8da81719e43f02534bc9", size = 35990057, upload-time = "2025-10-24T10:08:21.842Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/89/3c/359ed54c93b47fb6fe30ed16cdf50e3f0e8b9ccfb11b86218c3619ae50a8/pyarrow-22.0.0-cp314-cp314-manylinux_2_28_aarch64.whl", hash = "sha256:92843c305330aa94a36e706c16209cd4df274693e777ca47112617db7d0ef3d7", size = 45068002, upload-time = "2025-10-24T10:08:29.034Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/55/fc/4945896cc8638536ee787a3bd6ce7cec8ec9acf452d78ec39ab328efa0a1/pyarrow-22.0.0-cp314-cp314-manylinux_2_28_x86_64.whl", hash = "sha256:6dda1ddac033d27421c20d7a7943eec60be44e0db4e079f33cc5af3b8280ccde", size = 47737765, upload-time = "2025-10-24T10:08:38.559Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/cd/5e/7cb7edeb2abfaa1f79b5d5eb89432356155c8426f75d3753cbcb9592c0fd/pyarrow-22.0.0-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:84378110dd9a6c06323b41b56e129c504d157d1a983ce8f5443761eb5256bafc", size = 48048139, upload-time = "2025-10-24T10:08:46.784Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/88/c6/546baa7c48185f5e9d6e59277c4b19f30f48c94d9dd938c2a80d4d6b067c/pyarrow-22.0.0-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:854794239111d2b88b40b6ef92aa478024d1e5074f364033e73e21e3f76b25e0", size = 50314244, upload-time = "2025-10-24T10:08:55.771Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/3c/79/755ff2d145aafec8d347bf18f95e4e81c00127f06d080135dfc86aea417c/pyarrow-22.0.0-cp314-cp314-win_amd64.whl", hash = "sha256:b883fe6fd85adad7932b3271c38ac289c65b7337c2c132e9569f9d3940620730", size = 28757501, upload-time = "2025-10-24T10:09:59.891Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/0e/d2/237d75ac28ced3147912954e3c1a174df43a95f4f88e467809118a8165e0/pyarrow-22.0.0-cp314-cp314t-macosx_12_0_arm64.whl", hash = "sha256:7a820d8ae11facf32585507c11f04e3f38343c1e784c9b5a8b1da5c930547fe2", size = 34355506, upload-time = "2025-10-24T10:09:02.953Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/1e/2c/733dfffe6d3069740f98e57ff81007809067d68626c5faef293434d11bd6/pyarrow-22.0.0-cp314-cp314t-macosx_12_0_x86_64.whl", hash = "sha256:c6ec3675d98915bf1ec8b3c7986422682f7232ea76cad276f4c8abd5b7319b70", size = 36047312, upload-time = "2025-10-24T10:09:10.334Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/7c/2b/29d6e3782dc1f299727462c1543af357a0f2c1d3c160ce199950d9ca51eb/pyarrow-22.0.0-cp314-cp314t-manylinux_2_28_aarch64.whl", hash = "sha256:3e739edd001b04f654b166204fc7a9de896cf6007eaff33409ee9e50ceaff754", size = 45081609, upload-time = "2025-10-24T10:09:18.61Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/8d/42/aa9355ecc05997915af1b7b947a7f66c02dcaa927f3203b87871c114ba10/pyarrow-22.0.0-cp314-cp314t-manylinux_2_28_x86_64.whl", hash = "sha256:7388ac685cab5b279a41dfe0a6ccd99e4dbf322edfb63e02fc0443bf24134e91", size = 47703663, upload-time = "2025-10-24T10:09:27.369Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/ee/62/45abedde480168e83a1de005b7b7043fd553321c1e8c5a9a114425f64842/pyarrow-22.0.0-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:f633074f36dbc33d5c05b5dc75371e5660f1dbf9c8b1d95669def05e5425989c", size = 48066543, upload-time = "2025-10-24T10:09:34.908Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/84/e9/7878940a5b072e4f3bf998770acafeae13b267f9893af5f6d4ab3904b67e/pyarrow-22.0.0-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:4c19236ae2402a8663a2c8f21f1870a03cc57f0bef7e4b6eb3238cc82944de80", size = 50288838, upload-time = "2025-10-24T10:09:44.394Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/7b/03/f335d6c52b4a4761bcc83499789a1e2e16d9d201a58c327a9b5cc9a41bd9/pyarrow-22.0.0-cp314-cp314t-win_amd64.whl", hash = "sha256:0c34fe18094686194f204a3b1787a27456897d8a2d62caf84b61e8dfbc0252ae", size = 29185594, upload-time = "2025-10-24T10:09:53.111Z" }, +] + [[package]] name = "pydantic" version = "2.12.0" @@ -748,20 +909,6 @@ wheels = [ { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/a8/a4/20da314d277121d6534b3a980b29035dcd51e6744bd79075a6ce8fa4eb8d/pytest-8.4.2-py3-none-any.whl", hash = "sha256:872f880de3fc3a5bdc88a11b39c9710c3497a547cfa9320bc3c5e62fbf272e79", size = 365750, upload-time = "2025-09-04T14:34:20.226Z" }, ] -[[package]] -name = "pytest-cov" -version = "7.0.0" -source = { registry = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple" } -dependencies = [ - { name = "coverage" }, - { name = "pluggy" }, - { name = "pytest" }, -] -sdist = { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/5e/f7/c933acc76f5208b3b00089573cf6a2bc26dc80a8aece8f52bb7d6b1855ca/pytest_cov-7.0.0.tar.gz", hash = "sha256:33c97eda2e049a0c5298e91f519302a1334c26ac65c1a483d6206fd458361af1", size = 54328, upload-time = "2025-09-09T10:57:02.113Z" } -wheels = [ - { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/ee/49/1377b49de7d0c1ce41292161ea0f721913fa8722c19fb9c1e3aa0367eecb/pytest_cov-7.0.0-py3-none-any.whl", hash = "sha256:3b8e9558b16cc1479da72058bdecf8073661c7f57f7d3c5f22a1c23507f2d861", size = 22424, upload-time = "2025-09-09T10:57:00.695Z" }, -] - [[package]] name = "pyyaml" version = "6.0.3" @@ -808,6 +955,99 @@ wheels = [ { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/f1/12/de94a39c2ef588c7e6455cfbe7343d3b2dc9d6b6b2f40c4c6565744c873d/pyyaml-6.0.3-cp314-cp314t-win_arm64.whl", hash = "sha256:ebc55a14a21cb14062aa4162f906cd962b28e2e9ea38f9b4391244cd8de4ae0b", size = 149341, upload-time = "2025-09-25T21:32:56.828Z" }, ] +[[package]] +name = "regex" +version = "2025.11.3" +source = { registry = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple" } +sdist = { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/cc/a9/546676f25e573a4cf00fe8e119b78a37b6a8fe2dc95cda877b30889c9c45/regex-2025.11.3.tar.gz", hash = "sha256:1fedc720f9bb2494ce31a58a1631f9c82df6a09b49c19517ea5cc280b4541e01", size = 414669, upload-time = "2025-11-03T21:34:22.089Z" } +wheels = [ + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/e8/74/18f04cb53e58e3fb107439699bd8375cf5a835eec81084e0bddbd122e4c2/regex-2025.11.3-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:bc8ab71e2e31b16e40868a40a69007bc305e1109bd4658eb6cad007e0bf67c41", size = 489312, upload-time = "2025-11-03T21:31:34.343Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/78/3f/37fcdd0d2b1e78909108a876580485ea37c91e1acf66d3bb8e736348f441/regex-2025.11.3-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:22b29dda7e1f7062a52359fca6e58e548e28c6686f205e780b02ad8ef710de36", size = 291256, upload-time = "2025-11-03T21:31:35.675Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/bf/26/0a575f58eb23b7ebd67a45fccbc02ac030b737b896b7e7a909ffe43ffd6a/regex-2025.11.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:3a91e4a29938bc1a082cc28fdea44be420bf2bebe2665343029723892eb073e1", size = 288921, upload-time = "2025-11-03T21:31:37.07Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/ea/98/6a8dff667d1af907150432cf5abc05a17ccd32c72a3615410d5365ac167a/regex-2025.11.3-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:08b884f4226602ad40c5d55f52bf91a9df30f513864e0054bad40c0e9cf1afb7", size = 798568, upload-time = "2025-11-03T21:31:38.784Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/64/15/92c1db4fa4e12733dd5a526c2dd2b6edcbfe13257e135fc0f6c57f34c173/regex-2025.11.3-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:3e0b11b2b2433d1c39c7c7a30e3f3d0aeeea44c2a8d0bae28f6b95f639927a69", size = 864165, upload-time = "2025-11-03T21:31:40.559Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/f9/e7/3ad7da8cdee1ce66c7cd37ab5ab05c463a86ffeb52b1a25fe7bd9293b36c/regex-2025.11.3-cp312-cp312-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:87eb52a81ef58c7ba4d45c3ca74e12aa4b4e77816f72ca25258a85b3ea96cb48", size = 912182, upload-time = "2025-11-03T21:31:42.002Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/84/bd/9ce9f629fcb714ffc2c3faf62b6766ecb7a585e1e885eb699bcf130a5209/regex-2025.11.3-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a12ab1f5c29b4e93db518f5e3872116b7e9b1646c9f9f426f777b50d44a09e8c", size = 803501, upload-time = "2025-11-03T21:31:43.815Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/7c/0f/8dc2e4349d8e877283e6edd6c12bdcebc20f03744e86f197ab6e4492bf08/regex-2025.11.3-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:7521684c8c7c4f6e88e35ec89680ee1aa8358d3f09d27dfbdf62c446f5d4c695", size = 787842, upload-time = "2025-11-03T21:31:45.353Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/f9/73/cff02702960bc185164d5619c0c62a2f598a6abff6695d391b096237d4ab/regex-2025.11.3-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:7fe6e5440584e94cc4b3f5f4d98a25e29ca12dccf8873679a635638349831b98", size = 858519, upload-time = "2025-11-03T21:31:46.814Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/61/83/0e8d1ae71e15bc1dc36231c90b46ee35f9d52fab2e226b0e039e7ea9c10a/regex-2025.11.3-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:8e026094aa12b43f4fd74576714e987803a315c76edb6b098b9809db5de58f74", size = 850611, upload-time = "2025-11-03T21:31:48.289Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/c8/f5/70a5cdd781dcfaa12556f2955bf170cd603cb1c96a1827479f8faea2df97/regex-2025.11.3-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:435bbad13e57eb5606a68443af62bed3556de2f46deb9f7d4237bc2f1c9fb3a0", size = 789759, upload-time = "2025-11-03T21:31:49.759Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/59/9b/7c29be7903c318488983e7d97abcf8ebd3830e4c956c4c540005fcfb0462/regex-2025.11.3-cp312-cp312-win32.whl", hash = "sha256:3839967cf4dc4b985e1570fd8d91078f0c519f30491c60f9ac42a8db039be204", size = 266194, upload-time = "2025-11-03T21:31:51.53Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/1a/67/3b92df89f179d7c367be654ab5626ae311cb28f7d5c237b6bb976cd5fbbb/regex-2025.11.3-cp312-cp312-win_amd64.whl", hash = "sha256:e721d1b46e25c481dc5ded6f4b3f66c897c58d2e8cfdf77bbced84339108b0b9", size = 277069, upload-time = "2025-11-03T21:31:53.151Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/d7/55/85ba4c066fe5094d35b249c3ce8df0ba623cfd35afb22d6764f23a52a1c5/regex-2025.11.3-cp312-cp312-win_arm64.whl", hash = "sha256:64350685ff08b1d3a6fff33f45a9ca183dc1d58bbfe4981604e70ec9801bbc26", size = 270330, upload-time = "2025-11-03T21:31:54.514Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/e1/a7/dda24ebd49da46a197436ad96378f17df30ceb40e52e859fc42cac45b850/regex-2025.11.3-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:c1e448051717a334891f2b9a620fe36776ebf3dd8ec46a0b877c8ae69575feb4", size = 489081, upload-time = "2025-11-03T21:31:55.9Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/19/22/af2dc751aacf88089836aa088a1a11c4f21a04707eb1b0478e8e8fb32847/regex-2025.11.3-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:9b5aca4d5dfd7fbfbfbdaf44850fcc7709a01146a797536a8f84952e940cca76", size = 291123, upload-time = "2025-11-03T21:31:57.758Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/a3/88/1a3ea5672f4b0a84802ee9891b86743438e7c04eb0b8f8c4e16a42375327/regex-2025.11.3-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:04d2765516395cf7dda331a244a3282c0f5ae96075f728629287dfa6f76ba70a", size = 288814, upload-time = "2025-11-03T21:32:01.12Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/fb/8c/f5987895bf42b8ddeea1b315c9fedcfe07cadee28b9c98cf50d00adcb14d/regex-2025.11.3-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:5d9903ca42bfeec4cebedba8022a7c97ad2aab22e09573ce9976ba01b65e4361", size = 798592, upload-time = "2025-11-03T21:32:03.006Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/99/2a/6591ebeede78203fa77ee46a1c36649e02df9eaa77a033d1ccdf2fcd5d4e/regex-2025.11.3-cp313-cp313-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:639431bdc89d6429f6721625e8129413980ccd62e9d3f496be618a41d205f160", size = 864122, upload-time = "2025-11-03T21:32:04.553Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/94/d6/be32a87cf28cf8ed064ff281cfbd49aefd90242a83e4b08b5a86b38e8eb4/regex-2025.11.3-cp313-cp313-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:f117efad42068f9715677c8523ed2be1518116d1c49b1dd17987716695181efe", size = 912272, upload-time = "2025-11-03T21:32:06.148Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/62/11/9bcef2d1445665b180ac7f230406ad80671f0fc2a6ffb93493b5dd8cd64c/regex-2025.11.3-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:4aecb6f461316adf9f1f0f6a4a1a3d79e045f9b71ec76055a791affa3b285850", size = 803497, upload-time = "2025-11-03T21:32:08.162Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/e5/a7/da0dc273d57f560399aa16d8a68ae7f9b57679476fc7ace46501d455fe84/regex-2025.11.3-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:3b3a5f320136873cc5561098dfab677eea139521cb9a9e8db98b7e64aef44cbc", size = 787892, upload-time = "2025-11-03T21:32:09.769Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/da/4b/732a0c5a9736a0b8d6d720d4945a2f1e6f38f87f48f3173559f53e8d5d82/regex-2025.11.3-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:75fa6f0056e7efb1f42a1c34e58be24072cb9e61a601340cc1196ae92326a4f9", size = 858462, upload-time = "2025-11-03T21:32:11.769Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/0c/f5/a2a03df27dc4c2d0c769220f5110ba8c4084b0bfa9ab0f9b4fcfa3d2b0fc/regex-2025.11.3-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:dbe6095001465294f13f1adcd3311e50dd84e5a71525f20a10bd16689c61ce0b", size = 850528, upload-time = "2025-11-03T21:32:13.906Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/d6/09/e1cd5bee3841c7f6eb37d95ca91cdee7100b8f88b81e41c2ef426910891a/regex-2025.11.3-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:454d9b4ae7881afbc25015b8627c16d88a597479b9dea82b8c6e7e2e07240dc7", size = 789866, upload-time = "2025-11-03T21:32:15.748Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/eb/51/702f5ea74e2a9c13d855a6a85b7f80c30f9e72a95493260193c07f3f8d74/regex-2025.11.3-cp313-cp313-win32.whl", hash = "sha256:28ba4d69171fc6e9896337d4fc63a43660002b7da53fc15ac992abcf3410917c", size = 266189, upload-time = "2025-11-03T21:32:17.493Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/8b/00/6e29bb314e271a743170e53649db0fdb8e8ff0b64b4f425f5602f4eb9014/regex-2025.11.3-cp313-cp313-win_amd64.whl", hash = "sha256:bac4200befe50c670c405dc33af26dad5a3b6b255dd6c000d92fe4629f9ed6a5", size = 277054, upload-time = "2025-11-03T21:32:19.042Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/25/f1/b156ff9f2ec9ac441710764dda95e4edaf5f36aca48246d1eea3f1fd96ec/regex-2025.11.3-cp313-cp313-win_arm64.whl", hash = "sha256:2292cd5a90dab247f9abe892ac584cb24f0f54680c73fcb4a7493c66c2bf2467", size = 270325, upload-time = "2025-11-03T21:32:21.338Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/20/28/fd0c63357caefe5680b8ea052131acbd7f456893b69cc2a90cc3e0dc90d4/regex-2025.11.3-cp313-cp313t-macosx_10_13_universal2.whl", hash = "sha256:1eb1ebf6822b756c723e09f5186473d93236c06c579d2cc0671a722d2ab14281", size = 491984, upload-time = "2025-11-03T21:32:23.466Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/df/ec/7014c15626ab46b902b3bcc4b28a7bae46d8f281fc7ea9c95e22fcaaa917/regex-2025.11.3-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:1e00ec2970aab10dc5db34af535f21fcf32b4a31d99e34963419636e2f85ae39", size = 292673, upload-time = "2025-11-03T21:32:25.034Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/23/ab/3b952ff7239f20d05f1f99e9e20188513905f218c81d52fb5e78d2bf7634/regex-2025.11.3-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:a4cb042b615245d5ff9b3794f56be4138b5adc35a4166014d31d1814744148c7", size = 291029, upload-time = "2025-11-03T21:32:26.528Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/21/7e/3dc2749fc684f455f162dcafb8a187b559e2614f3826877d3844a131f37b/regex-2025.11.3-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:44f264d4bf02f3176467d90b294d59bf1db9fe53c141ff772f27a8b456b2a9ed", size = 807437, upload-time = "2025-11-03T21:32:28.363Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/1b/0b/d529a85ab349c6a25d1ca783235b6e3eedf187247eab536797021f7126c6/regex-2025.11.3-cp313-cp313t-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:7be0277469bf3bd7a34a9c57c1b6a724532a0d235cd0dc4e7f4316f982c28b19", size = 873368, upload-time = "2025-11-03T21:32:30.4Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/7d/18/2d868155f8c9e3e9d8f9e10c64e9a9f496bb8f7e037a88a8bed26b435af6/regex-2025.11.3-cp313-cp313t-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:0d31e08426ff4b5b650f68839f5af51a92a5b51abd8554a60c2fbc7c71f25d0b", size = 914921, upload-time = "2025-11-03T21:32:32.123Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/2d/71/9d72ff0f354fa783fe2ba913c8734c3b433b86406117a8db4ea2bf1c7a2f/regex-2025.11.3-cp313-cp313t-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:e43586ce5bd28f9f285a6e729466841368c4a0353f6fd08d4ce4630843d3648a", size = 812708, upload-time = "2025-11-03T21:32:34.305Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/e7/19/ce4bf7f5575c97f82b6e804ffb5c4e940c62609ab2a0d9538d47a7fdf7d4/regex-2025.11.3-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:0f9397d561a4c16829d4e6ff75202c1c08b68a3bdbfe29dbfcdb31c9830907c6", size = 795472, upload-time = "2025-11-03T21:32:36.364Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/03/86/fd1063a176ffb7b2315f9a1b08d17b18118b28d9df163132615b835a26ee/regex-2025.11.3-cp313-cp313t-musllinux_1_2_ppc64le.whl", hash = "sha256:dd16e78eb18ffdb25ee33a0682d17912e8cc8a770e885aeee95020046128f1ce", size = 868341, upload-time = "2025-11-03T21:32:38.042Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/12/43/103fb2e9811205e7386366501bc866a164a0430c79dd59eac886a2822950/regex-2025.11.3-cp313-cp313t-musllinux_1_2_s390x.whl", hash = "sha256:ffcca5b9efe948ba0661e9df0fa50d2bc4b097c70b9810212d6b62f05d83b2dd", size = 854666, upload-time = "2025-11-03T21:32:40.079Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/7d/22/e392e53f3869b75804762c7c848bd2dd2abf2b70fb0e526f58724638bd35/regex-2025.11.3-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:c56b4d162ca2b43318ac671c65bd4d563e841a694ac70e1a976ac38fcf4ca1d2", size = 799473, upload-time = "2025-11-03T21:32:42.148Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/4f/f9/8bd6b656592f925b6845fcbb4d57603a3ac2fb2373344ffa1ed70aa6820a/regex-2025.11.3-cp313-cp313t-win32.whl", hash = "sha256:9ddc42e68114e161e51e272f667d640f97e84a2b9ef14b7477c53aac20c2d59a", size = 268792, upload-time = "2025-11-03T21:32:44.13Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/e5/87/0e7d603467775ff65cd2aeabf1b5b50cc1c3708556a8b849a2fa4dd1542b/regex-2025.11.3-cp313-cp313t-win_amd64.whl", hash = "sha256:7a7c7fdf755032ffdd72c77e3d8096bdcb0eb92e89e17571a196f03d88b11b3c", size = 280214, upload-time = "2025-11-03T21:32:45.853Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/8d/d0/2afc6f8e94e2b64bfb738a7c2b6387ac1699f09f032d363ed9447fd2bb57/regex-2025.11.3-cp313-cp313t-win_arm64.whl", hash = "sha256:df9eb838c44f570283712e7cff14c16329a9f0fb19ca492d21d4b7528ee6821e", size = 271469, upload-time = "2025-11-03T21:32:48.026Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/31/e9/f6e13de7e0983837f7b6d238ad9458800a874bf37c264f7923e63409944c/regex-2025.11.3-cp314-cp314-macosx_10_13_universal2.whl", hash = "sha256:9697a52e57576c83139d7c6f213d64485d3df5bf84807c35fa409e6c970801c6", size = 489089, upload-time = "2025-11-03T21:32:50.027Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/a3/5c/261f4a262f1fa65141c1b74b255988bd2fa020cc599e53b080667d591cfc/regex-2025.11.3-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:e18bc3f73bd41243c9b38a6d9f2366cd0e0137a9aebe2d8ff76c5b67d4c0a3f4", size = 291059, upload-time = "2025-11-03T21:32:51.682Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/8e/57/f14eeb7f072b0e9a5a090d1712741fd8f214ec193dba773cf5410108bb7d/regex-2025.11.3-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:61a08bcb0ec14ff4e0ed2044aad948d0659604f824cbd50b55e30b0ec6f09c73", size = 288900, upload-time = "2025-11-03T21:32:53.569Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/3c/6b/1d650c45e99a9b327586739d926a1cd4e94666b1bd4af90428b36af66dc7/regex-2025.11.3-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:c9c30003b9347c24bcc210958c5d167b9e4f9be786cb380a7d32f14f9b84674f", size = 799010, upload-time = "2025-11-03T21:32:55.222Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/99/ee/d66dcbc6b628ce4e3f7f0cbbb84603aa2fc0ffc878babc857726b8aab2e9/regex-2025.11.3-cp314-cp314-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:4e1e592789704459900728d88d41a46fe3969b82ab62945560a31732ffc19a6d", size = 864893, upload-time = "2025-11-03T21:32:57.239Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/bf/2d/f238229f1caba7ac87a6c4153d79947fb0261415827ae0f77c304260c7d3/regex-2025.11.3-cp314-cp314-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:6538241f45eb5a25aa575dbba1069ad786f68a4f2773a29a2bd3dd1f9de787be", size = 911522, upload-time = "2025-11-03T21:32:59.274Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/bd/3d/22a4eaba214a917c80e04f6025d26143690f0419511e0116508e24b11c9b/regex-2025.11.3-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:bce22519c989bb72a7e6b36a199384c53db7722fe669ba891da75907fe3587db", size = 803272, upload-time = "2025-11-03T21:33:01.393Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/84/b1/03188f634a409353a84b5ef49754b97dbcc0c0f6fd6c8ede505a8960a0a4/regex-2025.11.3-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:66d559b21d3640203ab9075797a55165d79017520685fb407b9234d72ab63c62", size = 787958, upload-time = "2025-11-03T21:33:03.379Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/99/6a/27d072f7fbf6fadd59c64d210305e1ff865cc3b78b526fd147db768c553b/regex-2025.11.3-cp314-cp314-musllinux_1_2_ppc64le.whl", hash = "sha256:669dcfb2e38f9e8c69507bace46f4889e3abbfd9b0c29719202883c0a603598f", size = 859289, upload-time = "2025-11-03T21:33:05.374Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/9a/70/1b3878f648e0b6abe023172dacb02157e685564853cc363d9961bcccde4e/regex-2025.11.3-cp314-cp314-musllinux_1_2_s390x.whl", hash = "sha256:32f74f35ff0f25a5021373ac61442edcb150731fbaa28286bbc8bb1582c89d02", size = 850026, upload-time = "2025-11-03T21:33:07.131Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/dd/d5/68e25559b526b8baab8e66839304ede68ff6727237a47727d240006bd0ff/regex-2025.11.3-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:e6c7a21dffba883234baefe91bc3388e629779582038f75d2a5be918e250f0ed", size = 789499, upload-time = "2025-11-03T21:33:09.141Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/fc/df/43971264857140a350910d4e33df725e8c94dd9dee8d2e4729fa0d63d49e/regex-2025.11.3-cp314-cp314-win32.whl", hash = "sha256:795ea137b1d809eb6836b43748b12634291c0ed55ad50a7d72d21edf1cd565c4", size = 271604, upload-time = "2025-11-03T21:33:10.9Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/01/6f/9711b57dc6894a55faf80a4c1b5aa4f8649805cb9c7aef46f7d27e2b9206/regex-2025.11.3-cp314-cp314-win_amd64.whl", hash = "sha256:9f95fbaa0ee1610ec0fc6b26668e9917a582ba80c52cc6d9ada15e30aa9ab9ad", size = 280320, upload-time = "2025-11-03T21:33:12.572Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/f1/7e/f6eaa207d4377481f5e1775cdeb5a443b5a59b392d0065f3417d31d80f87/regex-2025.11.3-cp314-cp314-win_arm64.whl", hash = "sha256:dfec44d532be4c07088c3de2876130ff0fbeeacaa89a137decbbb5f665855a0f", size = 273372, upload-time = "2025-11-03T21:33:14.219Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/c3/06/49b198550ee0f5e4184271cee87ba4dfd9692c91ec55289e6282f0f86ccf/regex-2025.11.3-cp314-cp314t-macosx_10_13_universal2.whl", hash = "sha256:ba0d8a5d7f04f73ee7d01d974d47c5834f8a1b0224390e4fe7c12a3a92a78ecc", size = 491985, upload-time = "2025-11-03T21:33:16.555Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/ce/bf/abdafade008f0b1c9da10d934034cb670432d6cf6cbe38bbb53a1cfd6cf8/regex-2025.11.3-cp314-cp314t-macosx_10_13_x86_64.whl", hash = "sha256:442d86cf1cfe4faabf97db7d901ef58347efd004934da045c745e7b5bd57ac49", size = 292669, upload-time = "2025-11-03T21:33:18.32Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/f9/ef/0c357bb8edbd2ad8e273fcb9e1761bc37b8acbc6e1be050bebd6475f19c1/regex-2025.11.3-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:fd0a5e563c756de210bb964789b5abe4f114dacae9104a47e1a649b910361536", size = 291030, upload-time = "2025-11-03T21:33:20.048Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/79/06/edbb67257596649b8fb088d6aeacbcb248ac195714b18a65e018bf4c0b50/regex-2025.11.3-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:bf3490bcbb985a1ae97b2ce9ad1c0f06a852d5b19dde9b07bdf25bf224248c95", size = 807674, upload-time = "2025-11-03T21:33:21.797Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/f4/d9/ad4deccfce0ea336296bd087f1a191543bb99ee1c53093dcd4c64d951d00/regex-2025.11.3-cp314-cp314t-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:3809988f0a8b8c9dcc0f92478d6501fac7200b9ec56aecf0ec21f4a2ec4b6009", size = 873451, upload-time = "2025-11-03T21:33:23.741Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/13/75/a55a4724c56ef13e3e04acaab29df26582f6978c000ac9cd6810ad1f341f/regex-2025.11.3-cp314-cp314t-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:f4ff94e58e84aedb9c9fce66d4ef9f27a190285b451420f297c9a09f2b9abee9", size = 914980, upload-time = "2025-11-03T21:33:25.999Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/67/1e/a1657ee15bd9116f70d4a530c736983eed997b361e20ecd8f5ca3759d5c5/regex-2025.11.3-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:7eb542fd347ce61e1321b0a6b945d5701528dca0cd9759c2e3bb8bd57e47964d", size = 812852, upload-time = "2025-11-03T21:33:27.852Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/b8/6f/f7516dde5506a588a561d296b2d0044839de06035bb486b326065b4c101e/regex-2025.11.3-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:d6c2d5919075a1f2e413c00b056ea0c2f065b3f5fe83c3d07d325ab92dce51d6", size = 795566, upload-time = "2025-11-03T21:33:32.364Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/d9/dd/3d10b9e170cc16fb34cb2cef91513cf3df65f440b3366030631b2984a264/regex-2025.11.3-cp314-cp314t-musllinux_1_2_ppc64le.whl", hash = "sha256:3f8bf11a4827cc7ce5a53d4ef6cddd5ad25595d3c1435ef08f76825851343154", size = 868463, upload-time = "2025-11-03T21:33:34.459Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/f5/8e/935e6beff1695aa9085ff83195daccd72acc82c81793df480f34569330de/regex-2025.11.3-cp314-cp314t-musllinux_1_2_s390x.whl", hash = "sha256:22c12d837298651e5550ac1d964e4ff57c3f56965fc1812c90c9fb2028eaf267", size = 854694, upload-time = "2025-11-03T21:33:36.793Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/92/12/10650181a040978b2f5720a6a74d44f841371a3d984c2083fc1752e4acf6/regex-2025.11.3-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:62ba394a3dda9ad41c7c780f60f6e4a70988741415ae96f6d1bf6c239cf01379", size = 799691, upload-time = "2025-11-03T21:33:39.079Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/67/90/8f37138181c9a7690e7e4cb388debbd389342db3c7381d636d2875940752/regex-2025.11.3-cp314-cp314t-win32.whl", hash = "sha256:4bf146dca15cdd53224a1bf46d628bd7590e4a07fbb69e720d561aea43a32b38", size = 274583, upload-time = "2025-11-03T21:33:41.302Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/8f/cd/867f5ec442d56beb56f5f854f40abcfc75e11d10b11fdb1869dd39c63aaf/regex-2025.11.3-cp314-cp314t-win_amd64.whl", hash = "sha256:adad1a1bcf1c9e76346e091d22d23ac54ef28e1365117d99521631078dfec9de", size = 284286, upload-time = "2025-11-03T21:33:43.324Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/20/31/32c0c4610cbc070362bf1d2e4ea86d1ea29014d400a6d6c2486fcfd57766/regex-2025.11.3-cp314-cp314t-win_arm64.whl", hash = "sha256:c54f768482cef41e219720013cd05933b6f971d9562544d691c68699bf2b6801", size = 274741, upload-time = "2025-11-03T21:33:45.557Z" }, +] + +[[package]] +name = "requests" +version = "2.32.5" +source = { registry = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple" } +dependencies = [ + { name = "certifi" }, + { name = "charset-normalizer" }, + { name = "idna" }, + { name = "urllib3" }, +] +sdist = { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/c9/74/b3ff8e6c8446842c3f5c837e9c3dfcfe2018ea6ecef224c710c85ef728f4/requests-2.32.5.tar.gz", hash = "sha256:dbba0bac56e100853db0ea71b82b4dfd5fe2bf6d3754a8893c3af500cec7d7cf", size = 134517, upload-time = "2025-08-18T20:46:02.573Z" } +wheels = [ + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/1e/db/4254e3eabe8020b458f1a747140d32277ec7a271daf1d235b70dc0b4e6e3/requests-2.32.5-py3-none-any.whl", hash = "sha256:2462f94637a34fd532264295e186976db0f5d453d1cdd31473c85a6a161affb6", size = 64738, upload-time = "2025-08-18T20:46:00.542Z" }, +] + [[package]] name = "rich" version = "14.1.0" @@ -821,6 +1061,28 @@ wheels = [ { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/e3/30/3c4d035596d3cf444529e0b2953ad0466f6049528a879d27534700580395/rich-14.1.0-py3-none-any.whl", hash = "sha256:536f5f1785986d6dbdea3c75205c473f970777b4a0d6c6dd1b696aa05a3fa04f", size = 243368, upload-time = "2025-07-25T07:32:56.73Z" }, ] +[[package]] +name = "safetensors" +version = "0.7.0" +source = { registry = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple" } +sdist = { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/29/9c/6e74567782559a63bd040a236edca26fd71bc7ba88de2ef35d75df3bca5e/safetensors-0.7.0.tar.gz", hash = "sha256:07663963b67e8bd9f0b8ad15bb9163606cd27cc5a1b96235a50d8369803b96b0", size = 200878, upload-time = "2025-11-19T15:18:43.199Z" } +wheels = [ + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/fa/47/aef6c06649039accf914afef490268e1067ed82be62bcfa5b7e886ad15e8/safetensors-0.7.0-cp38-abi3-macosx_10_12_x86_64.whl", hash = "sha256:c82f4d474cf725255d9e6acf17252991c3c8aac038d6ef363a4bf8be2f6db517", size = 467781, upload-time = "2025-11-19T15:18:35.84Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/e8/00/374c0c068e30cd31f1e1b46b4b5738168ec79e7689ca82ee93ddfea05109/safetensors-0.7.0-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:94fd4858284736bb67a897a41608b5b0c2496c9bdb3bf2af1fa3409127f20d57", size = 447058, upload-time = "2025-11-19T15:18:34.416Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/f1/06/578ffed52c2296f93d7fd2d844cabfa92be51a587c38c8afbb8ae449ca89/safetensors-0.7.0-cp38-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e07d91d0c92a31200f25351f4acb2bc6aff7f48094e13ebb1d0fb995b54b6542", size = 491748, upload-time = "2025-11-19T15:18:09.79Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/ae/33/1debbbb70e4791dde185edb9413d1fe01619255abb64b300157d7f15dddd/safetensors-0.7.0-cp38-abi3-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:8469155f4cb518bafb4acf4865e8bb9d6804110d2d9bdcaa78564b9fd841e104", size = 503881, upload-time = "2025-11-19T15:18:16.145Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/8e/1c/40c2ca924d60792c3be509833df711b553c60effbd91da6f5284a83f7122/safetensors-0.7.0-cp38-abi3-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:54bef08bf00a2bff599982f6b08e8770e09cc012d7bba00783fc7ea38f1fb37d", size = 623463, upload-time = "2025-11-19T15:18:21.11Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/9b/3a/13784a9364bd43b0d61eef4bea2845039bc2030458b16594a1bd787ae26e/safetensors-0.7.0-cp38-abi3-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:42cb091236206bb2016d245c377ed383aa7f78691748f3bb6ee1bfa51ae2ce6a", size = 532855, upload-time = "2025-11-19T15:18:25.719Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/a0/60/429e9b1cb3fc651937727befe258ea24122d9663e4d5709a48c9cbfceecb/safetensors-0.7.0-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dac7252938f0696ddea46f5e855dd3138444e82236e3be475f54929f0c510d48", size = 507152, upload-time = "2025-11-19T15:18:33.023Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/3c/a8/4b45e4e059270d17af60359713ffd83f97900d45a6afa73aaa0d737d48b6/safetensors-0.7.0-cp38-abi3-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:1d060c70284127fa805085d8f10fbd0962792aed71879d00864acda69dbab981", size = 541856, upload-time = "2025-11-19T15:18:31.075Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/06/87/d26d8407c44175d8ae164a95b5a62707fcc445f3c0c56108e37d98070a3d/safetensors-0.7.0-cp38-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:cdab83a366799fa730f90a4ebb563e494f28e9e92c4819e556152ad55e43591b", size = 674060, upload-time = "2025-11-19T15:18:37.211Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/11/f5/57644a2ff08dc6325816ba7217e5095f17269dada2554b658442c66aed51/safetensors-0.7.0-cp38-abi3-musllinux_1_2_armv7l.whl", hash = "sha256:672132907fcad9f2aedcb705b2d7b3b93354a2aec1b2f706c4db852abe338f85", size = 771715, upload-time = "2025-11-19T15:18:38.689Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/86/31/17883e13a814bd278ae6e266b13282a01049b0c81341da7fd0e3e71a80a3/safetensors-0.7.0-cp38-abi3-musllinux_1_2_i686.whl", hash = "sha256:5d72abdb8a4d56d4020713724ba81dac065fedb7f3667151c4a637f1d3fb26c0", size = 714377, upload-time = "2025-11-19T15:18:40.162Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/4a/d8/0c8a7dc9b41dcac53c4cbf9df2b9c83e0e0097203de8b37a712b345c0be5/safetensors-0.7.0-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:b0f6d66c1c538d5a94a73aa9ddca8ccc4227e6c9ff555322ea40bdd142391dd4", size = 677368, upload-time = "2025-11-19T15:18:41.627Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/05/e5/cb4b713c8a93469e3c5be7c3f8d77d307e65fe89673e731f5c2bfd0a9237/safetensors-0.7.0-cp38-abi3-win32.whl", hash = "sha256:c74af94bf3ac15ac4d0f2a7c7b4663a15f8c2ab15ed0fc7531ca61d0835eccba", size = 326423, upload-time = "2025-11-19T15:18:45.74Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/5d/e6/ec8471c8072382cb91233ba7267fd931219753bb43814cbc71757bfd4dab/safetensors-0.7.0-cp38-abi3-win_amd64.whl", hash = "sha256:d1239932053f56f3456f32eb9625590cc7582e905021f94636202a864d470755", size = 341380, upload-time = "2025-11-19T15:18:44.427Z" }, +] + [[package]] name = "scipy" version = "1.16.2" @@ -882,6 +1144,19 @@ wheels = [ { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/97/30/2f9a5243008f76dfc5dee9a53dfb939d9b31e16ce4bd4f2e628bfc5d89d2/scipy-1.16.2-cp314-cp314t-win_arm64.whl", hash = "sha256:d2a4472c231328d4de38d5f1f68fdd6d28a615138f842580a8a321b5845cf779", size = 26448374, upload-time = "2025-09-11T17:45:03.45Z" }, ] +[[package]] +name = "simple-parsing" +version = "0.1.7" +source = { registry = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple" } +dependencies = [ + { name = "docstring-parser" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/eb/c5/f1e2fcb3a81085cdf3cfed48b8c8ce0e7cc30c95dee734cbb35d6265336a/simple_parsing-0.1.7.tar.gz", hash = "sha256:225e6b35252d68f7894716101fe3bd7e6dd3d30ab7b1c3c023f77a42dbe1336f", size = 96375, upload-time = "2025-01-20T19:46:35.986Z" } +wheels = [ + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/4f/9c/e9ea38750027a6de3e3c5e68a19fda0e7b0cd3db8045f30d0f6bc113b911/simple_parsing-0.1.7-py3-none-any.whl", hash = "sha256:5276e6c90c157362dd0173d1eecebe58361a66b457129cc9bba13b78a4e85092", size = 112782, upload-time = "2025-01-20T19:46:33.325Z" }, +] + [[package]] name = "simplejson" version = "3.20.2" @@ -917,6 +1192,56 @@ wheels = [ { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/05/5b/83e1ff87eb60ca706972f7e02e15c0b33396e7bdbd080069a5d1b53cf0d8/simplejson-3.20.2-py3-none-any.whl", hash = "sha256:3b6bb7fb96efd673eac2e4235200bfffdc2353ad12c54117e1e4e2fc485ac017", size = 57309, upload-time = "2025-09-26T16:29:35.312Z" }, ] +[[package]] +name = "six" +version = "1.17.0" +source = { registry = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple" } +sdist = { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/94/e7/b2c673351809dca68a0e064b6af791aa332cf192da575fd474ed7d6f16a2/six-1.17.0.tar.gz", hash = "sha256:ff70335d468e7eb6ec65b95b99d3a2836546063f63acc5171de367e834932a81", size = 34031, upload-time = "2024-12-04T17:35:28.174Z" } +wheels = [ + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/b7/ce/149a00dd41f10bc29e5921b496af8b574d8413afcd5e30dfa0ed46c2cc5e/six-1.17.0-py2.py3-none-any.whl", hash = "sha256:4721f391ed90541fddacab5acf947aa0d3dc7d27b2e1e8eda2be8970586c3274", size = 11050, upload-time = "2024-12-04T17:35:26.475Z" }, +] + +[[package]] +name = "tensorflow-datasets" +version = "4.9.9" +source = { registry = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple" } +dependencies = [ + { name = "absl-py" }, + { name = "array-record", marker = "sys_platform == 'linux'" }, + { name = "dm-tree" }, + { name = "etils", extra = ["edc", "enp", "epath", "epy", "etree"] }, + { name = "immutabledict" }, + { name = "numpy" }, + { name = "promise" }, + { name = "protobuf" }, + { name = "psutil" }, + { name = "pyarrow" }, + { name = "requests" }, + { name = "simple-parsing" }, + { name = "tensorflow-metadata" }, + { name = "termcolor" }, + { name = "toml" }, + { name = "tqdm" }, + { name = "wrapt" }, +] +sdist = { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/c9/92/a436764aeea5aa0c85774770afdc6063b1016dd38b67e39c5b6240cf1deb/tensorflow_datasets-4.9.9.tar.gz", hash = "sha256:9cb245cad97e7d227f0b8e006491cfef860ff8d4b9d84a3c68f8b96d6295355e", size = 3943946, upload-time = "2025-05-28T13:38:17.691Z" } +wheels = [ + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/16/e0/657192dbc03636532ccbd5c90669d31a65187365b99ba685db36bb31dd67/tensorflow_datasets-4.9.9-py3-none-any.whl", hash = "sha256:b94902d414cdc12a1014cda9ee5815c502c3d44215b780e06dacbd7949abd14e", size = 5319309, upload-time = "2025-05-28T13:38:15.693Z" }, +] + +[[package]] +name = "tensorflow-metadata" +version = "1.17.2" +source = { registry = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple" } +dependencies = [ + { name = "absl-py" }, + { name = "googleapis-common-protos" }, + { name = "protobuf" }, +] +wheels = [ + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/ec/bf/443fc7c325e59b9016f710bc7f6b0574e73eca7b789a6a6a2054ca77e887/tensorflow_metadata-1.17.2-py3-none-any.whl", hash = "sha256:7da3f8501d6ccfcdbe1a56e975c3624150ce6829048ab9efe62409c362d509b4", size = 31536, upload-time = "2025-06-24T18:11:31.326Z" }, +] + [[package]] name = "tensorstore" version = "0.1.78" @@ -939,6 +1264,71 @@ wheels = [ { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/db/a2/dbd1af0e97d5d549051309d72c6e3f2fe81fae636f9db3692d21adc9c731/tensorstore-0.1.78-cp313-cp313-win_amd64.whl", hash = "sha256:e0073de8fa3074bc4cc92ced0210310fd89851899faf42a5ba256f0ba87d095c", size = 12711250, upload-time = "2025-10-06T17:44:27.926Z" }, ] +[[package]] +name = "termcolor" +version = "3.2.0" +source = { registry = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple" } +sdist = { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/87/56/ab275c2b56a5e2342568838f0d5e3e66a32354adcc159b495e374cda43f5/termcolor-3.2.0.tar.gz", hash = "sha256:610e6456feec42c4bcd28934a8c87a06c3fa28b01561d46aa09a9881b8622c58", size = 14423, upload-time = "2025-10-25T19:11:42.586Z" } +wheels = [ + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/f9/d5/141f53d7c1eb2a80e6d3e9a390228c3222c27705cbe7f048d3623053f3ca/termcolor-3.2.0-py3-none-any.whl", hash = "sha256:a10343879eba4da819353c55cb8049b0933890c2ebf9ad5d3ecd2bb32ea96ea6", size = 7698, upload-time = "2025-10-25T19:11:41.536Z" }, +] + +[[package]] +name = "tiktoken" +version = "0.12.0" +source = { registry = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple" } +dependencies = [ + { name = "regex" }, + { name = "requests" }, +] +sdist = { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/7d/ab/4d017d0f76ec3171d469d80fc03dfbb4e48a4bcaddaa831b31d526f05edc/tiktoken-0.12.0.tar.gz", hash = "sha256:b18ba7ee2b093863978fcb14f74b3707cdc8d4d4d3836853ce7ec60772139931", size = 37806, upload-time = "2025-10-06T20:22:45.419Z" } +wheels = [ + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/a4/85/be65d39d6b647c79800fd9d29241d081d4eeb06271f383bb87200d74cf76/tiktoken-0.12.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:b97f74aca0d78a1ff21b8cd9e9925714c15a9236d6ceacf5c7327c117e6e21e8", size = 1050728, upload-time = "2025-10-06T20:21:52.756Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/4a/42/6573e9129bc55c9bf7300b3a35bef2c6b9117018acca0dc760ac2d93dffe/tiktoken-0.12.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:2b90f5ad190a4bb7c3eb30c5fa32e1e182ca1ca79f05e49b448438c3e225a49b", size = 994049, upload-time = "2025-10-06T20:21:53.782Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/66/c5/ed88504d2f4a5fd6856990b230b56d85a777feab84e6129af0822f5d0f70/tiktoken-0.12.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:65b26c7a780e2139e73acc193e5c63ac754021f160df919add909c1492c0fb37", size = 1129008, upload-time = "2025-10-06T20:21:54.832Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/f4/90/3dae6cc5436137ebd38944d396b5849e167896fc2073da643a49f372dc4f/tiktoken-0.12.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:edde1ec917dfd21c1f2f8046b86348b0f54a2c0547f68149d8600859598769ad", size = 1152665, upload-time = "2025-10-06T20:21:56.129Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/a3/fe/26df24ce53ffde419a42f5f53d755b995c9318908288c17ec3f3448313a3/tiktoken-0.12.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:35a2f8ddd3824608b3d650a000c1ef71f730d0c56486845705a8248da00f9fe5", size = 1194230, upload-time = "2025-10-06T20:21:57.546Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/20/cc/b064cae1a0e9fac84b0d2c46b89f4e57051a5f41324e385d10225a984c24/tiktoken-0.12.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:83d16643edb7fa2c99eff2ab7733508aae1eebb03d5dfc46f5565862810f24e3", size = 1254688, upload-time = "2025-10-06T20:21:58.619Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/81/10/b8523105c590c5b8349f2587e2fdfe51a69544bd5a76295fc20f2374f470/tiktoken-0.12.0-cp312-cp312-win_amd64.whl", hash = "sha256:ffc5288f34a8bc02e1ea7047b8d041104791d2ddbf42d1e5fa07822cbffe16bd", size = 878694, upload-time = "2025-10-06T20:21:59.876Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/00/61/441588ee21e6b5cdf59d6870f86beb9789e532ee9718c251b391b70c68d6/tiktoken-0.12.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:775c2c55de2310cc1bc9a3ad8826761cbdc87770e586fd7b6da7d4589e13dab3", size = 1050802, upload-time = "2025-10-06T20:22:00.96Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/1f/05/dcf94486d5c5c8d34496abe271ac76c5b785507c8eae71b3708f1ad9b45a/tiktoken-0.12.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:a01b12f69052fbe4b080a2cfb867c4de12c704b56178edf1d1d7b273561db160", size = 993995, upload-time = "2025-10-06T20:22:02.788Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/a0/70/5163fe5359b943f8db9946b62f19be2305de8c3d78a16f629d4165e2f40e/tiktoken-0.12.0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:01d99484dc93b129cd0964f9d34eee953f2737301f18b3c7257bf368d7615baa", size = 1128948, upload-time = "2025-10-06T20:22:03.814Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/0c/da/c028aa0babf77315e1cef357d4d768800c5f8a6de04d0eac0f377cb619fa/tiktoken-0.12.0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:4a1a4fcd021f022bfc81904a911d3df0f6543b9e7627b51411da75ff2fe7a1be", size = 1151986, upload-time = "2025-10-06T20:22:05.173Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/a0/5a/886b108b766aa53e295f7216b509be95eb7d60b166049ce2c58416b25f2a/tiktoken-0.12.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:981a81e39812d57031efdc9ec59fa32b2a5a5524d20d4776574c4b4bd2e9014a", size = 1194222, upload-time = "2025-10-06T20:22:06.265Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/f4/f8/4db272048397636ac7a078d22773dd2795b1becee7bc4922fe6207288d57/tiktoken-0.12.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:9baf52f84a3f42eef3ff4e754a0db79a13a27921b457ca9832cf944c6be4f8f3", size = 1255097, upload-time = "2025-10-06T20:22:07.403Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/8e/32/45d02e2e0ea2be3a9ed22afc47d93741247e75018aac967b713b2941f8ea/tiktoken-0.12.0-cp313-cp313-win_amd64.whl", hash = "sha256:b8a0cd0c789a61f31bf44851defbd609e8dd1e2c8589c614cc1060940ef1f697", size = 879117, upload-time = "2025-10-06T20:22:08.418Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/ce/76/994fc868f88e016e6d05b0da5ac24582a14c47893f4474c3e9744283f1d5/tiktoken-0.12.0-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:d5f89ea5680066b68bcb797ae85219c72916c922ef0fcdd3480c7d2315ffff16", size = 1050309, upload-time = "2025-10-06T20:22:10.939Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/f6/b8/57ef1456504c43a849821920d582a738a461b76a047f352f18c0b26c6516/tiktoken-0.12.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:b4e7ed1c6a7a8a60a3230965bdedba8cc58f68926b835e519341413370e0399a", size = 993712, upload-time = "2025-10-06T20:22:12.115Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/72/90/13da56f664286ffbae9dbcfadcc625439142675845baa62715e49b87b68b/tiktoken-0.12.0-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:fc530a28591a2d74bce821d10b418b26a094bf33839e69042a6e86ddb7a7fb27", size = 1128725, upload-time = "2025-10-06T20:22:13.541Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/05/df/4f80030d44682235bdaecd7346c90f67ae87ec8f3df4a3442cb53834f7e4/tiktoken-0.12.0-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:06a9f4f49884139013b138920a4c393aa6556b2f8f536345f11819389c703ebb", size = 1151875, upload-time = "2025-10-06T20:22:14.559Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/22/1f/ae535223a8c4ef4c0c1192e3f9b82da660be9eb66b9279e95c99288e9dab/tiktoken-0.12.0-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:04f0e6a985d95913cabc96a741c5ffec525a2c72e9df086ff17ebe35985c800e", size = 1194451, upload-time = "2025-10-06T20:22:15.545Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/78/a7/f8ead382fce0243cb625c4f266e66c27f65ae65ee9e77f59ea1653b6d730/tiktoken-0.12.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:0ee8f9ae00c41770b5f9b0bb1235474768884ae157de3beb5439ca0fd70f3e25", size = 1253794, upload-time = "2025-10-06T20:22:16.624Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/93/e0/6cc82a562bc6365785a3ff0af27a2a092d57c47d7a81d9e2295d8c36f011/tiktoken-0.12.0-cp313-cp313t-win_amd64.whl", hash = "sha256:dc2dd125a62cb2b3d858484d6c614d136b5b848976794edfb63688d539b8b93f", size = 878777, upload-time = "2025-10-06T20:22:18.036Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/72/05/3abc1db5d2c9aadc4d2c76fa5640134e475e58d9fbb82b5c535dc0de9b01/tiktoken-0.12.0-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:a90388128df3b3abeb2bfd1895b0681412a8d7dc644142519e6f0a97c2111646", size = 1050188, upload-time = "2025-10-06T20:22:19.563Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/e3/7b/50c2f060412202d6c95f32b20755c7a6273543b125c0985d6fa9465105af/tiktoken-0.12.0-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:da900aa0ad52247d8794e307d6446bd3cdea8e192769b56276695d34d2c9aa88", size = 993978, upload-time = "2025-10-06T20:22:20.702Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/14/27/bf795595a2b897e271771cd31cb847d479073497344c637966bdf2853da1/tiktoken-0.12.0-cp314-cp314-manylinux_2_28_aarch64.whl", hash = "sha256:285ba9d73ea0d6171e7f9407039a290ca77efcdb026be7769dccc01d2c8d7fff", size = 1129271, upload-time = "2025-10-06T20:22:22.06Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/f5/de/9341a6d7a8f1b448573bbf3425fa57669ac58258a667eb48a25dfe916d70/tiktoken-0.12.0-cp314-cp314-manylinux_2_28_x86_64.whl", hash = "sha256:d186a5c60c6a0213f04a7a802264083dea1bbde92a2d4c7069e1a56630aef830", size = 1151216, upload-time = "2025-10-06T20:22:23.085Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/75/0d/881866647b8d1be4d67cb24e50d0c26f9f807f994aa1510cb9ba2fe5f612/tiktoken-0.12.0-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:604831189bd05480f2b885ecd2d1986dc7686f609de48208ebbbddeea071fc0b", size = 1194860, upload-time = "2025-10-06T20:22:24.602Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/b3/1e/b651ec3059474dab649b8d5b69f5c65cd8fcd8918568c1935bd4136c9392/tiktoken-0.12.0-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:8f317e8530bb3a222547b85a58583238c8f74fd7a7408305f9f63246d1a0958b", size = 1254567, upload-time = "2025-10-06T20:22:25.671Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/80/57/ce64fd16ac390fafde001268c364d559447ba09b509181b2808622420eec/tiktoken-0.12.0-cp314-cp314-win_amd64.whl", hash = "sha256:399c3dd672a6406719d84442299a490420b458c44d3ae65516302a99675888f3", size = 921067, upload-time = "2025-10-06T20:22:26.753Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/ac/a4/72eed53e8976a099539cdd5eb36f241987212c29629d0a52c305173e0a68/tiktoken-0.12.0-cp314-cp314t-macosx_10_13_x86_64.whl", hash = "sha256:c2c714c72bc00a38ca969dae79e8266ddec999c7ceccd603cc4f0d04ccd76365", size = 1050473, upload-time = "2025-10-06T20:22:27.775Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/e6/d7/0110b8f54c008466b19672c615f2168896b83706a6611ba6e47313dbc6e9/tiktoken-0.12.0-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:cbb9a3ba275165a2cb0f9a83f5d7025afe6b9d0ab01a22b50f0e74fee2ad253e", size = 993855, upload-time = "2025-10-06T20:22:28.799Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/5f/77/4f268c41a3957c418b084dd576ea2fad2e95da0d8e1ab705372892c2ca22/tiktoken-0.12.0-cp314-cp314t-manylinux_2_28_aarch64.whl", hash = "sha256:dfdfaa5ffff8993a3af94d1125870b1d27aed7cb97aa7eb8c1cefdbc87dbee63", size = 1129022, upload-time = "2025-10-06T20:22:29.981Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/4e/2b/fc46c90fe5028bd094cd6ee25a7db321cb91d45dc87531e2bdbb26b4867a/tiktoken-0.12.0-cp314-cp314t-manylinux_2_28_x86_64.whl", hash = "sha256:584c3ad3d0c74f5269906eb8a659c8bfc6144a52895d9261cdaf90a0ae5f4de0", size = 1150736, upload-time = "2025-10-06T20:22:30.996Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/28/c0/3c7a39ff68022ddfd7d93f3337ad90389a342f761c4d71de99a3ccc57857/tiktoken-0.12.0-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:54c891b416a0e36b8e2045b12b33dd66fb34a4fe7965565f1b482da50da3e86a", size = 1194908, upload-time = "2025-10-06T20:22:32.073Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/ab/0d/c1ad6f4016a3968c048545f5d9b8ffebf577774b2ede3e2e352553b685fe/tiktoken-0.12.0-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:5edb8743b88d5be814b1a8a8854494719080c28faaa1ccbef02e87354fe71ef0", size = 1253706, upload-time = "2025-10-06T20:22:33.385Z" }, + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/af/df/c7891ef9d2712ad774777271d39fdef63941ffba0a9d59b7ad1fd2765e57/tiktoken-0.12.0-cp314-cp314t-win_amd64.whl", hash = "sha256:f61c0aea5565ac82e2ec50a05e02a6c44734e91b51c10510b084ea1b8e633a71", size = 920667, upload-time = "2025-10-06T20:22:34.444Z" }, +] + +[[package]] +name = "toml" +version = "0.10.2" +source = { registry = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple" } +sdist = { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/be/ba/1f744cdc819428fc6b5084ec34d9b30660f6f9daaf70eead706e3203ec3c/toml-0.10.2.tar.gz", hash = "sha256:b3bda1d108d5dd99f4a20d24d9c348e91c4db7ab1b749200bded2f839ccbe68f", size = 22253, upload-time = "2020-11-01T01:40:22.204Z" } +wheels = [ + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/44/6f/7120676b6d73228c96e17f1f794d8ab046fc910d781c8d151120c3f1569e/toml-0.10.2-py2.py3-none-any.whl", hash = "sha256:806143ae5bfb6a3c6e736a764057db0e6a0e05e338b5630894a5f779cabb4f9b", size = 16588, upload-time = "2020-11-01T01:40:20.672Z" }, +] + [[package]] name = "toolz" version = "1.0.0" @@ -948,6 +1338,18 @@ wheels = [ { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/03/98/eb27cc78ad3af8e302c9d8ff4977f5026676e130d28dd7578132a457170c/toolz-1.0.0-py3-none-any.whl", hash = "sha256:292c8f1c4e7516bf9086f8850935c799a874039c8bcf959d47b600e4c44a6236", size = 56383, upload-time = "2024-10-04T16:17:01.533Z" }, ] +[[package]] +name = "tqdm" +version = "4.67.1" +source = { registry = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple" } +dependencies = [ + { name = "colorama", marker = "sys_platform == 'win32'" }, +] +sdist = { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/a8/4b/29b4ef32e036bb34e4ab51796dd745cdba7ed47ad142a9f4a1eb8e0c744d/tqdm-4.67.1.tar.gz", hash = "sha256:f8aef9c52c08c13a65f30ea34f4e5aac3fd1a34959879d7e59e63027286627f2", size = 169737, upload-time = "2024-11-24T20:12:22.481Z" } +wheels = [ + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/d0/30/dc54f88dd4a2b5dc8a0279bdd7270e735851848b762aeb1c1184ed1f6b14/tqdm-4.67.1-py3-none-any.whl", hash = "sha256:26445eca388f82e72884e0d580d5464cd801a3ea01e63e5601bdff9ba6a48de2", size = 78540, upload-time = "2024-11-24T20:12:19.698Z" }, +] + [[package]] name = "typing-extensions" version = "4.15.0" @@ -969,6 +1371,15 @@ wheels = [ { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/dc/9b/47798a6c91d8bdb567fe2698fe81e0c6b7cb7ef4d13da4114b41d239f65d/typing_inspection-0.4.2-py3-none-any.whl", hash = "sha256:4ed1cacbdc298c220f1bd249ed5287caa16f34d44ef4e9c3d0cbad5b521545e7", size = 14611, upload-time = "2025-10-01T02:14:40.154Z" }, ] +[[package]] +name = "urllib3" +version = "2.5.0" +source = { registry = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple" } +sdist = { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/15/22/9ee70a2574a4f4599c47dd506532914ce044817c7752a79b6a51286319bc/urllib3-2.5.0.tar.gz", hash = "sha256:3fc47733c7e419d4bc3f6b3dc2b4f890bb743906a30d56ba4a5bfa4bbff92760", size = 393185, upload-time = "2025-06-18T14:07:41.644Z" } +wheels = [ + { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/a7/c2/fe1e52489ae3122415c51f387e221dd0773709bad6c6cdaa599e8a2c5185/urllib3-2.5.0-py3-none-any.whl", hash = "sha256:e6b01673c0fa6a13e374b50871808eb3bf7046c4b125b216f6bf1cc604cff0dc", size = 129795, upload-time = "2025-06-18T14:07:40.39Z" }, +] + [[package]] name = "wrapt" version = "1.17.3"