From c812aafb8e926aa21e98a9a86c3d7ae1988fcb39 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Fri, 17 Oct 2025 18:25:57 +0100 Subject: [PATCH 01/72] wip --- demo.py | 44 +++++ muutils/cli/arg_bool.py | 261 ++++++++++++++++++++++++++ muutils/ml/cuda_mem_info.py | 53 ++++++ muutils/web/html_to_pdf.py | 37 ++++ tests/benchmark_parallel.py | 347 +++++++++++++++++++++++++++++++++++ tests/test_benchmark_demo.py | 50 +++++ 6 files changed, 792 insertions(+) create mode 100644 demo.py create mode 100644 muutils/cli/arg_bool.py create mode 100644 muutils/ml/cuda_mem_info.py create mode 100644 muutils/web/html_to_pdf.py create mode 100644 tests/benchmark_parallel.py create mode 100644 tests/test_benchmark_demo.py diff --git a/demo.py b/demo.py new file mode 100644 index 00000000..671940cb --- /dev/null +++ b/demo.py @@ -0,0 +1,44 @@ +#%% +import torch +import numpy as np +from muutils.dbg import dbg_tensor + +# Different shapes +scalar = torch.tensor(42.0) +dbg_tensor(scalar) + +vector = torch.randn(10) +dbg_tensor(vector) + +matrix = torch.randn(5, 8) +dbg_tensor(matrix) + +# With NaN values +nan_tensor = torch.randn(100, 100) +nan_tensor[0:20, 0:20] = float('nan') +dbg_tensor(nan_tensor) + +# with Inf values +inf_tensor = torch.randn(100, 100) +inf_tensor[0:20, 0:20] = float('inf') +dbg_tensor(inf_tensor) + +# Different dtypes +bool_tensor = torch.rand(50, 50) > 0.5 +dbg_tensor(bool_tensor) + +int_tensor = torch.randint(-1000, 1000, (50, 50), dtype=torch.int32) +dbg_tensor(int_tensor) + +# CUDA if available +if torch.cuda.is_available(): + cuda_tensor = torch.randn(50, 50).cuda() + dbg_tensor(cuda_tensor) + +# NumPy +np_array = np.random.randn(50, 50) +dbg_tensor(np_array) + +# With gradients +grad_tensor = torch.randn(50, 50, requires_grad=True) +dbg_tensor(grad_tensor) \ No newline at end of file diff --git a/muutils/cli/arg_bool.py b/muutils/cli/arg_bool.py new file mode 100644 index 00000000..e320fea7 --- /dev/null +++ b/muutils/cli/arg_bool.py @@ -0,0 +1,261 @@ +import argparse +from collections.abc import Callable, Iterable, Sequence +from typing import Any, Final, override + + +def format_function_docstring[T_callable: Callable[..., Any]]( + mapping: dict[str, Any], + /, +) -> Callable[[T_callable], T_callable]: + """Decorator to format function docstring with the given keyword arguments""" + + # I think we don't need to use functools.wraps here, since we return the same function + def decorator(func: T_callable) -> T_callable: + assert func.__doc__ is not None, "Function must have a docstring to format." + func.__doc__ = func.__doc__.format_map(mapping) + return func + + return decorator + + +# Default token sets (lowercase). You can override per-option. +TRUE_SET_DEFAULT: Final[set[str]] = {"1", "true", "t", "yes", "y", "on"} +FALSE_SET_DEFAULT: Final[set[str]] = {"0", "false", "f", "no", "n", "off"} + + +def _normalize_set(tokens: Iterable[str] | None, fallback: set[str]) -> set[str]: + """Normalize a collection of tokens to a lowercase set, or return fallback.""" + if tokens is None: + return set(fallback) + return {str(t).lower() for t in tokens} + + +def parse_bool_token( + token: str, + true_set: set[str] | None = None, + false_set: set[str] | None = None, +) -> bool: + """Strict string-to-bool converter for argparse and friends. + + # Parameters: + - `token : str` + input token + - `true_set : set[str] | None` + accepted truthy strings (case-insensitive). + Defaults to TRUE_SET_DEFAULT when None. + - `false_set : set[str] | None` + accepted falsy strings (case-insensitive). + Defaults to FALSE_SET_DEFAULT when None. + + # Returns: + - `bool` + parsed boolean + + # Raises: + - `argparse.ArgumentTypeError` : if not a recognized boolean string + """ + ts: set[str] = _normalize_set(true_set, TRUE_SET_DEFAULT) + fs: set[str] = _normalize_set(false_set, FALSE_SET_DEFAULT) + v: str = token.lower() + if v in ts: + return True + if v in fs: + return False + valid: list[str] = sorted(ts | fs) + raise argparse.ArgumentTypeError(f"expected one of {valid}") + + +class BoolFlagOrValue(argparse.Action): + """summary + + Configurable boolean action supporting any combination of: + --flag -> True (if allow_bare) + --no-flag -> False (if allow_no and --no-flag is registered) + --flag true|false -> parsed via custom sets + --flag=true|false -> parsed via custom sets + + Notes: + - The --no-flag form never accepts a value. It forces False. + - If allow_no is False but you still register a --no-flag alias, + using it will produce a usage error. + - Do not pass type= to this action. + + # Parameters: + - `option_strings : list[str]` + provided by argparse + - `dest : str` + attribute name on the namespace + - `nargs : int | str | None` + must be '?' for optional value + - `true_set : set[str] | None` + accepted truthy strings (case-insensitive). Defaults provided. + - `false_set : set[str] | None` + accepted falsy strings (case-insensitive). Defaults provided. + - `allow_no : bool` + whether the --no-flag form is allowed (defaults to True) + - `allow_bare : bool` + whether bare --flag (no value) is allowed (defaults to True) + - `**kwargs` + forwarded to base class + + # Raises: + - `ValueError` : if nargs is not '?' or if type= is provided + """ + + def __init__( + self, + option_strings: Sequence[str], + dest: str, + nargs: int | str | None = None, + **kwargs: bool | set[str] | None, + ) -> None: + # Extract custom kwargs before calling super().__init__ + true_set_opt: set[str] | None = kwargs.pop("true_set", None) # pyright: ignore[reportAssignmentType] + false_set_opt: set[str] | None = kwargs.pop("false_set", None) # pyright: ignore[reportAssignmentType] + allow_no_opt: bool = bool(kwargs.pop("allow_no", True)) + allow_bare_opt: bool = bool(kwargs.pop("allow_bare", True)) + + if "type" in kwargs and kwargs["type"] is not None: + raise ValueError("BoolFlagOrValue does not accept type=. Remove it.") + + if nargs not in (None, "?"): + raise ValueError("BoolFlagOrValue requires nargs='?'") + + super().__init__( + option_strings=option_strings, + dest=dest, + nargs="?", + **kwargs, # pyright: ignore[reportArgumentType] + ) + # Store normalized config + self.true_set: set[str] = _normalize_set(true_set_opt, TRUE_SET_DEFAULT) + self.false_set: set[str] = _normalize_set(false_set_opt, FALSE_SET_DEFAULT) + self.allow_no: bool = allow_no_opt + self.allow_bare: bool = allow_bare_opt + + def _parse_token(self, token: str) -> bool: + """Parse a boolean token using this action's configured sets.""" + return parse_bool_token(token, self.true_set, self.false_set) + + @override + def __call__( + self, + parser: argparse.ArgumentParser, + namespace: argparse.Namespace, + values: str | Sequence[str] | None, + option_string: str | None = None, + ) -> None: + # Negated form handling + if option_string is not None and option_string.startswith("--no-"): + if not self.allow_no: + parser.error(f"{option_string} is not allowed for this option") + return # pyright: ignore[reportUnreachable] + if values is not None: + dest_flag: str = self.dest.replace("_", "-") + parser.error(f"{option_string} does not take a value; use --{dest_flag} true|false") + return # pyright: ignore[reportUnreachable] + setattr(namespace, self.dest, False) + return + + # Bare positive flag -> True (if allowed) + if values is None: + if not self.allow_bare: + valid: list[str] = sorted(self.true_set | self.false_set) + parser.error(f"option {option_string} requires a value; expected one of {valid}") + return # pyright: ignore[reportUnreachable] + setattr(namespace, self.dest, True) + return + + # we take only one value + if not isinstance(values, str): + if len(values) != 1: + parser.error( + f"{option_string} expects a single value, got {len(values) = }, {values = }" + ) + return # pyright: ignore[reportUnreachable] + values = values[0] # type: ignore[assignment] + + # Positive flag with explicit value -> parse + try: + val: bool = self._parse_token(values) + except argparse.ArgumentTypeError as e: + parser.error(str(e)) + return # pyright: ignore[reportUnreachable] + setattr(namespace, self.dest, val) + + +def add_bool_flag( + parser: argparse.ArgumentParser, + name: str, + *, + default: bool = False, + help: str = "", + true_set: set[str] | None = None, + false_set: set[str] | None = None, + allow_no: bool = False, + allow_bare: bool = True, +) -> None: + """summary + + Add a configurable boolean option that supports (depending on options): + -- (bare positive, if allow_bare) + --no- (negated, if allow_no) + -- true|false + --=true|false + + # Parameters: + - `parser : argparse.ArgumentParser` + parser to modify + - `name : str` + base long option name (without leading dashes) + - `default : bool` + default value (defaults to False) + - `help : str` + help text (optional) + - `true_set : set[str] | None` + accepted truthy strings (case-insensitive). Defaults used when None. + - `false_set : set[str] | None` + accepted falsy strings (case-insensitive). Defaults used when None. + - `allow_no : bool` + whether to register/allow the --no- alias (defaults to True) + - `allow_bare : bool` + whether bare -- implies True (defaults to True) + + # Returns: + - `None` + nothing; parser is modified + + # Modifies: + - `parser` : adds a new argument with dest `` (hyphens -> underscores) + + # Usage: + ```python + p = argparse.ArgumentParser() + add_bool_flag(p, "feature", default=False, help="enable/disable feature") + ns = p.parse_args(["--feature=false"]) + assert ns.feature is False + ``` + """ + long_opt: str = f"--{name}" + dest: str = name.replace("-", "_") + option_strings: list[str] = [long_opt] + if allow_no: + option_strings.append(f"--no-{name}") + + tokens_preview: str = "{true,false}" + readable_name: str = name.replace("-", " ") + arg_help: str = help or (f"enable/disable {readable_name}; also accepts explicit true|false") + + parser.add_argument( + *option_strings, + dest=dest, + action=BoolFlagOrValue, + nargs="?", + default=default, + metavar=tokens_preview, + help=arg_help, + true_set=true_set, + false_set=false_set, + allow_no=allow_no, + allow_bare=allow_bare, + ) diff --git a/muutils/ml/cuda_mem_info.py b/muutils/ml/cuda_mem_info.py new file mode 100644 index 00000000..40b92a03 --- /dev/null +++ b/muutils/ml/cuda_mem_info.py @@ -0,0 +1,53 @@ +import torch + +# pyright: reportUnreachable=false, reportUnnecessaryIsInstance=false + + +def _to_cuda_device(device: int | str | torch.device) -> torch.device: + """Return a normalized CUDA device object.""" + dev: torch.device + if isinstance(device, torch.device): + dev = device + elif isinstance(device, int): + dev = torch.device(f"cuda:{device}") + elif isinstance(device, str): + # Accept forms like "cuda", "cuda:0", or bare index "0" + dev = torch.device(device) + else: + raise TypeError(f"Unsupported device type: {type(device).__name__}") + + if dev.type != "cuda": + raise ValueError(f"Device {dev} is not a CUDA device") + + return dev + + +def cuda_mem_info(dev: torch.device) -> tuple[int, int]: + """Return (free, total) bytes for a CUDA device.""" + current_idx: int = torch.cuda.current_device() + if dev.index != current_idx: + torch.cuda.set_device(dev) + free: int + total: int + free, total = torch.cuda.mem_get_info() + torch.cuda.set_device(current_idx) + else: + free, total = torch.cuda.mem_get_info() + return free, total + + +def cuda_memory_used(device: int | str | torch.device = 0) -> int: + """Return bytes currently allocated on a CUDA device.""" + dev: torch.device = _to_cuda_device(device) + free, total = cuda_mem_info(dev) + used: int = total - free + return used + + +def cuda_memory_fraction(device: int | str | torch.device = 0) -> float: + """Return fraction of total memory in use on a CUDA device.""" + dev: torch.device = _to_cuda_device(device) + free, total = cuda_mem_info(dev) + used: int = total - free + fraction: float = used / total if total else 0.0 + return fraction diff --git a/muutils/web/html_to_pdf.py b/muutils/web/html_to_pdf.py new file mode 100644 index 00000000..f8e6990e --- /dev/null +++ b/muutils/web/html_to_pdf.py @@ -0,0 +1,37 @@ +from pathlib import Path +import subprocess + +from weasyprint import HTML as WeasyHTML + +def html_to_pdf(src: Path, dst: Path) -> None: + "write HTML file to PDF using WeasyPrint." + WeasyHTML(filename=src.as_posix()).write_pdf(dst.as_posix()) + + +def crop(pdf_in: Path, pdf_out: Path, margin_pt: int = 2) -> None: + """Run pdfcrop with a tiny safety margin.""" + subprocess.run( + ["pdfcrop", "--margins", str(margin_pt), pdf_in.as_posix(), pdf_out.as_posix()], + check=True, + ) + + +def save_html_to_pdf( + html: str, + pdf_out: Path, + pdfcrop: bool = True, + margin_pt: int = 2, +) -> None: + """Save HTML string to PDF file.""" + if isinstance(pdf_out, str): + pdf_out = Path(pdf_out) + temp_html: Path = pdf_out.with_suffix(".html") + temp_html.write_text(html, encoding="utf-8") + + html_to_pdf(temp_html, pdf_out) + + if pdfcrop: + crop(pdf_out, pdf_out, margin_pt) + + # Clean up temporary HTML file + temp_html.unlink(missing_ok=True) \ No newline at end of file diff --git a/tests/benchmark_parallel.py b/tests/benchmark_parallel.py new file mode 100644 index 00000000..8fe6743c --- /dev/null +++ b/tests/benchmark_parallel.py @@ -0,0 +1,347 @@ +#!/usr/bin/env python3 +"""Benchmark test comparing run_maybe_parallel with other parallelization techniques. + +Run with: python tests/benchmark_parallel.py +""" + +import time +import multiprocessing +from typing import List, Callable, Any, Dict +import pandas as pd +import numpy as np +from collections import defaultdict + +from muutils.parallel import run_maybe_parallel + + +def cpu_bound_task(x: int) -> int: + """CPU-intensive task for benchmarking.""" + # Simulate CPU work with a loop + result = 0 + for i in range(1000): + result += (x * i) % 1000 + return result + + +def io_bound_task(x: int) -> int: + """IO-bound task for benchmarking.""" + time.sleep(0.001) # Simulate I/O wait + return x * 2 + + +def light_cpu_task(x: int) -> int: + """Light CPU task for benchmarking.""" + return x ** 2 + x * 3 + 7 + + +class BenchmarkRunner: + """Run benchmarks and collect timing data.""" + + def __init__(self): + self.results = defaultdict(list) + self.cpu_count = multiprocessing.cpu_count() + + def time_execution(self, func: Callable, *args, **kwargs) -> float: + """Time a single execution.""" + start = time.perf_counter() + func(*args, **kwargs) + return time.perf_counter() - start + + def benchmark_method(self, method_name: str, method_func: Callable, + task_func: Callable, data: List[int], + runs: int = 3) -> Dict[str, float]: + """Benchmark a single method multiple times.""" + times = [] + for _ in range(runs): + _, duration = method_func(task_func, data) + times.append(duration) + + return { + 'mean': np.mean(times), + 'std': np.std(times), + 'min': np.min(times), + 'max': np.max(times), + 'median': np.median(times) + } + + def run_benchmark_suite(self, data_sizes: List[int], task_funcs: Dict[str, Callable], + runs_per_method: int = 3) -> pd.DataFrame: + """Run complete benchmark suite and return results as DataFrame.""" + + for data_size in data_sizes: + test_data = list(range(data_size)) + + for task_name, task_func in task_funcs.items(): + print(f"\nBenchmarking {task_name} with {data_size} items...") + + # Sequential baseline + stats = self.benchmark_method( + "sequential", benchmark_sequential, task_func, test_data, runs_per_method + ) + self._record_result("sequential", task_name, data_size, stats) + + # Pool.map + stats = self.benchmark_method( + "pool_map", benchmark_pool_map, task_func, test_data, runs_per_method + ) + self._record_result("pool_map", task_name, data_size, stats) + + # Pool.imap with optimal chunk size + chunksize = max(1, data_size // (self.cpu_count * 4)) + imap_func = lambda f, d: benchmark_pool_imap(f, d, chunksize=chunksize) + stats = self.benchmark_method( + "pool_imap", imap_func, task_func, test_data, runs_per_method + ) + self._record_result("pool_imap", task_name, data_size, stats) + + # Pool.imap_unordered + imap_unord_func = lambda f, d: benchmark_pool_imap_unordered(f, d, chunksize=chunksize) + stats = self.benchmark_method( + "pool_imap_unordered", imap_unord_func, task_func, test_data, runs_per_method + ) + self._record_result("pool_imap_unordered", task_name, data_size, stats) + + # run_maybe_parallel (ordered) + rmp_func = lambda f, d: benchmark_run_maybe_parallel(f, d, parallel=True) + stats = self.benchmark_method( + "run_maybe_parallel", rmp_func, task_func, test_data, runs_per_method + ) + self._record_result("run_maybe_parallel", task_name, data_size, stats) + + # run_maybe_parallel (unordered) + rmp_unord_func = lambda f, d: benchmark_run_maybe_parallel(f, d, parallel=True, keep_ordered=False) + stats = self.benchmark_method( + "run_maybe_parallel_unordered", rmp_unord_func, task_func, test_data, runs_per_method + ) + self._record_result("run_maybe_parallel_unordered", task_name, data_size, stats) + + return self._create_dataframe() + + def _record_result(self, method: str, task: str, data_size: int, stats: Dict[str, float]): + """Record benchmark result.""" + self.results['method'].append(method) + self.results['task'].append(task) + self.results['data_size'].append(data_size) + self.results['mean_time'].append(stats['mean']) + self.results['std_time'].append(stats['std']) + self.results['min_time'].append(stats['min']) + self.results['max_time'].append(stats['max']) + self.results['median_time'].append(stats['median']) + + def _create_dataframe(self) -> pd.DataFrame: + """Create DataFrame from results.""" + df = pd.DataFrame(self.results) + + # Calculate speedup relative to sequential + sequential_times = df[df['method'] == 'sequential'][['task', 'data_size', 'mean_time']] + sequential_times = sequential_times.rename(columns={'mean_time': 'sequential_time'}) + + df = df.merge(sequential_times, on=['task', 'data_size']) + df['speedup'] = df['sequential_time'] / df['mean_time'] + + return df + + +def benchmark_sequential(func: Callable, data: List[int]) -> tuple[List[Any], float]: + """Benchmark sequential processing.""" + start = time.perf_counter() + results = [func(x) for x in data] + end = time.perf_counter() + return results, end - start + + +def benchmark_pool_map(func: Callable, data: List[int], processes: int = None) -> tuple[List[Any], float]: + """Benchmark using multiprocessing.Pool.map.""" + start = time.perf_counter() + with multiprocessing.Pool(processes) as pool: + results = pool.map(func, data) + end = time.perf_counter() + return results, end - start + + +def benchmark_pool_imap(func: Callable, data: List[int], processes: int = None, chunksize: int = 1) -> tuple[List[Any], float]: + """Benchmark using multiprocessing.Pool.imap.""" + start = time.perf_counter() + with multiprocessing.Pool(processes) as pool: + results = list(pool.imap(func, data, chunksize=chunksize)) + end = time.perf_counter() + return results, end - start + + +def benchmark_pool_imap_unordered(func: Callable, data: List[int], processes: int = None, chunksize: int = 1) -> tuple[List[Any], float]: + """Benchmark using multiprocessing.Pool.imap_unordered.""" + start = time.perf_counter() + with multiprocessing.Pool(processes) as pool: + results = list(pool.imap_unordered(func, data, chunksize=chunksize)) + end = time.perf_counter() + return results, end - start + + +def benchmark_run_maybe_parallel(func: Callable, data: List[int], parallel: bool | int, keep_ordered: bool = True, chunksize: int = None) -> tuple[List[Any], float]: + """Benchmark using run_maybe_parallel.""" + start = time.perf_counter() + results = run_maybe_parallel( + func=func, + iterable=data, + parallel=parallel, + keep_ordered=keep_ordered, + chunksize=chunksize, + pbar="none" # Disable progress bar for fair comparison + ) + end = time.perf_counter() + return results, end - start + + +def plot_speedup_by_data_size(df: pd.DataFrame, task_type: str = None, save_path: str = None): + """Plot speedup vs data size for different methods.""" + import matplotlib.pyplot as plt + + fig, ax = plt.subplots(figsize=(10, 6)) + + # Filter by task type if specified + plot_df = df[df['task'] == task_type] if task_type else df + + # Group by method and plot + for method in plot_df['method'].unique(): + if method == 'sequential': + continue + method_df = plot_df[plot_df['method'] == method] + ax.plot(method_df['data_size'], method_df['speedup'], + marker='o', label=method) + + ax.set_xlabel('Data Size') + ax.set_ylabel('Speedup (vs Sequential)') + ax.set_title(f'Speedup by Data Size{f" ({task_type} tasks)" if task_type else ""}') + ax.set_xscale('log') + ax.axhline(y=1, color='gray', linestyle='--', alpha=0.5) + ax.legend() + ax.grid(True, alpha=0.3) + + if save_path: + plt.savefig(save_path) + else: + plt.show() + + +def plot_timing_comparison(df: pd.DataFrame, data_size: int = None, save_path: str = None): + """Plot timing comparison as bar chart.""" + import matplotlib.pyplot as plt + + # Filter by data size if specified + plot_df = df[df['data_size'] == data_size] if data_size else df + + # Pivot for easier plotting + pivot_df = plot_df.pivot_table( + index='task', columns='method', values='mean_time' + ) + + ax = pivot_df.plot(kind='bar', figsize=(12, 6), rot=0) + ax.set_ylabel('Time (seconds)') + ax.set_title(f'Timing Comparison{f" (Data Size: {data_size})" if data_size else ""}') + ax.legend(title='Method', bbox_to_anchor=(1.05, 1), loc='upper left') + + if save_path: + plt.tight_layout() + plt.savefig(save_path) + else: + plt.show() + + +def plot_efficiency_heatmap(df: pd.DataFrame, save_path: str = None): + """Plot efficiency heatmap (speedup across methods and tasks).""" + import matplotlib.pyplot as plt + import seaborn as sns + + # Create pivot table for heatmap + pivot_df = df.pivot_table( + index=['task', 'data_size'], + columns='method', + values='speedup' + ) + + # Create heatmap + plt.figure(figsize=(12, 8)) + sns.heatmap(pivot_df, annot=True, fmt='.2f', cmap='YlOrRd', + vmin=0, center=1, cbar_kws={'label': 'Speedup'}) + plt.title('Parallelization Efficiency Heatmap') + plt.tight_layout() + + if save_path: + plt.savefig(save_path) + else: + plt.show() + + +def print_summary_stats(df: pd.DataFrame): + """Print summary statistics from benchmark results.""" + print("\n=== BENCHMARK SUMMARY ===") + print(f"\nTotal configurations tested: {len(df)}") + + # Best method by task type + print("\nBest methods by task type (highest average speedup):") + best_by_task = df[df['method'] != 'sequential'].groupby('task').apply( + lambda x: x.loc[x['speedup'].idxmax()][['method', 'speedup', 'data_size']] + ) + print(best_by_task) + + # Overall best speedups + print("\nTop 5 speedups achieved:") + top_speedups = df[df['method'] != 'sequential'].nlargest(5, 'speedup')[ + ['method', 'task', 'data_size', 'speedup', 'mean_time'] + ] + print(top_speedups) + + # Method rankings + print("\nAverage speedup by method:") + avg_speedup = df[df['method'] != 'sequential'].groupby('method')['speedup'].agg(['mean', 'std']) + print(avg_speedup.sort_values('mean', ascending=False)) + + +def main(): + """Run benchmarks and display results.""" + print("Starting parallelization benchmark...") + + # Configure benchmark parameters + data_sizes = [100, 1000, 5000, 10000] + task_funcs = { + 'cpu_bound': cpu_bound_task, + 'io_bound': io_bound_task, + 'light_cpu': light_cpu_task + } + + # Run benchmarks + runner = BenchmarkRunner() + df = runner.run_benchmark_suite(data_sizes, task_funcs, runs_per_method=3) + + # Save results + df.to_csv('benchmark_results.csv', index=False) + print("\nResults saved to benchmark_results.csv") + + # Display summary + print_summary_stats(df) + + # Create visualizations + import matplotlib + matplotlib.use('Agg') # Use non-interactive backend + + # Plot speedup by data size for each task type + for task in task_funcs.keys(): + plot_speedup_by_data_size(df, task, f'speedup_{task}.png') + print(f"Saved speedup plot for {task} tasks to speedup_{task}.png") + + # Plot timing comparison for largest data size + plot_timing_comparison(df, data_sizes[-1], 'timing_comparison.png') + print(f"Saved timing comparison to timing_comparison.png") + + # Plot efficiency heatmap + plot_efficiency_heatmap(df, 'efficiency_heatmap.png') + print("Saved efficiency heatmap to efficiency_heatmap.png") + + return df + + +if __name__ == "__main__": + df = main() + print("\nDataFrame columns:", df.columns.tolist()) + print("\nFirst few rows:") + print(df.head(10)) \ No newline at end of file diff --git a/tests/test_benchmark_demo.py b/tests/test_benchmark_demo.py new file mode 100644 index 00000000..172a8ab4 --- /dev/null +++ b/tests/test_benchmark_demo.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python3 +"""Simple demo of using the benchmark script.""" + +import sys +import os +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from tests.benchmark_parallel import ( + BenchmarkRunner, cpu_bound_task, io_bound_task, light_cpu_task, + print_summary_stats +) + + +def quick_benchmark(): + """Run a quick benchmark with small data sizes.""" + print("Running quick benchmark demo...\n") + + # Small data sizes for quick demo + data_sizes = [100, 500, 1000] + task_funcs = { + 'cpu_bound': cpu_bound_task, + 'io_bound': io_bound_task, + 'light_cpu': light_cpu_task + } + + # Run benchmarks + runner = BenchmarkRunner() + df = runner.run_benchmark_suite(data_sizes, task_funcs, runs_per_method=2) + + # Print results + print("\n" + "="*60) + print("BENCHMARK RESULTS DATAFRAME") + print("="*60) + print(df.to_string()) + + print("\n" + "="*60) + print_summary_stats(df) + + # Show example of filtering data + print("\n" + "="*60) + print("EXAMPLE: CPU-bound tasks only") + print("="*60) + cpu_df = df[df['task'] == 'cpu_bound'] + print(cpu_df[['method', 'data_size', 'mean_time', 'speedup']].to_string()) + + return df + + +if __name__ == "__main__": + df = quick_benchmark() \ No newline at end of file From d626f204ac1e8f5667528e57e2b968cd2b0e4435 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Tue, 28 Oct 2025 17:57:41 +0000 Subject: [PATCH 02/72] remove old demo --- demo.py | 44 -------------------------------------------- 1 file changed, 44 deletions(-) delete mode 100644 demo.py diff --git a/demo.py b/demo.py deleted file mode 100644 index 671940cb..00000000 --- a/demo.py +++ /dev/null @@ -1,44 +0,0 @@ -#%% -import torch -import numpy as np -from muutils.dbg import dbg_tensor - -# Different shapes -scalar = torch.tensor(42.0) -dbg_tensor(scalar) - -vector = torch.randn(10) -dbg_tensor(vector) - -matrix = torch.randn(5, 8) -dbg_tensor(matrix) - -# With NaN values -nan_tensor = torch.randn(100, 100) -nan_tensor[0:20, 0:20] = float('nan') -dbg_tensor(nan_tensor) - -# with Inf values -inf_tensor = torch.randn(100, 100) -inf_tensor[0:20, 0:20] = float('inf') -dbg_tensor(inf_tensor) - -# Different dtypes -bool_tensor = torch.rand(50, 50) > 0.5 -dbg_tensor(bool_tensor) - -int_tensor = torch.randint(-1000, 1000, (50, 50), dtype=torch.int32) -dbg_tensor(int_tensor) - -# CUDA if available -if torch.cuda.is_available(): - cuda_tensor = torch.randn(50, 50).cuda() - dbg_tensor(cuda_tensor) - -# NumPy -np_array = np.random.randn(50, 50) -dbg_tensor(np_array) - -# With gradients -grad_tensor = torch.randn(50, 50, requires_grad=True) -dbg_tensor(grad_tensor) \ No newline at end of file From 83e0ec10e2d1b3c2ba217356907049e15034c6e0 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Tue, 28 Oct 2025 17:59:05 +0000 Subject: [PATCH 03/72] fix type hints for legacy compat --- tests/benchmark_parallel.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/benchmark_parallel.py b/tests/benchmark_parallel.py index 8fe6743c..0af0773f 100644 --- a/tests/benchmark_parallel.py +++ b/tests/benchmark_parallel.py @@ -6,7 +6,7 @@ import time import multiprocessing -from typing import List, Callable, Any, Dict +from typing import List, Callable, Any, Dict, Tuple import pandas as pd import numpy as np from collections import defaultdict @@ -142,7 +142,7 @@ def _create_dataframe(self) -> pd.DataFrame: return df -def benchmark_sequential(func: Callable, data: List[int]) -> tuple[List[Any], float]: +def benchmark_sequential(func: Callable, data: List[int]) -> Tuple[List[Any], float]: """Benchmark sequential processing.""" start = time.perf_counter() results = [func(x) for x in data] @@ -150,7 +150,7 @@ def benchmark_sequential(func: Callable, data: List[int]) -> tuple[List[Any], fl return results, end - start -def benchmark_pool_map(func: Callable, data: List[int], processes: int = None) -> tuple[List[Any], float]: +def benchmark_pool_map(func: Callable, data: List[int], processes: int = None) -> Tuple[List[Any], float]: """Benchmark using multiprocessing.Pool.map.""" start = time.perf_counter() with multiprocessing.Pool(processes) as pool: @@ -159,7 +159,7 @@ def benchmark_pool_map(func: Callable, data: List[int], processes: int = None) - return results, end - start -def benchmark_pool_imap(func: Callable, data: List[int], processes: int = None, chunksize: int = 1) -> tuple[List[Any], float]: +def benchmark_pool_imap(func: Callable, data: List[int], processes: int = None, chunksize: int = 1) -> Tuple[List[Any], float]: """Benchmark using multiprocessing.Pool.imap.""" start = time.perf_counter() with multiprocessing.Pool(processes) as pool: @@ -168,7 +168,7 @@ def benchmark_pool_imap(func: Callable, data: List[int], processes: int = None, return results, end - start -def benchmark_pool_imap_unordered(func: Callable, data: List[int], processes: int = None, chunksize: int = 1) -> tuple[List[Any], float]: +def benchmark_pool_imap_unordered(func: Callable, data: List[int], processes: int = None, chunksize: int = 1) -> Tuple[List[Any], float]: """Benchmark using multiprocessing.Pool.imap_unordered.""" start = time.perf_counter() with multiprocessing.Pool(processes) as pool: @@ -177,7 +177,7 @@ def benchmark_pool_imap_unordered(func: Callable, data: List[int], processes: in return results, end - start -def benchmark_run_maybe_parallel(func: Callable, data: List[int], parallel: bool | int, keep_ordered: bool = True, chunksize: int = None) -> tuple[List[Any], float]: +def benchmark_run_maybe_parallel(func: Callable, data: List[int], parallel: bool | int, keep_ordered: bool = True, chunksize: int = None) -> Tuple[List[Any], float]: """Benchmark using run_maybe_parallel.""" start = time.perf_counter() results = run_maybe_parallel( From 46f635bf039b44a92eeaa44a6bc08c517cbd7fb0 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Tue, 28 Oct 2025 18:00:21 +0000 Subject: [PATCH 04/72] make format --- muutils/cli/arg_bool.py | 18 +- muutils/web/html_to_pdf.py | 3 +- tests/benchmark_parallel.py | 320 +++++++++++++++++++++-------------- tests/test_benchmark_demo.py | 42 ++--- 4 files changed, 234 insertions(+), 149 deletions(-) diff --git a/muutils/cli/arg_bool.py b/muutils/cli/arg_bool.py index e320fea7..94207535 100644 --- a/muutils/cli/arg_bool.py +++ b/muutils/cli/arg_bool.py @@ -1,9 +1,11 @@ import argparse from collections.abc import Callable, Iterable, Sequence -from typing import Any, Final, override +from typing import Any, Final, override, TypeVar +T_callable = TypeVar("T_callable", bound=Callable[..., Any]) -def format_function_docstring[T_callable: Callable[..., Any]]( + +def format_function_docstring( mapping: dict[str, Any], /, ) -> Callable[[T_callable], T_callable]: @@ -152,7 +154,9 @@ def __call__( return # pyright: ignore[reportUnreachable] if values is not None: dest_flag: str = self.dest.replace("_", "-") - parser.error(f"{option_string} does not take a value; use --{dest_flag} true|false") + parser.error( + f"{option_string} does not take a value; use --{dest_flag} true|false" + ) return # pyright: ignore[reportUnreachable] setattr(namespace, self.dest, False) return @@ -161,7 +165,9 @@ def __call__( if values is None: if not self.allow_bare: valid: list[str] = sorted(self.true_set | self.false_set) - parser.error(f"option {option_string} requires a value; expected one of {valid}") + parser.error( + f"option {option_string} requires a value; expected one of {valid}" + ) return # pyright: ignore[reportUnreachable] setattr(namespace, self.dest, True) return @@ -244,7 +250,9 @@ def add_bool_flag( tokens_preview: str = "{true,false}" readable_name: str = name.replace("-", " ") - arg_help: str = help or (f"enable/disable {readable_name}; also accepts explicit true|false") + arg_help: str = help or ( + f"enable/disable {readable_name}; also accepts explicit true|false" + ) parser.add_argument( *option_strings, diff --git a/muutils/web/html_to_pdf.py b/muutils/web/html_to_pdf.py index f8e6990e..2728c92c 100644 --- a/muutils/web/html_to_pdf.py +++ b/muutils/web/html_to_pdf.py @@ -3,6 +3,7 @@ from weasyprint import HTML as WeasyHTML + def html_to_pdf(src: Path, dst: Path) -> None: "write HTML file to PDF using WeasyPrint." WeasyHTML(filename=src.as_posix()).write_pdf(dst.as_posix()) @@ -34,4 +35,4 @@ def save_html_to_pdf( crop(pdf_out, pdf_out, margin_pt) # Clean up temporary HTML file - temp_html.unlink(missing_ok=True) \ No newline at end of file + temp_html.unlink(missing_ok=True) diff --git a/tests/benchmark_parallel.py b/tests/benchmark_parallel.py index 0af0773f..887a4a91 100644 --- a/tests/benchmark_parallel.py +++ b/tests/benchmark_parallel.py @@ -31,114 +31,157 @@ def io_bound_task(x: int) -> int: def light_cpu_task(x: int) -> int: """Light CPU task for benchmarking.""" - return x ** 2 + x * 3 + 7 + return x**2 + x * 3 + 7 class BenchmarkRunner: """Run benchmarks and collect timing data.""" - + def __init__(self): self.results = defaultdict(list) self.cpu_count = multiprocessing.cpu_count() - + def time_execution(self, func: Callable, *args, **kwargs) -> float: """Time a single execution.""" start = time.perf_counter() func(*args, **kwargs) return time.perf_counter() - start - - def benchmark_method(self, method_name: str, method_func: Callable, - task_func: Callable, data: List[int], - runs: int = 3) -> Dict[str, float]: + + def benchmark_method( + self, + method_name: str, + method_func: Callable, + task_func: Callable, + data: List[int], + runs: int = 3, + ) -> Dict[str, float]: """Benchmark a single method multiple times.""" times = [] for _ in range(runs): _, duration = method_func(task_func, data) times.append(duration) - + return { - 'mean': np.mean(times), - 'std': np.std(times), - 'min': np.min(times), - 'max': np.max(times), - 'median': np.median(times) + "mean": np.mean(times), + "std": np.std(times), + "min": np.min(times), + "max": np.max(times), + "median": np.median(times), } - - def run_benchmark_suite(self, data_sizes: List[int], task_funcs: Dict[str, Callable], - runs_per_method: int = 3) -> pd.DataFrame: + + def run_benchmark_suite( + self, + data_sizes: List[int], + task_funcs: Dict[str, Callable], + runs_per_method: int = 3, + ) -> pd.DataFrame: """Run complete benchmark suite and return results as DataFrame.""" - + for data_size in data_sizes: test_data = list(range(data_size)) - + for task_name, task_func in task_funcs.items(): print(f"\nBenchmarking {task_name} with {data_size} items...") - + # Sequential baseline stats = self.benchmark_method( - "sequential", benchmark_sequential, task_func, test_data, runs_per_method + "sequential", + benchmark_sequential, + task_func, + test_data, + runs_per_method, ) self._record_result("sequential", task_name, data_size, stats) - + # Pool.map stats = self.benchmark_method( - "pool_map", benchmark_pool_map, task_func, test_data, runs_per_method + "pool_map", + benchmark_pool_map, + task_func, + test_data, + runs_per_method, ) self._record_result("pool_map", task_name, data_size, stats) - + # Pool.imap with optimal chunk size chunksize = max(1, data_size // (self.cpu_count * 4)) - imap_func = lambda f, d: benchmark_pool_imap(f, d, chunksize=chunksize) + imap_func = lambda f, d: benchmark_pool_imap(f, d, chunksize=chunksize) # noqa: E731 stats = self.benchmark_method( "pool_imap", imap_func, task_func, test_data, runs_per_method ) self._record_result("pool_imap", task_name, data_size, stats) - + # Pool.imap_unordered - imap_unord_func = lambda f, d: benchmark_pool_imap_unordered(f, d, chunksize=chunksize) + imap_unord_func = lambda f, d: benchmark_pool_imap_unordered( # noqa: E731 + f, d, chunksize=chunksize + ) stats = self.benchmark_method( - "pool_imap_unordered", imap_unord_func, task_func, test_data, runs_per_method + "pool_imap_unordered", + imap_unord_func, + task_func, + test_data, + runs_per_method, ) self._record_result("pool_imap_unordered", task_name, data_size, stats) - + # run_maybe_parallel (ordered) - rmp_func = lambda f, d: benchmark_run_maybe_parallel(f, d, parallel=True) + rmp_func = lambda f, d: benchmark_run_maybe_parallel( # noqa: E731 + f, d, parallel=True + ) stats = self.benchmark_method( - "run_maybe_parallel", rmp_func, task_func, test_data, runs_per_method + "run_maybe_parallel", + rmp_func, + task_func, + test_data, + runs_per_method, ) self._record_result("run_maybe_parallel", task_name, data_size, stats) - + # run_maybe_parallel (unordered) - rmp_unord_func = lambda f, d: benchmark_run_maybe_parallel(f, d, parallel=True, keep_ordered=False) + rmp_unord_func = lambda f, d: benchmark_run_maybe_parallel( # noqa: E731 + f, d, parallel=True, keep_ordered=False + ) stats = self.benchmark_method( - "run_maybe_parallel_unordered", rmp_unord_func, task_func, test_data, runs_per_method + "run_maybe_parallel_unordered", + rmp_unord_func, + task_func, + test_data, + runs_per_method, ) - self._record_result("run_maybe_parallel_unordered", task_name, data_size, stats) - + self._record_result( + "run_maybe_parallel_unordered", task_name, data_size, stats + ) + return self._create_dataframe() - - def _record_result(self, method: str, task: str, data_size: int, stats: Dict[str, float]): + + def _record_result( + self, method: str, task: str, data_size: int, stats: Dict[str, float] + ): """Record benchmark result.""" - self.results['method'].append(method) - self.results['task'].append(task) - self.results['data_size'].append(data_size) - self.results['mean_time'].append(stats['mean']) - self.results['std_time'].append(stats['std']) - self.results['min_time'].append(stats['min']) - self.results['max_time'].append(stats['max']) - self.results['median_time'].append(stats['median']) - + self.results["method"].append(method) + self.results["task"].append(task) + self.results["data_size"].append(data_size) + self.results["mean_time"].append(stats["mean"]) + self.results["std_time"].append(stats["std"]) + self.results["min_time"].append(stats["min"]) + self.results["max_time"].append(stats["max"]) + self.results["median_time"].append(stats["median"]) + def _create_dataframe(self) -> pd.DataFrame: """Create DataFrame from results.""" df = pd.DataFrame(self.results) - + # Calculate speedup relative to sequential - sequential_times = df[df['method'] == 'sequential'][['task', 'data_size', 'mean_time']] - sequential_times = sequential_times.rename(columns={'mean_time': 'sequential_time'}) - - df = df.merge(sequential_times, on=['task', 'data_size']) - df['speedup'] = df['sequential_time'] / df['mean_time'] - + sequential_times = df[df["method"] == "sequential"][ + ["task", "data_size", "mean_time"] + ] + sequential_times = sequential_times.rename( + columns={"mean_time": "sequential_time"} + ) + + df = df.merge(sequential_times, on=["task", "data_size"]) + df["speedup"] = df["sequential_time"] / df["mean_time"] + return df @@ -150,7 +193,9 @@ def benchmark_sequential(func: Callable, data: List[int]) -> Tuple[List[Any], fl return results, end - start -def benchmark_pool_map(func: Callable, data: List[int], processes: int = None) -> Tuple[List[Any], float]: +def benchmark_pool_map( + func: Callable, data: List[int], processes: int = None +) -> Tuple[List[Any], float]: """Benchmark using multiprocessing.Pool.map.""" start = time.perf_counter() with multiprocessing.Pool(processes) as pool: @@ -159,7 +204,9 @@ def benchmark_pool_map(func: Callable, data: List[int], processes: int = None) - return results, end - start -def benchmark_pool_imap(func: Callable, data: List[int], processes: int = None, chunksize: int = 1) -> Tuple[List[Any], float]: +def benchmark_pool_imap( + func: Callable, data: List[int], processes: int = None, chunksize: int = 1 +) -> Tuple[List[Any], float]: """Benchmark using multiprocessing.Pool.imap.""" start = time.perf_counter() with multiprocessing.Pool(processes) as pool: @@ -168,7 +215,9 @@ def benchmark_pool_imap(func: Callable, data: List[int], processes: int = None, return results, end - start -def benchmark_pool_imap_unordered(func: Callable, data: List[int], processes: int = None, chunksize: int = 1) -> Tuple[List[Any], float]: +def benchmark_pool_imap_unordered( + func: Callable, data: List[int], processes: int = None, chunksize: int = 1 +) -> Tuple[List[Any], float]: """Benchmark using multiprocessing.Pool.imap_unordered.""" start = time.perf_counter() with multiprocessing.Pool(processes) as pool: @@ -177,7 +226,13 @@ def benchmark_pool_imap_unordered(func: Callable, data: List[int], processes: in return results, end - start -def benchmark_run_maybe_parallel(func: Callable, data: List[int], parallel: bool | int, keep_ordered: bool = True, chunksize: int = None) -> Tuple[List[Any], float]: +def benchmark_run_maybe_parallel( + func: Callable, + data: List[int], + parallel: bool | int, + keep_ordered: bool = True, + chunksize: int = None, +) -> Tuple[List[Any], float]: """Benchmark using run_maybe_parallel.""" start = time.perf_counter() results = run_maybe_parallel( @@ -186,60 +241,63 @@ def benchmark_run_maybe_parallel(func: Callable, data: List[int], parallel: bool parallel=parallel, keep_ordered=keep_ordered, chunksize=chunksize, - pbar="none" # Disable progress bar for fair comparison + pbar="none", # Disable progress bar for fair comparison ) end = time.perf_counter() return results, end - start -def plot_speedup_by_data_size(df: pd.DataFrame, task_type: str = None, save_path: str = None): +def plot_speedup_by_data_size( + df: pd.DataFrame, task_type: str = None, save_path: str = None +): """Plot speedup vs data size for different methods.""" import matplotlib.pyplot as plt - + fig, ax = plt.subplots(figsize=(10, 6)) - + # Filter by task type if specified - plot_df = df[df['task'] == task_type] if task_type else df - + plot_df = df[df["task"] == task_type] if task_type else df + # Group by method and plot - for method in plot_df['method'].unique(): - if method == 'sequential': + for method in plot_df["method"].unique(): + if method == "sequential": continue - method_df = plot_df[plot_df['method'] == method] - ax.plot(method_df['data_size'], method_df['speedup'], - marker='o', label=method) - - ax.set_xlabel('Data Size') - ax.set_ylabel('Speedup (vs Sequential)') - ax.set_title(f'Speedup by Data Size{f" ({task_type} tasks)" if task_type else ""}') - ax.set_xscale('log') - ax.axhline(y=1, color='gray', linestyle='--', alpha=0.5) + method_df = plot_df[plot_df["method"] == method] + ax.plot(method_df["data_size"], method_df["speedup"], marker="o", label=method) + + ax.set_xlabel("Data Size") + ax.set_ylabel("Speedup (vs Sequential)") + ax.set_title(f"Speedup by Data Size{f' ({task_type} tasks)' if task_type else ''}") + ax.set_xscale("log") + ax.axhline(y=1, color="gray", linestyle="--", alpha=0.5) ax.legend() ax.grid(True, alpha=0.3) - + if save_path: plt.savefig(save_path) else: plt.show() -def plot_timing_comparison(df: pd.DataFrame, data_size: int = None, save_path: str = None): +def plot_timing_comparison( + df: pd.DataFrame, data_size: int = None, save_path: str = None +): """Plot timing comparison as bar chart.""" import matplotlib.pyplot as plt - + # Filter by data size if specified - plot_df = df[df['data_size'] == data_size] if data_size else df - + plot_df = df[df["data_size"] == data_size] if data_size else df + # Pivot for easier plotting - pivot_df = plot_df.pivot_table( - index='task', columns='method', values='mean_time' + pivot_df = plot_df.pivot_table(index="task", columns="method", values="mean_time") + + ax = pivot_df.plot(kind="bar", figsize=(12, 6), rot=0) + ax.set_ylabel("Time (seconds)") + ax.set_title( + f"Timing Comparison{f' (Data Size: {data_size})' if data_size else ''}" ) - - ax = pivot_df.plot(kind='bar', figsize=(12, 6), rot=0) - ax.set_ylabel('Time (seconds)') - ax.set_title(f'Timing Comparison{f" (Data Size: {data_size})" if data_size else ""}') - ax.legend(title='Method', bbox_to_anchor=(1.05, 1), loc='upper left') - + ax.legend(title="Method", bbox_to_anchor=(1.05, 1), loc="upper left") + if save_path: plt.tight_layout() plt.savefig(save_path) @@ -251,21 +309,26 @@ def plot_efficiency_heatmap(df: pd.DataFrame, save_path: str = None): """Plot efficiency heatmap (speedup across methods and tasks).""" import matplotlib.pyplot as plt import seaborn as sns - + # Create pivot table for heatmap pivot_df = df.pivot_table( - index=['task', 'data_size'], - columns='method', - values='speedup' + index=["task", "data_size"], columns="method", values="speedup" ) - + # Create heatmap plt.figure(figsize=(12, 8)) - sns.heatmap(pivot_df, annot=True, fmt='.2f', cmap='YlOrRd', - vmin=0, center=1, cbar_kws={'label': 'Speedup'}) - plt.title('Parallelization Efficiency Heatmap') + sns.heatmap( + pivot_df, + annot=True, + fmt=".2f", + cmap="YlOrRd", + vmin=0, + center=1, + cbar_kws={"label": "Speedup"}, + ) + plt.title("Parallelization Efficiency Heatmap") plt.tight_layout() - + if save_path: plt.savefig(save_path) else: @@ -276,67 +339,76 @@ def print_summary_stats(df: pd.DataFrame): """Print summary statistics from benchmark results.""" print("\n=== BENCHMARK SUMMARY ===") print(f"\nTotal configurations tested: {len(df)}") - + # Best method by task type print("\nBest methods by task type (highest average speedup):") - best_by_task = df[df['method'] != 'sequential'].groupby('task').apply( - lambda x: x.loc[x['speedup'].idxmax()][['method', 'speedup', 'data_size']] + best_by_task = ( + df[df["method"] != "sequential"] + .groupby("task") + .apply( + lambda x: x.loc[x["speedup"].idxmax()][["method", "speedup", "data_size"]] + ) ) print(best_by_task) - + # Overall best speedups print("\nTop 5 speedups achieved:") - top_speedups = df[df['method'] != 'sequential'].nlargest(5, 'speedup')[ - ['method', 'task', 'data_size', 'speedup', 'mean_time'] + top_speedups = df[df["method"] != "sequential"].nlargest(5, "speedup")[ + ["method", "task", "data_size", "speedup", "mean_time"] ] print(top_speedups) - + # Method rankings print("\nAverage speedup by method:") - avg_speedup = df[df['method'] != 'sequential'].groupby('method')['speedup'].agg(['mean', 'std']) - print(avg_speedup.sort_values('mean', ascending=False)) + avg_speedup = ( + df[df["method"] != "sequential"] + .groupby("method")["speedup"] + .agg(["mean", "std"]) + ) + print(avg_speedup.sort_values("mean", ascending=False)) def main(): """Run benchmarks and display results.""" print("Starting parallelization benchmark...") - + # Configure benchmark parameters data_sizes = [100, 1000, 5000, 10000] task_funcs = { - 'cpu_bound': cpu_bound_task, - 'io_bound': io_bound_task, - 'light_cpu': light_cpu_task + "cpu_bound": cpu_bound_task, + "io_bound": io_bound_task, + "light_cpu": light_cpu_task, } - + # Run benchmarks runner = BenchmarkRunner() df = runner.run_benchmark_suite(data_sizes, task_funcs, runs_per_method=3) - + # Save results - df.to_csv('benchmark_results.csv', index=False) + df.to_csv("benchmark_results.csv", index=False) print("\nResults saved to benchmark_results.csv") - + # Display summary print_summary_stats(df) - + # Create visualizations import matplotlib - matplotlib.use('Agg') # Use non-interactive backend - + + matplotlib.use("Agg") # Use non-interactive backend + # Plot speedup by data size for each task type for task in task_funcs.keys(): - plot_speedup_by_data_size(df, task, f'speedup_{task}.png') + plot_speedup_by_data_size(df, task, f"speedup_{task}.png") print(f"Saved speedup plot for {task} tasks to speedup_{task}.png") - + # Plot timing comparison for largest data size - plot_timing_comparison(df, data_sizes[-1], 'timing_comparison.png') - print(f"Saved timing comparison to timing_comparison.png") - + plot_timing_comparison(df, data_sizes[-1], "timing_comparison.png") + print("Saved timing comparison to timing_comparison.png") + # Plot efficiency heatmap - plot_efficiency_heatmap(df, 'efficiency_heatmap.png') + plot_efficiency_heatmap(df, "efficiency_heatmap.png") print("Saved efficiency heatmap to efficiency_heatmap.png") - + return df @@ -344,4 +416,4 @@ def main(): df = main() print("\nDataFrame columns:", df.columns.tolist()) print("\nFirst few rows:") - print(df.head(10)) \ No newline at end of file + print(df.head(10)) diff --git a/tests/test_benchmark_demo.py b/tests/test_benchmark_demo.py index 172a8ab4..7c496d0a 100644 --- a/tests/test_benchmark_demo.py +++ b/tests/test_benchmark_demo.py @@ -3,48 +3,52 @@ import sys import os + sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from tests.benchmark_parallel import ( - BenchmarkRunner, cpu_bound_task, io_bound_task, light_cpu_task, - print_summary_stats + BenchmarkRunner, + cpu_bound_task, + io_bound_task, + light_cpu_task, + print_summary_stats, ) def quick_benchmark(): """Run a quick benchmark with small data sizes.""" print("Running quick benchmark demo...\n") - + # Small data sizes for quick demo data_sizes = [100, 500, 1000] task_funcs = { - 'cpu_bound': cpu_bound_task, - 'io_bound': io_bound_task, - 'light_cpu': light_cpu_task + "cpu_bound": cpu_bound_task, + "io_bound": io_bound_task, + "light_cpu": light_cpu_task, } - + # Run benchmarks runner = BenchmarkRunner() df = runner.run_benchmark_suite(data_sizes, task_funcs, runs_per_method=2) - + # Print results - print("\n" + "="*60) + print("\n" + "=" * 60) print("BENCHMARK RESULTS DATAFRAME") - print("="*60) + print("=" * 60) print(df.to_string()) - - print("\n" + "="*60) + + print("\n" + "=" * 60) print_summary_stats(df) - + # Show example of filtering data - print("\n" + "="*60) + print("\n" + "=" * 60) print("EXAMPLE: CPU-bound tasks only") - print("="*60) - cpu_df = df[df['task'] == 'cpu_bound'] - print(cpu_df[['method', 'data_size', 'mean_time', 'speedup']].to_string()) - + print("=" * 60) + cpu_df = df[df["task"] == "cpu_bound"] + print(cpu_df[["method", "data_size", "mean_time", "speedup"]].to_string()) + return df if __name__ == "__main__": - df = quick_benchmark() \ No newline at end of file + df = quick_benchmark() From 8446d8e6b9a4ee5cae98f6463bddbe5df4b7985f Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Tue, 28 Oct 2025 18:04:38 +0000 Subject: [PATCH 05/72] add Command --- muutils/cli/command.py | 94 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 94 insertions(+) create mode 100644 muutils/cli/command.py diff --git a/muutils/cli/command.py b/muutils/cli/command.py new file mode 100644 index 00000000..c0aee11f --- /dev/null +++ b/muutils/cli/command.py @@ -0,0 +1,94 @@ +from __future__ import annotations + +import os +import subprocess +import sys +from dataclasses import dataclass +from typing import Any, List, Union + + +@dataclass +class Command: + """Simple typed command with shell flag and subprocess helpers.""" + + cmd: Union[List[str], str] + shell: bool = False + env: dict[str, str] | None = None + inherit_env: bool = True + + def __post_init__(self) -> None: + """Enforce cmd type when shell is False.""" + if self.shell is False and isinstance(self.cmd, str): + raise ValueError("cmd must be List[str] when shell is False") + + def _quote_env(self) -> str: + """Return KEY=VAL tokens for env values. ignores `inherit_env`.""" + if not self.env: + return "" + + parts: List[str] = [] + for k, v in self.env.items(): + token: str = f"{k}={v}" + parts.append(token) + prefix: str = " ".join(parts) + return prefix + + @property + def cmd_joined(self) -> str: + """Return cmd as a single string, joining with spaces if it's a list. no env included.""" + if isinstance(self.cmd, str): + return self.cmd + else: + return " ".join(self.cmd) + + @property + def cmd_for_subprocess(self) -> Union[List[str], str]: + """Return cmd, splitting if shell is True and cmd is a string.""" + if self.shell: + if isinstance(self.cmd, str): + return self.cmd + else: + return " ".join(self.cmd) + else: + assert isinstance(self.cmd, list) + return self.cmd + + def script_line(self) -> str: + """Return a single shell string, prefixing KEY=VAL for env if provided.""" + return f"{self._quote_env()} {self.cmd_joined}".strip() + + @property + def env_final(self) -> dict[str, str]: + """Return final env dict, merging with os.environ if inherit_env is True.""" + return { + **(os.environ if self.inherit_env else {}), + **(self.env or {}), + } + + def run( + self, + **kwargs: Any, + ) -> subprocess.CompletedProcess[Any]: + """Call subprocess.run with this command.""" + try: + return subprocess.run( + self.cmd_for_subprocess, + shell=self.shell, + env=self.env_final, + **kwargs, + ) + except subprocess.CalledProcessError as e: + print(f"Command failed: `{self.script_line()}`", file=sys.stderr) + raise e + + def Popen( + self, + **kwargs: Any, + ) -> subprocess.Popen[Any]: + """Call subprocess.Popen with this command.""" + return subprocess.Popen( + self.cmd_for_subprocess, + shell=self.shell, + env=self.env_final, + **kwargs, + ) From 534612af815e2c294776b6738f557136f5153e6a Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Tue, 28 Oct 2025 18:05:42 +0000 Subject: [PATCH 06/72] fix type hint --- tests/benchmark_parallel.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/benchmark_parallel.py b/tests/benchmark_parallel.py index 887a4a91..b5174160 100644 --- a/tests/benchmark_parallel.py +++ b/tests/benchmark_parallel.py @@ -6,7 +6,7 @@ import time import multiprocessing -from typing import List, Callable, Any, Dict, Tuple +from typing import List, Callable, Any, Dict, Tuple, Union import pandas as pd import numpy as np from collections import defaultdict @@ -229,7 +229,7 @@ def benchmark_pool_imap_unordered( def benchmark_run_maybe_parallel( func: Callable, data: List[int], - parallel: bool | int, + parallel: Union[bool, int], keep_ordered: bool = True, chunksize: int = None, ) -> Tuple[List[Any], float]: From 98b6a1d7445f6564218ade25186bacd9a5c82658 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Tue, 28 Oct 2025 18:11:10 +0000 Subject: [PATCH 07/72] try to fix type checking --- pyproject.toml | 1 + tests/benchmark_parallel.py | 12 ++++++------ 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 988437f6..a6f983ab 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -142,6 +142,7 @@ "tests/input_data", "tests/junk_data", "tests/_temp/", + "tests/benchmark_parallel.py", # wip stuff "_wip/", # not our problem diff --git a/tests/benchmark_parallel.py b/tests/benchmark_parallel.py index b5174160..d1a065f8 100644 --- a/tests/benchmark_parallel.py +++ b/tests/benchmark_parallel.py @@ -7,7 +7,7 @@ import time import multiprocessing from typing import List, Callable, Any, Dict, Tuple, Union -import pandas as pd +import pandas as pd # type: ignore[import-untyped] import numpy as np from collections import defaultdict @@ -251,7 +251,7 @@ def plot_speedup_by_data_size( df: pd.DataFrame, task_type: str = None, save_path: str = None ): """Plot speedup vs data size for different methods.""" - import matplotlib.pyplot as plt + import matplotlib.pyplot as plt # type: ignore[import-untyped] fig, ax = plt.subplots(figsize=(10, 6)) @@ -283,7 +283,7 @@ def plot_timing_comparison( df: pd.DataFrame, data_size: int = None, save_path: str = None ): """Plot timing comparison as bar chart.""" - import matplotlib.pyplot as plt + import matplotlib.pyplot as plt # type: ignore[import-untyped] # Filter by data size if specified plot_df = df[df["data_size"] == data_size] if data_size else df @@ -307,8 +307,8 @@ def plot_timing_comparison( def plot_efficiency_heatmap(df: pd.DataFrame, save_path: str = None): """Plot efficiency heatmap (speedup across methods and tasks).""" - import matplotlib.pyplot as plt - import seaborn as sns + import matplotlib.pyplot as plt # type: ignore[import-untyped] + import seaborn as sns # type: ignore[import-untyped] # Create pivot table for heatmap pivot_df = df.pivot_table( @@ -392,7 +392,7 @@ def main(): print_summary_stats(df) # Create visualizations - import matplotlib + import matplotlib # type: ignore[import-untyped] matplotlib.use("Agg") # Use non-interactive backend From e503b9087a3d091b9e230291eedc047ee970c444 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Tue, 28 Oct 2025 18:18:54 +0000 Subject: [PATCH 08/72] type fixes --- muutils/cli/arg_bool.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/muutils/cli/arg_bool.py b/muutils/cli/arg_bool.py index 94207535..e8ab91d8 100644 --- a/muutils/cli/arg_bool.py +++ b/muutils/cli/arg_bool.py @@ -112,8 +112,8 @@ def __init__( **kwargs: bool | set[str] | None, ) -> None: # Extract custom kwargs before calling super().__init__ - true_set_opt: set[str] | None = kwargs.pop("true_set", None) # pyright: ignore[reportAssignmentType] - false_set_opt: set[str] | None = kwargs.pop("false_set", None) # pyright: ignore[reportAssignmentType] + true_set_opt: set[str] | None = kwargs.pop("true_set", None) # type: ignore[assignment,misc] + false_set_opt: set[str] | None = kwargs.pop("false_set", None) # type: ignore[assignment,misc] allow_no_opt: bool = bool(kwargs.pop("allow_no", True)) allow_bare_opt: bool = bool(kwargs.pop("allow_bare", True)) @@ -127,7 +127,7 @@ def __init__( option_strings=option_strings, dest=dest, nargs="?", - **kwargs, # pyright: ignore[reportArgumentType] + **kwargs, # type: ignore[arg-type] ) # Store normalized config self.true_set: set[str] = _normalize_set(true_set_opt, TRUE_SET_DEFAULT) @@ -151,13 +151,13 @@ def __call__( if option_string is not None and option_string.startswith("--no-"): if not self.allow_no: parser.error(f"{option_string} is not allowed for this option") - return # pyright: ignore[reportUnreachable] + return if values is not None: dest_flag: str = self.dest.replace("_", "-") parser.error( f"{option_string} does not take a value; use --{dest_flag} true|false" ) - return # pyright: ignore[reportUnreachable] + return setattr(namespace, self.dest, False) return @@ -168,7 +168,7 @@ def __call__( parser.error( f"option {option_string} requires a value; expected one of {valid}" ) - return # pyright: ignore[reportUnreachable] + return setattr(namespace, self.dest, True) return @@ -178,7 +178,7 @@ def __call__( parser.error( f"{option_string} expects a single value, got {len(values) = }, {values = }" ) - return # pyright: ignore[reportUnreachable] + return values = values[0] # type: ignore[assignment] # Positive flag with explicit value -> parse @@ -186,7 +186,7 @@ def __call__( val: bool = self._parse_token(values) except argparse.ArgumentTypeError as e: parser.error(str(e)) - return # pyright: ignore[reportUnreachable] + return setattr(namespace, self.dest, val) From eb5fda639c1862e6b142be4694e763ca5b595692 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Tue, 28 Oct 2025 18:22:01 +0000 Subject: [PATCH 09/72] fix @override import --- muutils/cli/arg_bool.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/muutils/cli/arg_bool.py b/muutils/cli/arg_bool.py index e8ab91d8..309fbf81 100644 --- a/muutils/cli/arg_bool.py +++ b/muutils/cli/arg_bool.py @@ -1,6 +1,12 @@ import argparse +import sys from collections.abc import Callable, Iterable, Sequence -from typing import Any, Final, override, TypeVar +from typing import Any, Final, TypeVar + +if sys.version_info >= (3, 12): + from typing import override +else: + from typing_extensions import override T_callable = TypeVar("T_callable", bound=Callable[..., Any]) From e85fa01286e892bddad411987bab40897e6e4eeb Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Tue, 28 Oct 2025 18:34:21 +0000 Subject: [PATCH 10/72] weasyprint fixes --- .meta/requirements/requirements-all.txt | 62 +++- .meta/requirements/requirements-extras.txt | 62 +++- .meta/requirements/requirements.txt | 64 +++- muutils/web/html_to_pdf.py | 2 +- pyproject.toml | 4 + uv.lock | 384 ++++++++++++++++++++- 6 files changed, 560 insertions(+), 18 deletions(-) diff --git a/.meta/requirements/requirements-all.txt b/.meta/requirements/requirements-all.txt index 3fcfb22b..fb40ed6c 100644 --- a/.meta/requirements/requirements-all.txt +++ b/.meta/requirements/requirements-all.txt @@ -60,6 +60,10 @@ bleach==6.3.0 ; python_full_version >= '3.10' # via nbconvert bracex==2.6 ; python_full_version >= '3.11' # via wcmatch +brotli==1.1.0 ; platform_python_implementation == 'CPython' + # via fonttools +brotlicffi==1.1.0.0 ; platform_python_implementation != 'CPython' + # via fonttools certifi==2025.10.5 # via # httpcore @@ -68,13 +72,17 @@ certifi==2025.10.5 cffi==1.17.1 ; python_full_version < '3.9' # via # argon2-cffi-bindings + # brotlicffi # cryptography # pyzmq + # weasyprint cffi==2.0.0 ; python_full_version >= '3.9' # via # argon2-cffi-bindings + # brotlicffi # cryptography # pyzmq + # weasyprint charset-normalizer==3.4.4 # via requests cli-exit-tools==1.2.7 ; python_full_version >= '3.11' @@ -119,6 +127,10 @@ coverage==7.11.0 ; python_full_version >= '3.10' coverage-badge==1.1.2 cryptography==46.0.3 ; platform_machine != 'ppc64le' and platform_machine != 's390x' and sys_platform == 'linux' # via secretstorage +cssselect2==0.7.0 ; python_full_version < '3.9' + # via weasyprint +cssselect2==0.8.0 ; python_full_version >= '3.9' + # via weasyprint cycler==0.12.1 # via matplotlib debugpy==1.8.17 @@ -151,9 +163,13 @@ filelock==3.19.1 ; python_full_version == '3.9.*' filelock==3.20.0 ; python_full_version >= '3.10' and python_full_version < '3.14' # via torch fonttools==4.57.0 ; python_full_version < '3.9' - # via matplotlib + # via + # matplotlib + # weasyprint fonttools==4.60.1 ; python_full_version >= '3.9' - # via matplotlib + # via + # matplotlib + # weasyprint fqdn==1.5.1 # via jsonschema fsspec==2025.3.0 ; python_full_version < '3.9' @@ -162,6 +178,8 @@ fsspec==2025.9.0 ; python_full_version >= '3.9' and python_full_version < '3.14' # via torch h11==0.16.0 # via httpcore +html5lib==1.1 ; python_full_version < '3.9' + # via weasyprint httpcore==1.0.9 # via httpx httpx==0.28.1 @@ -594,11 +612,17 @@ pexpect==4.9.0 ; (python_full_version < '3.10' and sys_platform == 'emscripten') pickleshare==0.7.5 ; python_full_version < '3.9' # via ipython pillow==10.4.0 ; python_full_version < '3.9' - # via matplotlib + # via + # matplotlib + # weasyprint pillow==11.3.0 ; python_full_version == '3.9.*' - # via matplotlib + # via + # matplotlib + # weasyprint pillow==12.0.0 ; python_full_version >= '3.10' - # via matplotlib + # via + # matplotlib + # weasyprint pkgutil-resolve-name==1.3.10 ; python_full_version < '3.9' # via jsonschema platformdirs==4.3.6 ; python_full_version < '3.9' @@ -634,6 +658,8 @@ pycln==2.5.0 ; python_full_version < '3.9' pycln==2.6.0 ; python_full_version >= '3.9' pycparser==2.23 ; python_full_version < '3.9' or implementation_name != 'PyPy' # via cffi +pydyf==0.11.0 + # via weasyprint pygments==2.19.2 # via # ipython @@ -648,6 +674,10 @@ pyparsing==3.1.4 ; python_full_version < '3.9' # via matplotlib pyparsing==3.2.5 ; python_full_version >= '3.9' # via matplotlib +pyphen==0.16.0 ; python_full_version < '3.9' + # via weasyprint +pyphen==0.17.2 ; python_full_version >= '3.9' + # via weasyprint pytest==8.3.5 ; python_full_version < '3.9' # via pytest-cov pytest==8.4.2 ; python_full_version >= '3.9' @@ -777,6 +807,7 @@ six==1.17.0 # via # astunparse # bleach + # html5lib # python-dateutil # rfc3339-validator sniffio==1.3.1 @@ -796,9 +827,17 @@ terminado==0.18.1 # jupyter-server # jupyter-server-terminals tinycss2==1.2.1 ; python_full_version < '3.9' - # via bleach + # via + # bleach + # cssselect2 + # weasyprint tinycss2==1.4.0 ; python_full_version >= '3.9' - # via bleach + # via + # bleach + # cssselect2 + # weasyprint +tinyhtml5==2.0.0 ; python_full_version >= '3.9' + # via weasyprint tomli==2.3.0 ; python_full_version <= '3.11' # via # coverage @@ -914,6 +953,10 @@ wcmatch==10.1 ; python_full_version >= '3.11' # via igittigitt wcwidth==0.2.14 # via prompt-toolkit +weasyprint==61.2 ; python_full_version < '3.9' + # via muutils +weasyprint==66.0 ; python_full_version >= '3.9' + # via muutils webcolors==24.8.0 ; python_full_version < '3.9' # via jsonschema webcolors==24.11.1 ; python_full_version >= '3.9' @@ -921,7 +964,10 @@ webcolors==24.11.1 ; python_full_version >= '3.9' webencodings==0.5.1 # via # bleach + # cssselect2 + # html5lib # tinycss2 + # tinyhtml5 websocket-client==1.8.0 ; python_full_version < '3.9' # via jupyter-server websocket-client==1.9.0 ; python_full_version >= '3.9' @@ -938,3 +984,5 @@ zipp==3.23.0 ; (python_full_version >= '3.9' and python_full_version < '3.12' an # via # importlib-metadata # importlib-resources +zopfli==0.2.3.post1 + # via fonttools diff --git a/.meta/requirements/requirements-extras.txt b/.meta/requirements/requirements-extras.txt index 8eeba0b1..acc63cd8 100644 --- a/.meta/requirements/requirements-extras.txt +++ b/.meta/requirements/requirements-extras.txt @@ -1,5 +1,5 @@ # This file was autogenerated by uv via the following command: -# uv export --no-hashes --no-group dev --no-group lint --extra array --extra array_no_torch --extra notebook --extra parallel +# uv export --no-hashes --no-group dev --no-group lint --extra array --extra array_no_torch --extra notebook --extra parallel --extra web -e . appnope==0.1.4 ; python_full_version < '3.9' and sys_platform == 'darwin' # via ipython @@ -7,10 +7,26 @@ asttokens==3.0.0 # via stack-data backcall==0.2.0 ; python_full_version < '3.9' # via ipython +brotli==1.1.0 ; platform_python_implementation == 'CPython' + # via fonttools +brotlicffi==1.1.0.0 ; platform_python_implementation != 'CPython' + # via fonttools +cffi==1.17.1 ; python_full_version < '3.9' + # via + # brotlicffi + # weasyprint +cffi==2.0.0 ; python_full_version >= '3.9' + # via + # brotlicffi + # weasyprint colorama==0.4.6 ; sys_platform == 'win32' # via # ipython # tqdm +cssselect2==0.7.0 ; python_full_version < '3.9' + # via weasyprint +cssselect2==0.8.0 ; python_full_version >= '3.9' + # via weasyprint decorator==5.2.1 # via ipython dill==0.4.0 @@ -27,10 +43,16 @@ filelock==3.19.1 ; python_full_version == '3.9.*' # via torch filelock==3.20.0 ; python_full_version >= '3.10' and python_full_version < '3.14' # via torch +fonttools==4.57.0 ; python_full_version < '3.9' + # via weasyprint +fonttools==4.60.1 ; python_full_version >= '3.9' + # via weasyprint fsspec==2025.3.0 ; python_full_version < '3.9' # via torch fsspec==2025.9.0 ; python_full_version >= '3.9' and python_full_version < '3.14' # via torch +html5lib==1.1 ; python_full_version < '3.9' + # via weasyprint importlib-metadata==8.5.0 ; python_full_version < '3.9' # via typeguard importlib-metadata==8.7.0 ; python_full_version == '3.9.*' and platform_machine == 'x86_64' and sys_platform == 'linux' @@ -159,26 +181,52 @@ pexpect==4.9.0 ; (python_full_version < '3.10' and sys_platform == 'emscripten') # via ipython pickleshare==0.7.5 ; python_full_version < '3.9' # via ipython +pillow==10.4.0 ; python_full_version < '3.9' + # via weasyprint +pillow==11.3.0 ; python_full_version == '3.9.*' + # via weasyprint +pillow==12.0.0 ; python_full_version >= '3.10' + # via weasyprint prompt-toolkit==3.0.52 # via ipython ptyprocess==0.7.0 ; (python_full_version < '3.10' and sys_platform == 'emscripten') or (sys_platform != 'emscripten' and sys_platform != 'win32') # via pexpect pure-eval==0.2.3 # via stack-data +pycparser==2.23 ; python_full_version < '3.9' or implementation_name != 'PyPy' + # via cffi +pydyf==0.11.0 + # via weasyprint pygments==2.19.2 # via # ipython # ipython-pygments-lexers +pyphen==0.16.0 ; python_full_version < '3.9' + # via weasyprint +pyphen==0.17.2 ; python_full_version >= '3.9' + # via weasyprint setuptools==80.9.0 ; (python_full_version >= '3.12' and python_full_version < '3.14') or (python_full_version == '3.9.*' and platform_machine == 'x86_64' and sys_platform == 'linux') # via # torch # triton +six==1.17.0 ; python_full_version < '3.9' + # via html5lib stack-data==0.6.3 # via ipython sympy==1.13.3 ; python_full_version < '3.9' # via torch sympy==1.14.0 ; python_full_version >= '3.9' and python_full_version < '3.14' # via torch +tinycss2==1.2.1 ; python_full_version < '3.9' + # via + # cssselect2 + # weasyprint +tinycss2==1.4.0 ; python_full_version >= '3.9' + # via + # cssselect2 + # weasyprint +tinyhtml5==2.0.0 ; python_full_version >= '3.9' + # via weasyprint torch==2.4.1 ; python_full_version < '3.9' # via muutils torch==2.8.0 ; python_full_version == '3.9.*' @@ -216,7 +264,19 @@ wadler-lindig==0.1.7 ; python_full_version >= '3.10' # via jaxtyping wcwidth==0.2.14 # via prompt-toolkit +weasyprint==61.2 ; python_full_version < '3.9' + # via muutils +weasyprint==66.0 ; python_full_version >= '3.9' + # via muutils +webencodings==0.5.1 + # via + # cssselect2 + # html5lib + # tinycss2 + # tinyhtml5 zipp==3.20.2 ; python_full_version < '3.9' # via importlib-metadata zipp==3.23.0 ; python_full_version == '3.9.*' and platform_machine == 'x86_64' and sys_platform == 'linux' # via importlib-metadata +zopfli==0.2.3.post1 + # via fonttools diff --git a/.meta/requirements/requirements.txt b/.meta/requirements/requirements.txt index ae10d07c..d486ae25 100644 --- a/.meta/requirements/requirements.txt +++ b/.meta/requirements/requirements.txt @@ -1,5 +1,5 @@ # This file was autogenerated by uv via the following command: -# uv export --no-hashes --group dev --group lint --extra array --extra array_no_torch --extra notebook --extra parallel +# uv export --no-hashes --group dev --group lint --extra array --extra array_no_torch --extra notebook --extra parallel --extra web -e . # via lmcat anyio==4.5.2 ; python_full_version < '3.9' @@ -60,6 +60,10 @@ bleach==6.3.0 ; python_full_version >= '3.10' # via nbconvert bracex==2.6 ; python_full_version >= '3.11' # via wcmatch +brotli==1.1.0 ; platform_python_implementation == 'CPython' + # via fonttools +brotlicffi==1.1.0.0 ; platform_python_implementation != 'CPython' + # via fonttools certifi==2025.10.5 # via # httpcore @@ -68,13 +72,17 @@ certifi==2025.10.5 cffi==1.17.1 ; python_full_version < '3.9' # via # argon2-cffi-bindings + # brotlicffi # cryptography # pyzmq + # weasyprint cffi==2.0.0 ; python_full_version >= '3.9' # via # argon2-cffi-bindings + # brotlicffi # cryptography # pyzmq + # weasyprint charset-normalizer==3.4.4 # via requests cli-exit-tools==1.2.7 ; python_full_version >= '3.11' @@ -119,6 +127,10 @@ coverage==7.11.0 ; python_full_version >= '3.10' coverage-badge==1.1.2 cryptography==46.0.3 ; platform_machine != 'ppc64le' and platform_machine != 's390x' and sys_platform == 'linux' # via secretstorage +cssselect2==0.7.0 ; python_full_version < '3.9' + # via weasyprint +cssselect2==0.8.0 ; python_full_version >= '3.9' + # via weasyprint cycler==0.12.1 # via matplotlib debugpy==1.8.17 @@ -151,9 +163,13 @@ filelock==3.19.1 ; python_full_version == '3.9.*' filelock==3.20.0 ; python_full_version >= '3.10' and python_full_version < '3.14' # via torch fonttools==4.57.0 ; python_full_version < '3.9' - # via matplotlib + # via + # matplotlib + # weasyprint fonttools==4.60.1 ; python_full_version >= '3.9' - # via matplotlib + # via + # matplotlib + # weasyprint fqdn==1.5.1 # via jsonschema fsspec==2025.3.0 ; python_full_version < '3.9' @@ -162,6 +178,8 @@ fsspec==2025.9.0 ; python_full_version >= '3.9' and python_full_version < '3.14' # via torch h11==0.16.0 # via httpcore +html5lib==1.1 ; python_full_version < '3.9' + # via weasyprint httpcore==1.0.9 # via httpx httpx==0.28.1 @@ -594,11 +612,17 @@ pexpect==4.9.0 ; (python_full_version < '3.10' and sys_platform == 'emscripten') pickleshare==0.7.5 ; python_full_version < '3.9' # via ipython pillow==10.4.0 ; python_full_version < '3.9' - # via matplotlib + # via + # matplotlib + # weasyprint pillow==11.3.0 ; python_full_version == '3.9.*' - # via matplotlib + # via + # matplotlib + # weasyprint pillow==12.0.0 ; python_full_version >= '3.10' - # via matplotlib + # via + # matplotlib + # weasyprint pkgutil-resolve-name==1.3.10 ; python_full_version < '3.9' # via jsonschema platformdirs==4.3.6 ; python_full_version < '3.9' @@ -634,6 +658,8 @@ pycln==2.5.0 ; python_full_version < '3.9' pycln==2.6.0 ; python_full_version >= '3.9' pycparser==2.23 ; python_full_version < '3.9' or implementation_name != 'PyPy' # via cffi +pydyf==0.11.0 + # via weasyprint pygments==2.19.2 # via # ipython @@ -648,6 +674,10 @@ pyparsing==3.1.4 ; python_full_version < '3.9' # via matplotlib pyparsing==3.2.5 ; python_full_version >= '3.9' # via matplotlib +pyphen==0.16.0 ; python_full_version < '3.9' + # via weasyprint +pyphen==0.17.2 ; python_full_version >= '3.9' + # via weasyprint pytest==8.3.5 ; python_full_version < '3.9' # via pytest-cov pytest==8.4.2 ; python_full_version >= '3.9' @@ -777,6 +807,7 @@ six==1.17.0 # via # astunparse # bleach + # html5lib # python-dateutil # rfc3339-validator sniffio==1.3.1 @@ -796,9 +827,17 @@ terminado==0.18.1 # jupyter-server # jupyter-server-terminals tinycss2==1.2.1 ; python_full_version < '3.9' - # via bleach + # via + # bleach + # cssselect2 + # weasyprint tinycss2==1.4.0 ; python_full_version >= '3.9' - # via bleach + # via + # bleach + # cssselect2 + # weasyprint +tinyhtml5==2.0.0 ; python_full_version >= '3.9' + # via weasyprint tomli==2.3.0 ; python_full_version <= '3.11' # via # coverage @@ -914,6 +953,10 @@ wcmatch==10.1 ; python_full_version >= '3.11' # via igittigitt wcwidth==0.2.14 # via prompt-toolkit +weasyprint==61.2 ; python_full_version < '3.9' + # via muutils +weasyprint==66.0 ; python_full_version >= '3.9' + # via muutils webcolors==24.8.0 ; python_full_version < '3.9' # via jsonschema webcolors==24.11.1 ; python_full_version >= '3.9' @@ -921,7 +964,10 @@ webcolors==24.11.1 ; python_full_version >= '3.9' webencodings==0.5.1 # via # bleach + # cssselect2 + # html5lib # tinycss2 + # tinyhtml5 websocket-client==1.8.0 ; python_full_version < '3.9' # via jupyter-server websocket-client==1.9.0 ; python_full_version >= '3.9' @@ -938,3 +984,5 @@ zipp==3.23.0 ; (python_full_version >= '3.9' and python_full_version < '3.12' an # via # importlib-metadata # importlib-resources +zopfli==0.2.3.post1 + # via fonttools diff --git a/muutils/web/html_to_pdf.py b/muutils/web/html_to_pdf.py index 2728c92c..905910d9 100644 --- a/muutils/web/html_to_pdf.py +++ b/muutils/web/html_to_pdf.py @@ -1,7 +1,7 @@ from pathlib import Path import subprocess -from weasyprint import HTML as WeasyHTML +from weasyprint import HTML as WeasyHTML # type: ignore[import-not-found] def html_to_pdf(src: Path, dst: Path) -> None: diff --git a/pyproject.toml b/pyproject.toml index a6f983ab..e27daa60 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,6 +61,10 @@ "tqdm>=4.67.1", ] + web = [ + "weasyprint>=60.0", + ] + [dependency-groups] dev = [ # typing diff --git a/uv.lock b/uv.lock index c0c0a1ac..82ba38dc 100644 --- a/uv.lock +++ b/uv.lock @@ -419,6 +419,141 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/9d/2a/9186535ce58db529927f6cf5990a849aa9e052eea3e2cfefe20b9e1802da/bracex-2.6-py3-none-any.whl", hash = "sha256:0b0049264e7340b3ec782b5cb99beb325f36c3782a32e36e876452fd49a09952", size = 11508, upload-time = "2025-06-22T19:12:29.781Z" }, ] +[[package]] +name = "brotli" +version = "1.1.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/2f/c2/f9e977608bdf958650638c3f1e28f85a1b075f075ebbe77db8555463787b/Brotli-1.1.0.tar.gz", hash = "sha256:81de08ac11bcb85841e440c13611c00b67d3bf82698314928d0b676362546724", size = 7372270, upload-time = "2023-09-07T14:05:41.643Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/6d/3a/dbf4fb970c1019a57b5e492e1e0eae745d32e59ba4d6161ab5422b08eefe/Brotli-1.1.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:e1140c64812cb9b06c922e77f1c26a75ec5e3f0fb2bf92cc8c58720dec276752", size = 873045, upload-time = "2023-09-07T14:03:16.894Z" }, + { url = "https://files.pythonhosted.org/packages/dd/11/afc14026ea7f44bd6eb9316d800d439d092c8d508752055ce8d03086079a/Brotli-1.1.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:c8fd5270e906eef71d4a8d19b7c6a43760c6abcfcc10c9101d14eb2357418de9", size = 446218, upload-time = "2023-09-07T14:03:18.917Z" }, + { url = "https://files.pythonhosted.org/packages/36/83/7545a6e7729db43cb36c4287ae388d6885c85a86dd251768a47015dfde32/Brotli-1.1.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1ae56aca0402a0f9a3431cddda62ad71666ca9d4dc3a10a142b9dce2e3c0cda3", size = 2903872, upload-time = "2023-09-07T14:03:20.398Z" }, + { url = "https://files.pythonhosted.org/packages/32/23/35331c4d9391fcc0f29fd9bec2c76e4b4eeab769afbc4b11dd2e1098fb13/Brotli-1.1.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:43ce1b9935bfa1ede40028054d7f48b5469cd02733a365eec8a329ffd342915d", size = 2941254, upload-time = "2023-09-07T14:03:21.914Z" }, + { url = "https://files.pythonhosted.org/packages/3b/24/1671acb450c902edb64bd765d73603797c6c7280a9ada85a195f6b78c6e5/Brotli-1.1.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:7c4855522edb2e6ae7fdb58e07c3ba9111e7621a8956f481c68d5d979c93032e", size = 2857293, upload-time = "2023-09-07T14:03:24Z" }, + { url = "https://files.pythonhosted.org/packages/d5/00/40f760cc27007912b327fe15bf6bfd8eaecbe451687f72a8abc587d503b3/Brotli-1.1.0-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:38025d9f30cf4634f8309c6874ef871b841eb3c347e90b0851f63d1ded5212da", size = 3002385, upload-time = "2023-09-07T14:03:26.248Z" }, + { url = "https://files.pythonhosted.org/packages/b8/cb/8aaa83f7a4caa131757668c0fb0c4b6384b09ffa77f2fba9570d87ab587d/Brotli-1.1.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:e6a904cb26bfefc2f0a6f240bdf5233be78cd2488900a2f846f3c3ac8489ab80", size = 2911104, upload-time = "2023-09-07T14:03:27.849Z" }, + { url = "https://files.pythonhosted.org/packages/bc/c4/65456561d89d3c49f46b7fbeb8fe6e449f13bdc8ea7791832c5d476b2faf/Brotli-1.1.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:a37b8f0391212d29b3a91a799c8e4a2855e0576911cdfb2515487e30e322253d", size = 2809981, upload-time = "2023-09-07T14:03:29.92Z" }, + { url = "https://files.pythonhosted.org/packages/05/1b/cf49528437bae28abce5f6e059f0d0be6fecdcc1d3e33e7c54b3ca498425/Brotli-1.1.0-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:e84799f09591700a4154154cab9787452925578841a94321d5ee8fb9a9a328f0", size = 2935297, upload-time = "2023-09-07T14:03:32.035Z" }, + { url = "https://files.pythonhosted.org/packages/81/ff/190d4af610680bf0c5a09eb5d1eac6e99c7c8e216440f9c7cfd42b7adab5/Brotli-1.1.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:f66b5337fa213f1da0d9000bc8dc0cb5b896b726eefd9c6046f699b169c41b9e", size = 2930735, upload-time = "2023-09-07T14:03:33.801Z" }, + { url = "https://files.pythonhosted.org/packages/80/7d/f1abbc0c98f6e09abd3cad63ec34af17abc4c44f308a7a539010f79aae7a/Brotli-1.1.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:5dab0844f2cf82be357a0eb11a9087f70c5430b2c241493fc122bb6f2bb0917c", size = 2933107, upload-time = "2024-10-18T12:32:09.016Z" }, + { url = "https://files.pythonhosted.org/packages/34/ce/5a5020ba48f2b5a4ad1c0522d095ad5847a0be508e7d7569c8630ce25062/Brotli-1.1.0-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:e4fe605b917c70283db7dfe5ada75e04561479075761a0b3866c081d035b01c1", size = 2845400, upload-time = "2024-10-18T12:32:11.134Z" }, + { url = "https://files.pythonhosted.org/packages/44/89/fa2c4355ab1eecf3994e5a0a7f5492c6ff81dfcb5f9ba7859bd534bb5c1a/Brotli-1.1.0-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:1e9a65b5736232e7a7f91ff3d02277f11d339bf34099a56cdab6a8b3410a02b2", size = 3031985, upload-time = "2024-10-18T12:32:12.813Z" }, + { url = "https://files.pythonhosted.org/packages/af/a4/79196b4a1674143d19dca400866b1a4d1a089040df7b93b88ebae81f3447/Brotli-1.1.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:58d4b711689366d4a03ac7957ab8c28890415e267f9b6589969e74b6e42225ec", size = 2927099, upload-time = "2024-10-18T12:32:14.733Z" }, + { url = "https://files.pythonhosted.org/packages/e9/54/1c0278556a097f9651e657b873ab08f01b9a9ae4cac128ceb66427d7cd20/Brotli-1.1.0-cp310-cp310-win32.whl", hash = "sha256:be36e3d172dc816333f33520154d708a2657ea63762ec16b62ece02ab5e4daf2", size = 333172, upload-time = "2023-09-07T14:03:35.212Z" }, + { url = "https://files.pythonhosted.org/packages/f7/65/b785722e941193fd8b571afd9edbec2a9b838ddec4375d8af33a50b8dab9/Brotli-1.1.0-cp310-cp310-win_amd64.whl", hash = "sha256:0c6244521dda65ea562d5a69b9a26120769b7a9fb3db2fe9545935ed6735b128", size = 357255, upload-time = "2023-09-07T14:03:36.447Z" }, + { url = "https://files.pythonhosted.org/packages/96/12/ad41e7fadd5db55459c4c401842b47f7fee51068f86dd2894dd0dcfc2d2a/Brotli-1.1.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:a3daabb76a78f829cafc365531c972016e4aa8d5b4bf60660ad8ecee19df7ccc", size = 873068, upload-time = "2023-09-07T14:03:37.779Z" }, + { url = "https://files.pythonhosted.org/packages/95/4e/5afab7b2b4b61a84e9c75b17814198ce515343a44e2ed4488fac314cd0a9/Brotli-1.1.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:c8146669223164fc87a7e3de9f81e9423c67a79d6b3447994dfb9c95da16e2d6", size = 446244, upload-time = "2023-09-07T14:03:39.223Z" }, + { url = "https://files.pythonhosted.org/packages/9d/e6/f305eb61fb9a8580c525478a4a34c5ae1a9bcb12c3aee619114940bc513d/Brotli-1.1.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:30924eb4c57903d5a7526b08ef4a584acc22ab1ffa085faceb521521d2de32dd", size = 2906500, upload-time = "2023-09-07T14:03:40.858Z" }, + { url = "https://files.pythonhosted.org/packages/3e/4f/af6846cfbc1550a3024e5d3775ede1e00474c40882c7bf5b37a43ca35e91/Brotli-1.1.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ceb64bbc6eac5a140ca649003756940f8d6a7c444a68af170b3187623b43bebf", size = 2943950, upload-time = "2023-09-07T14:03:42.896Z" }, + { url = "https://files.pythonhosted.org/packages/b3/e7/ca2993c7682d8629b62630ebf0d1f3bb3d579e667ce8e7ca03a0a0576a2d/Brotli-1.1.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a469274ad18dc0e4d316eefa616d1d0c2ff9da369af19fa6f3daa4f09671fd61", size = 2918527, upload-time = "2023-09-07T14:03:44.552Z" }, + { url = "https://files.pythonhosted.org/packages/b3/96/da98e7bedc4c51104d29cc61e5f449a502dd3dbc211944546a4cc65500d3/Brotli-1.1.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:524f35912131cc2cabb00edfd8d573b07f2d9f21fa824bd3fb19725a9cf06327", size = 2845489, upload-time = "2023-09-07T14:03:46.594Z" }, + { url = "https://files.pythonhosted.org/packages/e8/ef/ccbc16947d6ce943a7f57e1a40596c75859eeb6d279c6994eddd69615265/Brotli-1.1.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:5b3cc074004d968722f51e550b41a27be656ec48f8afaeeb45ebf65b561481dd", size = 2914080, upload-time = "2023-09-07T14:03:48.204Z" }, + { url = "https://files.pythonhosted.org/packages/80/d6/0bd38d758d1afa62a5524172f0b18626bb2392d717ff94806f741fcd5ee9/Brotli-1.1.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:19c116e796420b0cee3da1ccec3b764ed2952ccfcc298b55a10e5610ad7885f9", size = 2813051, upload-time = "2023-09-07T14:03:50.348Z" }, + { url = "https://files.pythonhosted.org/packages/14/56/48859dd5d129d7519e001f06dcfbb6e2cf6db92b2702c0c2ce7d97e086c1/Brotli-1.1.0-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:510b5b1bfbe20e1a7b3baf5fed9e9451873559a976c1a78eebaa3b86c57b4265", size = 2938172, upload-time = "2023-09-07T14:03:52.395Z" }, + { url = "https://files.pythonhosted.org/packages/3d/77/a236d5f8cd9e9f4348da5acc75ab032ab1ab2c03cc8f430d24eea2672888/Brotli-1.1.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:a1fd8a29719ccce974d523580987b7f8229aeace506952fa9ce1d53a033873c8", size = 2933023, upload-time = "2023-09-07T14:03:53.96Z" }, + { url = "https://files.pythonhosted.org/packages/f1/87/3b283efc0f5cb35f7f84c0c240b1e1a1003a5e47141a4881bf87c86d0ce2/Brotli-1.1.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:c247dd99d39e0338a604f8c2b3bc7061d5c2e9e2ac7ba9cc1be5a69cb6cd832f", size = 2935871, upload-time = "2024-10-18T12:32:16.688Z" }, + { url = "https://files.pythonhosted.org/packages/f3/eb/2be4cc3e2141dc1a43ad4ca1875a72088229de38c68e842746b342667b2a/Brotli-1.1.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:1b2c248cd517c222d89e74669a4adfa5577e06ab68771a529060cf5a156e9757", size = 2847784, upload-time = "2024-10-18T12:32:18.459Z" }, + { url = "https://files.pythonhosted.org/packages/66/13/b58ddebfd35edde572ccefe6890cf7c493f0c319aad2a5badee134b4d8ec/Brotli-1.1.0-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:2a24c50840d89ded6c9a8fdc7b6ed3692ed4e86f1c4a4a938e1e92def92933e0", size = 3034905, upload-time = "2024-10-18T12:32:20.192Z" }, + { url = "https://files.pythonhosted.org/packages/84/9c/bc96b6c7db824998a49ed3b38e441a2cae9234da6fa11f6ed17e8cf4f147/Brotli-1.1.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:f31859074d57b4639318523d6ffdca586ace54271a73ad23ad021acd807eb14b", size = 2929467, upload-time = "2024-10-18T12:32:21.774Z" }, + { url = "https://files.pythonhosted.org/packages/e7/71/8f161dee223c7ff7fea9d44893fba953ce97cf2c3c33f78ba260a91bcff5/Brotli-1.1.0-cp311-cp311-win32.whl", hash = "sha256:39da8adedf6942d76dc3e46653e52df937a3c4d6d18fdc94a7c29d263b1f5b50", size = 333169, upload-time = "2023-09-07T14:03:55.404Z" }, + { url = "https://files.pythonhosted.org/packages/02/8a/fece0ee1057643cb2a5bbf59682de13f1725f8482b2c057d4e799d7ade75/Brotli-1.1.0-cp311-cp311-win_amd64.whl", hash = "sha256:aac0411d20e345dc0920bdec5548e438e999ff68d77564d5e9463a7ca9d3e7b1", size = 357253, upload-time = "2023-09-07T14:03:56.643Z" }, + { url = "https://files.pythonhosted.org/packages/5c/d0/5373ae13b93fe00095a58efcbce837fd470ca39f703a235d2a999baadfbc/Brotli-1.1.0-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:32d95b80260d79926f5fab3c41701dbb818fde1c9da590e77e571eefd14abe28", size = 815693, upload-time = "2024-10-18T12:32:23.824Z" }, + { url = "https://files.pythonhosted.org/packages/8e/48/f6e1cdf86751300c288c1459724bfa6917a80e30dbfc326f92cea5d3683a/Brotli-1.1.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:b760c65308ff1e462f65d69c12e4ae085cff3b332d894637f6273a12a482d09f", size = 422489, upload-time = "2024-10-18T12:32:25.641Z" }, + { url = "https://files.pythonhosted.org/packages/06/88/564958cedce636d0f1bed313381dfc4b4e3d3f6015a63dae6146e1b8c65c/Brotli-1.1.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:316cc9b17edf613ac76b1f1f305d2a748f1b976b033b049a6ecdfd5612c70409", size = 873081, upload-time = "2023-09-07T14:03:57.967Z" }, + { url = "https://files.pythonhosted.org/packages/58/79/b7026a8bb65da9a6bb7d14329fd2bd48d2b7f86d7329d5cc8ddc6a90526f/Brotli-1.1.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:caf9ee9a5775f3111642d33b86237b05808dafcd6268faa492250e9b78046eb2", size = 446244, upload-time = "2023-09-07T14:03:59.319Z" }, + { url = "https://files.pythonhosted.org/packages/e5/18/c18c32ecea41b6c0004e15606e274006366fe19436b6adccc1ae7b2e50c2/Brotli-1.1.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:70051525001750221daa10907c77830bc889cb6d865cc0b813d9db7fefc21451", size = 2906505, upload-time = "2023-09-07T14:04:01.327Z" }, + { url = "https://files.pythonhosted.org/packages/08/c8/69ec0496b1ada7569b62d85893d928e865df29b90736558d6c98c2031208/Brotli-1.1.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:7f4bf76817c14aa98cc6697ac02f3972cb8c3da93e9ef16b9c66573a68014f91", size = 2944152, upload-time = "2023-09-07T14:04:03.033Z" }, + { url = "https://files.pythonhosted.org/packages/ab/fb/0517cea182219d6768113a38167ef6d4eb157a033178cc938033a552ed6d/Brotli-1.1.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d0c5516f0aed654134a2fc936325cc2e642f8a0e096d075209672eb321cff408", size = 2919252, upload-time = "2023-09-07T14:04:04.675Z" }, + { url = "https://files.pythonhosted.org/packages/c7/53/73a3431662e33ae61a5c80b1b9d2d18f58dfa910ae8dd696e57d39f1a2f5/Brotli-1.1.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6c3020404e0b5eefd7c9485ccf8393cfb75ec38ce75586e046573c9dc29967a0", size = 2845955, upload-time = "2023-09-07T14:04:06.585Z" }, + { url = "https://files.pythonhosted.org/packages/55/ac/bd280708d9c5ebdbf9de01459e625a3e3803cce0784f47d633562cf40e83/Brotli-1.1.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:4ed11165dd45ce798d99a136808a794a748d5dc38511303239d4e2363c0695dc", size = 2914304, upload-time = "2023-09-07T14:04:08.668Z" }, + { url = "https://files.pythonhosted.org/packages/76/58/5c391b41ecfc4527d2cc3350719b02e87cb424ef8ba2023fb662f9bf743c/Brotli-1.1.0-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:4093c631e96fdd49e0377a9c167bfd75b6d0bad2ace734c6eb20b348bc3ea180", size = 2814452, upload-time = "2023-09-07T14:04:10.736Z" }, + { url = "https://files.pythonhosted.org/packages/c7/4e/91b8256dfe99c407f174924b65a01f5305e303f486cc7a2e8a5d43c8bec3/Brotli-1.1.0-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:7e4c4629ddad63006efa0ef968c8e4751c5868ff0b1c5c40f76524e894c50248", size = 2938751, upload-time = "2023-09-07T14:04:12.875Z" }, + { url = "https://files.pythonhosted.org/packages/5a/a6/e2a39a5d3b412938362bbbeba5af904092bf3f95b867b4a3eb856104074e/Brotli-1.1.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:861bf317735688269936f755fa136a99d1ed526883859f86e41a5d43c61d8966", size = 2933757, upload-time = "2023-09-07T14:04:14.551Z" }, + { url = "https://files.pythonhosted.org/packages/13/f0/358354786280a509482e0e77c1a5459e439766597d280f28cb097642fc26/Brotli-1.1.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:87a3044c3a35055527ac75e419dfa9f4f3667a1e887ee80360589eb8c90aabb9", size = 2936146, upload-time = "2024-10-18T12:32:27.257Z" }, + { url = "https://files.pythonhosted.org/packages/80/f7/daf538c1060d3a88266b80ecc1d1c98b79553b3f117a485653f17070ea2a/Brotli-1.1.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:c5529b34c1c9d937168297f2c1fde7ebe9ebdd5e121297ff9c043bdb2ae3d6fb", size = 2848055, upload-time = "2024-10-18T12:32:29.376Z" }, + { url = "https://files.pythonhosted.org/packages/ad/cf/0eaa0585c4077d3c2d1edf322d8e97aabf317941d3a72d7b3ad8bce004b0/Brotli-1.1.0-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:ca63e1890ede90b2e4454f9a65135a4d387a4585ff8282bb72964fab893f2111", size = 3035102, upload-time = "2024-10-18T12:32:31.371Z" }, + { url = "https://files.pythonhosted.org/packages/d8/63/1c1585b2aa554fe6dbce30f0c18bdbc877fa9a1bf5ff17677d9cca0ac122/Brotli-1.1.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:e79e6520141d792237c70bcd7a3b122d00f2613769ae0cb61c52e89fd3443839", size = 2930029, upload-time = "2024-10-18T12:32:33.293Z" }, + { url = "https://files.pythonhosted.org/packages/5f/3b/4e3fd1893eb3bbfef8e5a80d4508bec17a57bb92d586c85c12d28666bb13/Brotli-1.1.0-cp312-cp312-win32.whl", hash = "sha256:5f4d5ea15c9382135076d2fb28dde923352fe02951e66935a9efaac8f10e81b0", size = 333276, upload-time = "2023-09-07T14:04:16.49Z" }, + { url = "https://files.pythonhosted.org/packages/3d/d5/942051b45a9e883b5b6e98c041698b1eb2012d25e5948c58d6bf85b1bb43/Brotli-1.1.0-cp312-cp312-win_amd64.whl", hash = "sha256:906bc3a79de8c4ae5b86d3d75a8b77e44404b0f4261714306e3ad248d8ab0951", size = 357255, upload-time = "2023-09-07T14:04:17.83Z" }, + { url = "https://files.pythonhosted.org/packages/0a/9f/fb37bb8ffc52a8da37b1c03c459a8cd55df7a57bdccd8831d500e994a0ca/Brotli-1.1.0-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:8bf32b98b75c13ec7cf774164172683d6e7891088f6316e54425fde1efc276d5", size = 815681, upload-time = "2024-10-18T12:32:34.942Z" }, + { url = "https://files.pythonhosted.org/packages/06/b3/dbd332a988586fefb0aa49c779f59f47cae76855c2d00f450364bb574cac/Brotli-1.1.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:7bc37c4d6b87fb1017ea28c9508b36bbcb0c3d18b4260fcdf08b200c74a6aee8", size = 422475, upload-time = "2024-10-18T12:32:36.485Z" }, + { url = "https://files.pythonhosted.org/packages/bb/80/6aaddc2f63dbcf2d93c2d204e49c11a9ec93a8c7c63261e2b4bd35198283/Brotli-1.1.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3c0ef38c7a7014ffac184db9e04debe495d317cc9c6fb10071f7fefd93100a4f", size = 2906173, upload-time = "2024-10-18T12:32:37.978Z" }, + { url = "https://files.pythonhosted.org/packages/ea/1d/e6ca79c96ff5b641df6097d299347507d39a9604bde8915e76bf026d6c77/Brotli-1.1.0-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:91d7cc2a76b5567591d12c01f019dd7afce6ba8cba6571187e21e2fc418ae648", size = 2943803, upload-time = "2024-10-18T12:32:39.606Z" }, + { url = "https://files.pythonhosted.org/packages/ac/a3/d98d2472e0130b7dd3acdbb7f390d478123dbf62b7d32bda5c830a96116d/Brotli-1.1.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a93dde851926f4f2678e704fadeb39e16c35d8baebd5252c9fd94ce8ce68c4a0", size = 2918946, upload-time = "2024-10-18T12:32:41.679Z" }, + { url = "https://files.pythonhosted.org/packages/c4/a5/c69e6d272aee3e1423ed005d8915a7eaa0384c7de503da987f2d224d0721/Brotli-1.1.0-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f0db75f47be8b8abc8d9e31bc7aad0547ca26f24a54e6fd10231d623f183d089", size = 2845707, upload-time = "2024-10-18T12:32:43.478Z" }, + { url = "https://files.pythonhosted.org/packages/58/9f/4149d38b52725afa39067350696c09526de0125ebfbaab5acc5af28b42ea/Brotli-1.1.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:6967ced6730aed543b8673008b5a391c3b1076d834ca438bbd70635c73775368", size = 2936231, upload-time = "2024-10-18T12:32:45.224Z" }, + { url = "https://files.pythonhosted.org/packages/5a/5a/145de884285611838a16bebfdb060c231c52b8f84dfbe52b852a15780386/Brotli-1.1.0-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:7eedaa5d036d9336c95915035fb57422054014ebdeb6f3b42eac809928e40d0c", size = 2848157, upload-time = "2024-10-18T12:32:46.894Z" }, + { url = "https://files.pythonhosted.org/packages/50/ae/408b6bfb8525dadebd3b3dd5b19d631da4f7d46420321db44cd99dcf2f2c/Brotli-1.1.0-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:d487f5432bf35b60ed625d7e1b448e2dc855422e87469e3f450aa5552b0eb284", size = 3035122, upload-time = "2024-10-18T12:32:48.844Z" }, + { url = "https://files.pythonhosted.org/packages/af/85/a94e5cfaa0ca449d8f91c3d6f78313ebf919a0dbd55a100c711c6e9655bc/Brotli-1.1.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:832436e59afb93e1836081a20f324cb185836c617659b07b129141a8426973c7", size = 2930206, upload-time = "2024-10-18T12:32:51.198Z" }, + { url = "https://files.pythonhosted.org/packages/c2/f0/a61d9262cd01351df22e57ad7c34f66794709acab13f34be2675f45bf89d/Brotli-1.1.0-cp313-cp313-win32.whl", hash = "sha256:43395e90523f9c23a3d5bdf004733246fba087f2948f87ab28015f12359ca6a0", size = 333804, upload-time = "2024-10-18T12:32:52.661Z" }, + { url = "https://files.pythonhosted.org/packages/7e/c1/ec214e9c94000d1c1974ec67ced1c970c148aa6b8d8373066123fc3dbf06/Brotli-1.1.0-cp313-cp313-win_amd64.whl", hash = "sha256:9011560a466d2eb3f5a6e4929cf4a09be405c64154e12df0dd72713f6500e32b", size = 358517, upload-time = "2024-10-18T12:32:54.066Z" }, + { url = "https://files.pythonhosted.org/packages/34/1b/16114a20c0a43c20331f03431178ed8b12280b12c531a14186da0bc5b276/Brotli-1.1.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:efa8b278894b14d6da122a72fefcebc28445f2d3f880ac59d46c90f4c13be9a3", size = 873053, upload-time = "2023-09-07T14:04:58.335Z" }, + { url = "https://files.pythonhosted.org/packages/36/49/2afe4aa5a23a13dad4c7160ae574668eec58b3c80b56b74a826cebff7ab8/Brotli-1.1.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:03d20af184290887bdea3f0f78c4f737d126c74dc2f3ccadf07e54ceca3bf208", size = 446211, upload-time = "2023-09-07T14:04:59.928Z" }, + { url = "https://files.pythonhosted.org/packages/10/9d/6463edb80a9e0a944f70ed0c4d41330178526626d7824f729e81f78a3f24/Brotli-1.1.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6172447e1b368dcbc458925e5ddaf9113477b0ed542df258d84fa28fc45ceea7", size = 2904604, upload-time = "2023-09-07T14:05:02.348Z" }, + { url = "https://files.pythonhosted.org/packages/a4/bd/cfaac88c14f97d9e1f2e51a304c3573858548bb923d011b19f76b295f81c/Brotli-1.1.0-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a743e5a28af5f70f9c080380a5f908d4d21d40e8f0e0c8901604d15cfa9ba751", size = 2941707, upload-time = "2023-09-07T14:05:04.639Z" }, + { url = "https://files.pythonhosted.org/packages/60/3f/2618fa887d7af6828246822f10d9927244dab22db7a96ec56041a2fd1fbd/Brotli-1.1.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:0541e747cce78e24ea12d69176f6a7ddb690e62c425e01d31cc065e69ce55b48", size = 2672420, upload-time = "2023-09-07T14:05:06.709Z" }, + { url = "https://files.pythonhosted.org/packages/e7/41/1c6d15c8d5b55db2c3c249c64c352c8a1bc97f5e5c55183f5930866fc012/Brotli-1.1.0-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:cdbc1fc1bc0bff1cef838eafe581b55bfbffaed4ed0318b724d0b71d4d377619", size = 2757410, upload-time = "2023-09-07T14:05:09.28Z" }, + { url = "https://files.pythonhosted.org/packages/6c/5b/ca72fd8aa1278dfbb12eb320b6e409aefabcd767b85d607c9d54c9dadd1a/Brotli-1.1.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:890b5a14ce214389b2cc36ce82f3093f96f4cc730c1cffdbefff77a7c71f2a97", size = 2911143, upload-time = "2023-09-07T14:05:11.737Z" }, + { url = "https://files.pythonhosted.org/packages/b1/53/110657f4017d34a2e9a96d9630a388ad7e56092023f1d46d11648c6c0bce/Brotli-1.1.0-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:1ab4fbee0b2d9098c74f3057b2bc055a8bd92ccf02f65944a241b4349229185a", size = 2809968, upload-time = "2023-09-07T14:05:13.351Z" }, + { url = "https://files.pythonhosted.org/packages/3f/2a/fbc95429b45e4aa4a3a3a815e4af11772bfd8ef94e883dcff9ceaf556662/Brotli-1.1.0-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:141bd4d93984070e097521ed07e2575b46f817d08f9fa42b16b9b5f27b5ac088", size = 2935402, upload-time = "2023-09-07T14:05:15.039Z" }, + { url = "https://files.pythonhosted.org/packages/4e/52/02acd2992e5a2c10adf65fa920fad0c29e11e110f95eeb11bcb20342ecd2/Brotli-1.1.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:fce1473f3ccc4187f75b4690cfc922628aed4d3dd013d047f95a9b3919a86596", size = 2931208, upload-time = "2023-09-07T14:05:16.747Z" }, + { url = "https://files.pythonhosted.org/packages/6b/35/5d258d1aeb407e1fc6fcbbff463af9c64d1ecc17042625f703a1e9d22ec5/Brotli-1.1.0-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:d2b35ca2c7f81d173d2fadc2f4f31e88cc5f7a39ae5b6db5513cf3383b0e0ec7", size = 2933171, upload-time = "2024-10-18T12:33:10.342Z" }, + { url = "https://files.pythonhosted.org/packages/cc/58/b25ca26492da9880e517753967685903c6002ddc2aade93d6e56df817b30/Brotli-1.1.0-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:af6fa6817889314555aede9a919612b23739395ce767fe7fcbea9a80bf140fe5", size = 2845347, upload-time = "2024-10-18T12:33:12.367Z" }, + { url = "https://files.pythonhosted.org/packages/12/cf/91b84beaa051c9376a22cc38122dc6fbb63abcebd5a4b8503e9c388de7b1/Brotli-1.1.0-cp38-cp38-musllinux_1_2_ppc64le.whl", hash = "sha256:2feb1d960f760a575dbc5ab3b1c00504b24caaf6986e2dc2b01c09c87866a943", size = 3031668, upload-time = "2024-10-18T12:33:14.347Z" }, + { url = "https://files.pythonhosted.org/packages/38/05/04a57ba75aed972be0c6ad5f2f5ea34c83f5fecf57787cc6e54aac21a323/Brotli-1.1.0-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:4410f84b33374409552ac9b6903507cdb31cd30d2501fc5ca13d18f73548444a", size = 2926949, upload-time = "2024-10-18T12:33:15.988Z" }, + { url = "https://files.pythonhosted.org/packages/c9/2f/fbe6938f33d2cd9b7d7fb591991eb3fb57ffa40416bb873bbbacab60a381/Brotli-1.1.0-cp38-cp38-win32.whl", hash = "sha256:db85ecf4e609a48f4b29055f1e144231b90edc90af7481aa731ba2d059226b1b", size = 333179, upload-time = "2023-09-07T14:05:18.343Z" }, + { url = "https://files.pythonhosted.org/packages/39/a5/9322c8436072e77b8646f6bde5e19ee66f62acf7aa01337ded10777077fa/Brotli-1.1.0-cp38-cp38-win_amd64.whl", hash = "sha256:3d7954194c36e304e1523f55d7042c59dc53ec20dd4e9ea9d151f1b62b4415c0", size = 357254, upload-time = "2023-09-07T14:05:19.792Z" }, + { url = "https://files.pythonhosted.org/packages/1b/aa/aa6e0c9848ee4375514af0b27abf470904992939b7363ae78fc8aca8a9a8/Brotli-1.1.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:5fb2ce4b8045c78ebbc7b8f3c15062e435d47e7393cc57c25115cfd49883747a", size = 873048, upload-time = "2023-09-07T14:05:21.205Z" }, + { url = "https://files.pythonhosted.org/packages/ae/32/38bba1a8bef9ecb1cda08439fd28d7e9c51aff13b4783a4f1610da90b6c2/Brotli-1.1.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:7905193081db9bfa73b1219140b3d315831cbff0d8941f22da695832f0dd188f", size = 446207, upload-time = "2023-09-07T14:05:23.21Z" }, + { url = "https://files.pythonhosted.org/packages/3c/6a/14cc20ddc53efc274601c8195791a27cfb7acc5e5134e0f8c493a8b8821a/Brotli-1.1.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a77def80806c421b4b0af06f45d65a136e7ac0bdca3c09d9e2ea4e515367c7e9", size = 2903803, upload-time = "2023-09-07T14:05:24.864Z" }, + { url = "https://files.pythonhosted.org/packages/9a/26/62b2d894d4e82d7a7f4e0bb9007a42bbc765697a5679b43186acd68d7a79/Brotli-1.1.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8dadd1314583ec0bf2d1379f7008ad627cd6336625d6679cf2f8e67081b83acf", size = 2941149, upload-time = "2023-09-07T14:05:26.479Z" }, + { url = "https://files.pythonhosted.org/packages/a9/ca/00d55bbdd8631236c61777742d8a8454cf6a87eb4125cad675912c68bec7/Brotli-1.1.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:901032ff242d479a0efa956d853d16875d42157f98951c0230f69e69f9c09bac", size = 2672253, upload-time = "2023-09-07T14:05:28.133Z" }, + { url = "https://files.pythonhosted.org/packages/e2/e6/4a730f6e5b5d538e92d09bc51bf69119914f29a222f9e1d65ae4abb27a4e/Brotli-1.1.0-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:22fc2a8549ffe699bfba2256ab2ed0421a7b8fadff114a3d201794e45a9ff578", size = 2757005, upload-time = "2023-09-07T14:05:29.812Z" }, + { url = "https://files.pythonhosted.org/packages/cb/6b/8cf297987fe3c1bf1c87f0c0b714af2ce47092b8d307b9f6ecbc65f98968/Brotli-1.1.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:ae15b066e5ad21366600ebec29a7ccbc86812ed267e4b28e860b8ca16a2bc474", size = 2910658, upload-time = "2023-09-07T14:05:31.376Z" }, + { url = "https://files.pythonhosted.org/packages/2c/1f/be9443995821c933aad7159803f84ef4923c6f5b72c2affd001192b310fc/Brotli-1.1.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:949f3b7c29912693cee0afcf09acd6ebc04c57af949d9bf77d6101ebb61e388c", size = 2809728, upload-time = "2023-09-07T14:05:32.923Z" }, + { url = "https://files.pythonhosted.org/packages/76/2f/213bab6efa902658c80a1247142d42b138a27ccdd6bade49ca9cd74e714a/Brotli-1.1.0-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:89f4988c7203739d48c6f806f1e87a1d96e0806d44f0fba61dba81392c9e474d", size = 2935043, upload-time = "2023-09-07T14:05:34.607Z" }, + { url = "https://files.pythonhosted.org/packages/27/89/bbb14fa98e895d1e601491fba54a5feec167d262f0d3d537a3b0d4cd0029/Brotli-1.1.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:de6551e370ef19f8de1807d0a9aa2cdfdce2e85ce88b122fe9f6b2b076837e59", size = 2930639, upload-time = "2023-09-07T14:05:36.317Z" }, + { url = "https://files.pythonhosted.org/packages/14/87/03a6d6e1866eddf9f58cc57e35befbeb5514da87a416befe820150cae63f/Brotli-1.1.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:0737ddb3068957cf1b054899b0883830bb1fec522ec76b1098f9b6e0f02d9419", size = 2932834, upload-time = "2024-10-18T12:33:18.364Z" }, + { url = "https://files.pythonhosted.org/packages/a4/d5/e5f85e04f75144d1a89421ba432def6bdffc8f28b04f5b7d540bbd03362c/Brotli-1.1.0-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:4f3607b129417e111e30637af1b56f24f7a49e64763253bbc275c75fa887d4b2", size = 2845213, upload-time = "2024-10-18T12:33:20.059Z" }, + { url = "https://files.pythonhosted.org/packages/99/bf/25ef07add7afbb1aacd4460726a1a40370dfd60c0810b6f242a6d3871d7e/Brotli-1.1.0-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:6c6e0c425f22c1c719c42670d561ad682f7bfeeef918edea971a79ac5252437f", size = 3031573, upload-time = "2024-10-18T12:33:22.541Z" }, + { url = "https://files.pythonhosted.org/packages/55/22/948a97bda5c9dc9968d56b9ed722d9727778db43739cf12ef26ff69be94d/Brotli-1.1.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:494994f807ba0b92092a163a0a283961369a65f6cbe01e8891132b7a320e61eb", size = 2926885, upload-time = "2024-10-18T12:33:24.781Z" }, + { url = "https://files.pythonhosted.org/packages/31/ba/e53d107399b535ef89deb6977dd8eae468e2dde7b1b74c6cbe2c0e31fda2/Brotli-1.1.0-cp39-cp39-win32.whl", hash = "sha256:f0d8a7a6b5983c2496e364b969f0e526647a06b075d034f3297dc66f3b360c64", size = 333171, upload-time = "2023-09-07T14:05:38.071Z" }, + { url = "https://files.pythonhosted.org/packages/99/b3/f7b3af539f74b82e1c64d28685a5200c631cc14ae751d37d6ed819655627/Brotli-1.1.0-cp39-cp39-win_amd64.whl", hash = "sha256:cdad5b9014d83ca68c25d2e9444e28e967ef16e80f6b436918c700c117a85467", size = 357258, upload-time = "2023-09-07T14:05:39.591Z" }, +] + +[[package]] +name = "brotlicffi" +version = "1.1.0.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "cffi", version = "1.17.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.9'" }, + { name = "cffi", version = "2.0.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/95/9d/70caa61192f570fcf0352766331b735afa931b4c6bc9a348a0925cc13288/brotlicffi-1.1.0.0.tar.gz", hash = "sha256:b77827a689905143f87915310b93b273ab17888fd43ef350d4832c4a71083c13", size = 465192, upload-time = "2023-09-14T14:22:40.707Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a2/11/7b96009d3dcc2c931e828ce1e157f03824a69fb728d06bfd7b2fc6f93718/brotlicffi-1.1.0.0-cp37-abi3-macosx_10_9_x86_64.whl", hash = "sha256:9b7ae6bd1a3f0df532b6d67ff674099a96d22bc0948955cb338488c31bfb8851", size = 453786, upload-time = "2023-09-14T14:21:57.72Z" }, + { url = "https://files.pythonhosted.org/packages/d6/e6/a8f46f4a4ee7856fbd6ac0c6fb0dc65ed181ba46cd77875b8d9bbe494d9e/brotlicffi-1.1.0.0-cp37-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:19ffc919fa4fc6ace69286e0a23b3789b4219058313cf9b45625016bf7ff996b", size = 2911165, upload-time = "2023-09-14T14:21:59.613Z" }, + { url = "https://files.pythonhosted.org/packages/be/20/201559dff14e83ba345a5ec03335607e47467b6633c210607e693aefac40/brotlicffi-1.1.0.0-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9feb210d932ffe7798ee62e6145d3a757eb6233aa9a4e7db78dd3690d7755814", size = 2927895, upload-time = "2023-09-14T14:22:01.22Z" }, + { url = "https://files.pythonhosted.org/packages/cd/15/695b1409264143be3c933f708a3f81d53c4a1e1ebbc06f46331decbf6563/brotlicffi-1.1.0.0-cp37-abi3-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:84763dbdef5dd5c24b75597a77e1b30c66604725707565188ba54bab4f114820", size = 2851834, upload-time = "2023-09-14T14:22:03.571Z" }, + { url = "https://files.pythonhosted.org/packages/b4/40/b961a702463b6005baf952794c2e9e0099bde657d0d7e007f923883b907f/brotlicffi-1.1.0.0-cp37-abi3-win32.whl", hash = "sha256:1b12b50e07c3911e1efa3a8971543e7648100713d4e0971b13631cce22c587eb", size = 341731, upload-time = "2023-09-14T14:22:05.74Z" }, + { url = "https://files.pythonhosted.org/packages/1c/fa/5408a03c041114ceab628ce21766a4ea882aa6f6f0a800e04ee3a30ec6b9/brotlicffi-1.1.0.0-cp37-abi3-win_amd64.whl", hash = "sha256:994a4f0681bb6c6c3b0925530a1926b7a189d878e6e5e38fae8efa47c5d9c613", size = 366783, upload-time = "2023-09-14T14:22:07.096Z" }, + { url = "https://files.pythonhosted.org/packages/e5/3b/bd4f3d2bcf2306ae66b0346f5b42af1962480b200096ffc7abc3bd130eca/brotlicffi-1.1.0.0-pp310-pypy310_pp73-macosx_10_9_x86_64.whl", hash = "sha256:2e4aeb0bd2540cb91b069dbdd54d458da8c4334ceaf2d25df2f4af576d6766ca", size = 397397, upload-time = "2023-09-14T14:22:08.519Z" }, + { url = "https://files.pythonhosted.org/packages/54/10/1fd57864449360852c535c2381ee7120ba8f390aa3869df967c44ca7eba1/brotlicffi-1.1.0.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4b7b0033b0d37bb33009fb2fef73310e432e76f688af76c156b3594389d81391", size = 379698, upload-time = "2023-09-14T14:22:10.52Z" }, + { url = "https://files.pythonhosted.org/packages/e5/95/15aa422aa6450e6556e54a5fd1650ff59f470aed77ac739aa90ab63dc611/brotlicffi-1.1.0.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:54a07bb2374a1eba8ebb52b6fafffa2afd3c4df85ddd38fcc0511f2bb387c2a8", size = 378635, upload-time = "2023-09-14T14:22:11.982Z" }, + { url = "https://files.pythonhosted.org/packages/6c/a7/f254e13b2cb43337d6d99a4ec10394c134e41bfda8a2eff15b75627f4a3d/brotlicffi-1.1.0.0-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7901a7dc4b88f1c1475de59ae9be59799db1007b7d059817948d8e4f12e24e35", size = 385719, upload-time = "2023-09-14T14:22:13.483Z" }, + { url = "https://files.pythonhosted.org/packages/72/a9/0971251c4427c14b2a827dba3d910d4d3330dabf23d4278bf6d06a978847/brotlicffi-1.1.0.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:ce01c7316aebc7fce59da734286148b1d1b9455f89cf2c8a4dfce7d41db55c2d", size = 361760, upload-time = "2023-09-14T14:22:14.767Z" }, + { url = "https://files.pythonhosted.org/packages/75/ff/e227f8547f5ef11d861abae091d5dc012c2b1eb2e7358eff429fafbd608e/brotlicffi-1.1.0.0-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:9b6068e0f3769992d6b622a1cd2e7835eae3cf8d9da123d7f51ca9c1e9c333e5", size = 397391, upload-time = "2023-09-14T14:22:23.595Z" }, + { url = "https://files.pythonhosted.org/packages/85/2d/9e8057f9c73c29090ce885fe2a133c17082ce2aa0712c533a52a5aeb042f/brotlicffi-1.1.0.0-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8557a8559509b61e65083f8782329188a250102372576093c88930c875a69838", size = 379693, upload-time = "2023-09-14T14:22:25.618Z" }, + { url = "https://files.pythonhosted.org/packages/50/22/62b4bf874a0be46e79bb46db4e52533f757d85107ee0cdfcc800314e865f/brotlicffi-1.1.0.0-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2a7ae37e5d79c5bdfb5b4b99f2715a6035e6c5bf538c3746abc8e26694f92f33", size = 378627, upload-time = "2023-09-14T14:22:27.527Z" }, + { url = "https://files.pythonhosted.org/packages/ff/cb/648a47cd457a3afe3bacdfcd62e89fde6666be503d06403a6c2f157b7d61/brotlicffi-1.1.0.0-pp38-pypy38_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:391151ec86bb1c683835980f4816272a87eaddc46bb91cbf44f62228b84d8cca", size = 385712, upload-time = "2023-09-14T14:22:28.835Z" }, + { url = "https://files.pythonhosted.org/packages/4b/df/d81660ba62bb54cefd6e95d5315710a8871ebf0872a4bd61b13388181742/brotlicffi-1.1.0.0-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:2f3711be9290f0453de8eed5275d93d286abe26b08ab4a35d7452caa1fef532f", size = 361750, upload-time = "2023-09-14T14:22:30.772Z" }, + { url = "https://files.pythonhosted.org/packages/35/9b/e0b577351e1d9d5890e1a56900c4ceaaef783b807145cd229446a43cf437/brotlicffi-1.1.0.0-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:1a807d760763e398bbf2c6394ae9da5815901aa93ee0a37bca5efe78d4ee3171", size = 397392, upload-time = "2023-09-14T14:22:32.2Z" }, + { url = "https://files.pythonhosted.org/packages/4f/7f/a16534d28386f74781db8b4544a764cf955abae336379a76f50e745bb0ee/brotlicffi-1.1.0.0-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fa8ca0623b26c94fccc3a1fdd895be1743b838f3917300506d04aa3346fd2a14", size = 379695, upload-time = "2023-09-14T14:22:33.85Z" }, + { url = "https://files.pythonhosted.org/packages/50/2a/699388b5e489726991132441b55aff0691dd73c49105ef220408a5ab98d6/brotlicffi-1.1.0.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3de0cf28a53a3238b252aca9fed1593e9d36c1d116748013339f0949bfc84112", size = 378629, upload-time = "2023-09-14T14:22:35.9Z" }, + { url = "https://files.pythonhosted.org/packages/4a/3f/58254e7fbe6011bf043e4dcade0e16995a9f82b731734fad97220d201f42/brotlicffi-1.1.0.0-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6be5ec0e88a4925c91f3dea2bb0013b3a2accda6f77238f76a34a1ea532a1cb0", size = 385712, upload-time = "2023-09-14T14:22:37.767Z" }, + { url = "https://files.pythonhosted.org/packages/40/16/2a29a625a6f74d13726387f83484dfaaf6fcdaafaadfbe26a0412ae268cc/brotlicffi-1.1.0.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:d9eb71bb1085d996244439154387266fd23d6ad37161f6f52f1cd41dd95a3808", size = 361747, upload-time = "2023-09-14T14:22:39.368Z" }, +] + [[package]] name = "certifi" version = "2025.10.5" @@ -1481,6 +1616,43 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ba/af/72cd6ef29f9c5f731251acadaeb821559fe25f10852f44a63374c9ca08c1/cryptography-46.0.3-pp311-pypy311_pp73-manylinux_2_34_x86_64.whl", hash = "sha256:94cd0549accc38d1494e1f8de71eca837d0509d0d44bf11d158524b0e12cebf9", size = 4409447, upload-time = "2025-10-15T23:18:24.209Z" }, ] +[[package]] +name = "cssselect2" +version = "0.7.0" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version < '3.9'", +] +dependencies = [ + { name = "tinycss2", version = "1.2.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.9'" }, + { name = "webencodings", marker = "python_full_version < '3.9'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/e7/fc/326cb6f988905998f09bb54a3f5d98d4462ba119363c0dfad29750d48c09/cssselect2-0.7.0.tar.gz", hash = "sha256:1ccd984dab89fc68955043aca4e1b03e0cf29cad9880f6e28e3ba7a74b14aa5a", size = 35888, upload-time = "2022-09-19T12:55:11.876Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/9d/3a/e39436efe51894243ff145a37c4f9a030839b97779ebcc4f13b3ba21c54e/cssselect2-0.7.0-py3-none-any.whl", hash = "sha256:fd23a65bfd444595913f02fc71f6b286c29261e354c41d722ca7a261a49b5969", size = 15586, upload-time = "2022-09-19T12:55:07.56Z" }, +] + +[[package]] +name = "cssselect2" +version = "0.8.0" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version >= '3.14'", + "python_full_version == '3.13.*'", + "python_full_version == '3.12.*'", + "python_full_version == '3.11.*'", + "python_full_version == '3.10.*'", + "python_full_version == '3.9.*'", +] +dependencies = [ + { name = "tinycss2", version = "1.4.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9'" }, + { name = "webencodings", marker = "python_full_version >= '3.9'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/9f/86/fd7f58fc498b3166f3a7e8e0cddb6e620fe1da35b02248b1bd59e95dbaaa/cssselect2-0.8.0.tar.gz", hash = "sha256:7674ffb954a3b46162392aee2a3a0aedb2e14ecf99fcc28644900f4e6e3e9d3a", size = 35716, upload-time = "2025-03-05T14:46:07.988Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/0f/e7/aa315e6a749d9b96c2504a1ba0ba031ba2d0517e972ce22682e3fccecb09/cssselect2-0.8.0-py3-none-any.whl", hash = "sha256:46fc70ebc41ced7a32cd42d58b1884d72ade23d21e5a4eaaf022401c13f0e76e", size = 15454, upload-time = "2025-03-05T14:46:06.463Z" }, +] + [[package]] name = "cycler" version = "0.12.1" @@ -1713,6 +1885,13 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/90/27/45f8957c3132917f91aaa56b700bcfc2396be1253f685bd5c68529b6f610/fonttools-4.57.0-py3-none-any.whl", hash = "sha256:3122c604a675513c68bd24c6a8f9091f1c2376d18e8f5fe5a101746c81b3e98f", size = 1093605, upload-time = "2025-04-03T11:07:11.341Z" }, ] +[package.optional-dependencies] +woff = [ + { name = "brotli", marker = "python_full_version < '3.9' and platform_python_implementation == 'CPython'" }, + { name = "brotlicffi", marker = "python_full_version < '3.9' and platform_python_implementation != 'CPython'" }, + { name = "zopfli", marker = "python_full_version < '3.9'" }, +] + [[package]] name = "fonttools" version = "4.60.1" @@ -1786,6 +1965,13 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c7/93/0dd45cd283c32dea1545151d8c3637b4b8c53cdb3a625aeb2885b184d74d/fonttools-4.60.1-py3-none-any.whl", hash = "sha256:906306ac7afe2156fcf0042173d6ebbb05416af70f6b370967b47f8f00103bbb", size = 1143175, upload-time = "2025-09-29T21:13:24.134Z" }, ] +[package.optional-dependencies] +woff = [ + { name = "brotli", marker = "python_full_version >= '3.9' and platform_python_implementation == 'CPython'" }, + { name = "brotlicffi", marker = "python_full_version >= '3.9' and platform_python_implementation != 'CPython'" }, + { name = "zopfli", marker = "python_full_version >= '3.9'" }, +] + [[package]] name = "fqdn" version = "1.5.1" @@ -1832,6 +2018,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/04/4b/29cac41a4d98d144bf5f6d33995617b185d14b22401f75ca86f384e87ff1/h11-0.16.0-py3-none-any.whl", hash = "sha256:63cf8bbe7522de3bf65932fda1d9c2772064ffb3dae62d55932da54b31cb6c86", size = 37515, upload-time = "2025-04-24T03:35:24.344Z" }, ] +[[package]] +name = "html5lib" +version = "1.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "six", marker = "python_full_version < '3.9'" }, + { name = "webencodings", marker = "python_full_version < '3.9'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/ac/b6/b55c3f49042f1df3dcd422b7f224f939892ee94f22abcf503a9b7339eaf2/html5lib-1.1.tar.gz", hash = "sha256:b2e5b40261e20f354d198eae92afc10d750afb487ed5e50f9c4eaf07c184146f", size = 272215, upload-time = "2020-06-22T23:32:38.834Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/6c/dd/a834df6482147d48e225a49515aabc28974ad5a4ca3215c18a882565b028/html5lib-1.1-py2.py3-none-any.whl", hash = "sha256:0d78f8fde1c230e99fe37986a60526d7049ed4bf8a9fadbad5f00e22e58e041d", size = 112173, upload-time = "2020-06-22T23:32:36.781Z" }, +] + [[package]] name = "httpcore" version = "1.0.9" @@ -3882,6 +4081,10 @@ parallel = [ { name = "multiprocess" }, { name = "tqdm" }, ] +web = [ + { name = "weasyprint", version = "61.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.9'" }, + { name = "weasyprint", version = "66.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9'" }, +] [package.dev-dependencies] dev = [ @@ -3940,8 +4143,9 @@ requires-dist = [ { name = "torch", marker = "python_full_version < '3.9' and extra == 'array'", specifier = ">=1.13.1,<2.5.0" }, { name = "torch", marker = "python_full_version == '3.13.*' and extra == 'array'", specifier = ">=2.5.0" }, { name = "tqdm", marker = "extra == 'parallel'", specifier = ">=4.67.1" }, + { name = "weasyprint", marker = "extra == 'web'", specifier = ">=60.0" }, ] -provides-extras = ["array", "array-no-torch", "notebook", "parallel"] +provides-extras = ["array", "array-no-torch", "notebook", "parallel", "web"] [package.metadata.requires-dev] dev = [ @@ -5725,6 +5929,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a0/e3/59cd50310fc9b59512193629e1984c1f95e5c8ae6e5d8c69532ccc65a7fe/pycparser-2.23-py3-none-any.whl", hash = "sha256:e5c6e8d3fbad53479cab09ac03729e0a9faf2bee3db8208a550daf5af81a5934", size = 118140, upload-time = "2025-09-09T13:23:46.651Z" }, ] +[[package]] +name = "pydyf" +version = "0.11.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/2e/c2/97fc6ce4ce0045080dc99446def812081b57750ed8aa67bfdfafa4561fe5/pydyf-0.11.0.tar.gz", hash = "sha256:394dddf619cca9d0c55715e3c55ea121a9bf9cbc780cdc1201a2427917b86b64", size = 17769, upload-time = "2024-07-12T12:26:51.95Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c9/ac/d5db977deaf28c6ecbc61bbca269eb3e8f0b3a1f55c8549e5333e606e005/pydyf-0.11.0-py3-none-any.whl", hash = "sha256:0aaf9e2ebbe786ec7a78ec3fbffa4cdcecde53fd6f563221d53c6bc1328848a3", size = 8104, upload-time = "2024-07-12T12:26:49.896Z" }, +] + [[package]] name = "pygments" version = "2.19.2" @@ -5763,6 +5976,35 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/10/5e/1aa9a93198c6b64513c9d7752de7422c06402de6600a8767da1524f9570b/pyparsing-3.2.5-py3-none-any.whl", hash = "sha256:e38a4f02064cf41fe6593d328d0512495ad1f3d8a91c4f73fc401b3079a59a5e", size = 113890, upload-time = "2025-09-21T04:11:04.117Z" }, ] +[[package]] +name = "pyphen" +version = "0.16.0" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version < '3.9'", +] +sdist = { url = "https://files.pythonhosted.org/packages/33/87/493fc9e2597923a2b02a1facc376a3bf8e682698ae177b340c0c5fd1fdec/pyphen-0.16.0.tar.gz", hash = "sha256:2c006b3ddf072c9571ab97606d9ab3c26a92eaced4c0d59fd1d26988f308f413", size = 2072790, upload-time = "2024-07-30T11:43:12.755Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/52/34/839a8cb56f145abf2da52ba4607b0e45b79fa018cb154fcba149fb76f179/pyphen-0.16.0-py3-none-any.whl", hash = "sha256:b4a4c6d7d5654b698b5fc68123148bb799b3debe0175d1d5dc3edfe93066fc4c", size = 2073300, upload-time = "2024-07-30T11:43:08.608Z" }, +] + +[[package]] +name = "pyphen" +version = "0.17.2" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version >= '3.14'", + "python_full_version == '3.13.*'", + "python_full_version == '3.12.*'", + "python_full_version == '3.11.*'", + "python_full_version == '3.10.*'", + "python_full_version == '3.9.*'", +] +sdist = { url = "https://files.pythonhosted.org/packages/69/56/e4d7e1bd70d997713649c5ce530b2d15a5fc2245a74ca820fc2d51d89d4d/pyphen-0.17.2.tar.gz", hash = "sha256:f60647a9c9b30ec6c59910097af82bc5dd2d36576b918e44148d8b07ef3b4aa3", size = 2079470, upload-time = "2025-01-20T13:18:36.296Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7b/1f/c2142d2edf833a90728e5cdeb10bdbdc094dde8dbac078cee0cf33f5e11b/pyphen-0.17.2-py3-none-any.whl", hash = "sha256:3a07fb017cb2341e1d9ff31b8634efb1ae4dc4b130468c7c39dd3d32e7c3affd", size = 2079358, upload-time = "2025-01-20T13:18:29.629Z" }, +] + [[package]] name = "pytest" version = "8.3.5" @@ -7032,6 +7274,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e6/34/ebdc18bae6aa14fbee1a08b63c015c72b64868ff7dae68808ab500c492e2/tinycss2-1.4.0-py3-none-any.whl", hash = "sha256:3a49cf47b7675da0b15d0c6e1df8df4ebd96e9394bb905a5775adb0d884c5289", size = 26610, upload-time = "2024-10-24T14:58:28.029Z" }, ] +[[package]] +name = "tinyhtml5" +version = "2.0.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "webencodings", marker = "python_full_version >= '3.9'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/fd/03/6111ed99e9bf7dfa1c30baeef0e0fb7e0bd387bd07f8e5b270776fe1de3f/tinyhtml5-2.0.0.tar.gz", hash = "sha256:086f998833da24c300c414d9fe81d9b368fd04cb9d2596a008421cbc705fcfcc", size = 179507, upload-time = "2024-10-29T15:37:14.078Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/5c/de/27c57899297163a4a84104d5cec0af3b1ac5faf62f44667e506373c6b8ce/tinyhtml5-2.0.0-py3-none-any.whl", hash = "sha256:13683277c5b176d070f82d099d977194b7a1e26815b016114f581a74bbfbf47e", size = 39793, upload-time = "2024-10-29T15:37:11.743Z" }, +] + [[package]] name = "tomli" version = "2.3.0" @@ -7618,6 +7872,56 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/af/b5/123f13c975e9f27ab9c0770f514345bd406d0e8d3b7a0723af9d43f710af/wcwidth-0.2.14-py2.py3-none-any.whl", hash = "sha256:a7bb560c8aee30f9957e5f9895805edd20602f2d7f720186dfd906e82b4982e1", size = 37286, upload-time = "2025-09-22T16:29:51.641Z" }, ] +[[package]] +name = "weasyprint" +version = "61.2" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version < '3.9'", +] +dependencies = [ + { name = "cffi", version = "1.17.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.9'" }, + { name = "cssselect2", version = "0.7.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.9'" }, + { name = "fonttools", version = "4.57.0", source = { registry = "https://pypi.org/simple" }, extra = ["woff"], marker = "python_full_version < '3.9'" }, + { name = "html5lib", marker = "python_full_version < '3.9'" }, + { name = "pillow", version = "10.4.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.9'" }, + { name = "pydyf", marker = "python_full_version < '3.9'" }, + { name = "pyphen", version = "0.16.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.9'" }, + { name = "tinycss2", version = "1.2.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.9'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/f8/1a/32a7de6916ead1bd5b5ed6a5f4431d5011426850566d7ba2947f896cdd19/weasyprint-61.2.tar.gz", hash = "sha256:47df6cfeeff8c6c28cf2e4caf837cde17715efe462708ada74baa2eb391b6059", size = 447333, upload-time = "2024-03-08T10:34:15.081Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b7/f0/f313cee61f320b3c651dcc213767ff6eda0b4398fc04ceae8b4882b58b48/weasyprint-61.2-py3-none-any.whl", hash = "sha256:76c6dc0e75e09182d5645d92c66ddf86b1b992c9420235b723fb374b584e5bf4", size = 271459, upload-time = "2024-03-08T10:34:12.419Z" }, +] + +[[package]] +name = "weasyprint" +version = "66.0" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version >= '3.14'", + "python_full_version == '3.13.*'", + "python_full_version == '3.12.*'", + "python_full_version == '3.11.*'", + "python_full_version == '3.10.*'", + "python_full_version == '3.9.*'", +] +dependencies = [ + { name = "cffi", version = "2.0.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9'" }, + { name = "cssselect2", version = "0.8.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9'" }, + { name = "fonttools", version = "4.60.1", source = { registry = "https://pypi.org/simple" }, extra = ["woff"], marker = "python_full_version >= '3.9'" }, + { name = "pillow", version = "11.3.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version == '3.9.*'" }, + { name = "pillow", version = "12.0.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.10'" }, + { name = "pydyf", marker = "python_full_version >= '3.9'" }, + { name = "pyphen", version = "0.17.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9'" }, + { name = "tinycss2", version = "1.4.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9'" }, + { name = "tinyhtml5", marker = "python_full_version >= '3.9'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/32/99/480b5430b7eb0916e7d5df1bee7d9508b28b48fee28da894d0a050e0e930/weasyprint-66.0.tar.gz", hash = "sha256:da71dc87dc129ac9cffdc65e5477e90365ab9dbae45c744014ec1d06303dde40", size = 504224, upload-time = "2025-07-24T11:44:42.771Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/0f/d1/c5d9b341bf3d556c1e4c6566b3efdda0b1bb175510aa7b09dd3eee246923/weasyprint-66.0-py3-none-any.whl", hash = "sha256:82b0783b726fcd318e2c977dcdddca76515b30044bc7a830cc4fbe717582a6d0", size = 301965, upload-time = "2025-07-24T11:44:40.968Z" }, +] + [[package]] name = "webcolors" version = "24.8.0" @@ -7728,3 +8032,81 @@ sdist = { url = "https://files.pythonhosted.org/packages/e3/02/0f2892c661036d50e wheels = [ { url = "https://files.pythonhosted.org/packages/2e/54/647ade08bf0db230bfea292f893923872fd20be6ac6f53b2b936ba839d75/zipp-3.23.0-py3-none-any.whl", hash = "sha256:071652d6115ed432f5ce1d34c336c0adfd6a884660d1e9712a256d3d3bd4b14e", size = 10276, upload-time = "2025-06-08T17:06:38.034Z" }, ] + +[[package]] +name = "zopfli" +version = "0.2.3.post1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/5e/7c/a8f6696e694709e2abcbccd27d05ef761e9b6efae217e11d977471555b62/zopfli-0.2.3.post1.tar.gz", hash = "sha256:96484dc0f48be1c5d7ae9f38ed1ce41e3675fd506b27c11a6607f14b49101e99", size = 175629, upload-time = "2024-10-18T15:42:05.946Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/95/5b/7f21751e0da525a78a0269600c1d45dee565f9f0a9f875e1374b00778a82/zopfli-0.2.3.post1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:e0137dd64a493ba6a4be37405cfd6febe650a98cc1e9dca8f6b8c63b1db11b41", size = 296334, upload-time = "2024-10-18T15:40:29.42Z" }, + { url = "https://files.pythonhosted.org/packages/96/a9/b9bcac622a66ecfef22e2c735feefd3b9f31b8a45ca2ef8c1438604d2157/zopfli-0.2.3.post1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:aa588b21044f8a74e423d8c8a4c7fc9988501878aacced793467010039c50734", size = 163885, upload-time = "2024-10-18T15:40:31.638Z" }, + { url = "https://files.pythonhosted.org/packages/79/b6/02dcb076ceb3120dc7a7e1cb197add5189c265ef9424b595430f19583dad/zopfli-0.2.3.post1-cp310-cp310-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:9f4a7ec2770e6af05f5a02733fd3900f30a9cd58e5d6d3727e14c5bcd6e7d587", size = 790653, upload-time = "2024-10-18T15:40:32.734Z" }, + { url = "https://files.pythonhosted.org/packages/74/b5/720b8a6a0a103caee1c10deb52139ba25aa0b37263cd423521bc6c416ce2/zopfli-0.2.3.post1-cp310-cp310-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:f7d69c1a7168ad0e9cb864e8663acb232986a0c9c9cb9801f56bf6214f53a54d", size = 849105, upload-time = "2024-10-18T15:40:34.223Z" }, + { url = "https://files.pythonhosted.org/packages/e7/a6/74f03eb4c0243bc418634ebdceb4715a28db8ab281c89cde1b7d2c243c13/zopfli-0.2.3.post1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6c2d2bc8129707e34c51f9352c4636ca313b52350bbb7e04637c46c1818a2a70", size = 825695, upload-time = "2024-10-18T15:40:35.851Z" }, + { url = "https://files.pythonhosted.org/packages/f0/5c/eb1cb5a4e3c7becb5576944e225f3df05198c6d3ad20e4c762eb505c59b8/zopfli-0.2.3.post1-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:39e576f93576c5c223b41d9c780bbb91fd6db4babf3223d2a4fe7bf568e2b5a8", size = 1753293, upload-time = "2024-10-18T15:40:37.417Z" }, + { url = "https://files.pythonhosted.org/packages/57/15/04d1b212e8932acfb0ec3a513f13bfdc5cfb874ba2c23ee0771dffb1063d/zopfli-0.2.3.post1-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:cbe6df25807227519debd1a57ab236f5f6bad441500e85b13903e51f93a43214", size = 1904912, upload-time = "2024-10-18T15:40:39.158Z" }, + { url = "https://files.pythonhosted.org/packages/88/d5/dd458a9053129bc6cf6cd2554c595020f463ba7438f32313b70a697850f2/zopfli-0.2.3.post1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:7cce242b5df12b2b172489daf19c32e5577dd2fac659eb4b17f6a6efb446fd5c", size = 1834445, upload-time = "2024-10-18T15:40:40.437Z" }, + { url = "https://files.pythonhosted.org/packages/a9/c2/f1ddc57f8458fae8c54df9acd079fbd3a7ebaa12d839576719262a942cba/zopfli-0.2.3.post1-cp310-cp310-win32.whl", hash = "sha256:f815fcc2b2a457977724bad97fb4854022980f51ce7b136925e336b530545ae1", size = 82633, upload-time = "2024-10-18T15:40:41.634Z" }, + { url = "https://files.pythonhosted.org/packages/2c/f5/6b750c8326c00c46e486c180efb0f2d23cd0e43ecf8e0c9947586dda664a/zopfli-0.2.3.post1-cp310-cp310-win_amd64.whl", hash = "sha256:0cc20b02a9531559945324c38302fd4ba763311632d0ec8a1a0aa9c10ea363e6", size = 99343, upload-time = "2024-10-18T15:40:43.159Z" }, + { url = "https://files.pythonhosted.org/packages/92/6d/c8224a8fc77c1dff6caaa2dc63794a40ea284c82ac20030fb2521092dca6/zopfli-0.2.3.post1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:518f1f4ed35dd69ce06b552f84e6d081f07c552b4c661c5312d950a0b764a58a", size = 296334, upload-time = "2024-10-18T15:40:44.684Z" }, + { url = "https://files.pythonhosted.org/packages/f8/da/df0f87a489d223f184d69e9e88c80c1314be43b2361acffefdc09659e00d/zopfli-0.2.3.post1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:615a8ac9dda265e9cc38b2a76c3142e4a9f30fea4a79c85f670850783bc6feb4", size = 163886, upload-time = "2024-10-18T15:40:45.812Z" }, + { url = "https://files.pythonhosted.org/packages/39/b7/14529a7ae608cedddb2f791cbc13a392a246e2e6d9c9b4b8bcda707d08d8/zopfli-0.2.3.post1-cp311-cp311-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a82fc2dbebe6eb908b9c665e71496f8525c1bc4d2e3a7a7722ef2b128b6227c8", size = 823654, upload-time = "2024-10-18T15:40:46.969Z" }, + { url = "https://files.pythonhosted.org/packages/57/48/217c7bd720553d9e68b96926c02820e8b6184ef6dbac937823abad85b154/zopfli-0.2.3.post1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:37d011e92f7b9622742c905fdbed9920a1d0361df84142807ea2a528419dea7f", size = 826188, upload-time = "2024-10-18T15:40:48.147Z" }, + { url = "https://files.pythonhosted.org/packages/2f/8b/5ab8c4c6db2564a0c3369e584090c101ffad4f9d0a39396e0d3e80c98413/zopfli-0.2.3.post1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e63d558847166543c2c9789e6f985400a520b7eacc4b99181668b2c3aeadd352", size = 850573, upload-time = "2024-10-18T15:40:49.481Z" }, + { url = "https://files.pythonhosted.org/packages/33/f8/f52ec5c713f3325c852f19af7c8e3f98109ddcd1ce400dc39005072a2fea/zopfli-0.2.3.post1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:60db20f06c3d4c5934b16cfa62a2cc5c3f0686bffe0071ed7804d3c31ab1a04e", size = 1754164, upload-time = "2024-10-18T15:40:50.952Z" }, + { url = "https://files.pythonhosted.org/packages/92/24/6a6018125e1cc6ee5880a0ae60456fdc8a2da43f2f14b487cf49439a3448/zopfli-0.2.3.post1-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:716cdbfc57bfd3d3e31a58e6246e8190e6849b7dbb7c4ce39ef8bbf0edb8f6d5", size = 1906135, upload-time = "2024-10-18T15:40:52.484Z" }, + { url = "https://files.pythonhosted.org/packages/87/ad/697521dac8b46f0e0d081a3da153687d7583f3a2cd5466af1ddb9928394f/zopfli-0.2.3.post1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:3a89277ed5f8c0fb2d0b46d669aa0633123aa7381f1f6118c12f15e0fb48f8ca", size = 1835047, upload-time = "2024-10-18T15:40:54.453Z" }, + { url = "https://files.pythonhosted.org/packages/95/00/042c0cdba957343d7a83e572fc5ffe62de03d57c43075c8cf920b8b542e6/zopfli-0.2.3.post1-cp311-cp311-win32.whl", hash = "sha256:75a26a2307b10745a83b660c404416e984ee6fca515ec7f0765f69af3ce08072", size = 82635, upload-time = "2024-10-18T15:40:55.632Z" }, + { url = "https://files.pythonhosted.org/packages/e6/cc/07119cba00db12d7ef0472637b7d71a95f2c8e9a20ed460d759acd274887/zopfli-0.2.3.post1-cp311-cp311-win_amd64.whl", hash = "sha256:81c341d9bb87a6dbbb0d45d6e272aca80c7c97b4b210f9b6e233bf8b87242f29", size = 99345, upload-time = "2024-10-18T15:40:56.965Z" }, + { url = "https://files.pythonhosted.org/packages/3f/ce/b6441cc01881d06e0b5883f32c44e7cc9772e0d04e3e59277f59f80b9a19/zopfli-0.2.3.post1-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:3f0197b6aa6eb3086ae9e66d6dd86c4d502b6c68b0ec490496348ae8c05ecaef", size = 295489, upload-time = "2024-10-18T15:40:57.96Z" }, + { url = "https://files.pythonhosted.org/packages/93/f0/24dd708f00ae0a925bc5c9edae858641c80f6a81a516810dc4d21688a930/zopfli-0.2.3.post1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:5fcfc0dc2761e4fcc15ad5d273b4d58c2e8e059d3214a7390d4d3c8e2aee644e", size = 163010, upload-time = "2024-10-18T15:40:59.444Z" }, + { url = "https://files.pythonhosted.org/packages/65/57/0378eeeb5e3e1e83b1b0958616b2bf954f102ba5b0755b9747dafbd8cb72/zopfli-0.2.3.post1-cp312-cp312-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:cac2b37ab21c2b36a10b685b1893ebd6b0f83ae26004838ac817680881576567", size = 823649, upload-time = "2024-10-18T15:41:00.642Z" }, + { url = "https://files.pythonhosted.org/packages/ab/8a/3ab8a616d4655acf5cf63c40ca84e434289d7d95518a1a42d28b4a7228f8/zopfli-0.2.3.post1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8d5ab297d660b75c159190ce6d73035502310e40fd35170aed7d1a1aea7ddd65", size = 826557, upload-time = "2024-10-18T15:41:02.431Z" }, + { url = "https://files.pythonhosted.org/packages/ed/4d/7f6820af119c4fec6efaf007bffee7bc9052f695853a711a951be7afd26b/zopfli-0.2.3.post1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9ba214f4f45bec195ee8559651154d3ac2932470b9d91c5715fc29c013349f8c", size = 851127, upload-time = "2024-10-18T15:41:04.259Z" }, + { url = "https://files.pythonhosted.org/packages/e1/db/1ef5353ab06f9f2fb0c25ed0cddf1418fe275cc2ee548bc4a29340c44fe1/zopfli-0.2.3.post1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:c1e0ed5d84ffa2d677cc9582fc01e61dab2e7ef8b8996e055f0a76167b1b94df", size = 1754183, upload-time = "2024-10-18T15:41:05.808Z" }, + { url = "https://files.pythonhosted.org/packages/39/03/44f8f39950354d330fa798e4bab1ac8e38ec787d3fde25d5b9c7770065a2/zopfli-0.2.3.post1-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:bfa1eb759e07d8b7aa7a310a2bc535e127ee70addf90dc8d4b946b593c3e51a8", size = 1905945, upload-time = "2024-10-18T15:41:07.136Z" }, + { url = "https://files.pythonhosted.org/packages/74/7b/94b920c33cc64255f59e3cfc77c829b5c6e60805d189baeada728854a342/zopfli-0.2.3.post1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:cd2c002f160502608dcc822ed2441a0f4509c52e86fcfd1a09e937278ed1ca14", size = 1835885, upload-time = "2024-10-18T15:41:08.705Z" }, + { url = "https://files.pythonhosted.org/packages/ad/89/c869ac844351e285a6165e2da79b715b0619a122e3160d183805adf8ab45/zopfli-0.2.3.post1-cp312-cp312-win32.whl", hash = "sha256:7be5cc6732eb7b4df17305d8a7b293223f934a31783a874a01164703bc1be6cd", size = 82743, upload-time = "2024-10-18T15:41:10.377Z" }, + { url = "https://files.pythonhosted.org/packages/29/e6/c98912fd3a589d8a7316c408fd91519f72c237805c4400b753e3942fda0b/zopfli-0.2.3.post1-cp312-cp312-win_amd64.whl", hash = "sha256:4e50ffac74842c1c1018b9b73875a0d0a877c066ab06bf7cccbaa84af97e754f", size = 99403, upload-time = "2024-10-18T15:41:11.547Z" }, + { url = "https://files.pythonhosted.org/packages/2b/24/0e552e2efce9a20625b56e9609d1e33c2966be33fc008681121ec267daec/zopfli-0.2.3.post1-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:ecb7572df5372abce8073df078207d9d1749f20b8b136089916a4a0868d56051", size = 295485, upload-time = "2024-10-18T15:41:12.57Z" }, + { url = "https://files.pythonhosted.org/packages/08/83/b2564369fb98797a617fe2796097b1d719a4937234375757ad2a3febc04b/zopfli-0.2.3.post1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:a1cf720896d2ce998bc8e051d4b4ce0d8bec007aab6243102e8e1d22a0b2fb3f", size = 163000, upload-time = "2024-10-18T15:41:13.743Z" }, + { url = "https://files.pythonhosted.org/packages/3c/55/81d419739c2aab35e19b58bce5498dcb58e6446e5eb69f2d3c748b1c9151/zopfli-0.2.3.post1-cp313-cp313-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5aad740b4d4fcbaaae4887823925166ffd062db3b248b3f432198fc287381d1a", size = 823699, upload-time = "2024-10-18T15:41:14.874Z" }, + { url = "https://files.pythonhosted.org/packages/9e/91/89f07c8ea3c9bc64099b3461627b07a8384302235ee0f357eaa86f98f509/zopfli-0.2.3.post1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6617fb10f9e4393b331941861d73afb119cd847e88e4974bdbe8068ceef3f73f", size = 826612, upload-time = "2024-10-18T15:41:16.069Z" }, + { url = "https://files.pythonhosted.org/packages/41/31/46670fc0c7805d42bc89702440fa9b73491d68abbc39e28d687180755178/zopfli-0.2.3.post1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a53b18797cdef27e019db595d66c4b077325afe2fd62145953275f53d84ce40c", size = 851148, upload-time = "2024-10-18T15:41:17.403Z" }, + { url = "https://files.pythonhosted.org/packages/22/00/71ad39277bbb88f9fd20fb786bd3ff2ea4025c53b31652a0da796fb546cd/zopfli-0.2.3.post1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:b78008a69300d929ca2efeffec951b64a312e9a811e265ea4a907ab546d79fa6", size = 1754215, upload-time = "2024-10-18T15:41:18.661Z" }, + { url = "https://files.pythonhosted.org/packages/d0/4e/e542c508d20c3dfbef1b90fcf726f824f505e725747f777b0b7b7d1deb95/zopfli-0.2.3.post1-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:0aa5f90d6298bda02a95bc8dc8c3c19004d5a4e44bda00b67ca7431d857b4b54", size = 1905988, upload-time = "2024-10-18T15:41:19.933Z" }, + { url = "https://files.pythonhosted.org/packages/ba/a5/817ac1ecc888723e91dc172e8c6eeab9f48a1e52285803b965084e11bbd5/zopfli-0.2.3.post1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:2768c877f76c8a0e7519b1c86c93757f3c01492ddde55751e9988afb7eff64e1", size = 1835907, upload-time = "2024-10-18T15:41:21.582Z" }, + { url = "https://files.pythonhosted.org/packages/cd/35/2525f90c972d8aafc39784a8c00244eeee8e8221b26cbc576748ee9dc1cd/zopfli-0.2.3.post1-cp313-cp313-win32.whl", hash = "sha256:71390dbd3fbf6ebea9a5d85ffed8c26ee1453ee09248e9b88486e30e0397b775", size = 82742, upload-time = "2024-10-18T15:41:23.362Z" }, + { url = "https://files.pythonhosted.org/packages/2f/c6/49b27570923956d52d37363e8f5df3a31a61bd7719bb8718527a9df3ae5f/zopfli-0.2.3.post1-cp313-cp313-win_amd64.whl", hash = "sha256:a86eb88e06bd87e1fff31dac878965c26b0c26db59ddcf78bb0379a954b120de", size = 99408, upload-time = "2024-10-18T15:41:24.377Z" }, + { url = "https://files.pythonhosted.org/packages/aa/d7/fa32bb88c4a1c1382b06b0ebf026ce7fa6e0f365419ef6316a03ad217bf3/zopfli-0.2.3.post1-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:3827170de28faf144992d3d4dcf8f3998fe3c8a6a6f4a08f1d42c2ec6119d2bb", size = 296371, upload-time = "2024-10-18T15:41:25.803Z" }, + { url = "https://files.pythonhosted.org/packages/e0/77/650ee17075ea488fb8ed004d20d079ababfdfc02ddd62162690d072d257c/zopfli-0.2.3.post1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:b0ec13f352ea5ae0fc91f98a48540512eed0767d0ec4f7f3cb92d92797983d18", size = 163876, upload-time = "2024-10-18T15:41:27.274Z" }, + { url = "https://files.pythonhosted.org/packages/a6/a1/6a169e3aefa6a92afc53ec71eedc3ebaebce7c519d4611d89d61257b2b03/zopfli-0.2.3.post1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5f272186e03ad55e7af09ab78055535c201b1a0bcc2944edb1768298d9c483a4", size = 825929, upload-time = "2024-10-18T15:41:28.478Z" }, + { url = "https://files.pythonhosted.org/packages/3a/43/1751061a7e70eaa9b5efb88a80aa0aaf18493a467c8036e82b266bf35692/zopfli-0.2.3.post1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:29ea74e72ffa6e291b8c6f2504ce6c146b4fe990c724c1450eb8e4c27fd31431", size = 655151, upload-time = "2024-10-18T15:41:29.725Z" }, + { url = "https://files.pythonhosted.org/packages/3b/ec/f9c1ab2b4b4c57a70e1f8687519580fcb5f26576c4a7a6e8fae44f84e4d1/zopfli-0.2.3.post1-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:eb45a34f23da4f8bc712b6376ca5396914b0b7c09adbb001dad964eb7f3132f8", size = 704493, upload-time = "2024-10-18T15:41:31.234Z" }, + { url = "https://files.pythonhosted.org/packages/c5/9d/53c2deaaf54155f735da71abbb5cafa00c1f66ae793dfb1d8b07908a1db8/zopfli-0.2.3.post1-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:6482db9876c68faac2d20a96b566ffbf65ddaadd97b222e4e73641f4f8722fc4", size = 1753342, upload-time = "2024-10-18T15:41:32.529Z" }, + { url = "https://files.pythonhosted.org/packages/e4/a0/cce4664c31276902cd52735ab3a529c5d96d05d36fa1f539bc1d0986a6ec/zopfli-0.2.3.post1-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:95a260cafd56b8fffa679918937401c80bb38e1681c448b988022e4c3610965d", size = 1904825, upload-time = "2024-10-18T15:41:33.842Z" }, + { url = "https://files.pythonhosted.org/packages/dd/25/cc836bea8563ac2497f7e9f5eb804d4b5842e185a1af83fc2b19ab06e73d/zopfli-0.2.3.post1-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:676919fba7311125244eb0c4393679ac5fe856e5864a15d122bd815205369fa0", size = 1834458, upload-time = "2024-10-18T15:41:35.193Z" }, + { url = "https://files.pythonhosted.org/packages/73/fe/8768104cf6fb90c0d8a74b581f7033a534e02101a0c5c9f4c8b8f8a5fce4/zopfli-0.2.3.post1-cp38-cp38-win32.whl", hash = "sha256:b9026a21b6d41eb0e2e63f5bc1242c3fcc43ecb770963cda99a4307863dac12e", size = 82628, upload-time = "2024-10-18T15:41:36.902Z" }, + { url = "https://files.pythonhosted.org/packages/d4/f4/8230c23739433b8a4dacedd7edc1012df0fcc7e4dc555187c8acfa2405fd/zopfli-0.2.3.post1-cp38-cp38-win_amd64.whl", hash = "sha256:3c163911f8bad94b3e1db0a572e7c28ba681a0c91d0002ea1e4fa9264c21ef17", size = 99337, upload-time = "2024-10-18T15:41:37.872Z" }, + { url = "https://files.pythonhosted.org/packages/f9/cc/c3cc7e83396d3e864103a1ff0be68ee7033a4e50ac5c415df998d2134a7f/zopfli-0.2.3.post1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:b05296e8bc88c92e2b21e0a9bae4740c1551ee613c1d93a51fd28a7a0b2b6fbb", size = 296328, upload-time = "2024-10-18T15:41:39.215Z" }, + { url = "https://files.pythonhosted.org/packages/2d/91/75e0b2a701a6fe816a4bcd9370a805e1f152c8a6b54269afa147b5085f77/zopfli-0.2.3.post1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:f12000a6accdd4bf0a3fa6eaa1b1c7a7bc80af0a2edf3f89d770d3dcce1d0e22", size = 163880, upload-time = "2024-10-18T15:41:40.456Z" }, + { url = "https://files.pythonhosted.org/packages/e8/61/dff95be9ebbf5bdd963774f8e5117f957274a8d8e081c58028fb7d624400/zopfli-0.2.3.post1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a241a68581d34d67b40c425cce3d1fd211c092f99d9250947824ccba9f491949", size = 825513, upload-time = "2024-10-18T15:41:41.584Z" }, + { url = "https://files.pythonhosted.org/packages/91/db/e3057bfdb21855e4db2821c39f67aaedcb39e8bf5c490985009acbaf3d5a/zopfli-0.2.3.post1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:3657e416ffb8f31d9d3424af12122bb251befae109f2e271d87d825c92fc5b7b", size = 654716, upload-time = "2024-10-18T15:41:42.858Z" }, + { url = "https://files.pythonhosted.org/packages/b9/e1/11ed92cf3043a2e89b4a0ffcdcf67084da0f84e7b3c927a862a1e2510546/zopfli-0.2.3.post1-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:4915a41375bdee4db749ecd07d985a0486eb688a6619f713b7bf6fbfd145e960", size = 704138, upload-time = "2024-10-18T15:41:44.24Z" }, + { url = "https://files.pythonhosted.org/packages/a3/b4/48a44ab8a9e80a9c17527397852a6e1f5cc7f1a7d8dcc9d40d6912874ce3/zopfli-0.2.3.post1-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:bbe429fc50686bb2a2608a30843e36fbaa123462a5284f136c7d9e0145220bfd", size = 1753103, upload-time = "2024-10-18T15:41:45.772Z" }, + { url = "https://files.pythonhosted.org/packages/da/c1/fd0ebe0766854610f6d45679745af0220a33b4c478aa1333d48b060a108c/zopfli-0.2.3.post1-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:2345e713260a350bea0b01a816a469ea356bc2d63d009a0d777691ecbbcf7493", size = 1904816, upload-time = "2024-10-18T15:41:47.337Z" }, + { url = "https://files.pythonhosted.org/packages/40/b1/04262314c2c9a1f39f74b8a9d4ba4e31496041ce67e930e39fd5d1fbf798/zopfli-0.2.3.post1-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:fc39f5c27f962ec8660d8d20c24762431131b5d8c672b44b0a54cf2b5bcde9b9", size = 1834332, upload-time = "2024-10-18T15:41:48.942Z" }, + { url = "https://files.pythonhosted.org/packages/08/a6/e30077630b027e9ee84a765447f03c8b5c323a88da567fed0882e9fe4f09/zopfli-0.2.3.post1-cp39-cp39-win32.whl", hash = "sha256:9a6aec38a989bad7ddd1ef53f1265699e49e294d08231b5313d61293f3cd6237", size = 82630, upload-time = "2024-10-18T15:41:50.206Z" }, + { url = "https://files.pythonhosted.org/packages/bd/c2/fa6c4498d16c09e4700e0e92865b4d42ec66c089ea57d92757ad0ebfc556/zopfli-0.2.3.post1-cp39-cp39-win_amd64.whl", hash = "sha256:b3df42f52502438ee973042cc551877d24619fa1cd38ef7b7e9ac74200daca8b", size = 99338, upload-time = "2024-10-18T15:41:51.21Z" }, + { url = "https://files.pythonhosted.org/packages/da/92/62942d9b44b3d56e2d223924b759e2c2219f925da15a8acb103061e362ea/zopfli-0.2.3.post1-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:4c1226a7e2c7105ac31503a9bb97454743f55d88164d6d46bc138051b77f609b", size = 155889, upload-time = "2024-10-18T15:41:52.781Z" }, + { url = "https://files.pythonhosted.org/packages/61/db/9502c4256f126ccf0fc4686f1f59f2696cdaec079d7d57231bf120422ba6/zopfli-0.2.3.post1-pp310-pypy310_pp73-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:48dba9251060289101343110ab47c0756f66f809bb4d1ddbb6d5c7e7752115c5", size = 130129, upload-time = "2024-10-18T15:41:54.255Z" }, + { url = "https://files.pythonhosted.org/packages/a3/7d/1e8c36825798269a9271ac4477b592622fddc2948772fd2fcaceb54a7178/zopfli-0.2.3.post1-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:89899641d4de97dbad8e0cde690040d078b6aea04066dacaab98e0b5a23573f2", size = 126242, upload-time = "2024-10-18T15:41:55.302Z" }, + { url = "https://files.pythonhosted.org/packages/7e/6a/2c1ae9972f2745c074938d6a610e71ed47c36f911220d592f1e403822084/zopfli-0.2.3.post1-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:3654bfc927bc478b1c3f3ff5056ed7b20a1a37fa108ca503256d0a699c03bbb1", size = 99378, upload-time = "2024-10-18T15:41:56.39Z" }, + { url = "https://files.pythonhosted.org/packages/b7/5f/f8a5451ee32054a1c54c47ff3e052bcf4f5d66808df9854931935ae8b56f/zopfli-0.2.3.post1-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:c4278d1873ce6e803e5d4f8d702fd3026bd67fca744aa98881324d1157ddf748", size = 150783, upload-time = "2024-10-18T15:41:57.458Z" }, + { url = "https://files.pythonhosted.org/packages/1e/87/d0dae45684b9aa9914671326e28030aaa33e5b01de847187b27cb61301a1/zopfli-0.2.3.post1-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:1d8cc06605519e82b16df090e17cb3990d1158861b2872c3117f1168777b81e4", size = 99367, upload-time = "2024-10-18T15:41:58.594Z" }, + { url = "https://files.pythonhosted.org/packages/6a/18/5cb5ef140def15a833861139b72d9f1afb888132e6b09cb2f40c88935843/zopfli-0.2.3.post1-pp39-pypy39_pp73-macosx_10_15_x86_64.whl", hash = "sha256:1f990634fd5c5c8ced8edddd8bd45fab565123b4194d6841e01811292650acae", size = 155882, upload-time = "2024-10-18T15:41:59.638Z" }, + { url = "https://files.pythonhosted.org/packages/73/7e/324c6232a425c514785bd5df6976c5a906c295d1cd854072e3204d4fbaec/zopfli-0.2.3.post1-pp39-pypy39_pp73-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:91a2327a4d7e77471fa4fbb26991c6de4a738c6fc6a33e09bb25f56a870a4b7b", size = 130123, upload-time = "2024-10-18T15:42:00.815Z" }, + { url = "https://files.pythonhosted.org/packages/f0/0d/df0fe119da7ac80a42ff8e9bb94701f5d7bb21067c38b20ee52d48e46d41/zopfli-0.2.3.post1-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8fbe5bcf10d01aab3513550f284c09fef32f342b36f56bfae2120a9c4d12c130", size = 126239, upload-time = "2024-10-18T15:42:01.949Z" }, + { url = "https://files.pythonhosted.org/packages/c3/93/b06e0b4a13c5e78d9cc3a7627b4133f72daf2dfa81b6d74f444220b01c62/zopfli-0.2.3.post1-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:34a99592f3d9eb6f737616b5bd74b48a589fdb3cb59a01a50d636ea81d6af272", size = 99367, upload-time = "2024-10-18T15:42:04.616Z" }, +] From 90b8c2e7a3ef41052d5ea7d4af1de7fa82f77924 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Tue, 28 Oct 2025 18:37:08 +0000 Subject: [PATCH 11/72] fix import type hint ignore --- muutils/web/html_to_pdf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/muutils/web/html_to_pdf.py b/muutils/web/html_to_pdf.py index 905910d9..7df4e295 100644 --- a/muutils/web/html_to_pdf.py +++ b/muutils/web/html_to_pdf.py @@ -1,7 +1,7 @@ from pathlib import Path import subprocess -from weasyprint import HTML as WeasyHTML # type: ignore[import-not-found] +from weasyprint import HTML as WeasyHTML # type: ignore[import-untyped] def html_to_pdf(src: Path, dst: Path) -> None: From 30d0cb081193ecb5828f9db9184f5cb1e2b9a32e Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Wed, 29 Oct 2025 13:31:27 +0000 Subject: [PATCH 12/72] fix not passing `doc` to `SerializableField` init --- muutils/json_serialize/serializable_field.py | 1 + 1 file changed, 1 insertion(+) diff --git a/muutils/json_serialize/serializable_field.py b/muutils/json_serialize/serializable_field.py index 12cff5cd..ad34992a 100644 --- a/muutils/json_serialize/serializable_field.py +++ b/muutils/json_serialize/serializable_field.py @@ -297,6 +297,7 @@ class MyClass: repr=repr, hash=hash, compare=compare, + doc=doc, metadata=metadata, kw_only=kw_only, serialize=serialize, From 49ee854f671411932539d9546ba2765161ee31d7 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Wed, 29 Oct 2025 14:14:15 +0000 Subject: [PATCH 13/72] fix minor serialization bugs --- muutils/json_serialize/json_serialize.py | 42 ++++++++++++++++-------- muutils/json_serialize/util.py | 4 +-- 2 files changed, 31 insertions(+), 15 deletions(-) diff --git a/muutils/json_serialize/json_serialize.py b/muutils/json_serialize/json_serialize.py index 53904267..4db08fab 100644 --- a/muutils/json_serialize/json_serialize.py +++ b/muutils/json_serialize/json_serialize.py @@ -126,6 +126,15 @@ def serialize(self) -> dict: uid="dictionaries", desc="dictionaries", ), + SerializerHandler( + check=lambda self, obj, path: isinstance_namedtuple(obj), + serialize_func=lambda self, obj, path: { + str(k): self.json_serialize(v, tuple(path) + (k,)) + for k, v in obj._asdict().items() + }, + uid="namedtuple -> dict", + desc="namedtuples as dicts", + ), SerializerHandler( check=lambda self, obj, path: isinstance(obj, (list, tuple)), serialize_func=lambda self, obj, path: [ @@ -157,12 +166,6 @@ def _serialize_override_serialize_func( uid=".serialize override", desc="objects with .serialize method", ), - SerializerHandler( - check=lambda self, obj, path: isinstance_namedtuple(obj), - serialize_func=lambda self, obj, path: self.json_serialize(dict(obj._asdict())), - uid="namedtuple -> dict", - desc="namedtuples as dicts", - ), SerializerHandler( check=lambda self, obj, path: is_dataclass(obj), serialize_func=lambda self, obj, path: { @@ -212,13 +215,24 @@ def _serialize_override_serialize_func( desc="pandas DataFrames", ), SerializerHandler( - check=lambda self, obj, path: isinstance(obj, (set, list, tuple)) - or isinstance(obj, Iterable), + check=lambda self, obj, path: isinstance(obj, (set, frozenset)), + serialize_func=lambda self, obj, path: { + _FORMAT_KEY: "set" if isinstance(obj, set) else "frozenset", + "data": [ + self.json_serialize(x, tuple(path) + (i,)) for i, x in enumerate(obj) + ], + }, + uid="set -> dict[_FORMAT_KEY: 'set', data: list(...)]", + desc="sets as dicts with format key", + ), + SerializerHandler( + check=lambda self, obj, path: isinstance(obj, Iterable) + and not isinstance(obj, (list, tuple, str)), serialize_func=lambda self, obj, path: [ self.json_serialize(x, tuple(path) + (i,)) for i, x in enumerate(obj) ], - uid="(set, list, tuple, Iterable) -> list", - desc="sets, lists, tuples, and Iterables as lists", + uid="Iterable -> list", + desc="Iterables (not lists/tuples/strings) as lists", ), SerializerHandler( check=lambda self, obj, path: True, @@ -285,6 +299,7 @@ def json_serialize( obj: Any, path: ObjectPath = tuple(), ) -> JSONitem: + handler = None try: for handler in self.handlers: if handler.check(self, obj, path): @@ -298,14 +313,15 @@ def json_serialize( raise ValueError(f"no handler found for object with {type(obj) = }") except Exception as e: - if self.error_mode == "except": + if self.error_mode == ErrorMode.EXCEPT: obj_str: str = repr(obj) if len(obj_str) > 1000: obj_str = obj_str[:1000] + "..." + handler_uid = handler.uid if handler else "no handler matched" raise SerializationException( - f"error serializing at {path = } with last handler: '{handler.uid}'\nfrom: {e}\nobj: {obj_str}" + f"error serializing at {path = } with last handler: '{handler_uid}'\nfrom: {e}\nobj: {obj_str}" ) from e - elif self.error_mode == "warn": + elif self.error_mode == ErrorMode.WARN: warnings.warn( f"error serializing at {path = }, will return as string\n{obj = }\nexception = {e}" ) diff --git a/muutils/json_serialize/util.py b/muutils/json_serialize/util.py index 06068118..902a2ee5 100644 --- a/muutils/json_serialize/util.py +++ b/muutils/json_serialize/util.py @@ -114,10 +114,10 @@ def newfunc(*args, **kwargs): def _recursive_hashify(obj: Any, force: bool = True) -> Hashableitem: if isinstance(obj, typing.Mapping): return tuple((k, _recursive_hashify(v)) for k, v in obj.items()) - elif isinstance(obj, (tuple, list, Iterable)): - return tuple(_recursive_hashify(v) for v in obj) elif isinstance(obj, (bool, int, float, str)): return obj + elif isinstance(obj, (tuple, list, Iterable)): + return tuple(_recursive_hashify(v) for v in obj) else: if force: return str(obj) From 573ef839afca0e8fb14deb6ff30ce21cdc554225 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Wed, 29 Oct 2025 14:14:26 +0000 Subject: [PATCH 14/72] add a lot of tests thx claude --- tests/unit/cli/test_arg_bool.py | 581 ++++++++++++++ tests/unit/cli/test_command.py | 132 ++++ .../errormode/test_errormode_functionality.py | 694 ++++++++++++++++ tests/unit/json_serialize/test_array.py | 134 ++++ tests/unit/json_serialize/test_array_torch.py | 228 ++++++ .../json_serialize/test_json_serialize.py | 745 ++++++++++++++++++ .../json_serialize/test_serializable_field.py | 421 ++++++++++ tests/unit/json_serialize/test_util.py | 232 ++++++ tests/unit/logger/test_log_util.py | 176 +++++ tests/unit/test_collect_warnings.py | 464 +++++++++++ tests/unit/test_jsonlines.py | 194 +++++ 11 files changed, 4001 insertions(+) create mode 100644 tests/unit/cli/test_arg_bool.py create mode 100644 tests/unit/cli/test_command.py create mode 100644 tests/unit/json_serialize/test_array_torch.py create mode 100644 tests/unit/json_serialize/test_json_serialize.py create mode 100644 tests/unit/json_serialize/test_serializable_field.py create mode 100644 tests/unit/logger/test_log_util.py create mode 100644 tests/unit/test_collect_warnings.py create mode 100644 tests/unit/test_jsonlines.py diff --git a/tests/unit/cli/test_arg_bool.py b/tests/unit/cli/test_arg_bool.py new file mode 100644 index 00000000..6cd8ae50 --- /dev/null +++ b/tests/unit/cli/test_arg_bool.py @@ -0,0 +1,581 @@ +"""Tests for muutils.cli.arg_bool module.""" + +from __future__ import annotations + +import argparse +import pytest +from pytest import mark, param + +from muutils.cli.arg_bool import ( + parse_bool_token, + BoolFlagOrValue, + add_bool_flag, + TRUE_SET_DEFAULT, + FALSE_SET_DEFAULT, +) + + +# ============================================================================ +# Tests for parse_bool_token +# ============================================================================ + + +def test_parse_bool_token_valid(): + """Test parse_bool_token with valid true/false tokens.""" + # True tokens from default set + assert parse_bool_token("true") is True + assert parse_bool_token("1") is True + assert parse_bool_token("t") is True + assert parse_bool_token("yes") is True + assert parse_bool_token("y") is True + assert parse_bool_token("on") is True + + # False tokens from default set + assert parse_bool_token("false") is False + assert parse_bool_token("0") is False + assert parse_bool_token("f") is False + assert parse_bool_token("no") is False + assert parse_bool_token("n") is False + assert parse_bool_token("off") is False + + +def test_parse_bool_token_case_insensitive(): + """Test parse_bool_token is case-insensitive.""" + assert parse_bool_token("TRUE") is True + assert parse_bool_token("True") is True + assert parse_bool_token("TrUe") is True + assert parse_bool_token("FALSE") is False + assert parse_bool_token("False") is False + assert parse_bool_token("FaLsE") is False + assert parse_bool_token("YES") is True + assert parse_bool_token("NO") is False + assert parse_bool_token("ON") is True + assert parse_bool_token("OFF") is False + + +def test_parse_bool_token_invalid(): + """Test parse_bool_token with invalid tokens raises ArgumentTypeError.""" + with pytest.raises(argparse.ArgumentTypeError, match="expected one of"): + parse_bool_token("invalid") + + with pytest.raises(argparse.ArgumentTypeError, match="expected one of"): + parse_bool_token("maybe") + + with pytest.raises(argparse.ArgumentTypeError, match="expected one of"): + parse_bool_token("2") + + with pytest.raises(argparse.ArgumentTypeError, match="expected one of"): + parse_bool_token("") + + +def test_parse_bool_token_custom_sets(): + """Test parse_bool_token with custom true/false sets.""" + custom_true = {"enabled", "active"} + custom_false = {"disabled", "inactive"} + + assert ( + parse_bool_token("enabled", true_set=custom_true, false_set=custom_false) + is True + ) + assert ( + parse_bool_token("ACTIVE", true_set=custom_true, false_set=custom_false) is True + ) + assert ( + parse_bool_token("disabled", true_set=custom_true, false_set=custom_false) + is False + ) + assert ( + parse_bool_token("INACTIVE", true_set=custom_true, false_set=custom_false) + is False + ) + + # Default tokens should not work with custom sets + with pytest.raises(argparse.ArgumentTypeError): + parse_bool_token("true", true_set=custom_true, false_set=custom_false) + + +# ============================================================================ +# Tests for BoolFlagOrValue +# ============================================================================ + + +def test_BoolFlagOrValue_bare_flag(): + """Test bare flag (--flag with no value) → True.""" + parser = argparse.ArgumentParser() + parser.add_argument( + "--flag", + action=BoolFlagOrValue, + nargs="?", + default=False, + allow_bare=True, + ) + + # Bare flag should be True + args = parser.parse_args(["--flag"]) + assert args.flag is True + + # No flag should use default + args = parser.parse_args([]) + assert args.flag is False + + +def test_BoolFlagOrValue_negated(): + """Test negated flag (--no-flag) → False.""" + parser = argparse.ArgumentParser() + parser.add_argument( + "--flag", + "--no-flag", + dest="flag", + action=BoolFlagOrValue, + nargs="?", + default=True, + allow_no=True, + ) + + # --no-flag should be False + args = parser.parse_args(["--no-flag"]) + assert args.flag is False + + # --flag should be True (bare) + args = parser.parse_args(["--flag"]) + assert args.flag is True + + # No flag should use default + args = parser.parse_args([]) + assert args.flag is True + + +def test_BoolFlagOrValue_explicit_values(): + """Test explicit values: --flag true, --flag false.""" + parser = argparse.ArgumentParser() + parser.add_argument( + "--flag", + action=BoolFlagOrValue, + nargs="?", + default=False, + ) + + # --flag true + args = parser.parse_args(["--flag", "true"]) + assert args.flag is True + + # --flag false + args = parser.parse_args(["--flag", "false"]) + assert args.flag is False + + # --flag 1 + args = parser.parse_args(["--flag", "1"]) + assert args.flag is True + + # --flag 0 + args = parser.parse_args(["--flag", "0"]) + assert args.flag is False + + # --flag yes + args = parser.parse_args(["--flag", "yes"]) + assert args.flag is True + + # --flag no + args = parser.parse_args(["--flag", "no"]) + assert args.flag is False + + +def test_BoolFlagOrValue_equals_syntax(): + """Test --flag=true and --flag=false syntax.""" + parser = argparse.ArgumentParser() + parser.add_argument( + "--flag", + action=BoolFlagOrValue, + nargs="?", + default=False, + ) + + # --flag=true + args = parser.parse_args(["--flag=true"]) + assert args.flag is True + + # --flag=false + args = parser.parse_args(["--flag=false"]) + assert args.flag is False + + # --flag=1 + args = parser.parse_args(["--flag=1"]) + assert args.flag is True + + # --flag=0 + args = parser.parse_args(["--flag=0"]) + assert args.flag is False + + +def test_BoolFlagOrValue_allow_bare_false(): + """Test error on bare flag when allow_bare=False.""" + parser = argparse.ArgumentParser() + parser.add_argument( + "--flag", + action=BoolFlagOrValue, + nargs="?", + default=False, + allow_bare=False, + ) + + # Bare flag should error + with pytest.raises(SystemExit): + parser.parse_args(["--flag"]) + + # Explicit value should work + args = parser.parse_args(["--flag", "true"]) + assert args.flag is True + + +def test_BoolFlagOrValue_invalid_token(): + """Test --flag invalid raises error.""" + parser = argparse.ArgumentParser() + parser.add_argument( + "--flag", + action=BoolFlagOrValue, + nargs="?", + default=False, + ) + + # Invalid token should error + with pytest.raises(SystemExit): + parser.parse_args(["--flag", "invalid"]) + + with pytest.raises(SystemExit): + parser.parse_args(["--flag", "maybe"]) + + +def test_BoolFlagOrValue_no_flag_with_value_error(): + """Test --no-flag with a value raises error.""" + parser = argparse.ArgumentParser() + parser.add_argument( + "--flag", + "--no-flag", + dest="flag", + action=BoolFlagOrValue, + nargs="?", + default=True, + allow_no=True, + ) + + # --no-flag with value should error + with pytest.raises(SystemExit): + parser.parse_args(["--no-flag", "true"]) + + with pytest.raises(SystemExit): + parser.parse_args(["--no-flag=false"]) + + +def test_BoolFlagOrValue_allow_no_false(): + """Test error when using --no-flag but allow_no=False.""" + parser = argparse.ArgumentParser() + parser.add_argument( + "--flag", + "--no-flag", + dest="flag", + action=BoolFlagOrValue, + nargs="?", + default=True, + allow_no=False, + ) + + # --no-flag should error when allow_no=False + with pytest.raises(SystemExit): + parser.parse_args(["--no-flag"]) + + +def test_BoolFlagOrValue_custom_true_false_sets(): + """Test BoolFlagOrValue with custom true/false sets.""" + parser = argparse.ArgumentParser() + parser.add_argument( + "--flag", + action=BoolFlagOrValue, + nargs="?", + default=False, + true_set={"enabled", "active"}, + false_set={"disabled", "inactive"}, + ) + + args = parser.parse_args(["--flag", "enabled"]) + assert args.flag is True + + args = parser.parse_args(["--flag", "disabled"]) + assert args.flag is False + + # Default tokens should not work + with pytest.raises(SystemExit): + parser.parse_args(["--flag", "true"]) + + +def test_BoolFlagOrValue_invalid_nargs(): + """Test that BoolFlagOrValue raises ValueError for invalid nargs.""" + parser = argparse.ArgumentParser() + + # nargs other than '?' or None should raise ValueError + with pytest.raises(ValueError, match="requires nargs='?'"): + parser.add_argument( + "--flag", + action=BoolFlagOrValue, + nargs=1, + ) + + with pytest.raises(ValueError, match="requires nargs='?'"): + parser.add_argument( + "--flag2", + action=BoolFlagOrValue, + nargs="*", + ) + + +def test_BoolFlagOrValue_type_not_allowed(): + """Test that BoolFlagOrValue raises ValueError when type= is provided.""" + parser = argparse.ArgumentParser() + + with pytest.raises(ValueError, match="does not accept type="): + parser.add_argument( + "--flag", + action=BoolFlagOrValue, + nargs="?", + type=str, + ) + + +# ============================================================================ +# Tests for add_bool_flag +# ============================================================================ + + +def test_add_bool_flag_integration(): + """Test full integration with various argument combinations.""" + parser = argparse.ArgumentParser() + add_bool_flag(parser, "feature", default=False, help="Enable feature") + + # Bare flag + args = parser.parse_args(["--feature"]) + assert args.feature is True + + # Explicit true + args = parser.parse_args(["--feature", "true"]) + assert args.feature is True + + # Explicit false + args = parser.parse_args(["--feature", "false"]) + assert args.feature is False + + # Equals syntax + args = parser.parse_args(["--feature=true"]) + assert args.feature is True + + args = parser.parse_args(["--feature=false"]) + assert args.feature is False + + # No flag (default) + args = parser.parse_args([]) + assert args.feature is False + + +def test_add_bool_flag_allow_no(): + """Test both --flag and --no-flag work when allow_no=True.""" + parser = argparse.ArgumentParser() + add_bool_flag(parser, "feature", default=False, allow_no=True) + + # --feature + args = parser.parse_args(["--feature"]) + assert args.feature is True + + # --no-feature + args = parser.parse_args(["--no-feature"]) + assert args.feature is False + + # No flag (default) + args = parser.parse_args([]) + assert args.feature is False + + +def test_add_bool_flag_dest_conversion(): + """Test 'some-flag' → namespace.some_flag.""" + parser = argparse.ArgumentParser() + add_bool_flag(parser, "some-flag", default=False) + + args = parser.parse_args(["--some-flag"]) + assert args.some_flag is True + assert not hasattr(args, "some-flag") + + args = parser.parse_args(["--some-flag", "false"]) + assert args.some_flag is False + + +def test_add_bool_flag_custom_true_false_sets(): + """Test add_bool_flag with custom true/false sets.""" + parser = argparse.ArgumentParser() + add_bool_flag( + parser, + "feature", + default=False, + true_set={"enabled", "on"}, + false_set={"disabled", "off"}, + ) + + args = parser.parse_args(["--feature", "enabled"]) + assert args.feature is True + + args = parser.parse_args(["--feature", "disabled"]) + assert args.feature is False + + # Default tokens should not work + with pytest.raises(SystemExit): + parser.parse_args(["--feature", "true"]) + + +def test_add_bool_flag_allow_bare_false(): + """Test add_bool_flag with allow_bare=False.""" + parser = argparse.ArgumentParser() + add_bool_flag(parser, "feature", default=False, allow_bare=False) + + # Bare flag should error + with pytest.raises(SystemExit): + parser.parse_args(["--feature"]) + + # Explicit value should work + args = parser.parse_args(["--feature", "true"]) + assert args.feature is True + + +def test_add_bool_flag_default_true(): + """Test add_bool_flag with default=True.""" + parser = argparse.ArgumentParser() + add_bool_flag(parser, "feature", default=True) + + # No flag should use default=True + args = parser.parse_args([]) + assert args.feature is True + + # Explicit false should override + args = parser.parse_args(["--feature", "false"]) + assert args.feature is False + + +def test_add_bool_flag_multiple_flags(): + """Test multiple boolean flags in the same parser.""" + parser = argparse.ArgumentParser() + add_bool_flag(parser, "feature-a", default=False) + add_bool_flag(parser, "feature-b", default=True) + add_bool_flag(parser, "feature-c", default=False, allow_no=True) + + args = parser.parse_args( + [ + "--feature-a", + "--feature-b", + "false", + "--no-feature-c", + ] + ) + assert args.feature_a is True + assert args.feature_b is False + assert args.feature_c is False + + +def test_add_bool_flag_help_text(): + """Test that help text is generated or used correctly.""" + parser = argparse.ArgumentParser() + add_bool_flag(parser, "feature", default=False, help="Custom help text") + + # Check that the help is stored (can't easily test output without parsing help text) + action = None + for act in parser._actions: + if hasattr(act, "dest") and act.dest == "feature": + action = act + break + + assert action is not None + assert action.help == "Custom help text" + + +def test_add_bool_flag_default_help(): + """Test that default help text is generated when not provided.""" + parser = argparse.ArgumentParser() + add_bool_flag(parser, "my-feature", default=False) + + action = None + for act in parser._actions: + if hasattr(act, "dest") and act.dest == "my_feature": + action = act + break + + assert action is not None + assert "enable/disable my feature" in action.help + + +# ============================================================================ +# Integration and edge case tests +# ============================================================================ + + +def test_multiple_values_error(): + """Test that passing multiple values to a flag raises an error.""" + parser = argparse.ArgumentParser() + parser.add_argument( + "--flag", + action=BoolFlagOrValue, + nargs="?", + default=False, + ) + + # This should work with nargs='?', only one value accepted + args = parser.parse_args(["--flag", "true"]) + assert args.flag is True + + +@mark.parametrize( + "token, expected", + [ + param("true", True, id="true"), + param("false", False, id="false"), + param("1", True, id="1"), + param("0", False, id="0"), + param("yes", True, id="yes"), + param("no", False, id="no"), + param("on", True, id="on"), + param("off", False, id="off"), + param("t", True, id="t"), + param("f", False, id="f"), + param("y", True, id="y"), + param("n", False, id="n"), + param("TRUE", True, id="TRUE"), + param("FALSE", False, id="FALSE"), + param("Yes", True, id="Yes"), + param("No", False, id="No"), + ], +) +def test_parse_bool_token_parametrized(token: str, expected: bool): + """Parametrized test for all valid boolean tokens.""" + assert parse_bool_token(token) == expected + + +@mark.parametrize( + "invalid_token", + [ + param("invalid", id="invalid"), + param("maybe", id="maybe"), + param("2", id="2"), + param("-1", id="-1"), + param("", id="empty"), + param("truee", id="truee"), + param("yess", id="yess"), + ], +) +def test_parse_bool_token_invalid_parametrized(invalid_token: str): + """Parametrized test for invalid boolean tokens.""" + with pytest.raises(argparse.ArgumentTypeError): + parse_bool_token(invalid_token) + + +def test_constants_exist(): + """Test that the default token sets are defined correctly.""" + assert isinstance(TRUE_SET_DEFAULT, set) + assert isinstance(FALSE_SET_DEFAULT, set) + assert len(TRUE_SET_DEFAULT) > 0 + assert len(FALSE_SET_DEFAULT) > 0 + assert "true" in TRUE_SET_DEFAULT + assert "false" in FALSE_SET_DEFAULT + assert TRUE_SET_DEFAULT.isdisjoint(FALSE_SET_DEFAULT) diff --git a/tests/unit/cli/test_command.py b/tests/unit/cli/test_command.py new file mode 100644 index 00000000..0e8b11f7 --- /dev/null +++ b/tests/unit/cli/test_command.py @@ -0,0 +1,132 @@ +from __future__ import annotations + +import os +import subprocess + +import pytest + +from muutils.cli.command import Command + + +def test_Command_init(): + """Test Command initialization with list and string cmds.""" + # Valid: list cmd with shell=False (default) + cmd1 = Command(cmd=["echo", "hello"]) + assert cmd1.cmd == ["echo", "hello"] + assert cmd1.shell is False + + # Valid: string cmd with shell=True + cmd2 = Command(cmd="echo hello", shell=True) + assert cmd2.cmd == "echo hello" + assert cmd2.shell is True + + # Invalid: string cmd with shell=False should raise ValueError + with pytest.raises( + ValueError, match="cmd must be List\\[str\\] when shell is False" + ): + Command(cmd="echo hello", shell=False) + + # Valid: list cmd with shell=True is allowed (will be joined) + cmd3 = Command(cmd=["echo", "hello"], shell=True) + assert cmd3.cmd == ["echo", "hello"] + assert cmd3.shell is True + + +def test_Command_properties(): + """Test cmd_joined and cmd_for_subprocess properties in both shell modes.""" + # Test with shell=False (list cmd) + cmd_list = Command(cmd=["echo", "hello", "world"]) + assert cmd_list.cmd_joined == "echo hello world" + assert cmd_list.cmd_for_subprocess == ["echo", "hello", "world"] + + # Test with shell=True and string cmd + cmd_str = Command(cmd="echo hello world", shell=True) + assert cmd_str.cmd_joined == "echo hello world" + assert cmd_str.cmd_for_subprocess == "echo hello world" + + # Test with shell=True and list cmd (should be joined for subprocess) + cmd_list_shell = Command(cmd=["echo", "hello", "world"], shell=True) + assert cmd_list_shell.cmd_joined == "echo hello world" + assert cmd_list_shell.cmd_for_subprocess == "echo hello world" + + +def test_Command_script_line(): + """Test script_line with env vars formatting.""" + # No env vars + cmd1 = Command(cmd=["echo", "hello"]) + assert cmd1.script_line() == "echo hello" + + # With env vars + cmd2 = Command(cmd=["echo", "hello"], env={"FOO": "bar", "BAZ": "qux"}) + script = cmd2.script_line() + # env vars can be in any order, so check both are present + assert "FOO=bar" in script + assert "BAZ=qux" in script + assert "echo hello" in script + # Verify format: env vars come before command + assert script.endswith("echo hello") + + # With shell=True + cmd3 = Command(cmd="echo $FOO", shell=True, env={"FOO": "bar"}) + assert cmd3.script_line() == "FOO=bar echo $FOO" + + +def test_Command_env_final(): + """Test env_final with inherit_env=True and inherit_env=False.""" + # Set a test environment variable + os.environ["TEST_VAR_COMMAND"] = "original" + + try: + # inherit_env=True (default) should merge with os.environ + cmd1 = Command(cmd=["echo", "test"], env={"FOO": "bar"}) + env1 = cmd1.env_final + assert env1["FOO"] == "bar" + assert env1["TEST_VAR_COMMAND"] == "original" + + # inherit_env=False should only include provided env + cmd2 = Command(cmd=["echo", "test"], env={"FOO": "bar"}, inherit_env=False) + env2 = cmd2.env_final + assert env2["FOO"] == "bar" + assert "TEST_VAR_COMMAND" not in env2 + + # Custom env should override inherited env + os.environ["OVERRIDE_TEST"] = "old" + cmd3 = Command(cmd=["echo", "test"], env={"OVERRIDE_TEST": "new"}) + env3 = cmd3.env_final + assert env3["OVERRIDE_TEST"] == "new" + + finally: + # Clean up test env vars + os.environ.pop("TEST_VAR_COMMAND", None) + os.environ.pop("OVERRIDE_TEST", None) + + +def test_Command_run(): + """Test running a simple command and capturing output.""" + # Simple successful command + cmd = Command(cmd=["echo", "hello"]) + result = cmd.run(capture_output=True, text=True) + assert result.returncode == 0 + assert "hello" in result.stdout + + # Command with env vars + cmd2 = Command(cmd=["sh", "-c", "echo $TEST_VAR"], env={"TEST_VAR": "test_value"}) + result2 = cmd2.run(capture_output=True, text=True) + assert result2.returncode == 0 + assert "test_value" in result2.stdout + + # Shell command + cmd3 = Command(cmd="echo shell test", shell=True) + result3 = cmd3.run(capture_output=True, text=True) + assert result3.returncode == 0 + assert "shell test" in result3.stdout + + # Test that CalledProcessError is properly raised and handled + cmd4 = Command(cmd=["sh", "-c", "exit 1"]) + result4 = cmd4.run(capture_output=True) + assert result4.returncode == 1 # Should not raise by default + + # When check=True is passed, it should raise CalledProcessError + cmd5 = Command(cmd=["sh", "-c", "exit 1"]) + with pytest.raises(subprocess.CalledProcessError): + cmd5.run(check=True, capture_output=True) diff --git a/tests/unit/errormode/test_errormode_functionality.py b/tests/unit/errormode/test_errormode_functionality.py index d3ae7fe2..12cee8b4 100644 --- a/tests/unit/errormode/test_errormode_functionality.py +++ b/tests/unit/errormode/test_errormode_functionality.py @@ -109,6 +109,700 @@ def log_func(msg: str): assert log == ["test-log", "test-log-2", "test-log-3"] +def test_custom_showwarning(): + """Test custom_showwarning function with traceback handling and frame extraction.""" + from muutils.errormode import custom_showwarning + + # Capture warnings + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + + # Call custom_showwarning directly + custom_showwarning("test warning message", UserWarning) + + # Check that a warning was issued + assert len(w) == 1 + assert issubclass(w[0].category, UserWarning) + assert "test warning message" in str(w[0].message) + + # Check that the warning has traceback information + assert w[0].filename is not None + assert w[0].lineno is not None + + +def test_custom_showwarning_with_category(): + """Test custom_showwarning with different warning categories.""" + from muutils.errormode import custom_showwarning + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + + custom_showwarning("deprecation test", DeprecationWarning) + + assert len(w) == 1 + assert issubclass(w[0].category, DeprecationWarning) + + +def test_custom_showwarning_default_category(): + """Test custom_showwarning uses UserWarning as default.""" + from muutils.errormode import custom_showwarning + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + + # Call without specifying category + custom_showwarning("default category test", category=None) + + assert len(w) == 1 + assert issubclass(w[0].category, UserWarning) + + +def test_ErrorMode_process_except_from(): + """Test exception chaining with except_from parameter.""" + base_exception = ValueError("base error") + + try: + ErrorMode.EXCEPT.process( + "chained error message", + except_cls=RuntimeError, + except_from=base_exception, + ) + except RuntimeError as e: + # Check the exception message + assert str(e) == "chained error message" + # Check that __cause__ is set correctly + assert e.__cause__ is base_exception + assert isinstance(e.__cause__, ValueError) + assert str(e.__cause__) == "base error" + else: + pytest.fail("Expected RuntimeError to be raised") + + +def test_ErrorMode_process_except_from_different_types(): + """Test exception chaining with different exception types.""" + # Test with KeyError -> TypeError + base = KeyError("key not found") + try: + ErrorMode.EXCEPT.process("type error", except_cls=TypeError, except_from=base) + except TypeError as e: + assert e.__cause__ is base + + # Test with AttributeError -> ValueError + base2 = AttributeError("attribute missing") + try: + ErrorMode.EXCEPT.process( + "value error", except_cls=ValueError, except_from=base2 + ) + except ValueError as e: + assert e.__cause__ is base2 + + +def test_ErrorMode_process_custom_funcs(): + """Test custom warn_func and log_func parameters.""" + # Test custom warn_func + warnings_captured = [] + + def custom_warn(msg: str, category, source=None): + warnings_captured.append({"msg": msg, "category": category, "source": source}) + + ErrorMode.WARN.process( + "custom warn test", warn_cls=UserWarning, warn_func=custom_warn + ) + + assert len(warnings_captured) == 1 + assert warnings_captured[0]["msg"] == "custom warn test" + assert warnings_captured[0]["category"] == UserWarning # noqa: E721 + + # Test custom log_func + logs_captured = [] + + def custom_log(msg: str): + logs_captured.append(msg) + + ErrorMode.LOG.process("custom log test", log_func=custom_log) + + assert len(logs_captured) == 1 + assert logs_captured[0] == "custom log test" + + +def test_ErrorMode_process_custom_warn_func_with_except_from(): + """Test custom warn_func with except_from to augment message.""" + warnings_captured = [] + + def custom_warn(msg: str, category, source=None): + warnings_captured.append(msg) + + base_exception = ValueError("source exception") + + ErrorMode.WARN.process( + "warning message", + warn_cls=UserWarning, + warn_func=custom_warn, + except_from=base_exception, + ) + + assert len(warnings_captured) == 1 + # Check that the message is augmented with source + assert "warning message" in warnings_captured[0] + assert "Source of warning" in warnings_captured[0] + assert "source exception" in warnings_captured[0] + + +def test_ErrorMode_serialize_load(): + """Test round-trip serialization and loading.""" + # Test EXCEPT + serialized = ErrorMode.EXCEPT.serialize() + loaded = ErrorMode.load(serialized) + assert loaded is ErrorMode.EXCEPT + + # Test WARN + serialized = ErrorMode.WARN.serialize() + loaded = ErrorMode.load(serialized) + assert loaded is ErrorMode.WARN + + # Test LOG + serialized = ErrorMode.LOG.serialize() + loaded = ErrorMode.load(serialized) + assert loaded is ErrorMode.LOG + + # Test IGNORE + serialized = ErrorMode.IGNORE.serialize() + loaded = ErrorMode.load(serialized) + assert loaded is ErrorMode.IGNORE + + +def test_ErrorMode_serialize_format(): + """Test that serialize returns the expected format.""" + assert ErrorMode.EXCEPT.serialize() == "ErrorMode.Except" + assert ErrorMode.WARN.serialize() == "ErrorMode.Warn" + assert ErrorMode.LOG.serialize() == "ErrorMode.Log" + assert ErrorMode.IGNORE.serialize() == "ErrorMode.Ignore" + + +def test_ERROR_MODE_ALIASES(): + """Test that all aliases resolve correctly.""" + from muutils.errormode import ERROR_MODE_ALIASES + + # Test EXCEPT aliases + assert ERROR_MODE_ALIASES["except"] is ErrorMode.EXCEPT + assert ERROR_MODE_ALIASES["e"] is ErrorMode.EXCEPT + assert ERROR_MODE_ALIASES["error"] is ErrorMode.EXCEPT + assert ERROR_MODE_ALIASES["err"] is ErrorMode.EXCEPT + assert ERROR_MODE_ALIASES["raise"] is ErrorMode.EXCEPT + + # Test WARN aliases + assert ERROR_MODE_ALIASES["warn"] is ErrorMode.WARN + assert ERROR_MODE_ALIASES["w"] is ErrorMode.WARN + assert ERROR_MODE_ALIASES["warning"] is ErrorMode.WARN + + # Test LOG aliases + assert ERROR_MODE_ALIASES["log"] is ErrorMode.LOG + assert ERROR_MODE_ALIASES["l"] is ErrorMode.LOG + assert ERROR_MODE_ALIASES["print"] is ErrorMode.LOG + assert ERROR_MODE_ALIASES["output"] is ErrorMode.LOG + assert ERROR_MODE_ALIASES["show"] is ErrorMode.LOG + assert ERROR_MODE_ALIASES["display"] is ErrorMode.LOG + + # Test IGNORE aliases + assert ERROR_MODE_ALIASES["ignore"] is ErrorMode.IGNORE + assert ERROR_MODE_ALIASES["i"] is ErrorMode.IGNORE + assert ERROR_MODE_ALIASES["silent"] is ErrorMode.IGNORE + assert ERROR_MODE_ALIASES["quiet"] is ErrorMode.IGNORE + assert ERROR_MODE_ALIASES["nothing"] is ErrorMode.IGNORE + + +def test_ErrorMode_from_any_with_string(): + """Test from_any with string inputs.""" + # Test base values + assert ErrorMode.from_any("except") is ErrorMode.EXCEPT + assert ErrorMode.from_any("warn") is ErrorMode.WARN + assert ErrorMode.from_any("log") is ErrorMode.LOG + assert ErrorMode.from_any("ignore") is ErrorMode.IGNORE + + # Test with uppercase + assert ErrorMode.from_any("EXCEPT") is ErrorMode.EXCEPT + assert ErrorMode.from_any("WARN") is ErrorMode.WARN + + # Test with whitespace + assert ErrorMode.from_any(" except ") is ErrorMode.EXCEPT + assert ErrorMode.from_any(" warn ") is ErrorMode.WARN + + +def test_ErrorMode_from_any_with_aliases(): + """Test from_any with alias strings.""" + # Test EXCEPT aliases + assert ErrorMode.from_any("error") is ErrorMode.EXCEPT + assert ErrorMode.from_any("e") is ErrorMode.EXCEPT + assert ErrorMode.from_any("raise") is ErrorMode.EXCEPT + + # Test WARN aliases + assert ErrorMode.from_any("warning") is ErrorMode.WARN + assert ErrorMode.from_any("w") is ErrorMode.WARN + + # Test LOG aliases + assert ErrorMode.from_any("print") is ErrorMode.LOG + assert ErrorMode.from_any("l") is ErrorMode.LOG + assert ErrorMode.from_any("output") is ErrorMode.LOG + + # Test IGNORE aliases + assert ErrorMode.from_any("silent") is ErrorMode.IGNORE + assert ErrorMode.from_any("i") is ErrorMode.IGNORE + assert ErrorMode.from_any("quiet") is ErrorMode.IGNORE + + +def test_ErrorMode_from_any_with_prefix(): + """Test from_any with ErrorMode. prefix.""" + assert ErrorMode.from_any("ErrorMode.except") is ErrorMode.EXCEPT + assert ErrorMode.from_any("ErrorMode.warn") is ErrorMode.WARN + assert ErrorMode.from_any("ErrorMode.log") is ErrorMode.LOG + assert ErrorMode.from_any("ErrorMode.ignore") is ErrorMode.IGNORE + + # Test with mixed case + assert ErrorMode.from_any("ErrorMode.Except") is ErrorMode.EXCEPT + assert ErrorMode.from_any("ErrorMode.WARN") is ErrorMode.WARN + + +def test_ErrorMode_from_any_with_ErrorMode_instance(): + """Test from_any with ErrorMode instance.""" + assert ErrorMode.from_any(ErrorMode.EXCEPT) is ErrorMode.EXCEPT + assert ErrorMode.from_any(ErrorMode.WARN) is ErrorMode.WARN + assert ErrorMode.from_any(ErrorMode.LOG) is ErrorMode.LOG + assert ErrorMode.from_any(ErrorMode.IGNORE) is ErrorMode.IGNORE + + +def test_ErrorMode_from_any_without_aliases(): + """Test from_any with allow_aliases=False.""" + # Base values should still work + assert ErrorMode.from_any("except", allow_aliases=False) is ErrorMode.EXCEPT + + # Aliases should fail + with pytest.raises(KeyError): + ErrorMode.from_any("error", allow_aliases=False) + + with pytest.raises(KeyError): + ErrorMode.from_any("e", allow_aliases=False) + + +def test_ErrorMode_from_any_invalid_string(): + """Test from_any with invalid string.""" + with pytest.raises(KeyError): + ErrorMode.from_any("invalid_mode") + + with pytest.raises(KeyError): + ErrorMode.from_any("not_a_mode") + + +def test_ErrorMode_from_any_invalid_type(): + """Test from_any with invalid type.""" + with pytest.raises(TypeError): + ErrorMode.from_any(123) # type: ignore + + with pytest.raises(TypeError): + ErrorMode.from_any(None) # type: ignore + + with pytest.raises(TypeError): + ErrorMode.from_any([]) # type: ignore + + +def test_ErrorMode_str_repr(): + """Test __str__ and __repr__ methods.""" + assert str(ErrorMode.EXCEPT) == "ErrorMode.Except" + assert str(ErrorMode.WARN) == "ErrorMode.Warn" + assert str(ErrorMode.LOG) == "ErrorMode.Log" + assert str(ErrorMode.IGNORE) == "ErrorMode.Ignore" + + assert repr(ErrorMode.EXCEPT) == "ErrorMode.Except" + assert repr(ErrorMode.WARN) == "ErrorMode.Warn" + assert repr(ErrorMode.LOG) == "ErrorMode.Log" + assert repr(ErrorMode.IGNORE) == "ErrorMode.Ignore" + + +def test_ErrorMode_process_unknown_mode(): + """Test that an unknown error mode raises ValueError.""" + # This is a edge case that shouldn't normally happen, but testing defensively + # We can't easily create an invalid ErrorMode, so we test the else branch + # by mocking or checking that all modes are handled + # All enum values should be handled in process, so this is more of a sanity check + pass + + +def test_warn_with_except_from_builtin(): + """Test WARN mode with except_from using built-in warnings.warn.""" + import muutils.errormode as errormode + + # Make sure we're using the default warn function + errormode.GLOBAL_WARN_FUNC = warnings.warn # type: ignore + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + + base_exception = ValueError("base error") + ErrorMode.WARN.process( + "test warning", warn_cls=UserWarning, except_from=base_exception + ) + + assert len(w) == 1 + # Message should include source information + message_str = str(w[0].message) + assert "test warning" in message_str + assert "Source of warning" in message_str + assert "base error" in message_str + + +def test_custom_showwarning_with_warning_instance(): + """Test custom_showwarning when passed a Warning instance instead of string.""" + from muutils.errormode import custom_showwarning + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + + # Create a warning instance + warning_instance = UserWarning("instance warning") + custom_showwarning(warning_instance, UserWarning) + + assert len(w) == 1 + assert "instance warning" in str(w[0].message) + + +def test_log_with_custom_func(): + """Test LOG mode with custom log function passed directly.""" + logs = [] + + def my_logger(msg: str): + logs.append(f"LOGGED: {msg}") + + ErrorMode.LOG.process("test message", log_func=my_logger) + + assert len(logs) == 1 + assert logs[0] == "LOGGED: test message" + + +def test_multiple_log_functions(): + """Test that different log functions can be used.""" + log1 = [] + log2 = [] + + def logger1(msg: str): + log1.append(msg) + + def logger2(msg: str): + log2.append(msg) + + ErrorMode.LOG.process("message 1", log_func=logger1) + ErrorMode.LOG.process("message 2", log_func=logger2) + + assert log1 == ["message 1"] + assert log2 == ["message 2"] + + +def test_warn_with_source_parameter(): + """Test that warn_func receives proper parameters.""" + calls = [] + + def tracking_warn(msg: str, category, source=None): + calls.append({"msg": msg, "category": category, "source": source}) + + ErrorMode.WARN.process( + "test message", warn_cls=DeprecationWarning, warn_func=tracking_warn + ) + + assert len(calls) == 1 + assert calls[0]["msg"] == "test message" + assert calls[0]["category"] == DeprecationWarning # noqa: E721 + + +def test_ErrorMode_enum_values(): + """Test that ErrorMode has the expected enum values.""" + assert ErrorMode.EXCEPT.value == "except" + assert ErrorMode.WARN.value == "warn" + assert ErrorMode.LOG.value == "log" + assert ErrorMode.IGNORE.value == "ignore" + + +def test_from_any_without_prefix(): + """Test from_any with allow_prefix=False.""" + # Should still work with plain values + assert ErrorMode.from_any("except", allow_prefix=False) is ErrorMode.EXCEPT + + # Should fail with prefix + with pytest.raises(KeyError): + ErrorMode.from_any("ErrorMode.except", allow_prefix=False) + + +def test_GLOBAL_WARN_FUNC(): + """Test that GLOBAL_WARN_FUNC is used when no warn_func is provided.""" + import muutils.errormode as errormode + + # Save original + original_warn_func = errormode.GLOBAL_WARN_FUNC + + try: + # Set custom global warn function + captured = [] + + def global_warn(msg: str, category, source=None): + captured.append(msg) + + errormode.GLOBAL_WARN_FUNC = global_warn # type: ignore + + # Use WARN mode without providing warn_func + ErrorMode.WARN.process("test with global", warn_cls=UserWarning) + + assert len(captured) == 1 + assert captured[0] == "test with global" + + finally: + # Restore original + errormode.GLOBAL_WARN_FUNC = original_warn_func + + +def test_GLOBAL_LOG_FUNC(): + """Test that GLOBAL_LOG_FUNC is used when no log_func is provided.""" + import muutils.errormode as errormode + + # Save original + original_log_func = errormode.GLOBAL_LOG_FUNC + + try: + # Set custom global log function + captured = [] + + def global_log(msg: str): + captured.append(msg) + + errormode.GLOBAL_LOG_FUNC = global_log + + # Use LOG mode without providing log_func + ErrorMode.LOG.process("test with global log") + + assert len(captured) == 1 + assert captured[0] == "test with global log" + + finally: + # Restore original + errormode.GLOBAL_LOG_FUNC = original_log_func + + +def test_custom_warn_func_signature(): + """Test that custom warn_func follows the WarningFunc protocol.""" + from muutils.errormode import WarningFunc + + # Create a function that matches the protocol + def my_warn(msg: str, category: type[Warning], source=None) -> None: + pass + + # This should work without errors + warn_func: WarningFunc = my_warn # type: ignore + + # Use it with ErrorMode + ErrorMode.WARN.process("test", warn_cls=UserWarning, warn_func=warn_func) + + +def test_ErrorMode_all_enum_members(): + """Test that all ErrorMode enum members are accessible.""" + # Verify all enum members exist + assert hasattr(ErrorMode, "EXCEPT") + assert hasattr(ErrorMode, "WARN") + assert hasattr(ErrorMode, "LOG") + assert hasattr(ErrorMode, "IGNORE") + + # Test that they are unique + modes = [ErrorMode.EXCEPT, ErrorMode.WARN, ErrorMode.LOG, ErrorMode.IGNORE] + assert len(set(modes)) == 4 + + +def test_custom_showwarning_frame_extraction(): + """Test that custom_showwarning correctly extracts frame information.""" + import sys + from muutils.errormode import custom_showwarning + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + + # Call from this specific line so we can verify frame info + line_number = 0 + + def call_showwarning(): + nonlocal line_number + line_number = sys._getframe().f_lineno + 1 + custom_showwarning("frame test", UserWarning) + + call_showwarning() + + assert len(w) == 1 + # The warning should have been issued with correct file and line info + assert w[0].filename == __file__ + # Line number should be close to where we called it + assert isinstance(w[0].lineno, int) + + +def test_exception_traceback_attached(): + """Test that raised exceptions have traceback attached.""" + try: + ErrorMode.EXCEPT.process("test traceback", except_cls=ValueError) + except ValueError as e: + # Check that exception has traceback + assert e.__traceback__ is not None + else: + pytest.fail("Expected ValueError to be raised") + + +def test_exception_traceback_with_chaining(): + """Test that chained exceptions have correct traceback.""" + base = RuntimeError("base") + + try: + ErrorMode.EXCEPT.process("chained", except_cls=ValueError, except_from=base) + except ValueError as e: + # Check traceback exists + assert e.__traceback__ is not None + # Check cause is set + assert e.__cause__ is base + else: + pytest.fail("Expected ValueError to be raised") + + +def test_warn_with_default_warn_func(): + """Test WARN mode with default warnings.warn function.""" + import muutils.errormode as errormode + + # Ensure we're using default + errormode.GLOBAL_WARN_FUNC = warnings.warn # type: ignore + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + + ErrorMode.WARN.process("default warn func test", warn_cls=UserWarning) + + assert len(w) == 1 + assert "default warn func test" in str(w[0].message) + + +def test_from_any_strip_whitespace(): + """Test that from_any strips whitespace correctly.""" + # Leading/trailing spaces + assert ErrorMode.from_any(" except") is ErrorMode.EXCEPT + assert ErrorMode.from_any("warn ") is ErrorMode.WARN + assert ErrorMode.from_any(" log ") is ErrorMode.LOG + + # Tabs and newlines + assert ErrorMode.from_any("\texcept\t") is ErrorMode.EXCEPT + assert ErrorMode.from_any("\nwarn\n") is ErrorMode.WARN + + +def test_load_with_prefix(): + """Test load method with ErrorMode. prefix.""" + # load uses allow_prefix=True + loaded = ErrorMode.load("ErrorMode.Except") + assert loaded is ErrorMode.EXCEPT + + loaded = ErrorMode.load("ErrorMode.warn") + assert loaded is ErrorMode.WARN + + +def test_load_without_aliases(): + """Test that load does not accept aliases.""" + # load uses allow_aliases=False + with pytest.raises((KeyError, ValueError)): + ErrorMode.load("error") # alias should not work + + with pytest.raises((KeyError, ValueError)): + ErrorMode.load("e") # alias should not work + + +def test_ERROR_MODE_ALIASES_completeness(): + """Test that ERROR_MODE_ALIASES contains all expected aliases.""" + from muutils.errormode import ERROR_MODE_ALIASES + + # Count aliases per mode + except_aliases = [k for k, v in ERROR_MODE_ALIASES.items() if v is ErrorMode.EXCEPT] + warn_aliases = [k for k, v in ERROR_MODE_ALIASES.items() if v is ErrorMode.WARN] + log_aliases = [k for k, v in ERROR_MODE_ALIASES.items() if v is ErrorMode.LOG] + ignore_aliases = [k for k, v in ERROR_MODE_ALIASES.items() if v is ErrorMode.IGNORE] + + # Verify we have multiple aliases for each mode + assert len(except_aliases) >= 5 # except, e, error, err, raise + assert len(warn_aliases) >= 3 # warn, w, warning + assert len(log_aliases) >= 6 # log, l, print, output, show, display + assert len(ignore_aliases) >= 5 # ignore, i, silent, quiet, nothing + + +def test_custom_exception_classes(): + """Test process with various custom exception classes.""" + + class CustomError(Exception): + pass + + class NestedCustomError(CustomError): + pass + + # Test with custom exception + with pytest.raises(CustomError): + ErrorMode.EXCEPT.process("custom", except_cls=CustomError) + + # Test with nested custom exception + with pytest.raises(NestedCustomError): + ErrorMode.EXCEPT.process("nested custom", except_cls=NestedCustomError) + + +def test_custom_warning_classes(): + """Test process with various custom warning classes.""" + + class CustomWarning(UserWarning): + pass + + class NestedCustomWarning(CustomWarning): + pass + + # Test with custom warning + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + + def custom_warn(msg: str, category, source=None): + warnings.warn(msg, category) + + ErrorMode.WARN.process("custom", warn_cls=CustomWarning, warn_func=custom_warn) + + assert len(w) == 1 + assert issubclass(w[0].category, CustomWarning) + + +def test_ignore_with_all_parameters(): + """Test that IGNORE mode ignores all parameters.""" + # None of these should raise or warn + ErrorMode.IGNORE.process("ignored message") + ErrorMode.IGNORE.process("ignored", except_cls=ValueError) + ErrorMode.IGNORE.process("ignored", warn_cls=UserWarning) + ErrorMode.IGNORE.process("ignored", except_from=ValueError("base")) + + # Also test with custom functions (they should not be called) + called = [] + + def should_not_be_called(msg: str): + called.append(msg) + + ErrorMode.IGNORE.process("ignored", log_func=should_not_be_called) + + # log_func should not have been called + assert len(called) == 0 + + +def test_from_any_case_insensitivity(): + """Test that from_any is case insensitive.""" + # Test various cases + assert ErrorMode.from_any("EXCEPT") is ErrorMode.EXCEPT + assert ErrorMode.from_any("Except") is ErrorMode.EXCEPT + assert ErrorMode.from_any("eXcEpT") is ErrorMode.EXCEPT + + assert ErrorMode.from_any("WARN") is ErrorMode.WARN + assert ErrorMode.from_any("Warn") is ErrorMode.WARN + + # Test with aliases + assert ErrorMode.from_any("ERROR") is ErrorMode.EXCEPT + assert ErrorMode.from_any("Error") is ErrorMode.EXCEPT + assert ErrorMode.from_any("RAISE") is ErrorMode.EXCEPT + + # def test_logging_pass(): # errmode: ErrorMode = ErrorMode.LOG diff --git a/tests/unit/json_serialize/test_array.py b/tests/unit/json_serialize/test_array.py index ce348161..bd4c46df 100644 --- a/tests/unit/json_serialize/test_array.py +++ b/tests/unit/json_serialize/test_array.py @@ -9,6 +9,7 @@ load_array, serialize_array, ) +from muutils.json_serialize.util import _FORMAT_KEY # pylint: disable=missing-class-docstring @@ -87,3 +88,136 @@ def test_serialize_load_zero_dim(self): ) loaded_array = load_array(serialized_array) assert np.array_equal(loaded_array, self.array_zero_dim) + + +def test_array_shape_dtype_preservation(): + """Test that various shapes and dtypes are preserved through serialization.""" + # Test different shapes + shapes_and_arrays = [ + (np.array([1, 2, 3], dtype=np.int32), "1D int32"), + (np.array([[1.5, 2.5], [3.5, 4.5]], dtype=np.float32), "2D float32"), + (np.array([[[1]], [[2]]], dtype=np.int8), "3D int8"), + (np.array([[[[1, 2, 3, 4]]]], dtype=np.int16), "4D int16"), + ] + + # Test different dtypes + dtype_tests = [ + (np.array([1, 2, 3], dtype=np.int8), np.int8), + (np.array([1, 2, 3], dtype=np.int16), np.int16), + (np.array([1, 2, 3], dtype=np.int32), np.int32), + (np.array([1, 2, 3], dtype=np.int64), np.int64), + (np.array([1.0, 2.0, 3.0], dtype=np.float16), np.float16), + (np.array([1.0, 2.0, 3.0], dtype=np.float32), np.float32), + (np.array([1.0, 2.0, 3.0], dtype=np.float64), np.float64), + (np.array([True, False, True], dtype=np.bool_), np.bool_), + ] + + jser = JsonSerializer(array_mode="array_list_meta") + + # Test shapes preservation + for arr, description in shapes_and_arrays: + for mode in ["array_list_meta", "array_hex_meta", "array_b64_meta"]: + serialized = serialize_array(jser, arr, "test", array_mode=mode) # type: ignore[arg-type] + loaded = load_array(serialized) + assert loaded.shape == arr.shape, ( + f"Shape mismatch for {description} in {mode}" + ) + assert loaded.dtype == arr.dtype, ( + f"Dtype mismatch for {description} in {mode}" + ) + assert np.array_equal(loaded, arr), ( + f"Data mismatch for {description} in {mode}" + ) + + # Test dtypes preservation + for arr, expected_dtype in dtype_tests: + for mode in ["array_list_meta", "array_hex_meta", "array_b64_meta"]: + serialized = serialize_array(jser, arr, "test", array_mode=mode) # type: ignore[arg-type] + loaded = load_array(serialized) + assert loaded.dtype == expected_dtype, f"Dtype not preserved: {mode}" + assert np.array_equal(loaded, arr), f"Data not preserved: {mode}" + + +def test_array_serialization_handlers(): + """Test integration with JsonSerializer - ensure arrays are serialized correctly when part of larger objects.""" + # Test that JsonSerializer properly handles arrays in different contexts + jser = JsonSerializer(array_mode="array_list_meta") + + # Array in a dict + data_dict = { + "metadata": {"name": "test"}, + "array": np.array([1, 2, 3, 4]), + "nested": {"inner_array": np.array([[1, 2], [3, 4]])}, + } + + serialized = jser.json_serialize(data_dict) + assert isinstance(serialized["array"], dict) + assert _FORMAT_KEY in serialized["array"] + assert serialized["array"]["shape"] == [4] + + # Array in a list + data_list = [ + {"value": 1}, + np.array([10, 20, 30]), + {"value": 2, "data": np.array([[1, 2]])}, + ] + + serialized_list = jser.json_serialize(data_list) + assert isinstance(serialized_list[1], dict) + assert _FORMAT_KEY in serialized_list[1] + + # Test different array modes + for mode in ["list", "array_list_meta", "array_hex_meta", "array_b64_meta"]: + jser_mode = JsonSerializer(array_mode=mode) # type: ignore[arg-type] + arr = np.array([[1, 2, 3], [4, 5, 6]]) + result = jser_mode.json_serialize(arr) + + if mode == "list": + assert isinstance(result, list) + else: + assert isinstance(result, dict) + assert _FORMAT_KEY in result + + +def test_array_edge_cases(): + """Test edge cases: empty arrays, unusual dtypes, and boundary conditions.""" + jser = JsonSerializer(array_mode="array_list_meta") + + # Empty arrays with different shapes + empty_1d = np.array([], dtype=np.int32) + empty_2d = np.array([[], []], dtype=np.float32).reshape(2, 0) + empty_3d = np.array([[]], dtype=np.int64).reshape(1, 1, 0) + + for empty_arr in [empty_1d, empty_2d, empty_3d]: + for mode in ["array_list_meta", "array_hex_meta", "array_b64_meta"]: + serialized = serialize_array(jser, empty_arr, "test", array_mode=mode) # type: ignore[arg-type] + loaded = load_array(serialized) + assert loaded.shape == empty_arr.shape + assert loaded.dtype == empty_arr.dtype + assert np.array_equal(loaded, empty_arr) + + # Complex dtypes + complex_arr = np.array([1 + 2j, 3 + 4j, 5 + 6j], dtype=np.complex64) + serialized = serialize_array( + jser, complex_arr, "test", array_mode="array_list_meta" + ) + loaded = load_array(serialized) + assert loaded.dtype == np.complex64 + assert np.array_equal(loaded, complex_arr) + + # Large arrays (test that serialization doesn't break) + large_arr = np.random.rand(100, 100) + for mode in ["array_list_meta", "array_hex_meta", "array_b64_meta"]: + serialized = serialize_array(jser, large_arr, "test", array_mode=mode) # type: ignore[arg-type] + loaded = load_array(serialized) + assert np.allclose(loaded, large_arr) + + # Arrays with special values + special_arr = np.array([np.inf, -np.inf, np.nan, 0.0, -0.0], dtype=np.float64) + for mode in ["array_list_meta", "array_hex_meta", "array_b64_meta"]: + serialized = serialize_array(jser, special_arr, "test", array_mode=mode) # type: ignore[arg-type] + loaded = load_array(serialized) + # Use special comparison for NaN + assert np.isnan(loaded[2]) and np.isnan(special_arr[2]) + assert np.array_equal(loaded[:2], special_arr[:2]) # inf values + assert np.array_equal(loaded[3:], special_arr[3:]) # zeros diff --git a/tests/unit/json_serialize/test_array_torch.py b/tests/unit/json_serialize/test_array_torch.py new file mode 100644 index 00000000..11844638 --- /dev/null +++ b/tests/unit/json_serialize/test_array_torch.py @@ -0,0 +1,228 @@ +import numpy as np +import pytest +import torch + +from muutils.json_serialize import JsonSerializer +from muutils.json_serialize.array import ( + arr_metadata, + array_n_elements, + load_array, + serialize_array, +) +from muutils.json_serialize.util import _FORMAT_KEY + +# pylint: disable=missing-class-docstring + + +def test_arr_metadata_torch(): + """Test arr_metadata() with torch tensors.""" + # 1D tensor + tensor_1d = torch.tensor([1, 2, 3, 4, 5]) + metadata_1d = arr_metadata(tensor_1d) + assert metadata_1d["shape"] == [5] + assert "int64" in metadata_1d["dtype"] # Could be "torch.int64" or "int64" + assert metadata_1d["n_elements"] == 5 + + # 2D tensor + tensor_2d = torch.tensor([[1.0, 2.0], [3.0, 4.0]], dtype=torch.float32) + metadata_2d = arr_metadata(tensor_2d) + assert metadata_2d["shape"] == [2, 2] + assert "float32" in metadata_2d["dtype"] + assert metadata_2d["n_elements"] == 4 + + # 3D tensor + tensor_3d = torch.randn(3, 4, 5, dtype=torch.float64) + metadata_3d = arr_metadata(tensor_3d) + assert metadata_3d["shape"] == [3, 4, 5] + assert "float64" in metadata_3d["dtype"] + assert metadata_3d["n_elements"] == 60 + + # Zero-dimensional tensor + tensor_0d = torch.tensor(42) + metadata_0d = arr_metadata(tensor_0d) + assert metadata_0d["shape"] == [] + assert metadata_0d["n_elements"] == 1 + + +def test_array_n_elements_torch(): + """Test array_n_elements() with torch tensors.""" + assert array_n_elements(torch.tensor([1, 2, 3])) == 3 + assert array_n_elements(torch.tensor([[1, 2], [3, 4]])) == 4 + assert array_n_elements(torch.randn(2, 3, 4)) == 24 + assert array_n_elements(torch.tensor(42)) == 1 + + +def test_serialize_load_torch_tensors(): + """Test round-trip serialization of torch tensors.""" + jser = JsonSerializer(array_mode="array_list_meta") + + # Test various tensor types + tensors = [ + torch.tensor([1, 2, 3, 4], dtype=torch.int32), + torch.tensor([[1.5, 2.5], [3.5, 4.5]], dtype=torch.float32), + torch.tensor([[[1, 2]], [[3, 4]]], dtype=torch.int64), + torch.tensor([True, False, True], dtype=torch.bool), + ] + + for tensor in tensors: + for mode in ["array_list_meta", "array_hex_meta", "array_b64_meta"]: + serialized = serialize_array(jser, tensor, "test", array_mode=mode) # type: ignore[arg-type] + loaded = load_array(serialized) + + # Convert to numpy for comparison + tensor_np = tensor.cpu().numpy() + assert np.array_equal(loaded, tensor_np) + assert loaded.shape == tuple(tensor.shape) + + +def test_torch_shape_dtype_preservation(): + """Test that various torch tensor shapes and dtypes are preserved.""" + jser = JsonSerializer(array_mode="array_list_meta") + + # Different dtypes + dtype_tests = [ + (torch.tensor([1, 2, 3], dtype=torch.int8), torch.int8), + (torch.tensor([1, 2, 3], dtype=torch.int16), torch.int16), + (torch.tensor([1, 2, 3], dtype=torch.int32), torch.int32), + (torch.tensor([1, 2, 3], dtype=torch.int64), torch.int64), + (torch.tensor([1.0, 2.0, 3.0], dtype=torch.float16), torch.float16), + (torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32), torch.float32), + (torch.tensor([1.0, 2.0, 3.0], dtype=torch.float64), torch.float64), + (torch.tensor([True, False, True], dtype=torch.bool), torch.bool), + ] + + for tensor, expected_dtype in dtype_tests: + for mode in ["array_list_meta", "array_hex_meta", "array_b64_meta"]: + serialized = serialize_array(jser, tensor, "test", array_mode=mode) # type: ignore[arg-type] + loaded = load_array(serialized) + + # Convert for comparison + tensor_np = tensor.cpu().numpy() + assert np.array_equal(loaded, tensor_np) + assert loaded.dtype.name == tensor_np.dtype.name + + +def test_torch_zero_dim_tensor(): + """Test zero-dimensional torch tensors.""" + jser = JsonSerializer(array_mode="array_list_meta") + + tensor_0d = torch.tensor(42) + + for mode in ["array_list_meta", "array_hex_meta", "array_b64_meta"]: + serialized = serialize_array(jser, tensor_0d, "test", array_mode=mode) # type: ignore[arg-type] + loaded = load_array(serialized) + + # Zero-dim tensors have special handling + assert loaded.shape == tensor_0d.shape + assert np.array_equal(loaded, tensor_0d.cpu().numpy()) + + +def test_torch_edge_cases(): + """Test edge cases with torch tensors.""" + jser = JsonSerializer(array_mode="array_list_meta") + + # Empty tensors + empty_1d = torch.tensor([], dtype=torch.float32) + serialized = serialize_array(jser, empty_1d, "test", array_mode="array_list_meta") + loaded = load_array(serialized) + assert loaded.shape == (0,) + + # Tensors with special values + special_tensor = torch.tensor( + [float("inf"), float("-inf"), float("nan"), 0.0, -0.0] + ) + for mode in ["array_list_meta", "array_hex_meta", "array_b64_meta"]: + serialized = serialize_array(jser, special_tensor, "test", array_mode=mode) # type: ignore[arg-type] + loaded = load_array(serialized) + + # Check special values + assert np.isinf(loaded[0]) and loaded[0] > 0 + assert np.isinf(loaded[1]) and loaded[1] < 0 + assert np.isnan(loaded[2]) + + # Large tensor + large_tensor = torch.randn(100, 100) + serialized = serialize_array( + jser, large_tensor, "test", array_mode="array_b64_meta" + ) + loaded = load_array(serialized) + assert np.allclose(loaded, large_tensor.cpu().numpy()) + + +def test_torch_gpu_tensors(): + """Test serialization of GPU tensors (if CUDA is available).""" + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + + jser = JsonSerializer(array_mode="array_list_meta") + + # Create GPU tensor + tensor_gpu = torch.tensor([1, 2, 3, 4], dtype=torch.float32, device="cuda") + + for mode in ["array_list_meta", "array_hex_meta", "array_b64_meta"]: + # Need to move to CPU first for numpy conversion + tensor_cpu_torch = tensor_gpu.cpu() + serialized = serialize_array(jser, tensor_cpu_torch, "test", array_mode=mode) # type: ignore[arg-type] + loaded = load_array(serialized) + + # Should match the CPU version + tensor_cpu = tensor_gpu.cpu().numpy() + assert np.array_equal(loaded, tensor_cpu) + + +def test_torch_serialization_integration(): + """Test torch tensors integrated with JsonSerializer in complex structures.""" + jser = JsonSerializer(array_mode="array_list_meta") + + # Mixed structure with torch tensors + data = { + "model_weights": torch.randn(10, 5), + "biases": torch.randn(5), + "metadata": {"epochs": 10, "lr": 0.001}, + "history": [ + {"loss": torch.tensor(0.5), "accuracy": torch.tensor(0.95)}, + {"loss": torch.tensor(0.3), "accuracy": torch.tensor(0.97)}, + ], + } + + serialized = jser.json_serialize(data) + + # Check structure is preserved + assert isinstance(serialized["model_weights"], dict) + assert _FORMAT_KEY in serialized["model_weights"] + assert serialized["model_weights"]["shape"] == [10, 5] + + assert isinstance(serialized["biases"], dict) + assert serialized["biases"]["shape"] == [5] + + assert serialized["metadata"]["epochs"] == 10 + + # Check nested tensors + assert isinstance(serialized["history"][0]["loss"], dict) + assert _FORMAT_KEY in serialized["history"][0]["loss"] + + +def test_mixed_numpy_torch(): + """Test that both numpy arrays and torch tensors can coexist in serialization.""" + jser = JsonSerializer(array_mode="array_list_meta") + + data = { + "numpy_array": np.array([1, 2, 3]), + "torch_tensor": torch.tensor([4, 5, 6]), + "nested": { + "np": np.array([[1, 2]]), + "torch": torch.tensor([[3, 4]]), + }, + } + + serialized = jser.json_serialize(data) + + # Both should be serialized as dicts with metadata + assert isinstance(serialized["numpy_array"], dict) + assert isinstance(serialized["torch_tensor"], dict) + assert _FORMAT_KEY in serialized["numpy_array"] + assert _FORMAT_KEY in serialized["torch_tensor"] + + # Check format strings identify the type + assert "numpy" in serialized["numpy_array"][_FORMAT_KEY] + assert "torch" in serialized["torch_tensor"][_FORMAT_KEY] diff --git a/tests/unit/json_serialize/test_json_serialize.py b/tests/unit/json_serialize/test_json_serialize.py new file mode 100644 index 00000000..8e016c75 --- /dev/null +++ b/tests/unit/json_serialize/test_json_serialize.py @@ -0,0 +1,745 @@ +"""Tests for muutils.json_serialize.json_serialize module. + +IMPORTANT: This tests the core json_serialize functionality. Array-specific tests are in test_array.py, +and utility function tests are in test_util.py. We focus on JsonSerializer class and handler system here. +""" + +from __future__ import annotations + +import warnings +from collections import namedtuple +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +import pytest + +from muutils.errormode import ErrorMode +from muutils.json_serialize.json_serialize import ( + BASE_HANDLERS, + DEFAULT_HANDLERS, + JsonSerializer, + SerializerHandler, + json_serialize, +) +from muutils.json_serialize.util import SerializationException, _FORMAT_KEY + + +# ============================================================================ +# Test classes and fixtures +# ============================================================================ + + +@dataclass +class SimpleDataclass: + """Simple dataclass for testing.""" + + x: int + y: str + z: bool = True + + +@dataclass +class NestedDataclass: + """Nested dataclass for testing.""" + + simple: SimpleDataclass + data: dict[str, Any] + + +class ClassWithSerialize: + """Class with custom serialize method.""" + + def __init__(self, value: int): + self.value = value + self.name = "test" + + def serialize(self) -> dict: + """Custom serialization.""" + return {"custom_value": self.value * 2, "custom_name": self.name.upper()} + + +class UnserializableClass: + """Class that can't be easily serialized.""" + + def __init__(self): + self.data = "test" + + +# ============================================================================ +# Tests for basic type serialization +# ============================================================================ + + +def test_json_serialize_basic_types(): + """Test serialization of basic Python types: int, float, str, bool, None, list, dict.""" + serializer = JsonSerializer() + + # Test primitives + assert serializer.json_serialize(42) == 42 + assert serializer.json_serialize(3.14) == 3.14 + assert serializer.json_serialize("hello") == "hello" + assert serializer.json_serialize(True) is True + assert serializer.json_serialize(False) is False + assert serializer.json_serialize(None) is None + + # Test list + result = serializer.json_serialize([1, 2, 3]) + assert result == [1, 2, 3] + assert isinstance(result, list) + + # Test dict + result = serializer.json_serialize({"a": 1, "b": 2}) + assert result == {"a": 1, "b": 2} + assert isinstance(result, dict) + + # Test empty containers + assert serializer.json_serialize([]) == [] + assert serializer.json_serialize({}) == {} + + +def test_json_serialize_function(): + """Test the module-level json_serialize function with default config.""" + # Test that it works with basic types + assert json_serialize(42) == 42 + assert json_serialize("test") == "test" + assert json_serialize([1, 2, 3]) == [1, 2, 3] + assert json_serialize({"key": "value"}) == {"key": "value"} + + # Test with more complex types + obj = SimpleDataclass(x=10, y="hello", z=False) + result = json_serialize(obj) + assert result == {"x": 10, "y": "hello", "z": False} + + +# ============================================================================ +# Tests for .serialize() method override +# ============================================================================ + + +def test_json_serialize_serialize_method(): + """Test objects with .serialize() method are handled correctly.""" + serializer = JsonSerializer() + + obj = ClassWithSerialize(value=5) + result = serializer.json_serialize(obj) + + # Should use the custom serialize method + assert result == {"custom_value": 10, "custom_name": "TEST"} + assert result["custom_value"] == obj.value * 2 + assert result["custom_name"] == obj.name.upper() + + +def test_serialize_method_priority(): + """Test that .serialize() method takes priority over other handlers.""" + serializer = JsonSerializer() + + # Even though this is a dataclass, the .serialize() method should take priority + @dataclass + class DataclassWithSerialize: + x: int + y: int + + def serialize(self) -> dict: + return {"sum": self.x + self.y} + + obj = DataclassWithSerialize(x=3, y=7) + result = serializer.json_serialize(obj) + + # Should use custom serialize, not dataclass handler + assert result == {"sum": 10} + assert "x" not in result + assert "y" not in result + + +# ============================================================================ +# Tests for custom handlers +# ============================================================================ + + +def test_JsonSerializer_custom_handlers(): + """Test adding custom pre/post handlers and verify execution order.""" + # Create a custom handler that captures specific types + custom_check_count = {"count": 0} + custom_serialize_count = {"count": 0} + + def custom_check(self, obj, path): + custom_check_count["count"] += 1 + return isinstance(obj, str) and obj.startswith("CUSTOM:") + + def custom_serialize(self, obj, path): + custom_serialize_count["count"] += 1 + return {"custom": True, "value": obj[7:]} # Remove "CUSTOM:" prefix + + custom_handler = SerializerHandler( + check=custom_check, + serialize_func=custom_serialize, + uid="custom_string_handler", + desc="Custom handler for strings starting with CUSTOM:", + ) + + # Create serializer with custom handler in handlers_pre (before defaults) + serializer = JsonSerializer(handlers_pre=(custom_handler,)) + + # Test that custom handler is used + result = serializer.json_serialize("CUSTOM:test") + assert result == {"custom": True, "value": "test"} + assert custom_serialize_count["count"] == 1 + + # Test that normal strings still work (use default handler) + result = serializer.json_serialize("normal string") + assert result == "normal string" + + +def test_custom_handler_execution_order(): + """Test that handlers_pre are executed before default handlers.""" + executed_handlers = [] + + def tracking_check(handler_name): + def check(self, obj, path): + executed_handlers.append(handler_name) + return isinstance(obj, dict) and "_test_marker" in obj + + return check + + def tracking_serialize(handler_name): + def serialize(self, obj, path): + return {"handled_by": handler_name} + + return serialize + + handler1 = SerializerHandler( + check=tracking_check("handler1"), + serialize_func=tracking_serialize("handler1"), + uid="handler1", + desc="First custom handler", + ) + + handler2 = SerializerHandler( + check=tracking_check("handler2"), + serialize_func=tracking_serialize("handler2"), + uid="handler2", + desc="Second custom handler", + ) + + serializer = JsonSerializer(handlers_pre=(handler1, handler2)) + + test_obj = {"_test_marker": True} + result = serializer.json_serialize(test_obj) + + # First handler that matches should be used (handler1) + assert result == {"handled_by": "handler1"} + assert executed_handlers[0] == "handler1" + + +# ============================================================================ +# Tests for DEFAULT_HANDLERS +# ============================================================================ + + +def test_DEFAULT_HANDLERS(): + """Test that all default type handlers serialize correctly.""" + serializer = JsonSerializer() + + # Test dataclass + dc = SimpleDataclass(x=1, y="test", z=False) + result = serializer.json_serialize(dc) + assert result == {"x": 1, "y": "test", "z": False} + + # Test namedtuple - should serialize as dict + Point = namedtuple("Point", ["x", "y"]) + point = Point(10, 20) + result = serializer.json_serialize(point) + assert result == {"x": 10, "y": 20} + assert isinstance(result, dict) + + # Test Path + path = Path("/home/user/test.txt") + result = serializer.json_serialize(path) + assert result == "/home/user/test.txt" + assert isinstance(result, str) + + # Test set (should become dict with _FORMAT_KEY) + result = serializer.json_serialize({1, 2, 3}) + assert isinstance(result, dict) + assert result[_FORMAT_KEY] == "set" + assert set(result["data"]) == {1, 2, 3} + + # Test tuple (should become list) + result = serializer.json_serialize((1, 2, 3)) + assert result == [1, 2, 3] + assert isinstance(result, list) + + +def test_BASE_HANDLERS(): + """Test that BASE_HANDLERS work correctly (primitives, dicts, lists, tuples).""" + serializer = JsonSerializer(handlers_default=BASE_HANDLERS) + + # Base handlers should handle primitives + assert serializer.json_serialize(42) == 42 + assert serializer.json_serialize("test") == "test" + assert serializer.json_serialize(True) is True + assert serializer.json_serialize(None) is None + + # Base handlers should handle dicts and lists + assert serializer.json_serialize([1, 2, 3]) == [1, 2, 3] + assert serializer.json_serialize({"a": 1}) == {"a": 1} + assert serializer.json_serialize((1, 2)) == [1, 2] + + +def test_fallback_handler(): + """Test that the fallback handler works for unserializable objects.""" + serializer = JsonSerializer() + + obj = UnserializableClass() + result = serializer.json_serialize(obj) + + # Fallback handler should return dict with special keys + assert isinstance(result, dict) + assert "__name__" in result + assert "__module__" in result + assert "type" in result + assert "repr" in result + + +# ============================================================================ +# Tests for nested structures +# ============================================================================ + + +def test_nested_structures(): + """Test serialization of mixed types and nested dicts/lists.""" + serializer = JsonSerializer() + + # Nested dicts and lists + nested = {"outer": {"inner": [1, 2, {"deep": "value"}]}} + result = serializer.json_serialize(nested) + assert result == {"outer": {"inner": [1, 2, {"deep": "value"}]}} + + # List of dicts + list_of_dicts = [{"a": 1}, {"b": 2}, {"c": 3}] + result = serializer.json_serialize(list_of_dicts) + assert result == [{"a": 1}, {"b": 2}, {"c": 3}] + + # Dict of lists + dict_of_lists = {"nums": [1, 2, 3], "strs": ["a", "b", "c"]} + result = serializer.json_serialize(dict_of_lists) + assert result == {"nums": [1, 2, 3], "strs": ["a", "b", "c"]} + + +def test_nested_dataclasses(): + """Test serialization of nested dataclasses.""" + serializer = JsonSerializer() + + simple = SimpleDataclass(x=5, y="inner", z=True) + nested = NestedDataclass(simple=simple, data={"key": "value"}) + + result = serializer.json_serialize(nested) + assert result == { + "simple": {"x": 5, "y": "inner", "z": True}, + "data": {"key": "value"}, + } + + +def test_deeply_nested_structure(): + """Test very deeply nested structures.""" + serializer = JsonSerializer() + + deep = {"l1": {"l2": {"l3": {"l4": {"l5": [1, 2, 3]}}}}} + result = serializer.json_serialize(deep) + assert result == {"l1": {"l2": {"l3": {"l4": {"l5": [1, 2, 3]}}}}} + + +def test_mixed_types_nested(): + """Test nested structures with mixed types (dataclass, dict, list, primitives).""" + serializer = JsonSerializer() + + dc = SimpleDataclass(x=100, y="test", z=False) + mixed = { + "dataclass": dc, + "list": [1, 2, dc], + "nested": {"inner_dc": dc, "values": [10, 20]}, + "primitive": 42, + } + + result = serializer.json_serialize(mixed) + expected_dc = {"x": 100, "y": "test", "z": False} + assert result == { + "dataclass": expected_dc, + "list": [1, 2, expected_dc], + "nested": {"inner_dc": expected_dc, "values": [10, 20]}, + "primitive": 42, + } + + +# ============================================================================ +# Tests for ErrorMode handling +# ============================================================================ + + +def test_error_mode_except(): + """Test that ErrorMode.EXCEPT raises SerializationException on errors.""" + + # Create a handler that always raises an error + def error_check(self, obj, path): + return isinstance(obj, str) and obj == "ERROR" + + def error_serialize(self, obj, path): + raise ValueError("Intentional error") + + error_handler = SerializerHandler( + check=error_check, + serialize_func=error_serialize, + uid="error_handler", + desc="Handler that raises errors", + ) + + serializer = JsonSerializer( + error_mode=ErrorMode.EXCEPT, handlers_pre=(error_handler,) + ) + + with pytest.raises(SerializationException) as exc_info: + serializer.json_serialize("ERROR") + + assert "error serializing" in str(exc_info.value) + assert "error_handler" in str(exc_info.value) + + +def test_error_mode_warn(): + """Test that ErrorMode.WARN returns repr on errors and emits warnings.""" + + # Create a handler that always raises an error + def error_check(self, obj, path): + return isinstance(obj, str) and obj == "ERROR" + + def error_serialize(self, obj, path): + raise ValueError("Intentional error") + + error_handler = SerializerHandler( + check=error_check, + serialize_func=error_serialize, + uid="error_handler", + desc="Handler that raises errors", + ) + + serializer = JsonSerializer( + error_mode=ErrorMode.WARN, handlers_pre=(error_handler,) + ) + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + result = serializer.json_serialize("ERROR") + + # Should return repr instead of raising + assert result == "'ERROR'" + # Should have emitted a warning + assert len(w) > 0 + assert "error serializing" in str(w[0].message) + + +def test_error_mode_ignore(): + """Test that ErrorMode.IGNORE returns repr on errors without warnings.""" + + # Create a handler that always raises an error + def error_check(self, obj, path): + return isinstance(obj, str) and obj == "ERROR" + + def error_serialize(self, obj, path): + raise ValueError("Intentional error") + + error_handler = SerializerHandler( + check=error_check, + serialize_func=error_serialize, + uid="error_handler", + desc="Handler that raises errors", + ) + + serializer = JsonSerializer( + error_mode=ErrorMode.IGNORE, handlers_pre=(error_handler,) + ) + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + result = serializer.json_serialize("ERROR") + + # Should return repr + assert result == "'ERROR'" + # Should not have emitted warnings + assert len(w) == 0 + + +# ============================================================================ +# Tests for write_only_format +# ============================================================================ + + +def test_write_only_format(): + """Test that write_only_format changes _FORMAT_KEY to __write_format__.""" + + # Create a handler that outputs _FORMAT_KEY + def format_check(self, obj, path): + return isinstance(obj, str) and obj.startswith("FORMAT:") + + def format_serialize(self, obj, path): + return {_FORMAT_KEY: "custom_format", "data": obj[7:]} + + format_handler = SerializerHandler( + check=format_check, + serialize_func=format_serialize, + uid="format_handler", + desc="Handler that uses _FORMAT_KEY", + ) + + # Without write_only_format + serializer1 = JsonSerializer(handlers_pre=(format_handler,)) + result1 = serializer1.json_serialize("FORMAT:test") + assert _FORMAT_KEY in result1 + assert result1[_FORMAT_KEY] == "custom_format" + + # With write_only_format + serializer2 = JsonSerializer(handlers_pre=(format_handler,), write_only_format=True) + result2 = serializer2.json_serialize("FORMAT:test") + assert _FORMAT_KEY not in result2 + assert "__write_format__" in result2 + assert result2["__write_format__"] == "custom_format" + assert result2["data"] == "test" + + +# ============================================================================ +# Tests for SerializerHandler.serialize() +# ============================================================================ + + +def test_SerializerHandler_serialize(): + """Test that SerializerHandler can serialize its own metadata.""" + + def simple_check(self, obj, path): + """Check if object is an integer.""" + return isinstance(obj, int) + + def simple_serialize(self, obj, path): + """Serialize integer.""" + return obj * 2 + + handler = SerializerHandler( + check=simple_check, + serialize_func=simple_serialize, + uid="test_handler", + desc="Test handler description", + ) + + metadata = handler.serialize() + + assert isinstance(metadata, dict) + assert "check" in metadata + assert "serialize_func" in metadata + assert "uid" in metadata + assert "desc" in metadata + + assert metadata["uid"] == "test_handler" + assert metadata["desc"] == "Test handler description" + + # Check that code and doc are included + assert "code" in metadata["check"] + assert "doc" in metadata["check"] + assert "code" in metadata["serialize_func"] + assert "doc" in metadata["serialize_func"] + + +# ============================================================================ +# Tests for hashify +# ============================================================================ + + +def test_hashify(): + """Test JsonSerializer.hashify() method.""" + serializer = JsonSerializer() + + # Test that it converts to hashable types + result = serializer.hashify({"a": [1, 2, 3]}) + assert isinstance(result, tuple) + assert result == (("a", (1, 2, 3)),) + # Should be hashable + hash(result) + + # Test with list + result = serializer.hashify([1, 2, 3]) + assert result == (1, 2, 3) + hash(result) + + # Test with primitive (already hashable) + result = serializer.hashify(42) + assert result == 42 + hash(result) + + +def test_hashify_force(): + """Test hashify with force parameter.""" + serializer = JsonSerializer() + + # With force=True (default), should handle unhashable objects + obj = UnserializableClass() + result = serializer.hashify(obj, force=True) + assert isinstance(result, tuple) # Converted to hashable form + + +# ============================================================================ +# Tests for path tracking +# ============================================================================ + + +def test_path_tracking(): + """Test that paths are correctly tracked through nested serialization.""" + paths_seen = [] + + def tracking_check(self, obj, path): + paths_seen.append(path) + return False # Never actually handle, just track + + tracking_handler = SerializerHandler( + check=tracking_check, + serialize_func=lambda self, obj, path: obj, + uid="tracking", + desc="Path tracking handler", + ) + + serializer = JsonSerializer(handlers_pre=(tracking_handler,)) + + # Serialize nested structure + nested = {"a": {"b": [1, 2]}} + serializer.json_serialize(nested) + + # Check that we saw paths for nested elements + assert tuple() in paths_seen # Root + assert ("a",) in paths_seen # nested dict + assert ("a", "b") in paths_seen # nested list + assert ("a", "b", 0) in paths_seen # first element + assert ("a", "b", 1) in paths_seen # second element + + +# ============================================================================ +# Tests for initialization +# ============================================================================ + + +def test_JsonSerializer_init_no_positional_args(): + """Test that JsonSerializer raises ValueError on positional arguments.""" + with pytest.raises(ValueError, match="no positional arguments"): + JsonSerializer("invalid", "args") + + # Should work with keyword args + serializer = JsonSerializer(error_mode=ErrorMode.WARN) + assert serializer.error_mode == ErrorMode.WARN + + +def test_JsonSerializer_init_defaults(): + """Test JsonSerializer default initialization values.""" + serializer = JsonSerializer() + + assert serializer.array_mode == "array_list_meta" + assert serializer.error_mode == ErrorMode.EXCEPT + assert serializer.write_only_format is False + assert serializer.handlers == DEFAULT_HANDLERS + + +def test_JsonSerializer_init_custom_values(): + """Test JsonSerializer with custom initialization values.""" + custom_handler = SerializerHandler( + check=lambda self, obj, path: False, + serialize_func=lambda self, obj, path: obj, + uid="custom", + desc="Custom handler", + ) + + serializer = JsonSerializer( + array_mode="list", + error_mode="warn", + handlers_pre=(custom_handler,), + handlers_default=BASE_HANDLERS, + write_only_format=True, + ) + + assert serializer.array_mode == "list" + assert serializer.error_mode == ErrorMode.WARN + assert serializer.write_only_format is True + assert serializer.handlers[0] == custom_handler + assert len(serializer.handlers) == len(BASE_HANDLERS) + 1 + + +# ============================================================================ +# Edge cases and integration tests +# ============================================================================ + + +def test_empty_handlers(): + """Test serializer with no handlers.""" + serializer = JsonSerializer(handlers_default=tuple()) + + # Should fail to serialize anything + with pytest.raises(SerializationException): + serializer.json_serialize(42) + + +# TODO: Implement circular reference protection in the future. see https://github.com/mivanit/muutils/issues/62 +@pytest.mark.skip( + reason="we don't currently have circular reference protection, see https://github.com/mivanit/muutils/issues/62" +) +def test_circular_reference_protection(): + """Test that circular references don't cause infinite loops (will hit recursion limit).""" + # Note: This test verifies the expected behavior (recursion error) rather than + # infinite loop, as the module doesn't explicitly handle circular references + serializer = JsonSerializer() + + # Create circular reference + circular = {"a": None} + circular["a"] = circular # type: ignore + + # Should eventually raise RecursionError + with pytest.raises(RecursionError): + serializer.json_serialize(circular) + + +def test_large_nested_structure(): + """Test serialization of large nested structure.""" + serializer = JsonSerializer() + + # Create large nested list + large = [[i, i * 2, i * 3] for i in range(100)] + result = serializer.json_serialize(large) + assert len(result) == 100 + assert result[0] == [0, 0, 0] + assert result[99] == [99, 198, 297] + + +def test_mixed_container_types(): + """Test serialization of sets, frozensets, and other iterables.""" + serializer = JsonSerializer() + + # Set - serialized with format key + result = serializer.json_serialize({3, 1, 2}) + assert isinstance(result, dict) + assert _FORMAT_KEY in result + assert result[_FORMAT_KEY] == "set" + assert set(result["data"]) == {1, 2, 3} + + # Frozenset - serialized with format key + result = serializer.json_serialize(frozenset([4, 5, 6])) + assert isinstance(result, dict) + assert _FORMAT_KEY in result + assert result[_FORMAT_KEY] == "frozenset" + assert set(result["data"]) == {4, 5, 6} + + # Generator (Iterable) - serialized as list + gen = (x * 2 for x in range(3)) + result = serializer.json_serialize(gen) + assert result == [0, 2, 4] + + +def test_string_keys_in_dict(): + """Test that dict keys are converted to strings.""" + serializer = JsonSerializer() + + # Integer keys should be converted to strings + result = serializer.json_serialize({1: "a", 2: "b", 3: "c"}) + assert result == {"1": "a", "2": "b", "3": "c"} + assert all(isinstance(k, str) for k in result.keys()) diff --git a/tests/unit/json_serialize/test_serializable_field.py b/tests/unit/json_serialize/test_serializable_field.py new file mode 100644 index 00000000..99515f33 --- /dev/null +++ b/tests/unit/json_serialize/test_serializable_field.py @@ -0,0 +1,421 @@ +"""Tests for muutils.json_serialize.serializable_field module. + +Tests the SerializableField class and serializable_field function, +which extend dataclasses.Field with serialization capabilities. +""" + +from __future__ import annotations + +import dataclasses +from dataclasses import field +from typing import Any + +import pytest + +from muutils.json_serialize import ( + SerializableDataclass, + serializable_dataclass, + serializable_field, +) +from muutils.json_serialize.serializable_field import SerializableField + + +# ============================================================================ +# Test SerializableField creation with various parameters +# ============================================================================ + + +def test_SerializableField_creation(): + """Test creating SerializableField with various parameters.""" + # Basic creation with default parameters + sf1 = SerializableField() + assert sf1.serialize is True + assert sf1.serialization_fn is None + assert sf1.loading_fn is None + assert sf1.deserialize_fn is None + assert sf1.assert_type is True + assert sf1.custom_typecheck_fn is None + assert sf1.default is dataclasses.MISSING + assert sf1.default_factory is dataclasses.MISSING + + # Creation with default value + sf2 = SerializableField(default=42) + assert sf2.default == 42 + assert sf2.init is True + assert sf2.repr is True + assert sf2.compare is True + + # Creation with default_factory + sf3 = SerializableField(default_factory=list) + assert sf3.default_factory == list # noqa: E721 + assert sf3.default is dataclasses.MISSING + + # Creation with custom parameters + sf4 = SerializableField( + default=100, + init=True, + repr=False, + hash=True, + compare=False, + serialize=True, + ) + assert sf4.default == 100 + assert sf4.init is True + assert sf4.repr is False + assert sf4.hash is True + assert sf4.compare is False + assert sf4.serialize is True + + # Creation with serialization parameters + def custom_serialize(x): + return str(x) + + def custom_deserialize(x): + return int(x) + + sf5 = SerializableField( + serialization_fn=custom_serialize, + deserialize_fn=custom_deserialize, + assert_type=False, + ) + assert sf5.serialization_fn == custom_serialize + assert sf5.deserialize_fn == custom_deserialize + assert sf5.assert_type is False + + +def test_SerializableField_init_serialize_validation(): + """Test that init=True and serialize=False raises ValueError.""" + with pytest.raises(ValueError, match="Cannot have init=True and serialize=False"): + SerializableField(init=True, serialize=False) + + +def test_SerializableField_loading_deserialize_conflict(): + """Test that passing both loading_fn and deserialize_fn raises ValueError.""" + + def dummy_fn(x): + return x + + with pytest.raises( + ValueError, match="Cannot pass both loading_fn and deserialize_fn" + ): + SerializableField(loading_fn=dummy_fn, deserialize_fn=dummy_fn) + + +def test_SerializableField_doc(): + """Test doc parameter handling across Python versions.""" + sf = SerializableField(doc="Test documentation") + assert sf.doc == "Test documentation" + + +# ============================================================================ +# Test from_Field() method +# ============================================================================ + + +def test_from_Field(): + """Test converting a dataclasses.Field to SerializableField.""" + # Create a standard dataclasses.Field + dc_field = field( + default=42, + init=True, + repr=True, + hash=None, + compare=True, + ) + + # Convert to SerializableField + sf = SerializableField.from_Field(dc_field) + + # Verify all standard Field properties were copied + assert sf.default == 42 + assert sf.init is True + assert sf.repr is True + assert sf.hash is None + assert sf.compare is True + + # Verify SerializableField-specific properties have defaults + assert sf.serialize == sf.repr # serialize defaults to repr value + assert sf.serialization_fn is None + assert sf.loading_fn is None + assert sf.deserialize_fn is None + + # Test with default_factory and init=False to avoid init=True, serialize=False error + dc_field2 = field(default_factory=list, repr=True, init=True) + sf2 = SerializableField.from_Field(dc_field2) + assert sf2.default_factory == list # noqa: E721 + assert sf2.default is dataclasses.MISSING + assert sf2.serialize is True # should match repr=True + + +# ============================================================================ +# Test serialization_fn and deserialize_fn +# ============================================================================ + + +def test_serialization_deserialize_fn(): + """Test custom serialization and deserialization functions.""" + + @serializable_dataclass + class CustomSerialize(SerializableDataclass): + # Serialize as uppercase, deserialize as lowercase + value: str = serializable_field( + serialization_fn=lambda x: x.upper(), + deserialize_fn=lambda x: x.lower(), + ) + + # Test serialization + instance = CustomSerialize(value="Hello") + serialized = instance.serialize() + assert serialized["value"] == "HELLO" + + # Test deserialization + loaded = CustomSerialize.load({"value": "WORLD"}) + assert loaded.value == "world" + + +def test_serialization_fn_with_complex_type(): + """Test serialization_fn with more complex transformations.""" + + @serializable_dataclass + class ComplexSerialize(SerializableDataclass): + # Store a tuple as a list + coords: tuple[int, int] = serializable_field( + default=(0, 0), + serialization_fn=lambda x: list(x), + deserialize_fn=lambda x: tuple(x), + ) + + instance = ComplexSerialize(coords=(3, 4)) + serialized = instance.serialize() + assert serialized["coords"] == [3, 4] # serialized as list + + loaded = ComplexSerialize.load({"coords": [5, 6]}) + assert loaded.coords == (5, 6) # loaded as tuple + + +# ============================================================================ +# Test loading_fn (takes full data dict) +# ============================================================================ + + +def test_loading_fn(): + """Test loading_fn which takes the full data dict.""" + + @serializable_dataclass + class WithLoadingFn(SerializableDataclass): + x: int + y: int + # computed field that depends on other fields + sum_xy: int = serializable_field( + init=False, + serialize=False, + default=0, + ) + + # Create instance + instance = WithLoadingFn(x=3, y=4) + instance.sum_xy = instance.x + instance.y + assert instance.sum_xy == 7 + + +def test_loading_fn_vs_deserialize_fn(): + """Test the difference between loading_fn (dict) and deserialize_fn (value).""" + + @serializable_dataclass + class WithLoadingFn(SerializableDataclass): + value: int = serializable_field( + serialization_fn=lambda x: x * 2, + loading_fn=lambda data: data["value"] // 2, # takes full dict + ) + + @serializable_dataclass + class WithDeserializeFn(SerializableDataclass): + value: int = serializable_field( + serialization_fn=lambda x: x * 2, + deserialize_fn=lambda x: x // 2, # takes just the value + ) + + # Both should behave the same in this case + instance1 = WithLoadingFn(value=10) + serialized1 = instance1.serialize() + assert serialized1["value"] == 20 + + loaded1 = WithLoadingFn.load({"value": 20}) + assert loaded1.value == 10 + + instance2 = WithDeserializeFn(value=10) + serialized2 = instance2.serialize() + assert serialized2["value"] == 20 + + loaded2 = WithDeserializeFn.load({"value": 20}) + assert loaded2.value == 10 + + +# ============================================================================ +# Test field validation: assert_type and custom_typecheck_fn +# ============================================================================ + + +def test_field_validation_assert_type(): + """Test assert_type parameter for type validation.""" + + @serializable_dataclass + class StrictType(SerializableDataclass): + value: int = serializable_field(assert_type=True) + + @serializable_dataclass + class LooseType(SerializableDataclass): + value: int = serializable_field(assert_type=False) + + # Strict type checking should warn with wrong type (using WARN mode by default) + with pytest.warns(UserWarning, match="Type mismatch"): + instance = StrictType.load({"value": "not an int"}) + assert instance.value == "not an int" + + # Loose type checking should allow wrong type without warning + instance2 = LooseType.load({"value": "not an int"}) + assert instance2.value == "not an int" + + +def test_field_validation_custom_typecheck_fn(): + """Test custom_typecheck_fn for custom type validation.""" + + def is_positive(value: Any) -> bool: + """Check if value is a positive number.""" + return isinstance(value, (int, float)) and value > 0 + + @serializable_dataclass + class PositiveNumber(SerializableDataclass): + value: int = serializable_field( + custom_typecheck_fn=lambda t: True # Accept any type + ) + + # This should work because custom_typecheck_fn returns True + instance = PositiveNumber(value=42) + assert instance.value == 42 + + +# ============================================================================ +# Test serializable_field() function +# ============================================================================ + + +def test_serializable_field_function(): + """Test the serializable_field() function wrapper.""" + # Test basic usage + f1 = serializable_field() + assert isinstance(f1, SerializableField) + assert f1.serialize is True + + # Test with default + f2 = serializable_field(default=100) + assert f2.default == 100 + + # Test with default_factory + f3 = serializable_field(default_factory=list) + assert f3.default_factory == list # noqa: E721 + + # Test with all parameters + f4 = serializable_field( + default=42, + init=True, + repr=False, + hash=True, + compare=False, + serialize=True, + serialization_fn=str, + deserialize_fn=int, + assert_type=False, + ) + assert f4.default == 42 + assert f4.repr is False + assert f4.hash is True + assert f4.serialization_fn == str # noqa: E721 + assert f4.deserialize_fn == int # noqa: E721 + + +def test_serializable_field_no_positional_args(): + """Test that serializable_field doesn't accept positional arguments.""" + with pytest.raises(AssertionError, match="unexpected positional arguments"): + serializable_field("invalid") # type: ignore + + +def test_serializable_field_description_deprecated(): + """Test that 'description' parameter is deprecated in favor of 'doc'.""" + import warnings + + # Using description should raise DeprecationWarning + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + f = serializable_field(description="Test description") + # Check that a deprecation warning was issued + assert len(w) == 1 + assert issubclass(w[0].category, DeprecationWarning) + assert "`description` is deprecated" in str(w[0].message) + # Verify doc was set + assert f.doc == "Test description" + + # Using both doc and description should raise ValueError + with pytest.raises(ValueError, match="cannot pass both"): + serializable_field(doc="Doc", description="Description") + + +# ============================================================================ +# Integration tests with SerializableDataclass +# ============================================================================ + + +def test_serializable_field_integration(): + """Test SerializableField integration with SerializableDataclass.""" + + @serializable_dataclass + class IntegrationTest(SerializableDataclass): + # Regular field + normal: str + + # Field with custom serialization (no default, so must come before fields with defaults) + custom: str = serializable_field( + serialization_fn=lambda x: x.upper(), + deserialize_fn=lambda x: x.lower(), + ) + + # Field with default + with_default: int = serializable_field(default=42) + + # Field with default_factory + with_factory: list = serializable_field(default_factory=list) + + # Non-serialized field + internal: int = serializable_field(init=False, serialize=False, default=0) + + # Create instance + instance = IntegrationTest( + normal="test", + custom="hello", + with_default=100, + with_factory=[1, 2, 3], + ) + instance.internal = 999 + + # Serialize + serialized = instance.serialize() + assert serialized["normal"] == "test" + assert serialized["with_default"] == 100 + assert serialized["with_factory"] == [1, 2, 3] + assert serialized["custom"] == "HELLO" # uppercase + assert "internal" not in serialized # not serialized + + # Load + loaded = IntegrationTest.load( + { + "normal": "loaded", + "custom": "WORLD", + "with_default": 200, + "with_factory": [4, 5], + } + ) + assert loaded.normal == "loaded" + assert loaded.with_default == 200 + assert loaded.with_factory == [4, 5] + assert loaded.custom == "world" # lowercase + assert loaded.internal == 0 # default value diff --git a/tests/unit/json_serialize/test_util.py b/tests/unit/json_serialize/test_util.py index 0c7d5297..9b88c038 100644 --- a/tests/unit/json_serialize/test_util.py +++ b/tests/unit/json_serialize/test_util.py @@ -1,4 +1,5 @@ from collections import namedtuple +from dataclasses import dataclass from typing import NamedTuple import pytest @@ -6,7 +7,10 @@ # Module code assumed to be imported from my_module from muutils.json_serialize.util import ( UniversalContainer, + _FORMAT_KEY, _recursive_hashify, + array_safe_eq, + dc_eq, isinstance_namedtuple, safe_getsource, string_as_lines, @@ -77,3 +81,231 @@ def raises_error(): print(f"Source of wrapped_func: {error_source}") # Check for the original function's source since the decorator doesn't change this assert any("def raises_error():" in line for line in error_source) + + +# Additional tests from TODO.md + + +def test_try_catch_exception_handling(): + """Test that try_catch properly catches exceptions and returns default error message.""" + + @try_catch + def raises_runtime_error(): + raise RuntimeError("runtime error message") + + @try_catch + def raises_key_error(): + raise KeyError("missing key") + + @try_catch + def raises_zero_division(): + return 1 / 0 + + # Test that exceptions are caught and serialized + assert raises_runtime_error() == "RuntimeError: runtime error message" + assert raises_key_error() == "KeyError: 'missing key'" + result = raises_zero_division() + assert "ZeroDivisionError" in result + + # Test with arguments + @try_catch + def func_with_args(a, b): + if a == 0: + raise ValueError(f"a cannot be 0, got {a}") + return a + b + + assert func_with_args(1, 2) == 3 + assert func_with_args(0, 2) == "ValueError: a cannot be 0, got 0" + + +def test_array_safe_eq(): + """Test array_safe_eq with numpy arrays, torch tensors, and nested arrays.""" + # Basic types + assert array_safe_eq(1, 1) is True + assert array_safe_eq(1, 2) is False + # Note: strings are treated as sequences by array_safe_eq, so we test differently + assert array_safe_eq(1.5, 1.5) is True + assert array_safe_eq(True, True) is True + + # Lists and sequences + assert array_safe_eq([1, 2, 3], [1, 2, 3]) is True + assert array_safe_eq([1, 2, 3], [1, 2, 4]) is False + assert array_safe_eq([], []) is True + assert array_safe_eq((1, 2, 3), (1, 2, 3)) is True + + # Nested arrays + assert array_safe_eq([[1, 2], [3, 4]], [[1, 2], [3, 4]]) is True + assert array_safe_eq([[1, 2], [3, 4]], [[1, 2], [3, 5]]) is False + assert array_safe_eq([[[1]], [[2]]], [[[1]], [[2]]]) is True + + # Dicts + assert array_safe_eq({"a": 1, "b": 2}, {"a": 1, "b": 2}) is True + assert array_safe_eq({"a": 1, "b": 2}, {"a": 1, "b": 3}) is False + assert array_safe_eq({}, {}) is True + + # Mixed nested structures + assert ( + array_safe_eq({"a": [1, 2], "b": {"c": 3}}, {"a": [1, 2], "b": {"c": 3}}) + is True + ) + assert ( + array_safe_eq({"a": [1, 2], "b": {"c": 3}}, {"a": [1, 2], "b": {"c": 4}}) + is False + ) + + # Identity check + obj = {"a": 1} + assert array_safe_eq(obj, obj) is True + + # Type mismatch + assert array_safe_eq(1, 1.0) is False # Different types + assert array_safe_eq([1, 2], (1, 2)) is False + + # Try with numpy if available (note: numpy returns np.True_ not Python True) + try: + import numpy as np + + arr1 = np.array([1, 2, 3]) + arr2 = np.array([1, 2, 3]) + arr3 = np.array([1, 2, 4]) + assert array_safe_eq(arr1, arr2) # Use == not is for numpy bool + assert not array_safe_eq(arr1, arr3) + except ImportError: + pass # Skip numpy tests if not available + + # Try with torch if available (note: torch also may return tensor bool) + try: + import torch + + t1 = torch.tensor([1.0, 2.0, 3.0]) + t2 = torch.tensor([1.0, 2.0, 3.0]) + t3 = torch.tensor([1.0, 2.0, 4.0]) + assert array_safe_eq(t1, t2) # Use == not is for torch bool + assert not array_safe_eq(t1, t3) + except ImportError: + pass # Skip torch tests if not available + + +def test_dc_eq(): + """Test dc_eq for dataclasses equal and unequal cases.""" + + @dataclass + class Point: + x: int + y: int + + @dataclass + class Point3D: + x: int + y: int + z: int + + @dataclass + class PointWithArray: + x: int + coords: list + + # Equal dataclasses + p1 = Point(1, 2) + p2 = Point(1, 2) + assert dc_eq(p1, p2) is True + + # Unequal dataclasses + p3 = Point(1, 3) + assert dc_eq(p1, p3) is False + + # Identity + assert dc_eq(p1, p1) is True + + # Different classes - default behavior (false_when_class_mismatch=True) + p3d = Point3D(1, 2, 3) + assert dc_eq(p1, p3d) is False + + # Different classes - except_when_class_mismatch=True + with pytest.raises( + TypeError, match="Cannot compare dataclasses of different classes" + ): + dc_eq(p1, p3d, except_when_class_mismatch=True) + + # Dataclasses with arrays + pa1 = PointWithArray(1, [1, 2, 3]) + pa2 = PointWithArray(1, [1, 2, 3]) + pa3 = PointWithArray(1, [1, 2, 4]) + assert dc_eq(pa1, pa2) is True + assert dc_eq(pa1, pa3) is False + + # Test with nested structures + @dataclass + class Container: + items: list + metadata: dict + + c1 = Container([1, 2, 3], {"name": "test"}) + c2 = Container([1, 2, 3], {"name": "test"}) + c3 = Container([1, 2, 3], {"name": "other"}) + assert dc_eq(c1, c2) is True + assert dc_eq(c1, c3) is False + + # Test except_when_field_mismatch with different classes but same fields + @dataclass + class Point2D: + x: int + y: int + + # Different classes but same fields - should raise with except_when_field_mismatch + with pytest.raises(AttributeError): + dc_eq(p1, p3d, except_when_field_mismatch=True) + + +def test_FORMAT_KEY(): + """Test that FORMAT_KEY constant is accessible and has expected value.""" + # Test that the format key exists and is a string + assert isinstance(_FORMAT_KEY, str) + assert _FORMAT_KEY == "__muutils_format__" + + # Test that it can be used in dictionaries (common use case) + data = {_FORMAT_KEY: "custom_type", "value": 42} + assert data[_FORMAT_KEY] == "custom_type" + assert _FORMAT_KEY in data + + +def test_edge_cases(): + """Test edge cases for utility functions: None values, empty containers, mixed types.""" + # string_as_lines with None + assert string_as_lines(None) == [] + # Empty string splits to empty list (splitlines behavior) + assert string_as_lines("") == [] + assert string_as_lines("single") == ["single"] + + # _recursive_hashify with empty containers + assert _recursive_hashify([]) == () + assert _recursive_hashify({}) == () + assert _recursive_hashify(()) == () + + # _recursive_hashify with mixed nested types + mixed = {"list": [1, 2], "dict": {"nested": True}, "tuple": (3, 4)} + result = _recursive_hashify(mixed) + assert isinstance(result, tuple) + + # array_safe_eq with empty containers + assert array_safe_eq([], []) is True + assert array_safe_eq({}, {}) is True + assert array_safe_eq((), ()) is True + + # array_safe_eq with None + assert array_safe_eq(None, None) is True + assert array_safe_eq(None, 0) is False + + # try_catch with function returning None + @try_catch + def returns_none(): + return None + + assert returns_none() is None + + # UniversalContainer with various types + uc = UniversalContainer() + assert None in uc + assert [] in uc + assert {} in uc + assert object() in uc diff --git a/tests/unit/logger/test_log_util.py b/tests/unit/logger/test_log_util.py new file mode 100644 index 00000000..55f9a985 --- /dev/null +++ b/tests/unit/logger/test_log_util.py @@ -0,0 +1,176 @@ +from __future__ import annotations + +import os +from pathlib import Path + +import pytest + +from muutils.jsonlines import jsonl_write +from muutils.logger.log_util import ( + gather_log, + gather_stream, + gather_val, + get_any_from_stream, +) + +TEMP_PATH: Path = Path("tests/_temp/logger") + + +def test_gather_log(): + """Test gathering and sorting all streams from a multi-stream log file""" + # Create test directory + os.makedirs(TEMP_PATH, exist_ok=True) + log_file = TEMP_PATH / "test_gather_log.jsonl" + + # Create test data with multiple streams + test_data = [ + {"msg": "stream1_msg1", "value": 1, "_stream": "stream1"}, + {"msg": "stream2_msg1", "value": 10, "_stream": "stream2"}, + {"msg": "stream1_msg2", "value": 2, "_stream": "stream1"}, + {"msg": "default_msg1", "value": 100}, # no _stream key + {"msg": "stream2_msg2", "value": 20, "_stream": "stream2"}, + {"msg": "stream1_msg3", "value": 3, "_stream": "stream1"}, + ] + + jsonl_write(str(log_file), test_data) + + # Gather all streams + result = gather_log(str(log_file)) + + # Verify correct streams are present + assert "stream1" in result + assert "stream2" in result + assert "default" in result + + # Verify stream separation + assert len(result["stream1"]) == 3 + assert len(result["stream2"]) == 2 + assert len(result["default"]) == 1 + + # Verify data integrity + assert result["stream1"][0]["msg"] == "stream1_msg1" + assert result["stream1"][1]["msg"] == "stream1_msg2" + assert result["stream1"][2]["msg"] == "stream1_msg3" + + assert result["stream2"][0]["msg"] == "stream2_msg1" + assert result["stream2"][1]["msg"] == "stream2_msg2" + + assert result["default"][0]["msg"] == "default_msg1" + assert result["default"][0]["value"] == 100 + + +def test_gather_stream(): + """Test extracting a specific stream from a log file""" + os.makedirs(TEMP_PATH, exist_ok=True) + log_file = TEMP_PATH / "test_gather_stream.jsonl" + + # Create test data with multiple streams + test_data = [ + {"msg": "stream1_msg1", "idx": 1, "_stream": "target"}, + {"msg": "stream2_msg1", "idx": 2, "_stream": "other"}, + {"msg": "stream1_msg2", "idx": 3, "_stream": "target"}, + {"msg": "no_stream", "idx": 4}, # no _stream key + {"msg": "stream2_msg2", "idx": 5, "_stream": "other"}, + {"msg": "stream1_msg3", "idx": 6, "_stream": "target"}, + ] + + jsonl_write(str(log_file), test_data) + + # Gather only the "target" stream + result = gather_stream(str(log_file), "target") + + # Verify filtering + assert len(result) == 3 + + # Verify correct items were selected + assert result[0]["msg"] == "stream1_msg1" + assert result[0]["idx"] == 1 + assert result[1]["msg"] == "stream1_msg2" + assert result[1]["idx"] == 3 + assert result[2]["msg"] == "stream1_msg3" + assert result[2]["idx"] == 6 + + # Verify all items have the correct stream + for item in result: + assert item["_stream"] == "target" + + # Test with non-existent stream + empty_result = gather_stream(str(log_file), "nonexistent") + assert len(empty_result) == 0 + + +def test_gather_val(): + """Test extracting specific keys from a specific stream""" + os.makedirs(TEMP_PATH, exist_ok=True) + log_file = TEMP_PATH / "test_gather_val.jsonl" + + # Create test data matching the example from the docstring + test_data = [ + {"a": 1, "b": 2, "c": 3, "_stream": "s1"}, + {"a": 4, "b": 5, "c": 6, "_stream": "s1"}, + {"a": 7, "b": 8, "c": 9, "_stream": "s2"}, + {"a": 10, "b": 11, "_stream": "s1"}, # missing key 'c' + {"a": 13, "b": 14, "c": 15, "_stream": "s1"}, + ] + + jsonl_write(str(log_file), test_data) + + # Test basic key extraction + result = gather_val(str(log_file), "s1", ("a", "b")) + + # Verify data structure + assert len(result) == 4 # s1 has 4 entries + assert result[0] == [1, 2] + assert result[1] == [4, 5] + assert result[2] == [10, 11] + assert result[3] == [13, 14] + + # Test with three keys (should skip the entry missing 'c') + result_three_keys = gather_val(str(log_file), "s1", ("a", "b", "c")) + assert len(result_three_keys) == 3 # one entry missing 'c' is skipped + assert result_three_keys[0] == [1, 2, 3] + assert result_three_keys[1] == [4, 5, 6] + assert result_three_keys[2] == [13, 14, 15] + + # Test with allow_skip=False - should raise error on missing key + with pytest.raises(ValueError, match="missing keys"): + gather_val(str(log_file), "s1", ("a", "b", "c"), allow_skip=False) + + # Test with different stream + result_s2 = gather_val(str(log_file), "s2", ("a", "c")) + assert len(result_s2) == 1 + assert result_s2[0] == [7, 9] + + # Test with non-existent stream + empty_result = gather_val(str(log_file), "nonexistent", ("a", "b")) + assert len(empty_result) == 0 + + +def test_get_any_from_stream(): + """Test extracting first value of a key from stream and KeyError on missing key""" + # Test with a list of dicts + stream = [ + {"foo": "bar", "value": 1}, + {"foo": "baz", "value": 2}, + {"other": "data", "value": 3}, + ] + + # Test successful key extraction (first occurrence) + result = get_any_from_stream(stream, "foo") + assert result == "bar" # should get the first one + + # Test key that exists later + result_value = get_any_from_stream(stream, "value") + assert result_value == 1 # first occurrence + + # Test key that appears only in later entry + result_other = get_any_from_stream(stream, "other") + assert result_other == "data" + + # Test KeyError on missing key + with pytest.raises(KeyError, match="key 'nonexistent' not found in stream"): + get_any_from_stream(stream, "nonexistent") + + # Test with empty stream + with pytest.raises(KeyError, match="key 'foo' not found in stream"): + get_any_from_stream([], "foo") diff --git a/tests/unit/test_collect_warnings.py b/tests/unit/test_collect_warnings.py new file mode 100644 index 00000000..ffa5764f --- /dev/null +++ b/tests/unit/test_collect_warnings.py @@ -0,0 +1,464 @@ +from __future__ import annotations + +import sys +import warnings +from io import StringIO + +import pytest + +from muutils.collect_warnings import CollateWarnings + + +def test_basic_warning_capture(): + """Test that warnings issued inside the context populate the counts dict.""" + with CollateWarnings(print_on_exit=False) as cw: + warnings.warn("test warning 1", UserWarning) + warnings.warn("test warning 2", DeprecationWarning) + + assert len(cw.counts) == 2 + + # Check that the warnings are in the counts dict + warning_messages = [msg for (_, _, _, msg) in cw.counts.keys()] + assert "test warning 1" in warning_messages + assert "test warning 2" in warning_messages + + # Check that the category names are correct + categories = [cat for (_, _, cat, _) in cw.counts.keys()] + assert "UserWarning" in categories + assert "DeprecationWarning" in categories + + # Check that counts are 1 for each + assert all(count == 1 for count in cw.counts.values()) + + +def test_collation(): + """Test that duplicate warnings from the same line increment count correctly.""" + with CollateWarnings(print_on_exit=False) as cw: + # Issue the same warning multiple times from a loop (same line) + for _ in range(3): + warnings.warn("duplicate warning", UserWarning) + warnings.warn("different warning", UserWarning) + + # The duplicate warnings from the same line should be collated + # Find the duplicate warning entry + duplicate_count = None + different_count = None + for (filename, lineno, category, message), count in cw.counts.items(): + if message == "duplicate warning": + duplicate_count = count + elif message == "different warning": + different_count = count + + assert duplicate_count == 3 + assert different_count == 1 + + +def test_print_on_exit_true(): + """Test that warnings are printed to stderr on exit when print_on_exit=True.""" + # Capture stderr + old_stderr = sys.stderr + sys.stderr = StringIO() + + try: + with CollateWarnings(print_on_exit=True) as cw: + warnings.warn("printed warning", UserWarning) + + assert cw + + # Get the output + stderr_output = sys.stderr.getvalue() + + # Check that the warning was printed + assert "printed warning" in stderr_output + assert "UserWarning" in stderr_output + assert "(1x)" in stderr_output # Default format includes count + + finally: + # Restore stderr + sys.stderr = old_stderr + + +def test_print_on_exit_false(): + """Test that no output is produced but counts are tracked when print_on_exit=False.""" + # Capture stderr + old_stderr = sys.stderr + sys.stderr = StringIO() + + try: + with CollateWarnings(print_on_exit=False) as cw: + warnings.warn("silent warning", UserWarning) + + # Get the output + stderr_output = sys.stderr.getvalue() + + # Check that nothing was printed + assert stderr_output == "" + + # But counts should still be tracked + assert len(cw.counts) == 1 + warning_messages = [msg for (_, _, _, msg) in cw.counts.keys()] + assert "silent warning" in warning_messages + + finally: + # Restore stderr + sys.stderr = old_stderr + + +def test_custom_format_string(): + """Test that custom fmt parameter controls output format.""" + # Capture stderr + old_stderr = sys.stderr + sys.stderr = StringIO() + + try: + custom_fmt = "WARNING: {message} ({category}) appeared {count} times" + with CollateWarnings(print_on_exit=True, fmt=custom_fmt) as cw: + warnings.warn("custom format warning", UserWarning) + + assert cw + + # Get the output + stderr_output = sys.stderr.getvalue() + + # Check that the custom format was used + assert "WARNING: custom format warning" in stderr_output + assert "(UserWarning)" in stderr_output + assert "appeared 1 times" in stderr_output + + # Check that default format was NOT used + assert "(1x)" not in stderr_output + + finally: + # Restore stderr + sys.stderr = old_stderr + + +def test_multiple_different_warnings(): + """Test handling of multiple different warnings.""" + with CollateWarnings(print_on_exit=False) as cw: + warnings.warn("warning 1", UserWarning) + warnings.warn("warning 2", DeprecationWarning) + warnings.warn("warning 3", FutureWarning) + warnings.warn("warning 4", RuntimeWarning) + + assert len(cw.counts) == 4 + + categories = [cat for (_, _, cat, _) in cw.counts.keys()] + assert "UserWarning" in categories + assert "DeprecationWarning" in categories + assert "FutureWarning" in categories + assert "RuntimeWarning" in categories + + +def test_no_warnings(): + """Test that CollateWarnings works correctly when no warnings are issued.""" + with CollateWarnings(print_on_exit=False) as cw: + # No warnings issued + pass + + assert len(cw.counts) == 0 + + +def test_same_message_different_categories(): + """Test that same message with different categories are counted separately.""" + with CollateWarnings(print_on_exit=False) as cw: + # Issue same message with different categories from the same line in a loop + for _ in range(2): + warnings.warn("same message", UserWarning) + warnings.warn("same message", DeprecationWarning) + + # Find the counts for each category + user_warning_count = 0 + deprecation_warning_count = 0 + for (_, _, category, message), count in cw.counts.items(): + if message == "same message" and category == "UserWarning": + user_warning_count += count + elif message == "same message" and category == "DeprecationWarning": + deprecation_warning_count += count + + assert user_warning_count == 2 + assert deprecation_warning_count == 1 + + +def test_filename_and_lineno_tracking(): + """Test that filename and line number are tracked correctly.""" + with CollateWarnings(print_on_exit=False) as cw: + warnings.warn("tracked warning", UserWarning) + + assert len(cw.counts) == 1 + + # Get the filename and lineno + (filename, lineno, category, message) = list(cw.counts.keys())[0] + + # Check that filename and lineno are present and reasonable + assert filename is not None + assert isinstance(filename, str) + assert lineno is not None + assert isinstance(lineno, int) + assert lineno > 0 + + +def test_context_manager_re_entry_fails(): + """Test that CollateWarnings cannot be re-entered while active.""" + cw = CollateWarnings(print_on_exit=False) + + with cw: + # Try to re-enter while still inside the context + with pytest.raises(RuntimeError, match="cannot be re-entered"): + with cw: + pass + + +def test_format_string_all_fields(): + """Test that all format fields work correctly.""" + old_stderr = sys.stderr + sys.stderr = StringIO() + + try: + fmt = "count={count} file={filename} line={lineno} cat={category} msg={message}" + with CollateWarnings(print_on_exit=True, fmt=fmt) as cw: + warnings.warn("test all fields", UserWarning) + + assert cw + + stderr_output = sys.stderr.getvalue() + + # Check that all fields are present + assert "count=1" in stderr_output + assert "file=" in stderr_output + assert "line=" in stderr_output + assert "cat=UserWarning" in stderr_output + assert "msg=test all fields" in stderr_output + + finally: + sys.stderr = old_stderr + + +def test_warning_with_stacklevel(): + """Test that warnings with different stacklevels are handled correctly.""" + + def issue_warning(): + warnings.warn("nested warning", UserWarning, stacklevel=2) + + with CollateWarnings(print_on_exit=False) as cw: + issue_warning() + + assert len(cw.counts) == 1 + warning_messages = [msg for (_, _, _, msg) in cw.counts.keys()] + assert "nested warning" in warning_messages + + +def test_counts_dict_structure(): + """Test the structure of the counts dictionary.""" + with CollateWarnings(print_on_exit=False) as cw: + warnings.warn("test warning", UserWarning) + + # Check that counts is a Counter + from collections import Counter + + assert isinstance(cw.counts, Counter) + + # Check the key structure + key = list(cw.counts.keys())[0] + assert isinstance(key, tuple) + assert len(key) == 4 + + filename, lineno, category, message = key + assert isinstance(filename, str) + assert isinstance(lineno, int) + assert isinstance(category, str) + assert isinstance(message, str) + + +def test_large_number_of_warnings(): + """Test handling of a large number of duplicate warnings.""" + with CollateWarnings(print_on_exit=False) as cw: + for i in range(1000): + warnings.warn("repeated warning", UserWarning) + + assert len(cw.counts) == 1 + + # Find the count + count = list(cw.counts.values())[0] + assert count == 1000 + + +def test_mixed_warning_counts(): + """Test a mix of different warning counts.""" + with CollateWarnings(print_on_exit=False) as cw: + # Warning A: 5 times + for _ in range(5): + warnings.warn("warning A", UserWarning) + + # Warning B: 3 times + for _ in range(3): + warnings.warn("warning B", DeprecationWarning) + + # Warning C: 1 time + warnings.warn("warning C", FutureWarning) + + assert len(cw.counts) == 3 + + # Extract counts by message + counts_by_message = {} + for (_, _, _, message), count in cw.counts.items(): + counts_by_message[message] = count + + assert counts_by_message["warning A"] == 5 + assert counts_by_message["warning B"] == 3 + assert counts_by_message["warning C"] == 1 + + +def test_exception_propagation(): + """Test that exceptions from the with-block are propagated.""" + with pytest.raises(ValueError, match="test exception"): + with CollateWarnings(print_on_exit=False) as cw: + warnings.warn("warning before exception", UserWarning) + raise ValueError("test exception") + + # Counts should still be populated even though an exception was raised + assert len(cw.counts) == 1 + + +def test_warning_with_special_characters(): + """Test warnings with special characters in messages.""" + with CollateWarnings(print_on_exit=False) as cw: + warnings.warn("warning with 'quotes' and \"double quotes\"", UserWarning) + warnings.warn("warning with\nnewline", UserWarning) + warnings.warn("warning with\ttab", UserWarning) + + assert len(cw.counts) == 3 + + messages = [msg for (_, _, _, msg) in cw.counts.keys()] + assert "warning with 'quotes' and \"double quotes\"" in messages + assert "warning with\nnewline" in messages + assert "warning with\ttab" in messages + + +def test_empty_warning_message(): + """Test warning with empty message.""" + with CollateWarnings(print_on_exit=False) as cw: + warnings.warn("", UserWarning) + + assert len(cw.counts) == 1 + messages = [msg for (_, _, _, msg) in cw.counts.keys()] + assert "" in messages + + +def test_unicode_warning_message(): + """Test warnings with unicode characters.""" + with CollateWarnings(print_on_exit=False) as cw: + warnings.warn("warning with unicode: 你好 мир 🌍", UserWarning) + + assert len(cw.counts) == 1 + messages = [msg for (_, _, _, msg) in cw.counts.keys()] + assert "warning with unicode: 你好 мир 🌍" in messages + + +def test_custom_warning_class(): + """Test with custom warning classes.""" + + class CustomWarning(UserWarning): + pass + + with CollateWarnings(print_on_exit=False) as cw: + warnings.warn("custom warning", CustomWarning) + + assert len(cw.counts) == 1 + categories = [cat for (_, _, cat, _) in cw.counts.keys()] + assert "CustomWarning" in categories + + +def test_default_format_string(): + """Test the default format string output.""" + old_stderr = sys.stderr + sys.stderr = StringIO() + + try: + with CollateWarnings(print_on_exit=True) as cw: + warnings.warn("test default format", UserWarning) + + assert cw + + stderr_output = sys.stderr.getvalue().strip() + + # Default format: "({count}x) {filename}:{lineno} {category}: {message}" + assert stderr_output.startswith("(1x)") + assert "UserWarning: test default format" in stderr_output + assert ":" in stderr_output # filename:lineno separator + + finally: + sys.stderr = old_stderr + + +def test_collate_warnings_with_warnings_always(): + """Test that warnings.simplefilter('always') is set correctly.""" + # This test verifies that even if we would normally suppress duplicate warnings, + # CollateWarnings captures them all + with CollateWarnings(print_on_exit=False) as cw: + # These would normally be suppressed if the same warning is issued twice + # from the same location, but CollateWarnings should capture all of them + for _ in range(3): + warnings.warn("repeated warning", UserWarning) + + # All 3 warnings should be captured + count = list(cw.counts.values())[0] + assert count == 3 + + +def test_multiple_warnings_same_line(): + """Test multiple different warnings from the same line.""" + with CollateWarnings(print_on_exit=False) as cw: + warnings.warn("warning 1", UserWarning) + warnings.warn("warning 2", UserWarning) # noqa: E702 + + # Should have 2 different warnings (different messages, same line) + assert len(cw.counts) == 2 + + +def test_counts_accessible_after_exit(): + """Test that counts are accessible after exiting the context.""" + with CollateWarnings(print_on_exit=False) as cw: + warnings.warn("test warning", UserWarning) + + # After exiting, counts should still be accessible + assert len(cw.counts) == 1 + assert cw.counts is not None + + # Should be able to iterate over counts + for key, count in cw.counts.items(): + assert isinstance(key, tuple) + assert isinstance(count, int) + + +def test_print_on_exit_default_true(): + """Test that print_on_exit defaults to True.""" + old_stderr = sys.stderr + sys.stderr = StringIO() + + try: + # Don't specify print_on_exit, should default to True + with CollateWarnings() as cw: + warnings.warn("default print test", UserWarning) + + assert cw + stderr_output = sys.stderr.getvalue() + assert "default print test" in stderr_output + + finally: + sys.stderr = old_stderr + + +def test_exit_twice_fails(): + """Test that calling __exit__ twice raises RuntimeError.""" + cw = CollateWarnings(print_on_exit=False) + + # Enter the context + cw.__enter__() + + # Exit once + cw.__exit__(None, None, None) + + # Try to exit again - should raise RuntimeError + with pytest.raises(RuntimeError, match="exited twice"): + cw.__exit__(None, None, None) diff --git a/tests/unit/test_jsonlines.py b/tests/unit/test_jsonlines.py new file mode 100644 index 00000000..e06462d4 --- /dev/null +++ b/tests/unit/test_jsonlines.py @@ -0,0 +1,194 @@ +from __future__ import annotations + +import gzip +import json +from pathlib import Path + +import pytest + +from muutils.jsonlines import jsonl_load, jsonl_load_log, jsonl_write + +TEMP_PATH: Path = Path("tests/_temp/jsonl") + + +def test_jsonl_load(): + """Test loading jsonlines file - write data, load it back, verify it matches.""" + # Create temp directory + TEMP_PATH.mkdir(parents=True, exist_ok=True) + + test_file = TEMP_PATH / "test_load.jsonl" + + # Create test data + test_data = [ + {"id": 1, "name": "Alice", "value": 42.5}, + {"id": 2, "name": "Bob", "value": 17.3}, + {"id": 3, "name": "Charlie", "value": None}, + {"list": [1, 2, 3], "nested": {"a": 1, "b": 2}}, + ] + + # Write the data manually + with open(test_file, "w", encoding="UTF-8") as f: + for item in test_data: + f.write(json.dumps(item) + "\n") + + # Load it back using jsonl_load + loaded_data = jsonl_load(str(test_file)) + + # Verify the data matches + assert loaded_data == test_data + assert len(loaded_data) == 4 + assert loaded_data[0]["name"] == "Alice" + assert loaded_data[3]["nested"]["b"] == 2 + + +def test_jsonl_write(): + """Test writing jsonlines data - write using jsonl_write, read raw contents, verify format.""" + # Create temp directory + TEMP_PATH.mkdir(parents=True, exist_ok=True) + + test_file = TEMP_PATH / "test_write.jsonl" + + # Test data + test_data = [ + {"id": 1, "status": "active"}, + {"id": 2, "status": "inactive"}, + {"id": 3, "status": "pending", "metadata": {"priority": "high"}}, + ] + + # Write using jsonl_write + jsonl_write(str(test_file), test_data) + + # Read raw contents + with open(test_file, "r", encoding="UTF-8") as f: + lines = f.readlines() + + # Verify format + assert len(lines) == 3 + + # Each line should be valid JSON + for i, line in enumerate(lines): + assert line.endswith("\n") + parsed = json.loads(line) + assert parsed == test_data[i] + + # Verify specific content + assert json.loads(lines[0]) == {"id": 1, "status": "active"} + assert json.loads(lines[2])["metadata"]["priority"] == "high" + + +def test_gzip_support(): + """Test .gz extension auto-detection for both reading and writing.""" + # Create temp directory + TEMP_PATH.mkdir(parents=True, exist_ok=True) + + test_file_gz = TEMP_PATH / "test_gzip.jsonl.gz" + test_file_gzip = TEMP_PATH / "test_gzip2.jsonl.gzip" + + # Test data + test_data = [ + {"compressed": True, "value": 123}, + {"compressed": True, "value": 456}, + ] + + # Test with .gz extension - auto-detection + jsonl_write(str(test_file_gz), test_data) + + # Verify it's actually gzipped by trying to read with gzip + with gzip.open(test_file_gz, "rt", encoding="UTF-8") as f: + lines = f.readlines() + assert len(lines) == 2 + + # Load back using jsonl_load with auto-detection + loaded_data = jsonl_load(str(test_file_gz)) + assert loaded_data == test_data + + # Test with .gzip extension + jsonl_write(str(test_file_gzip), test_data) + loaded_data_gzip = jsonl_load(str(test_file_gzip)) + assert loaded_data_gzip == test_data + + # Test explicit use_gzip parameter + test_file_explicit = TEMP_PATH / "test_explicit.jsonl" + jsonl_write(str(test_file_explicit), test_data, use_gzip=True) + + # Should be gzipped even without .gz extension + with gzip.open(test_file_explicit, "rt", encoding="UTF-8") as f: + lines = f.readlines() + assert len(lines) == 2 + + loaded_explicit = jsonl_load(str(test_file_explicit), use_gzip=True) + assert loaded_explicit == test_data + + +def test_jsonl_load_log(): + """Test jsonl_load_log with dict assertion - test with valid dicts and non-dict items.""" + # Create temp directory + TEMP_PATH.mkdir(parents=True, exist_ok=True) + + # Test with valid dict data + test_file_valid = TEMP_PATH / "test_log_valid.jsonl" + valid_data = [ + {"level": "INFO", "message": "Starting process"}, + {"level": "WARNING", "message": "Low memory"}, + {"level": "ERROR", "message": "Connection failed"}, + ] + + jsonl_write(str(test_file_valid), valid_data) + loaded_log = jsonl_load_log(str(test_file_valid)) + + assert loaded_log == valid_data + assert all(isinstance(item, dict) for item in loaded_log) + + # Test with non-dict items - should raise AssertionError + test_file_invalid = TEMP_PATH / "test_log_invalid.jsonl" + invalid_data = [ + {"level": "INFO", "message": "Valid entry"}, + "not a dict", # This is invalid + {"level": "ERROR", "message": "Another valid entry"}, + ] + + jsonl_write(str(test_file_invalid), invalid_data) + + with pytest.raises(AssertionError) as exc_info: + jsonl_load_log(str(test_file_invalid)) + + # Verify the error message contains useful information + error_msg = str(exc_info.value) + assert "idx = 1" in error_msg + assert "is not a dict" in error_msg + + # Test with list item + test_file_list = TEMP_PATH / "test_log_list.jsonl" + list_data = [ + {"level": "INFO"}, + [1, 2, 3], # List instead of dict + ] + + jsonl_write(str(test_file_list), list_data) + + with pytest.raises(AssertionError) as exc_info: + jsonl_load_log(str(test_file_list)) + + error_msg = str(exc_info.value) + assert "idx = 1" in error_msg + assert "is not a dict" in error_msg + + +def test_gzip_compresslevel(): + """Test that gzip_compresslevel parameter works without errors.""" + # Create temp directory + TEMP_PATH.mkdir(parents=True, exist_ok=True) + + test_file = TEMP_PATH / "test_compresslevel.jsonl.gz" + + # Create test data + test_data = [{"value": i, "data": "content"} for i in range(10)] + + # Write with different compression levels - should not error + jsonl_write(str(test_file), test_data, gzip_compresslevel=1) + loaded_data = jsonl_load(str(test_file)) + assert loaded_data == test_data + + jsonl_write(str(test_file), test_data, gzip_compresslevel=9) + loaded_data = jsonl_load(str(test_file)) + assert loaded_data == test_data From 3fb953d35a277ce54948c4da17eefae8ec056f35 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Wed, 29 Oct 2025 14:29:58 +0000 Subject: [PATCH 15/72] fix benchmark parallel tests --- tests/test_benchmark_demo.py | 54 ------------- .../benchmark_parallel}/benchmark_parallel.py | 77 +++++++++++-------- .../benchmark_parallel/test_benchmark_demo.py | 17 ++++ 3 files changed, 62 insertions(+), 86 deletions(-) delete mode 100644 tests/test_benchmark_demo.py rename tests/{ => unit/benchmark_parallel}/benchmark_parallel.py (88%) create mode 100644 tests/unit/benchmark_parallel/test_benchmark_demo.py diff --git a/tests/test_benchmark_demo.py b/tests/test_benchmark_demo.py deleted file mode 100644 index 7c496d0a..00000000 --- a/tests/test_benchmark_demo.py +++ /dev/null @@ -1,54 +0,0 @@ -#!/usr/bin/env python3 -"""Simple demo of using the benchmark script.""" - -import sys -import os - -sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) - -from tests.benchmark_parallel import ( - BenchmarkRunner, - cpu_bound_task, - io_bound_task, - light_cpu_task, - print_summary_stats, -) - - -def quick_benchmark(): - """Run a quick benchmark with small data sizes.""" - print("Running quick benchmark demo...\n") - - # Small data sizes for quick demo - data_sizes = [100, 500, 1000] - task_funcs = { - "cpu_bound": cpu_bound_task, - "io_bound": io_bound_task, - "light_cpu": light_cpu_task, - } - - # Run benchmarks - runner = BenchmarkRunner() - df = runner.run_benchmark_suite(data_sizes, task_funcs, runs_per_method=2) - - # Print results - print("\n" + "=" * 60) - print("BENCHMARK RESULTS DATAFRAME") - print("=" * 60) - print(df.to_string()) - - print("\n" + "=" * 60) - print_summary_stats(df) - - # Show example of filtering data - print("\n" + "=" * 60) - print("EXAMPLE: CPU-bound tasks only") - print("=" * 60) - cpu_df = df[df["task"] == "cpu_bound"] - print(cpu_df[["method", "data_size", "mean_time", "speedup"]].to_string()) - - return df - - -if __name__ == "__main__": - df = quick_benchmark() diff --git a/tests/benchmark_parallel.py b/tests/unit/benchmark_parallel/benchmark_parallel.py similarity index 88% rename from tests/benchmark_parallel.py rename to tests/unit/benchmark_parallel/benchmark_parallel.py index d1a065f8..9d04530b 100644 --- a/tests/benchmark_parallel.py +++ b/tests/unit/benchmark_parallel/benchmark_parallel.py @@ -1,12 +1,12 @@ -#!/usr/bin/env python3 """Benchmark test comparing run_maybe_parallel with other parallelization techniques. Run with: python tests/benchmark_parallel.py """ +from pathlib import Path import time import multiprocessing -from typing import List, Callable, Any, Dict, Tuple, Union +from typing import List, Callable, Any, Dict, Optional, Sequence, Tuple, Union import pandas as pd # type: ignore[import-untyped] import numpy as np from collections import defaultdict @@ -308,7 +308,6 @@ def plot_timing_comparison( def plot_efficiency_heatmap(df: pd.DataFrame, save_path: str = None): """Plot efficiency heatmap (speedup across methods and tasks).""" import matplotlib.pyplot as plt # type: ignore[import-untyped] - import seaborn as sns # type: ignore[import-untyped] # Create pivot table for heatmap pivot_df = df.pivot_table( @@ -317,15 +316,19 @@ def plot_efficiency_heatmap(df: pd.DataFrame, save_path: str = None): # Create heatmap plt.figure(figsize=(12, 8)) - sns.heatmap( - pivot_df, - annot=True, - fmt=".2f", - cmap="YlOrRd", - vmin=0, - center=1, - cbar_kws={"label": "Speedup"}, - ) + # sns.heatmap( + # pivot_df, + # annot=True, + # fmt=".2f", + # cmap="YlOrRd", + # vmin=0, + # center=1, + # cbar_kws={"label": "Speedup"}, + # ) + plt.imshow(pivot_df, aspect="auto", cmap="YlOrRd", vmin=0) + plt.colorbar(label="Speedup") + plt.yticks(range(len(pivot_df.index)), [f"{t[0]}-{t[1]}" for t in pivot_df.index]) + plt.xticks(range(len(pivot_df.columns)), pivot_df.columns, rotation=45) plt.title("Parallelization Efficiency Heatmap") plt.tight_layout() @@ -368,17 +371,29 @@ def print_summary_stats(df: pd.DataFrame): print(avg_speedup.sort_values("mean", ascending=False)) -def main(): +_DEFAULT_TASK_FUNCS: dict[str, Callable[[int], int]] = { + "cpu_bound": cpu_bound_task, + "io_bound": io_bound_task, + "light_cpu": light_cpu_task, +} + + +def main( + data_sizes: Sequence[int] = (100, 1000, 5000, 10000), + base_path: Path = Path("."), + plot: bool = True, + task_funcs: Optional[Dict[str, Callable[[int], int]]] = None, +): """Run benchmarks and display results.""" print("Starting parallelization benchmark...") + base_path = Path(base_path) + base_path.mkdir(parents=True, exist_ok=True) + + # Configure benchmark parameters - data_sizes = [100, 1000, 5000, 10000] - task_funcs = { - "cpu_bound": cpu_bound_task, - "io_bound": io_bound_task, - "light_cpu": light_cpu_task, - } + if task_funcs is None: + task_funcs = _DEFAULT_TASK_FUNCS # Run benchmarks runner = BenchmarkRunner() @@ -391,23 +406,21 @@ def main(): # Display summary print_summary_stats(df) - # Create visualizations - import matplotlib # type: ignore[import-untyped] + if plot: + # Create visualizations + import matplotlib # type: ignore[import-untyped] - matplotlib.use("Agg") # Use non-interactive backend + matplotlib.use("Agg") # Use non-interactive backend - # Plot speedup by data size for each task type - for task in task_funcs.keys(): - plot_speedup_by_data_size(df, task, f"speedup_{task}.png") - print(f"Saved speedup plot for {task} tasks to speedup_{task}.png") + # Plot speedup by data size for each task type + for task in task_funcs.keys(): + plot_speedup_by_data_size(df, task, base_path / f"speedup_{task}.png") - # Plot timing comparison for largest data size - plot_timing_comparison(df, data_sizes[-1], "timing_comparison.png") - print("Saved timing comparison to timing_comparison.png") + # Plot timing comparison for largest data size + plot_timing_comparison(df, data_sizes[-1], base_path / "timing_comparison.png") - # Plot efficiency heatmap - plot_efficiency_heatmap(df, "efficiency_heatmap.png") - print("Saved efficiency heatmap to efficiency_heatmap.png") + # Plot efficiency heatmap + plot_efficiency_heatmap(df, base_path / "efficiency_heatmap.png") return df diff --git a/tests/unit/benchmark_parallel/test_benchmark_demo.py b/tests/unit/benchmark_parallel/test_benchmark_demo.py new file mode 100644 index 00000000..8042447b --- /dev/null +++ b/tests/unit/benchmark_parallel/test_benchmark_demo.py @@ -0,0 +1,17 @@ +"""Simple demo of using the benchmark script.""" + + +from benchmark_parallel import io_bound_task, light_cpu_task, main + + +def test_main(): + """Test the main function of the benchmark script.""" + main( + data_sizes=(1, 2), + base_path="tests/_temp/benchmark_demo", + plot=True, + task_funcs= { + "io_bound": io_bound_task, + "light_cpu": light_cpu_task, + }, + ) From cfca63cd0be3a1476fccb2b6ff485b4edc76c825 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Wed, 29 Oct 2025 14:30:35 +0000 Subject: [PATCH 16/72] format, fix one path --- tests/unit/benchmark_parallel/benchmark_parallel.py | 3 +-- tests/unit/benchmark_parallel/test_benchmark_demo.py | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/tests/unit/benchmark_parallel/benchmark_parallel.py b/tests/unit/benchmark_parallel/benchmark_parallel.py index 9d04530b..29a7391f 100644 --- a/tests/unit/benchmark_parallel/benchmark_parallel.py +++ b/tests/unit/benchmark_parallel/benchmark_parallel.py @@ -390,7 +390,6 @@ def main( base_path = Path(base_path) base_path.mkdir(parents=True, exist_ok=True) - # Configure benchmark parameters if task_funcs is None: task_funcs = _DEFAULT_TASK_FUNCS @@ -400,7 +399,7 @@ def main( df = runner.run_benchmark_suite(data_sizes, task_funcs, runs_per_method=3) # Save results - df.to_csv("benchmark_results.csv", index=False) + df.to_csv(base_path / "benchmark_results.csv", index=False) print("\nResults saved to benchmark_results.csv") # Display summary diff --git a/tests/unit/benchmark_parallel/test_benchmark_demo.py b/tests/unit/benchmark_parallel/test_benchmark_demo.py index 8042447b..0184fef6 100644 --- a/tests/unit/benchmark_parallel/test_benchmark_demo.py +++ b/tests/unit/benchmark_parallel/test_benchmark_demo.py @@ -1,6 +1,5 @@ """Simple demo of using the benchmark script.""" - from benchmark_parallel import io_bound_task, light_cpu_task, main @@ -10,7 +9,7 @@ def test_main(): data_sizes=(1, 2), base_path="tests/_temp/benchmark_demo", plot=True, - task_funcs= { + task_funcs={ "io_bound": io_bound_task, "light_cpu": light_cpu_task, }, From 7edf9df01d23b9a5c4283a2430cad73ebce141c3 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Wed, 29 Oct 2025 14:34:42 +0000 Subject: [PATCH 17/72] more dbg tests --- tests/unit/test_dbg.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/tests/unit/test_dbg.py b/tests/unit/test_dbg.py index 80fa32c9..b792bd54 100644 --- a/tests/unit/test_dbg.py +++ b/tests/unit/test_dbg.py @@ -13,7 +13,9 @@ _process_path, _CWD, # we do use this as a global in `test_dbg_counter_increments` - _COUNTER, # noqa: F401 + _COUNTER, + dbg_auto, + dbg_dict, # noqa: F401 grep_repr, _normalize_for_loose, _compile_pattern, @@ -219,6 +221,13 @@ def test_dbg_non_callable_formatter() -> None: dbg(42, formatter="not callable") # type: ignore +def test_misc() -> None: + d1 = {"apple": 1, "banana": 2, "cherry": 3} + dbg_dict(d1) + dbg_auto(d1) + l1 = [10, 20, 30] + dbg_auto(l1) + # # --- Tests for tensor_info_dict and tensor_info --- # def test_tensor_info_dict_with_nan() -> None: # tensor: DummyTensor = DummyTensor() From 8ba78fb4d7e8f096bd63de408d1d165d722ff6a4 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Wed, 29 Oct 2025 14:36:12 +0000 Subject: [PATCH 18/72] makefile merging wip --- makefile | 538 +++++++++++++--- makefile-old | 1695 ++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 2151 insertions(+), 82 deletions(-) create mode 100644 makefile-old diff --git a/makefile b/makefile index ee992946..ee6e3fea 100644 --- a/makefile +++ b/makefile @@ -2,10 +2,14 @@ #| python project makefile template | #| originally by Michael Ivanitskiy (mivanits@umich.edu) | #| https://github.com/mivanit/python-project-makefile-template | -#| version: v0.3.4 | +#| version: v0.4.0 | #| license: https://creativecommons.org/licenses/by-sa/4.0/ | -#| modifications from the original should be denoted with `~~~~~` | -#| as this makes it easier to find edits when updating makefile | +#|==================================================================| +#| CUSTOMIZATION: | +#| - modify PACKAGE_NAME and other variables in config section | +#| - mark custom changes with `~~~~~` for easier template updates | +#| - run `make help` to see available targets | +#| - run `make help=TARGET` for detailed info about specific target | #|==================================================================| @@ -21,9 +25,10 @@ # configuration & variables # ================================================== +# !!! MODIFY AT LEAST THIS PART TO SUIT YOUR PROJECT !!! # it assumes that the source is in a directory named the same as the package name # this also gets passed to some other places -PACKAGE_NAME := muutils +PACKAGE_NAME := myproject # for checking you are on the right branch when publishing PUBLISH_BRANCH := main @@ -123,32 +128,7 @@ endif # if you want different behavior for different python versions # -------------------------------------------------- -# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -# compatibility mode for python <3.10 - -# loose typing, allow warnings for python <3.10 -# -------------------------------------------------- -TYPECHECK_ARGS ?= -# COMPATIBILITY_MODE: whether to run in compatibility mode for python <3.10 -COMPATIBILITY_MODE := $(shell $(PYTHON) -c "import sys; print(1 if sys.version_info < (3, 11) else 0)") - -# compatibility mode for python <3.10 -# -------------------------------------------------- - -# whether to run pytest with warnings as errors -WARN_STRICT ?= 0 - -ifneq ($(WARN_STRICT), 0) - PYTEST_OPTIONS += -W error -endif - -# compatibility mode for python <3.10 -ifeq ($(COMPATIBILITY_MODE), 1) - JUNK := $(info !!! WARNING !!!: Detected python version less than 3.10, some behavior will be different) - TYPECHECK_ARGS += --disable-error-code misc --disable-error-code syntax --disable-error-code import-not-found --no-check-untyped-defs -endif - -# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# COMPATIBILITY_MODE := $(shell $(PYTHON) -c "import sys; print(1 if sys.version_info < (3, 10) else 0)") # options we might want to pass to pytest # -------------------------------------------------- @@ -1255,6 +1235,334 @@ endef export SCRIPT_MYPY_REPORT +# get information about makefile recipes/targets +define SCRIPT_RECIPE_INFO +# source: https://github.com/mivanit/python-project-makefile-template/tree/main/scripts/make/recipe_info.py + +"""CLI tool to get information about Makefile recipes/targets.""" + +from __future__ import annotations + +import argparse +import difflib +import fnmatch +import re +import sys +from dataclasses import dataclass +from pathlib import Path +from typing import Dict, List, Optional, Set, Union + + +def _scan_makefile( + lines: List[str], + target_name: Optional[str] = None, +) -> Union[Dict[str, int], int]: + """Scan makefile for target definitions, skipping define blocks. + + Args: + lines: Makefile lines + target_name: If provided, return line index for this specific target. + If None, return dict of all targets. + + Returns: + If target_name is None: dict mapping target names to line indices + If target_name is provided: line index of that target, or -1 if not found + + """ + in_define_block: bool = False + target_rx: re.Pattern = re.compile(r"^([a-zA-Z0-9_-]+)[ \t]*:") + targets: Dict[str, int] = {} + + for i, line in enumerate(lines): + # Track if we're inside a define block (embedded scripts) + if line.startswith("define "): + in_define_block = True + continue + if line.startswith("endef"): + in_define_block = False + continue + + # Skip lines inside define blocks + if in_define_block: + continue + + # Match target definitions + match = target_rx.match(line) + if match: + tgt_name: str = match.group(1) + if target_name is not None: + # Looking for specific target + if tgt_name == target_name: + return i + else: + # Collecting all targets + targets[tgt_name] = i + + # Return results based on mode + if target_name is not None: + return -1 # Target not found + return targets + + +class Colors: + """ANSI color codes""" + + def __init__(self, enabled: bool = True) -> None: + "init color codes, or empty strings if `not enabled`" + if enabled: + self.RESET = "\033[0m" + self.BOLD = "\033[1m" + self.RED = "\033[31m" + self.GREEN = "\033[32m" + self.YELLOW = "\033[33m" + self.BLUE = "\033[34m" + self.MAGENTA = "\033[35m" + self.WHITE = "\033[37m" + else: + self.RESET = self.BOLD = "" + self.RED = self.GREEN = self.YELLOW = "" + self.BLUE = self.MAGENTA = self.WHITE = "" + + +@dataclass +class MakeRecipe: + """Information about a Makefile recipe/target.""" + + target: str + comments: List[str] + dependencies: List[str] + echo_message: str + + @classmethod + def from_makefile(cls, lines: List[str], target: str) -> MakeRecipe: + """Parse and create a MakeRecipe from makefile lines for *target*.""" + i: int = _scan_makefile(lines, target_name=target) + if i == -1: + err_msg: str = f"target '{target}' not found in makefile" + raise ValueError(err_msg) + + line: str = lines[i] + + # contiguous comment block immediately above + # (skip backward past .PHONY declarations and blank lines) + comments: List[str] = [] + j: int = i - 1 + blank_count: int = 0 + while j >= 0: + stripped: str = lines[j].lstrip() + if stripped.startswith("#"): + comments.append(stripped[1:].lstrip()) + blank_count = 0 # Reset blank counter when we hit a comment + j -= 1 + elif stripped == "": + # Track consecutive blank lines + blank_count += 1 + if blank_count >= 2: # noqa: PLR2004 + # Hit 2 blank lines in a row - stop + break + j -= 1 + elif stripped.startswith(".PHONY:"): + # Skip .PHONY declarations + blank_count = 0 # Reset blank counter + j -= 1 + else: + # Hit a non-comment, non-blank, non-.PHONY line - stop + break + comments.reverse() + + # prerequisites + deps_str: str = line.split(":", 1)[1].strip() + deps: List[str] = deps_str.split() if deps_str else [] + + # first echo in the recipe + echo_msg: str = "" + k: int = i + 1 + while k < len(lines) and ( + lines[k].startswith("\t") or lines[k].startswith(" ") + ): + stripped: str = lines[k].lstrip() + m = re.match(r"@?echo[ \t]+(.*)", stripped) + if m: + content: str = m.group(1).strip() + if (content.startswith('"') and content.endswith('"')) or ( + content.startswith("'") and content.endswith("'") + ): + content = content[1:-1] + echo_msg = content + break + k += 1 + + return cls( + target=target, + comments=comments, + dependencies=deps, + echo_message=echo_msg, + ) + + def describe(self, color: bool = False) -> List[str]: + """Return a list of description lines for this recipe.""" + output: List[str] = [] + c: Colors = Colors(enabled=color) + + # Target name (bold blue) with colon in white + output.append(f"{c.BOLD}{c.BLUE}{self.target}{c.RESET}{c.WHITE}:{c.RESET}") + + # Echo message (description) in yellow + if self.echo_message: + output.append(f" {c.YELLOW}{self.echo_message}{c.RESET}") + + # Dependencies in magenta + if self.dependencies: + deps_str = " ".join( + f"{c.MAGENTA}{dep}{c.RESET}" for dep in self.dependencies + ) + output.append(f" {c.RED}depends-on:{c.RESET} {deps_str}") + + # Comments in green + if self.comments: + output.append(f" {c.RED}comments:{c.RESET}") + output.extend(f" {c.GREEN}{line}{c.RESET}" for line in self.comments) + + return output + + +def find_all_targets(lines: List[str]) -> List[str]: + """Find all .PHONY target names in the makefile.""" + # First, get all .PHONY declarations + phony_targets: Set[str] = set() + # Use chr(36) to get dollar sign - works both standalone and embedded in makefile + # issue being that the makefile processes dollar sign as an escape character + phony_pattern: re.Pattern = re.compile(r"^\.PHONY:\s+(.+)" + chr(36)) + + for line in lines: + match = phony_pattern.match(line) + if match: + # Get all targets from this .PHONY line (space-separated) + target_names: List[str] = match.group(1).split() + phony_targets.update(target_names) + + # Now scan for actual target definitions and filter to .PHONY ones + all_target_defs: Dict[str, int] = _scan_makefile(lines) + return [tgt for tgt in all_target_defs if tgt in phony_targets] + + +def get_all_recipes(lines: List[str]) -> List[MakeRecipe]: + """Get MakeRecipe objects for all .PHONY targets in the makefile.""" + targets: List[str] = find_all_targets(lines) + return [MakeRecipe.from_makefile(lines, target) for target in targets] + + +def describe_target(makefile_path: Path, target: str) -> None: + """Emit the description for *target*.""" + lines: List[str] = makefile_path.read_text(encoding="utf-8").splitlines() + recipe: MakeRecipe = MakeRecipe.from_makefile(lines, target) + + for line in recipe.describe(): + print(line) + + +def main() -> None: + """CLI entry point.""" + parser: argparse.ArgumentParser = argparse.ArgumentParser( + "recipe_info", + description="Get detailed information about Makefile recipes/targets", + ) + parser.add_argument( + "-f", + "--file", + default="makefile", + help="Path to the Makefile (default: ./makefile)", + ) + parser.add_argument( + "-a", + "--all", + action="store_true", + help="Print help for all targets in the Makefile", + ) + parser.add_argument( + "--no-color", + action="store_true", + help="Disable colored output (color is enabled by default)", + ) + parser.add_argument("targets", nargs="*", help="Target names") + args: argparse.Namespace = parser.parse_args() + + lines: List[str] = Path(args.file).read_text(encoding="utf-8").splitlines() + + # Get recipes to describe + if args.all: + recipes: List[MakeRecipe] = get_all_recipes(lines) + elif args.targets: + recipes = [] + all_targets: List[str] = find_all_targets(lines) + c: Colors = Colors(enabled=not args.no_color) + for tgt in args.targets: + # Check if target contains wildcard characters + if any(char in tgt for char in ["*", "?", "["]): + # Pattern matching mode + matched_targets: List[str] = [ + t for t in all_targets if fnmatch.fnmatch(t, tgt) + ] + if matched_targets: + for matched in matched_targets: + recipes.append(MakeRecipe.from_makefile(lines, matched)) + else: + print( + f"Error: no targets match pattern '{c.RED}{tgt}{c.RESET}'", + file=sys.stderr, + ) + sys.exit(1) + else: + # Exact target lookup + try: + recipes.append(MakeRecipe.from_makefile(lines, tgt)) + except ValueError: + # Find similar targets (fuzzy matching) + fuzzy_matches: List[str] = difflib.get_close_matches( + tgt, + all_targets, + n=5, + cutoff=0.6, + ) + # Also find targets that contain the attempted target + substring_matches: List[str] = [ + t for t in all_targets if tgt in t and t not in fuzzy_matches + ] + # Combine and deduplicate while preserving order + matches: List[str] = fuzzy_matches + substring_matches + matches = matches[:5] # Limit to 5 suggestions + + print( + f"Error: target '{c.RED}{tgt}{c.RESET}' not found in makefile", + file=sys.stderr, + ) + if matches: + suggestions: str = ", ".join( + f"{c.BLUE}{m}{c.RESET}" for m in matches + ) + print(f"Did you mean: {suggestions}?", file=sys.stderr) + sys.exit(1) + else: + recipes = [] + + if not recipes: + parser.error("Provide target names or use --all flag") + + # Print descriptions (color is True by default, unless --no-color is passed) + descriptions: List[str] = [ + line for recipe in recipes for line in recipe.describe(color=not args.no_color) + ] + print("\n".join(descriptions).replace("\n\n", f"\n{'-' * 40}\n")) + + +if __name__ == "__main__": + main() + +endef + +export SCRIPT_RECIPE_INFO + + ## ## ######## ######## ###### #### ####### ## ## ## ## ## ## ## ## ## ## ## ## ### ## ## ## ## ## ## ## ## ## ## #### ## @@ -1342,14 +1650,30 @@ dep-check-torch: @echo "see if torch is installed, and which CUDA version and devices it sees" $(PYTHON) -c "$$SCRIPT_CHECK_TORCH" +# sync dependencies and export to requirements.txt files +# - syncs all extras and groups with uv (including dev dependencies) +# - compiles bytecode for faster imports +# - exports to requirements.txt files per tool.uv-exports.exports config +# configure via pyproject.toml:[tool.uv-exports]: +# [tool.uv-exports] +# exports = [ +# { name = "base", extras = [], groups = [] }, # base package deps only +# { name = "dev", extras = [], groups = ["dev"] }, # dev dependencies +# { name = "all", extras = ["all"], groups = ["dev"] } # everything +# ] .PHONY: dep dep: @echo "Exporting dependencies as per $(PYPROJECT) section 'tool.uv-exports.exports'" uv sync --all-extras --all-groups --compile-bytecode mkdir -p $(REQUIREMENTS_DIR) $(PYTHON) -c "$$SCRIPT_EXPORT_REQUIREMENTS" $(PYPROJECT) $(REQUIREMENTS_DIR) | sh -x - + +# verify that requirements.txt files match current dependencies +# - exports deps to temp directory +# - diffs temp against existing requirements files +# - FAILS if any differences found (means you need to run `make dep`) +# useful in CI to catch when pyproject.toml changed but requirements weren't regenerated .PHONY: dep-check dep-check: @echo "Checking that exported requirements are up to date" @@ -1380,58 +1704,48 @@ dep-clean: # checks (formatting/linting, typing, tests) # ================================================== -# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -# added gen-extra-tests and it is required by some other recipes: -# format-check, typing, test - -# extra tests with python >=3.10 type hints -.PHONY: gen-extra-tests -gen-extra-tests: - if [ $(COMPATIBILITY_MODE) -eq 0 ]; then \ - echo "converting certain tests to modern format"; \ - $(PYTHON) tests/util/replace_type_hints.py tests/unit/validate_type/test_validate_type.py "# DO NOT EDIT, GENERATED FILE" > tests/unit/validate_type/test_validate_type_GENERATED.py; \ - fi; \ - -# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -# runs ruff and pycln to format the code +# format code AND auto-fix linting issues +# performs TWO operations: reformats code, then auto-fixes safe linting issues +# configure in pyproject.toml:[tool.ruff] .PHONY: format format: @echo "format the source code" $(PYTHON) -m ruff format --config $(PYPROJECT) . $(PYTHON) -m ruff check --fix --config $(PYPROJECT) . - $(PYTHON) -m pycln --config $(PYPROJECT) --all . -# runs ruff and pycln to check if the code is formatted correctly +# runs ruff to check if the code is formatted correctly .PHONY: format-check format-check: @echo "check if the source code is formatted correctly" $(PYTHON) -m ruff check --config $(PYPROJECT) . - $(PYTHON) -m pycln --check --config $(PYPROJECT) . -# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # runs type checks with mypy -# at some point, need to add back --check-untyped-defs to mypy call -# but it complains when we specify arguments by keyword where positional is fine -# not sure how to fix this .PHONY: typing -typing: gen-extra-tests +typing: clean @echo "running type checks" $(PYTHON) -m mypy --config-file $(PYPROJECT) $(TYPECHECK_ARGS) . -# $(PYTHON) -m ty check muutils/ - -# generates a report of the mypy output +# generate summary report of type check errors grouped by file +# outputs TOML format showing error count per file +# useful for tracking typing progress across large codebases .PHONY: typing-report -typing-report: clean gen-extra-tests +typing-report: @echo "generate a report of the type check output -- errors per file" $(PYTHON) -m mypy --config-file $(PYPROJECT) $(TYPECHECK_ARGS) . | $(PYTHON) -c "$$SCRIPT_MYPY_REPORT" --mode toml +# run tests with pytest +# - automatically runs `make clean` first +# - respects COV and VERBOSE makefile variables +# - pass custom args: make test PYTEST_OPTIONS="--maxfail=1 -x" +# makefile variables: +# COV=1 # generate coverage reports (default: 1) +# VERBOSE=1 # verbose pytest output (default: 0) +# PYTEST_OPTIONS="..." # pass additional pytest arguments +# pytest config in pyproject.toml:[tool.pytest.ini_options] .PHONY: test -test: clean gen-extra-tests +test: clean @echo "running tests" $(PYTHON) -m pytest $(PYTEST_OPTIONS) $(TESTS_DIR) -# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .PHONY: check check: clean format-check test typing @@ -1478,9 +1792,13 @@ docs-combined: docs-md $(PANDOC) -f markdown -t plain $(DOCS_DIR)/combined/$(PACKAGE_NAME).md -o $(DOCS_DIR)/combined/$(PACKAGE_NAME).txt $(PANDOC) -f markdown -t html $(DOCS_DIR)/combined/$(PACKAGE_NAME).md -o $(DOCS_DIR)/combined/$(PACKAGE_NAME).html -# generates coverage reports as html and text with `pytest-cov`, and a badge with `coverage-badge` -# if `.coverage` is not found, will run tests first -# also removes the `.gitignore` file that `coverage html` creates, since we count that as part of the docs +# generate coverage reports from test results +# WARNING: if .coverage file not found, will automatically run `make test` first +# - generates text report: $(COVERAGE_REPORTS_DIR)/coverage.txt +# - generates SVG badge: $(COVERAGE_REPORTS_DIR)/coverage.svg +# - generates HTML report: $(COVERAGE_REPORTS_DIR)/html/ +# - removes .gitignore from html dir (we publish coverage with docs) +# run tests with: make test COV=1 (COV=1 is default) .PHONY: cov cov: @echo "generate coverage reports" @@ -1499,15 +1817,39 @@ cov: docs: cov docs-html docs-combined todo lmcat @echo "generate all documentation and coverage reports" -# removed all generated documentation files, but leaves everything in `$DOCS_RESOURCES_DIR` -# and leaves things defined in `pyproject.toml:tool.makefile.docs.no_clean` -# (templates, svg, css, make_docs.py script) -# distinct from `make clean` +# remove generated documentation files, but preserve resources +# - removes all docs except those in DOCS_RESOURCES_DIR +# - preserves files/patterns specified in pyproject.toml config +# - distinct from `make clean` (which removes temp build files, not docs) +# configure via pyproject.toml:[tool.makefile.docs]: +# [tool.makefile.docs] +# output_dir = "docs" # must match DOCS_DIR in makefile +# no_clean = [ # files/patterns to preserve when cleaning +# "resources/**", +# "*.svg", +# "*.css" +# ] .PHONY: docs-clean docs-clean: @echo "remove generated docs except resources" $(PYTHON) -c "$$SCRIPT_DOCS_CLEAN" $(PYPROJECT) $(DOCS_DIR) $(DOCS_RESOURCES_DIR) + +# get all TODO's from the code +# configure via pyproject.toml:[tool.makefile.inline-todo]: +# [tool.makefile.inline-todo] +# search_dir = "." # directory to search for TODOs +# out_file_base = "docs/other/todo-inline" # output file path (without extension) +# context_lines = 2 # lines of context around each TODO +# extensions = ["py", "md"] # file extensions to search +# tags = ["CRIT", "TODO", "FIXME", "HACK", "BUG", "DOC"] # tags to look for +# exclude = ["docs/**", ".venv/**", "scripts/get_todos.py"] # patterns to exclude +# branch = "main" # git branch for URLs +# # repo_url = "..." # repository URL (defaults to [project.urls.{repository,github}]) +# # template_md = "..." # custom jinja2 template for markdown output +# # template_issue = "..." # custom format string for issues +# # template_html_source = "..." # custom html template path +# tag_label_map = { "BUG" = "bug", "TODO" = "enhancement", "DOC" = "documentation" } # mapping of tags to GitHub issue labels .PHONY: todo todo: @echo "get all TODO's from the code" @@ -1535,10 +1877,13 @@ lmcat: # build and publish # ================================================== -# verifies that the current branch is $(PUBLISH_BRANCH) and that git is clean -# used before publishing +# verify git is ready for publishing +# REQUIRES: +# - current branch must be $(PUBLISH_BRANCH) +# - no uncommitted changes (git status --porcelain must be empty) +# EXITS with error if either condition fails .PHONY: verify-git -verify-git: +verify-git: @echo "checking git status" if [ "$(shell git branch --show-current)" != $(PUBLISH_BRANCH) ]; then \ echo "!!! ERROR !!!"; \ @@ -1555,14 +1900,25 @@ verify-git: fi; \ +# build package distribution files +# creates wheel (.whl) and source distribution (.tar.gz) in dist/ .PHONY: build -build: +build: @echo "build the package" uv build -# gets the commit log, checks everything, builds, and then publishes with twine -# will ask the user to confirm the new version number (and this allows for editing the tag info) -# will also print the contents of $(PYPI_TOKEN_FILE) to the console for the user to copy and paste in when prompted by twine +# publish package to PyPI and create git tag +# PREREQUISITES: +# - must be on $(PUBLISH_BRANCH) branch with clean git status (verified by verify-git) +# - must have $(PYPI_TOKEN_FILE) with your PyPI token +# - version in pyproject.toml must be different from $(LAST_VERSION_FILE) +# PROCESS: +# 1. runs all checks, builds package +# 2. prompts for version confirmation (you can edit $(COMMIT_LOG_FILE) at this point) +# 3. creates git commit updating $(LAST_VERSION_FILE) +# 4. creates annotated git tag with commit log as description +# 5. pushes tag to origin +# 6. uploads to PyPI via twine (you'll paste token from $(PYPI_TOKEN_FILE)) .PHONY: publish publish: gen-commit-log check build verify-git version gen-version-info @echo "run all checks, build, and then publish" @@ -1599,8 +1955,6 @@ publish: gen-commit-log check build verify-git version gen-version-info # removes $(TESTS_TEMP_DIR) to remove temporary test files # recursively removes all `__pycache__` directories and `*.pyc` or `*.pyo` files # distinct from `make docs-clean`, which only removes generated documentation files -# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -# slight modification in last line for extra tests .PHONY: clean clean: @echo "clean up temporary files" @@ -1612,10 +1966,12 @@ clean: rm -rf build rm -rf $(PACKAGE_NAME).egg-info rm -rf $(TESTS_TEMP_DIR) - $(PYTHON_BASE) -Bc "import pathlib; [p.unlink() for path in ['$(PACKAGE_NAME)', '$(TESTS_DIR)', '$(DOCS_DIR)'] for pattern in ['*.py[co]', '__pycache__/*'] for p in pathlib.Path(path).rglob(pattern)]" - rm -rf tests/unit/validate_type/test_validate_type_GENERATED.py -# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + $(PYTHON) -Bc "import pathlib; [p.unlink() for path in ['$(PACKAGE_NAME)', '$(TESTS_DIR)', '$(DOCS_DIR)'] for pattern in ['*.py[co]', '__pycache__/*'] for p in pathlib.Path(path).rglob(pattern)]" +# remove all generated/build files including .venv +# runs: clean + docs-clean + dep-clean +# removes .venv, uv.lock, requirements.txt files, generated docs, build artifacts +# run `make dep` after this to reinstall dependencies .PHONY: clean-all clean-all: clean docs-clean dep-clean @echo "clean up all temporary files, dep files, venv, and generated docs" @@ -1675,10 +2031,28 @@ info-long: info @echo " RUN_GLOBAL = $(RUN_GLOBAL)" @echo " TYPECHECK_ARGS = $(TYPECHECK_ARGS)" -# immediately print out the help targets, and then local variables (but those take a bit longer) +# Smart help command: shows general help, or detailed info about specific targets +# Usage: +# make help - shows general help (list of targets + makefile variables) +# make help="test" - shows detailed info about the 'test' recipe +# make HELP="test clean" - shows detailed info about multiple recipes +# make h=* - shows detailed info about all recipes (wildcard expansion) +# make H="test" - same as HELP (case variations supported) +# +# All variations work: help/HELP/h/H with values like "foo", "foo bar", "*", "--all" .PHONY: help -help: help-targets info - @echo -n "" +help: + @$(eval HELP_ARG := $(or $(HELP),$(help),$(H),$(h))) + @$(eval HELP_EXPANDED := $(if $(filter *,$(HELP_ARG)),--all,$(HELP_ARG))) + @if [ -n "$(HELP_EXPANDED)" ]; then \ + $(PYTHON) -c "$$SCRIPT_RECIPE_INFO" -f makefile $(HELP_EXPANDED); \ + else \ + $(MAKE) --no-print-directory help-targets info; \ + echo ""; \ + echo "To get detailed info about specific make targets, use:"; \ + echo " make help=TARGET or make HELP=\"TARGET1 TARGET2\""; \ + echo " make H=* or make h=--all"; \ + fi ###### ## ## ###### ######## ####### ## ## diff --git a/makefile-old b/makefile-old new file mode 100644 index 00000000..ee992946 --- /dev/null +++ b/makefile-old @@ -0,0 +1,1695 @@ +#|==================================================================| +#| python project makefile template | +#| originally by Michael Ivanitskiy (mivanits@umich.edu) | +#| https://github.com/mivanit/python-project-makefile-template | +#| version: v0.3.4 | +#| license: https://creativecommons.org/licenses/by-sa/4.0/ | +#| modifications from the original should be denoted with `~~~~~` | +#| as this makes it easier to find edits when updating makefile | +#|==================================================================| + + + ###### ######## ###### +## ## ## ## ## +## ## ## +## ###### ## #### +## ## ## ## +## ## ## ## ## + ###### ## ###### + +# ================================================== +# configuration & variables +# ================================================== + +# it assumes that the source is in a directory named the same as the package name +# this also gets passed to some other places +PACKAGE_NAME := muutils + +# for checking you are on the right branch when publishing +PUBLISH_BRANCH := main + +# where to put docs +# if you change this, you must also change pyproject.toml:tool.makefile.docs.output_dir to match +DOCS_DIR := docs + +# where the tests are, for pytest +TESTS_DIR := tests + +# tests temp directory to clean up. will remove this in `make clean` +TESTS_TEMP_DIR := $(TESTS_DIR)/_temp/ + +# probably don't change these: +# -------------------------------------------------- + +# where the pyproject.toml file is. no idea why you would change this but just in case +PYPROJECT := pyproject.toml + +# dir to store various configuration files +# use of `.meta/` inspired by https://news.ycombinator.com/item?id=36472613 +META_DIR := .meta + +# requirements.txt files for base package, all extras, dev, and all +REQUIREMENTS_DIR := $(META_DIR)/requirements + +# local files (don't push this to git!) +LOCAL_DIR := $(META_DIR)/local + +# will print this token when publishing. make sure not to commit this file!!! +PYPI_TOKEN_FILE := $(LOCAL_DIR)/.pypi-token + +# version files +VERSIONS_DIR := $(META_DIR)/versions + +# the last version that was auto-uploaded. will use this to create a commit log for version tag +# see `gen-commit-log` target +LAST_VERSION_FILE := $(VERSIONS_DIR)/.lastversion + +# current version (writing to file needed due to shell escaping issues) +VERSION_FILE := $(VERSIONS_DIR)/.version + +# base python to use. Will add `uv run` in front of this if `RUN_GLOBAL` is not set to 1 +PYTHON_BASE := python + +# where the commit log will be stored +COMMIT_LOG_FILE := $(LOCAL_DIR)/.commit_log + +# pandoc commands (for docs) +PANDOC ?= pandoc + +# where to put the coverage reports +# note that this will be published with the docs! +# modify the `docs` targets and `.gitignore` if you don't want that +COVERAGE_REPORTS_DIR := $(DOCS_DIR)/coverage + +# this stuff in the docs will be kept +# in addition to anything specified in `pyproject.toml:tool.makefile.docs.no_clean` +DOCS_RESOURCES_DIR := $(DOCS_DIR)/resources + +# location of the make docs script +MAKE_DOCS_SCRIPT_PATH := $(DOCS_RESOURCES_DIR)/make_docs.py + +# version vars - extracted automatically from `pyproject.toml`, `$(LAST_VERSION_FILE)`, and $(PYTHON) +# -------------------------------------------------- + +# assuming your `pyproject.toml` has a line that looks like `version = "0.0.1"`, `gen-version-info` will extract this +PROJ_VERSION := NULL +# `gen-version-info` will read the last version from `$(LAST_VERSION_FILE)`, or `NULL` if it doesn't exist +LAST_VERSION := NULL +# get the python version, now that we have picked the python command +PYTHON_VERSION := NULL + + +# ================================================== +# reading command line options +# ================================================== + +# for formatting or something, we might want to run python without uv +# RUN_GLOBAL=1 to use global `PYTHON_BASE` instead of `uv run $(PYTHON_BASE)` +RUN_GLOBAL ?= 0 + +# for running tests or other commands without updating the env, set this to 1 +# and it will pass `--no-sync` to `uv run` +UV_NOSYNC ?= 0 + +ifeq ($(RUN_GLOBAL),0) + ifeq ($(UV_NOSYNC),1) + PYTHON = uv run --no-sync $(PYTHON_BASE) + else + PYTHON = uv run $(PYTHON_BASE) + endif +else + PYTHON = $(PYTHON_BASE) +endif + +# if you want different behavior for different python versions +# -------------------------------------------------- +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# compatibility mode for python <3.10 + +# loose typing, allow warnings for python <3.10 +# -------------------------------------------------- +TYPECHECK_ARGS ?= +# COMPATIBILITY_MODE: whether to run in compatibility mode for python <3.10 +COMPATIBILITY_MODE := $(shell $(PYTHON) -c "import sys; print(1 if sys.version_info < (3, 11) else 0)") + +# compatibility mode for python <3.10 +# -------------------------------------------------- + +# whether to run pytest with warnings as errors +WARN_STRICT ?= 0 + +ifneq ($(WARN_STRICT), 0) + PYTEST_OPTIONS += -W error +endif + +# compatibility mode for python <3.10 +ifeq ($(COMPATIBILITY_MODE), 1) + JUNK := $(info !!! WARNING !!!: Detected python version less than 3.10, some behavior will be different) + TYPECHECK_ARGS += --disable-error-code misc --disable-error-code syntax --disable-error-code import-not-found --no-check-untyped-defs +endif + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +# options we might want to pass to pytest +# -------------------------------------------------- + +# base options for pytest, will be appended to if `COV` or `VERBOSE` are 1. +# user can also set this when running make to add more options +PYTEST_OPTIONS ?= + +# set to `1` to run pytest with `--cov=.` to get coverage reports in a `.coverage` file +COV ?= 1 +# set to `1` to run pytest with `--verbose` +VERBOSE ?= 0 + +ifeq ($(VERBOSE),1) + PYTEST_OPTIONS += --verbose +endif + +ifeq ($(COV),1) + PYTEST_OPTIONS += --cov=. +endif + +# ================================================== +# default target (help) +# ================================================== + +# first/default target is help +.PHONY: default +default: help + + + + ###### ###### ######## #### ######## ######## ###### +## ## ## ## ## ## ## ## ## ## ## ## +## ## ## ## ## ## ## ## ## + ###### ## ######## ## ######## ## ###### + ## ## ## ## ## ## ## ## +## ## ## ## ## ## ## ## ## ## ## + ###### ###### ## ## #### ## ## ###### + +# ================================================== +# python scripts we want to use inside the makefile +# when developing, these are populated by `scripts/assemble_make.py` +# ================================================== + +# create commands for exporting requirements as specified in `pyproject.toml:tool.uv-exports.exports` +define SCRIPT_EXPORT_REQUIREMENTS +# source: https://github.com/mivanit/python-project-makefile-template/tree/main/scripts/make/export_requirements.py + +"export to requirements.txt files based on pyproject.toml configuration" + +from __future__ import annotations + +import sys +import warnings + +try: + import tomllib # type: ignore[import-not-found] +except ImportError: + import tomli as tomllib # type: ignore +from functools import reduce +from pathlib import Path +from typing import Any, Dict, List, Union + +TOOL_PATH: str = "tool.makefile.uv-exports" + + +def deep_get(d: dict, path: str, default: Any = None, sep: str = ".") -> Any: # noqa: ANN401 + "get a value from a nested dictionary" + return reduce( + lambda x, y: x.get(y, default) if isinstance(x, dict) else default, # function + path.split(sep) if isinstance(path, str) else path, # sequence + d, # initial + ) + + +def export_configuration( + export: dict, + all_groups: List[str], + all_extras: List[str], + export_opts: dict, + output_dir: Path, +) -> None: + "print to console a uv command for make which will export a requirements.txt file" + # get name and validate + name = export.get("name") + if not name or not name.isalnum(): + warnings.warn( + f"Export configuration missing valid 'name' field {export}", + ) + return + + # get other options with default fallbacks + filename: str = export.get("filename") or f"requirements-{name}.txt" + groups: Union[List[str], bool, None] = export.get("groups") + extras: Union[List[str], bool] = export.get("extras", []) + options: List[str] = export.get("options", []) + + # init command + cmd: List[str] = ["uv", "export", *export_opts.get("args", [])] + + # handle groups + if groups is not None: + groups_list: List[str] = [] + if isinstance(groups, bool): + if groups: + groups_list = all_groups.copy() + else: + groups_list = groups + + for group in all_groups: + if group in groups_list: + cmd.extend(["--group", group]) + else: + cmd.extend(["--no-group", group]) + + # handle extras + extras_list: List[str] = [] + if isinstance(extras, bool): + if extras: + extras_list = all_extras.copy() + else: + extras_list = extras + + for extra in extras_list: + cmd.extend(["--extra", extra]) + + # add extra options + cmd.extend(options) + + # assemble the command and print to console -- makefile will run it + output_path = output_dir / filename + print(f"{' '.join(cmd)} > {output_path.as_posix()}") + + +def main( + pyproject_path: Path, + output_dir: Path, +) -> None: + "export to requirements.txt files based on pyproject.toml configuration" + # read pyproject.toml + with open(pyproject_path, "rb") as f: + pyproject_data: dict = tomllib.load(f) + + # all available groups + all_groups: List[str] = list(pyproject_data.get("dependency-groups", {}).keys()) + all_extras: List[str] = list( + deep_get(pyproject_data, "project.optional-dependencies", {}).keys(), + ) + + # options for exporting + export_opts: dict = deep_get(pyproject_data, TOOL_PATH, {}) + + # what are we exporting? + exports: List[Dict[str, Any]] = export_opts.get("exports", []) + if not exports: + exports = [{"name": "all", "groups": [], "extras": [], "options": []}] + + # export each configuration + for export in exports: + export_configuration( + export=export, + all_groups=all_groups, + all_extras=all_extras, + export_opts=export_opts, + output_dir=output_dir, + ) + + +if __name__ == "__main__": + main( + pyproject_path=Path(sys.argv[1]), + output_dir=Path(sys.argv[2]), + ) + +endef + +export SCRIPT_EXPORT_REQUIREMENTS + + +# get the version from `pyproject.toml:project.version` +define SCRIPT_GET_VERSION +# source: https://github.com/mivanit/python-project-makefile-template/tree/main/scripts/make/get_version.py + +"write the current version of the project to a file" + +from __future__ import annotations + +import sys + +try: + try: + import tomllib # type: ignore[import-not-found] + except ImportError: + import tomli as tomllib # type: ignore + + pyproject_path: str = sys.argv[1].strip() + + with open(pyproject_path, "rb") as f: + pyproject_data: dict = tomllib.load(f) + + print("v" + pyproject_data["project"]["version"], end="") +except Exception: # noqa: BLE001 + print("NULL", end="") + sys.exit(1) + +endef + +export SCRIPT_GET_VERSION + + +# get the commit log since the last version from `$(LAST_VERSION_FILE)` +define SCRIPT_GET_COMMIT_LOG +# source: https://github.com/mivanit/python-project-makefile-template/tree/main/scripts/make/get_commit_log.py + +"pretty print a commit log amd wrote it to a file" + +from __future__ import annotations + +import subprocess +import sys +from typing import List + + +def main( + last_version: str, + commit_log_file: str, +) -> None: + "pretty print a commit log amd wrote it to a file" + if last_version == "NULL": + print("!!! ERROR !!!", file=sys.stderr) + print("LAST_VERSION is NULL, can't get commit log!", file=sys.stderr) + sys.exit(1) + + try: + log_cmd: List[str] = [ + "git", + "log", + f"{last_version}..HEAD", + "--pretty=format:- %s (%h)", + ] + commits: List[str] = ( + subprocess.check_output(log_cmd).decode("utf-8").strip().split("\n") # noqa: S603 + ) + with open(commit_log_file, "w") as f: + f.write("\n".join(reversed(commits))) + except subprocess.CalledProcessError as e: + print(f"Error: {e}", file=sys.stderr) + sys.exit(1) + + +if __name__ == "__main__": + main( + last_version=sys.argv[1].strip(), + commit_log_file=sys.argv[2].strip(), + ) + +endef + +export SCRIPT_GET_COMMIT_LOG + + +# get cuda information and whether torch sees it +define SCRIPT_CHECK_TORCH +# source: https://github.com/mivanit/python-project-makefile-template/tree/main/scripts/make/check_torch.py + +"print info about current python, torch, cuda, and devices" + +from __future__ import annotations + +import os +import re +import subprocess +import sys +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + + +def print_info_dict( + info: Dict[str, Union[Any, Dict[str, Any]]], + indent: str = " ", + level: int = 1, +) -> None: + "pretty print the info" + indent_str: str = indent * level + longest_key_len: int = max(map(len, info.keys())) + for key, value in info.items(): + if isinstance(value, dict): + print(f"{indent_str}{key:<{longest_key_len}}:") + print_info_dict(value, indent, level + 1) + else: + print(f"{indent_str}{key:<{longest_key_len}} = {value}") + + +def get_nvcc_info() -> Dict[str, str]: + "get info about cuda from nvcc --version" + # Run the nvcc command. + try: + result: subprocess.CompletedProcess[str] = subprocess.run( # noqa: S603 + ["nvcc", "--version"], # noqa: S607 + check=True, + capture_output=True, + text=True, + ) + except Exception as e: # noqa: BLE001 + return {"Failed to run 'nvcc --version'": str(e)} + + output: str = result.stdout + lines: List[str] = [line.strip() for line in output.splitlines() if line.strip()] + + # Ensure there are exactly 5 lines in the output. + assert len(lines) == 5, ( # noqa: PLR2004 + f"Expected exactly 5 lines from nvcc --version, got {len(lines)} lines:\n{output}" + ) + + # Compile shared regex for release info. + release_regex: re.Pattern = re.compile( + r"Cuda compilation tools,\s*release\s*([^,]+),\s*(V.+)", + ) + + # Define a mapping for each desired field: + # key -> (line index, regex pattern, group index, transformation function) + patterns: Dict[str, Tuple[int, re.Pattern, int, Callable[[str], str]]] = { + "build_time": ( + 2, + re.compile(r"Built on (.+)"), + 1, + lambda s: s.replace("_", " "), + ), + "release": (3, release_regex, 1, str.strip), + "release_V": (3, release_regex, 2, str.strip), + "build": (4, re.compile(r"Build (.+)"), 1, str.strip), + } + + info: Dict[str, str] = {} + for key, (line_index, pattern, group_index, transform) in patterns.items(): + match: Optional[re.Match] = pattern.search(lines[line_index]) + if not match: + err_msg: str = ( + f"Unable to parse {key} from nvcc output: {lines[line_index]}" + ) + raise ValueError(err_msg) + info[key] = transform(match.group(group_index)) + + info["release_short"] = info["release"].replace(".", "").strip() + + return info + + +def get_torch_info() -> Tuple[List[Exception], Dict[str, Any]]: + "get info about pytorch and cuda devices" + exceptions: List[Exception] = [] + info: Dict[str, Any] = {} + + try: + import torch + except ImportError as e: + info["torch.__version__"] = "not available" + exceptions.append(e) + return exceptions, info + + try: + info["torch.__version__"] = torch.__version__ + info["torch.cuda.is_available()"] = torch.cuda.is_available() + + if torch.cuda.is_available(): + info["torch.version.cuda"] = torch.version.cuda + info["torch.cuda.device_count()"] = torch.cuda.device_count() + + if torch.cuda.device_count() > 0: + info["torch.cuda.current_device()"] = torch.cuda.current_device() + n_devices: int = torch.cuda.device_count() + info["n_devices"] = n_devices + for current_device in range(n_devices): + try: + current_device_info: Dict[str, Union[str, int]] = {} + + dev_prop = torch.cuda.get_device_properties( + torch.device(f"cuda:{current_device}"), + ) + + current_device_info["name"] = dev_prop.name + current_device_info["version"] = ( + f"{dev_prop.major}.{dev_prop.minor}" + ) + current_device_info["total_memory"] = ( + f"{dev_prop.total_memory} ({dev_prop.total_memory:.1e})" + ) + current_device_info["multi_processor_count"] = ( + dev_prop.multi_processor_count + ) + current_device_info["is_integrated"] = dev_prop.is_integrated + current_device_info["is_multi_gpu_board"] = ( + dev_prop.is_multi_gpu_board + ) + + info[f"device cuda:{current_device}"] = current_device_info + + except Exception as e: # noqa: PERF203,BLE001 + exceptions.append(e) + else: + err_msg_nodevice: str = ( + f"{torch.cuda.device_count() = } devices detected, invalid" + ) + raise ValueError(err_msg_nodevice) # noqa: TRY301 + + else: + err_msg_nocuda: str = ( + f"CUDA is NOT available in torch: {torch.cuda.is_available() = }" + ) + raise ValueError(err_msg_nocuda) # noqa: TRY301 + + except Exception as e: # noqa: BLE001 + exceptions.append(e) + + return exceptions, info + + +if __name__ == "__main__": + print(f"python: {sys.version}") + print_info_dict( + { + "python executable path: sys.executable": str(sys.executable), + "sys.platform": sys.platform, + "current working directory: os.getcwd()": os.getcwd(), # noqa: PTH109 + "Host name: os.name": os.name, + "CPU count: os.cpu_count()": str(os.cpu_count()), + }, + ) + + nvcc_info: Dict[str, Any] = get_nvcc_info() + print("nvcc:") + print_info_dict(nvcc_info) + + torch_exceptions, torch_info = get_torch_info() + print("torch:") + print_info_dict(torch_info) + + if torch_exceptions: + print("torch_exceptions:") + for e in torch_exceptions: + print(f" {e}") + +endef + +export SCRIPT_CHECK_TORCH + + +# get todo's from the code +define SCRIPT_GET_TODOS +# source: https://github.com/mivanit/python-project-makefile-template/tree/main/scripts/make/get_todos.py + +"read all TODO type comments and write them to markdown, jsonl, html. configurable in pyproject.toml" + +from __future__ import annotations + +import argparse +import fnmatch +import json +import textwrap +import urllib.parse +import warnings +from dataclasses import asdict, dataclass, field +from functools import reduce +from pathlib import Path +from typing import Any, Dict, List, Union + +from jinja2 import Template + +try: + import tomllib # type: ignore[import-not-found] +except ImportError: + import tomli as tomllib # type: ignore + +TOOL_PATH: str = "tool.makefile.inline-todo" + + +def deep_get(d: dict, path: str, default: Any = None, sep: str = ".") -> Any: # noqa: ANN401 + "get a value from a nested dictionary" + return reduce( + lambda x, y: x.get(y, default) if isinstance(x, dict) else default, # function + path.split(sep) if isinstance(path, str) else path, # sequence + d, # initial + ) + + +_TEMPLATE_MD_LIST: str = """\ +# Inline TODOs + +{% for tag, file_map in grouped|dictsort %} +# {{ tag }} +{% for filepath, item_list in file_map|dictsort %} +## [`{{ filepath }}`](/{{ filepath }}) +{% for itm in item_list %} +- {{ itm.stripped_title }} + local link: [`/{{ filepath }}:{{ itm.line_num }}`](/{{ filepath }}#L{{ itm.line_num }}) + | view on GitHub: [{{ itm.file }}#L{{ itm.line_num }}]({{ itm.code_url | safe }}) + | [Make Issue]({{ itm.issue_url | safe }}) +{% if itm.context %} + ```{{ itm.file_lang }} +{{ itm.context_indented }} + ``` +{% endif %} +{% endfor %} + +{% endfor %} +{% endfor %} +""" + +_TEMPLATE_MD_TABLE: str = """\ +# Inline TODOs + +| Location | Tag | Todo | GitHub | Issue | +|:---------|:----|:-----|:-------|:------| +{% for itm in all_items %}| [`{{ itm.file }}:{{ itm.line_num }}`](/{{ itm.file }}#L{{ itm.line_num }}) | {{ itm.tag }} | {{ itm.stripped_title_escaped }} | [View]({{ itm.code_url | safe }}) | [Create]({{ itm.issue_url | safe }}) | +{% endfor %} +""" + +TEMPLATES_MD: Dict[str, str] = dict( + standard=_TEMPLATE_MD_LIST, + table=_TEMPLATE_MD_TABLE, +) + +TEMPLATE_ISSUE: str = """\ +# source + +[`{file}#L{line_num}`]({code_url}) + +# context +```{file_lang} +{context} +``` +""" + + +@dataclass +class Config: + """Configuration for the inline-todo scraper""" + + search_dir: Path = Path() + out_file_base: Path = Path("docs/todo-inline") + tags: List[str] = field( + default_factory=lambda: ["CRIT", "TODO", "FIXME", "HACK", "BUG"], + ) + extensions: List[str] = field(default_factory=lambda: ["py", "md"]) + exclude: List[str] = field(default_factory=lambda: ["docs/**", ".venv/**"]) + context_lines: int = 2 + valid_post_tag: Union[str, List[str]] = " \t:<>|[](){{}}" + valid_pre_tag: Union[str, List[str]] = " \t:<>|[](){{}}#" + tag_label_map: Dict[str, str] = field( + default_factory=lambda: { + "CRIT": "bug", + "TODO": "enhancement", + "FIXME": "bug", + "BUG": "bug", + "HACK": "enhancement", + }, + ) + extension_lang_map: Dict[str, str] = field( + default_factory=lambda: { + "py": "python", + "md": "markdown", + "html": "html", + "css": "css", + "js": "javascript", + }, + ) + + templates_md: dict[str, str] = field(default_factory=lambda: TEMPLATES_MD) + # templates for the output markdown file + + template_issue: str = TEMPLATE_ISSUE + # template for the issue creation + + template_html_source: Path = Path("docs/resources/templates/todo-template.html") + # template source for the output html file (interactive table) + + @property + def template_html(self) -> str: + "read the html template" + return self.template_html_source.read_text(encoding="utf-8") + + template_code_url_: str = "{repo_url}/blob/{branch}/{file}#L{line_num}" + # template for the code url + + @property + def template_code_url(self) -> str: + "code url with repo url and branch substituted" + return self.template_code_url_.replace("{repo_url}", self.repo_url).replace( + "{branch}", + self.branch, + ) + + repo_url: str = "UNKNOWN" + # for the issue creation url + + branch: str = "main" + # branch for links to files on github + + @classmethod + def read(cls, config_file: Path) -> Config: + "read from a file, or return default" + output: Config + if config_file.is_file(): + # read file and load if present + with config_file.open("rb") as f: + data: Dict[str, Any] = tomllib.load(f) + + # try to get the repo url + repo_url: str = "UNKNOWN" + try: + urls: Dict[str, str] = { + k.lower(): v for k, v in data["project"]["urls"].items() + } + if "repository" in urls: + repo_url = urls["repository"] + if "github" in urls: + repo_url = urls["github"] + except Exception as e: # noqa: BLE001 + warnings.warn( + f"No repository URL found in pyproject.toml, 'make issue' links will not work.\n{e}", + ) + + # load the inline-todo config if present + data_inline_todo: Dict[str, Any] = deep_get( + d=data, + path=TOOL_PATH, + default={}, + ) + + if "repo_url" not in data_inline_todo: + data_inline_todo["repo_url"] = repo_url + + output = cls.load(data_inline_todo) + else: + # return default otherwise + output = cls() + + return output + + @classmethod + def load(cls, data: dict) -> Config: + "load from a dictionary, converting to `Path` as needed" + # process variables that should be paths + data = { + k: Path(v) + if k in {"search_dir", "out_file_base", "template_html_source"} + else v + for k, v in data.items() + } + + # default value for the templates + data["templates_md"] = { + **TEMPLATES_MD, + **data.get("templates_md", {}), + } + + return cls(**data) + + +CFG: Config = Config() +# this is messy, but we use a global config so we can get `TodoItem().issue_url` to work + + +@dataclass +class TodoItem: + """Holds one todo occurrence""" + + tag: str + file: str + line_num: int + content: str + context: str = "" + + def serialize(self) -> Dict[str, Union[str, int]]: + "serialize to a dict we can dump to json" + return { + **asdict(self), + "issue_url": self.issue_url, + "file_lang": self.file_lang, + "stripped_title": self.stripped_title, + "code_url": self.code_url, + } + + @property + def context_indented(self) -> str: + """Returns the context with each line indented""" + dedented: str = textwrap.dedent(self.context) + return textwrap.indent(dedented, " ") + + @property + def code_url(self) -> str: + """Returns a URL to the code on GitHub""" + return CFG.template_code_url.format( + file=self.file, + line_num=self.line_num, + ) + + @property + def stripped_title(self) -> str: + """Returns the title of the issue, stripped of the tag""" + return self.content.split(self.tag, 1)[-1].lstrip(":").strip() + + @property + def stripped_title_escaped(self) -> str: + """Returns the title of the issue, stripped of the tag and escaped for markdown""" + return self.stripped_title.replace("|", "\\|") + + @property + def issue_url(self) -> str: + """Constructs a GitHub issue creation URL for a given TodoItem.""" + # title + title: str = self.stripped_title + if not title: + title = "Issue from inline todo" + # body + body: str = CFG.template_issue.format( + file=self.file, + line_num=self.line_num, + context=self.context, + context_indented=self.context_indented, + code_url=self.code_url, + file_lang=self.file_lang, + ).strip() + # labels + label: str = CFG.tag_label_map.get(self.tag, self.tag) + # assemble url + query: Dict[str, str] = dict(title=title, body=body, labels=label) + query_string: str = urllib.parse.urlencode(query, quote_via=urllib.parse.quote) + return f"{CFG.repo_url}/issues/new?{query_string}" + + @property + def file_lang(self) -> str: + """Returns the language for the file extension""" + ext: str = Path(self.file).suffix.lstrip(".") + return CFG.extension_lang_map.get(ext, ext) + + +def scrape_file( + file_path: Path, + cfg: Config, +) -> List[TodoItem]: + """Scrapes a file for lines containing any of the specified tags""" + items: List[TodoItem] = [] + if not file_path.is_file(): + return items + lines: List[str] = file_path.read_text(encoding="utf-8").splitlines(True) + + # over all lines + for i, line in enumerate(lines): + # over all tags + for tag in cfg.tags: + # check tag is present + if tag in line[:200]: + # check tag is surrounded by valid strings + tag_idx_start: int = line.index(tag) + tag_idx_end: int = tag_idx_start + len(tag) + if ( + line[tag_idx_start - 1] in cfg.valid_pre_tag + and line[tag_idx_end] in cfg.valid_post_tag + ): + # get the context and add the item + start: int = max(0, i - cfg.context_lines) + end: int = min(len(lines), i + cfg.context_lines + 1) + snippet: str = "".join(lines[start:end]) + items.append( + TodoItem( + tag=tag, + file=file_path.as_posix(), + line_num=i + 1, + content=line.strip("\n"), + context=snippet.strip("\n"), + ), + ) + break + return items + + +def collect_files( + search_dir: Path, + extensions: List[str], + exclude: List[str], +) -> List[Path]: + """Recursively collects all files with specified extensions, excluding matches via globs""" + results: List[Path] = [] + for ext in extensions: + results.extend(search_dir.rglob(f"*.{ext}")) + + return [ + f + for f in results + if not any(fnmatch.fnmatch(f.as_posix(), pattern) for pattern in exclude) + ] + + +def group_items_by_tag_and_file( + items: List[TodoItem], +) -> Dict[str, Dict[str, List[TodoItem]]]: + """Groups items by tag, then by file""" + grouped: Dict[str, Dict[str, List[TodoItem]]] = {} + for itm in items: + grouped.setdefault(itm.tag, {}).setdefault(itm.file, []).append(itm) + for tag_dict in grouped.values(): + for file_list in tag_dict.values(): + file_list.sort(key=lambda x: x.line_num) + return grouped + + +def main(config_file: Path) -> None: + "cli interface to get todos" + global CFG # noqa: PLW0603 + # read configuration + cfg: Config = Config.read(config_file) + CFG = cfg + + # get data + files: List[Path] = collect_files(cfg.search_dir, cfg.extensions, cfg.exclude) + all_items: List[TodoItem] = [] + n_files: int = len(files) + for i, fpath in enumerate(files): + print(f"Scraping {i + 1:>2}/{n_files:>2}: {fpath.as_posix():<60}", end="\r") + all_items.extend(scrape_file(fpath, cfg)) + + # create dir + cfg.out_file_base.parent.mkdir(parents=True, exist_ok=True) + + # write raw to jsonl + with open(cfg.out_file_base.with_suffix(".jsonl"), "w", encoding="utf-8") as f: + for itm in all_items: + f.write(json.dumps(itm.serialize()) + "\n") + + # group, render + grouped: Dict[str, Dict[str, List[TodoItem]]] = group_items_by_tag_and_file( + all_items, + ) + + # render each template and save + for template_key, template in cfg.templates_md.items(): + rendered: str = Template(template).render(grouped=grouped, all_items=all_items) + template_out_path: Path = Path( + cfg.out_file_base.with_stem( + cfg.out_file_base.stem + f"-{template_key}", + ).with_suffix(".md"), + ) + template_out_path.write_text(rendered, encoding="utf-8") + + # write html output + try: + html_rendered: str = cfg.template_html.replace( + "//{{DATA}}//", + json.dumps([itm.serialize() for itm in all_items]), + ) + cfg.out_file_base.with_suffix(".html").write_text( + html_rendered, + encoding="utf-8", + ) + except Exception as e: # noqa: BLE001 + warnings.warn(f"Failed to write html output: {e}") + + print("wrote to:") + print(cfg.out_file_base.with_suffix(".md").as_posix()) + + +if __name__ == "__main__": + # parse args + parser: argparse.ArgumentParser = argparse.ArgumentParser("inline_todo") + parser.add_argument( + "--config-file", + default="pyproject.toml", + help="Path to the TOML config, will look under [tool.inline-todo].", + ) + args: argparse.Namespace = parser.parse_args() + # call main + main(Path(args.config_file)) + +endef + +export SCRIPT_GET_TODOS + + +# markdown to html using pdoc +define SCRIPT_PDOC_MARKDOWN2_CLI +# source: https://github.com/mivanit/python-project-makefile-template/tree/main/scripts/make/pdoc_markdown2_cli.py + +"cli to convert markdown files to HTML using pdoc's markdown2" + +from __future__ import annotations + +import argparse +from pathlib import Path +from typing import Optional + +from pdoc.markdown2 import Markdown, _safe_mode # type: ignore + + +def convert_file( + input_path: Path, + output_path: Path, + safe_mode: Optional[_safe_mode] = None, + encoding: str = "utf-8", +) -> None: + """Convert a markdown file to HTML""" + # Read markdown input + text: str = input_path.read_text(encoding=encoding) + + # Convert to HTML using markdown2 + markdown: Markdown = Markdown( + extras=["fenced-code-blocks", "header-ids", "markdown-in-html", "tables"], + safe_mode=safe_mode, + ) + html: str = markdown.convert(text) + + # Write HTML output + output_path.write_text(str(html), encoding=encoding) + + +def main() -> None: + "cli entry point" + parser: argparse.ArgumentParser = argparse.ArgumentParser( + description="Convert markdown files to HTML using pdoc's markdown2", + ) + parser.add_argument("input", type=Path, help="Input markdown file path") + parser.add_argument("output", type=Path, help="Output HTML file path") + parser.add_argument( + "--safe-mode", + choices=["escape", "replace"], + help="Sanitize literal HTML: 'escape' escapes HTML meta chars, 'replace' replaces with [HTML_REMOVED]", + ) + parser.add_argument( + "--encoding", + default="utf-8", + help="Character encoding for reading/writing files (default: utf-8)", + ) + + args: argparse.Namespace = parser.parse_args() + + convert_file( + args.input, + args.output, + safe_mode=args.safe_mode, + encoding=args.encoding, + ) + + +if __name__ == "__main__": + main() + +endef + +export SCRIPT_PDOC_MARKDOWN2_CLI + +# clean up the docs (configurable in pyproject.toml) +define SCRIPT_DOCS_CLEAN +# source: https://github.com/mivanit/python-project-makefile-template/tree/main/scripts/make/docs_clean.py + +"clean up docs directory based on pyproject.toml configuration" + +from __future__ import annotations + +import shutil +import sys +from functools import reduce +from pathlib import Path +from typing import Any, List, Set + +try: + import tomllib # type: ignore[import-not-found] +except ImportError: + import tomli as tomllib # type: ignore + +TOOL_PATH: str = "tool.makefile.docs" +DEFAULT_DOCS_DIR: str = "docs" + + +def deep_get(d: dict, path: str, default: Any = None, sep: str = ".") -> Any: # noqa: ANN401 + """Get nested dictionary value via separated path with default.""" + return reduce( + lambda x, y: x.get(y, default) if isinstance(x, dict) else default, # function + path.split(sep) if isinstance(path, str) else path, # sequence + d, # initial + ) + + +def read_config(pyproject_path: Path) -> tuple[Path, Set[Path]]: + "read configuration from pyproject.toml" + if not pyproject_path.is_file(): + return Path(DEFAULT_DOCS_DIR), set() + + with pyproject_path.open("rb") as f: + config = tomllib.load(f) + + preserved: List[str] = deep_get(config, f"{TOOL_PATH}.no_clean", []) + docs_dir: Path = Path(deep_get(config, f"{TOOL_PATH}.output_dir", DEFAULT_DOCS_DIR)) + + # Convert to absolute paths and validate + preserve_set: Set[Path] = set() + for p in preserved: + full_path = (docs_dir / p).resolve() + if not full_path.as_posix().startswith(docs_dir.resolve().as_posix()): + err_msg: str = f"Preserved path '{p}' must be within docs directory" + raise ValueError(err_msg) + preserve_set.add(docs_dir / p) + + return docs_dir, preserve_set + + +def clean_docs(docs_dir: Path, preserved: Set[Path]) -> None: + """delete files not in preserved set + + TODO: this is not recursive + """ + for path in docs_dir.iterdir(): + if path.is_file() and path not in preserved: + path.unlink() + elif path.is_dir() and path not in preserved: + shutil.rmtree(path) + + +def main( + pyproject_path: str, + docs_dir_cli: str, + extra_preserve: list[str], +) -> None: + "Clean up docs directory based on pyproject.toml configuration." + docs_dir: Path + preserved: Set[Path] + docs_dir, preserved = read_config(Path(pyproject_path)) + + assert docs_dir.is_dir(), f"Docs directory '{docs_dir}' not found" + assert docs_dir == Path(docs_dir_cli), ( + f"Docs directory mismatch: {docs_dir = } != {docs_dir_cli = }. this is probably because you changed one of `pyproject.toml:{TOOL_PATH}.output_dir` (the former) or `makefile:DOCS_DIR` (the latter) without updating the other." + ) + + for x in extra_preserve: + preserved.add(Path(x)) + clean_docs(docs_dir, preserved) + + +if __name__ == "__main__": + main(sys.argv[1], sys.argv[2], sys.argv[3:]) + +endef + +export SCRIPT_DOCS_CLEAN + +# generate a report of the mypy output +define SCRIPT_MYPY_REPORT +# source: https://github.com/mivanit/python-project-makefile-template/tree/main/scripts/make/mypy_report.py + +"usage: mypy ... | mypy_report.py [--mode jsonl|exclude]" + +from __future__ import annotations + +import argparse +import json +import re +import sys +from pathlib import Path +from typing import Dict, List, Tuple + + +def parse_mypy_output(lines: List[str]) -> Dict[str, int]: + "given mypy output, turn it into a dict of `filename: error_count`" + pattern: re.Pattern[str] = re.compile(r"^(?P[^:]+):\d+:\s+error:") + counts: Dict[str, int] = {} + for line in lines: + m = pattern.match(line) + if m: + f_raw: str = m.group("file") + f_norm: str = Path(f_raw).as_posix() + counts[f_norm] = counts.get(f_norm, 0) + 1 + return counts + + +def main() -> None: + "cli interface for mypy_report" + parser: argparse.ArgumentParser = argparse.ArgumentParser() + parser.add_argument("--mode", choices=["jsonl", "toml"], default="jsonl") + args: argparse.Namespace = parser.parse_args() + lines: List[str] = sys.stdin.read().splitlines() + error_dict: Dict[str, int] = parse_mypy_output(lines) + sorted_errors: List[Tuple[str, int]] = sorted( + error_dict.items(), + key=lambda x: x[1], + ) + if len(sorted_errors) == 0: + print("# no errors found!") + return + if args.mode == "jsonl": + for fname, count in sorted_errors: + print(json.dumps({"filename": fname, "errors": count})) + elif args.mode == "toml": + for fname, count in sorted_errors: + print(f'"{fname}", # {count}') + else: + err_msg: str = f"unknown mode {args.mode}" + raise ValueError(err_msg) + print(f"# total errors: {sum(error_dict.values())}") + + +if __name__ == "__main__": + main() + +endef + +export SCRIPT_MYPY_REPORT + + +## ## ######## ######## ###### #### ####### ## ## +## ## ## ## ## ## ## ## ## ## ### ## +## ## ## ## ## ## ## ## ## #### ## +## ## ###### ######## ###### ## ## ## ## ## ## + ## ## ## ## ## ## ## ## ## ## #### + ## ## ## ## ## ## ## ## ## ## ## ### + ### ######## ## ## ###### #### ####### ## ## + +# ================================================== +# getting version info +# we do this in a separate target because it takes a bit of time +# ================================================== + +# this recipe is weird. we need it because: +# - a one liner for getting the version with toml is unwieldy, and using regex is fragile +# - using $$SCRIPT_GET_VERSION within $(shell ...) doesn't work because of escaping issues +# - trying to write to the file inside the `gen-version-info` recipe doesn't work, +# shell eval happens before our `python -c ...` gets run and `cat` doesn't see the new file +.PHONY: write-proj-version +write-proj-version: + @mkdir -p $(VERSIONS_DIR) + @$(PYTHON) -c "$$SCRIPT_GET_VERSION" "$(PYPROJECT)" > $(VERSION_FILE) + +# gets version info from $(PYPROJECT), last version from $(LAST_VERSION_FILE), and python version +# uses just `python` for everything except getting the python version. no echo here, because this is "private" +.PHONY: gen-version-info +gen-version-info: write-proj-version + @mkdir -p $(LOCAL_DIR) + $(eval PROJ_VERSION := $(shell cat $(VERSION_FILE)) ) + $(eval LAST_VERSION := $(shell [ -f $(LAST_VERSION_FILE) ] && cat $(LAST_VERSION_FILE) || echo NULL) ) + $(eval PYTHON_VERSION := $(shell $(PYTHON) -c "import sys; print(f'{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}')") ) + +# getting commit log since the tag specified in $(LAST_VERSION_FILE) +# will write to $(COMMIT_LOG_FILE) +# when publishing, the contents of $(COMMIT_LOG_FILE) will be used as the tag description (but can be edited during the process) +# no echo here, because this is "private" +.PHONY: gen-commit-log +gen-commit-log: gen-version-info + @if [ "$(LAST_VERSION)" = "NULL" ]; then \ + echo "!!! ERROR !!!"; \ + echo "LAST_VERSION is NULL, cant get commit log!"; \ + exit 1; \ + fi + @mkdir -p $(LOCAL_DIR) + @$(PYTHON) -c "$$SCRIPT_GET_COMMIT_LOG" "$(LAST_VERSION)" "$(COMMIT_LOG_FILE)" + + +# force the version info to be read, printing it out +# also force the commit log to be generated, and cat it out +.PHONY: version +version: gen-commit-log + @echo "Current version is $(PROJ_VERSION), last auto-uploaded version is $(LAST_VERSION)" + @echo "Commit log since last version from '$(COMMIT_LOG_FILE)':" + @cat $(COMMIT_LOG_FILE) + @echo "" + @if [ "$(PROJ_VERSION)" = "$(LAST_VERSION)" ]; then \ + echo "!!! ERROR !!!"; \ + echo "Python package $(PROJ_VERSION) is the same as last published version $(LAST_VERSION), exiting!"; \ + exit 1; \ + fi + + + +######## ######## ######## ###### +## ## ## ## ## ## ## +## ## ## ## ## ## +## ## ###### ######## ###### +## ## ## ## ## +## ## ## ## ## ## +######## ######## ## ###### + +# ================================================== +# dependencies and setup +# ================================================== + +.PHONY: setup +setup: dep-check + @echo "install and update via uv" + @echo "To activate the virtual environment, run one of:" + @echo " source .venv/bin/activate" + @echo " source .venv/Scripts/activate" + +.PHONY: dep-check-torch +dep-check-torch: + @echo "see if torch is installed, and which CUDA version and devices it sees" + $(PYTHON) -c "$$SCRIPT_CHECK_TORCH" + +.PHONY: dep +dep: + @echo "Exporting dependencies as per $(PYPROJECT) section 'tool.uv-exports.exports'" + uv sync --all-extras --all-groups --compile-bytecode + mkdir -p $(REQUIREMENTS_DIR) + $(PYTHON) -c "$$SCRIPT_EXPORT_REQUIREMENTS" $(PYPROJECT) $(REQUIREMENTS_DIR) | sh -x + + +.PHONY: dep-check +dep-check: + @echo "Checking that exported requirements are up to date" + uv sync --all-extras --all-groups + mkdir -p $(REQUIREMENTS_DIR)-TEMP + $(PYTHON) -c "$$SCRIPT_EXPORT_REQUIREMENTS" $(PYPROJECT) $(REQUIREMENTS_DIR)-TEMP | sh -x + diff -r $(REQUIREMENTS_DIR)-TEMP $(REQUIREMENTS_DIR) + rm -rf $(REQUIREMENTS_DIR)-TEMP + + +.PHONY: dep-clean +dep-clean: + @echo "clean up lock files, .venv, and requirements files" + rm -rf .venv + rm -rf uv.lock + rm -rf $(REQUIREMENTS_DIR)/*.txt + + + ###### ## ## ######## ###### ## ## ###### +## ## ## ## ## ## ## ## ## ## ## +## ## ## ## ## ## ## ## +## ######### ###### ## ##### ###### +## ## ## ## ## ## ## ## +## ## ## ## ## ## ## ## ## ## ## + ###### ## ## ######## ###### ## ## ###### + +# ================================================== +# checks (formatting/linting, typing, tests) +# ================================================== + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# added gen-extra-tests and it is required by some other recipes: +# format-check, typing, test + +# extra tests with python >=3.10 type hints +.PHONY: gen-extra-tests +gen-extra-tests: + if [ $(COMPATIBILITY_MODE) -eq 0 ]; then \ + echo "converting certain tests to modern format"; \ + $(PYTHON) tests/util/replace_type_hints.py tests/unit/validate_type/test_validate_type.py "# DO NOT EDIT, GENERATED FILE" > tests/unit/validate_type/test_validate_type_GENERATED.py; \ + fi; \ + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +# runs ruff and pycln to format the code +.PHONY: format +format: + @echo "format the source code" + $(PYTHON) -m ruff format --config $(PYPROJECT) . + $(PYTHON) -m ruff check --fix --config $(PYPROJECT) . + $(PYTHON) -m pycln --config $(PYPROJECT) --all . + +# runs ruff and pycln to check if the code is formatted correctly +.PHONY: format-check +format-check: + @echo "check if the source code is formatted correctly" + $(PYTHON) -m ruff check --config $(PYPROJECT) . + $(PYTHON) -m pycln --check --config $(PYPROJECT) . + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# runs type checks with mypy +# at some point, need to add back --check-untyped-defs to mypy call +# but it complains when we specify arguments by keyword where positional is fine +# not sure how to fix this +.PHONY: typing +typing: gen-extra-tests + @echo "running type checks" + $(PYTHON) -m mypy --config-file $(PYPROJECT) $(TYPECHECK_ARGS) . +# $(PYTHON) -m ty check muutils/ + + +# generates a report of the mypy output +.PHONY: typing-report +typing-report: clean gen-extra-tests + @echo "generate a report of the type check output -- errors per file" + $(PYTHON) -m mypy --config-file $(PYPROJECT) $(TYPECHECK_ARGS) . | $(PYTHON) -c "$$SCRIPT_MYPY_REPORT" --mode toml + +.PHONY: test +test: clean gen-extra-tests + @echo "running tests" + $(PYTHON) -m pytest $(PYTEST_OPTIONS) $(TESTS_DIR) +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.PHONY: check +check: clean format-check test typing + @echo "run format checks, tests, and typing checks" + + +######## ####### ###### ###### +## ## ## ## ## ## ## ## +## ## ## ## ## ## +## ## ## ## ## ###### +## ## ## ## ## ## +## ## ## ## ## ## ## ## +######## ####### ###### ###### + +# ================================================== +# coverage & docs +# ================================================== + +# generates a whole tree of documentation in html format. +# see `$(MAKE_DOCS_SCRIPT_PATH)` and the templates in `$(DOCS_RESOURCES_DIR)/templates/html/` for more info +.PHONY: docs-html +docs-html: + @echo "generate html docs" + $(PYTHON) $(MAKE_DOCS_SCRIPT_PATH) + +# instead of a whole website, generates a single markdown file with all docs using the templates in `$(DOCS_RESOURCES_DIR)/templates/markdown/`. +# this is useful if you want to have a copy that you can grep/search, but those docs are much messier. +# docs-combined will use pandoc to convert them to other formats. +.PHONY: docs-md +docs-md: + @echo "generate combined (single-file) docs in markdown" + mkdir $(DOCS_DIR)/combined -p + $(PYTHON) $(MAKE_DOCS_SCRIPT_PATH) --combined + +# after running docs-md, this will convert the combined markdown file to other formats: +# gfm (github-flavored markdown), plain text, and html +# requires pandoc in path, pointed to by $(PANDOC) +# pdf output would be nice but requires other deps +.PHONY: docs-combined +docs-combined: docs-md + @echo "generate combined (single-file) docs in markdown and convert to other formats" + @echo "requires pandoc in path" + $(PANDOC) -f markdown -t gfm $(DOCS_DIR)/combined/$(PACKAGE_NAME).md -o $(DOCS_DIR)/combined/$(PACKAGE_NAME)_gfm.md + $(PANDOC) -f markdown -t plain $(DOCS_DIR)/combined/$(PACKAGE_NAME).md -o $(DOCS_DIR)/combined/$(PACKAGE_NAME).txt + $(PANDOC) -f markdown -t html $(DOCS_DIR)/combined/$(PACKAGE_NAME).md -o $(DOCS_DIR)/combined/$(PACKAGE_NAME).html + +# generates coverage reports as html and text with `pytest-cov`, and a badge with `coverage-badge` +# if `.coverage` is not found, will run tests first +# also removes the `.gitignore` file that `coverage html` creates, since we count that as part of the docs +.PHONY: cov +cov: + @echo "generate coverage reports" + @if [ ! -f .coverage ]; then \ + echo ".coverage not found, running tests first..."; \ + $(MAKE) test; \ + fi + mkdir $(COVERAGE_REPORTS_DIR) -p + $(PYTHON) -m coverage report -m > $(COVERAGE_REPORTS_DIR)/coverage.txt + $(PYTHON) -m coverage_badge -f -o $(COVERAGE_REPORTS_DIR)/coverage.svg + $(PYTHON) -m coverage html --directory=$(COVERAGE_REPORTS_DIR)/html/ + rm -rf $(COVERAGE_REPORTS_DIR)/html/.gitignore + +# runs the coverage report, then the docs, then the combined docs +.PHONY: docs +docs: cov docs-html docs-combined todo lmcat + @echo "generate all documentation and coverage reports" + +# removed all generated documentation files, but leaves everything in `$DOCS_RESOURCES_DIR` +# and leaves things defined in `pyproject.toml:tool.makefile.docs.no_clean` +# (templates, svg, css, make_docs.py script) +# distinct from `make clean` +.PHONY: docs-clean +docs-clean: + @echo "remove generated docs except resources" + $(PYTHON) -c "$$SCRIPT_DOCS_CLEAN" $(PYPROJECT) $(DOCS_DIR) $(DOCS_RESOURCES_DIR) + +.PHONY: todo +todo: + @echo "get all TODO's from the code" + $(PYTHON) -c "$$SCRIPT_GET_TODOS" + +.PHONY: lmcat-tree +lmcat-tree: + @echo "show in console the lmcat tree view" + -$(PYTHON) -m lmcat -t --output STDOUT + +.PHONY: lmcat +lmcat: + @echo "write the lmcat full output to pyproject.toml:[tool.lmcat.output]" + -$(PYTHON) -m lmcat + +######## ## ## #### ## ######## +## ## ## ## ## ## ## ## +## ## ## ## ## ## ## ## +######## ## ## ## ## ## ## +## ## ## ## ## ## ## ## +## ## ## ## ## ## ## ## +######## ####### #### ######## ######## + +# ================================================== +# build and publish +# ================================================== + +# verifies that the current branch is $(PUBLISH_BRANCH) and that git is clean +# used before publishing +.PHONY: verify-git +verify-git: + @echo "checking git status" + if [ "$(shell git branch --show-current)" != $(PUBLISH_BRANCH) ]; then \ + echo "!!! ERROR !!!"; \ + echo "Git is not on the $(PUBLISH_BRANCH) branch, exiting!"; \ + git branch; \ + git status; \ + exit 1; \ + fi; \ + if [ -n "$(shell git status --porcelain)" ]; then \ + echo "!!! ERROR !!!"; \ + echo "Git is not clean, exiting!"; \ + git status; \ + exit 1; \ + fi; \ + + +.PHONY: build +build: + @echo "build the package" + uv build + +# gets the commit log, checks everything, builds, and then publishes with twine +# will ask the user to confirm the new version number (and this allows for editing the tag info) +# will also print the contents of $(PYPI_TOKEN_FILE) to the console for the user to copy and paste in when prompted by twine +.PHONY: publish +publish: gen-commit-log check build verify-git version gen-version-info + @echo "run all checks, build, and then publish" + + @echo "Enter the new version number if you want to upload to pypi and create a new tag" + @echo "Now would also be the time to edit $(COMMIT_LOG_FILE), as that will be used as the tag description" + @read -p "Confirm: " NEW_VERSION; \ + if [ "$$NEW_VERSION" = $(PROJ_VERSION) ]; then \ + echo "!!! ERROR !!!"; \ + echo "Version confirmed. Proceeding with publish."; \ + else \ + echo "Version mismatch, exiting: you gave $$NEW_VERSION but expected $(PROJ_VERSION)"; \ + exit 1; \ + fi; + + @echo "pypi username: __token__" + @echo "pypi token from '$(PYPI_TOKEN_FILE)' :" + echo $$(cat $(PYPI_TOKEN_FILE)) + + echo "Uploading!"; \ + echo $(PROJ_VERSION) > $(LAST_VERSION_FILE); \ + git add $(LAST_VERSION_FILE); \ + git commit -m "Auto update to $(PROJ_VERSION)"; \ + git tag -a $(PROJ_VERSION) -F $(COMMIT_LOG_FILE); \ + git push origin $(PROJ_VERSION); \ + twine upload dist/* --verbose + +# ================================================== +# cleanup of temp files +# ================================================== + +# cleans up temp files from formatter, type checking, tests, coverage +# removes all built files +# removes $(TESTS_TEMP_DIR) to remove temporary test files +# recursively removes all `__pycache__` directories and `*.pyc` or `*.pyo` files +# distinct from `make docs-clean`, which only removes generated documentation files +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# slight modification in last line for extra tests +.PHONY: clean +clean: + @echo "clean up temporary files" + rm -rf .mypy_cache + rm -rf .ruff_cache + rm -rf .pytest_cache + rm -rf .coverage + rm -rf dist + rm -rf build + rm -rf $(PACKAGE_NAME).egg-info + rm -rf $(TESTS_TEMP_DIR) + $(PYTHON_BASE) -Bc "import pathlib; [p.unlink() for path in ['$(PACKAGE_NAME)', '$(TESTS_DIR)', '$(DOCS_DIR)'] for pattern in ['*.py[co]', '__pycache__/*'] for p in pathlib.Path(path).rglob(pattern)]" + rm -rf tests/unit/validate_type/test_validate_type_GENERATED.py +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.PHONY: clean-all +clean-all: clean docs-clean dep-clean + @echo "clean up all temporary files, dep files, venv, and generated docs" + + +## ## ######## ## ######## +## ## ## ## ## ## +## ## ## ## ## ## +######### ###### ## ######## +## ## ## ## ## +## ## ## ## ## +## ## ######## ######## ## + +# ================================================== +# smart help command +# ================================================== + +# listing targets is from stackoverflow +# https://stackoverflow.com/questions/4219255/how-do-you-get-the-list-of-targets-in-a-makefile +# no .PHONY because this will only be run before `make help` +# it's a separate command because getting the `info` takes a bit of time +# and we want to show the make targets right away without making the user wait for `info` to finish running +help-targets: + @echo -n "# make targets" + @echo ":" + @cat makefile | sed -n '/^\.PHONY: / h; /\(^\t@*echo\|^\t:\)/ {H; x; /PHONY/ s/.PHONY: \(.*\)\n.*"\(.*\)"/ make \1\t\2/p; d; x}'| sort -k2,2 |expand -t 30 + + +.PHONY: info +info: gen-version-info + @echo "# makefile variables" + @echo " PYTHON = $(PYTHON)" + @echo " PYTHON_VERSION = $(PYTHON_VERSION)" + @echo " PACKAGE_NAME = $(PACKAGE_NAME)" + @echo " PROJ_VERSION = $(PROJ_VERSION)" + @echo " LAST_VERSION = $(LAST_VERSION)" + @echo " PYTEST_OPTIONS = $(PYTEST_OPTIONS)" + +.PHONY: info-long +info-long: info + @echo "# other variables" + @echo " PUBLISH_BRANCH = $(PUBLISH_BRANCH)" + @echo " DOCS_DIR = $(DOCS_DIR)" + @echo " COVERAGE_REPORTS_DIR = $(COVERAGE_REPORTS_DIR)" + @echo " TESTS_DIR = $(TESTS_DIR)" + @echo " TESTS_TEMP_DIR = $(TESTS_TEMP_DIR)" + @echo " PYPROJECT = $(PYPROJECT)" + @echo " REQUIREMENTS_DIR = $(REQUIREMENTS_DIR)" + @echo " LOCAL_DIR = $(LOCAL_DIR)" + @echo " PYPI_TOKEN_FILE = $(PYPI_TOKEN_FILE)" + @echo " LAST_VERSION_FILE = $(LAST_VERSION_FILE)" + @echo " PYTHON_BASE = $(PYTHON_BASE)" + @echo " COMMIT_LOG_FILE = $(COMMIT_LOG_FILE)" + @echo " PANDOC = $(PANDOC)" + @echo " COV = $(COV)" + @echo " VERBOSE = $(VERBOSE)" + @echo " RUN_GLOBAL = $(RUN_GLOBAL)" + @echo " TYPECHECK_ARGS = $(TYPECHECK_ARGS)" + +# immediately print out the help targets, and then local variables (but those take a bit longer) +.PHONY: help +help: help-targets info + @echo -n "" + + + ###### ## ## ###### ######## ####### ## ## +## ## ## ## ## ## ## ## ## ### ### +## ## ## ## ## ## ## #### #### +## ## ## ###### ## ## ## ## ### ## +## ## ## ## ## ## ## ## ## +## ## ## ## ## ## ## ## ## ## ## + ###### ####### ###### ## ####### ## ## + +# ================================================== +# custom targets +# ================================================== +# (put them down here, or delimit with ~~~~~) \ No newline at end of file From 50d14a6c425dc34302b0fe23f3378209c579b5a8 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Wed, 29 Oct 2025 14:39:10 +0000 Subject: [PATCH 19/72] merging makefile pt2 --- makefile | 59 +++++++++++++++++++++++++++++++++++++++++++++++----- makefile-old | 58 ++------------------------------------------------- 2 files changed, 56 insertions(+), 61 deletions(-) diff --git a/makefile b/makefile index ee6e3fea..34dce67b 100644 --- a/makefile +++ b/makefile @@ -128,7 +128,32 @@ endif # if you want different behavior for different python versions # -------------------------------------------------- -# COMPATIBILITY_MODE := $(shell $(PYTHON) -c "import sys; print(1 if sys.version_info < (3, 10) else 0)") +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# compatibility mode for python <3.10 + +# loose typing, allow warnings for python <3.10 +# -------------------------------------------------- +TYPECHECK_ARGS ?= +# COMPATIBILITY_MODE: whether to run in compatibility mode for python <3.10 +COMPATIBILITY_MODE := $(shell $(PYTHON) -c "import sys; print(1 if sys.version_info < (3, 11) else 0)") + +# compatibility mode for python <3.10 +# -------------------------------------------------- + +# whether to run pytest with warnings as errors +WARN_STRICT ?= 0 + +ifneq ($(WARN_STRICT), 0) + PYTEST_OPTIONS += -W error +endif + +# compatibility mode for python <3.10 +ifeq ($(COMPATIBILITY_MODE), 1) + JUNK := $(info !!! WARNING !!!: Detected python version less than 3.10, some behavior will be different) + TYPECHECK_ARGS += --disable-error-code misc --disable-error-code syntax --disable-error-code import-not-found --no-check-untyped-defs +endif + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # options we might want to pass to pytest # -------------------------------------------------- @@ -1704,6 +1729,20 @@ dep-clean: # checks (formatting/linting, typing, tests) # ================================================== +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# added gen-extra-tests and it is required by some other recipes: +# format-check, typing, test + +# extra tests with python >=3.10 type hints +.PHONY: gen-extra-tests +gen-extra-tests: + if [ $(COMPATIBILITY_MODE) -eq 0 ]; then \ + echo "converting certain tests to modern format"; \ + $(PYTHON) tests/util/replace_type_hints.py tests/unit/validate_type/test_validate_type.py "# DO NOT EDIT, GENERATED FILE" > tests/unit/validate_type/test_validate_type_GENERATED.py; \ + fi; \ + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # format code AND auto-fix linting issues # performs TWO operations: reformats code, then auto-fixes safe linting issues # configure in pyproject.toml:[tool.ruff] @@ -1719,17 +1758,22 @@ format-check: @echo "check if the source code is formatted correctly" $(PYTHON) -m ruff check --config $(PYPROJECT) . +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # runs type checks with mypy +# at some point, need to add back --check-untyped-defs to mypy call +# but it complains when we specify arguments by keyword where positional is fine +# not sure how to fix this .PHONY: typing -typing: clean +typing: gen-extra-tests @echo "running type checks" $(PYTHON) -m mypy --config-file $(PYPROJECT) $(TYPECHECK_ARGS) . +# $(PYTHON) -m ty check muutils/ # generate summary report of type check errors grouped by file # outputs TOML format showing error count per file # useful for tracking typing progress across large codebases .PHONY: typing-report -typing-report: +typing-report: clean gen-extra-tests @echo "generate a report of the type check output -- errors per file" $(PYTHON) -m mypy --config-file $(PYPROJECT) $(TYPECHECK_ARGS) . | $(PYTHON) -c "$$SCRIPT_MYPY_REPORT" --mode toml @@ -1743,9 +1787,10 @@ typing-report: # PYTEST_OPTIONS="..." # pass additional pytest arguments # pytest config in pyproject.toml:[tool.pytest.ini_options] .PHONY: test -test: clean +test: clean gen-extra-tests @echo "running tests" $(PYTHON) -m pytest $(PYTEST_OPTIONS) $(TESTS_DIR) +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .PHONY: check check: clean format-check test typing @@ -1955,6 +2000,8 @@ publish: gen-commit-log check build verify-git version gen-version-info # removes $(TESTS_TEMP_DIR) to remove temporary test files # recursively removes all `__pycache__` directories and `*.pyc` or `*.pyo` files # distinct from `make docs-clean`, which only removes generated documentation files +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# slight modification in last line for extra tests .PHONY: clean clean: @echo "clean up temporary files" @@ -1966,7 +2013,9 @@ clean: rm -rf build rm -rf $(PACKAGE_NAME).egg-info rm -rf $(TESTS_TEMP_DIR) - $(PYTHON) -Bc "import pathlib; [p.unlink() for path in ['$(PACKAGE_NAME)', '$(TESTS_DIR)', '$(DOCS_DIR)'] for pattern in ['*.py[co]', '__pycache__/*'] for p in pathlib.Path(path).rglob(pattern)]" + $(PYTHON_BASE) -Bc "import pathlib; [p.unlink() for path in ['$(PACKAGE_NAME)', '$(TESTS_DIR)', '$(DOCS_DIR)'] for pattern in ['*.py[co]', '__pycache__/*'] for p in pathlib.Path(path).rglob(pattern)]" + rm -rf tests/unit/validate_type/test_validate_type_GENERATED.py +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # remove all generated/build files including .venv # runs: clean + docs-clean + dep-clean diff --git a/makefile-old b/makefile-old index ee992946..e5d58dd8 100644 --- a/makefile-old +++ b/makefile-old @@ -123,32 +123,7 @@ endif # if you want different behavior for different python versions # -------------------------------------------------- -# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -# compatibility mode for python <3.10 -# loose typing, allow warnings for python <3.10 -# -------------------------------------------------- -TYPECHECK_ARGS ?= -# COMPATIBILITY_MODE: whether to run in compatibility mode for python <3.10 -COMPATIBILITY_MODE := $(shell $(PYTHON) -c "import sys; print(1 if sys.version_info < (3, 11) else 0)") - -# compatibility mode for python <3.10 -# -------------------------------------------------- - -# whether to run pytest with warnings as errors -WARN_STRICT ?= 0 - -ifneq ($(WARN_STRICT), 0) - PYTEST_OPTIONS += -W error -endif - -# compatibility mode for python <3.10 -ifeq ($(COMPATIBILITY_MODE), 1) - JUNK := $(info !!! WARNING !!!: Detected python version less than 3.10, some behavior will be different) - TYPECHECK_ARGS += --disable-error-code misc --disable-error-code syntax --disable-error-code import-not-found --no-check-untyped-defs -endif - -# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # options we might want to pass to pytest # -------------------------------------------------- @@ -1380,19 +1355,6 @@ dep-clean: # checks (formatting/linting, typing, tests) # ================================================== -# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -# added gen-extra-tests and it is required by some other recipes: -# format-check, typing, test - -# extra tests with python >=3.10 type hints -.PHONY: gen-extra-tests -gen-extra-tests: - if [ $(COMPATIBILITY_MODE) -eq 0 ]; then \ - echo "converting certain tests to modern format"; \ - $(PYTHON) tests/util/replace_type_hints.py tests/unit/validate_type/test_validate_type.py "# DO NOT EDIT, GENERATED FILE" > tests/unit/validate_type/test_validate_type_GENERATED.py; \ - fi; \ - -# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # runs ruff and pycln to format the code .PHONY: format @@ -1409,7 +1371,7 @@ format-check: $(PYTHON) -m ruff check --config $(PYPROJECT) . $(PYTHON) -m pycln --check --config $(PYPROJECT) . -# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # runs type checks with mypy # at some point, need to add back --check-untyped-defs to mypy call # but it complains when we specify arguments by keyword where positional is fine @@ -1431,7 +1393,7 @@ typing-report: clean gen-extra-tests test: clean gen-extra-tests @echo "running tests" $(PYTHON) -m pytest $(PYTEST_OPTIONS) $(TESTS_DIR) -# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + .PHONY: check check: clean format-check test typing @@ -1599,22 +1561,6 @@ publish: gen-commit-log check build verify-git version gen-version-info # removes $(TESTS_TEMP_DIR) to remove temporary test files # recursively removes all `__pycache__` directories and `*.pyc` or `*.pyo` files # distinct from `make docs-clean`, which only removes generated documentation files -# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -# slight modification in last line for extra tests -.PHONY: clean -clean: - @echo "clean up temporary files" - rm -rf .mypy_cache - rm -rf .ruff_cache - rm -rf .pytest_cache - rm -rf .coverage - rm -rf dist - rm -rf build - rm -rf $(PACKAGE_NAME).egg-info - rm -rf $(TESTS_TEMP_DIR) - $(PYTHON_BASE) -Bc "import pathlib; [p.unlink() for path in ['$(PACKAGE_NAME)', '$(TESTS_DIR)', '$(DOCS_DIR)'] for pattern in ['*.py[co]', '__pycache__/*'] for p in pathlib.Path(path).rglob(pattern)]" - rm -rf tests/unit/validate_type/test_validate_type_GENERATED.py -# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .PHONY: clean-all clean-all: clean docs-clean dep-clean From a3ca6137c8435adadd5edf32e598659042a63337 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Wed, 29 Oct 2025 14:39:28 +0000 Subject: [PATCH 20/72] remove old makefile --- makefile-old | 1641 -------------------------------------------------- 1 file changed, 1641 deletions(-) delete mode 100644 makefile-old diff --git a/makefile-old b/makefile-old deleted file mode 100644 index e5d58dd8..00000000 --- a/makefile-old +++ /dev/null @@ -1,1641 +0,0 @@ -#|==================================================================| -#| python project makefile template | -#| originally by Michael Ivanitskiy (mivanits@umich.edu) | -#| https://github.com/mivanit/python-project-makefile-template | -#| version: v0.3.4 | -#| license: https://creativecommons.org/licenses/by-sa/4.0/ | -#| modifications from the original should be denoted with `~~~~~` | -#| as this makes it easier to find edits when updating makefile | -#|==================================================================| - - - ###### ######## ###### -## ## ## ## ## -## ## ## -## ###### ## #### -## ## ## ## -## ## ## ## ## - ###### ## ###### - -# ================================================== -# configuration & variables -# ================================================== - -# it assumes that the source is in a directory named the same as the package name -# this also gets passed to some other places -PACKAGE_NAME := muutils - -# for checking you are on the right branch when publishing -PUBLISH_BRANCH := main - -# where to put docs -# if you change this, you must also change pyproject.toml:tool.makefile.docs.output_dir to match -DOCS_DIR := docs - -# where the tests are, for pytest -TESTS_DIR := tests - -# tests temp directory to clean up. will remove this in `make clean` -TESTS_TEMP_DIR := $(TESTS_DIR)/_temp/ - -# probably don't change these: -# -------------------------------------------------- - -# where the pyproject.toml file is. no idea why you would change this but just in case -PYPROJECT := pyproject.toml - -# dir to store various configuration files -# use of `.meta/` inspired by https://news.ycombinator.com/item?id=36472613 -META_DIR := .meta - -# requirements.txt files for base package, all extras, dev, and all -REQUIREMENTS_DIR := $(META_DIR)/requirements - -# local files (don't push this to git!) -LOCAL_DIR := $(META_DIR)/local - -# will print this token when publishing. make sure not to commit this file!!! -PYPI_TOKEN_FILE := $(LOCAL_DIR)/.pypi-token - -# version files -VERSIONS_DIR := $(META_DIR)/versions - -# the last version that was auto-uploaded. will use this to create a commit log for version tag -# see `gen-commit-log` target -LAST_VERSION_FILE := $(VERSIONS_DIR)/.lastversion - -# current version (writing to file needed due to shell escaping issues) -VERSION_FILE := $(VERSIONS_DIR)/.version - -# base python to use. Will add `uv run` in front of this if `RUN_GLOBAL` is not set to 1 -PYTHON_BASE := python - -# where the commit log will be stored -COMMIT_LOG_FILE := $(LOCAL_DIR)/.commit_log - -# pandoc commands (for docs) -PANDOC ?= pandoc - -# where to put the coverage reports -# note that this will be published with the docs! -# modify the `docs` targets and `.gitignore` if you don't want that -COVERAGE_REPORTS_DIR := $(DOCS_DIR)/coverage - -# this stuff in the docs will be kept -# in addition to anything specified in `pyproject.toml:tool.makefile.docs.no_clean` -DOCS_RESOURCES_DIR := $(DOCS_DIR)/resources - -# location of the make docs script -MAKE_DOCS_SCRIPT_PATH := $(DOCS_RESOURCES_DIR)/make_docs.py - -# version vars - extracted automatically from `pyproject.toml`, `$(LAST_VERSION_FILE)`, and $(PYTHON) -# -------------------------------------------------- - -# assuming your `pyproject.toml` has a line that looks like `version = "0.0.1"`, `gen-version-info` will extract this -PROJ_VERSION := NULL -# `gen-version-info` will read the last version from `$(LAST_VERSION_FILE)`, or `NULL` if it doesn't exist -LAST_VERSION := NULL -# get the python version, now that we have picked the python command -PYTHON_VERSION := NULL - - -# ================================================== -# reading command line options -# ================================================== - -# for formatting or something, we might want to run python without uv -# RUN_GLOBAL=1 to use global `PYTHON_BASE` instead of `uv run $(PYTHON_BASE)` -RUN_GLOBAL ?= 0 - -# for running tests or other commands without updating the env, set this to 1 -# and it will pass `--no-sync` to `uv run` -UV_NOSYNC ?= 0 - -ifeq ($(RUN_GLOBAL),0) - ifeq ($(UV_NOSYNC),1) - PYTHON = uv run --no-sync $(PYTHON_BASE) - else - PYTHON = uv run $(PYTHON_BASE) - endif -else - PYTHON = $(PYTHON_BASE) -endif - -# if you want different behavior for different python versions -# -------------------------------------------------- - - -# options we might want to pass to pytest -# -------------------------------------------------- - -# base options for pytest, will be appended to if `COV` or `VERBOSE` are 1. -# user can also set this when running make to add more options -PYTEST_OPTIONS ?= - -# set to `1` to run pytest with `--cov=.` to get coverage reports in a `.coverage` file -COV ?= 1 -# set to `1` to run pytest with `--verbose` -VERBOSE ?= 0 - -ifeq ($(VERBOSE),1) - PYTEST_OPTIONS += --verbose -endif - -ifeq ($(COV),1) - PYTEST_OPTIONS += --cov=. -endif - -# ================================================== -# default target (help) -# ================================================== - -# first/default target is help -.PHONY: default -default: help - - - - ###### ###### ######## #### ######## ######## ###### -## ## ## ## ## ## ## ## ## ## ## ## -## ## ## ## ## ## ## ## ## - ###### ## ######## ## ######## ## ###### - ## ## ## ## ## ## ## ## -## ## ## ## ## ## ## ## ## ## ## - ###### ###### ## ## #### ## ## ###### - -# ================================================== -# python scripts we want to use inside the makefile -# when developing, these are populated by `scripts/assemble_make.py` -# ================================================== - -# create commands for exporting requirements as specified in `pyproject.toml:tool.uv-exports.exports` -define SCRIPT_EXPORT_REQUIREMENTS -# source: https://github.com/mivanit/python-project-makefile-template/tree/main/scripts/make/export_requirements.py - -"export to requirements.txt files based on pyproject.toml configuration" - -from __future__ import annotations - -import sys -import warnings - -try: - import tomllib # type: ignore[import-not-found] -except ImportError: - import tomli as tomllib # type: ignore -from functools import reduce -from pathlib import Path -from typing import Any, Dict, List, Union - -TOOL_PATH: str = "tool.makefile.uv-exports" - - -def deep_get(d: dict, path: str, default: Any = None, sep: str = ".") -> Any: # noqa: ANN401 - "get a value from a nested dictionary" - return reduce( - lambda x, y: x.get(y, default) if isinstance(x, dict) else default, # function - path.split(sep) if isinstance(path, str) else path, # sequence - d, # initial - ) - - -def export_configuration( - export: dict, - all_groups: List[str], - all_extras: List[str], - export_opts: dict, - output_dir: Path, -) -> None: - "print to console a uv command for make which will export a requirements.txt file" - # get name and validate - name = export.get("name") - if not name or not name.isalnum(): - warnings.warn( - f"Export configuration missing valid 'name' field {export}", - ) - return - - # get other options with default fallbacks - filename: str = export.get("filename") or f"requirements-{name}.txt" - groups: Union[List[str], bool, None] = export.get("groups") - extras: Union[List[str], bool] = export.get("extras", []) - options: List[str] = export.get("options", []) - - # init command - cmd: List[str] = ["uv", "export", *export_opts.get("args", [])] - - # handle groups - if groups is not None: - groups_list: List[str] = [] - if isinstance(groups, bool): - if groups: - groups_list = all_groups.copy() - else: - groups_list = groups - - for group in all_groups: - if group in groups_list: - cmd.extend(["--group", group]) - else: - cmd.extend(["--no-group", group]) - - # handle extras - extras_list: List[str] = [] - if isinstance(extras, bool): - if extras: - extras_list = all_extras.copy() - else: - extras_list = extras - - for extra in extras_list: - cmd.extend(["--extra", extra]) - - # add extra options - cmd.extend(options) - - # assemble the command and print to console -- makefile will run it - output_path = output_dir / filename - print(f"{' '.join(cmd)} > {output_path.as_posix()}") - - -def main( - pyproject_path: Path, - output_dir: Path, -) -> None: - "export to requirements.txt files based on pyproject.toml configuration" - # read pyproject.toml - with open(pyproject_path, "rb") as f: - pyproject_data: dict = tomllib.load(f) - - # all available groups - all_groups: List[str] = list(pyproject_data.get("dependency-groups", {}).keys()) - all_extras: List[str] = list( - deep_get(pyproject_data, "project.optional-dependencies", {}).keys(), - ) - - # options for exporting - export_opts: dict = deep_get(pyproject_data, TOOL_PATH, {}) - - # what are we exporting? - exports: List[Dict[str, Any]] = export_opts.get("exports", []) - if not exports: - exports = [{"name": "all", "groups": [], "extras": [], "options": []}] - - # export each configuration - for export in exports: - export_configuration( - export=export, - all_groups=all_groups, - all_extras=all_extras, - export_opts=export_opts, - output_dir=output_dir, - ) - - -if __name__ == "__main__": - main( - pyproject_path=Path(sys.argv[1]), - output_dir=Path(sys.argv[2]), - ) - -endef - -export SCRIPT_EXPORT_REQUIREMENTS - - -# get the version from `pyproject.toml:project.version` -define SCRIPT_GET_VERSION -# source: https://github.com/mivanit/python-project-makefile-template/tree/main/scripts/make/get_version.py - -"write the current version of the project to a file" - -from __future__ import annotations - -import sys - -try: - try: - import tomllib # type: ignore[import-not-found] - except ImportError: - import tomli as tomllib # type: ignore - - pyproject_path: str = sys.argv[1].strip() - - with open(pyproject_path, "rb") as f: - pyproject_data: dict = tomllib.load(f) - - print("v" + pyproject_data["project"]["version"], end="") -except Exception: # noqa: BLE001 - print("NULL", end="") - sys.exit(1) - -endef - -export SCRIPT_GET_VERSION - - -# get the commit log since the last version from `$(LAST_VERSION_FILE)` -define SCRIPT_GET_COMMIT_LOG -# source: https://github.com/mivanit/python-project-makefile-template/tree/main/scripts/make/get_commit_log.py - -"pretty print a commit log amd wrote it to a file" - -from __future__ import annotations - -import subprocess -import sys -from typing import List - - -def main( - last_version: str, - commit_log_file: str, -) -> None: - "pretty print a commit log amd wrote it to a file" - if last_version == "NULL": - print("!!! ERROR !!!", file=sys.stderr) - print("LAST_VERSION is NULL, can't get commit log!", file=sys.stderr) - sys.exit(1) - - try: - log_cmd: List[str] = [ - "git", - "log", - f"{last_version}..HEAD", - "--pretty=format:- %s (%h)", - ] - commits: List[str] = ( - subprocess.check_output(log_cmd).decode("utf-8").strip().split("\n") # noqa: S603 - ) - with open(commit_log_file, "w") as f: - f.write("\n".join(reversed(commits))) - except subprocess.CalledProcessError as e: - print(f"Error: {e}", file=sys.stderr) - sys.exit(1) - - -if __name__ == "__main__": - main( - last_version=sys.argv[1].strip(), - commit_log_file=sys.argv[2].strip(), - ) - -endef - -export SCRIPT_GET_COMMIT_LOG - - -# get cuda information and whether torch sees it -define SCRIPT_CHECK_TORCH -# source: https://github.com/mivanit/python-project-makefile-template/tree/main/scripts/make/check_torch.py - -"print info about current python, torch, cuda, and devices" - -from __future__ import annotations - -import os -import re -import subprocess -import sys -from typing import Any, Callable, Dict, List, Optional, Tuple, Union - - -def print_info_dict( - info: Dict[str, Union[Any, Dict[str, Any]]], - indent: str = " ", - level: int = 1, -) -> None: - "pretty print the info" - indent_str: str = indent * level - longest_key_len: int = max(map(len, info.keys())) - for key, value in info.items(): - if isinstance(value, dict): - print(f"{indent_str}{key:<{longest_key_len}}:") - print_info_dict(value, indent, level + 1) - else: - print(f"{indent_str}{key:<{longest_key_len}} = {value}") - - -def get_nvcc_info() -> Dict[str, str]: - "get info about cuda from nvcc --version" - # Run the nvcc command. - try: - result: subprocess.CompletedProcess[str] = subprocess.run( # noqa: S603 - ["nvcc", "--version"], # noqa: S607 - check=True, - capture_output=True, - text=True, - ) - except Exception as e: # noqa: BLE001 - return {"Failed to run 'nvcc --version'": str(e)} - - output: str = result.stdout - lines: List[str] = [line.strip() for line in output.splitlines() if line.strip()] - - # Ensure there are exactly 5 lines in the output. - assert len(lines) == 5, ( # noqa: PLR2004 - f"Expected exactly 5 lines from nvcc --version, got {len(lines)} lines:\n{output}" - ) - - # Compile shared regex for release info. - release_regex: re.Pattern = re.compile( - r"Cuda compilation tools,\s*release\s*([^,]+),\s*(V.+)", - ) - - # Define a mapping for each desired field: - # key -> (line index, regex pattern, group index, transformation function) - patterns: Dict[str, Tuple[int, re.Pattern, int, Callable[[str], str]]] = { - "build_time": ( - 2, - re.compile(r"Built on (.+)"), - 1, - lambda s: s.replace("_", " "), - ), - "release": (3, release_regex, 1, str.strip), - "release_V": (3, release_regex, 2, str.strip), - "build": (4, re.compile(r"Build (.+)"), 1, str.strip), - } - - info: Dict[str, str] = {} - for key, (line_index, pattern, group_index, transform) in patterns.items(): - match: Optional[re.Match] = pattern.search(lines[line_index]) - if not match: - err_msg: str = ( - f"Unable to parse {key} from nvcc output: {lines[line_index]}" - ) - raise ValueError(err_msg) - info[key] = transform(match.group(group_index)) - - info["release_short"] = info["release"].replace(".", "").strip() - - return info - - -def get_torch_info() -> Tuple[List[Exception], Dict[str, Any]]: - "get info about pytorch and cuda devices" - exceptions: List[Exception] = [] - info: Dict[str, Any] = {} - - try: - import torch - except ImportError as e: - info["torch.__version__"] = "not available" - exceptions.append(e) - return exceptions, info - - try: - info["torch.__version__"] = torch.__version__ - info["torch.cuda.is_available()"] = torch.cuda.is_available() - - if torch.cuda.is_available(): - info["torch.version.cuda"] = torch.version.cuda - info["torch.cuda.device_count()"] = torch.cuda.device_count() - - if torch.cuda.device_count() > 0: - info["torch.cuda.current_device()"] = torch.cuda.current_device() - n_devices: int = torch.cuda.device_count() - info["n_devices"] = n_devices - for current_device in range(n_devices): - try: - current_device_info: Dict[str, Union[str, int]] = {} - - dev_prop = torch.cuda.get_device_properties( - torch.device(f"cuda:{current_device}"), - ) - - current_device_info["name"] = dev_prop.name - current_device_info["version"] = ( - f"{dev_prop.major}.{dev_prop.minor}" - ) - current_device_info["total_memory"] = ( - f"{dev_prop.total_memory} ({dev_prop.total_memory:.1e})" - ) - current_device_info["multi_processor_count"] = ( - dev_prop.multi_processor_count - ) - current_device_info["is_integrated"] = dev_prop.is_integrated - current_device_info["is_multi_gpu_board"] = ( - dev_prop.is_multi_gpu_board - ) - - info[f"device cuda:{current_device}"] = current_device_info - - except Exception as e: # noqa: PERF203,BLE001 - exceptions.append(e) - else: - err_msg_nodevice: str = ( - f"{torch.cuda.device_count() = } devices detected, invalid" - ) - raise ValueError(err_msg_nodevice) # noqa: TRY301 - - else: - err_msg_nocuda: str = ( - f"CUDA is NOT available in torch: {torch.cuda.is_available() = }" - ) - raise ValueError(err_msg_nocuda) # noqa: TRY301 - - except Exception as e: # noqa: BLE001 - exceptions.append(e) - - return exceptions, info - - -if __name__ == "__main__": - print(f"python: {sys.version}") - print_info_dict( - { - "python executable path: sys.executable": str(sys.executable), - "sys.platform": sys.platform, - "current working directory: os.getcwd()": os.getcwd(), # noqa: PTH109 - "Host name: os.name": os.name, - "CPU count: os.cpu_count()": str(os.cpu_count()), - }, - ) - - nvcc_info: Dict[str, Any] = get_nvcc_info() - print("nvcc:") - print_info_dict(nvcc_info) - - torch_exceptions, torch_info = get_torch_info() - print("torch:") - print_info_dict(torch_info) - - if torch_exceptions: - print("torch_exceptions:") - for e in torch_exceptions: - print(f" {e}") - -endef - -export SCRIPT_CHECK_TORCH - - -# get todo's from the code -define SCRIPT_GET_TODOS -# source: https://github.com/mivanit/python-project-makefile-template/tree/main/scripts/make/get_todos.py - -"read all TODO type comments and write them to markdown, jsonl, html. configurable in pyproject.toml" - -from __future__ import annotations - -import argparse -import fnmatch -import json -import textwrap -import urllib.parse -import warnings -from dataclasses import asdict, dataclass, field -from functools import reduce -from pathlib import Path -from typing import Any, Dict, List, Union - -from jinja2 import Template - -try: - import tomllib # type: ignore[import-not-found] -except ImportError: - import tomli as tomllib # type: ignore - -TOOL_PATH: str = "tool.makefile.inline-todo" - - -def deep_get(d: dict, path: str, default: Any = None, sep: str = ".") -> Any: # noqa: ANN401 - "get a value from a nested dictionary" - return reduce( - lambda x, y: x.get(y, default) if isinstance(x, dict) else default, # function - path.split(sep) if isinstance(path, str) else path, # sequence - d, # initial - ) - - -_TEMPLATE_MD_LIST: str = """\ -# Inline TODOs - -{% for tag, file_map in grouped|dictsort %} -# {{ tag }} -{% for filepath, item_list in file_map|dictsort %} -## [`{{ filepath }}`](/{{ filepath }}) -{% for itm in item_list %} -- {{ itm.stripped_title }} - local link: [`/{{ filepath }}:{{ itm.line_num }}`](/{{ filepath }}#L{{ itm.line_num }}) - | view on GitHub: [{{ itm.file }}#L{{ itm.line_num }}]({{ itm.code_url | safe }}) - | [Make Issue]({{ itm.issue_url | safe }}) -{% if itm.context %} - ```{{ itm.file_lang }} -{{ itm.context_indented }} - ``` -{% endif %} -{% endfor %} - -{% endfor %} -{% endfor %} -""" - -_TEMPLATE_MD_TABLE: str = """\ -# Inline TODOs - -| Location | Tag | Todo | GitHub | Issue | -|:---------|:----|:-----|:-------|:------| -{% for itm in all_items %}| [`{{ itm.file }}:{{ itm.line_num }}`](/{{ itm.file }}#L{{ itm.line_num }}) | {{ itm.tag }} | {{ itm.stripped_title_escaped }} | [View]({{ itm.code_url | safe }}) | [Create]({{ itm.issue_url | safe }}) | -{% endfor %} -""" - -TEMPLATES_MD: Dict[str, str] = dict( - standard=_TEMPLATE_MD_LIST, - table=_TEMPLATE_MD_TABLE, -) - -TEMPLATE_ISSUE: str = """\ -# source - -[`{file}#L{line_num}`]({code_url}) - -# context -```{file_lang} -{context} -``` -""" - - -@dataclass -class Config: - """Configuration for the inline-todo scraper""" - - search_dir: Path = Path() - out_file_base: Path = Path("docs/todo-inline") - tags: List[str] = field( - default_factory=lambda: ["CRIT", "TODO", "FIXME", "HACK", "BUG"], - ) - extensions: List[str] = field(default_factory=lambda: ["py", "md"]) - exclude: List[str] = field(default_factory=lambda: ["docs/**", ".venv/**"]) - context_lines: int = 2 - valid_post_tag: Union[str, List[str]] = " \t:<>|[](){{}}" - valid_pre_tag: Union[str, List[str]] = " \t:<>|[](){{}}#" - tag_label_map: Dict[str, str] = field( - default_factory=lambda: { - "CRIT": "bug", - "TODO": "enhancement", - "FIXME": "bug", - "BUG": "bug", - "HACK": "enhancement", - }, - ) - extension_lang_map: Dict[str, str] = field( - default_factory=lambda: { - "py": "python", - "md": "markdown", - "html": "html", - "css": "css", - "js": "javascript", - }, - ) - - templates_md: dict[str, str] = field(default_factory=lambda: TEMPLATES_MD) - # templates for the output markdown file - - template_issue: str = TEMPLATE_ISSUE - # template for the issue creation - - template_html_source: Path = Path("docs/resources/templates/todo-template.html") - # template source for the output html file (interactive table) - - @property - def template_html(self) -> str: - "read the html template" - return self.template_html_source.read_text(encoding="utf-8") - - template_code_url_: str = "{repo_url}/blob/{branch}/{file}#L{line_num}" - # template for the code url - - @property - def template_code_url(self) -> str: - "code url with repo url and branch substituted" - return self.template_code_url_.replace("{repo_url}", self.repo_url).replace( - "{branch}", - self.branch, - ) - - repo_url: str = "UNKNOWN" - # for the issue creation url - - branch: str = "main" - # branch for links to files on github - - @classmethod - def read(cls, config_file: Path) -> Config: - "read from a file, or return default" - output: Config - if config_file.is_file(): - # read file and load if present - with config_file.open("rb") as f: - data: Dict[str, Any] = tomllib.load(f) - - # try to get the repo url - repo_url: str = "UNKNOWN" - try: - urls: Dict[str, str] = { - k.lower(): v for k, v in data["project"]["urls"].items() - } - if "repository" in urls: - repo_url = urls["repository"] - if "github" in urls: - repo_url = urls["github"] - except Exception as e: # noqa: BLE001 - warnings.warn( - f"No repository URL found in pyproject.toml, 'make issue' links will not work.\n{e}", - ) - - # load the inline-todo config if present - data_inline_todo: Dict[str, Any] = deep_get( - d=data, - path=TOOL_PATH, - default={}, - ) - - if "repo_url" not in data_inline_todo: - data_inline_todo["repo_url"] = repo_url - - output = cls.load(data_inline_todo) - else: - # return default otherwise - output = cls() - - return output - - @classmethod - def load(cls, data: dict) -> Config: - "load from a dictionary, converting to `Path` as needed" - # process variables that should be paths - data = { - k: Path(v) - if k in {"search_dir", "out_file_base", "template_html_source"} - else v - for k, v in data.items() - } - - # default value for the templates - data["templates_md"] = { - **TEMPLATES_MD, - **data.get("templates_md", {}), - } - - return cls(**data) - - -CFG: Config = Config() -# this is messy, but we use a global config so we can get `TodoItem().issue_url` to work - - -@dataclass -class TodoItem: - """Holds one todo occurrence""" - - tag: str - file: str - line_num: int - content: str - context: str = "" - - def serialize(self) -> Dict[str, Union[str, int]]: - "serialize to a dict we can dump to json" - return { - **asdict(self), - "issue_url": self.issue_url, - "file_lang": self.file_lang, - "stripped_title": self.stripped_title, - "code_url": self.code_url, - } - - @property - def context_indented(self) -> str: - """Returns the context with each line indented""" - dedented: str = textwrap.dedent(self.context) - return textwrap.indent(dedented, " ") - - @property - def code_url(self) -> str: - """Returns a URL to the code on GitHub""" - return CFG.template_code_url.format( - file=self.file, - line_num=self.line_num, - ) - - @property - def stripped_title(self) -> str: - """Returns the title of the issue, stripped of the tag""" - return self.content.split(self.tag, 1)[-1].lstrip(":").strip() - - @property - def stripped_title_escaped(self) -> str: - """Returns the title of the issue, stripped of the tag and escaped for markdown""" - return self.stripped_title.replace("|", "\\|") - - @property - def issue_url(self) -> str: - """Constructs a GitHub issue creation URL for a given TodoItem.""" - # title - title: str = self.stripped_title - if not title: - title = "Issue from inline todo" - # body - body: str = CFG.template_issue.format( - file=self.file, - line_num=self.line_num, - context=self.context, - context_indented=self.context_indented, - code_url=self.code_url, - file_lang=self.file_lang, - ).strip() - # labels - label: str = CFG.tag_label_map.get(self.tag, self.tag) - # assemble url - query: Dict[str, str] = dict(title=title, body=body, labels=label) - query_string: str = urllib.parse.urlencode(query, quote_via=urllib.parse.quote) - return f"{CFG.repo_url}/issues/new?{query_string}" - - @property - def file_lang(self) -> str: - """Returns the language for the file extension""" - ext: str = Path(self.file).suffix.lstrip(".") - return CFG.extension_lang_map.get(ext, ext) - - -def scrape_file( - file_path: Path, - cfg: Config, -) -> List[TodoItem]: - """Scrapes a file for lines containing any of the specified tags""" - items: List[TodoItem] = [] - if not file_path.is_file(): - return items - lines: List[str] = file_path.read_text(encoding="utf-8").splitlines(True) - - # over all lines - for i, line in enumerate(lines): - # over all tags - for tag in cfg.tags: - # check tag is present - if tag in line[:200]: - # check tag is surrounded by valid strings - tag_idx_start: int = line.index(tag) - tag_idx_end: int = tag_idx_start + len(tag) - if ( - line[tag_idx_start - 1] in cfg.valid_pre_tag - and line[tag_idx_end] in cfg.valid_post_tag - ): - # get the context and add the item - start: int = max(0, i - cfg.context_lines) - end: int = min(len(lines), i + cfg.context_lines + 1) - snippet: str = "".join(lines[start:end]) - items.append( - TodoItem( - tag=tag, - file=file_path.as_posix(), - line_num=i + 1, - content=line.strip("\n"), - context=snippet.strip("\n"), - ), - ) - break - return items - - -def collect_files( - search_dir: Path, - extensions: List[str], - exclude: List[str], -) -> List[Path]: - """Recursively collects all files with specified extensions, excluding matches via globs""" - results: List[Path] = [] - for ext in extensions: - results.extend(search_dir.rglob(f"*.{ext}")) - - return [ - f - for f in results - if not any(fnmatch.fnmatch(f.as_posix(), pattern) for pattern in exclude) - ] - - -def group_items_by_tag_and_file( - items: List[TodoItem], -) -> Dict[str, Dict[str, List[TodoItem]]]: - """Groups items by tag, then by file""" - grouped: Dict[str, Dict[str, List[TodoItem]]] = {} - for itm in items: - grouped.setdefault(itm.tag, {}).setdefault(itm.file, []).append(itm) - for tag_dict in grouped.values(): - for file_list in tag_dict.values(): - file_list.sort(key=lambda x: x.line_num) - return grouped - - -def main(config_file: Path) -> None: - "cli interface to get todos" - global CFG # noqa: PLW0603 - # read configuration - cfg: Config = Config.read(config_file) - CFG = cfg - - # get data - files: List[Path] = collect_files(cfg.search_dir, cfg.extensions, cfg.exclude) - all_items: List[TodoItem] = [] - n_files: int = len(files) - for i, fpath in enumerate(files): - print(f"Scraping {i + 1:>2}/{n_files:>2}: {fpath.as_posix():<60}", end="\r") - all_items.extend(scrape_file(fpath, cfg)) - - # create dir - cfg.out_file_base.parent.mkdir(parents=True, exist_ok=True) - - # write raw to jsonl - with open(cfg.out_file_base.with_suffix(".jsonl"), "w", encoding="utf-8") as f: - for itm in all_items: - f.write(json.dumps(itm.serialize()) + "\n") - - # group, render - grouped: Dict[str, Dict[str, List[TodoItem]]] = group_items_by_tag_and_file( - all_items, - ) - - # render each template and save - for template_key, template in cfg.templates_md.items(): - rendered: str = Template(template).render(grouped=grouped, all_items=all_items) - template_out_path: Path = Path( - cfg.out_file_base.with_stem( - cfg.out_file_base.stem + f"-{template_key}", - ).with_suffix(".md"), - ) - template_out_path.write_text(rendered, encoding="utf-8") - - # write html output - try: - html_rendered: str = cfg.template_html.replace( - "//{{DATA}}//", - json.dumps([itm.serialize() for itm in all_items]), - ) - cfg.out_file_base.with_suffix(".html").write_text( - html_rendered, - encoding="utf-8", - ) - except Exception as e: # noqa: BLE001 - warnings.warn(f"Failed to write html output: {e}") - - print("wrote to:") - print(cfg.out_file_base.with_suffix(".md").as_posix()) - - -if __name__ == "__main__": - # parse args - parser: argparse.ArgumentParser = argparse.ArgumentParser("inline_todo") - parser.add_argument( - "--config-file", - default="pyproject.toml", - help="Path to the TOML config, will look under [tool.inline-todo].", - ) - args: argparse.Namespace = parser.parse_args() - # call main - main(Path(args.config_file)) - -endef - -export SCRIPT_GET_TODOS - - -# markdown to html using pdoc -define SCRIPT_PDOC_MARKDOWN2_CLI -# source: https://github.com/mivanit/python-project-makefile-template/tree/main/scripts/make/pdoc_markdown2_cli.py - -"cli to convert markdown files to HTML using pdoc's markdown2" - -from __future__ import annotations - -import argparse -from pathlib import Path -from typing import Optional - -from pdoc.markdown2 import Markdown, _safe_mode # type: ignore - - -def convert_file( - input_path: Path, - output_path: Path, - safe_mode: Optional[_safe_mode] = None, - encoding: str = "utf-8", -) -> None: - """Convert a markdown file to HTML""" - # Read markdown input - text: str = input_path.read_text(encoding=encoding) - - # Convert to HTML using markdown2 - markdown: Markdown = Markdown( - extras=["fenced-code-blocks", "header-ids", "markdown-in-html", "tables"], - safe_mode=safe_mode, - ) - html: str = markdown.convert(text) - - # Write HTML output - output_path.write_text(str(html), encoding=encoding) - - -def main() -> None: - "cli entry point" - parser: argparse.ArgumentParser = argparse.ArgumentParser( - description="Convert markdown files to HTML using pdoc's markdown2", - ) - parser.add_argument("input", type=Path, help="Input markdown file path") - parser.add_argument("output", type=Path, help="Output HTML file path") - parser.add_argument( - "--safe-mode", - choices=["escape", "replace"], - help="Sanitize literal HTML: 'escape' escapes HTML meta chars, 'replace' replaces with [HTML_REMOVED]", - ) - parser.add_argument( - "--encoding", - default="utf-8", - help="Character encoding for reading/writing files (default: utf-8)", - ) - - args: argparse.Namespace = parser.parse_args() - - convert_file( - args.input, - args.output, - safe_mode=args.safe_mode, - encoding=args.encoding, - ) - - -if __name__ == "__main__": - main() - -endef - -export SCRIPT_PDOC_MARKDOWN2_CLI - -# clean up the docs (configurable in pyproject.toml) -define SCRIPT_DOCS_CLEAN -# source: https://github.com/mivanit/python-project-makefile-template/tree/main/scripts/make/docs_clean.py - -"clean up docs directory based on pyproject.toml configuration" - -from __future__ import annotations - -import shutil -import sys -from functools import reduce -from pathlib import Path -from typing import Any, List, Set - -try: - import tomllib # type: ignore[import-not-found] -except ImportError: - import tomli as tomllib # type: ignore - -TOOL_PATH: str = "tool.makefile.docs" -DEFAULT_DOCS_DIR: str = "docs" - - -def deep_get(d: dict, path: str, default: Any = None, sep: str = ".") -> Any: # noqa: ANN401 - """Get nested dictionary value via separated path with default.""" - return reduce( - lambda x, y: x.get(y, default) if isinstance(x, dict) else default, # function - path.split(sep) if isinstance(path, str) else path, # sequence - d, # initial - ) - - -def read_config(pyproject_path: Path) -> tuple[Path, Set[Path]]: - "read configuration from pyproject.toml" - if not pyproject_path.is_file(): - return Path(DEFAULT_DOCS_DIR), set() - - with pyproject_path.open("rb") as f: - config = tomllib.load(f) - - preserved: List[str] = deep_get(config, f"{TOOL_PATH}.no_clean", []) - docs_dir: Path = Path(deep_get(config, f"{TOOL_PATH}.output_dir", DEFAULT_DOCS_DIR)) - - # Convert to absolute paths and validate - preserve_set: Set[Path] = set() - for p in preserved: - full_path = (docs_dir / p).resolve() - if not full_path.as_posix().startswith(docs_dir.resolve().as_posix()): - err_msg: str = f"Preserved path '{p}' must be within docs directory" - raise ValueError(err_msg) - preserve_set.add(docs_dir / p) - - return docs_dir, preserve_set - - -def clean_docs(docs_dir: Path, preserved: Set[Path]) -> None: - """delete files not in preserved set - - TODO: this is not recursive - """ - for path in docs_dir.iterdir(): - if path.is_file() and path not in preserved: - path.unlink() - elif path.is_dir() and path not in preserved: - shutil.rmtree(path) - - -def main( - pyproject_path: str, - docs_dir_cli: str, - extra_preserve: list[str], -) -> None: - "Clean up docs directory based on pyproject.toml configuration." - docs_dir: Path - preserved: Set[Path] - docs_dir, preserved = read_config(Path(pyproject_path)) - - assert docs_dir.is_dir(), f"Docs directory '{docs_dir}' not found" - assert docs_dir == Path(docs_dir_cli), ( - f"Docs directory mismatch: {docs_dir = } != {docs_dir_cli = }. this is probably because you changed one of `pyproject.toml:{TOOL_PATH}.output_dir` (the former) or `makefile:DOCS_DIR` (the latter) without updating the other." - ) - - for x in extra_preserve: - preserved.add(Path(x)) - clean_docs(docs_dir, preserved) - - -if __name__ == "__main__": - main(sys.argv[1], sys.argv[2], sys.argv[3:]) - -endef - -export SCRIPT_DOCS_CLEAN - -# generate a report of the mypy output -define SCRIPT_MYPY_REPORT -# source: https://github.com/mivanit/python-project-makefile-template/tree/main/scripts/make/mypy_report.py - -"usage: mypy ... | mypy_report.py [--mode jsonl|exclude]" - -from __future__ import annotations - -import argparse -import json -import re -import sys -from pathlib import Path -from typing import Dict, List, Tuple - - -def parse_mypy_output(lines: List[str]) -> Dict[str, int]: - "given mypy output, turn it into a dict of `filename: error_count`" - pattern: re.Pattern[str] = re.compile(r"^(?P[^:]+):\d+:\s+error:") - counts: Dict[str, int] = {} - for line in lines: - m = pattern.match(line) - if m: - f_raw: str = m.group("file") - f_norm: str = Path(f_raw).as_posix() - counts[f_norm] = counts.get(f_norm, 0) + 1 - return counts - - -def main() -> None: - "cli interface for mypy_report" - parser: argparse.ArgumentParser = argparse.ArgumentParser() - parser.add_argument("--mode", choices=["jsonl", "toml"], default="jsonl") - args: argparse.Namespace = parser.parse_args() - lines: List[str] = sys.stdin.read().splitlines() - error_dict: Dict[str, int] = parse_mypy_output(lines) - sorted_errors: List[Tuple[str, int]] = sorted( - error_dict.items(), - key=lambda x: x[1], - ) - if len(sorted_errors) == 0: - print("# no errors found!") - return - if args.mode == "jsonl": - for fname, count in sorted_errors: - print(json.dumps({"filename": fname, "errors": count})) - elif args.mode == "toml": - for fname, count in sorted_errors: - print(f'"{fname}", # {count}') - else: - err_msg: str = f"unknown mode {args.mode}" - raise ValueError(err_msg) - print(f"# total errors: {sum(error_dict.values())}") - - -if __name__ == "__main__": - main() - -endef - -export SCRIPT_MYPY_REPORT - - -## ## ######## ######## ###### #### ####### ## ## -## ## ## ## ## ## ## ## ## ## ### ## -## ## ## ## ## ## ## ## ## #### ## -## ## ###### ######## ###### ## ## ## ## ## ## - ## ## ## ## ## ## ## ## ## ## #### - ## ## ## ## ## ## ## ## ## ## ## ### - ### ######## ## ## ###### #### ####### ## ## - -# ================================================== -# getting version info -# we do this in a separate target because it takes a bit of time -# ================================================== - -# this recipe is weird. we need it because: -# - a one liner for getting the version with toml is unwieldy, and using regex is fragile -# - using $$SCRIPT_GET_VERSION within $(shell ...) doesn't work because of escaping issues -# - trying to write to the file inside the `gen-version-info` recipe doesn't work, -# shell eval happens before our `python -c ...` gets run and `cat` doesn't see the new file -.PHONY: write-proj-version -write-proj-version: - @mkdir -p $(VERSIONS_DIR) - @$(PYTHON) -c "$$SCRIPT_GET_VERSION" "$(PYPROJECT)" > $(VERSION_FILE) - -# gets version info from $(PYPROJECT), last version from $(LAST_VERSION_FILE), and python version -# uses just `python` for everything except getting the python version. no echo here, because this is "private" -.PHONY: gen-version-info -gen-version-info: write-proj-version - @mkdir -p $(LOCAL_DIR) - $(eval PROJ_VERSION := $(shell cat $(VERSION_FILE)) ) - $(eval LAST_VERSION := $(shell [ -f $(LAST_VERSION_FILE) ] && cat $(LAST_VERSION_FILE) || echo NULL) ) - $(eval PYTHON_VERSION := $(shell $(PYTHON) -c "import sys; print(f'{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}')") ) - -# getting commit log since the tag specified in $(LAST_VERSION_FILE) -# will write to $(COMMIT_LOG_FILE) -# when publishing, the contents of $(COMMIT_LOG_FILE) will be used as the tag description (but can be edited during the process) -# no echo here, because this is "private" -.PHONY: gen-commit-log -gen-commit-log: gen-version-info - @if [ "$(LAST_VERSION)" = "NULL" ]; then \ - echo "!!! ERROR !!!"; \ - echo "LAST_VERSION is NULL, cant get commit log!"; \ - exit 1; \ - fi - @mkdir -p $(LOCAL_DIR) - @$(PYTHON) -c "$$SCRIPT_GET_COMMIT_LOG" "$(LAST_VERSION)" "$(COMMIT_LOG_FILE)" - - -# force the version info to be read, printing it out -# also force the commit log to be generated, and cat it out -.PHONY: version -version: gen-commit-log - @echo "Current version is $(PROJ_VERSION), last auto-uploaded version is $(LAST_VERSION)" - @echo "Commit log since last version from '$(COMMIT_LOG_FILE)':" - @cat $(COMMIT_LOG_FILE) - @echo "" - @if [ "$(PROJ_VERSION)" = "$(LAST_VERSION)" ]; then \ - echo "!!! ERROR !!!"; \ - echo "Python package $(PROJ_VERSION) is the same as last published version $(LAST_VERSION), exiting!"; \ - exit 1; \ - fi - - - -######## ######## ######## ###### -## ## ## ## ## ## ## -## ## ## ## ## ## -## ## ###### ######## ###### -## ## ## ## ## -## ## ## ## ## ## -######## ######## ## ###### - -# ================================================== -# dependencies and setup -# ================================================== - -.PHONY: setup -setup: dep-check - @echo "install and update via uv" - @echo "To activate the virtual environment, run one of:" - @echo " source .venv/bin/activate" - @echo " source .venv/Scripts/activate" - -.PHONY: dep-check-torch -dep-check-torch: - @echo "see if torch is installed, and which CUDA version and devices it sees" - $(PYTHON) -c "$$SCRIPT_CHECK_TORCH" - -.PHONY: dep -dep: - @echo "Exporting dependencies as per $(PYPROJECT) section 'tool.uv-exports.exports'" - uv sync --all-extras --all-groups --compile-bytecode - mkdir -p $(REQUIREMENTS_DIR) - $(PYTHON) -c "$$SCRIPT_EXPORT_REQUIREMENTS" $(PYPROJECT) $(REQUIREMENTS_DIR) | sh -x - - -.PHONY: dep-check -dep-check: - @echo "Checking that exported requirements are up to date" - uv sync --all-extras --all-groups - mkdir -p $(REQUIREMENTS_DIR)-TEMP - $(PYTHON) -c "$$SCRIPT_EXPORT_REQUIREMENTS" $(PYPROJECT) $(REQUIREMENTS_DIR)-TEMP | sh -x - diff -r $(REQUIREMENTS_DIR)-TEMP $(REQUIREMENTS_DIR) - rm -rf $(REQUIREMENTS_DIR)-TEMP - - -.PHONY: dep-clean -dep-clean: - @echo "clean up lock files, .venv, and requirements files" - rm -rf .venv - rm -rf uv.lock - rm -rf $(REQUIREMENTS_DIR)/*.txt - - - ###### ## ## ######## ###### ## ## ###### -## ## ## ## ## ## ## ## ## ## ## -## ## ## ## ## ## ## ## -## ######### ###### ## ##### ###### -## ## ## ## ## ## ## ## -## ## ## ## ## ## ## ## ## ## ## - ###### ## ## ######## ###### ## ## ###### - -# ================================================== -# checks (formatting/linting, typing, tests) -# ================================================== - - -# runs ruff and pycln to format the code -.PHONY: format -format: - @echo "format the source code" - $(PYTHON) -m ruff format --config $(PYPROJECT) . - $(PYTHON) -m ruff check --fix --config $(PYPROJECT) . - $(PYTHON) -m pycln --config $(PYPROJECT) --all . - -# runs ruff and pycln to check if the code is formatted correctly -.PHONY: format-check -format-check: - @echo "check if the source code is formatted correctly" - $(PYTHON) -m ruff check --config $(PYPROJECT) . - $(PYTHON) -m pycln --check --config $(PYPROJECT) . - - -# runs type checks with mypy -# at some point, need to add back --check-untyped-defs to mypy call -# but it complains when we specify arguments by keyword where positional is fine -# not sure how to fix this -.PHONY: typing -typing: gen-extra-tests - @echo "running type checks" - $(PYTHON) -m mypy --config-file $(PYPROJECT) $(TYPECHECK_ARGS) . -# $(PYTHON) -m ty check muutils/ - - -# generates a report of the mypy output -.PHONY: typing-report -typing-report: clean gen-extra-tests - @echo "generate a report of the type check output -- errors per file" - $(PYTHON) -m mypy --config-file $(PYPROJECT) $(TYPECHECK_ARGS) . | $(PYTHON) -c "$$SCRIPT_MYPY_REPORT" --mode toml - -.PHONY: test -test: clean gen-extra-tests - @echo "running tests" - $(PYTHON) -m pytest $(PYTEST_OPTIONS) $(TESTS_DIR) - - -.PHONY: check -check: clean format-check test typing - @echo "run format checks, tests, and typing checks" - - -######## ####### ###### ###### -## ## ## ## ## ## ## ## -## ## ## ## ## ## -## ## ## ## ## ###### -## ## ## ## ## ## -## ## ## ## ## ## ## ## -######## ####### ###### ###### - -# ================================================== -# coverage & docs -# ================================================== - -# generates a whole tree of documentation in html format. -# see `$(MAKE_DOCS_SCRIPT_PATH)` and the templates in `$(DOCS_RESOURCES_DIR)/templates/html/` for more info -.PHONY: docs-html -docs-html: - @echo "generate html docs" - $(PYTHON) $(MAKE_DOCS_SCRIPT_PATH) - -# instead of a whole website, generates a single markdown file with all docs using the templates in `$(DOCS_RESOURCES_DIR)/templates/markdown/`. -# this is useful if you want to have a copy that you can grep/search, but those docs are much messier. -# docs-combined will use pandoc to convert them to other formats. -.PHONY: docs-md -docs-md: - @echo "generate combined (single-file) docs in markdown" - mkdir $(DOCS_DIR)/combined -p - $(PYTHON) $(MAKE_DOCS_SCRIPT_PATH) --combined - -# after running docs-md, this will convert the combined markdown file to other formats: -# gfm (github-flavored markdown), plain text, and html -# requires pandoc in path, pointed to by $(PANDOC) -# pdf output would be nice but requires other deps -.PHONY: docs-combined -docs-combined: docs-md - @echo "generate combined (single-file) docs in markdown and convert to other formats" - @echo "requires pandoc in path" - $(PANDOC) -f markdown -t gfm $(DOCS_DIR)/combined/$(PACKAGE_NAME).md -o $(DOCS_DIR)/combined/$(PACKAGE_NAME)_gfm.md - $(PANDOC) -f markdown -t plain $(DOCS_DIR)/combined/$(PACKAGE_NAME).md -o $(DOCS_DIR)/combined/$(PACKAGE_NAME).txt - $(PANDOC) -f markdown -t html $(DOCS_DIR)/combined/$(PACKAGE_NAME).md -o $(DOCS_DIR)/combined/$(PACKAGE_NAME).html - -# generates coverage reports as html and text with `pytest-cov`, and a badge with `coverage-badge` -# if `.coverage` is not found, will run tests first -# also removes the `.gitignore` file that `coverage html` creates, since we count that as part of the docs -.PHONY: cov -cov: - @echo "generate coverage reports" - @if [ ! -f .coverage ]; then \ - echo ".coverage not found, running tests first..."; \ - $(MAKE) test; \ - fi - mkdir $(COVERAGE_REPORTS_DIR) -p - $(PYTHON) -m coverage report -m > $(COVERAGE_REPORTS_DIR)/coverage.txt - $(PYTHON) -m coverage_badge -f -o $(COVERAGE_REPORTS_DIR)/coverage.svg - $(PYTHON) -m coverage html --directory=$(COVERAGE_REPORTS_DIR)/html/ - rm -rf $(COVERAGE_REPORTS_DIR)/html/.gitignore - -# runs the coverage report, then the docs, then the combined docs -.PHONY: docs -docs: cov docs-html docs-combined todo lmcat - @echo "generate all documentation and coverage reports" - -# removed all generated documentation files, but leaves everything in `$DOCS_RESOURCES_DIR` -# and leaves things defined in `pyproject.toml:tool.makefile.docs.no_clean` -# (templates, svg, css, make_docs.py script) -# distinct from `make clean` -.PHONY: docs-clean -docs-clean: - @echo "remove generated docs except resources" - $(PYTHON) -c "$$SCRIPT_DOCS_CLEAN" $(PYPROJECT) $(DOCS_DIR) $(DOCS_RESOURCES_DIR) - -.PHONY: todo -todo: - @echo "get all TODO's from the code" - $(PYTHON) -c "$$SCRIPT_GET_TODOS" - -.PHONY: lmcat-tree -lmcat-tree: - @echo "show in console the lmcat tree view" - -$(PYTHON) -m lmcat -t --output STDOUT - -.PHONY: lmcat -lmcat: - @echo "write the lmcat full output to pyproject.toml:[tool.lmcat.output]" - -$(PYTHON) -m lmcat - -######## ## ## #### ## ######## -## ## ## ## ## ## ## ## -## ## ## ## ## ## ## ## -######## ## ## ## ## ## ## -## ## ## ## ## ## ## ## -## ## ## ## ## ## ## ## -######## ####### #### ######## ######## - -# ================================================== -# build and publish -# ================================================== - -# verifies that the current branch is $(PUBLISH_BRANCH) and that git is clean -# used before publishing -.PHONY: verify-git -verify-git: - @echo "checking git status" - if [ "$(shell git branch --show-current)" != $(PUBLISH_BRANCH) ]; then \ - echo "!!! ERROR !!!"; \ - echo "Git is not on the $(PUBLISH_BRANCH) branch, exiting!"; \ - git branch; \ - git status; \ - exit 1; \ - fi; \ - if [ -n "$(shell git status --porcelain)" ]; then \ - echo "!!! ERROR !!!"; \ - echo "Git is not clean, exiting!"; \ - git status; \ - exit 1; \ - fi; \ - - -.PHONY: build -build: - @echo "build the package" - uv build - -# gets the commit log, checks everything, builds, and then publishes with twine -# will ask the user to confirm the new version number (and this allows for editing the tag info) -# will also print the contents of $(PYPI_TOKEN_FILE) to the console for the user to copy and paste in when prompted by twine -.PHONY: publish -publish: gen-commit-log check build verify-git version gen-version-info - @echo "run all checks, build, and then publish" - - @echo "Enter the new version number if you want to upload to pypi and create a new tag" - @echo "Now would also be the time to edit $(COMMIT_LOG_FILE), as that will be used as the tag description" - @read -p "Confirm: " NEW_VERSION; \ - if [ "$$NEW_VERSION" = $(PROJ_VERSION) ]; then \ - echo "!!! ERROR !!!"; \ - echo "Version confirmed. Proceeding with publish."; \ - else \ - echo "Version mismatch, exiting: you gave $$NEW_VERSION but expected $(PROJ_VERSION)"; \ - exit 1; \ - fi; - - @echo "pypi username: __token__" - @echo "pypi token from '$(PYPI_TOKEN_FILE)' :" - echo $$(cat $(PYPI_TOKEN_FILE)) - - echo "Uploading!"; \ - echo $(PROJ_VERSION) > $(LAST_VERSION_FILE); \ - git add $(LAST_VERSION_FILE); \ - git commit -m "Auto update to $(PROJ_VERSION)"; \ - git tag -a $(PROJ_VERSION) -F $(COMMIT_LOG_FILE); \ - git push origin $(PROJ_VERSION); \ - twine upload dist/* --verbose - -# ================================================== -# cleanup of temp files -# ================================================== - -# cleans up temp files from formatter, type checking, tests, coverage -# removes all built files -# removes $(TESTS_TEMP_DIR) to remove temporary test files -# recursively removes all `__pycache__` directories and `*.pyc` or `*.pyo` files -# distinct from `make docs-clean`, which only removes generated documentation files - -.PHONY: clean-all -clean-all: clean docs-clean dep-clean - @echo "clean up all temporary files, dep files, venv, and generated docs" - - -## ## ######## ## ######## -## ## ## ## ## ## -## ## ## ## ## ## -######### ###### ## ######## -## ## ## ## ## -## ## ## ## ## -## ## ######## ######## ## - -# ================================================== -# smart help command -# ================================================== - -# listing targets is from stackoverflow -# https://stackoverflow.com/questions/4219255/how-do-you-get-the-list-of-targets-in-a-makefile -# no .PHONY because this will only be run before `make help` -# it's a separate command because getting the `info` takes a bit of time -# and we want to show the make targets right away without making the user wait for `info` to finish running -help-targets: - @echo -n "# make targets" - @echo ":" - @cat makefile | sed -n '/^\.PHONY: / h; /\(^\t@*echo\|^\t:\)/ {H; x; /PHONY/ s/.PHONY: \(.*\)\n.*"\(.*\)"/ make \1\t\2/p; d; x}'| sort -k2,2 |expand -t 30 - - -.PHONY: info -info: gen-version-info - @echo "# makefile variables" - @echo " PYTHON = $(PYTHON)" - @echo " PYTHON_VERSION = $(PYTHON_VERSION)" - @echo " PACKAGE_NAME = $(PACKAGE_NAME)" - @echo " PROJ_VERSION = $(PROJ_VERSION)" - @echo " LAST_VERSION = $(LAST_VERSION)" - @echo " PYTEST_OPTIONS = $(PYTEST_OPTIONS)" - -.PHONY: info-long -info-long: info - @echo "# other variables" - @echo " PUBLISH_BRANCH = $(PUBLISH_BRANCH)" - @echo " DOCS_DIR = $(DOCS_DIR)" - @echo " COVERAGE_REPORTS_DIR = $(COVERAGE_REPORTS_DIR)" - @echo " TESTS_DIR = $(TESTS_DIR)" - @echo " TESTS_TEMP_DIR = $(TESTS_TEMP_DIR)" - @echo " PYPROJECT = $(PYPROJECT)" - @echo " REQUIREMENTS_DIR = $(REQUIREMENTS_DIR)" - @echo " LOCAL_DIR = $(LOCAL_DIR)" - @echo " PYPI_TOKEN_FILE = $(PYPI_TOKEN_FILE)" - @echo " LAST_VERSION_FILE = $(LAST_VERSION_FILE)" - @echo " PYTHON_BASE = $(PYTHON_BASE)" - @echo " COMMIT_LOG_FILE = $(COMMIT_LOG_FILE)" - @echo " PANDOC = $(PANDOC)" - @echo " COV = $(COV)" - @echo " VERBOSE = $(VERBOSE)" - @echo " RUN_GLOBAL = $(RUN_GLOBAL)" - @echo " TYPECHECK_ARGS = $(TYPECHECK_ARGS)" - -# immediately print out the help targets, and then local variables (but those take a bit longer) -.PHONY: help -help: help-targets info - @echo -n "" - - - ###### ## ## ###### ######## ####### ## ## -## ## ## ## ## ## ## ## ## ### ### -## ## ## ## ## ## ## #### #### -## ## ## ###### ## ## ## ## ### ## -## ## ## ## ## ## ## ## ## -## ## ## ## ## ## ## ## ## ## ## - ###### ####### ###### ## ####### ## ## - -# ================================================== -# custom targets -# ================================================== -# (put them down here, or delimit with ~~~~~) \ No newline at end of file From 8a1ed6b835eba1ace1484e158316e93a00fadbec Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Wed, 29 Oct 2025 14:40:46 +0000 Subject: [PATCH 21/72] _COUNTER assert to prevent removal from imports --- tests/unit/test_dbg.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/unit/test_dbg.py b/tests/unit/test_dbg.py index b792bd54..2c87dd2d 100644 --- a/tests/unit/test_dbg.py +++ b/tests/unit/test_dbg.py @@ -21,6 +21,8 @@ _compile_pattern, ) +assert _COUNTER + DBG_MODULE_NAME: str = "muutils.dbg" @@ -228,6 +230,7 @@ def test_misc() -> None: l1 = [10, 20, 30] dbg_auto(l1) + # # --- Tests for tensor_info_dict and tensor_info --- # def test_tensor_info_dict_with_nan() -> None: # tensor: DummyTensor = DummyTensor() From c4fa438bdc48ddc7c1291eea9a9c4a97753b38a1 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Wed, 29 Oct 2025 14:42:41 +0000 Subject: [PATCH 22/72] fix type hint --- tests/unit/cli/test_arg_bool.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/unit/cli/test_arg_bool.py b/tests/unit/cli/test_arg_bool.py index 6cd8ae50..10631ff4 100644 --- a/tests/unit/cli/test_arg_bool.py +++ b/tests/unit/cli/test_arg_bool.py @@ -503,6 +503,7 @@ def test_add_bool_flag_default_help(): break assert action is not None + assert action.help is not None assert "enable/disable my feature" in action.help From 16b5e64869830f9e90f5d2ab2869a080ea7613c7 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Wed, 29 Oct 2025 16:02:16 +0000 Subject: [PATCH 23/72] type fix --- tests/unit/validate_type/test_validate_type.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/validate_type/test_validate_type.py b/tests/unit/validate_type/test_validate_type.py index f7fb76b8..51e5e2ae 100644 --- a/tests/unit/validate_type/test_validate_type.py +++ b/tests/unit/validate_type/test_validate_type.py @@ -224,7 +224,7 @@ def test_validate_type_set(value, expected_type, expected_result): (1, "a", 3.14, "b", True, None, (1, 2, 3)), # no idea why this throws type error, only locally, and only for the generated modern types typing.Tuple[ # type: ignore[misc] - int, str, float, str, bool, type(None), typing.Tuple[int, int, int] + int, str, float, str, bool, None, typing.Tuple[int, int, int] ], True, ), From 3aada0105ed17d5502b78d41f913eae810a972b0 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Wed, 29 Oct 2025 16:02:35 +0000 Subject: [PATCH 24/72] will be doing lots of type checking fixes B) --- .meta/requirements/requirements-all.txt | 3 + .meta/requirements/requirements-dev.txt | 3 + .meta/requirements/requirements.txt | 3 + TODO.md | 90 +++++++++++++++++++++++++ makefile | 4 +- pyproject.toml | 1 + uv.lock | 30 +++++++++ 7 files changed, 133 insertions(+), 1 deletion(-) create mode 100644 TODO.md diff --git a/.meta/requirements/requirements-all.txt b/.meta/requirements/requirements-all.txt index fb40ed6c..5daae08c 100644 --- a/.meta/requirements/requirements-all.txt +++ b/.meta/requirements/requirements-all.txt @@ -47,6 +47,7 @@ backports-tarfile==1.2.0 ; python_full_version < '3.12' and platform_machine != # via jaraco-context backports-zoneinfo==0.2.1 ; python_full_version < '3.9' # via arrow +basedpyright==1.32.1 beartype==0.19.0 ; python_full_version < '3.9' beartype==0.22.2 ; python_full_version == '3.9.*' beartype==0.22.4 ; python_full_version >= '3.10' @@ -480,6 +481,8 @@ networkx==3.5 ; python_full_version >= '3.11' and python_full_version < '3.14' # via torch nh3==0.3.1 # via readme-renderer +nodejs-wheel-binaries==22.20.0 + # via basedpyright notebook==7.3.3 ; python_full_version < '3.9' # via jupyter notebook==7.4.7 ; python_full_version >= '3.9' diff --git a/.meta/requirements/requirements-dev.txt b/.meta/requirements/requirements-dev.txt index 73ec345b..45c5df2c 100644 --- a/.meta/requirements/requirements-dev.txt +++ b/.meta/requirements/requirements-dev.txt @@ -47,6 +47,7 @@ backports-tarfile==1.2.0 ; python_full_version < '3.12' and platform_machine != # via jaraco-context backports-zoneinfo==0.2.1 ; python_full_version < '3.9' # via arrow +basedpyright==1.32.1 beartype==0.19.0 ; python_full_version < '3.9' beartype==0.22.2 ; python_full_version == '3.9.*' beartype==0.22.4 ; python_full_version >= '3.10' @@ -422,6 +423,8 @@ nest-asyncio==1.6.0 # via ipykernel nh3==0.3.1 # via readme-renderer +nodejs-wheel-binaries==22.20.0 + # via basedpyright notebook==7.3.3 ; python_full_version < '3.9' # via jupyter notebook==7.4.7 ; python_full_version >= '3.9' diff --git a/.meta/requirements/requirements.txt b/.meta/requirements/requirements.txt index d486ae25..0f737086 100644 --- a/.meta/requirements/requirements.txt +++ b/.meta/requirements/requirements.txt @@ -47,6 +47,7 @@ backports-tarfile==1.2.0 ; python_full_version < '3.12' and platform_machine != # via jaraco-context backports-zoneinfo==0.2.1 ; python_full_version < '3.9' # via arrow +basedpyright==1.32.1 beartype==0.19.0 ; python_full_version < '3.9' beartype==0.22.2 ; python_full_version == '3.9.*' beartype==0.22.4 ; python_full_version >= '3.10' @@ -480,6 +481,8 @@ networkx==3.5 ; python_full_version >= '3.11' and python_full_version < '3.14' # via torch nh3==0.3.1 # via readme-renderer +nodejs-wheel-binaries==22.20.0 + # via basedpyright notebook==7.3.3 ; python_full_version < '3.9' # via jupyter notebook==7.4.7 ; python_full_version >= '3.9' diff --git a/TODO.md b/TODO.md new file mode 100644 index 00000000..7db97076 --- /dev/null +++ b/TODO.md @@ -0,0 +1,90 @@ +# Mypy Type Error Fixes + +## Source Files (Priority 1) + +### muutils/jsonlines.py +- Line 63-68: Change `items: Sequence[JSONitem]` to `items: Sequence[Any]` +- Current signature is too restrictive for actual usage patterns + +### muutils/logger/log_util.py +- Line 47: Change `keys: tuple[str]` to `keys: tuple[str, ...]` in `gather_val()` +- Line 5: Change return type from `None` to `Any` in `get_any_from_stream()` + +### muutils/json_serialize/array.py +- Line 167: Change `load_array()` return type from `Any` to `np.ndarray` + +## Test Files (Priority 2) + +### tests/unit/json_serialize/test_json_serialize.py +Lines 129-130, 151-152, 496-497, 502-505, 709-711, 745: +- Add `assert isinstance(result, dict)` before dict operations +- Or use `cast(dict[str, Any], result)` + +Lines 266, 723, 730: +- Add `assert isinstance(result["key"], list)` before calling `set()` + +Line 656: +- Pass `error_mode=ErrorMode.EXCEPT` instead of string + +### tests/unit/json_serialize/test_serializable_field.py +Line 127: +- Remove incorrect indexing of Field object + +Line 143: +- Add type annotation: `dc_field2: Field[list] = field(default_factory=list)` + +Lines 312-334: +- Fix variable assignments - `serializable_field()` returns SerializableField, not primitive values +- Don't access `.default`, `.repr`, `.hash`, etc. on non-Field objects + +### tests/unit/json_serialize/test_array.py +Lines 122, 125, 128, 138, 195-197: +- Add `assert isinstance(loaded, np.ndarray)` after `load_array()` calls + +Lines 154-156, 166-167: +- Add dict type assertions before indexing serialized results + +### tests/unit/json_serialize/test_array_torch.py +Lines 23, 30, 37: +- Add `assert isinstance(shape, list)` before `in` checks + +Lines 126, 146: +- Add type ignore or change signature to accept torch.Tensor + +Lines 191-193, 195-196, 198, 201-202, 221-224, 227-228: +- Add dict type assertions before indexing + +### tests/unit/test_jsonlines.py +Lines 40-41: +- Add dict type assertions before indexing + +Lines 59, 94, 106, 112, 136, 150, 167, 188, 192: +- Will be fixed by jsonl_write signature change + +### tests/unit/logger/test_log_util.py +Lines 35, 77, 116: +- Will be fixed by jsonl_write signature change + +Lines 119, 129, 137, 140, 145: +- Will be fixed by gather_val signature change + +Lines 159, 163, 167: +- Will be fixed by get_any_from_stream return type change + +### tests/unit/benchmark_parallel/benchmark_parallel.py +Lines 197, 208, 219, 234, 251 (task_type, save_path), 283, 308: +- Add `| None` to type hints: `param: Type = None` → `param: Type | None = None` + +Line 399: +- Change parameter type from `list[int]` to `Sequence[int]` or cast argument + +Lines 416, 419, 422: +- Change parameter types to accept `str | Path` or cast Path to str + +### tests/unit/benchmark_parallel/test_benchmark_demo.py +Line 10: +- Change `main()` signature to accept `base_path: str | Path` +- Or wrap string in `Path()` call + +## Total Errors +157 errors across 8 files diff --git a/makefile b/makefile index 34dce67b..c46723c3 100644 --- a/makefile +++ b/makefile @@ -1766,7 +1766,9 @@ format-check: .PHONY: typing typing: gen-extra-tests @echo "running type checks" - $(PYTHON) -m mypy --config-file $(PYPROJECT) $(TYPECHECK_ARGS) . + $(PYTHON) -m mypy --config-file $(PYPROJECT) $(TYPECHECK_ARGS) . || true + $(PYTHON) -m basedpyright check . || true + $(PYTHON) -m ty check . || true # $(PYTHON) -m ty check muutils/ # generate summary report of type check errors grouped by file diff --git a/pyproject.toml b/pyproject.toml index e27daa60..e639eae7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -73,6 +73,7 @@ "typing-extensions; python_version < '3.11'", "beartype>=0.14.1", "ty", + "basedpyright", # tests & coverage "pytest>=8.2.2", "pytest-cov>=4.1.0", diff --git a/uv.lock b/uv.lock index 82ba38dc..150f1ce6 100644 --- a/uv.lock +++ b/uv.lock @@ -290,6 +290,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/78/cc/e27fd6493bbce8dbea7e6c1bc861fe3d3bc22c4f7c81f4c3befb8ff5bfaf/backports.zoneinfo-0.2.1-cp38-cp38-win_amd64.whl", hash = "sha256:4a0f800587060bf8880f954dbef70de6c11bbe59c673c3d818921f042f9954a6", size = 38967, upload-time = "2020-06-23T13:51:13.735Z" }, ] +[[package]] +name = "basedpyright" +version = "1.32.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "nodejs-wheel-binaries" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/4f/a5/691d02a30bda15acb6a5727bb696dd7f3fcae1ad5b9f2708020c2645af8c/basedpyright-1.32.1.tar.gz", hash = "sha256:ce979891a3c4649e7c31d665acb06fd451f33fedfd500bc7796ee0950034aa54", size = 22757919, upload-time = "2025-10-23T12:53:28.169Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f8/d5/17d24fd7ba9d899b82859ee04f4599a1e8a02a85c0753bc15eb3ca7ffff7/basedpyright-1.32.1-py3-none-any.whl", hash = "sha256:06b5cc56693e3690653955e19fbe5d2e38f2a343563b40ef95fd1b10fa556fb6", size = 11841548, upload-time = "2025-10-23T12:53:25.541Z" }, +] + [[package]] name = "beartype" version = "0.19.0" @@ -4088,6 +4100,7 @@ web = [ [package.dev-dependencies] dev = [ + { name = "basedpyright" }, { name = "beartype", version = "0.19.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.9'" }, { name = "beartype", version = "0.22.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version == '3.9.*'" }, { name = "beartype", version = "0.22.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.10'" }, @@ -4149,6 +4162,7 @@ provides-extras = ["array", "array-no-torch", "notebook", "parallel", "web"] [package.metadata.requires-dev] dev = [ + { name = "basedpyright" }, { name = "beartype", specifier = ">=0.14.1" }, { name = "beautifulsoup4" }, { name = "coverage-badge", specifier = ">=1.1.0" }, @@ -4510,6 +4524,22 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/34/67/d5e07efd38194f52b59b8af25a029b46c0643e9af68204ee263022924c27/nh3-0.3.1-cp38-abi3-win_arm64.whl", hash = "sha256:a3e810a92fb192373204456cac2834694440af73d749565b4348e30235da7f0b", size = 586369, upload-time = "2025-10-07T03:27:57.234Z" }, ] +[[package]] +name = "nodejs-wheel-binaries" +version = "22.20.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/0f/54/02f58c8119e2f1984e2572cc77a7b469dbaf4f8d171ad376e305749ef48e/nodejs_wheel_binaries-22.20.0.tar.gz", hash = "sha256:a62d47c9fd9c32191dff65bbe60261504f26992a0a19fe8b4d523256a84bd351", size = 8058, upload-time = "2025-09-26T09:48:00.906Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/24/6d/333e5458422f12318e3c3e6e7f194353aa68b0d633217c7e89833427ca01/nodejs_wheel_binaries-22.20.0-py2.py3-none-macosx_11_0_arm64.whl", hash = "sha256:455add5ac4f01c9c830ab6771dbfad0fdf373f9b040d3aabe8cca9b6c56654fb", size = 53246314, upload-time = "2025-09-26T09:47:32.536Z" }, + { url = "https://files.pythonhosted.org/packages/56/30/dcd6879d286a35b3c4c8f9e5e0e1bcf4f9e25fe35310fc77ecf97f915a23/nodejs_wheel_binaries-22.20.0-py2.py3-none-macosx_11_0_x86_64.whl", hash = "sha256:5d8c12f97eea7028b34a84446eb5ca81829d0c428dfb4e647e09ac617f4e21fa", size = 53644391, upload-time = "2025-09-26T09:47:36.093Z" }, + { url = "https://files.pythonhosted.org/packages/58/be/c7b2e7aa3bb281d380a1c531f84d0ccfe225832dfc3bed1ca171753b9630/nodejs_wheel_binaries-22.20.0-py2.py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7a2b0989194148f66e9295d8f11bc463bde02cbe276517f4d20a310fb84780ae", size = 60282516, upload-time = "2025-09-26T09:47:39.88Z" }, + { url = "https://files.pythonhosted.org/packages/3e/c5/8befacf4190e03babbae54cb0809fb1a76e1600ec3967ab8ee9f8fc85b65/nodejs_wheel_binaries-22.20.0-py2.py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b5c500aa4dc046333ecb0a80f183e069e5c30ce637f1c1a37166b2c0b642dc21", size = 60347290, upload-time = "2025-09-26T09:47:43.712Z" }, + { url = "https://files.pythonhosted.org/packages/c0/bd/cfffd1e334277afa0714962c6ec432b5fe339340a6bca2e5fa8e678e7590/nodejs_wheel_binaries-22.20.0-py2.py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:3279eb1b99521f0d20a850bbfc0159a658e0e85b843b3cf31b090d7da9f10dfc", size = 62178798, upload-time = "2025-09-26T09:47:47.752Z" }, + { url = "https://files.pythonhosted.org/packages/08/14/10b83a9c02faac985b3e9f5e65d63a34fc0f46b48d8a2c3e4caa3e1e7318/nodejs_wheel_binaries-22.20.0-py2.py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:d29705797b33bade62d79d8f106c2453c8a26442a9b2a5576610c0f7e7c351ed", size = 62772957, upload-time = "2025-09-26T09:47:51.266Z" }, + { url = "https://files.pythonhosted.org/packages/b4/a9/c6a480259aa0d6b270aac2c6ba73a97444b9267adde983a5b7e34f17e45a/nodejs_wheel_binaries-22.20.0-py2.py3-none-win_amd64.whl", hash = "sha256:4bd658962f24958503541963e5a6f2cc512a8cb301e48a69dc03c879f40a28ae", size = 40120431, upload-time = "2025-09-26T09:47:54.363Z" }, + { url = "https://files.pythonhosted.org/packages/42/b1/6a4eb2c6e9efa028074b0001b61008c9d202b6b46caee9e5d1b18c088216/nodejs_wheel_binaries-22.20.0-py2.py3-none-win_arm64.whl", hash = "sha256:1fccac931faa210d22b6962bcdbc99269d16221d831b9a118bbb80fe434a60b8", size = 38844133, upload-time = "2025-09-26T09:47:57.357Z" }, +] + [[package]] name = "notebook" version = "7.3.3" From 6f5736c2fbc8f6d2d669cafb4d3741bec9d7c8a4 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Wed, 29 Oct 2025 16:23:45 +0000 Subject: [PATCH 25/72] typing-summary recipe we got like 20k lines of errors lol --- .gitignore | 1 + makefile | 13 ++++++++++++- 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 50c217e8..370eb29f 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,6 @@ # local stuff (pypi token, commit log) .meta/local/** +.meta/.type-errors/** .coverage.* # this one is cursed diff --git a/makefile b/makefile index c46723c3..4a9ab232 100644 --- a/makefile +++ b/makefile @@ -1767,10 +1767,21 @@ format-check: typing: gen-extra-tests @echo "running type checks" $(PYTHON) -m mypy --config-file $(PYPROJECT) $(TYPECHECK_ARGS) . || true - $(PYTHON) -m basedpyright check . || true + $(PYTHON) -m basedpyright . || true $(PYTHON) -m ty check . || true # $(PYTHON) -m ty check muutils/ +.PHONY: typing-summary +typing-summary: gen-extra-tests + @echo "running type checks and saving to .meta/.type-errors/" + @mkdir -p .meta/.type-errors + @$(PYTHON) -m mypy --config-file $(PYPROJECT) $(TYPECHECK_ARGS) . > .meta/.type-errors/mypy.txt 2>&1 || true + @$(PYTHON) -m basedpyright . > .meta/.type-errors/basedpyright.txt 2>&1 || true + @$(PYTHON) -m ty check . > .meta/.type-errors/ty.txt 2>&1 || true + @echo "mypy: $$(tail -n 1 .meta/.type-errors/mypy.txt)" + @echo "basedpyright: $$(tail -n 1 .meta/.type-errors/basedpyright.txt)" + @echo "ty: $$(tail -n 1 .meta/.type-errors/ty.txt)" + # generate summary report of type check errors grouped by file # outputs TOML format showing error count per file # useful for tracking typing progress across large codebases From 83eec0a9cba92aaa5518e548a75bd6d9291fd232 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Thu, 30 Oct 2025 13:56:22 +0000 Subject: [PATCH 26/72] some typing fixes --- muutils/logger/log_util.py | 7 +++++-- .../unit/benchmark_parallel/benchmark_parallel.py | 14 +++++++------- .../unit/benchmark_parallel/test_benchmark_demo.py | 4 +++- tests/unit/test_statcounter.py | 4 ++-- 4 files changed, 17 insertions(+), 12 deletions(-) diff --git a/muutils/logger/log_util.py b/muutils/logger/log_util.py index 8b93972b..94e8f949 100644 --- a/muutils/logger/log_util.py +++ b/muutils/logger/log_util.py @@ -1,8 +1,11 @@ from __future__ import annotations +from typing import TypeVar from muutils.jsonlines import jsonl_load_log +T_StreamValue = TypeVar("T_StreamValue") -def get_any_from_stream(stream: list[dict], key: str) -> None: + +def get_any_from_stream(stream: list[dict[str, T_StreamValue]], key: str) -> T_StreamValue: """get the first value of a key from a stream. errors if not found""" for msg in stream: if key in msg: @@ -44,7 +47,7 @@ def gather_stream( def gather_val( file: str, stream: str, - keys: tuple[str], + keys: tuple[str, ...], allow_skip: bool = True, ) -> list[list]: """gather specific keys from a specific stream in a log file diff --git a/tests/unit/benchmark_parallel/benchmark_parallel.py b/tests/unit/benchmark_parallel/benchmark_parallel.py index 29a7391f..f2534bf9 100644 --- a/tests/unit/benchmark_parallel/benchmark_parallel.py +++ b/tests/unit/benchmark_parallel/benchmark_parallel.py @@ -194,7 +194,7 @@ def benchmark_sequential(func: Callable, data: List[int]) -> Tuple[List[Any], fl def benchmark_pool_map( - func: Callable, data: List[int], processes: int = None + func: Callable, data: List[int], processes: int | None = None ) -> Tuple[List[Any], float]: """Benchmark using multiprocessing.Pool.map.""" start = time.perf_counter() @@ -205,7 +205,7 @@ def benchmark_pool_map( def benchmark_pool_imap( - func: Callable, data: List[int], processes: int = None, chunksize: int = 1 + func: Callable, data: List[int], processes: int | None = None, chunksize: int = 1 ) -> Tuple[List[Any], float]: """Benchmark using multiprocessing.Pool.imap.""" start = time.perf_counter() @@ -216,7 +216,7 @@ def benchmark_pool_imap( def benchmark_pool_imap_unordered( - func: Callable, data: List[int], processes: int = None, chunksize: int = 1 + func: Callable, data: List[int], processes: int | None = None, chunksize: int = 1 ) -> Tuple[List[Any], float]: """Benchmark using multiprocessing.Pool.imap_unordered.""" start = time.perf_counter() @@ -231,7 +231,7 @@ def benchmark_run_maybe_parallel( data: List[int], parallel: Union[bool, int], keep_ordered: bool = True, - chunksize: int = None, + chunksize: int | None = None, ) -> Tuple[List[Any], float]: """Benchmark using run_maybe_parallel.""" start = time.perf_counter() @@ -248,7 +248,7 @@ def benchmark_run_maybe_parallel( def plot_speedup_by_data_size( - df: pd.DataFrame, task_type: str = None, save_path: str = None + df: pd.DataFrame, task_type: str | None = None, save_path: str | None = None ): """Plot speedup vs data size for different methods.""" import matplotlib.pyplot as plt # type: ignore[import-untyped] @@ -280,7 +280,7 @@ def plot_speedup_by_data_size( def plot_timing_comparison( - df: pd.DataFrame, data_size: int = None, save_path: str = None + df: pd.DataFrame, data_size: int | None = None, save_path: str | None = None ): """Plot timing comparison as bar chart.""" import matplotlib.pyplot as plt # type: ignore[import-untyped] @@ -305,7 +305,7 @@ def plot_timing_comparison( plt.show() -def plot_efficiency_heatmap(df: pd.DataFrame, save_path: str = None): +def plot_efficiency_heatmap(df: pd.DataFrame, save_path: str | None = None): """Plot efficiency heatmap (speedup across methods and tasks).""" import matplotlib.pyplot as plt # type: ignore[import-untyped] diff --git a/tests/unit/benchmark_parallel/test_benchmark_demo.py b/tests/unit/benchmark_parallel/test_benchmark_demo.py index 0184fef6..b7af7f19 100644 --- a/tests/unit/benchmark_parallel/test_benchmark_demo.py +++ b/tests/unit/benchmark_parallel/test_benchmark_demo.py @@ -1,5 +1,7 @@ """Simple demo of using the benchmark script.""" +from pathlib import Path + from benchmark_parallel import io_bound_task, light_cpu_task, main @@ -7,7 +9,7 @@ def test_main(): """Test the main function of the benchmark script.""" main( data_sizes=(1, 2), - base_path="tests/_temp/benchmark_demo", + base_path=Path("tests/_temp/benchmark_demo"), plot=True, task_funcs={ "io_bound": io_bound_task, diff --git a/tests/unit/test_statcounter.py b/tests/unit/test_statcounter.py index caeb71ed..c3dba2d5 100644 --- a/tests/unit/test_statcounter.py +++ b/tests/unit/test_statcounter.py @@ -5,7 +5,7 @@ from muutils.statcounter import StatCounter -def _compute_err(a: float, b: float, /) -> dict[str, float]: +def _compute_err(a: float, b: float | np.floating, /) -> dict[str, float]: return dict( num_a=float(a), num_b=float(b), @@ -14,7 +14,7 @@ def _compute_err(a: float, b: float, /) -> dict[str, float]: ) -def _compare_np_custom(arr: np.ndarray) -> dict[str, dict]: +def _compare_np_custom(arr: np.ndarray) -> dict[str, dict[str, float]]: counter: StatCounter = StatCounter(arr) return dict( mean=_compute_err(counter.mean(), np.mean(arr)), From d004e54c31d92d5b081ede0fc45ffbca67db3fae Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Thu, 30 Oct 2025 15:18:52 +0000 Subject: [PATCH 27/72] make JSONitem type recursive! --- muutils/json_serialize/util.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/muutils/json_serialize/util.py b/muutils/json_serialize/util.py index 902a2ee5..9253594d 100644 --- a/muutils/json_serialize/util.py +++ b/muutils/json_serialize/util.py @@ -28,9 +28,8 @@ JSONitem = Union[ BaseType, - # mypy doesn't like recursive types, so we just go down a few levels manually - typing.List[Union[BaseType, typing.List[Any], typing.Dict[str, Any]]], - typing.Dict[str, Union[BaseType, typing.List[Any], typing.Dict[str, Any]]], + typing.List["JSONitem"], + typing.Dict[str, "JSONitem"], ] JSONdict = typing.Dict[str, JSONitem] From 16248032d4912ebdde920a7c6cede5c5a7c61cd3 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Thu, 30 Oct 2025 16:34:41 +0000 Subject: [PATCH 28/72] typing --- muutils/json_serialize/util.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/muutils/json_serialize/util.py b/muutils/json_serialize/util.py index 9253594d..c2309006 100644 --- a/muutils/json_serialize/util.py +++ b/muutils/json_serialize/util.py @@ -263,8 +263,8 @@ def dc_eq( f"Cannot compare dataclasses of different classes: `{dc1.__class__}` and `{dc2.__class__}`" ) if except_when_field_mismatch: - dc1_fields: set = set([fld.name for fld in dataclasses.fields(dc1)]) - dc2_fields: set = set([fld.name for fld in dataclasses.fields(dc2)]) + dc1_fields: set[str] = set([fld.name for fld in dataclasses.fields(dc1)]) + dc2_fields: set[str] = set([fld.name for fld in dataclasses.fields(dc2)]) fields_match: bool = set(dc1_fields) == set(dc2_fields) if not fields_match: # if the fields match, keep going From 8a8de950b5879524100f7836fa2b77c6fd81f8b9 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Thu, 30 Oct 2025 16:34:59 +0000 Subject: [PATCH 29/72] fix assert (counter was 0 lol) --- tests/unit/test_dbg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/test_dbg.py b/tests/unit/test_dbg.py index 2c87dd2d..9cf04158 100644 --- a/tests/unit/test_dbg.py +++ b/tests/unit/test_dbg.py @@ -21,7 +21,7 @@ _compile_pattern, ) -assert _COUNTER +assert _COUNTER is not None DBG_MODULE_NAME: str = "muutils.dbg" From 3dd88f507a3554bc17a88815759e06718131a0a4 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Thu, 30 Oct 2025 16:35:14 +0000 Subject: [PATCH 30/72] typing fixes --- tests/unit/test_statcounter.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/unit/test_statcounter.py b/tests/unit/test_statcounter.py index c3dba2d5..1013ba60 100644 --- a/tests/unit/test_statcounter.py +++ b/tests/unit/test_statcounter.py @@ -5,13 +5,14 @@ from muutils.statcounter import StatCounter -def _compute_err(a: float, b: float | np.floating, /) -> dict[str, float]: - return dict( +def _compute_err(a: float, b: float | np.floating, /) -> dict[str, int | float]: + result: dict[str, int | float] = dict( num_a=float(a), num_b=float(b), diff=float(b - a), # frac_err=float((b - a) / a), # this causes division by zero, whatever ) + return result def _compare_np_custom(arr: np.ndarray) -> dict[str, dict[str, float]]: From 9d0824e7e9215ea04e3c3ba57560b740ff1e50a9 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Thu, 30 Oct 2025 16:35:22 +0000 Subject: [PATCH 31/72] better type checking config --- makefile | 6 +++--- pyproject.toml | 23 +++++++++++++++++++++++ 2 files changed, 26 insertions(+), 3 deletions(-) diff --git a/makefile b/makefile index 4a9ab232..7e90c104 100644 --- a/makefile +++ b/makefile @@ -1775,9 +1775,9 @@ typing: gen-extra-tests typing-summary: gen-extra-tests @echo "running type checks and saving to .meta/.type-errors/" @mkdir -p .meta/.type-errors - @$(PYTHON) -m mypy --config-file $(PYPROJECT) $(TYPECHECK_ARGS) . > .meta/.type-errors/mypy.txt 2>&1 || true - @$(PYTHON) -m basedpyright . > .meta/.type-errors/basedpyright.txt 2>&1 || true - @$(PYTHON) -m ty check . > .meta/.type-errors/ty.txt 2>&1 || true + $(PYTHON) -m mypy --config-file $(PYPROJECT) $(TYPECHECK_ARGS) . > .meta/.type-errors/mypy.txt 2>&1 || true + $(PYTHON) -m basedpyright . > .meta/.type-errors/basedpyright.txt 2>&1 || true + $(PYTHON) -m ty check . > .meta/.type-errors/ty.txt 2>&1 || true @echo "mypy: $$(tail -n 1 .meta/.type-errors/mypy.txt)" @echo "basedpyright: $$(tail -n 1 .meta/.type-errors/basedpyright.txt)" @echo "ty: $$(tail -n 1 .meta/.type-errors/ty.txt)" diff --git a/pyproject.toml b/pyproject.toml index e639eae7..da311947 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -141,6 +141,29 @@ all = true exclude = ["tests/input_data", "tests/junk_data", "_wip/"] +[tool.basedpyright] + include = ["muutils", "tests"] + reportConstantRedefinition = false # I always use all caps for globals, not just consts + reportDeprecated = false # this library is backwards compatible back to 3.8, so we are using lots of deprecated stuff + exclude = [ + "tests/input_data", + "tests/junk_data", + "tests/_temp", + "_wip", + "docs/resources/make_docs.py", + ".venv", + ] + +[tool.ty.src] + exclude = [ + "tests/input_data/", + "tests/junk_data/", + "tests/_temp/", + "tests/benchmark_parallel.py", + "_wip/", + "docs/resources/make_docs.py", + ] + [tool.mypy] exclude = [ # tests From 6997c6c6e0f9175a704d6578bb51e35b2b2cfb3f Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Thu, 30 Oct 2025 16:35:32 +0000 Subject: [PATCH 32/72] todos for claude --- TODO.md | 94 ++++++--------------------------------------------------- 1 file changed, 10 insertions(+), 84 deletions(-) diff --git a/TODO.md b/TODO.md index 7db97076..240e1ae9 100644 --- a/TODO.md +++ b/TODO.md @@ -1,90 +1,16 @@ -# Mypy Type Error Fixes +# Type Error Fixing TODO -## Source Files (Priority 1) +## Instructions -### muutils/jsonlines.py -- Line 63-68: Change `items: Sequence[JSONitem]` to `items: Sequence[Any]` -- Current signature is too restrictive for actual usage patterns +1. Read the type checker output files: + - `.meta/.type-errors/mypy.txt` + - `.meta/.type-errors/basedpyright.txt` + - `.meta/.type-errors/ty.txt` -### muutils/logger/log_util.py -- Line 47: Change `keys: tuple[str]` to `keys: tuple[str, ...]` in `gather_val()` -- Line 5: Change return type from `None` to `Any` in `get_any_from_stream()` + NOTE: the latter two files are many thousands of lines, you will have to pick the first few or last few hundred lines to get a sense of the errors. -### muutils/json_serialize/array.py -- Line 167: Change `load_array()` return type from `Any` to `np.ndarray` +2. Find the fix with the best **"number of errors / complexity of change" ratio** -## Test Files (Priority 2) +3. Implement that fix -### tests/unit/json_serialize/test_json_serialize.py -Lines 129-130, 151-152, 496-497, 502-505, 709-711, 745: -- Add `assert isinstance(result, dict)` before dict operations -- Or use `cast(dict[str, Any], result)` - -Lines 266, 723, 730: -- Add `assert isinstance(result["key"], list)` before calling `set()` - -Line 656: -- Pass `error_mode=ErrorMode.EXCEPT` instead of string - -### tests/unit/json_serialize/test_serializable_field.py -Line 127: -- Remove incorrect indexing of Field object - -Line 143: -- Add type annotation: `dc_field2: Field[list] = field(default_factory=list)` - -Lines 312-334: -- Fix variable assignments - `serializable_field()` returns SerializableField, not primitive values -- Don't access `.default`, `.repr`, `.hash`, etc. on non-Field objects - -### tests/unit/json_serialize/test_array.py -Lines 122, 125, 128, 138, 195-197: -- Add `assert isinstance(loaded, np.ndarray)` after `load_array()` calls - -Lines 154-156, 166-167: -- Add dict type assertions before indexing serialized results - -### tests/unit/json_serialize/test_array_torch.py -Lines 23, 30, 37: -- Add `assert isinstance(shape, list)` before `in` checks - -Lines 126, 146: -- Add type ignore or change signature to accept torch.Tensor - -Lines 191-193, 195-196, 198, 201-202, 221-224, 227-228: -- Add dict type assertions before indexing - -### tests/unit/test_jsonlines.py -Lines 40-41: -- Add dict type assertions before indexing - -Lines 59, 94, 106, 112, 136, 150, 167, 188, 192: -- Will be fixed by jsonl_write signature change - -### tests/unit/logger/test_log_util.py -Lines 35, 77, 116: -- Will be fixed by jsonl_write signature change - -Lines 119, 129, 137, 140, 145: -- Will be fixed by gather_val signature change - -Lines 159, 163, 167: -- Will be fixed by get_any_from_stream return type change - -### tests/unit/benchmark_parallel/benchmark_parallel.py -Lines 197, 208, 219, 234, 251 (task_type, save_path), 283, 308: -- Add `| None` to type hints: `param: Type = None` → `param: Type | None = None` - -Line 399: -- Change parameter type from `list[int]` to `Sequence[int]` or cast argument - -Lines 416, 419, 422: -- Change parameter types to accept `str | Path` or cast Path to str - -### tests/unit/benchmark_parallel/test_benchmark_demo.py -Line 10: -- Change `main()` signature to accept `base_path: str | Path` -- Or wrap string in `Path()` call - -## Total Errors -157 errors across 8 files +run type checking only on the specific file you are changing to verify that the errors are fixed. \ No newline at end of file From e90cfa0d7a30e19e530e69700e966e1406a2df46 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Thu, 30 Oct 2025 17:33:14 +0000 Subject: [PATCH 33/72] fix array stuff, mostly for type hints. no more 0-dim list mode --- muutils/json_serialize/array.py | 78 +++++++++++++++++-------- tests/unit/json_serialize/test_array.py | 7 ++- 2 files changed, 58 insertions(+), 27 deletions(-) diff --git a/muutils/json_serialize/array.py b/muutils/json_serialize/array.py index 73514134..f09a65a5 100644 --- a/muutils/json_serialize/array.py +++ b/muutils/json_serialize/array.py @@ -11,7 +11,7 @@ import base64 import typing import warnings -from typing import Any, Iterable, Literal, Optional, Sequence +from typing import Any, Iterable, Literal, Optional, Sequence, TypedDict try: import numpy as np @@ -25,6 +25,12 @@ # pylint: disable=unused-argument +# Recursive type for nested numeric lists (output of arr.tolist()) +NumericList = typing.Union[ + typing.List[typing.Union[int, float, bool]], + typing.List["NumericList"], +] + ArrayMode = Literal[ "list", "array_list_meta", @@ -45,7 +51,23 @@ def array_n_elements(arr) -> int: # type: ignore[name-defined] raise TypeError(f"invalid type: {type(arr)}") -def arr_metadata(arr) -> dict[str, list[int] | str | int]: +class ArrayMetadata(TypedDict): + """Metadata for a numpy/torch array""" + shape: list[int] + dtype: str + n_elements: int + + +class SerializedArrayWithMeta(TypedDict): + """Serialized array with metadata (for array_list_meta, array_hex_meta, array_b64_meta, zero_dim modes)""" + __muutils_format__: str + data: typing.Union[NumericList, str, int, float, bool] # list, hex str, b64 str, or scalar for zero_dim + shape: list[int] + dtype: str + n_elements: int + + +def arr_metadata(arr) -> ArrayMetadata: """get metadata for a numpy array""" return { "shape": list(arr.shape), @@ -61,7 +83,7 @@ def serialize_array( arr: np.ndarray, path: str | Sequence[str | int], array_mode: ArrayMode | None = None, -) -> JSONitem: +) -> SerializedArrayWithMeta | NumericList: """serialize a numpy or pytorch array in one of several modes if the object is zero-dimensional, simply get the unique item @@ -101,34 +123,40 @@ def serialize_array( arr_type: str = f"{type(arr).__module__}.{type(arr).__name__}" arr_np: np.ndarray = arr if isinstance(arr, np.ndarray) else np.array(arr) + # Handle list mode first (no metadata needed) + if array_mode == "list": + return arr_np.tolist() + + # For all other modes, compute metadata once + metadata: ArrayMetadata = arr_metadata(arr if len(arr.shape) == 0 else arr_np) + # handle zero-dimensional arrays if len(arr.shape) == 0: - return { - _FORMAT_KEY: f"{arr_type}:zero_dim", - "data": arr.item(), - **arr_metadata(arr), - } + return SerializedArrayWithMeta( + __muutils_format__=f"{arr_type}:zero_dim", + data=arr.item(), + **metadata, + ) + # Handle the metadata modes if array_mode == "array_list_meta": - return { - _FORMAT_KEY: f"{arr_type}:array_list_meta", - "data": arr_np.tolist(), - **arr_metadata(arr_np), - } - elif array_mode == "list": - return arr_np.tolist() + return SerializedArrayWithMeta( + __muutils_format__=f"{arr_type}:array_list_meta", + data=arr_np.tolist(), + **metadata, + ) elif array_mode == "array_hex_meta": - return { - _FORMAT_KEY: f"{arr_type}:array_hex_meta", - "data": arr_np.tobytes().hex(), - **arr_metadata(arr_np), - } + return SerializedArrayWithMeta( + __muutils_format__=f"{arr_type}:array_hex_meta", + data=arr_np.tobytes().hex(), + **metadata, + ) elif array_mode == "array_b64_meta": - return { - _FORMAT_KEY: f"{arr_type}:array_b64_meta", - "data": base64.b64encode(arr_np.tobytes()).decode(), - **arr_metadata(arr_np), - } + return SerializedArrayWithMeta( + __muutils_format__=f"{arr_type}:array_b64_meta", + data=base64.b64encode(arr_np.tobytes()).decode(), + **metadata, + ) else: raise KeyError(f"invalid array_mode: {array_mode}") diff --git a/tests/unit/json_serialize/test_array.py b/tests/unit/json_serialize/test_array.py index bd4c46df..d65f0f12 100644 --- a/tests/unit/json_serialize/test_array.py +++ b/tests/unit/json_serialize/test_array.py @@ -14,7 +14,7 @@ # pylint: disable=missing-class-docstring -class TestYourModule: +class TestArray: def setup_method(self): self.array_1d = np.array([1, 2, 3]) self.array_2d = np.array([[1, 2], [3, 4]]) @@ -75,17 +75,20 @@ def test_serialize_load_integration(self): def test_serialize_load_zero_dim(self): for array_mode in [ - "list", + # TODO: do we even want to support "list" mode for zero-dim arrays? + # "list", "array_list_meta", "array_hex_meta", "array_b64_meta", ]: + print(array_mode) serialized_array = serialize_array( self.jser, self.array_zero_dim, "test_path", array_mode=array_mode, # type: ignore[arg-type] ) + print(serialized_array) loaded_array = load_array(serialized_array) assert np.array_equal(loaded_array, self.array_zero_dim) From ae666967cb1a6fd6d20ccce874658dc1d2b4688b Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Thu, 30 Oct 2025 17:33:50 +0000 Subject: [PATCH 34/72] type hint global ignores --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index da311947..8ddbe34d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -145,6 +145,7 @@ include = ["muutils", "tests"] reportConstantRedefinition = false # I always use all caps for globals, not just consts reportDeprecated = false # this library is backwards compatible back to 3.8, so we are using lots of deprecated stuff + reportUnsupportedDunderAll = false # we use __all__ a lot for docs stuff exclude = [ "tests/input_data", "tests/junk_data", From c340c1a032d7887fc84285ae7447cb04c22d12bc Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Thu, 30 Oct 2025 17:34:08 +0000 Subject: [PATCH 35/72] none type handled? --- muutils/validate_type.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/muutils/validate_type.py b/muutils/validate_type.py index 81272897..bb4b2361 100644 --- a/muutils/validate_type.py +++ b/muutils/validate_type.py @@ -89,6 +89,10 @@ def validate_type( else _return_validation_bool ) + # handle None type (used in type hints like tuple[int, None]) + if expected_type is None: + return _return_func(value is None) + # base type without args if isinstance(expected_type, type): try: From 54b69cb31bd2bfb998e6170a8403343102e94922 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Thu, 30 Oct 2025 17:34:26 +0000 Subject: [PATCH 36/72] explicit kwargs --- muutils/cli/arg_bool.py | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/muutils/cli/arg_bool.py b/muutils/cli/arg_bool.py index 309fbf81..673a4eed 100644 --- a/muutils/cli/arg_bool.py +++ b/muutils/cli/arg_bool.py @@ -115,14 +115,12 @@ def __init__( option_strings: Sequence[str], dest: str, nargs: int | str | None = None, - **kwargs: bool | set[str] | None, + true_set: set[str] | None = None, + false_set: set[str] | None = None, + allow_no: bool = True, + allow_bare: bool = True, + **kwargs: Any, ) -> None: - # Extract custom kwargs before calling super().__init__ - true_set_opt: set[str] | None = kwargs.pop("true_set", None) # type: ignore[assignment,misc] - false_set_opt: set[str] | None = kwargs.pop("false_set", None) # type: ignore[assignment,misc] - allow_no_opt: bool = bool(kwargs.pop("allow_no", True)) - allow_bare_opt: bool = bool(kwargs.pop("allow_bare", True)) - if "type" in kwargs and kwargs["type"] is not None: raise ValueError("BoolFlagOrValue does not accept type=. Remove it.") @@ -133,13 +131,13 @@ def __init__( option_strings=option_strings, dest=dest, nargs="?", - **kwargs, # type: ignore[arg-type] + **kwargs, ) # Store normalized config - self.true_set: set[str] = _normalize_set(true_set_opt, TRUE_SET_DEFAULT) - self.false_set: set[str] = _normalize_set(false_set_opt, FALSE_SET_DEFAULT) - self.allow_no: bool = allow_no_opt - self.allow_bare: bool = allow_bare_opt + self.true_set: set[str] = _normalize_set(true_set, TRUE_SET_DEFAULT) + self.false_set: set[str] = _normalize_set(false_set, FALSE_SET_DEFAULT) + self.allow_no: bool = allow_no + self.allow_bare: bool = allow_bare def _parse_token(self, token: str) -> bool: """Parse a boolean token using this action's configured sets.""" From 2d3ab897c3611dd7ac93f3b5721a10e07623860a Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Thu, 30 Oct 2025 17:34:39 +0000 Subject: [PATCH 37/72] fix dict type hints --- muutils/dbg.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/muutils/dbg.py b/muutils/dbg.py index 3faa70fc..c0cad685 100644 --- a/muutils/dbg.py +++ b/muutils/dbg.py @@ -215,18 +215,18 @@ def tensor_info(tensor: typing.Any) -> str: return array_summary(tensor, as_list=False, **DBG_TENSOR_ARRAY_SUMMARY_DEFAULTS) -DBG_DICT_DEFAULTS: typing.Dict[str, typing.Union[bool, int, str]] = dict( - key_types=True, - val_types=True, - max_len=32, - indent=" ", - max_depth=3, -) +DBG_DICT_DEFAULTS: typing.Dict[str, typing.Union[bool, int, str]] = { + "key_types": True, + "val_types": True, + "max_len": 32, + "indent": " ", + "max_depth": 3, +} -DBG_LIST_DEFAULTS: typing.Dict[str, typing.Union[bool, int, str]] = dict( - max_len=16, - summary_show_types=True, -) +DBG_LIST_DEFAULTS: typing.Dict[str, typing.Union[bool, int, str]] = { + "max_len": 16, + "summary_show_types": True, +} def list_info( From bbe6ec77595b86c141cfaaf6802f31cb3151b3db Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Thu, 30 Oct 2025 17:34:45 +0000 Subject: [PATCH 38/72] make format --- muutils/json_serialize/array.py | 6 +++++- muutils/logger/log_util.py | 4 +++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/muutils/json_serialize/array.py b/muutils/json_serialize/array.py index f09a65a5..00950507 100644 --- a/muutils/json_serialize/array.py +++ b/muutils/json_serialize/array.py @@ -53,6 +53,7 @@ def array_n_elements(arr) -> int: # type: ignore[name-defined] class ArrayMetadata(TypedDict): """Metadata for a numpy/torch array""" + shape: list[int] dtype: str n_elements: int @@ -60,8 +61,11 @@ class ArrayMetadata(TypedDict): class SerializedArrayWithMeta(TypedDict): """Serialized array with metadata (for array_list_meta, array_hex_meta, array_b64_meta, zero_dim modes)""" + __muutils_format__: str - data: typing.Union[NumericList, str, int, float, bool] # list, hex str, b64 str, or scalar for zero_dim + data: typing.Union[ + NumericList, str, int, float, bool + ] # list, hex str, b64 str, or scalar for zero_dim shape: list[int] dtype: str n_elements: int diff --git a/muutils/logger/log_util.py b/muutils/logger/log_util.py index 94e8f949..80ded213 100644 --- a/muutils/logger/log_util.py +++ b/muutils/logger/log_util.py @@ -5,7 +5,9 @@ T_StreamValue = TypeVar("T_StreamValue") -def get_any_from_stream(stream: list[dict[str, T_StreamValue]], key: str) -> T_StreamValue: +def get_any_from_stream( + stream: list[dict[str, T_StreamValue]], key: str +) -> T_StreamValue: """get the first value of a key from a stream. errors if not found""" for msg in stream: if key in msg: From 52ba540f9e734d332794f7802ffaf262e613d91a Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Thu, 30 Oct 2025 18:31:25 +0000 Subject: [PATCH 39/72] json serialize related type fixes --- muutils/json_serialize/array.py | 18 ++++++++--- muutils/json_serialize/json_serialize.py | 25 +++++++++------- muutils/json_serialize/util.py | 30 ++++++++++++++----- .../json_serialize/test_json_serialize.py | 11 ++++++- 4 files changed, 61 insertions(+), 23 deletions(-) diff --git a/muutils/json_serialize/array.py b/muutils/json_serialize/array.py index 00950507..aaa80faf 100644 --- a/muutils/json_serialize/array.py +++ b/muutils/json_serialize/array.py @@ -134,12 +134,16 @@ def serialize_array( # For all other modes, compute metadata once metadata: ArrayMetadata = arr_metadata(arr if len(arr.shape) == 0 else arr_np) + # TYPING: ty<=0.0.1a24 does not appear to support unpacking TypedDicts, so we do things manually. change it back later maybe? + # handle zero-dimensional arrays if len(arr.shape) == 0: return SerializedArrayWithMeta( __muutils_format__=f"{arr_type}:zero_dim", data=arr.item(), - **metadata, + shape=metadata["shape"], + dtype=metadata["dtype"], + n_elements=metadata["n_elements"], ) # Handle the metadata modes @@ -147,19 +151,25 @@ def serialize_array( return SerializedArrayWithMeta( __muutils_format__=f"{arr_type}:array_list_meta", data=arr_np.tolist(), - **metadata, + shape=metadata["shape"], + dtype=metadata["dtype"], + n_elements=metadata["n_elements"], ) elif array_mode == "array_hex_meta": return SerializedArrayWithMeta( __muutils_format__=f"{arr_type}:array_hex_meta", data=arr_np.tobytes().hex(), - **metadata, + shape=metadata["shape"], + dtype=metadata["dtype"], + n_elements=metadata["n_elements"], ) elif array_mode == "array_b64_meta": return SerializedArrayWithMeta( __muutils_format__=f"{arr_type}:array_b64_meta", data=base64.b64encode(arr_np.tobytes()).decode(), - **metadata, + shape=metadata["shape"], + dtype=metadata["dtype"], + n_elements=metadata["n_elements"], ) else: raise KeyError(f"invalid array_mode: {array_mode}") diff --git a/muutils/json_serialize/json_serialize.py b/muutils/json_serialize/json_serialize.py index 4db08fab..a2b589da 100644 --- a/muutils/json_serialize/json_serialize.py +++ b/muutils/json_serialize/json_serialize.py @@ -21,8 +21,9 @@ try: from muutils.json_serialize.array import ArrayMode, serialize_array except ImportError as e: + # TYPING: obviously, these types are all wrong if we can't import array.py ArrayMode = str # type: ignore[misc] - serialize_array = lambda *args, **kwargs: None # noqa: E731 + serialize_array = lambda *args, **kwargs: None # type: ignore[assignment, invalid-assignment] # noqa: E731 # pyright: ignore[reportUnknownVariableType, reportUnknownLambdaType] warnings.warn( f"muutils.json_serialize.array could not be imported probably because missing numpy, array serialization will not work: \n{e}", ImportWarning, @@ -196,7 +197,9 @@ def _serialize_override_serialize_func( SerializerHandler( check=lambda self, obj, path: str(type(obj)) == "", serialize_func=lambda self, obj, path: serialize_array( - self, obj.detach().cpu(), path=path + self, + obj.detach().cpu(), + path=path, # pyright: ignore[reportAny] ), uid="torch.Tensor", desc="pytorch tensors", @@ -205,11 +208,12 @@ def _serialize_override_serialize_func( check=lambda self, obj, path: ( str(type(obj)) == "" ), - serialize_func=lambda self, obj, path: { - _FORMAT_KEY: "pandas.DataFrame", - "columns": obj.columns.tolist(), - "data": obj.to_dict(orient="records"), - "path": path, # type: ignore + # TYPING: type checkers have no idea that obj is a DataFrame here + serialize_func=lambda self, obj, path: { # pyright: ignore[reportArgumentType, reportAny] + _FORMAT_KEY: "pandas.DataFrame", # type: ignore[misc] + "columns": obj.columns.tolist(), # pyright: ignore[reportAny] + "data": obj.to_dict(orient="records"), # pyright: ignore[reportAny] + "path": path, }, uid="pandas.DataFrame", desc="pandas DataFrames", @@ -217,7 +221,7 @@ def _serialize_override_serialize_func( SerializerHandler( check=lambda self, obj, path: isinstance(obj, (set, frozenset)), serialize_func=lambda self, obj, path: { - _FORMAT_KEY: "set" if isinstance(obj, set) else "frozenset", + _FORMAT_KEY: "set" if isinstance(obj, set) else "frozenset", # type: ignore[misc] "data": [ self.json_serialize(x, tuple(path) + (i,)) for i, x in enumerate(obj) ], @@ -306,8 +310,9 @@ def json_serialize( output: JSONitem = handler.serialize_func(self, obj, path) if self.write_only_format: if isinstance(output, dict) and _FORMAT_KEY in output: - new_fmt: JSONitem = output.pop(_FORMAT_KEY) - output["__write_format__"] = new_fmt + # TYPING: JSONitem has no idea that _FORMAT_KEY is str + new_fmt: str = output.pop(_FORMAT_KEY) # type: ignore + output["__write_format__"] = new_fmt # type: ignore return output raise ValueError(f"no handler found for object with {type(obj) = }") diff --git a/muutils/json_serialize/util.py b/muutils/json_serialize/util.py index c2309006..8dc1f7e6 100644 --- a/muutils/json_serialize/util.py +++ b/muutils/json_serialize/util.py @@ -8,7 +8,7 @@ import sys import typing import warnings -from typing import Any, Callable, Iterable, Union +from typing import TYPE_CHECKING, Any, Callable, Final, Iterable, Union _NUMPY_WORKING: bool try: @@ -26,18 +26,32 @@ None, ] -JSONitem = Union[ - BaseType, - typing.List["JSONitem"], - typing.Dict[str, "JSONitem"], -] +# At type-checking time, include array serialization types to avoid nominal type errors +# This avoids runtime circular imports since array.py imports from util.py +if TYPE_CHECKING: + from muutils.json_serialize.array import NumericList, SerializedArrayWithMeta + + JSONitem = Union[ + BaseType, + typing.Sequence["JSONitem"], + typing.Dict[str, "JSONitem"], + SerializedArrayWithMeta, + NumericList, + ] +else: + JSONitem = Union[ + BaseType, + typing.Sequence["JSONitem"], + typing.Dict[str, "JSONitem"], + ] + JSONdict = typing.Dict[str, JSONitem] Hashableitem = Union[bool, int, float, str, tuple] -_FORMAT_KEY: str = "__muutils_format__" -_REF_KEY: str = "$ref" +_FORMAT_KEY: Final[str] = "__muutils_format__" +_REF_KEY: Final[str] = "$ref" # or if python version <3.9 if typing.TYPE_CHECKING or sys.version_info < (3, 9): diff --git a/tests/unit/json_serialize/test_json_serialize.py b/tests/unit/json_serialize/test_json_serialize.py index 8e016c75..71ede73d 100644 --- a/tests/unit/json_serialize/test_json_serialize.py +++ b/tests/unit/json_serialize/test_json_serialize.py @@ -123,6 +123,7 @@ def test_json_serialize_serialize_method(): obj = ClassWithSerialize(value=5) result = serializer.json_serialize(obj) + assert isinstance(result, dict) # Should use the custom serialize method assert result == {"custom_value": 10, "custom_name": "TEST"} @@ -145,6 +146,7 @@ def serialize(self) -> dict: obj = DataclassWithSerialize(x=3, y=7) result = serializer.json_serialize(obj) + assert isinstance(result, dict) # Should use custom serialize, not dataclass handler assert result == {"sum": 10} @@ -263,6 +265,7 @@ def test_DEFAULT_HANDLERS(): result = serializer.json_serialize({1, 2, 3}) assert isinstance(result, dict) assert result[_FORMAT_KEY] == "set" + assert isinstance(result["data"], list) assert set(result["data"]) == {1, 2, 3} # Test tuple (should become list) @@ -493,12 +496,14 @@ def format_serialize(self, obj, path): # Without write_only_format serializer1 = JsonSerializer(handlers_pre=(format_handler,)) result1 = serializer1.json_serialize("FORMAT:test") + assert isinstance(result1, dict) assert _FORMAT_KEY in result1 assert result1[_FORMAT_KEY] == "custom_format" # With write_only_format serializer2 = JsonSerializer(handlers_pre=(format_handler,), write_only_format=True) result2 = serializer2.json_serialize("FORMAT:test") + assert isinstance(result2, dict) assert _FORMAT_KEY not in result2 assert "__write_format__" in result2 assert result2["__write_format__"] == "custom_format" @@ -653,7 +658,7 @@ def test_JsonSerializer_init_custom_values(): serializer = JsonSerializer( array_mode="list", - error_mode="warn", + error_mode=ErrorMode.WARN, handlers_pre=(custom_handler,), handlers_default=BASE_HANDLERS, write_only_format=True, @@ -706,6 +711,7 @@ def test_large_nested_structure(): # Create large nested list large = [[i, i * 2, i * 3] for i in range(100)] result = serializer.json_serialize(large) + assert isinstance(result, list) assert len(result) == 100 assert result[0] == [0, 0, 0] assert result[99] == [99, 198, 297] @@ -720,6 +726,7 @@ def test_mixed_container_types(): assert isinstance(result, dict) assert _FORMAT_KEY in result assert result[_FORMAT_KEY] == "set" + assert isinstance(result["data"], list) assert set(result["data"]) == {1, 2, 3} # Frozenset - serialized with format key @@ -727,6 +734,7 @@ def test_mixed_container_types(): assert isinstance(result, dict) assert _FORMAT_KEY in result assert result[_FORMAT_KEY] == "frozenset" + assert isinstance(result["data"], list) assert set(result["data"]) == {4, 5, 6} # Generator (Iterable) - serialized as list @@ -741,5 +749,6 @@ def test_string_keys_in_dict(): # Integer keys should be converted to strings result = serializer.json_serialize({1: "a", 2: "b", 3: "c"}) + assert isinstance(result, dict) assert result == {"1": "a", "2": "b", "3": "c"} assert all(isinstance(k, str) for k in result.keys()) From 2c3e5daf0dd9ade589c2d209fe0ea7f0fa868404 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Thu, 30 Oct 2025 18:31:42 +0000 Subject: [PATCH 40/72] typed dicts for dbg config --- muutils/dbg.py | 72 ++++++++++++++++++++++++++++++++++---------------- 1 file changed, 49 insertions(+), 23 deletions(-) diff --git a/muutils/dbg.py b/muutils/dbg.py index c0cad685..89363398 100644 --- a/muutils/dbg.py +++ b/muutils/dbg.py @@ -44,6 +44,35 @@ _ExpType_list = typing.TypeVar("_ExpType_list", bound=typing.List[typing.Any]) +# TypedDict definitions for configuration dictionaries +class DBGDictDefaultsType(typing.TypedDict): + key_types: bool + val_types: bool + max_len: int + indent: str + max_depth: int + + +class DBGListDefaultsType(typing.TypedDict): + max_len: int + summary_show_types: bool + + +class DBGTensorArraySummaryDefaultsType(typing.TypedDict): + fmt: str + precision: int + stats: bool + shape: bool + dtype: bool + device: bool + requires_grad: bool + sparkline: bool + sparkline_bins: int + sparkline_logy: typing.Union[None, bool] + colored: bool + eq_char: str + + # Sentinel type for no expression passed class _NoExpPassedSentinel: """Unique sentinel type used to indicate that no expression was passed.""" @@ -188,22 +217,20 @@ def square(x: int) -> int: # formatted `dbg_*` functions with their helpers -DBG_TENSOR_ARRAY_SUMMARY_DEFAULTS: typing.Dict[ - str, typing.Union[None, bool, int, str] -] = dict( - fmt="unicode", - precision=2, - stats=True, - shape=True, - dtype=True, - device=True, - requires_grad=True, - sparkline=True, - sparkline_bins=7, - sparkline_logy=None, # None means auto-detect - colored=True, - eq_char="=", -) +DBG_TENSOR_ARRAY_SUMMARY_DEFAULTS: DBGTensorArraySummaryDefaultsType = { + "fmt": "unicode", + "precision": 2, + "stats": True, + "shape": True, + "dtype": True, + "device": True, + "requires_grad": True, + "sparkline": True, + "sparkline_bins": 7, + "sparkline_logy": None, # None means auto-detect + "colored": True, + "eq_char": "=", +} DBG_TENSOR_VAL_JOINER: str = ": " @@ -215,7 +242,7 @@ def tensor_info(tensor: typing.Any) -> str: return array_summary(tensor, as_list=False, **DBG_TENSOR_ARRAY_SUMMARY_DEFAULTS) -DBG_DICT_DEFAULTS: typing.Dict[str, typing.Union[bool, int, str]] = { +DBG_DICT_DEFAULTS: DBGDictDefaultsType = { "key_types": True, "val_types": True, "max_len": 32, @@ -223,7 +250,7 @@ def tensor_info(tensor: typing.Any) -> str: "max_depth": 3, } -DBG_LIST_DEFAULTS: typing.Dict[str, typing.Union[bool, int, str]] = { +DBG_LIST_DEFAULTS: DBGListDefaultsType = { "max_len": 16, "summary_show_types": True, } @@ -234,8 +261,7 @@ def list_info( ) -> str: len_l: int = len(lst) output: str - # TYPING: make `DBG_LIST_DEFAULTS` and the others typed dicts - if len_l > DBG_LIST_DEFAULTS["max_len"]: # type: ignore[operator] + if len_l > DBG_LIST_DEFAULTS["max_len"]: output = f" str: len_d: int = len(d) - indent: str = DBG_DICT_DEFAULTS["indent"] # type: ignore[assignment] + indent: str = DBG_DICT_DEFAULTS["indent"] # summary line output: str = f"{indent * depth} 0 and len_d < DBG_DICT_DEFAULTS["max_len"]: # type: ignore[operator] + if depth < DBG_DICT_DEFAULTS["max_depth"]: + if len_d > 0 and len_d < DBG_DICT_DEFAULTS["max_len"]: for k, v in d.items(): key_str: str = repr(k) if not isinstance(k, str) else k From cabe36b49f1298c117a1b8b351999b7b2dc70692 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Thu, 30 Oct 2025 18:52:45 +0000 Subject: [PATCH 41/72] `ty check muutils` passes !!! --- muutils/interval.py | 4 +++- muutils/kappa.py | 6 ++++-- muutils/misc/classes.py | 18 +++++++++++------- muutils/misc/func.py | 3 ++- muutils/mlutils.py | 3 ++- muutils/nbutils/print_tex.py | 12 ++++++------ muutils/parallel.py | 3 ++- muutils/tensor_info.py | 5 ++++- muutils/timeit_fancy.py | 20 +++++++++++++------- muutils/validate_type.py | 6 +++--- 10 files changed, 50 insertions(+), 30 deletions(-) diff --git a/muutils/interval.py b/muutils/interval.py index 8b1addb3..fc1d4a00 100644 --- a/muutils/interval.py +++ b/muutils/interval.py @@ -115,7 +115,9 @@ def __init__( ) # Ensure lower bound is less than upper bound - if self.lower > self.upper: + # TYPING: ty throws a @Todo here + # Operator `>` is not supported for types `Sequence[@Todo]` and `Sequence[@Todo]`, in comparing `@Todo | Sequence[@Todo]` with `@Todo | Sequence[@Todo]`tyunsupported-operator + if self.lower > self.upper: # type: ignore[unsupported-operator] raise ValueError("Lower bound must be less than upper bound") if math.isnan(self.lower) or math.isnan(self.upper): diff --git a/muutils/kappa.py b/muutils/kappa.py index 819f1ad6..e964b4ff 100644 --- a/muutils/kappa.py +++ b/muutils/kappa.py @@ -7,14 +7,16 @@ from __future__ import annotations -from typing import Callable, Mapping, TypeVar +from typing import Callable, Final, Mapping, TypeVar _kappa_K = TypeVar("_kappa_K") _kappa_V = TypeVar("_kappa_V") # get the docstring of this file -_BASE_DOC: str = ( +_BASE_DOC: Final[str] = ( + # TYPING: type checkers complain here, they have no idea that this module does in fact have a __doc__ __doc__ + or "anonymous getitem class" + """ source function docstring: diff --git a/muutils/misc/classes.py b/muutils/misc/classes.py index 7874a7c8..ef1d5917 100644 --- a/muutils/misc/classes.py +++ b/muutils/misc/classes.py @@ -17,7 +17,7 @@ def is_abstract(cls: type) -> bool: """ if not hasattr(cls, "__abstractmethods__"): return False # an ordinary class - elif len(cls.__abstractmethods__) == 0: + elif len(cls.__abstractmethods__) == 0: # type: ignore[invalid-argument-type] # pyright: ignore[reportUnknownMemberType, reportUnknownArgumentType] return False # a concrete implementation of an abstract class else: return True # an abstract class @@ -69,18 +69,22 @@ def isinstance_by_type_name(o: object, type_name: str): class IsDataclass(Protocol): # Generic type for any dataclass instance # https://stackoverflow.com/questions/54668000/type-hint-for-an-instance-of-a-non-specific-dataclass - __dataclass_fields__: ClassVar[dict[str, Any]] + __dataclass_fields__: ClassVar[dict[str, Any]] # pyright: ignore[reportExplicitAny] -def get_hashable_eq_attrs(dc: IsDataclass) -> tuple[Any]: +def get_hashable_eq_attrs(dc: IsDataclass) -> tuple[Any]: # pyright: ignore[reportExplicitAny] """Returns a tuple of all fields used for equality comparison, including the type of the dataclass itself. The type is included to preserve the unequal equality behavior of instances of different dataclasses whose fields are identical. Essentially used to generate a hashable dataclass representation for equality comparison even if it's not frozen. """ - return *( - getattr(dc, fld.name) - for fld in filter(lambda x: x.compare, dc.__dataclass_fields__.values()) - ), type(dc) + # TYPING: ty gives @Todo here + return ( # type: ignore[invalid-return-type] + *( + getattr(dc, fld.name) + for fld in filter(lambda x: x.compare, dc.__dataclass_fields__.values()) + ), + type(dc), + ) def dataclass_set_equals( diff --git a/muutils/misc/func.py b/muutils/misc/func.py index 3aa09299..5db91583 100644 --- a/muutils/misc/func.py +++ b/muutils/misc/func.py @@ -254,7 +254,8 @@ def typed_lambda( # Raises: - `ValueError` if the number of input types doesn't match the lambda's parameters. """ - code: CodeType = fn.__code__ + # it will just error here if fn.__code__ doesn't exist + code: CodeType = fn.__code__ # type: ignore[unresolved-attribute] n_params: int = code.co_argcount if len(in_types) != n_params: diff --git a/muutils/mlutils.py b/muutils/mlutils.py index 6c1abcf5..628d0e94 100644 --- a/muutils/mlutils.py +++ b/muutils/mlutils.py @@ -153,7 +153,8 @@ def decorator(method: F) -> F: method_name = method_name_orig else: method_name = custom_name - method.__name__ = custom_name + # TYPING: ty complains here + method.__name__ = custom_name # type: ignore[unresolved-attribute] assert method_name not in method_dict, ( f"Method name already exists in method_dict: {method_name = }, {list(method_dict.keys()) = }" ) diff --git a/muutils/nbutils/print_tex.py b/muutils/nbutils/print_tex.py index ec9e44aa..215772aa 100644 --- a/muutils/nbutils/print_tex.py +++ b/muutils/nbutils/print_tex.py @@ -1,21 +1,21 @@ """quickly print a sympy expression in latex""" -import sympy as sp # type: ignore -from IPython.display import Math, display # type: ignore +import sympy as sp # type: ignore # pyright: ignore[reportMissingTypeStubs] +from IPython.display import Math, display # type: ignore # pyright: ignore[reportUnknownVariableType] def print_tex( - expr: sp.Expr, + expr: sp.Expr, # type: ignore name: str | None = None, plain: bool = False, rendered: bool = True, ): """function for easily rendering a sympy expression in latex""" - out: str = sp.latex(expr) + out: str = sp.latex(expr) # pyright: ignore[reportUnknownVariableType] if name is not None: out = f"{name} = {out}" if plain: - print(out) + print(out) # pyright: ignore[reportUnknownArgumentType] if rendered: - display(Math(out)) + display(Math(out)) # pyright: ignore[reportUnusedCallResult] diff --git a/muutils/parallel.py b/muutils/parallel.py index 4476edfd..f5189a15 100644 --- a/muutils/parallel.py +++ b/muutils/parallel.py @@ -238,7 +238,8 @@ def run_maybe_parallel( ] if parallel: # use `mp.Pool` since we might want to use `multiprocess` instead of `multiprocessing` - pool = mp.Pool(num_processes) + # TYPING: messy here + pool = mp.Pool(num_processes) # type: ignore[possibly-missing-attribute] # pyright: ignore[reportAttributeAccessIssue, reportUnknownMemberType, reportUnknownVariableType] # use `imap` if we want to keep the order, otherwise use `imap_unordered` if keep_ordered: diff --git a/muutils/tensor_info.py b/muutils/tensor_info.py index bd8b1a1b..edd7ad77 100644 --- a/muutils/tensor_info.py +++ b/muutils/tensor_info.py @@ -275,9 +275,12 @@ def array_info( return result +SparklineFormat = Literal["unicode", "latex", "ascii"] + + def generate_sparkline( histogram: np.ndarray, - format: Literal["unicode", "latex", "ascii"] = "unicode", + format: SparklineFormat = "unicode", log_y: Optional[bool] = None, ) -> tuple[str, bool]: """Generate a sparkline visualization of the histogram. diff --git a/muutils/timeit_fancy.py b/muutils/timeit_fancy.py index 0354ff19..30dafc3e 100644 --- a/muutils/timeit_fancy.py +++ b/muutils/timeit_fancy.py @@ -17,15 +17,15 @@ class FancyTimeitResult(NamedTuple): """return type of `timeit_fancy`""" timings: StatCounter - return_value: T_return # type: ignore[valid-type] + return_value: T_return # type: ignore[valid-type] # pyright: ignore[reportGeneralTypeIssues] profile: Union[pstats.Stats, None] def timeit_fancy( cmd: Union[Callable[[], T_return], str], - setup: Union[str, Callable[[], Any]] = lambda: None, + setup: Union[str, Callable[[], Any]] = lambda: None, # pyright: ignore[reportExplicitAny] repeats: int = 5, - namespace: Union[dict[str, Any], None] = None, + namespace: Union[dict[str, Any], None] = None, # pyright: ignore[reportExplicitAny] get_return: bool = True, do_profiling: bool = False, ) -> FancyTimeitResult: @@ -75,14 +75,16 @@ def timeit_fancy( return_value: T_return | None = None if (get_return or do_profiling) and isinstance(cmd, str): warnings.warn( - "Can't do profiling or get return value from `cmd` because it is a string." - " If you want to get the return value, pass a callable instead.", + ( + "Can't do profiling or get return value from `cmd` because it is a string." + + " If you want to get the return value, pass a callable instead." + ), UserWarning, ) if (get_return or do_profiling) and not isinstance(cmd, str): # Optionally perform profiling if do_profiling: - profiler = cProfile.Profile() + profiler: cProfile.Profile = cProfile.Profile() profiler.enable() try: @@ -93,6 +95,8 @@ def timeit_fancy( ) if do_profiling: + # profiler is def bound here + assert isinstance(profiler, cProfile.Profile) # pyright: ignore[reportPossiblyUnboundVariable] profiler.disable() profile = pstats.Stats(profiler).strip_dirs().sort_stats("cumulative") @@ -102,6 +106,8 @@ def timeit_fancy( return FancyTimeitResult( timings=StatCounter(times), - return_value=return_value, + # TYPING: Argument is incorrect: Expected `typing.TypeVar`, found `None | @Todo`tyinvalid-argument-type + # no idea how to fix + return_value=return_value, # type: ignore[invalid-argument-type] profile=profile, ) diff --git a/muutils/validate_type.py b/muutils/validate_type.py index bb4b2361..d06834e0 100644 --- a/muutils/validate_type.py +++ b/muutils/validate_type.py @@ -230,9 +230,9 @@ def get_fn_allowed_kwargs(fn: typing.Callable) -> typing.Set[str]: fn = unwrap(fn) params = signature(fn).parameters except ValueError as e: - raise ValueError( - f"Cannot retrieve signature for {fn.__name__ = } {fn = }: {str(e)}" - ) from e + fn_name: str = getattr(fn, "__name__", str(fn)) + err_msg = f"Cannot retrieve signature for {fn_name = } {fn = }: {str(e)}" + raise ValueError(err_msg) from e return { param.name From 35b7435de92a603ffd41c2a86fd1de43038df70a Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Thu, 30 Oct 2025 21:31:16 +0000 Subject: [PATCH 42/72] wip --- muutils/json_serialize/array.py | 31 +++++++++++-------- .../json_serialize/serializable_dataclass.py | 19 ++++++------ muutils/json_serialize/util.py | 28 ++++++++++------- 3 files changed, 44 insertions(+), 34 deletions(-) diff --git a/muutils/json_serialize/array.py b/muutils/json_serialize/array.py index aaa80faf..1c5f53c3 100644 --- a/muutils/json_serialize/array.py +++ b/muutils/json_serialize/array.py @@ -1,4 +1,4 @@ -"""this utilities module handles serialization and loading of numpy and torch arrays as json +"""this utilities module handles serialization and loading of numpy and torch arrays as json - `array_list_meta` is less efficient (arrays are stored as nested lists), but preserves both metadata and human readability. - `array_b64_meta` is the most efficient, but is not human readable. @@ -11,7 +11,7 @@ import base64 import typing import warnings -from typing import Any, Iterable, Literal, Optional, Sequence, TypedDict +from typing import TYPE_CHECKING, Any, Iterable, Literal, Optional, Sequence, TypedDict try: import numpy as np @@ -21,9 +21,13 @@ ImportWarning, ) +if TYPE_CHECKING: + import numpy as np + from muutils.json_serialize.util import _FORMAT_KEY, JSONitem -# pylint: disable=unused-argument +# TYPING: pyright complains way too much here +# pyright: reportCallIssue=false,reportArgumentType=false,reportUnknownVariableType=false,reportUnknownMemberType=false # Recursive type for nested numeric lists (output of arr.tolist()) NumericList = typing.Union[ @@ -181,7 +185,8 @@ def infer_array_mode(arr: JSONitem) -> ArrayMode: assumes the array was serialized via `serialize_array()` """ if isinstance(arr, typing.Mapping): - fmt: str = arr.get(_FORMAT_KEY, "") # type: ignore + # _FORMAT_KEY always maps to a string + fmt: str = arr.get(_FORMAT_KEY, "") # type: ignore # pyright: ignore[reportAssignmentType] if fmt.endswith(":array_list_meta"): if not isinstance(arr["data"], Iterable): raise ValueError(f"invalid list format: {type(arr['data']) = }\t{arr}") @@ -206,7 +211,7 @@ def infer_array_mode(arr: JSONitem) -> ArrayMode: raise ValueError(f"cannot infer array_mode from\t{type(arr) = }\n{arr = }") -def load_array(arr: JSONitem, array_mode: Optional[ArrayMode] = None) -> Any: +def load_array(arr: JSONitem, array_mode: Optional[ArrayMode] = None) -> Any: # pyright: ignore[reportExplicitAny, reportAny] """load a json-serialized array, infer the mode if not specified""" # return arr if its already a numpy array if isinstance(arr, np.ndarray) and array_mode is None: @@ -226,24 +231,24 @@ def load_array(arr: JSONitem, array_mode: Optional[ArrayMode] = None) -> Any: assert isinstance(arr, typing.Mapping), ( f"invalid list format: {type(arr) = }\n{arr = }" ) - data = np.array(arr["data"], dtype=arr["dtype"]) # type: ignore - if tuple(arr["shape"]) != tuple(data.shape): # type: ignore + data = np.array(arr["data"], dtype=arr["dtype"]) # type: ignore + if tuple(arr["shape"]) != tuple(data.shape): # type: ignore raise ValueError(f"invalid shape: {arr}") - return data + return data elif array_mode == "array_hex_meta": assert isinstance(arr, typing.Mapping), ( f"invalid list format: {type(arr) = }\n{arr = }" ) - data = np.frombuffer(bytes.fromhex(arr["data"]), dtype=arr["dtype"]) # type: ignore - return data.reshape(arr["shape"]) # type: ignore + data = np.frombuffer(bytes.fromhex(arr["data"]), dtype=arr["dtype"]) # type: ignore + return data.reshape(arr["shape"]) # type: ignore elif array_mode == "array_b64_meta": assert isinstance(arr, typing.Mapping), ( f"invalid list format: {type(arr) = }\n{arr = }" ) - data = np.frombuffer(base64.b64decode(arr["data"]), dtype=arr["dtype"]) # type: ignore - return data.reshape(arr["shape"]) # type: ignore + data = np.frombuffer(base64.b64decode(arr["data"]), dtype=arr["dtype"]) # type: ignore + return data.reshape(arr["shape"]) # type: ignore elif array_mode == "list": assert isinstance(arr, typing.Sequence), ( @@ -265,4 +270,4 @@ def load_array(arr: JSONitem, array_mode: Optional[ArrayMode] = None) -> Any: raise ValueError(f"invalid shape: {arr}") return data else: - raise ValueError(f"invalid array_mode: {array_mode}") + raise ValueError(f"invalid array_mode: {array_mode}") # pyright: ignore[reportUnreachable] diff --git a/muutils/json_serialize/serializable_dataclass.py b/muutils/json_serialize/serializable_dataclass.py index 514a7b58..bfc443ce 100644 --- a/muutils/json_serialize/serializable_dataclass.py +++ b/muutils/json_serialize/serializable_dataclass.py @@ -72,14 +72,13 @@ class NestedClass(SerializableDataclass): # this is quite horrible, but unfortunately mypy fails if we try to assign to `dataclass_transform` directly # and every time we try to init a serializable dataclass it says the argument doesnt exist -try: - try: - # type ignore here for legacy versions - from typing import dataclass_transform # type: ignore[attr-defined] - except Exception: +if sys.version_info >= (3, 11): + from typing import dataclass_transform +else: + try: # pyright: ignore[reportUnreachable] from typing_extensions import dataclass_transform -except Exception: - from muutils.json_serialize.dataclass_transform_mock import dataclass_transform + except Exception: + from muutils.json_serialize.dataclass_transform_mock import dataclass_transform T = TypeVar("T") @@ -112,9 +111,9 @@ def zanj_register_loader_serializable_dataclass(cls: typing.Type[T]): if _zanj_loading_needs_import: try: - from zanj.loading import ( # type: ignore[import] - LoaderHandler, - register_loader_handler, + from zanj.loading import ( # type: ignore[import] # pyright: ignore[reportMissingImports] + LoaderHandler, # pyright: ignore[reportUnknownVariableType] + register_loader_handler, # pyright: ignore[reportUnknownVariableType] ) except ImportError: # NOTE: if ZANJ is not installed, then failing to register the loader handler doesnt matter diff --git a/muutils/json_serialize/util.py b/muutils/json_serialize/util.py index 8dc1f7e6..2d84face 100644 --- a/muutils/json_serialize/util.py +++ b/muutils/json_serialize/util.py @@ -18,6 +18,9 @@ _NUMPY_WORKING = False +# pyright: reportExplicitAny=false + + BaseType = Union[ bool, int, @@ -47,18 +50,20 @@ JSONdict = typing.Dict[str, JSONitem] -Hashableitem = Union[bool, int, float, str, tuple] +Hashableitem = Union[bool, int, float, str, tuple] # pyright: ignore[reportMissingTypeArgument] _FORMAT_KEY: Final[str] = "__muutils_format__" _REF_KEY: Final[str] = "$ref" + +# TODO: this bit is very broken # or if python version <3.9 if typing.TYPE_CHECKING or sys.version_info < (3, 9): MonoTuple = typing.Sequence else: - class MonoTuple: + class MonoTuple: # pyright: ignore[reportUnreachable] """tuple type hint, but for a tuple of any length with all the same type""" __slots__ = () @@ -86,20 +91,21 @@ def __class_getitem__(cls, params): raise TypeError(f"MonoTuple expects 1 type argument, got {params = }") +# TYPING: we allow `Any` here because the container is... universal class UniversalContainer: """contains everything -- `x in UniversalContainer()` is always True""" - def __contains__(self, x: Any) -> bool: + def __contains__(self, x: Any) -> bool: # pyright: ignore[reportAny] return True -def isinstance_namedtuple(x: Any) -> bool: +def isinstance_namedtuple(x: Any) -> bool: # pyright: ignore[reportAny] """checks if `x` is a `namedtuple` credit to https://stackoverflow.com/questions/2166818/how-to-check-if-an-object-is-an-instance-of-a-namedtuple """ - t: type = type(x) - b: tuple = t.__bases__ + t: type = type(x) # pyright: ignore[reportUnknownVariableType, reportAny] + b: tuple[type, ...] = t.__bases__ if len(b) != 1 or (b[0] is not tuple): return False f: Any = getattr(t, "_fields", None) @@ -108,7 +114,7 @@ def isinstance_namedtuple(x: Any) -> bool: return all(isinstance(n, str) for n in f) -def try_catch(func: Callable): +def try_catch(func: Callable[[Any], Any]): """wraps the function to catch exceptions, returns serialized error message on exception returned func will return normal result on success, or error message on exception @@ -270,15 +276,15 @@ def dc_eq( if dc1 is dc2: return True - if dc1.__class__ is not dc2.__class__: + if dc1.__class__ is not dc2.__class__: # pyright: ignore[reportUnknownMemberType] if except_when_class_mismatch: # if the classes don't match, raise an error raise TypeError( - f"Cannot compare dataclasses of different classes: `{dc1.__class__}` and `{dc2.__class__}`" + f"Cannot compare dataclasses of different classes: `{dc1.__class__}` and `{dc2.__class__}`" # pyright: ignore[reportUnknownMemberType] ) if except_when_field_mismatch: - dc1_fields: set[str] = set([fld.name for fld in dataclasses.fields(dc1)]) - dc2_fields: set[str] = set([fld.name for fld in dataclasses.fields(dc2)]) + dc1_fields: set[str] = set([fld.name for fld in dataclasses.fields(dc1)]) # pyright: ignore[reportUnknownArgumentType] + dc2_fields: set[str] = set([fld.name for fld in dataclasses.fields(dc2)]) # pyright: ignore[reportUnknownArgumentType] fields_match: bool = set(dc1_fields) == set(dc2_fields) if not fields_match: # if the fields match, keep going From a760fac2734c011fec9305e20da6d8d10d2b4aa2 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Mon, 3 Nov 2025 10:26:58 +0000 Subject: [PATCH 43/72] typing ignores? --- muutils/json_serialize/serializable_dataclass.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/muutils/json_serialize/serializable_dataclass.py b/muutils/json_serialize/serializable_dataclass.py index bfc443ce..e8e2da74 100644 --- a/muutils/json_serialize/serializable_dataclass.py +++ b/muutils/json_serialize/serializable_dataclass.py @@ -66,7 +66,7 @@ class NestedClass(SerializableDataclass): SerializableField, serializable_field, ) -from muutils.json_serialize.util import _FORMAT_KEY, array_safe_eq, dc_eq +from muutils.json_serialize.util import _FORMAT_KEY, JSONdict, JSONitem, array_safe_eq, dc_eq # pylint: disable=bad-mcs-classmethod-argument, too-many-arguments, protected-access @@ -424,11 +424,11 @@ def diff( # if we are working with serialized data, serialize the instances if of_serialized: - ser_self: dict = self.serialize() - ser_other: dict = other.serialize() + ser_self: JSONdict = self.serialize() + ser_other: JSONdict = other.serialize() # for each field in the class - for field in dataclasses.fields(self): # type: ignore[arg-type] + for field in dataclasses.fields(self): # type: ignore[arg-type] # pyright: ignore[reportArgumentType] # skip fields that are not for comparison if not field.compare: continue @@ -454,8 +454,12 @@ def diff( raise ValueError("Non-serializable dataclass is not supported") else: # get the values of either the serialized or the actual values - self_value_s = ser_self[field_name] if of_serialized else self_value - other_value_s = ser_other[field_name] if of_serialized else other_value + if of_serialized: + self_value_s = ser_self[field_name] # pyright: ignore[reportPossiblyUnboundVariable, reportUnknownVariableType] + other_value_s = ser_other[field_name] # pyright: ignore[reportPossiblyUnboundVariable, reportUnknownVariableType] + else: + self_value_s = self_value + other_value_s = other_value # compare the values if not array_safe_eq(self_value_s, other_value_s): diff_result[field_name] = {"self": self_value, "other": other_value} From 15e14a6e14de16bab00b43ffca930a2f16b1b8d4 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Mon, 3 Nov 2025 12:01:09 +0000 Subject: [PATCH 44/72] better typing breakdown --- .meta/typing-summary.txt | 80 +++++++++ TODO.md | 21 ++- makefile | 5 +- muutils/misc/typing_breakdown.py | 278 +++++++++++++++++++++++++++++++ 4 files changed, 376 insertions(+), 8 deletions(-) create mode 100644 .meta/typing-summary.txt create mode 100644 muutils/misc/typing_breakdown.py diff --git a/.meta/typing-summary.txt b/.meta/typing-summary.txt new file mode 100644 index 00000000..c2836f41 --- /dev/null +++ b/.meta/typing-summary.txt @@ -0,0 +1,80 @@ +# Showing top 10 errors per category + +# mypy: Found 127 errors in 8 files (checked 115 source files) +# basedpyright: 803 errors, 3864 warnings, 0 notes +# ty: Found 216 diagnostics + +[type_errors.mypy] +total_errors = 127 + +[type_errors.mypy.by_type] +"index" = 31 +"typeddict-item" = 26 +"arg-type" = 23 +"call-overload" = 21 +"attr-defined" = 11 +"literal-required" = 10 +"var-annotated" = 2 +"operator" = 2 +"no-redef" = 1 + +[type_errors.mypy.by_file] +"tests/unit/json_serialize/test_array_torch.py" = 58 +"tests/unit/json_serialize/test_array.py" = 22 +"tests/unit/test_jsonlines.py" = 21 +"tests/unit/json_serialize/test_serializable_field.py" = 11 +"tests/unit/json_serialize/test_json_serialize.py" = 7 +"tests/unit/benchmark_parallel/benchmark_parallel.py" = 4 +"tests/unit/logger/test_log_util.py" = 3 +"muutils/misc/typing_breakdown.py" = 1 + +[type_errors.basedpyright] +total_errors_by_type = 2521 +total_errors_by_file = 1148 + +[type_errors.basedpyright.by_type] +"reportUnknownParameterType" = 452 +"reportMissingParameterType" = 397 +"reportAny" = 385 +"reportUnusedCallResult" = 295 +"reportUnknownVariableType" = 232 +"reportMissingTypeArgument" = 201 +"reportExplicitAny" = 183 +"reportUnknownMemberType" = 145 +"reportUnknownLambdaType" = 131 +"reportUnusedParameter" = 100 + +[type_errors.basedpyright.by_file] +"/home/miv/projects/tools/muutils/tests/unit/json_serialize/test_json_serialize.py" = 167 +"/home/miv/projects/tools/muutils/tests/unit/test_dbg.py" = 156 +"/home/miv/projects/tools/muutils/tests/unit/test_parallel.py" = 134 +"/home/miv/projects/tools/muutils/tests/unit/json_serialize/serializable_dataclass/test_serializable_dataclass.py" = 125 +"/home/miv/projects/tools/muutils/tests/unit/json_serialize/test_array_torch.py" = 107 +"/home/miv/projects/tools/muutils/muutils/json_serialize/serializable_dataclass.py" = 98 +"/home/miv/projects/tools/muutils/tests/unit/misc/test_func.py" = 98 +"/home/miv/projects/tools/muutils/muutils/dictmagic.py" = 88 +"/home/miv/projects/tools/muutils/muutils/json_serialize/serializable_field.py" = 88 +"/home/miv/projects/tools/muutils/tests/unit/validate_type/test_get_kwargs.py" = 87 + +[type_errors.ty] +total_errors = 216 + +[type_errors.ty.by_type] +"unknown-argument" = 164 +"unresolved-attribute" = 33 +"invalid-argument-type" = 8 +"invalid-assignment" = 6 +"too-many-positional-arguments" = 3 +"invalid-return-type" = 1 +"unresolved-import" = 1 + +[type_errors.ty.by_file] +"tests/unit/json_serialize/serializable_dataclass/test_serializable_dataclass.py" = 134 +"tests/unit/json_serialize/test_serializable_field.py" = 29 +"tests/unit/json_serialize/serializable_dataclass/test_methods_no_override.py" = 26 +"tests/unit/json_serialize/serializable_dataclass/test_sdc_properties_nested.py" = 9 +"tests/unit/test_dictmagic.py" = 8 +"tests/unit/benchmark_parallel/benchmark_parallel.py" = 6 +"tests/unit/json_serialize/test_array_torch.py" = 2 +"tests/unit/benchmark_parallel/test_benchmark_demo.py" = 1 +"tests/unit/test_statcounter.py" = 1 diff --git a/TODO.md b/TODO.md index 240e1ae9..95e46cfa 100644 --- a/TODO.md +++ b/TODO.md @@ -2,15 +2,26 @@ ## Instructions -1. Read the type checker output files: +1. Read the entire file `.meta/typing-summary.txt` to get an overview of the current type errors in the codebase. + +2. Read the type checker output files: - `.meta/.type-errors/mypy.txt` - `.meta/.type-errors/basedpyright.txt` - `.meta/.type-errors/ty.txt` - NOTE: the latter two files are many thousands of lines, you will have to pick the first few or last few hundred lines to get a sense of the errors. + NOTE: the files are many thousands of lines, you will have to pick a *random* few hundred lines to read. it is important that you pick a random set of lines, since you will be working in parallel with other Claude instances, and we want to avoid everyone working on the same errors. + +3. Decide on a good fix to make. For example, you might pick: + - the fix with the best **"number of errors / complexity of change" ratio** + - a fix that gets us closer to having no errors in a specific file (or group of files) + - a fix that gets us closer to removing an entire category of errors + +4. Implement that fix + +run type checking only on the specific file you are changing to verify that the errors are fixed. -2. Find the fix with the best **"number of errors / complexity of change" ratio** -3. Implement that fix +# Guidelines: -run type checking only on the specific file you are changing to verify that the errors are fixed. \ No newline at end of file +- make sure all type hints are python>=3.8 compatible +- always err on the side of STRICTER type hints! \ No newline at end of file diff --git a/makefile b/makefile index 7e90c104..ebfad4b1 100644 --- a/makefile +++ b/makefile @@ -1778,9 +1778,8 @@ typing-summary: gen-extra-tests $(PYTHON) -m mypy --config-file $(PYPROJECT) $(TYPECHECK_ARGS) . > .meta/.type-errors/mypy.txt 2>&1 || true $(PYTHON) -m basedpyright . > .meta/.type-errors/basedpyright.txt 2>&1 || true $(PYTHON) -m ty check . > .meta/.type-errors/ty.txt 2>&1 || true - @echo "mypy: $$(tail -n 1 .meta/.type-errors/mypy.txt)" - @echo "basedpyright: $$(tail -n 1 .meta/.type-errors/basedpyright.txt)" - @echo "ty: $$(tail -n 1 .meta/.type-errors/ty.txt)" + @echo "generating typing summary and breakdown..." + $(PYTHON) -m muutils.misc.typing_breakdown --error-dir .meta/.type-errors --output .meta/typing-summary.txt # generate summary report of type check errors grouped by file # outputs TOML format showing error count per file diff --git a/muutils/misc/typing_breakdown.py b/muutils/misc/typing_breakdown.py new file mode 100644 index 00000000..95f76b39 --- /dev/null +++ b/muutils/misc/typing_breakdown.py @@ -0,0 +1,278 @@ +"""Parse type checker outputs and generate detailed breakdown of errors by type and file. + +Usage: + python -m muutils.misc.typing_breakdown [OPTIONS] + +Examples: + python -m muutils.misc.typing_breakdown + python -m muutils.misc.typing_breakdown --error-dir .meta/.type-errors + python -m muutils.misc.typing_breakdown --top-n 15 --output .meta/typing-summary.txt +""" + +from __future__ import annotations + +import argparse +import re +import sys +from collections import defaultdict +from dataclasses import dataclass, field +from pathlib import Path +from typing import Callable, Dict, Iterator, List, Literal, Tuple +import warnings + + +@dataclass +class TypeCheckResult: + "results from parsing a type checker output" + type_checker: Literal["mypy", "basedpyright", "ty"] + by_type: Dict[str, int] = field(default_factory=lambda: defaultdict(int)) + by_file: Dict[str, int] = field(default_factory=lambda: defaultdict(int)) + + @property + def total_errors(self) -> int: + "total number of errors across all types, validates they match between type and file dicts" + total_by_type: int = sum(self.by_type.values()) + total_by_file: int = sum(self.by_file.values()) + + if total_by_type != total_by_file: + raise ValueError( + f"Error count mismatch for {self.type_checker}: " + f"by_type={total_by_type}, by_file={total_by_file}" + ) + + return total_by_type + + def filter_by(self, top_n: int | None) -> TypeCheckResult: + "return a copy with errors sorted by count and filtered to top_n items (or all if None)" + # Sort by count (descending) + sorted_by_type: List[Tuple[str, int]] = sorted( + self.by_type.items(), + key=lambda x: x[1], + reverse=True, + ) + sorted_by_file: List[Tuple[str, int]] = sorted( + self.by_file.items(), + key=lambda x: x[1], + reverse=True, + ) + + # Apply top_n limit if specified + if top_n is not None: + sorted_by_type = sorted_by_type[:top_n] + sorted_by_file = sorted_by_file[:top_n] + + # Create new instance with filtered data (dicts maintain insertion order in Python 3.7+) + result: TypeCheckResult = TypeCheckResult(type_checker=self.type_checker) + result.by_type = dict(sorted_by_type) + result.by_file = dict(sorted_by_file) + + return result + + def to_toml(self) -> str: + "format as TOML-like output" + lines: List[str] = [] + + # Main section with total + lines.append(f"[type_errors.{self.type_checker}]") + try: + lines.append(f"total_errors = {self.total_errors}") + except ValueError as e: + lines.append(f"total_errors_by_type = {sum(self.by_type.values())}") + lines.append(f"total_errors_by_file = {sum(self.by_file.values())}") + lines.append("") + + # by_type section + lines.append(f"[type_errors.{self.type_checker}.by_type]") + error_type: str + count: int + for error_type, count in self.by_type.items(): + # Always quote keys + lines.append(f'"{error_type}" = {count}') + + lines.append("") + + # by_file section + lines.append(f"[type_errors.{self.type_checker}.by_file]") + file_path: str + for file_path, count in self.by_file.items(): + # Always quote file paths + lines.append(f'"{file_path}" = {count}') + + return "\n".join(lines) + + +def parse_mypy(content: str) -> TypeCheckResult: + "parse mypy output: file.py:line: error: message [error-code]" + result: TypeCheckResult = TypeCheckResult(type_checker="mypy") + + pattern: re.Pattern[str] = re.compile(r"^(.+?):\d+: error: .+ \[(.+?)\]", re.MULTILINE) + match: re.Match[str] + for match in pattern.finditer(content): + file_path: str = match.group(1) + error_code: str = match.group(2) + result.by_type[error_code] += 1 + result.by_file[file_path] += 1 + + return result + + +def parse_basedpyright(content: str) -> TypeCheckResult: + "parse basedpyright output: path on line, then indented errors with (code)" + result: TypeCheckResult = TypeCheckResult(type_checker="basedpyright") + + # Pattern for file paths (lines that start with /) + # Pattern for errors: indented line with - error/warning: message (code) + current_file: str = "" + + line: str + for line in content.splitlines(): + # Check if this is a file path line + if line and not line.startswith(" ") and line.startswith("/"): + current_file = line.strip() + # Check if this is an error/warning line + elif line.strip() and current_file: + # Match pattern like: " path:line:col - warning: message (reportCode)" + match: re.Match[str] | None = re.search(r"\s+.+:\d+:\d+ - (error|warning): .+ \((\w+)\)", line) + if match: + error_type: str = match.group(1) + error_code: str = match.group(2) + result.by_type[error_code] += 1 + result.by_file[current_file] += 1 + + return result + + +def parse_ty(content: str) -> TypeCheckResult: + "parse ty output: error[error-code]: message then --> file:line:col" + result: TypeCheckResult = TypeCheckResult(type_checker="ty") + + # Pattern for error type: error[code]: or warning[code]: + error_pattern: re.Pattern[str] = re.compile(r"^(error|warning)\[(.+?)\]:", re.MULTILINE) + # Pattern for location: --> file:line:col + location_pattern: re.Pattern[str] = re.compile(r"^\s+-->\s+(.+?):\d+:\d+", re.MULTILINE) + + # Find all errors and their locations + errors: List[re.Match[str]] = list(error_pattern.finditer(content)) + locations: List[re.Match[str]] = list(location_pattern.finditer(content)) + + # Match errors with locations (they should be in order) + i: int + error_match: re.Match[str] + for i, error_match in enumerate(errors): + error_code: str = error_match.group(2) + result.by_type[error_code] += 1 + + # Find the next location after this error + error_pos: int = error_match.end() + loc_match: re.Match[str] + for loc_match in locations: + if loc_match.start() > error_pos: + file_path: str = loc_match.group(1) + result.by_file[file_path] += 1 + break + + return result + + +def extract_summary_line(file_path: Path) -> str: + "extract the last non-empty line from a file (typically the summary line)" + content: str = file_path.read_text(encoding="utf-8") + lines: List[str] = [line.strip() for line in content.splitlines() if line.strip()] + return lines[-1] + + +def main(error_dir: str, output_file: str, top_n: int | None = 10) -> None: + "parse all type checker outputs and generate breakdown" + error_path: Path = Path(error_dir) + output_path: Path = Path(output_file) + + output_lines: List[str] = [] + + # Add header comment with top_n info + if top_n is None: + output_lines.append("# Showing all errors") + else: + output_lines.append(f"# Showing top {top_n} errors per category") + output_lines.append("") + + # First, extract summary lines from each type checker + checkers_files: List[Tuple[str, str]] = [ + ("mypy", "mypy.txt"), + ("basedpyright", "basedpyright.txt"), + ("ty", "ty.txt"), + ] + + name: str + filename: str + for name, filename in checkers_files: + file_path: Path = error_path / filename + summary: str = extract_summary_line(file_path) + output_lines.append(f"# {name}: {summary}") + + output_lines.append("") + + # Parse each type checker + checkers: List[Tuple[str, str, Callable[[str], TypeCheckResult]]] = [ + ("mypy", "mypy.txt", parse_mypy), + ("basedpyright", "basedpyright.txt", parse_basedpyright), + ("ty", "ty.txt", parse_ty), + ] + + parser_fn: Callable[[str], TypeCheckResult] + for name, filename, parser_fn in checkers: + file_path: Path = error_path / filename + content: str = file_path.read_text(encoding="utf-8") + result: TypeCheckResult = parser_fn(content) + # Filter and sort the result + filtered_result: TypeCheckResult = result.filter_by(top_n) + # Convert to TOML + breakdown: str = filtered_result.to_toml() + output_lines.append(breakdown) + output_lines.append("") # Add blank line between checkers + + # Write to output file + final_output: str = "\n".join(output_lines) + output_path.parent.mkdir(parents=True, exist_ok=True) + output_path.write_text(final_output, encoding="utf-8") + + # Also print to stdout + print(final_output) + + +if __name__ == "__main__": + parser: argparse.ArgumentParser = argparse.ArgumentParser( + description="Parse type checker outputs and generate detailed breakdown of errors by type and file", + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + parser.add_argument( + "--error-dir", + type=str, + default=".meta/.type-errors", + help="Directory containing type checker output files (default: .meta/.type-errors)", + ) + parser.add_argument( + "--output", + "-o", + type=str, + default=".meta/typing-summary.txt", + help="Output file to write summary to (default: .meta/typing-summary.txt)", + ) + parser.add_argument( + "--top-n", + "-n", + type=str, + default="10", + help='Number of top items to show in each category (default: 10). Use "all" or negative number for all items.', + ) + + args: argparse.Namespace = parser.parse_args() + + # Parse top_n value + top_n_value: int | None + if args.top_n.lower() == "all": + top_n_value = None + else: + top_n_int: int = int(args.top_n) + top_n_value = None if top_n_int < 0 else top_n_int + + main(args.error_dir, args.output, top_n_value) From 9bf4bc3ab264756030f46d5b6872222f8ff5413d Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Mon, 3 Nov 2025 12:04:07 +0000 Subject: [PATCH 45/72] better `try_catch` type hints --- muutils/json_serialize/util.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/muutils/json_serialize/util.py b/muutils/json_serialize/util.py index 2d84face..13513183 100644 --- a/muutils/json_serialize/util.py +++ b/muutils/json_serialize/util.py @@ -8,7 +8,7 @@ import sys import typing import warnings -from typing import TYPE_CHECKING, Any, Callable, Final, Iterable, Union +from typing import TYPE_CHECKING, Any, Callable, Final, Iterable, TypeVar, Union _NUMPY_WORKING: bool try: @@ -114,14 +114,16 @@ def isinstance_namedtuple(x: Any) -> bool: # pyright: ignore[reportAny] return all(isinstance(n, str) for n in f) -def try_catch(func: Callable[[Any], Any]): +T_FuncTryCatchReturn = TypeVar("T_FuncTryCatchReturn") + +def try_catch(func: Callable[..., T_FuncTryCatchReturn]) -> Callable[..., Union[T_FuncTryCatchReturn, str]]: """wraps the function to catch exceptions, returns serialized error message on exception returned func will return normal result on success, or error message on exception """ @functools.wraps(func) - def newfunc(*args, **kwargs): + def newfunc(*args: Any, **kwargs: Any) -> Union[T_FuncTryCatchReturn, str]: try: return func(*args, **kwargs) except Exception as e: From 67248304c9d14bdcd8aa7533585bf4c37e93eb54 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Mon, 3 Nov 2025 12:05:06 +0000 Subject: [PATCH 46/72] format --- muutils/misc/typing_breakdown.py | 479 ++++++++++++++++--------------- 1 file changed, 243 insertions(+), 236 deletions(-) diff --git a/muutils/misc/typing_breakdown.py b/muutils/misc/typing_breakdown.py index 95f76b39..696d2dff 100644 --- a/muutils/misc/typing_breakdown.py +++ b/muutils/misc/typing_breakdown.py @@ -13,266 +13,273 @@ import argparse import re -import sys from collections import defaultdict from dataclasses import dataclass, field from pathlib import Path -from typing import Callable, Dict, Iterator, List, Literal, Tuple -import warnings +from typing import Callable, Dict, List, Literal, Tuple @dataclass class TypeCheckResult: - "results from parsing a type checker output" - type_checker: Literal["mypy", "basedpyright", "ty"] - by_type: Dict[str, int] = field(default_factory=lambda: defaultdict(int)) - by_file: Dict[str, int] = field(default_factory=lambda: defaultdict(int)) - - @property - def total_errors(self) -> int: - "total number of errors across all types, validates they match between type and file dicts" - total_by_type: int = sum(self.by_type.values()) - total_by_file: int = sum(self.by_file.values()) - - if total_by_type != total_by_file: - raise ValueError( - f"Error count mismatch for {self.type_checker}: " - f"by_type={total_by_type}, by_file={total_by_file}" - ) - - return total_by_type - - def filter_by(self, top_n: int | None) -> TypeCheckResult: - "return a copy with errors sorted by count and filtered to top_n items (or all if None)" - # Sort by count (descending) - sorted_by_type: List[Tuple[str, int]] = sorted( - self.by_type.items(), - key=lambda x: x[1], - reverse=True, - ) - sorted_by_file: List[Tuple[str, int]] = sorted( - self.by_file.items(), - key=lambda x: x[1], - reverse=True, - ) - - # Apply top_n limit if specified - if top_n is not None: - sorted_by_type = sorted_by_type[:top_n] - sorted_by_file = sorted_by_file[:top_n] - - # Create new instance with filtered data (dicts maintain insertion order in Python 3.7+) - result: TypeCheckResult = TypeCheckResult(type_checker=self.type_checker) - result.by_type = dict(sorted_by_type) - result.by_file = dict(sorted_by_file) - - return result - - def to_toml(self) -> str: - "format as TOML-like output" - lines: List[str] = [] - - # Main section with total - lines.append(f"[type_errors.{self.type_checker}]") - try: - lines.append(f"total_errors = {self.total_errors}") - except ValueError as e: - lines.append(f"total_errors_by_type = {sum(self.by_type.values())}") - lines.append(f"total_errors_by_file = {sum(self.by_file.values())}") - lines.append("") - - # by_type section - lines.append(f"[type_errors.{self.type_checker}.by_type]") - error_type: str - count: int - for error_type, count in self.by_type.items(): - # Always quote keys - lines.append(f'"{error_type}" = {count}') - - lines.append("") - - # by_file section - lines.append(f"[type_errors.{self.type_checker}.by_file]") - file_path: str - for file_path, count in self.by_file.items(): - # Always quote file paths - lines.append(f'"{file_path}" = {count}') - - return "\n".join(lines) + "results from parsing a type checker output" + + type_checker: Literal["mypy", "basedpyright", "ty"] + by_type: Dict[str, int] = field(default_factory=lambda: defaultdict(int)) + by_file: Dict[str, int] = field(default_factory=lambda: defaultdict(int)) + + @property + def total_errors(self) -> int: + "total number of errors across all types, validates they match between type and file dicts" + total_by_type: int = sum(self.by_type.values()) + total_by_file: int = sum(self.by_file.values()) + + if total_by_type != total_by_file: + raise ValueError( + f"Error count mismatch for {self.type_checker}: " + f"by_type={total_by_type}, by_file={total_by_file}" + ) + + return total_by_type + + def filter_by(self, top_n: int | None) -> TypeCheckResult: + "return a copy with errors sorted by count and filtered to top_n items (or all if None)" + # Sort by count (descending) + sorted_by_type: List[Tuple[str, int]] = sorted( + self.by_type.items(), + key=lambda x: x[1], + reverse=True, + ) + sorted_by_file: List[Tuple[str, int]] = sorted( + self.by_file.items(), + key=lambda x: x[1], + reverse=True, + ) + + # Apply top_n limit if specified + if top_n is not None: + sorted_by_type = sorted_by_type[:top_n] + sorted_by_file = sorted_by_file[:top_n] + + # Create new instance with filtered data (dicts maintain insertion order in Python 3.7+) + result: TypeCheckResult = TypeCheckResult(type_checker=self.type_checker) + result.by_type = dict(sorted_by_type) + result.by_file = dict(sorted_by_file) + + return result + + def to_toml(self) -> str: + "format as TOML-like output" + lines: List[str] = [] + + # Main section with total + lines.append(f"[type_errors.{self.type_checker}]") + try: + lines.append(f"total_errors = {self.total_errors}") + except ValueError: + lines.append(f"total_errors_by_type = {sum(self.by_type.values())}") + lines.append(f"total_errors_by_file = {sum(self.by_file.values())}") + lines.append("") + + # by_type section + lines.append(f"[type_errors.{self.type_checker}.by_type]") + error_type: str + count: int + for error_type, count in self.by_type.items(): + # Always quote keys + lines.append(f'"{error_type}" = {count}') + + lines.append("") + + # by_file section + lines.append(f"[type_errors.{self.type_checker}.by_file]") + file_path: str + for file_path, count in self.by_file.items(): + # Always quote file paths + lines.append(f'"{file_path}" = {count}') + + return "\n".join(lines) def parse_mypy(content: str) -> TypeCheckResult: - "parse mypy output: file.py:line: error: message [error-code]" - result: TypeCheckResult = TypeCheckResult(type_checker="mypy") + "parse mypy output: file.py:line: error: message [error-code]" + result: TypeCheckResult = TypeCheckResult(type_checker="mypy") - pattern: re.Pattern[str] = re.compile(r"^(.+?):\d+: error: .+ \[(.+?)\]", re.MULTILINE) - match: re.Match[str] - for match in pattern.finditer(content): - file_path: str = match.group(1) - error_code: str = match.group(2) - result.by_type[error_code] += 1 - result.by_file[file_path] += 1 + pattern: re.Pattern[str] = re.compile( + r"^(.+?):\d+: error: .+ \[(.+?)\]", re.MULTILINE + ) + match: re.Match[str] + for match in pattern.finditer(content): + file_path: str = match.group(1) + error_code: str = match.group(2) + result.by_type[error_code] += 1 + result.by_file[file_path] += 1 - return result + return result def parse_basedpyright(content: str) -> TypeCheckResult: - "parse basedpyright output: path on line, then indented errors with (code)" - result: TypeCheckResult = TypeCheckResult(type_checker="basedpyright") - - # Pattern for file paths (lines that start with /) - # Pattern for errors: indented line with - error/warning: message (code) - current_file: str = "" - - line: str - for line in content.splitlines(): - # Check if this is a file path line - if line and not line.startswith(" ") and line.startswith("/"): - current_file = line.strip() - # Check if this is an error/warning line - elif line.strip() and current_file: - # Match pattern like: " path:line:col - warning: message (reportCode)" - match: re.Match[str] | None = re.search(r"\s+.+:\d+:\d+ - (error|warning): .+ \((\w+)\)", line) - if match: - error_type: str = match.group(1) - error_code: str = match.group(2) - result.by_type[error_code] += 1 - result.by_file[current_file] += 1 - - return result + "parse basedpyright output: path on line, then indented errors with (code)" + result: TypeCheckResult = TypeCheckResult(type_checker="basedpyright") + + # Pattern for file paths (lines that start with /) + # Pattern for errors: indented line with - error/warning: message (code) + current_file: str = "" + + line: str + for line in content.splitlines(): + # Check if this is a file path line + if line and not line.startswith(" ") and line.startswith("/"): + current_file = line.strip() + # Check if this is an error/warning line + elif line.strip() and current_file: + # Match pattern like: " path:line:col - warning: message (reportCode)" + match: re.Match[str] | None = re.search( + r"\s+.+:\d+:\d+ - (error|warning): .+ \((\w+)\)", line + ) + if match: + error_type: str = match.group(1) + error_code: str = match.group(2) + result.by_type[error_code] += 1 + result.by_file[current_file] += 1 + + return result def parse_ty(content: str) -> TypeCheckResult: - "parse ty output: error[error-code]: message then --> file:line:col" - result: TypeCheckResult = TypeCheckResult(type_checker="ty") - - # Pattern for error type: error[code]: or warning[code]: - error_pattern: re.Pattern[str] = re.compile(r"^(error|warning)\[(.+?)\]:", re.MULTILINE) - # Pattern for location: --> file:line:col - location_pattern: re.Pattern[str] = re.compile(r"^\s+-->\s+(.+?):\d+:\d+", re.MULTILINE) - - # Find all errors and their locations - errors: List[re.Match[str]] = list(error_pattern.finditer(content)) - locations: List[re.Match[str]] = list(location_pattern.finditer(content)) - - # Match errors with locations (they should be in order) - i: int - error_match: re.Match[str] - for i, error_match in enumerate(errors): - error_code: str = error_match.group(2) - result.by_type[error_code] += 1 - - # Find the next location after this error - error_pos: int = error_match.end() - loc_match: re.Match[str] - for loc_match in locations: - if loc_match.start() > error_pos: - file_path: str = loc_match.group(1) - result.by_file[file_path] += 1 - break - - return result + "parse ty output: error[error-code]: message then --> file:line:col" + result: TypeCheckResult = TypeCheckResult(type_checker="ty") + + # Pattern for error type: error[code]: or warning[code]: + error_pattern: re.Pattern[str] = re.compile( + r"^(error|warning)\[(.+?)\]:", re.MULTILINE + ) + # Pattern for location: --> file:line:col + location_pattern: re.Pattern[str] = re.compile( + r"^\s+-->\s+(.+?):\d+:\d+", re.MULTILINE + ) + + # Find all errors and their locations + errors: List[re.Match[str]] = list(error_pattern.finditer(content)) + locations: List[re.Match[str]] = list(location_pattern.finditer(content)) + + # Match errors with locations (they should be in order) + i: int + error_match: re.Match[str] + for i, error_match in enumerate(errors): + error_code: str = error_match.group(2) + result.by_type[error_code] += 1 + + # Find the next location after this error + error_pos: int = error_match.end() + loc_match: re.Match[str] + for loc_match in locations: + if loc_match.start() > error_pos: + file_path: str = loc_match.group(1) + result.by_file[file_path] += 1 + break + + return result def extract_summary_line(file_path: Path) -> str: - "extract the last non-empty line from a file (typically the summary line)" - content: str = file_path.read_text(encoding="utf-8") - lines: List[str] = [line.strip() for line in content.splitlines() if line.strip()] - return lines[-1] + "extract the last non-empty line from a file (typically the summary line)" + content: str = file_path.read_text(encoding="utf-8") + lines: List[str] = [line.strip() for line in content.splitlines() if line.strip()] + return lines[-1] def main(error_dir: str, output_file: str, top_n: int | None = 10) -> None: - "parse all type checker outputs and generate breakdown" - error_path: Path = Path(error_dir) - output_path: Path = Path(output_file) - - output_lines: List[str] = [] - - # Add header comment with top_n info - if top_n is None: - output_lines.append("# Showing all errors") - else: - output_lines.append(f"# Showing top {top_n} errors per category") - output_lines.append("") - - # First, extract summary lines from each type checker - checkers_files: List[Tuple[str, str]] = [ - ("mypy", "mypy.txt"), - ("basedpyright", "basedpyright.txt"), - ("ty", "ty.txt"), - ] - - name: str - filename: str - for name, filename in checkers_files: - file_path: Path = error_path / filename - summary: str = extract_summary_line(file_path) - output_lines.append(f"# {name}: {summary}") - - output_lines.append("") - - # Parse each type checker - checkers: List[Tuple[str, str, Callable[[str], TypeCheckResult]]] = [ - ("mypy", "mypy.txt", parse_mypy), - ("basedpyright", "basedpyright.txt", parse_basedpyright), - ("ty", "ty.txt", parse_ty), - ] - - parser_fn: Callable[[str], TypeCheckResult] - for name, filename, parser_fn in checkers: - file_path: Path = error_path / filename - content: str = file_path.read_text(encoding="utf-8") - result: TypeCheckResult = parser_fn(content) - # Filter and sort the result - filtered_result: TypeCheckResult = result.filter_by(top_n) - # Convert to TOML - breakdown: str = filtered_result.to_toml() - output_lines.append(breakdown) - output_lines.append("") # Add blank line between checkers - - # Write to output file - final_output: str = "\n".join(output_lines) - output_path.parent.mkdir(parents=True, exist_ok=True) - output_path.write_text(final_output, encoding="utf-8") - - # Also print to stdout - print(final_output) + "parse all type checker outputs and generate breakdown" + error_path: Path = Path(error_dir) + output_path: Path = Path(output_file) + + output_lines: List[str] = [] + + # Add header comment with top_n info + if top_n is None: + output_lines.append("# Showing all errors") + else: + output_lines.append(f"# Showing top {top_n} errors per category") + output_lines.append("") + + # First, extract summary lines from each type checker + checkers_files: List[Tuple[str, str]] = [ + ("mypy", "mypy.txt"), + ("basedpyright", "basedpyright.txt"), + ("ty", "ty.txt"), + ] + + name: str + filename: str + for name, filename in checkers_files: + file_path: Path = error_path / filename + summary: str = extract_summary_line(file_path) + output_lines.append(f"# {name}: {summary}") + + output_lines.append("") + + # Parse each type checker + checkers: List[Tuple[str, str, Callable[[str], TypeCheckResult]]] = [ + ("mypy", "mypy.txt", parse_mypy), + ("basedpyright", "basedpyright.txt", parse_basedpyright), + ("ty", "ty.txt", parse_ty), + ] + + parser_fn: Callable[[str], TypeCheckResult] + for name, filename, parser_fn in checkers: + file_path: Path = error_path / filename + content: str = file_path.read_text(encoding="utf-8") + result: TypeCheckResult = parser_fn(content) + # Filter and sort the result + filtered_result: TypeCheckResult = result.filter_by(top_n) + # Convert to TOML + breakdown: str = filtered_result.to_toml() + output_lines.append(breakdown) + output_lines.append("") # Add blank line between checkers + + # Write to output file + final_output: str = "\n".join(output_lines) + output_path.parent.mkdir(parents=True, exist_ok=True) + output_path.write_text(final_output, encoding="utf-8") + + # Also print to stdout + print(final_output) if __name__ == "__main__": - parser: argparse.ArgumentParser = argparse.ArgumentParser( - description="Parse type checker outputs and generate detailed breakdown of errors by type and file", - formatter_class=argparse.RawDescriptionHelpFormatter, - ) - parser.add_argument( - "--error-dir", - type=str, - default=".meta/.type-errors", - help="Directory containing type checker output files (default: .meta/.type-errors)", - ) - parser.add_argument( - "--output", - "-o", - type=str, - default=".meta/typing-summary.txt", - help="Output file to write summary to (default: .meta/typing-summary.txt)", - ) - parser.add_argument( - "--top-n", - "-n", - type=str, - default="10", - help='Number of top items to show in each category (default: 10). Use "all" or negative number for all items.', - ) - - args: argparse.Namespace = parser.parse_args() - - # Parse top_n value - top_n_value: int | None - if args.top_n.lower() == "all": - top_n_value = None - else: - top_n_int: int = int(args.top_n) - top_n_value = None if top_n_int < 0 else top_n_int - - main(args.error_dir, args.output, top_n_value) + parser: argparse.ArgumentParser = argparse.ArgumentParser( + description="Parse type checker outputs and generate detailed breakdown of errors by type and file", + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + parser.add_argument( + "--error-dir", + type=str, + default=".meta/.type-errors", + help="Directory containing type checker output files (default: .meta/.type-errors)", + ) + parser.add_argument( + "--output", + "-o", + type=str, + default=".meta/typing-summary.txt", + help="Output file to write summary to (default: .meta/typing-summary.txt)", + ) + parser.add_argument( + "--top-n", + "-n", + type=str, + default="10", + help='Number of top items to show in each category (default: 10). Use "all" or negative number for all items.', + ) + + args: argparse.Namespace = parser.parse_args() + + # Parse top_n value + top_n_value: int | None + if args.top_n.lower() == "all": + top_n_value = None + else: + top_n_int: int = int(args.top_n) + top_n_value = None if top_n_int < 0 else top_n_int + + main(args.error_dir, args.output, top_n_value) From d244867c00627c4d9a46dee2241a3670e604bd6b Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Mon, 3 Nov 2025 12:05:43 +0000 Subject: [PATCH 47/72] handling type warnings vs errors --- muutils/misc/typing_breakdown.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/muutils/misc/typing_breakdown.py b/muutils/misc/typing_breakdown.py index 696d2dff..04ba20e6 100644 --- a/muutils/misc/typing_breakdown.py +++ b/muutils/misc/typing_breakdown.py @@ -137,7 +137,8 @@ def parse_basedpyright(content: str) -> TypeCheckResult: r"\s+.+:\d+:\d+ - (error|warning): .+ \((\w+)\)", line ) if match: - error_type: str = match.group(1) + # TODO: handle warnings vs errors + _error_type: str = match.group(1) error_code: str = match.group(2) result.by_type[error_code] += 1 result.by_file[current_file] += 1 From 53d7000b5734e9c0fea97dd02cc241889b3466dc Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Mon, 3 Nov 2025 12:06:14 +0000 Subject: [PATCH 48/72] fixes in logger type hints --- muutils/logger/logger.py | 6 +++--- muutils/logger/simplelogger.py | 13 ++++++++----- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/muutils/logger/logger.py b/muutils/logger/logger.py index b4165f95..24501a0f 100644 --- a/muutils/logger/logger.py +++ b/muutils/logger/logger.py @@ -13,7 +13,7 @@ import time import typing from functools import partial -from typing import Callable, Sequence +from typing import Any, Callable, Sequence from muutils.json_serialize import JSONitem, json_serialize from muutils.logger.exception_context import ExceptionContext @@ -220,11 +220,11 @@ def log( # type: ignore # yes, the signatures are different here. # convert and add data # ======================================== # converting to dict - msg_dict: typing.Mapping + msg_dict: dict[str, Any] if not isinstance(msg, typing.Mapping): msg_dict = {"_msg": msg} else: - msg_dict = msg + msg_dict = dict(msg) # level+stream metadata if lvl is not None: diff --git a/muutils/logger/simplelogger.py b/muutils/logger/simplelogger.py index 07f1a306..b39a683a 100644 --- a/muutils/logger/simplelogger.py +++ b/muutils/logger/simplelogger.py @@ -4,7 +4,7 @@ import sys import time import typing -from typing import TextIO, Union +from typing import Any, TextIO, Union from muutils.json_serialize import JSONitem, json_serialize @@ -69,13 +69,16 @@ def log(self, msg: JSONitem, console_print: bool = False, **kwargs): if console_print: print(msg) + msg_dict: dict[str, Any] if not isinstance(msg, typing.Mapping): - msg = {"_msg": msg} + msg_dict = {"_msg": msg} + else: + msg_dict = dict(msg) if self._timestamp: - msg["_timestamp"] = time.time() + msg_dict["_timestamp"] = time.time() if len(kwargs) > 0: - msg["_kwargs"] = kwargs + msg_dict["_kwargs"] = kwargs - self._log_file_handle.write(json.dumps(json_serialize(msg)) + "\n") + self._log_file_handle.write(json.dumps(json_serialize(msg_dict)) + "\n") From 8a0ca930ad2b052f4fe439154753c93d76aef153 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Mon, 3 Nov 2025 12:07:18 +0000 Subject: [PATCH 49/72] format --- muutils/json_serialize/serializable_dataclass.py | 7 ++++++- muutils/json_serialize/util.py | 5 ++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/muutils/json_serialize/serializable_dataclass.py b/muutils/json_serialize/serializable_dataclass.py index e8e2da74..0abed22f 100644 --- a/muutils/json_serialize/serializable_dataclass.py +++ b/muutils/json_serialize/serializable_dataclass.py @@ -66,7 +66,12 @@ class NestedClass(SerializableDataclass): SerializableField, serializable_field, ) -from muutils.json_serialize.util import _FORMAT_KEY, JSONdict, JSONitem, array_safe_eq, dc_eq +from muutils.json_serialize.util import ( + _FORMAT_KEY, + JSONdict, + array_safe_eq, + dc_eq, +) # pylint: disable=bad-mcs-classmethod-argument, too-many-arguments, protected-access diff --git a/muutils/json_serialize/util.py b/muutils/json_serialize/util.py index 13513183..47480315 100644 --- a/muutils/json_serialize/util.py +++ b/muutils/json_serialize/util.py @@ -116,7 +116,10 @@ def isinstance_namedtuple(x: Any) -> bool: # pyright: ignore[reportAny] T_FuncTryCatchReturn = TypeVar("T_FuncTryCatchReturn") -def try_catch(func: Callable[..., T_FuncTryCatchReturn]) -> Callable[..., Union[T_FuncTryCatchReturn, str]]: + +def try_catch( + func: Callable[..., T_FuncTryCatchReturn], +) -> Callable[..., Union[T_FuncTryCatchReturn, str]]: """wraps the function to catch exceptions, returns serialized error message on exception returned func will return normal result on success, or error message on exception From 84ad1947911b1d635df32330fd58d4b0e02fc26c Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Mon, 3 Nov 2025 12:18:02 +0000 Subject: [PATCH 50/72] fix and restrict lots of array saving/loading type hints --- muutils/json_serialize/array.py | 109 ++++++++++++++++++++++++-------- 1 file changed, 82 insertions(+), 27 deletions(-) diff --git a/muutils/json_serialize/array.py b/muutils/json_serialize/array.py index 1c5f53c3..3286035a 100644 --- a/muutils/json_serialize/array.py +++ b/muutils/json_serialize/array.py @@ -1,4 +1,4 @@ -"""this utilities module handles serialization and loading of numpy and torch arrays as json +"""this utilities module handles serialization and loading of numpy and torch arrays as json - `array_list_meta` is less efficient (arrays are stored as nested lists), but preserves both metadata and human readability. - `array_b64_meta` is the most efficient, but is not human readable. @@ -11,7 +11,16 @@ import base64 import typing import warnings -from typing import TYPE_CHECKING, Any, Iterable, Literal, Optional, Sequence, TypedDict +from typing import ( + TYPE_CHECKING, + Iterable, + Literal, + Optional, + Sequence, + TypedDict, + Union, + overload, +) try: import numpy as np @@ -24,7 +33,7 @@ if TYPE_CHECKING: import numpy as np -from muutils.json_serialize.util import _FORMAT_KEY, JSONitem +from muutils.json_serialize.util import _FORMAT_KEY # TYPING: pyright complains way too much here # pyright: reportCallIssue=false,reportArgumentType=false,reportUnknownVariableType=false,reportUnknownMemberType=false @@ -87,7 +96,7 @@ def arr_metadata(arr) -> ArrayMetadata: def serialize_array( - jser: "JsonSerializer", # type: ignore[name-defined] # noqa: F821 + jser: "JsonSerializer", # type: ignore[name-defined] # noqa: F821 # pyright: ignore[reportUndefinedVariable] arr: np.ndarray, path: str | Sequence[str | int], array_mode: ArrayMode | None = None, @@ -179,42 +188,88 @@ def serialize_array( raise KeyError(f"invalid array_mode: {array_mode}") -def infer_array_mode(arr: JSONitem) -> ArrayMode: +@overload +def infer_array_mode( + arr: SerializedArrayWithMeta, +) -> Literal[ + "array_list_meta", + "array_hex_meta", + "array_b64_meta", + "external", + "zero_dim", +]: ... +@overload +def infer_array_mode(arr: NumericList) -> Literal["list"]: ... +def infer_array_mode( + arr: Union[SerializedArrayWithMeta, NumericList], +) -> ArrayMode: """given a serialized array, infer the mode assumes the array was serialized via `serialize_array()` """ + return_mode: ArrayMode if isinstance(arr, typing.Mapping): # _FORMAT_KEY always maps to a string fmt: str = arr.get(_FORMAT_KEY, "") # type: ignore # pyright: ignore[reportAssignmentType] if fmt.endswith(":array_list_meta"): if not isinstance(arr["data"], Iterable): raise ValueError(f"invalid list format: {type(arr['data']) = }\t{arr}") - return "array_list_meta" + return_mode = "array_list_meta" elif fmt.endswith(":array_hex_meta"): if not isinstance(arr["data"], str): raise ValueError(f"invalid hex format: {type(arr['data']) = }\t{arr}") - return "array_hex_meta" + return_mode = "array_hex_meta" elif fmt.endswith(":array_b64_meta"): if not isinstance(arr["data"], str): raise ValueError(f"invalid b64 format: {type(arr['data']) = }\t{arr}") - return "array_b64_meta" + return_mode = "array_b64_meta" elif fmt.endswith(":external"): - return "external" + return_mode = "external" elif fmt.endswith(":zero_dim"): - return "zero_dim" + return_mode = "zero_dim" else: raise ValueError(f"invalid format: {arr}") - elif isinstance(arr, list): - return "list" + elif isinstance(arr, list): # pyright: ignore[reportUnnecessaryIsInstance] + return_mode = "list" else: - raise ValueError(f"cannot infer array_mode from\t{type(arr) = }\n{arr = }") - - -def load_array(arr: JSONitem, array_mode: Optional[ArrayMode] = None) -> Any: # pyright: ignore[reportExplicitAny, reportAny] + raise ValueError(f"cannot infer array_mode from\t{type(arr) = }\n{arr = }") # pyright: ignore[reportUnreachable] + + return return_mode + + +@overload +def load_array( + arr: SerializedArrayWithMeta, + array_mode: Optional[ + Literal[ + "array_list_meta", + "array_hex_meta", + "array_b64_meta", + "external", + "zero_dim", + ] + ] = None, +) -> np.ndarray: ... +@overload +def load_array( + arr: NumericList, + array_mode: Optional[Literal["list"]] = None, +) -> np.ndarray: ... +@overload +def load_array( + arr: np.ndarray, + array_mode: None = None, +) -> np.ndarray: ... +def load_array( + arr: Union[SerializedArrayWithMeta, np.ndarray, NumericList], + array_mode: Optional[ArrayMode] = None, +) -> np.ndarray: """load a json-serialized array, infer the mode if not specified""" # return arr if its already a numpy array - if isinstance(arr, np.ndarray) and array_mode is None: + if isinstance(arr, np.ndarray): + assert array_mode is None, ( + "array_mode should not be specified when loading a numpy array, since that is a no-op" + ) return arr # try to infer the array_mode @@ -231,24 +286,24 @@ def load_array(arr: JSONitem, array_mode: Optional[ArrayMode] = None) -> Any: # assert isinstance(arr, typing.Mapping), ( f"invalid list format: {type(arr) = }\n{arr = }" ) - data = np.array(arr["data"], dtype=arr["dtype"]) # type: ignore - if tuple(arr["shape"]) != tuple(data.shape): # type: ignore + data = np.array(arr["data"], dtype=arr["dtype"]) # type: ignore + if tuple(arr["shape"]) != tuple(data.shape): # type: ignore raise ValueError(f"invalid shape: {arr}") - return data + return data elif array_mode == "array_hex_meta": assert isinstance(arr, typing.Mapping), ( f"invalid list format: {type(arr) = }\n{arr = }" ) - data = np.frombuffer(bytes.fromhex(arr["data"]), dtype=arr["dtype"]) # type: ignore - return data.reshape(arr["shape"]) # type: ignore + data = np.frombuffer(bytes.fromhex(arr["data"]), dtype=arr["dtype"]) # type: ignore + return data.reshape(arr["shape"]) # type: ignore elif array_mode == "array_b64_meta": assert isinstance(arr, typing.Mapping), ( f"invalid list format: {type(arr) = }\n{arr = }" ) - data = np.frombuffer(base64.b64decode(arr["data"]), dtype=arr["dtype"]) # type: ignore - return data.reshape(arr["shape"]) # type: ignore + data = np.frombuffer(base64.b64decode(arr["data"]), dtype=arr["dtype"]) # type: ignore + return data.reshape(arr["shape"]) # type: ignore elif array_mode == "list": assert isinstance(arr, typing.Sequence), ( @@ -256,13 +311,13 @@ def load_array(arr: JSONitem, array_mode: Optional[ArrayMode] = None) -> Any: # ) return np.array(arr) # type: ignore elif array_mode == "external": - # assume ZANJ has taken care of it assert isinstance(arr, typing.Mapping) if "data" not in arr: - raise KeyError( + raise KeyError( # pyright: ignore[reportUnreachable] f"invalid external array, expected key 'data', got keys: '{list(arr.keys())}' and arr: {arr}" ) - return arr["data"] + # we can ignore here since we assume ZANJ has taken care of it + return arr["data"] # type: ignore[return-value] # pyright: ignore[reportReturnType] elif array_mode == "zero_dim": assert isinstance(arr, typing.Mapping) data = np.array(arr["data"]) From 54d8e104ec42f1657237a7adb44d4f4bffa66f85 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Mon, 3 Nov 2025 12:20:26 +0000 Subject: [PATCH 51/72] add json_serialize.types to help avoid import cycles --- muutils/json_serialize/array.py | 2 +- muutils/json_serialize/types.py | 20 ++++++++++++++++++++ muutils/json_serialize/util.py | 21 ++++----------------- 3 files changed, 25 insertions(+), 18 deletions(-) create mode 100644 muutils/json_serialize/types.py diff --git a/muutils/json_serialize/array.py b/muutils/json_serialize/array.py index 3286035a..42f23463 100644 --- a/muutils/json_serialize/array.py +++ b/muutils/json_serialize/array.py @@ -33,7 +33,7 @@ if TYPE_CHECKING: import numpy as np -from muutils.json_serialize.util import _FORMAT_KEY +from muutils.json_serialize.types import _FORMAT_KEY # TYPING: pyright complains way too much here # pyright: reportCallIssue=false,reportArgumentType=false,reportUnknownVariableType=false,reportUnknownMemberType=false diff --git a/muutils/json_serialize/types.py b/muutils/json_serialize/types.py new file mode 100644 index 00000000..34284c92 --- /dev/null +++ b/muutils/json_serialize/types.py @@ -0,0 +1,20 @@ +"""base types, lets us avoid import cycles""" + +from __future__ import annotations + +from typing import Final, Union + + +BaseType = Union[ + bool, + int, + float, + str, + None, +] + +Hashableitem = Union[bool, int, float, str, tuple] # pyright: ignore[reportMissingTypeArgument] + + +_FORMAT_KEY: Final[str] = "__muutils_format__" +_REF_KEY: Final[str] = "$ref" diff --git a/muutils/json_serialize/util.py b/muutils/json_serialize/util.py index 47480315..62c73243 100644 --- a/muutils/json_serialize/util.py +++ b/muutils/json_serialize/util.py @@ -8,7 +8,9 @@ import sys import typing import warnings -from typing import TYPE_CHECKING, Any, Callable, Final, Iterable, TypeVar, Union +from typing import TYPE_CHECKING, Any, Callable, Iterable, TypeVar, Union + +from muutils.json_serialize.types import BaseType, Hashableitem _NUMPY_WORKING: bool try: @@ -20,17 +22,8 @@ # pyright: reportExplicitAny=false - -BaseType = Union[ - bool, - int, - float, - str, - None, -] - # At type-checking time, include array serialization types to avoid nominal type errors -# This avoids runtime circular imports since array.py imports from util.py +# This avoids superfluous imports at runtime if TYPE_CHECKING: from muutils.json_serialize.array import NumericList, SerializedArrayWithMeta @@ -50,12 +43,6 @@ JSONdict = typing.Dict[str, JSONitem] -Hashableitem = Union[bool, int, float, str, tuple] # pyright: ignore[reportMissingTypeArgument] - - -_FORMAT_KEY: Final[str] = "__muutils_format__" -_REF_KEY: Final[str] = "$ref" - # TODO: this bit is very broken # or if python version <3.9 From 7b4e4694c062c950e4dc33634e7ca1bb82d05633 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Mon, 3 Nov 2025 12:28:40 +0000 Subject: [PATCH 52/72] full typing summary --- .meta/typing-summary.txt | 165 ++++++++++++++++++++++++++++++++++----- makefile | 2 +- 2 files changed, 145 insertions(+), 22 deletions(-) diff --git a/.meta/typing-summary.txt b/.meta/typing-summary.txt index c2836f41..82bb7866 100644 --- a/.meta/typing-summary.txt +++ b/.meta/typing-summary.txt @@ -1,27 +1,30 @@ -# Showing top 10 errors per category +# Showing all errors -# mypy: Found 127 errors in 8 files (checked 115 source files) -# basedpyright: 803 errors, 3864 warnings, 0 notes -# ty: Found 216 diagnostics +# mypy: Found 159 errors in 9 files (checked 116 source files) +# basedpyright: 796 errors, 3811 warnings, 0 notes +# ty: Found 219 diagnostics [type_errors.mypy] -total_errors = 127 +total_errors = 159 [type_errors.mypy.by_type] -"index" = 31 +"dict-item" = 35 +"index" = 33 "typeddict-item" = 26 -"arg-type" = 23 -"call-overload" = 21 +"call-overload" = 23 +"arg-type" = 14 "attr-defined" = 11 "literal-required" = 10 "var-annotated" = 2 "operator" = 2 +"assignment" = 2 "no-redef" = 1 [type_errors.mypy.by_file] "tests/unit/json_serialize/test_array_torch.py" = 58 -"tests/unit/json_serialize/test_array.py" = 22 -"tests/unit/test_jsonlines.py" = 21 +"muutils/tensor_utils.py" = 39 +"tests/unit/json_serialize/test_array.py" = 24 +"tests/unit/test_jsonlines.py" = 12 "tests/unit/json_serialize/test_serializable_field.py" = 11 "tests/unit/json_serialize/test_json_serialize.py" = 7 "tests/unit/benchmark_parallel/benchmark_parallel.py" = 4 @@ -29,42 +32,161 @@ total_errors = 127 "muutils/misc/typing_breakdown.py" = 1 [type_errors.basedpyright] -total_errors_by_type = 2521 -total_errors_by_file = 1148 +total_errors = 3075 [type_errors.basedpyright.by_type] "reportUnknownParameterType" = 452 "reportMissingParameterType" = 397 -"reportAny" = 385 +"reportAny" = 364 "reportUnusedCallResult" = 295 -"reportUnknownVariableType" = 232 -"reportMissingTypeArgument" = 201 -"reportExplicitAny" = 183 -"reportUnknownMemberType" = 145 +"reportUnknownVariableType" = 226 +"reportMissingTypeArgument" = 195 +"reportExplicitAny" = 187 +"reportUnknownMemberType" = 141 "reportUnknownLambdaType" = 131 "reportUnusedParameter" = 100 +"reportIndexIssue" = 93 +"reportCallIssue" = 59 +"reportImplicitOverride" = 53 +"reportInvalidTypeForm" = 51 +"reportPrivateUsage" = 41 +"reportUnannotatedClassAttribute" = 41 +"reportPossiblyUnboundVariable" = 34 +"reportUnreachable" = 32 +"reportOptionalSubscript" = 27 +"reportUnnecessaryIsInstance" = 21 +"reportUntypedClassDecorator" = 17 +"reportMissingSuperCall" = 14 +"reportUntypedFunctionDecorator" = 13 +"reportUnusedVariable" = 11 +"reportInvalidTypeArguments" = 9 +"reportUnnecessaryComparison" = 8 +"reportCallInDefaultInitializer" = 8 +"reportUndefinedVariable" = 7 +"reportOptionalMemberAccess" = 7 +"reportAttributeAccessIssue" = 5 +"reportUninitializedInstanceVariable" = 5 +"reportUnusedClass" = 4 +"reportUnnecessaryTypeIgnoreComment" = 3 +"reportImplicitStringConcatenation" = 3 +"reportUnusedImport" = 3 +"reportMissingTypeStubs" = 3 +"reportMissingImports" = 3 +"reportUnusedExpression" = 3 +"reportOperatorIssue" = 2 +"reportArgumentType" = 2 +"reportUntypedNamedTuple" = 2 +"reportRedeclaration" = 1 +"reportUnusedFunction" = 1 +"reportGeneralTypeIssues" = 1 [type_errors.basedpyright.by_file] "/home/miv/projects/tools/muutils/tests/unit/json_serialize/test_json_serialize.py" = 167 "/home/miv/projects/tools/muutils/tests/unit/test_dbg.py" = 156 "/home/miv/projects/tools/muutils/tests/unit/test_parallel.py" = 134 "/home/miv/projects/tools/muutils/tests/unit/json_serialize/serializable_dataclass/test_serializable_dataclass.py" = 125 -"/home/miv/projects/tools/muutils/tests/unit/json_serialize/test_array_torch.py" = 107 -"/home/miv/projects/tools/muutils/muutils/json_serialize/serializable_dataclass.py" = 98 "/home/miv/projects/tools/muutils/tests/unit/misc/test_func.py" = 98 +"/home/miv/projects/tools/muutils/muutils/json_serialize/serializable_dataclass.py" = 97 +"/home/miv/projects/tools/muutils/tests/unit/json_serialize/test_array_torch.py" = 95 "/home/miv/projects/tools/muutils/muutils/dictmagic.py" = 88 "/home/miv/projects/tools/muutils/muutils/json_serialize/serializable_field.py" = 88 "/home/miv/projects/tools/muutils/tests/unit/validate_type/test_get_kwargs.py" = 87 +"/home/miv/projects/tools/muutils/tests/unit/validate_type/test_validate_type.py" = 86 +"/home/miv/projects/tools/muutils/tests/unit/validate_type/test_validate_type_GENERATED.py" = 86 +"/home/miv/projects/tools/muutils/muutils/json_serialize/json_serialize.py" = 81 +"/home/miv/projects/tools/muutils/muutils/misc/func.py" = 77 +"/home/miv/projects/tools/muutils/tests/unit/test_interval.py" = 75 +"/home/miv/projects/tools/muutils/muutils/misc/freezing.py" = 67 +"/home/miv/projects/tools/muutils/tests/unit/cli/test_arg_bool.py" = 66 +"/home/miv/projects/tools/muutils/muutils/spinner.py" = 63 +"/home/miv/projects/tools/muutils/muutils/tensor_utils.py" = 63 +"/home/miv/projects/tools/muutils/tests/unit/benchmark_parallel/benchmark_parallel.py" = 56 +"/home/miv/projects/tools/muutils/tests/unit/web/test_bundle_html.py" = 55 +"/home/miv/projects/tools/muutils/muutils/tensor_info.py" = 52 +"/home/miv/projects/tools/muutils/tests/unit/errormode/test_errormode_init.py" = 46 +"/home/miv/projects/tools/muutils/tests/unit/json_serialize/test_serializable_field.py" = 46 +"/home/miv/projects/tools/muutils/muutils/parallel.py" = 45 +"/home/miv/projects/tools/muutils/muutils/statcounter.py" = 41 +"/home/miv/projects/tools/muutils/muutils/sysinfo.py" = 39 +"/home/miv/projects/tools/muutils/tests/unit/json_serialize/test_array.py" = 38 +"/home/miv/projects/tools/muutils/muutils/web/bundle_html.py" = 37 +"/home/miv/projects/tools/muutils/tests/unit/test_dictmagic.py" = 37 +"/home/miv/projects/tools/muutils/tests/unit/errormode/test_errormode_functionality.py" = 36 +"/home/miv/projects/tools/muutils/muutils/nbutils/convert_ipynb_to_script.py" = 34 +"/home/miv/projects/tools/muutils/muutils/nbutils/configure_notebook.py" = 33 +"/home/miv/projects/tools/muutils/tests/unit/test_spinner.py" = 33 +"/home/miv/projects/tools/muutils/tests/unit/misc/test_freeze.py" = 31 +"/home/miv/projects/tools/muutils/muutils/json_serialize/util.py" = 27 +"/home/miv/projects/tools/muutils/muutils/misc/sequence.py" = 27 +"/home/miv/projects/tools/muutils/tests/unit/misc/test_numerical_conversions.py" = 27 +"/home/miv/projects/tools/muutils/tests/unit/json_serialize/serializable_dataclass/test_methods_no_override.py" = 25 +"/home/miv/projects/tools/muutils/muutils/mlutils.py" = 24 +"/home/miv/projects/tools/muutils/muutils/validate_type.py" = 24 +"/home/miv/projects/tools/muutils/muutils/interval.py" = 20 +"/home/miv/projects/tools/muutils/tests/unit/json_serialize/test_util.py" = 20 +"/home/miv/projects/tools/muutils/tests/unit/validate_type/test_validate_type_special.py" = 20 +"/home/miv/projects/tools/muutils/muutils/logger/logger.py" = 19 +"/home/miv/projects/tools/muutils/tests/unit/test_jsonlines.py" = 19 +"/home/miv/projects/tools/muutils/muutils/math/matrix_powers.py" = 18 +"/home/miv/projects/tools/muutils/muutils/misc/string.py" = 18 +"/home/miv/projects/tools/muutils/tests/unit/test_kappa.py" = 17 +"/home/miv/projects/tools/muutils/muutils/logger/log_util.py" = 16 +"/home/miv/projects/tools/muutils/tests/unit/test_tensor_info_torch.py" = 15 +"/home/miv/projects/tools/muutils/tests/unit/nbutils/test_configure_notebook.py" = 14 +"/home/miv/projects/tools/muutils/muutils/jsonlines.py" = 13 +"/home/miv/projects/tools/muutils/muutils/misc/typing_breakdown.py" = 13 +"/home/miv/projects/tools/muutils/tests/unit/misc/test_misc.py" = 13 +"/home/miv/projects/tools/muutils/muutils/nbutils/run_notebook_tests.py" = 12 +"/home/miv/projects/tools/muutils/muutils/cli/arg_bool.py" = 11 +"/home/miv/projects/tools/muutils/muutils/errormode.py" = 11 +"/home/miv/projects/tools/muutils/muutils/json_serialize/dataclass_transform_mock.py" = 11 +"/home/miv/projects/tools/muutils/muutils/dbg.py" = 10 +"/home/miv/projects/tools/muutils/muutils/json_serialize/array.py" = 10 +"/home/miv/projects/tools/muutils/muutils/logger/exception_context.py" = 10 +"/home/miv/projects/tools/muutils/muutils/logger/headerfuncs.py" = 10 +"/home/miv/projects/tools/muutils/tests/unit/test_collect_warnings.py" = 10 +"/home/miv/projects/tools/muutils/muutils/kappa.py" = 9 +"/home/miv/projects/tools/muutils/muutils/collect_warnings.py" = 8 +"/home/miv/projects/tools/muutils/tests/unit/test_console_unicode.py" = 8 +"/home/miv/projects/tools/muutils/tests/util/test_fire.py" = 8 +"/home/miv/projects/tools/muutils/muutils/cli/command.py" = 7 +"/home/miv/projects/tools/muutils/muutils/misc/classes.py" = 7 +"/home/miv/projects/tools/muutils/muutils/nbutils/mermaid.py" = 7 +"/home/miv/projects/tools/muutils/tests/unit/cli/test_command.py" = 7 +"/home/miv/projects/tools/muutils/muutils/logger/timing.py" = 6 +"/home/miv/projects/tools/muutils/tests/unit/nbutils/test_conversion.py" = 6 +"/home/miv/projects/tools/muutils/tests/conftest.py" = 5 +"/home/miv/projects/tools/muutils/tests/unit/json_serialize/serializable_dataclass/test_sdc_defaults.py" = 5 +"/home/miv/projects/tools/muutils/tests/unit/json_serialize/serializable_dataclass/test_sdc_properties_nested.py" = 5 +"/home/miv/projects/tools/muutils/muutils/logger/simplelogger.py" = 4 +"/home/miv/projects/tools/muutils/muutils/web/html_to_pdf.py" = 4 +"/home/miv/projects/tools/muutils/tests/unit/math/test_matrix_powers_torch.py" = 4 +"/home/miv/projects/tools/muutils/tests/unit/misc/test_sequence.py" = 4 +"/home/miv/projects/tools/muutils/tests/unit/test_mlutils.py" = 4 +"/home/miv/projects/tools/muutils/tests/unit/test_tensor_utils_torch.py" = 4 +"/home/miv/projects/tools/muutils/muutils/misc/hashing.py" = 3 +"/home/miv/projects/tools/muutils/tests/unit/test_tensor_info.py" = 3 +"/home/miv/projects/tools/muutils/muutils/logger/loggingstream.py" = 2 +"/home/miv/projects/tools/muutils/muutils/misc/numerical.py" = 2 +"/home/miv/projects/tools/muutils/tests/unit/logger/test_log_util.py" = 2 +"/home/miv/projects/tools/muutils/tests/unit/test_chunks.py" = 2 +"/home/miv/projects/tools/muutils/tests/unit/test_timeit_fancy.py" = 2 +"/home/miv/projects/tools/muutils/muutils/console_unicode.py" = 1 +"/home/miv/projects/tools/muutils/muutils/math/bins.py" = 1 +"/home/miv/projects/tools/muutils/muutils/misc/__init__.py" = 1 +"/home/miv/projects/tools/muutils/muutils/misc/b64_decode.py" = 1 +"/home/miv/projects/tools/muutils/tests/unit/benchmark_parallel/test_benchmark_demo.py" = 1 [type_errors.ty] -total_errors = 216 +total_errors = 219 [type_errors.ty.by_type] "unknown-argument" = 164 "unresolved-attribute" = 33 "invalid-argument-type" = 8 -"invalid-assignment" = 6 +"invalid-assignment" = 7 "too-many-positional-arguments" = 3 +"non-subscriptable" = 2 "invalid-return-type" = 1 "unresolved-import" = 1 @@ -75,6 +197,7 @@ total_errors = 216 "tests/unit/json_serialize/serializable_dataclass/test_sdc_properties_nested.py" = 9 "tests/unit/test_dictmagic.py" = 8 "tests/unit/benchmark_parallel/benchmark_parallel.py" = 6 +"muutils/tensor_utils.py" = 3 "tests/unit/json_serialize/test_array_torch.py" = 2 "tests/unit/benchmark_parallel/test_benchmark_demo.py" = 1 "tests/unit/test_statcounter.py" = 1 diff --git a/makefile b/makefile index ebfad4b1..3c02972b 100644 --- a/makefile +++ b/makefile @@ -1779,7 +1779,7 @@ typing-summary: gen-extra-tests $(PYTHON) -m basedpyright . > .meta/.type-errors/basedpyright.txt 2>&1 || true $(PYTHON) -m ty check . > .meta/.type-errors/ty.txt 2>&1 || true @echo "generating typing summary and breakdown..." - $(PYTHON) -m muutils.misc.typing_breakdown --error-dir .meta/.type-errors --output .meta/typing-summary.txt + $(PYTHON) -m muutils.misc.typing_breakdown --error-dir .meta/.type-errors --top-n -1 --output .meta/typing-summary.txt # generate summary report of type check errors grouped by file # outputs TOML format showing error count per file From 1a7c9688a9602a8a8969bba9ba177287f29d6944 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Mon, 3 Nov 2025 12:39:05 +0000 Subject: [PATCH 53/72] typing summary stuff --- .meta/typing-summary.txt | 37 +++++++++++++++---------------------- TODO.md | 2 +- 2 files changed, 16 insertions(+), 23 deletions(-) diff --git a/.meta/typing-summary.txt b/.meta/typing-summary.txt index 82bb7866..9b36bec5 100644 --- a/.meta/typing-summary.txt +++ b/.meta/typing-summary.txt @@ -1,51 +1,47 @@ # Showing all errors -# mypy: Found 159 errors in 9 files (checked 116 source files) -# basedpyright: 796 errors, 3811 warnings, 0 notes -# ty: Found 219 diagnostics +# mypy: Found 117 errors in 7 files (checked 116 source files) +# basedpyright: 778 errors, 3791 warnings, 0 notes +# ty: Found 216 diagnostics [type_errors.mypy] -total_errors = 159 +total_errors = 117 [type_errors.mypy.by_type] -"dict-item" = 35 -"index" = 33 +"index" = 31 "typeddict-item" = 26 "call-overload" = 23 -"arg-type" = 14 +"arg-type" = 11 "attr-defined" = 11 "literal-required" = 10 "var-annotated" = 2 "operator" = 2 -"assignment" = 2 "no-redef" = 1 [type_errors.mypy.by_file] "tests/unit/json_serialize/test_array_torch.py" = 58 -"muutils/tensor_utils.py" = 39 "tests/unit/json_serialize/test_array.py" = 24 "tests/unit/test_jsonlines.py" = 12 "tests/unit/json_serialize/test_serializable_field.py" = 11 "tests/unit/json_serialize/test_json_serialize.py" = 7 "tests/unit/benchmark_parallel/benchmark_parallel.py" = 4 -"tests/unit/logger/test_log_util.py" = 3 "muutils/misc/typing_breakdown.py" = 1 [type_errors.basedpyright] -total_errors = 3075 +total_errors = 3066 [type_errors.basedpyright.by_type] "reportUnknownParameterType" = 452 "reportMissingParameterType" = 397 -"reportAny" = 364 +"reportAny" = 374 "reportUnusedCallResult" = 295 -"reportUnknownVariableType" = 226 -"reportMissingTypeArgument" = 195 -"reportExplicitAny" = 187 +"reportUnknownVariableType" = 214 +"reportExplicitAny" = 194 +"reportMissingTypeArgument" = 185 "reportUnknownMemberType" = 141 "reportUnknownLambdaType" = 131 "reportUnusedParameter" = 100 -"reportIndexIssue" = 93 +"reportIndexIssue" = 91 "reportCallIssue" = 59 "reportImplicitOverride" = 53 "reportInvalidTypeForm" = 51 @@ -74,7 +70,6 @@ total_errors = 3075 "reportMissingImports" = 3 "reportUnusedExpression" = 3 "reportOperatorIssue" = 2 -"reportArgumentType" = 2 "reportUntypedNamedTuple" = 2 "reportRedeclaration" = 1 "reportUnusedFunction" = 1 @@ -99,9 +94,9 @@ total_errors = 3075 "/home/miv/projects/tools/muutils/muutils/misc/freezing.py" = 67 "/home/miv/projects/tools/muutils/tests/unit/cli/test_arg_bool.py" = 66 "/home/miv/projects/tools/muutils/muutils/spinner.py" = 63 -"/home/miv/projects/tools/muutils/muutils/tensor_utils.py" = 63 "/home/miv/projects/tools/muutils/tests/unit/benchmark_parallel/benchmark_parallel.py" = 56 "/home/miv/projects/tools/muutils/tests/unit/web/test_bundle_html.py" = 55 +"/home/miv/projects/tools/muutils/muutils/tensor_utils.py" = 54 "/home/miv/projects/tools/muutils/muutils/tensor_info.py" = 52 "/home/miv/projects/tools/muutils/tests/unit/errormode/test_errormode_init.py" = 46 "/home/miv/projects/tools/muutils/tests/unit/json_serialize/test_serializable_field.py" = 46 @@ -178,15 +173,14 @@ total_errors = 3075 "/home/miv/projects/tools/muutils/tests/unit/benchmark_parallel/test_benchmark_demo.py" = 1 [type_errors.ty] -total_errors = 219 +total_errors = 216 [type_errors.ty.by_type] "unknown-argument" = 164 "unresolved-attribute" = 33 "invalid-argument-type" = 8 -"invalid-assignment" = 7 +"invalid-assignment" = 6 "too-many-positional-arguments" = 3 -"non-subscriptable" = 2 "invalid-return-type" = 1 "unresolved-import" = 1 @@ -197,7 +191,6 @@ total_errors = 219 "tests/unit/json_serialize/serializable_dataclass/test_sdc_properties_nested.py" = 9 "tests/unit/test_dictmagic.py" = 8 "tests/unit/benchmark_parallel/benchmark_parallel.py" = 6 -"muutils/tensor_utils.py" = 3 "tests/unit/json_serialize/test_array_torch.py" = 2 "tests/unit/benchmark_parallel/test_benchmark_demo.py" = 1 "tests/unit/test_statcounter.py" = 1 diff --git a/TODO.md b/TODO.md index 95e46cfa..46890f31 100644 --- a/TODO.md +++ b/TODO.md @@ -13,8 +13,8 @@ 3. Decide on a good fix to make. For example, you might pick: - the fix with the best **"number of errors / complexity of change" ratio** - - a fix that gets us closer to having no errors in a specific file (or group of files) - a fix that gets us closer to removing an entire category of errors + - a fix that gets us closer to having no errors in a specific file (**FOCUS ON THIS!**) 4. Implement that fix From 2da2ec51550156e79c75e514e7c5ea7cf1625523 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Mon, 3 Nov 2025 12:39:17 +0000 Subject: [PATCH 54/72] fix typing breakdown type hints --- muutils/misc/typing_breakdown.py | 24 +++++++++++------------- 1 file changed, 11 insertions(+), 13 deletions(-) diff --git a/muutils/misc/typing_breakdown.py b/muutils/misc/typing_breakdown.py index 04ba20e6..965b77f1 100644 --- a/muutils/misc/typing_breakdown.py +++ b/muutils/misc/typing_breakdown.py @@ -34,10 +34,8 @@ def total_errors(self) -> int: total_by_file: int = sum(self.by_file.values()) if total_by_type != total_by_file: - raise ValueError( - f"Error count mismatch for {self.type_checker}: " - f"by_type={total_by_type}, by_file={total_by_file}" - ) + err_msg: str = f"Error count mismatch for {self.type_checker}: by_type={total_by_type}, by_file={total_by_file}" + raise ValueError(err_msg) return total_by_type @@ -164,9 +162,8 @@ def parse_ty(content: str) -> TypeCheckResult: locations: List[re.Match[str]] = list(location_pattern.finditer(content)) # Match errors with locations (they should be in order) - i: int error_match: re.Match[str] - for i, error_match in enumerate(errors): + for error_match in errors: error_code: str = error_match.group(2) result.by_type[error_code] += 1 @@ -228,8 +225,8 @@ def main(error_dir: str, output_file: str, top_n: int | None = 10) -> None: parser_fn: Callable[[str], TypeCheckResult] for name, filename, parser_fn in checkers: - file_path: Path = error_path / filename - content: str = file_path.read_text(encoding="utf-8") + file_path_: Path = error_path / filename + content: str = file_path_.read_text(encoding="utf-8") result: TypeCheckResult = parser_fn(content) # Filter and sort the result filtered_result: TypeCheckResult = result.filter_by(top_n) @@ -241,7 +238,7 @@ def main(error_dir: str, output_file: str, top_n: int | None = 10) -> None: # Write to output file final_output: str = "\n".join(output_lines) output_path.parent.mkdir(parents=True, exist_ok=True) - output_path.write_text(final_output, encoding="utf-8") + _ = output_path.write_text(final_output, encoding="utf-8") # Also print to stdout print(final_output) @@ -252,20 +249,20 @@ def main(error_dir: str, output_file: str, top_n: int | None = 10) -> None: description="Parse type checker outputs and generate detailed breakdown of errors by type and file", formatter_class=argparse.RawDescriptionHelpFormatter, ) - parser.add_argument( + _ = parser.add_argument( "--error-dir", type=str, default=".meta/.type-errors", help="Directory containing type checker output files (default: .meta/.type-errors)", ) - parser.add_argument( + _ = parser.add_argument( "--output", "-o", type=str, default=".meta/typing-summary.txt", help="Output file to write summary to (default: .meta/typing-summary.txt)", ) - parser.add_argument( + _ = parser.add_argument( "--top-n", "-n", type=str, @@ -276,6 +273,7 @@ def main(error_dir: str, output_file: str, top_n: int | None = 10) -> None: args: argparse.Namespace = parser.parse_args() # Parse top_n value + assert isinstance(args.top_n, str) # pyright: ignore[reportAny] top_n_value: int | None if args.top_n.lower() == "all": top_n_value = None @@ -283,4 +281,4 @@ def main(error_dir: str, output_file: str, top_n: int | None = 10) -> None: top_n_int: int = int(args.top_n) top_n_value = None if top_n_int < 0 else top_n_int - main(args.error_dir, args.output, top_n_value) + main(error_dir=args.error_dir, output_file=args.output, top_n=top_n_value) # pyright: ignore[reportAny] From 79136e1b02b46eb3e3da86ca85cb78aec293bf3d Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Mon, 3 Nov 2025 12:41:30 +0000 Subject: [PATCH 55/72] more typing summary --- .meta/typing-summary.txt | 85 ++++++++++++++++++++++------------------ 1 file changed, 47 insertions(+), 38 deletions(-) diff --git a/.meta/typing-summary.txt b/.meta/typing-summary.txt index 9b36bec5..4fad10f2 100644 --- a/.meta/typing-summary.txt +++ b/.meta/typing-summary.txt @@ -1,41 +1,46 @@ # Showing all errors -# mypy: Found 117 errors in 7 files (checked 116 source files) -# basedpyright: 778 errors, 3791 warnings, 0 notes -# ty: Found 216 diagnostics +# mypy: Found 126 errors in 13 files (checked 116 source files) +# basedpyright: 787 errors, 3772 warnings, 0 notes +# ty: Found 226 diagnostics [type_errors.mypy] -total_errors = 117 +total_errors = 126 [type_errors.mypy.by_type] "index" = 31 "typeddict-item" = 26 "call-overload" = 23 +"attr-defined" = 21 "arg-type" = 11 -"attr-defined" = 11 "literal-required" = 10 "var-annotated" = 2 "operator" = 2 -"no-redef" = 1 [type_errors.mypy.by_file] -"tests/unit/json_serialize/test_array_torch.py" = 58 -"tests/unit/json_serialize/test_array.py" = 24 +"tests/unit/json_serialize/test_array_torch.py" = 59 +"tests/unit/json_serialize/test_array.py" = 25 "tests/unit/test_jsonlines.py" = 12 "tests/unit/json_serialize/test_serializable_field.py" = 11 -"tests/unit/json_serialize/test_json_serialize.py" = 7 +"tests/unit/json_serialize/test_json_serialize.py" = 8 "tests/unit/benchmark_parallel/benchmark_parallel.py" = 4 -"muutils/misc/typing_breakdown.py" = 1 +"muutils/json_serialize/serializable_dataclass.py" = 1 +"muutils/json_serialize/json_serialize.py" = 1 +"tests/unit/json_serialize/serializable_dataclass/test_serializable_dataclass.py" = 1 +"tests/unit/json_serialize/serializable_dataclass/test_sdc_properties_nested.py" = 1 +"tests/unit/json_serialize/serializable_dataclass/test_sdc_defaults.py" = 1 +"tests/unit/json_serialize/serializable_dataclass/test_methods_no_override.py" = 1 +"tests/unit/json_serialize/test_util.py" = 1 [type_errors.basedpyright] -total_errors = 3066 +total_errors = 3059 [type_errors.basedpyright.by_type] "reportUnknownParameterType" = 452 "reportMissingParameterType" = 397 -"reportAny" = 374 -"reportUnusedCallResult" = 295 -"reportUnknownVariableType" = 214 +"reportAny" = 369 +"reportUnusedCallResult" = 291 +"reportUnknownVariableType" = 224 "reportExplicitAny" = 194 "reportMissingTypeArgument" = 185 "reportUnknownMemberType" = 141 @@ -45,50 +50,49 @@ total_errors = 3066 "reportCallIssue" = 59 "reportImplicitOverride" = 53 "reportInvalidTypeForm" = 51 -"reportPrivateUsage" = 41 "reportUnannotatedClassAttribute" = 41 "reportPossiblyUnboundVariable" = 34 "reportUnreachable" = 32 +"reportPrivateUsage" = 29 "reportOptionalSubscript" = 27 "reportUnnecessaryIsInstance" = 21 "reportUntypedClassDecorator" = 17 +"reportAttributeAccessIssue" = 15 "reportMissingSuperCall" = 14 "reportUntypedFunctionDecorator" = 13 -"reportUnusedVariable" = 11 +"reportUnusedVariable" = 9 "reportInvalidTypeArguments" = 9 "reportUnnecessaryComparison" = 8 "reportCallInDefaultInitializer" = 8 "reportUndefinedVariable" = 7 "reportOptionalMemberAccess" = 7 -"reportAttributeAccessIssue" = 5 "reportUninitializedInstanceVariable" = 5 "reportUnusedClass" = 4 "reportUnnecessaryTypeIgnoreComment" = 3 -"reportImplicitStringConcatenation" = 3 -"reportUnusedImport" = 3 "reportMissingTypeStubs" = 3 "reportMissingImports" = 3 "reportUnusedExpression" = 3 +"reportImplicitStringConcatenation" = 2 "reportOperatorIssue" = 2 "reportUntypedNamedTuple" = 2 -"reportRedeclaration" = 1 +"reportUnusedImport" = 1 "reportUnusedFunction" = 1 "reportGeneralTypeIssues" = 1 [type_errors.basedpyright.by_file] -"/home/miv/projects/tools/muutils/tests/unit/json_serialize/test_json_serialize.py" = 167 +"/home/miv/projects/tools/muutils/tests/unit/json_serialize/test_json_serialize.py" = 168 "/home/miv/projects/tools/muutils/tests/unit/test_dbg.py" = 156 "/home/miv/projects/tools/muutils/tests/unit/test_parallel.py" = 134 -"/home/miv/projects/tools/muutils/tests/unit/json_serialize/serializable_dataclass/test_serializable_dataclass.py" = 125 +"/home/miv/projects/tools/muutils/tests/unit/json_serialize/serializable_dataclass/test_serializable_dataclass.py" = 126 +"/home/miv/projects/tools/muutils/muutils/json_serialize/serializable_dataclass.py" = 98 "/home/miv/projects/tools/muutils/tests/unit/misc/test_func.py" = 98 -"/home/miv/projects/tools/muutils/muutils/json_serialize/serializable_dataclass.py" = 97 -"/home/miv/projects/tools/muutils/tests/unit/json_serialize/test_array_torch.py" = 95 +"/home/miv/projects/tools/muutils/tests/unit/json_serialize/test_array_torch.py" = 96 "/home/miv/projects/tools/muutils/muutils/dictmagic.py" = 88 "/home/miv/projects/tools/muutils/muutils/json_serialize/serializable_field.py" = 88 "/home/miv/projects/tools/muutils/tests/unit/validate_type/test_get_kwargs.py" = 87 "/home/miv/projects/tools/muutils/tests/unit/validate_type/test_validate_type.py" = 86 "/home/miv/projects/tools/muutils/tests/unit/validate_type/test_validate_type_GENERATED.py" = 86 -"/home/miv/projects/tools/muutils/muutils/json_serialize/json_serialize.py" = 81 +"/home/miv/projects/tools/muutils/muutils/json_serialize/json_serialize.py" = 82 "/home/miv/projects/tools/muutils/muutils/misc/func.py" = 77 "/home/miv/projects/tools/muutils/tests/unit/test_interval.py" = 75 "/home/miv/projects/tools/muutils/muutils/misc/freezing.py" = 67 @@ -103,7 +107,7 @@ total_errors = 3066 "/home/miv/projects/tools/muutils/muutils/parallel.py" = 45 "/home/miv/projects/tools/muutils/muutils/statcounter.py" = 41 "/home/miv/projects/tools/muutils/muutils/sysinfo.py" = 39 -"/home/miv/projects/tools/muutils/tests/unit/json_serialize/test_array.py" = 38 +"/home/miv/projects/tools/muutils/tests/unit/json_serialize/test_array.py" = 39 "/home/miv/projects/tools/muutils/muutils/web/bundle_html.py" = 37 "/home/miv/projects/tools/muutils/tests/unit/test_dictmagic.py" = 37 "/home/miv/projects/tools/muutils/tests/unit/errormode/test_errormode_functionality.py" = 36 @@ -111,14 +115,14 @@ total_errors = 3066 "/home/miv/projects/tools/muutils/muutils/nbutils/configure_notebook.py" = 33 "/home/miv/projects/tools/muutils/tests/unit/test_spinner.py" = 33 "/home/miv/projects/tools/muutils/tests/unit/misc/test_freeze.py" = 31 -"/home/miv/projects/tools/muutils/muutils/json_serialize/util.py" = 27 "/home/miv/projects/tools/muutils/muutils/misc/sequence.py" = 27 "/home/miv/projects/tools/muutils/tests/unit/misc/test_numerical_conversions.py" = 27 -"/home/miv/projects/tools/muutils/tests/unit/json_serialize/serializable_dataclass/test_methods_no_override.py" = 25 +"/home/miv/projects/tools/muutils/tests/unit/json_serialize/serializable_dataclass/test_methods_no_override.py" = 26 "/home/miv/projects/tools/muutils/muutils/mlutils.py" = 24 "/home/miv/projects/tools/muutils/muutils/validate_type.py" = 24 +"/home/miv/projects/tools/muutils/muutils/json_serialize/util.py" = 23 +"/home/miv/projects/tools/muutils/tests/unit/json_serialize/test_util.py" = 21 "/home/miv/projects/tools/muutils/muutils/interval.py" = 20 -"/home/miv/projects/tools/muutils/tests/unit/json_serialize/test_util.py" = 20 "/home/miv/projects/tools/muutils/tests/unit/validate_type/test_validate_type_special.py" = 20 "/home/miv/projects/tools/muutils/muutils/logger/logger.py" = 19 "/home/miv/projects/tools/muutils/tests/unit/test_jsonlines.py" = 19 @@ -129,7 +133,6 @@ total_errors = 3066 "/home/miv/projects/tools/muutils/tests/unit/test_tensor_info_torch.py" = 15 "/home/miv/projects/tools/muutils/tests/unit/nbutils/test_configure_notebook.py" = 14 "/home/miv/projects/tools/muutils/muutils/jsonlines.py" = 13 -"/home/miv/projects/tools/muutils/muutils/misc/typing_breakdown.py" = 13 "/home/miv/projects/tools/muutils/tests/unit/misc/test_misc.py" = 13 "/home/miv/projects/tools/muutils/muutils/nbutils/run_notebook_tests.py" = 12 "/home/miv/projects/tools/muutils/muutils/cli/arg_bool.py" = 11 @@ -149,10 +152,10 @@ total_errors = 3066 "/home/miv/projects/tools/muutils/muutils/nbutils/mermaid.py" = 7 "/home/miv/projects/tools/muutils/tests/unit/cli/test_command.py" = 7 "/home/miv/projects/tools/muutils/muutils/logger/timing.py" = 6 +"/home/miv/projects/tools/muutils/tests/unit/json_serialize/serializable_dataclass/test_sdc_defaults.py" = 6 +"/home/miv/projects/tools/muutils/tests/unit/json_serialize/serializable_dataclass/test_sdc_properties_nested.py" = 6 "/home/miv/projects/tools/muutils/tests/unit/nbutils/test_conversion.py" = 6 "/home/miv/projects/tools/muutils/tests/conftest.py" = 5 -"/home/miv/projects/tools/muutils/tests/unit/json_serialize/serializable_dataclass/test_sdc_defaults.py" = 5 -"/home/miv/projects/tools/muutils/tests/unit/json_serialize/serializable_dataclass/test_sdc_properties_nested.py" = 5 "/home/miv/projects/tools/muutils/muutils/logger/simplelogger.py" = 4 "/home/miv/projects/tools/muutils/muutils/web/html_to_pdf.py" = 4 "/home/miv/projects/tools/muutils/tests/unit/math/test_matrix_powers_torch.py" = 4 @@ -173,24 +176,30 @@ total_errors = 3066 "/home/miv/projects/tools/muutils/tests/unit/benchmark_parallel/test_benchmark_demo.py" = 1 [type_errors.ty] -total_errors = 216 +total_errors = 226 [type_errors.ty.by_type] "unknown-argument" = 164 "unresolved-attribute" = 33 +"unresolved-import" = 11 "invalid-argument-type" = 8 "invalid-assignment" = 6 "too-many-positional-arguments" = 3 "invalid-return-type" = 1 -"unresolved-import" = 1 [type_errors.ty.by_file] -"tests/unit/json_serialize/serializable_dataclass/test_serializable_dataclass.py" = 134 +"tests/unit/json_serialize/serializable_dataclass/test_serializable_dataclass.py" = 135 "tests/unit/json_serialize/test_serializable_field.py" = 29 -"tests/unit/json_serialize/serializable_dataclass/test_methods_no_override.py" = 26 -"tests/unit/json_serialize/serializable_dataclass/test_sdc_properties_nested.py" = 9 +"tests/unit/json_serialize/serializable_dataclass/test_methods_no_override.py" = 27 +"tests/unit/json_serialize/serializable_dataclass/test_sdc_properties_nested.py" = 10 "tests/unit/test_dictmagic.py" = 8 "tests/unit/benchmark_parallel/benchmark_parallel.py" = 6 -"tests/unit/json_serialize/test_array_torch.py" = 2 +"tests/unit/json_serialize/test_array_torch.py" = 3 +"muutils/json_serialize/json_serialize.py" = 1 +"muutils/json_serialize/serializable_dataclass.py" = 1 "tests/unit/benchmark_parallel/test_benchmark_demo.py" = 1 +"tests/unit/json_serialize/serializable_dataclass/test_sdc_defaults.py" = 1 +"tests/unit/json_serialize/test_array.py" = 1 +"tests/unit/json_serialize/test_json_serialize.py" = 1 +"tests/unit/json_serialize/test_util.py" = 1 "tests/unit/test_statcounter.py" = 1 From 5986d6371ec7c32648dcd40a9cb991fe5960497e Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Mon, 3 Nov 2025 12:43:02 +0000 Subject: [PATCH 56/72] type fixes --- tests/unit/test_statcounter.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/unit/test_statcounter.py b/tests/unit/test_statcounter.py index 1013ba60..f56a1979 100644 --- a/tests/unit/test_statcounter.py +++ b/tests/unit/test_statcounter.py @@ -6,7 +6,7 @@ def _compute_err(a: float, b: float | np.floating, /) -> dict[str, int | float]: - result: dict[str, int | float] = dict( + result: dict[str, int | float] = dict( # type: ignore[invalid-assignment] num_a=float(a), num_b=float(b), diff=float(b - a), @@ -20,11 +20,11 @@ def _compare_np_custom(arr: np.ndarray) -> dict[str, dict[str, float]]: return dict( mean=_compute_err(counter.mean(), np.mean(arr)), std=_compute_err(counter.std(), np.std(arr)), - min=_compute_err(counter.min(), np.min(arr)), + min=_compute_err(counter.min(), np.min(arr)), # pyright: ignore[reportUnknownArgumentType, reportAny] q1=_compute_err(counter.percentile(0.25), np.percentile(arr, 25)), median=_compute_err(counter.median(), np.median(arr)), q3=_compute_err(counter.percentile(0.75), np.percentile(arr, 75)), - max=_compute_err(counter.max(), np.max(arr)), + max=_compute_err(counter.max(), np.max(arr)), # pyright: ignore[reportUnknownArgumentType, reportAny] ) From 0ee66ed0d7a16a67c7ce5ec33acee2a1743ca999 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Mon, 3 Nov 2025 14:29:47 +0000 Subject: [PATCH 57/72] wip, lots of type fixes. --- .meta/typing-summary.txt | 115 ++++---- muutils/json_serialize/json_serialize.py | 49 ++-- .../json_serialize/serializable_dataclass.py | 2 +- muutils/json_serialize/util.py | 60 +++-- muutils/tensor_utils.py | 246 +++++++++--------- muutils/validate_type.py | 11 +- .../test_methods_no_override.py | 2 +- .../test_sdc_defaults.py | 2 +- .../test_sdc_properties_nested.py | 2 +- .../test_serializable_dataclass.py | 2 +- tests/unit/json_serialize/test_array.py | 2 +- tests/unit/json_serialize/test_array_torch.py | 4 +- .../json_serialize/test_json_serialize.py | 3 +- tests/unit/json_serialize/test_util.py | 4 +- tests/unit/logger/test_log_util.py | 7 +- tests/unit/test_jsonlines.py | 13 +- tests/unit/test_tensor_utils_torch.py | 25 +- 17 files changed, 284 insertions(+), 265 deletions(-) diff --git a/.meta/typing-summary.txt b/.meta/typing-summary.txt index 4fad10f2..444b42f2 100644 --- a/.meta/typing-summary.txt +++ b/.meta/typing-summary.txt @@ -1,76 +1,74 @@ # Showing all errors -# mypy: Found 126 errors in 13 files (checked 116 source files) -# basedpyright: 787 errors, 3772 warnings, 0 notes -# ty: Found 226 diagnostics +# mypy: Found 103 errors in 8 files (checked 116 source files) +# basedpyright: 666 errors, 3663 warnings, 0 notes +# ty: Found 215 diagnostics [type_errors.mypy] -total_errors = 126 +total_errors = 103 [type_errors.mypy.by_type] -"index" = 31 "typeddict-item" = 26 -"call-overload" = 23 -"attr-defined" = 21 -"arg-type" = 11 +"index" = 19 +"arg-type" = 14 +"call-overload" = 12 +"attr-defined" = 11 "literal-required" = 10 +"operator" = 6 "var-annotated" = 2 -"operator" = 2 +"valid-type" = 1 +"assignment" = 1 +"return" = 1 [type_errors.mypy.by_file] -"tests/unit/json_serialize/test_array_torch.py" = 59 -"tests/unit/json_serialize/test_array.py" = 25 +"tests/unit/json_serialize/test_array_torch.py" = 35 +"tests/unit/json_serialize/test_array.py" = 24 "tests/unit/test_jsonlines.py" = 12 +"tests/unit/json_serialize/test_json_serialize.py" = 11 "tests/unit/json_serialize/test_serializable_field.py" = 11 -"tests/unit/json_serialize/test_json_serialize.py" = 8 +"muutils/tensor_utils.py" = 5 "tests/unit/benchmark_parallel/benchmark_parallel.py" = 4 -"muutils/json_serialize/serializable_dataclass.py" = 1 -"muutils/json_serialize/json_serialize.py" = 1 -"tests/unit/json_serialize/serializable_dataclass/test_serializable_dataclass.py" = 1 -"tests/unit/json_serialize/serializable_dataclass/test_sdc_properties_nested.py" = 1 -"tests/unit/json_serialize/serializable_dataclass/test_sdc_defaults.py" = 1 -"tests/unit/json_serialize/serializable_dataclass/test_methods_no_override.py" = 1 -"tests/unit/json_serialize/test_util.py" = 1 +"tests/unit/test_tensor_utils_torch.py" = 1 [type_errors.basedpyright] -total_errors = 3059 +total_errors = 2938 [type_errors.basedpyright.by_type] -"reportUnknownParameterType" = 452 -"reportMissingParameterType" = 397 -"reportAny" = 369 +"reportUnknownParameterType" = 437 +"reportMissingParameterType" = 386 +"reportAny" = 367 "reportUnusedCallResult" = 291 -"reportUnknownVariableType" = 224 -"reportExplicitAny" = 194 -"reportMissingTypeArgument" = 185 -"reportUnknownMemberType" = 141 -"reportUnknownLambdaType" = 131 +"reportUnknownVariableType" = 207 +"reportExplicitAny" = 201 +"reportMissingTypeArgument" = 186 +"reportUnknownMemberType" = 136 +"reportUnknownLambdaType" = 127 "reportUnusedParameter" = 100 -"reportIndexIssue" = 91 -"reportCallIssue" = 59 "reportImplicitOverride" = 53 -"reportInvalidTypeForm" = 51 +"reportInvalidTypeForm" = 49 +"reportIndexIssue" = 49 +"reportCallIssue" = 45 "reportUnannotatedClassAttribute" = 41 +"reportPrivateUsage" = 35 "reportPossiblyUnboundVariable" = 34 "reportUnreachable" = 32 -"reportPrivateUsage" = 29 -"reportOptionalSubscript" = 27 "reportUnnecessaryIsInstance" = 21 "reportUntypedClassDecorator" = 17 -"reportAttributeAccessIssue" = 15 "reportMissingSuperCall" = 14 "reportUntypedFunctionDecorator" = 13 +"reportOptionalSubscript" = 13 "reportUnusedVariable" = 9 "reportInvalidTypeArguments" = 9 "reportUnnecessaryComparison" = 8 "reportCallInDefaultInitializer" = 8 "reportUndefinedVariable" = 7 "reportOptionalMemberAccess" = 7 +"reportUnnecessaryTypeIgnoreComment" = 6 "reportUninitializedInstanceVariable" = 5 "reportUnusedClass" = 4 -"reportUnnecessaryTypeIgnoreComment" = 3 "reportMissingTypeStubs" = 3 "reportMissingImports" = 3 +"reportArgumentType" = 3 "reportUnusedExpression" = 3 "reportImplicitStringConcatenation" = 2 "reportOperatorIssue" = 2 @@ -80,34 +78,33 @@ total_errors = 3059 "reportGeneralTypeIssues" = 1 [type_errors.basedpyright.by_file] -"/home/miv/projects/tools/muutils/tests/unit/json_serialize/test_json_serialize.py" = 168 +"/home/miv/projects/tools/muutils/tests/unit/json_serialize/test_json_serialize.py" = 165 "/home/miv/projects/tools/muutils/tests/unit/test_dbg.py" = 156 "/home/miv/projects/tools/muutils/tests/unit/test_parallel.py" = 134 -"/home/miv/projects/tools/muutils/tests/unit/json_serialize/serializable_dataclass/test_serializable_dataclass.py" = 126 -"/home/miv/projects/tools/muutils/muutils/json_serialize/serializable_dataclass.py" = 98 +"/home/miv/projects/tools/muutils/tests/unit/json_serialize/serializable_dataclass/test_serializable_dataclass.py" = 125 "/home/miv/projects/tools/muutils/tests/unit/misc/test_func.py" = 98 -"/home/miv/projects/tools/muutils/tests/unit/json_serialize/test_array_torch.py" = 96 +"/home/miv/projects/tools/muutils/muutils/json_serialize/serializable_dataclass.py" = 97 "/home/miv/projects/tools/muutils/muutils/dictmagic.py" = 88 "/home/miv/projects/tools/muutils/muutils/json_serialize/serializable_field.py" = 88 "/home/miv/projects/tools/muutils/tests/unit/validate_type/test_get_kwargs.py" = 87 "/home/miv/projects/tools/muutils/tests/unit/validate_type/test_validate_type.py" = 86 "/home/miv/projects/tools/muutils/tests/unit/validate_type/test_validate_type_GENERATED.py" = 86 -"/home/miv/projects/tools/muutils/muutils/json_serialize/json_serialize.py" = 82 "/home/miv/projects/tools/muutils/muutils/misc/func.py" = 77 "/home/miv/projects/tools/muutils/tests/unit/test_interval.py" = 75 +"/home/miv/projects/tools/muutils/muutils/json_serialize/json_serialize.py" = 70 "/home/miv/projects/tools/muutils/muutils/misc/freezing.py" = 67 "/home/miv/projects/tools/muutils/tests/unit/cli/test_arg_bool.py" = 66 "/home/miv/projects/tools/muutils/muutils/spinner.py" = 63 "/home/miv/projects/tools/muutils/tests/unit/benchmark_parallel/benchmark_parallel.py" = 56 "/home/miv/projects/tools/muutils/tests/unit/web/test_bundle_html.py" = 55 -"/home/miv/projects/tools/muutils/muutils/tensor_utils.py" = 54 "/home/miv/projects/tools/muutils/muutils/tensor_info.py" = 52 +"/home/miv/projects/tools/muutils/muutils/tensor_utils.py" = 49 "/home/miv/projects/tools/muutils/tests/unit/errormode/test_errormode_init.py" = 46 "/home/miv/projects/tools/muutils/tests/unit/json_serialize/test_serializable_field.py" = 46 "/home/miv/projects/tools/muutils/muutils/parallel.py" = 45 "/home/miv/projects/tools/muutils/muutils/statcounter.py" = 41 "/home/miv/projects/tools/muutils/muutils/sysinfo.py" = 39 -"/home/miv/projects/tools/muutils/tests/unit/json_serialize/test_array.py" = 39 +"/home/miv/projects/tools/muutils/tests/unit/json_serialize/test_array.py" = 38 "/home/miv/projects/tools/muutils/muutils/web/bundle_html.py" = 37 "/home/miv/projects/tools/muutils/tests/unit/test_dictmagic.py" = 37 "/home/miv/projects/tools/muutils/tests/unit/errormode/test_errormode_functionality.py" = 36 @@ -117,17 +114,17 @@ total_errors = 3059 "/home/miv/projects/tools/muutils/tests/unit/misc/test_freeze.py" = 31 "/home/miv/projects/tools/muutils/muutils/misc/sequence.py" = 27 "/home/miv/projects/tools/muutils/tests/unit/misc/test_numerical_conversions.py" = 27 -"/home/miv/projects/tools/muutils/tests/unit/json_serialize/serializable_dataclass/test_methods_no_override.py" = 26 +"/home/miv/projects/tools/muutils/tests/unit/json_serialize/serializable_dataclass/test_methods_no_override.py" = 25 +"/home/miv/projects/tools/muutils/tests/unit/json_serialize/test_array_torch.py" = 25 "/home/miv/projects/tools/muutils/muutils/mlutils.py" = 24 "/home/miv/projects/tools/muutils/muutils/validate_type.py" = 24 -"/home/miv/projects/tools/muutils/muutils/json_serialize/util.py" = 23 -"/home/miv/projects/tools/muutils/tests/unit/json_serialize/test_util.py" = 21 "/home/miv/projects/tools/muutils/muutils/interval.py" = 20 "/home/miv/projects/tools/muutils/tests/unit/validate_type/test_validate_type_special.py" = 20 "/home/miv/projects/tools/muutils/muutils/logger/logger.py" = 19 "/home/miv/projects/tools/muutils/tests/unit/test_jsonlines.py" = 19 "/home/miv/projects/tools/muutils/muutils/math/matrix_powers.py" = 18 "/home/miv/projects/tools/muutils/muutils/misc/string.py" = 18 +"/home/miv/projects/tools/muutils/tests/unit/json_serialize/test_util.py" = 18 "/home/miv/projects/tools/muutils/tests/unit/test_kappa.py" = 17 "/home/miv/projects/tools/muutils/muutils/logger/log_util.py" = 16 "/home/miv/projects/tools/muutils/tests/unit/test_tensor_info_torch.py" = 15 @@ -152,16 +149,16 @@ total_errors = 3059 "/home/miv/projects/tools/muutils/muutils/nbutils/mermaid.py" = 7 "/home/miv/projects/tools/muutils/tests/unit/cli/test_command.py" = 7 "/home/miv/projects/tools/muutils/muutils/logger/timing.py" = 6 -"/home/miv/projects/tools/muutils/tests/unit/json_serialize/serializable_dataclass/test_sdc_defaults.py" = 6 -"/home/miv/projects/tools/muutils/tests/unit/json_serialize/serializable_dataclass/test_sdc_properties_nested.py" = 6 "/home/miv/projects/tools/muutils/tests/unit/nbutils/test_conversion.py" = 6 "/home/miv/projects/tools/muutils/tests/conftest.py" = 5 +"/home/miv/projects/tools/muutils/tests/unit/json_serialize/serializable_dataclass/test_sdc_defaults.py" = 5 +"/home/miv/projects/tools/muutils/tests/unit/json_serialize/serializable_dataclass/test_sdc_properties_nested.py" = 5 +"/home/miv/projects/tools/muutils/tests/unit/test_tensor_utils_torch.py" = 5 "/home/miv/projects/tools/muutils/muutils/logger/simplelogger.py" = 4 "/home/miv/projects/tools/muutils/muutils/web/html_to_pdf.py" = 4 "/home/miv/projects/tools/muutils/tests/unit/math/test_matrix_powers_torch.py" = 4 "/home/miv/projects/tools/muutils/tests/unit/misc/test_sequence.py" = 4 "/home/miv/projects/tools/muutils/tests/unit/test_mlutils.py" = 4 -"/home/miv/projects/tools/muutils/tests/unit/test_tensor_utils_torch.py" = 4 "/home/miv/projects/tools/muutils/muutils/misc/hashing.py" = 3 "/home/miv/projects/tools/muutils/tests/unit/test_tensor_info.py" = 3 "/home/miv/projects/tools/muutils/muutils/logger/loggingstream.py" = 2 @@ -170,36 +167,30 @@ total_errors = 3059 "/home/miv/projects/tools/muutils/tests/unit/test_chunks.py" = 2 "/home/miv/projects/tools/muutils/tests/unit/test_timeit_fancy.py" = 2 "/home/miv/projects/tools/muutils/muutils/console_unicode.py" = 1 +"/home/miv/projects/tools/muutils/muutils/json_serialize/util.py" = 1 "/home/miv/projects/tools/muutils/muutils/math/bins.py" = 1 "/home/miv/projects/tools/muutils/muutils/misc/__init__.py" = 1 "/home/miv/projects/tools/muutils/muutils/misc/b64_decode.py" = 1 "/home/miv/projects/tools/muutils/tests/unit/benchmark_parallel/test_benchmark_demo.py" = 1 [type_errors.ty] -total_errors = 226 +total_errors = 215 [type_errors.ty.by_type] "unknown-argument" = 164 "unresolved-attribute" = 33 -"unresolved-import" = 11 "invalid-argument-type" = 8 -"invalid-assignment" = 6 +"invalid-assignment" = 5 "too-many-positional-arguments" = 3 "invalid-return-type" = 1 +"unresolved-import" = 1 [type_errors.ty.by_file] -"tests/unit/json_serialize/serializable_dataclass/test_serializable_dataclass.py" = 135 +"tests/unit/json_serialize/serializable_dataclass/test_serializable_dataclass.py" = 134 "tests/unit/json_serialize/test_serializable_field.py" = 29 -"tests/unit/json_serialize/serializable_dataclass/test_methods_no_override.py" = 27 -"tests/unit/json_serialize/serializable_dataclass/test_sdc_properties_nested.py" = 10 +"tests/unit/json_serialize/serializable_dataclass/test_methods_no_override.py" = 26 +"tests/unit/json_serialize/serializable_dataclass/test_sdc_properties_nested.py" = 9 "tests/unit/test_dictmagic.py" = 8 "tests/unit/benchmark_parallel/benchmark_parallel.py" = 6 -"tests/unit/json_serialize/test_array_torch.py" = 3 -"muutils/json_serialize/json_serialize.py" = 1 -"muutils/json_serialize/serializable_dataclass.py" = 1 +"tests/unit/json_serialize/test_array_torch.py" = 2 "tests/unit/benchmark_parallel/test_benchmark_demo.py" = 1 -"tests/unit/json_serialize/serializable_dataclass/test_sdc_defaults.py" = 1 -"tests/unit/json_serialize/test_array.py" = 1 -"tests/unit/json_serialize/test_json_serialize.py" = 1 -"tests/unit/json_serialize/test_util.py" = 1 -"tests/unit/test_statcounter.py" = 1 diff --git a/muutils/json_serialize/json_serialize.py b/muutils/json_serialize/json_serialize.py index a2b589da..dd2f6512 100644 --- a/muutils/json_serialize/json_serialize.py +++ b/muutils/json_serialize/json_serialize.py @@ -14,28 +14,33 @@ import warnings from dataclasses import dataclass, is_dataclass from pathlib import Path -from typing import Any, Callable, Iterable, Mapping, Set, Union +from typing import TYPE_CHECKING, Any, Callable, Iterable, Mapping, Set, Union from muutils.errormode import ErrorMode -try: +if TYPE_CHECKING: + # always need array.py for type checking from muutils.json_serialize.array import ArrayMode, serialize_array -except ImportError as e: - # TYPING: obviously, these types are all wrong if we can't import array.py - ArrayMode = str # type: ignore[misc] - serialize_array = lambda *args, **kwargs: None # type: ignore[assignment, invalid-assignment] # noqa: E731 # pyright: ignore[reportUnknownVariableType, reportUnknownLambdaType] - warnings.warn( - f"muutils.json_serialize.array could not be imported probably because missing numpy, array serialization will not work: \n{e}", - ImportWarning, - ) +else: + try: + from muutils.json_serialize.array import ArrayMode, serialize_array + except ImportError as e: + # TYPING: obviously, these types are all wrong if we can't import array.py + ArrayMode = str # type: ignore[misc] + serialize_array = lambda *args, **kwargs: None # type: ignore[assignment, invalid-assignment] # noqa: E731 # pyright: ignore[reportUnknownVariableType, reportUnknownLambdaType] + warnings.warn( + f"muutils.json_serialize.array could not be imported probably because missing numpy, array serialization will not work: \n{e}", + ImportWarning, + ) + +from muutils.json_serialize.types import _FORMAT_KEY, Hashableitem # pyright: ignore[reportPrivateUsage] from muutils.json_serialize.util import ( - _FORMAT_KEY, - Hashableitem, + JSONdict, JSONitem, MonoTuple, SerializationException, - _recursive_hashify, + _recursive_hashify, # pyright: ignore[reportPrivateUsage, reportUnknownVariableType] isinstance_namedtuple, safe_getsource, string_as_lines, @@ -53,13 +58,13 @@ "__annotations__", ) -SERIALIZER_SPECIAL_FUNCS: dict[str, Callable] = { +SERIALIZER_SPECIAL_FUNCS: dict[str, Callable[..., str | list[str]]] = { "str": str, "dir": dir, - "type": try_catch(lambda x: str(type(x).__name__)), - "repr": try_catch(lambda x: repr(x)), - "code": try_catch(lambda x: inspect.getsource(x)), - "sourcefile": try_catch(lambda x: inspect.getsourcefile(x)), + "type": try_catch(lambda x: str(type(x).__name__)), # pyright: ignore[reportUnknownArgumentType, reportUnknownLambdaType] + "repr": try_catch(lambda x: repr(x)), # pyright: ignore[reportUnknownArgumentType, reportUnknownLambdaType] + "code": try_catch(lambda x: inspect.getsource(x)), # pyright: ignore[reportUnknownArgumentType, reportUnknownLambdaType] + "sourcefile": try_catch(lambda x: str(inspect.getsourcefile(x))), # pyright: ignore[reportUnknownArgumentType, reportUnknownLambdaType] } SERIALIZE_DIRECT_AS_STR: Set[str] = { @@ -89,7 +94,7 @@ class SerializerHandler: # description of this serializer desc: str - def serialize(self) -> dict: + def serialize(self) -> JSONdict: """serialize the handler info""" return { # get the code and doc of the check function @@ -241,7 +246,7 @@ def _serialize_override_serialize_func( SerializerHandler( check=lambda self, obj, path: True, serialize_func=lambda self, obj, path: { - **{k: str(getattr(obj, k, None)) for k in SERIALIZER_SPECIAL_KEYS}, + **{k: str(getattr(obj, k, None)) for k in SERIALIZER_SPECIAL_KEYS}, # type: ignore[typeddict-item] **{k: f(obj) for k, f in SERIALIZER_SPECIAL_FUNCS.items()}, }, uid="fallback", @@ -279,7 +284,7 @@ class JsonSerializer: def __init__( self, *args, - array_mode: ArrayMode = "array_list_meta", + array_mode: "ArrayMode" = "array_list_meta", error_mode: ErrorMode = ErrorMode.EXCEPT, handlers_pre: MonoTuple[SerializerHandler] = tuple(), handlers_default: MonoTuple[SerializerHandler] = DEFAULT_HANDLERS, @@ -290,7 +295,7 @@ def __init__( f"JsonSerializer takes no positional arguments!\n{args = }" ) - self.array_mode: ArrayMode = array_mode + self.array_mode: "ArrayMode" = array_mode self.error_mode: ErrorMode = ErrorMode.from_any(error_mode) self.write_only_format: bool = write_only_format # join up the handlers diff --git a/muutils/json_serialize/serializable_dataclass.py b/muutils/json_serialize/serializable_dataclass.py index 0abed22f..81af3bda 100644 --- a/muutils/json_serialize/serializable_dataclass.py +++ b/muutils/json_serialize/serializable_dataclass.py @@ -66,8 +66,8 @@ class NestedClass(SerializableDataclass): SerializableField, serializable_field, ) +from muutils.json_serialize.types import _FORMAT_KEY from muutils.json_serialize.util import ( - _FORMAT_KEY, JSONdict, array_safe_eq, dc_eq, diff --git a/muutils/json_serialize/util.py b/muutils/json_serialize/util.py index 62c73243..19e25b95 100644 --- a/muutils/json_serialize/util.py +++ b/muutils/json_serialize/util.py @@ -98,7 +98,8 @@ def isinstance_namedtuple(x: Any) -> bool: # pyright: ignore[reportAny] f: Any = getattr(t, "_fields", None) if not isinstance(f, tuple): return False - return all(isinstance(n, str) for n in f) + # fine that the type is unknown -- that's what we want to check + return all(isinstance(n, str) for n in f) # pyright: ignore[reportUnknownVariableType] T_FuncTryCatchReturn = TypeVar("T_FuncTryCatchReturn") @@ -113,7 +114,7 @@ def try_catch( """ @functools.wraps(func) - def newfunc(*args: Any, **kwargs: Any) -> Union[T_FuncTryCatchReturn, str]: + def newfunc(*args: Any, **kwargs: Any) -> Union[T_FuncTryCatchReturn, str]: # pyright: ignore[reportAny] try: return func(*args, **kwargs) except Exception as e: @@ -122,16 +123,17 @@ def newfunc(*args: Any, **kwargs: Any) -> Union[T_FuncTryCatchReturn, str]: return newfunc -def _recursive_hashify(obj: Any, force: bool = True) -> Hashableitem: +# TYPING: can we get rid of any of these? +def _recursive_hashify(obj: Any, force: bool = True) -> Hashableitem: # pyright: ignore[reportUnknownParameterType, reportAny] if isinstance(obj, typing.Mapping): - return tuple((k, _recursive_hashify(v)) for k, v in obj.items()) + return tuple((k, _recursive_hashify(v)) for k, v in obj.items()) # pyright: ignore[reportUnknownVariableType, reportUnknownArgumentType] elif isinstance(obj, (bool, int, float, str)): return obj elif isinstance(obj, (tuple, list, Iterable)): - return tuple(_recursive_hashify(v) for v in obj) + return tuple(_recursive_hashify(v) for v in obj) # pyright: ignore[reportUnknownVariableType, reportUnknownArgumentType] else: if force: - return str(obj) + return str(obj) # pyright: ignore[reportAny] else: raise ValueError(f"cannot hashify:\n{obj}") @@ -151,7 +153,7 @@ def string_as_lines(s: str | None) -> list[str]: return s.splitlines(keepends=False) -def safe_getsource(func) -> list[str]: +def safe_getsource(func: Callable[..., Any]) -> list[str]: try: return string_as_lines(inspect.getsource(func)) except Exception as e: @@ -159,28 +161,28 @@ def safe_getsource(func) -> list[str]: # credit to https://stackoverflow.com/questions/51743827/how-to-compare-equality-of-dataclasses-holding-numpy-ndarray-boola-b-raises -def array_safe_eq(a: Any, b: Any) -> bool: +def array_safe_eq(a: Any, b: Any) -> bool: # pyright: ignore[reportAny] """check if two objects are equal, account for if numpy arrays or torch tensors""" if a is b: return True - if type(a) is not type(b): + if type(a) is not type(b): # pyright: ignore[reportAny] return False if ( - str(type(a)) == "" - and str(type(b)) == "" + str(type(a)) == "" # pyright: ignore[reportAny, reportUnknownArgumentType] + and str(type(b)) == "" # pyright: ignore[reportAny, reportUnknownArgumentType] ) or ( - str(type(a)) == "" - and str(type(b)) == "" + str(type(a)) == "" # pyright: ignore[reportAny, reportUnknownArgumentType] + and str(type(b)) == "" # pyright: ignore[reportAny, reportUnknownArgumentType] ): - return (a == b).all() + return (a == b).all() # pyright: ignore[reportAny] if ( - str(type(a)) == "" - and str(type(b)) == "" + str(type(a)) == "" # pyright: ignore[reportUnknownArgumentType, reportAny] + and str(type(b)) == "" # pyright: ignore[reportUnknownArgumentType, reportAny] ): - return a.equals(b) + return a.equals(b) # pyright: ignore[reportAny] if isinstance(a, typing.Sequence) and isinstance(b, typing.Sequence): if len(a) == 0 and len(b) == 0: @@ -188,22 +190,24 @@ def array_safe_eq(a: Any, b: Any) -> bool: return len(a) == len(b) and all(array_safe_eq(a1, b1) for a1, b1 in zip(a, b)) if isinstance(a, (dict, typing.Mapping)) and isinstance(b, (dict, typing.Mapping)): - return len(a) == len(b) and all( + return len(a) == len(b) and all( # pyright: ignore[reportUnknownArgumentType] array_safe_eq(k1, k2) and array_safe_eq(a[k1], b[k2]) - for k1, k2 in zip(a.keys(), b.keys()) + for k1, k2 in zip(a.keys(), b.keys()) # pyright: ignore[reportUnknownArgumentType, reportUnknownVariableType] ) try: - return bool(a == b) + return bool(a == b) # pyright: ignore[reportAny] except (TypeError, ValueError) as e: warnings.warn(f"Cannot compare {a} and {b} for equality\n{e}") return NotImplemented # type: ignore[return-value] +# TYPING: see what can be done about so many `Any`s here def dc_eq( - dc1, - dc2, + dc1: Any, # pyright: ignore[reportAny] + dc2: Any, # pyright: ignore[reportAny] except_when_class_mismatch: bool = False, + # TODO: why is this unused? false_when_class_mismatch: bool = True, except_when_field_mismatch: bool = False, ) -> bool: @@ -268,15 +272,15 @@ def dc_eq( if dc1 is dc2: return True - if dc1.__class__ is not dc2.__class__: # pyright: ignore[reportUnknownMemberType] + if dc1.__class__ is not dc2.__class__: # pyright: ignore[reportAny] if except_when_class_mismatch: # if the classes don't match, raise an error raise TypeError( - f"Cannot compare dataclasses of different classes: `{dc1.__class__}` and `{dc2.__class__}`" # pyright: ignore[reportUnknownMemberType] + f"Cannot compare dataclasses of different classes: `{dc1.__class__}` and `{dc2.__class__}`" # pyright: ignore[reportAny] ) if except_when_field_mismatch: - dc1_fields: set[str] = set([fld.name for fld in dataclasses.fields(dc1)]) # pyright: ignore[reportUnknownArgumentType] - dc2_fields: set[str] = set([fld.name for fld in dataclasses.fields(dc2)]) # pyright: ignore[reportUnknownArgumentType] + dc1_fields: set[str] = set([fld.name for fld in dataclasses.fields(dc1)]) # pyright: ignore[reportAny] + dc2_fields: set[str] = set([fld.name for fld in dataclasses.fields(dc2)]) # pyright: ignore[reportAny] fields_match: bool = set(dc1_fields) == set(dc2_fields) if not fields_match: # if the fields match, keep going @@ -286,7 +290,7 @@ def dc_eq( return False return all( - array_safe_eq(getattr(dc1, fld.name), getattr(dc2, fld.name)) - for fld in dataclasses.fields(dc1) + array_safe_eq(getattr(dc1, fld.name), getattr(dc2, fld.name)) # pyright: ignore[reportAny] + for fld in dataclasses.fields(dc1) # pyright: ignore[reportAny] if fld.compare ) diff --git a/muutils/tensor_utils.py b/muutils/tensor_utils.py index 9bebcd40..fa53dc94 100644 --- a/muutils/tensor_utils.py +++ b/muutils/tensor_utils.py @@ -13,18 +13,18 @@ import json import typing +from typing import Any import jaxtyping import numpy as np import torch -from muutils.errormode import ErrorMode from muutils.dictmagic import dotlist_to_nested_dict # pylint: disable=missing-class-docstring -TYPE_TO_JAX_DTYPE: dict = { +TYPE_TO_JAX_DTYPE: dict[Any, Any] = { float: jaxtyping.Float, int: jaxtyping.Int, jaxtyping.Float: jaxtyping.Float, @@ -68,109 +68,112 @@ } "dict mapping python, numpy, and torch types to `jaxtyping` types" -# we check for version here, so it shouldn't error -if np.version.version < "2.0.0": - TYPE_TO_JAX_DTYPE[np.float_] = jaxtyping.Float # type: ignore[attr-defined] - TYPE_TO_JAX_DTYPE[np.int_] = jaxtyping.Int # type: ignore[attr-defined] +# np.float_ and np.int_ were deprecated in numpy 1.20 and removed in 2.0 +# use try/except for backwards compatibility and type checker friendliness +try: + TYPE_TO_JAX_DTYPE[np.float_] = jaxtyping.Float # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue] + TYPE_TO_JAX_DTYPE[np.int_] = jaxtyping.Int # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue] +except AttributeError: + pass # numpy 2.0+ removed these deprecated aliases # TODO: add proper type annotations to this signature # TODO: maybe get rid of this altogether? -def jaxtype_factory( - name: str, - array_type: type, - default_jax_dtype=jaxtyping.Float, - legacy_mode: typing.Union[ErrorMode, str] = ErrorMode.WARN, -) -> type: - """usage: - ``` - ATensor = jaxtype_factory("ATensor", torch.Tensor, jaxtyping.Float) - x: ATensor["dim1 dim2", np.float32] - ``` - """ - legacy_mode_ = ErrorMode.from_any(legacy_mode) - - class _BaseArray: - """jaxtyping shorthand - (backwards compatible with older versions of muutils.tensor_utils) - - default_jax_dtype = {default_jax_dtype} - array_type = {array_type} - """ - - def __new__(cls, *args, **kwargs): - raise TypeError("Type FArray cannot be instantiated.") - - def __init_subclass__(cls, *args, **kwargs): - raise TypeError(f"Cannot subclass {cls.__name__}") - - @classmethod - def param_info(cls, params) -> str: - """useful for error printing""" - return "\n".join( - f"{k} = {v}" - for k, v in { - "cls.__name__": cls.__name__, - "cls.__doc__": cls.__doc__, - "params": params, - "type(params)": type(params), - }.items() - ) - - @typing._tp_cache # type: ignore - def __class_getitem__(cls, params: typing.Union[str, tuple]) -> type: # type: ignore - # MyTensor["dim1 dim2"] - if isinstance(params, str): - return default_jax_dtype[array_type, params] - - elif isinstance(params, tuple): - if len(params) != 2: - raise Exception( - f"unexpected type for params, expected tuple of length 2 here:\n{cls.param_info(params)}" - ) - - if isinstance(params[0], str): - # MyTensor["dim1 dim2", int] - return TYPE_TO_JAX_DTYPE[params[1]][array_type, params[0]] - - elif isinstance(params[0], tuple): - legacy_mode_.process( - f"legacy type annotation was used:\n{cls.param_info(params) = }", - except_cls=Exception, - ) - # MyTensor[("dim1", "dim2"), int] - shape_anot: list[str] = list() - for x in params[0]: - if isinstance(x, str): - shape_anot.append(x) - elif isinstance(x, int): - shape_anot.append(str(x)) - elif isinstance(x, tuple): - shape_anot.append("".join(str(y) for y in x)) - else: - raise Exception( - f"unexpected type for params, expected first part to be str, int, or tuple:\n{cls.param_info(params)}" - ) - - return TYPE_TO_JAX_DTYPE[params[1]][ - array_type, " ".join(shape_anot) - ] - else: - raise Exception( - f"unexpected type for params:\n{cls.param_info(params)}" - ) - - _BaseArray.__name__ = name - - if _BaseArray.__doc__ is None: - _BaseArray.__doc__ = "{default_jax_dtype = }\n{array_type = }" - - _BaseArray.__doc__ = _BaseArray.__doc__.format( - default_jax_dtype=repr(default_jax_dtype), - array_type=repr(array_type), - ) - - return _BaseArray +# def jaxtype_factory( +# name: str, +# array_type: type, +# default_jax_dtype: type[jaxtyping.Float | jaxtyping.Int | jaxtyping.Bool] = jaxtyping.Float, +# legacy_mode: typing.Union[ErrorMode, str] = ErrorMode.WARN, +# ) -> type: +# """usage: +# ``` +# ATensor = jaxtype_factory("ATensor", torch.Tensor, jaxtyping.Float) +# x: ATensor["dim1 dim2", np.float32] +# ``` +# """ +# legacy_mode_ = ErrorMode.from_any(legacy_mode) + +# class _BaseArray: +# """jaxtyping shorthand +# (backwards compatible with older versions of muutils.tensor_utils) + +# default_jax_dtype = {default_jax_dtype} +# array_type = {array_type} +# """ + +# def __new__(cls, *args: Any, **kwargs: Any) -> typing.NoReturn: +# raise TypeError("Type FArray cannot be instantiated.") + +# def __init_subclass__(cls, *args: Any, **kwargs: Any) -> typing.NoReturn: +# raise TypeError(f"Cannot subclass {cls.__name__}") + +# @classmethod +# def param_info(cls, params: typing.Union[str, tuple[Any, ...]]) -> str: +# """useful for error printing""" +# return "\n".join( +# f"{k} = {v}" +# for k, v in { +# "cls.__name__": cls.__name__, +# "cls.__doc__": cls.__doc__, +# "params": params, +# "type(params)": type(params), +# }.items() +# ) + +# @typing._tp_cache # type: ignore[attr-defined] # pyright: ignore[reportUnknownMemberType, reportAttributeAccessIssue] +# def __class_getitem__(cls, params: typing.Union[str, tuple[Any, ...]]) -> type: # type: ignore[misc] +# # MyTensor["dim1 dim2"] +# if isinstance(params, str): +# return default_jax_dtype[array_type, params] + +# elif isinstance(params, tuple): +# if len(params) != 2: +# raise Exception( +# f"unexpected type for params, expected tuple of length 2 here:\n{cls.param_info(params)}" +# ) + +# if isinstance(params[0], str): +# # MyTensor["dim1 dim2", int] +# return TYPE_TO_JAX_DTYPE[params[1]][array_type, params[0]] + +# elif isinstance(params[0], tuple): +# legacy_mode_.process( +# f"legacy type annotation was used:\n{cls.param_info(params) = }", +# except_cls=Exception, +# ) +# # MyTensor[("dim1", "dim2"), int] +# shape_anot: list[str] = list() +# for x in params[0]: +# if isinstance(x, str): +# shape_anot.append(x) +# elif isinstance(x, int): +# shape_anot.append(str(x)) +# elif isinstance(x, tuple): +# shape_anot.append("".join(str(y) for y in x)) +# else: +# raise Exception( +# f"unexpected type for params, expected first part to be str, int, or tuple:\n{cls.param_info(params)}" +# ) + +# return TYPE_TO_JAX_DTYPE[params[1]][ +# array_type, " ".join(shape_anot) +# ] +# else: +# raise Exception( +# f"unexpected type for params:\n{cls.param_info(params)}" +# ) + +# _BaseArray.__name__ = name + +# if _BaseArray.__doc__ is None: +# _BaseArray.__doc__ = "{default_jax_dtype = }\n{array_type = }" + +# _BaseArray.__doc__ = _BaseArray.__doc__.format( +# default_jax_dtype=repr(default_jax_dtype), +# array_type=repr(array_type), +# ) + +# return _BaseArray if typing.TYPE_CHECKING: @@ -178,19 +181,19 @@ def __class_getitem__(cls, params: typing.Union[str, tuple]) -> type: # type: i # but they make mypy unhappy and there is no way to only run if not mypy # so, later on we have more ignores class ATensor(torch.Tensor): - @typing._tp_cache # type: ignore - def __class_getitem__(cls, params): + @typing._tp_cache # type: ignore[attr-defined] # pyright: ignore[reportUnknownMemberType, reportAttributeAccessIssue] + def __class_getitem__(cls, params: typing.Union[str, tuple[Any, ...]]) -> type: raise NotImplementedError() class NDArray(torch.Tensor): - @typing._tp_cache # type: ignore - def __class_getitem__(cls, params): + @typing._tp_cache # type: ignore[attr-defined] # pyright: ignore[reportUnknownMemberType, reportAttributeAccessIssue] + def __class_getitem__(cls, params: typing.Union[str, tuple[Any, ...]]) -> type: raise NotImplementedError() -ATensor = jaxtype_factory("ATensor", torch.Tensor, jaxtyping.Float) # type: ignore[misc, assignment] +# ATensor = jaxtype_factory("ATensor", torch.Tensor, jaxtyping.Float) # type: ignore[misc, assignment] -NDArray = jaxtype_factory("NDArray", np.ndarray, jaxtyping.Float) # type: ignore[misc, assignment] +# NDArray = jaxtype_factory("NDArray", np.ndarray, jaxtyping.Float) # type: ignore[misc, assignment] def numpy_to_torch_dtype(dtype: typing.Union[np.dtype, torch.dtype]) -> torch.dtype: @@ -201,7 +204,7 @@ def numpy_to_torch_dtype(dtype: typing.Union[np.dtype, torch.dtype]) -> torch.dt return torch.from_numpy(np.array(0, dtype=dtype)).dtype -DTYPE_LIST: list = [ +DTYPE_LIST: list[Any] = [ *[ bool, int, @@ -261,16 +264,19 @@ def numpy_to_torch_dtype(dtype: typing.Union[np.dtype, torch.dtype]) -> torch.dt ] "list of all the python, numpy, and torch numerical types I could think of" -if np.version.version < "2.0.0": - DTYPE_LIST.extend([np.float_, np.int_]) # type: ignore[attr-defined] +# np.float_ and np.int_ were deprecated in numpy 1.20 and removed in 2.0 +try: + DTYPE_LIST.extend([np.float_, np.int_]) # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue] +except AttributeError: + pass # numpy 2.0+ removed these deprecated aliases -DTYPE_MAP: dict = { +DTYPE_MAP: dict[str, Any] = { **{str(x): x for x in DTYPE_LIST}, **{dtype.__name__: dtype for dtype in DTYPE_LIST if dtype.__module__ == "numpy"}, } "mapping from string representations of types to their type" -TORCH_DTYPE_MAP: dict = { +TORCH_DTYPE_MAP: dict[str, torch.dtype] = { key: numpy_to_torch_dtype(dtype) for key, dtype in DTYPE_MAP.items() } "mapping from string representations of types to specifically torch types" @@ -420,7 +426,11 @@ class StateDictValueError(StateDictCompareError): def compare_state_dicts( - d1: dict, d2: dict, rtol: float = 1e-5, atol: float = 1e-8, verbose: bool = True + d1: dict[str, Any], + d2: dict[str, Any], + rtol: float = 1e-5, + atol: float = 1e-8, + verbose: bool = True, ) -> None: """compare two dicts of tensors @@ -442,11 +452,11 @@ def compare_state_dicts( - `StateDictValueError` : values don't match (but keys and shapes do) """ # check keys match - d1_keys: set = set(d1.keys()) - d2_keys: set = set(d2.keys()) - symmetric_diff: set = set.symmetric_difference(d1_keys, d2_keys) - keys_diff_1: set = d1_keys - d2_keys - keys_diff_2: set = d2_keys - d1_keys + d1_keys: set[str] = set(d1.keys()) + d2_keys: set[str] = set(d2.keys()) + symmetric_diff: set[str] = set.symmetric_difference(d1_keys, d2_keys) + keys_diff_1: set[str] = d1_keys - d2_keys + keys_diff_2: set[str] = d2_keys - d1_keys # sort sets for easier debugging symmetric_diff = set(sorted(symmetric_diff)) keys_diff_1 = set(sorted(keys_diff_1)) diff --git a/muutils/validate_type.py b/muutils/validate_type.py index d06834e0..f671990c 100644 --- a/muutils/validate_type.py +++ b/muutils/validate_type.py @@ -6,6 +6,7 @@ import types import typing import functools +from typing import Any # this is also for python <3.10 compatibility _GenericAliasTypeNames: typing.List[str] = [ @@ -15,11 +16,13 @@ "_BaseGenericAlias", ] -_GenericAliasTypesList: list = [ +_GenericAliasTypesList: list[Any] = [ getattr(typing, name, None) for name in _GenericAliasTypeNames ] -GenericAliasTypes: tuple = tuple([t for t in _GenericAliasTypesList if t is not None]) +GenericAliasTypes: tuple[Any, ...] = tuple( + [t for t in _GenericAliasTypesList if t is not None] +) class IncorrectTypeException(TypeError): @@ -103,7 +106,7 @@ def validate_type( raise e origin: typing.Any = typing.get_origin(expected_type) - args: tuple = typing.get_args(expected_type) + args: tuple[Any, ...] = typing.get_args(expected_type) # useful for debugging # print(f"{value = }, {expected_type = }, {origin = }, {args = }") @@ -224,7 +227,7 @@ def validate_type( ) -def get_fn_allowed_kwargs(fn: typing.Callable) -> typing.Set[str]: +def get_fn_allowed_kwargs(fn: typing.Callable[..., Any]) -> typing.Set[str]: """Get the allowed kwargs for a function, raising an exception if the signature cannot be determined.""" try: fn = unwrap(fn) diff --git a/tests/unit/json_serialize/serializable_dataclass/test_methods_no_override.py b/tests/unit/json_serialize/serializable_dataclass/test_methods_no_override.py index 431a3976..42ea0204 100644 --- a/tests/unit/json_serialize/serializable_dataclass/test_methods_no_override.py +++ b/tests/unit/json_serialize/serializable_dataclass/test_methods_no_override.py @@ -9,7 +9,7 @@ serializable_dataclass, SerializableDataclass, ) -from muutils.json_serialize.util import _FORMAT_KEY +from muutils.json_serialize.types import _FORMAT_KEY @serializable_dataclass diff --git a/tests/unit/json_serialize/serializable_dataclass/test_sdc_defaults.py b/tests/unit/json_serialize/serializable_dataclass/test_sdc_defaults.py index eed9a29d..d97eb5fb 100644 --- a/tests/unit/json_serialize/serializable_dataclass/test_sdc_defaults.py +++ b/tests/unit/json_serialize/serializable_dataclass/test_sdc_defaults.py @@ -7,7 +7,7 @@ serializable_dataclass, serializable_field, ) -from muutils.json_serialize.util import _FORMAT_KEY +from muutils.json_serialize.types import _FORMAT_KEY # pylint: disable=missing-class-docstring diff --git a/tests/unit/json_serialize/serializable_dataclass/test_sdc_properties_nested.py b/tests/unit/json_serialize/serializable_dataclass/test_sdc_properties_nested.py index 3f34cb44..5167970b 100644 --- a/tests/unit/json_serialize/serializable_dataclass/test_sdc_properties_nested.py +++ b/tests/unit/json_serialize/serializable_dataclass/test_sdc_properties_nested.py @@ -5,7 +5,7 @@ import pytest from muutils.json_serialize import SerializableDataclass, serializable_dataclass -from muutils.json_serialize.util import _FORMAT_KEY +from muutils.json_serialize.types import _FORMAT_KEY SUPPORTS_KW_ONLY: bool = sys.version_info >= (3, 10) diff --git a/tests/unit/json_serialize/serializable_dataclass/test_serializable_dataclass.py b/tests/unit/json_serialize/serializable_dataclass/test_serializable_dataclass.py index 5624efa5..bc20ab86 100644 --- a/tests/unit/json_serialize/serializable_dataclass/test_serializable_dataclass.py +++ b/tests/unit/json_serialize/serializable_dataclass/test_serializable_dataclass.py @@ -17,7 +17,7 @@ FieldIsNotInitOrSerializeWarning, FieldTypeMismatchError, ) -from muutils.json_serialize.util import _FORMAT_KEY +from muutils.json_serialize.types import _FORMAT_KEY # pylint: disable=missing-class-docstring, unused-variable diff --git a/tests/unit/json_serialize/test_array.py b/tests/unit/json_serialize/test_array.py index d65f0f12..bfa72678 100644 --- a/tests/unit/json_serialize/test_array.py +++ b/tests/unit/json_serialize/test_array.py @@ -9,7 +9,7 @@ load_array, serialize_array, ) -from muutils.json_serialize.util import _FORMAT_KEY +from muutils.json_serialize.types import _FORMAT_KEY # pylint: disable=missing-class-docstring diff --git a/tests/unit/json_serialize/test_array_torch.py b/tests/unit/json_serialize/test_array_torch.py index 11844638..f8786812 100644 --- a/tests/unit/json_serialize/test_array_torch.py +++ b/tests/unit/json_serialize/test_array_torch.py @@ -9,7 +9,7 @@ load_array, serialize_array, ) -from muutils.json_serialize.util import _FORMAT_KEY +from muutils.json_serialize.types import _FORMAT_KEY # pylint: disable=missing-class-docstring @@ -186,6 +186,7 @@ def test_torch_serialization_integration(): } serialized = jser.json_serialize(data) + assert isinstance(serialized, dict) # Check structure is preserved assert isinstance(serialized["model_weights"], dict) @@ -216,6 +217,7 @@ def test_mixed_numpy_torch(): } serialized = jser.json_serialize(data) + assert isinstance(serialized, dict) # Both should be serialized as dicts with metadata assert isinstance(serialized["numpy_array"], dict) diff --git a/tests/unit/json_serialize/test_json_serialize.py b/tests/unit/json_serialize/test_json_serialize.py index 71ede73d..7569a7ed 100644 --- a/tests/unit/json_serialize/test_json_serialize.py +++ b/tests/unit/json_serialize/test_json_serialize.py @@ -22,7 +22,8 @@ SerializerHandler, json_serialize, ) -from muutils.json_serialize.util import SerializationException, _FORMAT_KEY +from muutils.json_serialize.types import _FORMAT_KEY +from muutils.json_serialize.util import SerializationException # ============================================================================ diff --git a/tests/unit/json_serialize/test_util.py b/tests/unit/json_serialize/test_util.py index 9b88c038..7ade3161 100644 --- a/tests/unit/json_serialize/test_util.py +++ b/tests/unit/json_serialize/test_util.py @@ -4,10 +4,12 @@ import pytest +# pyright: reportPrivateUsage=false + # Module code assumed to be imported from my_module +from muutils.json_serialize.types import _FORMAT_KEY from muutils.json_serialize.util import ( UniversalContainer, - _FORMAT_KEY, _recursive_hashify, array_safe_eq, dc_eq, diff --git a/tests/unit/logger/test_log_util.py b/tests/unit/logger/test_log_util.py index 55f9a985..9fd4df86 100644 --- a/tests/unit/logger/test_log_util.py +++ b/tests/unit/logger/test_log_util.py @@ -5,6 +5,7 @@ import pytest +from muutils.json_serialize import JSONitem from muutils.jsonlines import jsonl_write from muutils.logger.log_util import ( gather_log, @@ -23,7 +24,7 @@ def test_gather_log(): log_file = TEMP_PATH / "test_gather_log.jsonl" # Create test data with multiple streams - test_data = [ + test_data: list[JSONitem] = [ {"msg": "stream1_msg1", "value": 1, "_stream": "stream1"}, {"msg": "stream2_msg1", "value": 10, "_stream": "stream2"}, {"msg": "stream1_msg2", "value": 2, "_stream": "stream1"}, @@ -65,7 +66,7 @@ def test_gather_stream(): log_file = TEMP_PATH / "test_gather_stream.jsonl" # Create test data with multiple streams - test_data = [ + test_data: list[JSONitem] = [ {"msg": "stream1_msg1", "idx": 1, "_stream": "target"}, {"msg": "stream2_msg1", "idx": 2, "_stream": "other"}, {"msg": "stream1_msg2", "idx": 3, "_stream": "target"}, @@ -105,7 +106,7 @@ def test_gather_val(): log_file = TEMP_PATH / "test_gather_val.jsonl" # Create test data matching the example from the docstring - test_data = [ + test_data: list[JSONitem] = [ {"a": 1, "b": 2, "c": 3, "_stream": "s1"}, {"a": 4, "b": 5, "c": 6, "_stream": "s1"}, {"a": 7, "b": 8, "c": 9, "_stream": "s2"}, diff --git a/tests/unit/test_jsonlines.py b/tests/unit/test_jsonlines.py index e06462d4..56b19f91 100644 --- a/tests/unit/test_jsonlines.py +++ b/tests/unit/test_jsonlines.py @@ -6,6 +6,7 @@ import pytest +from muutils.json_serialize import JSONitem from muutils.jsonlines import jsonl_load, jsonl_load_log, jsonl_write TEMP_PATH: Path = Path("tests/_temp/jsonl") @@ -49,7 +50,7 @@ def test_jsonl_write(): test_file = TEMP_PATH / "test_write.jsonl" # Test data - test_data = [ + test_data: list[JSONitem] = [ {"id": 1, "status": "active"}, {"id": 2, "status": "inactive"}, {"id": 3, "status": "pending", "metadata": {"priority": "high"}}, @@ -85,7 +86,7 @@ def test_gzip_support(): test_file_gzip = TEMP_PATH / "test_gzip2.jsonl.gzip" # Test data - test_data = [ + test_data: list[JSONitem] = [ {"compressed": True, "value": 123}, {"compressed": True, "value": 456}, ] @@ -127,7 +128,7 @@ def test_jsonl_load_log(): # Test with valid dict data test_file_valid = TEMP_PATH / "test_log_valid.jsonl" - valid_data = [ + valid_data: list[JSONitem] = [ {"level": "INFO", "message": "Starting process"}, {"level": "WARNING", "message": "Low memory"}, {"level": "ERROR", "message": "Connection failed"}, @@ -141,7 +142,7 @@ def test_jsonl_load_log(): # Test with non-dict items - should raise AssertionError test_file_invalid = TEMP_PATH / "test_log_invalid.jsonl" - invalid_data = [ + invalid_data: list[JSONitem] = [ {"level": "INFO", "message": "Valid entry"}, "not a dict", # This is invalid {"level": "ERROR", "message": "Another valid entry"}, @@ -159,7 +160,7 @@ def test_jsonl_load_log(): # Test with list item test_file_list = TEMP_PATH / "test_log_list.jsonl" - list_data = [ + list_data: list[JSONitem] = [ {"level": "INFO"}, [1, 2, 3], # List instead of dict ] @@ -182,7 +183,7 @@ def test_gzip_compresslevel(): test_file = TEMP_PATH / "test_compresslevel.jsonl.gz" # Create test data - test_data = [{"value": i, "data": "content"} for i in range(10)] + test_data: list[JSONitem] = [{"value": i, "data": "content"} for i in range(10)] # Write with different compression levels - should not error jsonl_write(str(test_file), test_data, gzip_compresslevel=1) diff --git a/tests/unit/test_tensor_utils_torch.py b/tests/unit/test_tensor_utils_torch.py index d3854f7c..bdf2c9e1 100644 --- a/tests/unit/test_tensor_utils_torch.py +++ b/tests/unit/test_tensor_utils_torch.py @@ -1,6 +1,5 @@ from __future__ import annotations -import jaxtyping import numpy as np import pytest import torch @@ -12,7 +11,7 @@ StateDictShapeError, compare_state_dicts, get_dict_shapes, - jaxtype_factory, + # jaxtype_factory, lpad_tensor, numpy_to_torch_dtype, pad_tensor, @@ -30,18 +29,18 @@ def test_pad_array(): assert np.array_equal(rpad_array(array, 5), np.array([1, 2, 3, 0, 0])) -def test_jaxtype_factory(): - ATensor = jaxtype_factory( - "ATensor", torch.Tensor, jaxtyping.Float, legacy_mode="ignore" - ) - assert ATensor.__name__ == "ATensor" - assert "default_jax_dtype = " in ATensor.__doc__ # type: ignore[operator] +# def test_jaxtype_factory(): +# ATensor = jaxtype_factory( +# "ATensor", torch.Tensor, jaxtyping.Float, legacy_mode="ignore" +# ) +# assert ATensor.__name__ == "ATensor" +# assert "default_jax_dtype = " in ATensor.__doc__ # type: ignore[operator] - x = ATensor[(1, 2, 3), np.float32] # type: ignore[index] - print(x) - y = ATensor["dim1 dim2", np.float32] # type: ignore[index] - print(y) +# x = ATensor[(1, 2, 3), np.float32] # type: ignore[index] +# print(x) +# y = ATensor["dim1 dim2", np.float32] # type: ignore[index] +# print(y) def test_numpy_to_torch_dtype(): From 22a39fe5590028647d945ef91ff49170a47ca4c6 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Mon, 3 Nov 2025 14:35:50 +0000 Subject: [PATCH 58/72] wip --- .meta/typing-summary.txt | 221 +++++++++++++++---------------- muutils/misc/typing_breakdown.py | 28 +++- pyproject.toml | 9 +- 3 files changed, 140 insertions(+), 118 deletions(-) diff --git a/.meta/typing-summary.txt b/.meta/typing-summary.txt index 444b42f2..0f2f483a 100644 --- a/.meta/typing-summary.txt +++ b/.meta/typing-summary.txt @@ -1,24 +1,21 @@ # Showing all errors -# mypy: Found 103 errors in 8 files (checked 116 source files) -# basedpyright: 666 errors, 3663 warnings, 0 notes +# mypy: Found 97 errors in 6 files (checked 116 source files) +# basedpyright: 653 errors, 3639 warnings, 0 notes # ty: Found 215 diagnostics [type_errors.mypy] -total_errors = 103 +total_errors = 97 [type_errors.mypy.by_type] "typeddict-item" = 26 "index" = 19 -"arg-type" = 14 "call-overload" = 12 +"arg-type" = 11 "attr-defined" = 11 "literal-required" = 10 "operator" = 6 "var-annotated" = 2 -"valid-type" = 1 -"assignment" = 1 -"return" = 1 [type_errors.mypy.by_file] "tests/unit/json_serialize/test_array_torch.py" = 35 @@ -26,49 +23,46 @@ total_errors = 103 "tests/unit/test_jsonlines.py" = 12 "tests/unit/json_serialize/test_json_serialize.py" = 11 "tests/unit/json_serialize/test_serializable_field.py" = 11 -"muutils/tensor_utils.py" = 5 "tests/unit/benchmark_parallel/benchmark_parallel.py" = 4 -"tests/unit/test_tensor_utils_torch.py" = 1 [type_errors.basedpyright] -total_errors = 2938 +total_errors = 2908 [type_errors.basedpyright.by_type] "reportUnknownParameterType" = 437 "reportMissingParameterType" = 386 -"reportAny" = 367 +"reportAny" = 358 "reportUnusedCallResult" = 291 -"reportUnknownVariableType" = 207 -"reportExplicitAny" = 201 -"reportMissingTypeArgument" = 186 +"reportUnknownVariableType" = 202 +"reportExplicitAny" = 195 +"reportMissingTypeArgument" = 183 "reportUnknownMemberType" = 136 "reportUnknownLambdaType" = 127 "reportUnusedParameter" = 100 "reportImplicitOverride" = 53 "reportInvalidTypeForm" = 49 -"reportIndexIssue" = 49 +"reportIndexIssue" = 46 "reportCallIssue" = 45 "reportUnannotatedClassAttribute" = 41 "reportPrivateUsage" = 35 "reportPossiblyUnboundVariable" = 34 -"reportUnreachable" = 32 -"reportUnnecessaryIsInstance" = 21 +"reportUnreachable" = 31 +"reportUnnecessaryIsInstance" = 20 "reportUntypedClassDecorator" = 17 "reportMissingSuperCall" = 14 -"reportUntypedFunctionDecorator" = 13 "reportOptionalSubscript" = 13 +"reportUntypedFunctionDecorator" = 12 "reportUnusedVariable" = 9 "reportInvalidTypeArguments" = 9 "reportUnnecessaryComparison" = 8 +"reportUnnecessaryTypeIgnoreComment" = 8 "reportCallInDefaultInitializer" = 8 "reportUndefinedVariable" = 7 "reportOptionalMemberAccess" = 7 -"reportUnnecessaryTypeIgnoreComment" = 6 "reportUninitializedInstanceVariable" = 5 "reportUnusedClass" = 4 "reportMissingTypeStubs" = 3 "reportMissingImports" = 3 -"reportArgumentType" = 3 "reportUnusedExpression" = 3 "reportImplicitStringConcatenation" = 2 "reportOperatorIssue" = 2 @@ -78,100 +72,99 @@ total_errors = 2938 "reportGeneralTypeIssues" = 1 [type_errors.basedpyright.by_file] -"/home/miv/projects/tools/muutils/tests/unit/json_serialize/test_json_serialize.py" = 165 -"/home/miv/projects/tools/muutils/tests/unit/test_dbg.py" = 156 -"/home/miv/projects/tools/muutils/tests/unit/test_parallel.py" = 134 -"/home/miv/projects/tools/muutils/tests/unit/json_serialize/serializable_dataclass/test_serializable_dataclass.py" = 125 -"/home/miv/projects/tools/muutils/tests/unit/misc/test_func.py" = 98 -"/home/miv/projects/tools/muutils/muutils/json_serialize/serializable_dataclass.py" = 97 -"/home/miv/projects/tools/muutils/muutils/dictmagic.py" = 88 -"/home/miv/projects/tools/muutils/muutils/json_serialize/serializable_field.py" = 88 -"/home/miv/projects/tools/muutils/tests/unit/validate_type/test_get_kwargs.py" = 87 -"/home/miv/projects/tools/muutils/tests/unit/validate_type/test_validate_type.py" = 86 -"/home/miv/projects/tools/muutils/tests/unit/validate_type/test_validate_type_GENERATED.py" = 86 -"/home/miv/projects/tools/muutils/muutils/misc/func.py" = 77 -"/home/miv/projects/tools/muutils/tests/unit/test_interval.py" = 75 -"/home/miv/projects/tools/muutils/muutils/json_serialize/json_serialize.py" = 70 -"/home/miv/projects/tools/muutils/muutils/misc/freezing.py" = 67 -"/home/miv/projects/tools/muutils/tests/unit/cli/test_arg_bool.py" = 66 -"/home/miv/projects/tools/muutils/muutils/spinner.py" = 63 -"/home/miv/projects/tools/muutils/tests/unit/benchmark_parallel/benchmark_parallel.py" = 56 -"/home/miv/projects/tools/muutils/tests/unit/web/test_bundle_html.py" = 55 -"/home/miv/projects/tools/muutils/muutils/tensor_info.py" = 52 -"/home/miv/projects/tools/muutils/muutils/tensor_utils.py" = 49 -"/home/miv/projects/tools/muutils/tests/unit/errormode/test_errormode_init.py" = 46 -"/home/miv/projects/tools/muutils/tests/unit/json_serialize/test_serializable_field.py" = 46 -"/home/miv/projects/tools/muutils/muutils/parallel.py" = 45 -"/home/miv/projects/tools/muutils/muutils/statcounter.py" = 41 -"/home/miv/projects/tools/muutils/muutils/sysinfo.py" = 39 -"/home/miv/projects/tools/muutils/tests/unit/json_serialize/test_array.py" = 38 -"/home/miv/projects/tools/muutils/muutils/web/bundle_html.py" = 37 -"/home/miv/projects/tools/muutils/tests/unit/test_dictmagic.py" = 37 -"/home/miv/projects/tools/muutils/tests/unit/errormode/test_errormode_functionality.py" = 36 -"/home/miv/projects/tools/muutils/muutils/nbutils/convert_ipynb_to_script.py" = 34 -"/home/miv/projects/tools/muutils/muutils/nbutils/configure_notebook.py" = 33 -"/home/miv/projects/tools/muutils/tests/unit/test_spinner.py" = 33 -"/home/miv/projects/tools/muutils/tests/unit/misc/test_freeze.py" = 31 -"/home/miv/projects/tools/muutils/muutils/misc/sequence.py" = 27 -"/home/miv/projects/tools/muutils/tests/unit/misc/test_numerical_conversions.py" = 27 -"/home/miv/projects/tools/muutils/tests/unit/json_serialize/serializable_dataclass/test_methods_no_override.py" = 25 -"/home/miv/projects/tools/muutils/tests/unit/json_serialize/test_array_torch.py" = 25 -"/home/miv/projects/tools/muutils/muutils/mlutils.py" = 24 -"/home/miv/projects/tools/muutils/muutils/validate_type.py" = 24 -"/home/miv/projects/tools/muutils/muutils/interval.py" = 20 -"/home/miv/projects/tools/muutils/tests/unit/validate_type/test_validate_type_special.py" = 20 -"/home/miv/projects/tools/muutils/muutils/logger/logger.py" = 19 -"/home/miv/projects/tools/muutils/tests/unit/test_jsonlines.py" = 19 -"/home/miv/projects/tools/muutils/muutils/math/matrix_powers.py" = 18 -"/home/miv/projects/tools/muutils/muutils/misc/string.py" = 18 -"/home/miv/projects/tools/muutils/tests/unit/json_serialize/test_util.py" = 18 -"/home/miv/projects/tools/muutils/tests/unit/test_kappa.py" = 17 -"/home/miv/projects/tools/muutils/muutils/logger/log_util.py" = 16 -"/home/miv/projects/tools/muutils/tests/unit/test_tensor_info_torch.py" = 15 -"/home/miv/projects/tools/muutils/tests/unit/nbutils/test_configure_notebook.py" = 14 -"/home/miv/projects/tools/muutils/muutils/jsonlines.py" = 13 -"/home/miv/projects/tools/muutils/tests/unit/misc/test_misc.py" = 13 -"/home/miv/projects/tools/muutils/muutils/nbutils/run_notebook_tests.py" = 12 -"/home/miv/projects/tools/muutils/muutils/cli/arg_bool.py" = 11 -"/home/miv/projects/tools/muutils/muutils/errormode.py" = 11 -"/home/miv/projects/tools/muutils/muutils/json_serialize/dataclass_transform_mock.py" = 11 -"/home/miv/projects/tools/muutils/muutils/dbg.py" = 10 -"/home/miv/projects/tools/muutils/muutils/json_serialize/array.py" = 10 -"/home/miv/projects/tools/muutils/muutils/logger/exception_context.py" = 10 -"/home/miv/projects/tools/muutils/muutils/logger/headerfuncs.py" = 10 -"/home/miv/projects/tools/muutils/tests/unit/test_collect_warnings.py" = 10 -"/home/miv/projects/tools/muutils/muutils/kappa.py" = 9 -"/home/miv/projects/tools/muutils/muutils/collect_warnings.py" = 8 -"/home/miv/projects/tools/muutils/tests/unit/test_console_unicode.py" = 8 -"/home/miv/projects/tools/muutils/tests/util/test_fire.py" = 8 -"/home/miv/projects/tools/muutils/muutils/cli/command.py" = 7 -"/home/miv/projects/tools/muutils/muutils/misc/classes.py" = 7 -"/home/miv/projects/tools/muutils/muutils/nbutils/mermaid.py" = 7 -"/home/miv/projects/tools/muutils/tests/unit/cli/test_command.py" = 7 -"/home/miv/projects/tools/muutils/muutils/logger/timing.py" = 6 -"/home/miv/projects/tools/muutils/tests/unit/nbutils/test_conversion.py" = 6 -"/home/miv/projects/tools/muutils/tests/conftest.py" = 5 -"/home/miv/projects/tools/muutils/tests/unit/json_serialize/serializable_dataclass/test_sdc_defaults.py" = 5 -"/home/miv/projects/tools/muutils/tests/unit/json_serialize/serializable_dataclass/test_sdc_properties_nested.py" = 5 -"/home/miv/projects/tools/muutils/tests/unit/test_tensor_utils_torch.py" = 5 -"/home/miv/projects/tools/muutils/muutils/logger/simplelogger.py" = 4 -"/home/miv/projects/tools/muutils/muutils/web/html_to_pdf.py" = 4 -"/home/miv/projects/tools/muutils/tests/unit/math/test_matrix_powers_torch.py" = 4 -"/home/miv/projects/tools/muutils/tests/unit/misc/test_sequence.py" = 4 -"/home/miv/projects/tools/muutils/tests/unit/test_mlutils.py" = 4 -"/home/miv/projects/tools/muutils/muutils/misc/hashing.py" = 3 -"/home/miv/projects/tools/muutils/tests/unit/test_tensor_info.py" = 3 -"/home/miv/projects/tools/muutils/muutils/logger/loggingstream.py" = 2 -"/home/miv/projects/tools/muutils/muutils/misc/numerical.py" = 2 -"/home/miv/projects/tools/muutils/tests/unit/logger/test_log_util.py" = 2 -"/home/miv/projects/tools/muutils/tests/unit/test_chunks.py" = 2 -"/home/miv/projects/tools/muutils/tests/unit/test_timeit_fancy.py" = 2 -"/home/miv/projects/tools/muutils/muutils/console_unicode.py" = 1 -"/home/miv/projects/tools/muutils/muutils/json_serialize/util.py" = 1 -"/home/miv/projects/tools/muutils/muutils/math/bins.py" = 1 -"/home/miv/projects/tools/muutils/muutils/misc/__init__.py" = 1 -"/home/miv/projects/tools/muutils/muutils/misc/b64_decode.py" = 1 -"/home/miv/projects/tools/muutils/tests/unit/benchmark_parallel/test_benchmark_demo.py" = 1 +"tests/unit/json_serialize/test_json_serialize.py" = 165 +"tests/unit/test_dbg.py" = 156 +"tests/unit/test_parallel.py" = 134 +"tests/unit/json_serialize/serializable_dataclass/test_serializable_dataclass.py" = 125 +"tests/unit/misc/test_func.py" = 98 +"muutils/json_serialize/serializable_dataclass.py" = 97 +"muutils/dictmagic.py" = 88 +"muutils/json_serialize/serializable_field.py" = 88 +"tests/unit/validate_type/test_get_kwargs.py" = 87 +"tests/unit/validate_type/test_validate_type.py" = 86 +"tests/unit/validate_type/test_validate_type_GENERATED.py" = 86 +"muutils/misc/func.py" = 77 +"tests/unit/test_interval.py" = 75 +"muutils/json_serialize/json_serialize.py" = 70 +"muutils/misc/freezing.py" = 67 +"tests/unit/cli/test_arg_bool.py" = 66 +"muutils/spinner.py" = 63 +"tests/unit/benchmark_parallel/benchmark_parallel.py" = 56 +"tests/unit/web/test_bundle_html.py" = 55 +"muutils/tensor_info.py" = 52 +"tests/unit/errormode/test_errormode_init.py" = 46 +"tests/unit/json_serialize/test_serializable_field.py" = 46 +"muutils/parallel.py" = 45 +"muutils/statcounter.py" = 41 +"muutils/sysinfo.py" = 39 +"tests/unit/json_serialize/test_array.py" = 38 +"muutils/web/bundle_html.py" = 37 +"tests/unit/test_dictmagic.py" = 37 +"tests/unit/errormode/test_errormode_functionality.py" = 36 +"muutils/nbutils/convert_ipynb_to_script.py" = 34 +"muutils/nbutils/configure_notebook.py" = 33 +"tests/unit/test_spinner.py" = 33 +"tests/unit/misc/test_freeze.py" = 31 +"muutils/misc/sequence.py" = 27 +"tests/unit/misc/test_numerical_conversions.py" = 27 +"tests/unit/json_serialize/serializable_dataclass/test_methods_no_override.py" = 25 +"tests/unit/json_serialize/test_array_torch.py" = 25 +"muutils/mlutils.py" = 24 +"muutils/tensor_utils.py" = 24 +"muutils/validate_type.py" = 24 +"muutils/interval.py" = 20 +"tests/unit/validate_type/test_validate_type_special.py" = 20 +"muutils/logger/logger.py" = 19 +"tests/unit/test_jsonlines.py" = 19 +"muutils/math/matrix_powers.py" = 18 +"muutils/misc/string.py" = 18 +"tests/unit/json_serialize/test_util.py" = 18 +"tests/unit/test_kappa.py" = 17 +"muutils/logger/log_util.py" = 16 +"tests/unit/test_tensor_info_torch.py" = 15 +"tests/unit/nbutils/test_configure_notebook.py" = 14 +"muutils/jsonlines.py" = 13 +"tests/unit/misc/test_misc.py" = 13 +"muutils/nbutils/run_notebook_tests.py" = 12 +"muutils/cli/arg_bool.py" = 11 +"muutils/errormode.py" = 11 +"muutils/json_serialize/dataclass_transform_mock.py" = 11 +"muutils/dbg.py" = 10 +"muutils/json_serialize/array.py" = 10 +"muutils/logger/exception_context.py" = 10 +"muutils/logger/headerfuncs.py" = 10 +"tests/unit/test_collect_warnings.py" = 10 +"muutils/kappa.py" = 9 +"muutils/collect_warnings.py" = 8 +"tests/unit/test_console_unicode.py" = 8 +"tests/util/test_fire.py" = 8 +"muutils/cli/command.py" = 7 +"muutils/misc/classes.py" = 7 +"muutils/nbutils/mermaid.py" = 7 +"tests/unit/cli/test_command.py" = 7 +"muutils/logger/timing.py" = 6 +"tests/unit/nbutils/test_conversion.py" = 6 +"tests/conftest.py" = 5 +"tests/unit/json_serialize/serializable_dataclass/test_sdc_defaults.py" = 5 +"tests/unit/json_serialize/serializable_dataclass/test_sdc_properties_nested.py" = 5 +"muutils/logger/simplelogger.py" = 4 +"muutils/web/html_to_pdf.py" = 4 +"tests/unit/math/test_matrix_powers_torch.py" = 4 +"tests/unit/misc/test_sequence.py" = 4 +"tests/unit/test_mlutils.py" = 4 +"muutils/misc/hashing.py" = 3 +"tests/unit/test_tensor_info.py" = 3 +"muutils/logger/loggingstream.py" = 2 +"muutils/misc/numerical.py" = 2 +"tests/unit/logger/test_log_util.py" = 2 +"tests/unit/test_chunks.py" = 2 +"tests/unit/test_timeit_fancy.py" = 2 +"muutils/console_unicode.py" = 1 +"muutils/json_serialize/util.py" = 1 +"muutils/math/bins.py" = 1 +"muutils/misc/__init__.py" = 1 +"muutils/misc/b64_decode.py" = 1 +"tests/unit/benchmark_parallel/test_benchmark_demo.py" = 1 [type_errors.ty] total_errors = 215 diff --git a/muutils/misc/typing_breakdown.py b/muutils/misc/typing_breakdown.py index 965b77f1..31bf50f4 100644 --- a/muutils/misc/typing_breakdown.py +++ b/muutils/misc/typing_breakdown.py @@ -12,6 +12,7 @@ from __future__ import annotations import argparse +import os import re from collections import defaultdict from dataclasses import dataclass, field @@ -19,6 +20,31 @@ from typing import Callable, Dict, List, Literal, Tuple +def strip_cwd(path: str) -> str: + """Strip the current working directory from a file path to make it relative. + + Args: + path: File path (absolute or relative) + + Returns: + Relative path with CWD stripped, or original path if not under CWD + """ + cwd: str = os.getcwd() + # Normalize both paths to handle different separators and resolve symlinks + abs_path: str = os.path.abspath(path) + abs_cwd: str = os.path.abspath(cwd) + + # Ensure CWD ends with separator for proper prefix matching + if not abs_cwd.endswith(os.sep): + abs_cwd += os.sep + + # Strip CWD prefix if present + if abs_path.startswith(abs_cwd): + return abs_path[len(abs_cwd):] + + return path + + @dataclass class TypeCheckResult: "results from parsing a type checker output" @@ -127,7 +153,7 @@ def parse_basedpyright(content: str) -> TypeCheckResult: for line in content.splitlines(): # Check if this is a file path line if line and not line.startswith(" ") and line.startswith("/"): - current_file = line.strip() + current_file = strip_cwd(line.strip()) # Check if this is an error/warning line elif line.strip() and current_file: # Match pattern like: " path:line:col - warning: message (reportCode)" diff --git a/pyproject.toml b/pyproject.toml index 8ddbe34d..232a7d76 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -142,10 +142,8 @@ exclude = ["tests/input_data", "tests/junk_data", "_wip/"] [tool.basedpyright] + # file include/exclude include = ["muutils", "tests"] - reportConstantRedefinition = false # I always use all caps for globals, not just consts - reportDeprecated = false # this library is backwards compatible back to 3.8, so we are using lots of deprecated stuff - reportUnsupportedDunderAll = false # we use __all__ a lot for docs stuff exclude = [ "tests/input_data", "tests/junk_data", @@ -154,6 +152,11 @@ "docs/resources/make_docs.py", ".venv", ] + # rules + reportConstantRedefinition = false # I always use all caps for globals, not just consts + reportDeprecated = false # this library is backwards compatible back to 3.8, so we are using lots of deprecated stuff + reportUnsupportedDunderAll = false # we use __all__ a lot for docs stuff + # reportExplicitAny = false # we allow Any in many places. if it's there, it's intentional [tool.ty.src] exclude = [ From 24a4bccd556013631efc79e6f189c0e4f69b261b Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Mon, 3 Nov 2025 15:21:44 +0000 Subject: [PATCH 59/72] ? --- .meta/typing-summary.txt | 87 +++++++++---------- muutils/misc/typing_breakdown.py | 2 +- pyproject.toml | 2 +- .../benchmark_parallel/benchmark_parallel.py | 22 ++--- 4 files changed, 56 insertions(+), 57 deletions(-) diff --git a/.meta/typing-summary.txt b/.meta/typing-summary.txt index 0f2f483a..9eec93ff 100644 --- a/.meta/typing-summary.txt +++ b/.meta/typing-summary.txt @@ -1,8 +1,8 @@ # Showing all errors # mypy: Found 97 errors in 6 files (checked 116 source files) -# basedpyright: 653 errors, 3639 warnings, 0 notes -# ty: Found 215 diagnostics +# basedpyright: 651 errors, 3446 warnings, 0 notes +# ty: Found 214 diagnostics [type_errors.mypy] total_errors = 97 @@ -26,16 +26,15 @@ total_errors = 97 "tests/unit/benchmark_parallel/benchmark_parallel.py" = 4 [type_errors.basedpyright] -total_errors = 2908 +total_errors = 2715 [type_errors.basedpyright.by_type] "reportUnknownParameterType" = 437 "reportMissingParameterType" = 386 "reportAny" = 358 "reportUnusedCallResult" = 291 -"reportUnknownVariableType" = 202 -"reportExplicitAny" = 195 -"reportMissingTypeArgument" = 183 +"reportUnknownVariableType" = 201 +"reportMissingTypeArgument" = 182 "reportUnknownMemberType" = 136 "reportUnknownLambdaType" = 127 "reportUnusedParameter" = 100 @@ -51,11 +50,11 @@ total_errors = 2908 "reportUntypedClassDecorator" = 17 "reportMissingSuperCall" = 14 "reportOptionalSubscript" = 13 +"reportUnnecessaryTypeIgnoreComment" = 12 "reportUntypedFunctionDecorator" = 12 "reportUnusedVariable" = 9 "reportInvalidTypeArguments" = 9 "reportUnnecessaryComparison" = 8 -"reportUnnecessaryTypeIgnoreComment" = 8 "reportCallInDefaultInitializer" = 8 "reportUndefinedVariable" = 7 "reportOptionalMemberAccess" = 7 @@ -72,29 +71,29 @@ total_errors = 2908 "reportGeneralTypeIssues" = 1 [type_errors.basedpyright.by_file] -"tests/unit/json_serialize/test_json_serialize.py" = 165 -"tests/unit/test_dbg.py" = 156 -"tests/unit/test_parallel.py" = 134 -"tests/unit/json_serialize/serializable_dataclass/test_serializable_dataclass.py" = 125 +"tests/unit/json_serialize/test_json_serialize.py" = 164 +"tests/unit/test_dbg.py" = 150 +"tests/unit/test_parallel.py" = 131 +"tests/unit/json_serialize/serializable_dataclass/test_serializable_dataclass.py" = 121 "tests/unit/misc/test_func.py" = 98 -"muutils/json_serialize/serializable_dataclass.py" = 97 -"muutils/dictmagic.py" = 88 -"muutils/json_serialize/serializable_field.py" = 88 "tests/unit/validate_type/test_get_kwargs.py" = 87 "tests/unit/validate_type/test_validate_type.py" = 86 "tests/unit/validate_type/test_validate_type_GENERATED.py" = 86 -"muutils/misc/func.py" = 77 +"muutils/json_serialize/serializable_dataclass.py" = 80 "tests/unit/test_interval.py" = 75 -"muutils/json_serialize/json_serialize.py" = 70 -"muutils/misc/freezing.py" = 67 +"muutils/misc/func.py" = 70 "tests/unit/cli/test_arg_bool.py" = 66 -"muutils/spinner.py" = 63 -"tests/unit/benchmark_parallel/benchmark_parallel.py" = 56 +"muutils/misc/freezing.py" = 65 +"muutils/json_serialize/json_serialize.py" = 64 +"muutils/dictmagic.py" = 62 "tests/unit/web/test_bundle_html.py" = 55 -"muutils/tensor_info.py" = 52 +"muutils/spinner.py" = 51 +"muutils/json_serialize/serializable_field.py" = 49 +"tests/unit/benchmark_parallel/benchmark_parallel.py" = 49 "tests/unit/errormode/test_errormode_init.py" = 46 -"tests/unit/json_serialize/test_serializable_field.py" = 46 -"muutils/parallel.py" = 45 +"muutils/tensor_info.py" = 45 +"tests/unit/json_serialize/test_serializable_field.py" = 45 +"muutils/parallel.py" = 42 "muutils/statcounter.py" = 41 "muutils/sysinfo.py" = 39 "tests/unit/json_serialize/test_array.py" = 38 @@ -103,43 +102,42 @@ total_errors = 2908 "tests/unit/errormode/test_errormode_functionality.py" = 36 "muutils/nbutils/convert_ipynb_to_script.py" = 34 "muutils/nbutils/configure_notebook.py" = 33 -"tests/unit/test_spinner.py" = 33 +"tests/unit/test_spinner.py" = 32 "tests/unit/misc/test_freeze.py" = 31 -"muutils/misc/sequence.py" = 27 "tests/unit/misc/test_numerical_conversions.py" = 27 -"tests/unit/json_serialize/serializable_dataclass/test_methods_no_override.py" = 25 "tests/unit/json_serialize/test_array_torch.py" = 25 -"muutils/mlutils.py" = 24 -"muutils/tensor_utils.py" = 24 -"muutils/validate_type.py" = 24 -"muutils/interval.py" = 20 +"muutils/misc/sequence.py" = 22 +"muutils/mlutils.py" = 22 +"muutils/validate_type.py" = 20 "tests/unit/validate_type/test_validate_type_special.py" = 20 -"muutils/logger/logger.py" = 19 "tests/unit/test_jsonlines.py" = 19 +"muutils/logger/logger.py" = 18 "muutils/math/matrix_powers.py" = 18 "muutils/misc/string.py" = 18 "tests/unit/json_serialize/test_util.py" = 18 +"muutils/interval.py" = 17 "tests/unit/test_kappa.py" = 17 "muutils/logger/log_util.py" = 16 -"tests/unit/test_tensor_info_torch.py" = 15 +"muutils/tensor_utils.py" = 16 +"tests/unit/json_serialize/serializable_dataclass/test_methods_no_override.py" = 15 "tests/unit/nbutils/test_configure_notebook.py" = 14 "muutils/jsonlines.py" = 13 -"tests/unit/misc/test_misc.py" = 13 "muutils/nbutils/run_notebook_tests.py" = 12 -"muutils/cli/arg_bool.py" = 11 "muutils/errormode.py" = 11 -"muutils/json_serialize/dataclass_transform_mock.py" = 11 +"tests/unit/misc/test_misc.py" = 11 +"tests/unit/test_tensor_info_torch.py" = 11 "muutils/dbg.py" = 10 "muutils/json_serialize/array.py" = 10 "muutils/logger/exception_context.py" = 10 -"muutils/logger/headerfuncs.py" = 10 "tests/unit/test_collect_warnings.py" = 10 "muutils/kappa.py" = 9 -"muutils/collect_warnings.py" = 8 +"muutils/misc/classes.py" = 9 +"muutils/cli/arg_bool.py" = 8 +"muutils/json_serialize/dataclass_transform_mock.py" = 8 +"muutils/logger/headerfuncs.py" = 8 "tests/unit/test_console_unicode.py" = 8 "tests/util/test_fire.py" = 8 -"muutils/cli/command.py" = 7 -"muutils/misc/classes.py" = 7 +"muutils/collect_warnings.py" = 7 "muutils/nbutils/mermaid.py" = 7 "tests/unit/cli/test_command.py" = 7 "muutils/logger/timing.py" = 6 @@ -147,27 +145,29 @@ total_errors = 2908 "tests/conftest.py" = 5 "tests/unit/json_serialize/serializable_dataclass/test_sdc_defaults.py" = 5 "tests/unit/json_serialize/serializable_dataclass/test_sdc_properties_nested.py" = 5 -"muutils/logger/simplelogger.py" = 4 "muutils/web/html_to_pdf.py" = 4 "tests/unit/math/test_matrix_powers_torch.py" = 4 "tests/unit/misc/test_sequence.py" = 4 "tests/unit/test_mlutils.py" = 4 +"muutils/cli/command.py" = 3 +"muutils/logger/simplelogger.py" = 3 "muutils/misc/hashing.py" = 3 -"tests/unit/test_tensor_info.py" = 3 -"muutils/logger/loggingstream.py" = 2 "muutils/misc/numerical.py" = 2 +"muutils/timeit_fancy.py" = 2 "tests/unit/logger/test_log_util.py" = 2 "tests/unit/test_chunks.py" = 2 +"tests/unit/test_tensor_info.py" = 2 "tests/unit/test_timeit_fancy.py" = 2 "muutils/console_unicode.py" = 1 "muutils/json_serialize/util.py" = 1 +"muutils/logger/loggingstream.py" = 1 "muutils/math/bins.py" = 1 "muutils/misc/__init__.py" = 1 "muutils/misc/b64_decode.py" = 1 "tests/unit/benchmark_parallel/test_benchmark_demo.py" = 1 [type_errors.ty] -total_errors = 215 +total_errors = 214 [type_errors.ty.by_type] "unknown-argument" = 164 @@ -175,7 +175,6 @@ total_errors = 215 "invalid-argument-type" = 8 "invalid-assignment" = 5 "too-many-positional-arguments" = 3 -"invalid-return-type" = 1 "unresolved-import" = 1 [type_errors.ty.by_file] @@ -184,6 +183,6 @@ total_errors = 215 "tests/unit/json_serialize/serializable_dataclass/test_methods_no_override.py" = 26 "tests/unit/json_serialize/serializable_dataclass/test_sdc_properties_nested.py" = 9 "tests/unit/test_dictmagic.py" = 8 -"tests/unit/benchmark_parallel/benchmark_parallel.py" = 6 +"tests/unit/benchmark_parallel/benchmark_parallel.py" = 5 "tests/unit/json_serialize/test_array_torch.py" = 2 "tests/unit/benchmark_parallel/test_benchmark_demo.py" = 1 diff --git a/muutils/misc/typing_breakdown.py b/muutils/misc/typing_breakdown.py index 31bf50f4..cf363e0d 100644 --- a/muutils/misc/typing_breakdown.py +++ b/muutils/misc/typing_breakdown.py @@ -40,7 +40,7 @@ def strip_cwd(path: str) -> str: # Strip CWD prefix if present if abs_path.startswith(abs_cwd): - return abs_path[len(abs_cwd):] + return abs_path[len(abs_cwd) :] return path diff --git a/pyproject.toml b/pyproject.toml index 232a7d76..831f8cf5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -156,7 +156,7 @@ reportConstantRedefinition = false # I always use all caps for globals, not just consts reportDeprecated = false # this library is backwards compatible back to 3.8, so we are using lots of deprecated stuff reportUnsupportedDunderAll = false # we use __all__ a lot for docs stuff - # reportExplicitAny = false # we allow Any in many places. if it's there, it's intentional + reportExplicitAny = false # we allow Any in many places. if it's there, it's intentional [tool.ty.src] exclude = [ diff --git a/tests/unit/benchmark_parallel/benchmark_parallel.py b/tests/unit/benchmark_parallel/benchmark_parallel.py index f2534bf9..c1288813 100644 --- a/tests/unit/benchmark_parallel/benchmark_parallel.py +++ b/tests/unit/benchmark_parallel/benchmark_parallel.py @@ -7,7 +7,7 @@ import time import multiprocessing from typing import List, Callable, Any, Dict, Optional, Sequence, Tuple, Union -import pandas as pd # type: ignore[import-untyped] +import pandas as pd # type: ignore[import-untyped] # pyright: ignore[reportMissingTypeStubs] import numpy as np from collections import defaultdict @@ -41,7 +41,7 @@ def __init__(self): self.results = defaultdict(list) self.cpu_count = multiprocessing.cpu_count() - def time_execution(self, func: Callable, *args, **kwargs) -> float: + def time_execution(self, func: Callable[..., float], *args, **kwargs) -> float: """Time a single execution.""" start = time.perf_counter() func(*args, **kwargs) @@ -62,17 +62,17 @@ def benchmark_method( times.append(duration) return { - "mean": np.mean(times), - "std": np.std(times), - "min": np.min(times), - "max": np.max(times), - "median": np.median(times), + "mean": float(np.mean(times)), + "std": float(np.std(times)), + "min": float(np.min(times)), + "max": float(np.max(times)), + "median": float(np.median(times)), } def run_benchmark_suite( self, - data_sizes: List[int], - task_funcs: Dict[str, Callable], + data_sizes: Sequence[int], + task_funcs: Dict[str, Callable[[int], int]], runs_per_method: int = 3, ) -> pd.DataFrame: """Run complete benchmark suite and return results as DataFrame.""" @@ -280,7 +280,7 @@ def plot_speedup_by_data_size( def plot_timing_comparison( - df: pd.DataFrame, data_size: int | None = None, save_path: str | None = None + df: pd.DataFrame, data_size: int | None = None, save_path: str | Path | None = None ): """Plot timing comparison as bar chart.""" import matplotlib.pyplot as plt # type: ignore[import-untyped] @@ -305,7 +305,7 @@ def plot_timing_comparison( plt.show() -def plot_efficiency_heatmap(df: pd.DataFrame, save_path: str | None = None): +def plot_efficiency_heatmap(df: pd.DataFrame, save_path: str | Path | None = None): """Plot efficiency heatmap (speedup across methods and tasks).""" import matplotlib.pyplot as plt # type: ignore[import-untyped] From 54efcf73a713200882aa3dcaf40b8f4ca4412b36 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Mon, 3 Nov 2025 15:53:17 +0000 Subject: [PATCH 60/72] wip --- .meta/typing-summary.txt | 47 +++++++++---------- muutils/json_serialize/array.py | 36 +++++++------- muutils/json_serialize/json_serialize.py | 6 +-- muutils/json_serialize/util.py | 33 ++++++------- tests/unit/json_serialize/test_array_torch.py | 19 ++++---- 5 files changed, 74 insertions(+), 67 deletions(-) diff --git a/.meta/typing-summary.txt b/.meta/typing-summary.txt index 9eec93ff..d550d0dc 100644 --- a/.meta/typing-summary.txt +++ b/.meta/typing-summary.txt @@ -1,19 +1,19 @@ # Showing all errors -# mypy: Found 97 errors in 6 files (checked 116 source files) -# basedpyright: 651 errors, 3446 warnings, 0 notes -# ty: Found 214 diagnostics +# mypy: Found 94 errors in 6 files (checked 116 source files) +# basedpyright: 637 errors, 3416 warnings, 0 notes +# ty: Found 209 diagnostics [type_errors.mypy] -total_errors = 97 +total_errors = 94 [type_errors.mypy.by_type] "typeddict-item" = 26 "index" = 19 +"attr-defined" = 13 "call-overload" = 12 -"arg-type" = 11 -"attr-defined" = 11 "literal-required" = 10 +"arg-type" = 6 "operator" = 6 "var-annotated" = 2 @@ -23,39 +23,39 @@ total_errors = 97 "tests/unit/test_jsonlines.py" = 12 "tests/unit/json_serialize/test_json_serialize.py" = 11 "tests/unit/json_serialize/test_serializable_field.py" = 11 -"tests/unit/benchmark_parallel/benchmark_parallel.py" = 4 +"tests/unit/benchmark_parallel/benchmark_parallel.py" = 1 [type_errors.basedpyright] -total_errors = 2715 +total_errors = 2705 [type_errors.basedpyright.by_type] -"reportUnknownParameterType" = 437 -"reportMissingParameterType" = 386 -"reportAny" = 358 -"reportUnusedCallResult" = 291 +"reportUnknownParameterType" = 434 +"reportMissingParameterType" = 384 +"reportAny" = 357 +"reportUnusedCallResult" = 292 "reportUnknownVariableType" = 201 -"reportMissingTypeArgument" = 182 +"reportMissingTypeArgument" = 181 "reportUnknownMemberType" = 136 "reportUnknownLambdaType" = 127 -"reportUnusedParameter" = 100 +"reportUnusedParameter" = 99 "reportImplicitOverride" = 53 "reportInvalidTypeForm" = 49 "reportIndexIssue" = 46 "reportCallIssue" = 45 "reportUnannotatedClassAttribute" = 41 -"reportPrivateUsage" = 35 "reportPossiblyUnboundVariable" = 34 +"reportPrivateUsage" = 33 "reportUnreachable" = 31 -"reportUnnecessaryIsInstance" = 20 +"reportUnnecessaryIsInstance" = 19 "reportUntypedClassDecorator" = 17 "reportMissingSuperCall" = 14 +"reportUnnecessaryTypeIgnoreComment" = 13 "reportOptionalSubscript" = 13 -"reportUnnecessaryTypeIgnoreComment" = 12 "reportUntypedFunctionDecorator" = 12 -"reportUnusedVariable" = 9 "reportInvalidTypeArguments" = 9 "reportUnnecessaryComparison" = 8 "reportCallInDefaultInitializer" = 8 +"reportUnusedVariable" = 8 "reportUndefinedVariable" = 7 "reportOptionalMemberAccess" = 7 "reportUninitializedInstanceVariable" = 5 @@ -105,7 +105,7 @@ total_errors = 2715 "tests/unit/test_spinner.py" = 32 "tests/unit/misc/test_freeze.py" = 31 "tests/unit/misc/test_numerical_conversions.py" = 27 -"tests/unit/json_serialize/test_array_torch.py" = 25 +"tests/unit/json_serialize/test_array_torch.py" = 23 "muutils/misc/sequence.py" = 22 "muutils/mlutils.py" = 22 "muutils/validate_type.py" = 20 @@ -127,7 +127,6 @@ total_errors = 2715 "tests/unit/misc/test_misc.py" = 11 "tests/unit/test_tensor_info_torch.py" = 11 "muutils/dbg.py" = 10 -"muutils/json_serialize/array.py" = 10 "muutils/logger/exception_context.py" = 10 "tests/unit/test_collect_warnings.py" = 10 "muutils/kappa.py" = 9 @@ -152,6 +151,7 @@ total_errors = 2715 "muutils/cli/command.py" = 3 "muutils/logger/simplelogger.py" = 3 "muutils/misc/hashing.py" = 3 +"muutils/json_serialize/array.py" = 2 "muutils/misc/numerical.py" = 2 "muutils/timeit_fancy.py" = 2 "tests/unit/logger/test_log_util.py" = 2 @@ -167,13 +167,13 @@ total_errors = 2715 "tests/unit/benchmark_parallel/test_benchmark_demo.py" = 1 [type_errors.ty] -total_errors = 214 +total_errors = 209 [type_errors.ty.by_type] "unknown-argument" = 164 "unresolved-attribute" = 33 -"invalid-argument-type" = 8 "invalid-assignment" = 5 +"invalid-argument-type" = 3 "too-many-positional-arguments" = 3 "unresolved-import" = 1 @@ -183,6 +183,5 @@ total_errors = 214 "tests/unit/json_serialize/serializable_dataclass/test_methods_no_override.py" = 26 "tests/unit/json_serialize/serializable_dataclass/test_sdc_properties_nested.py" = 9 "tests/unit/test_dictmagic.py" = 8 -"tests/unit/benchmark_parallel/benchmark_parallel.py" = 5 -"tests/unit/json_serialize/test_array_torch.py" = 2 +"tests/unit/benchmark_parallel/benchmark_parallel.py" = 2 "tests/unit/benchmark_parallel/test_benchmark_demo.py" = 1 diff --git a/muutils/json_serialize/array.py b/muutils/json_serialize/array.py index 42f23463..4e219bf3 100644 --- a/muutils/json_serialize/array.py +++ b/muutils/json_serialize/array.py @@ -13,6 +13,7 @@ import warnings from typing import ( TYPE_CHECKING, + Any, Iterable, Literal, Optional, @@ -32,8 +33,10 @@ if TYPE_CHECKING: import numpy as np + import torch + from muutils.json_serialize.json_serialize import JsonSerializer -from muutils.json_serialize.types import _FORMAT_KEY +from muutils.json_serialize.types import _FORMAT_KEY # pyright: ignore[reportPrivateUsage] # TYPING: pyright complains way too much here # pyright: reportCallIssue=false,reportArgumentType=false,reportUnknownVariableType=false,reportUnknownMemberType=false @@ -54,14 +57,15 @@ ] -def array_n_elements(arr) -> int: # type: ignore[name-defined] +def array_n_elements(arr: Any) -> int: # type: ignore[name-defined] # pyright: ignore[reportAny] """get the number of elements in an array""" if isinstance(arr, np.ndarray): return arr.size - elif str(type(arr)) == "": - return arr.nelement() + elif str(type(arr)) == "": # pyright: ignore[reportUnknownArgumentType, reportAny] + assert hasattr(arr, "nelement"), "torch Tensor does not have nelement() method? this should not happen" # pyright: ignore[reportAny] + return arr.nelement() # pyright: ignore[reportAny] else: - raise TypeError(f"invalid type: {type(arr)}") + raise TypeError(f"invalid type: {type(arr)}") # pyright: ignore[reportAny] class ArrayMetadata(TypedDict): @@ -84,21 +88,21 @@ class SerializedArrayWithMeta(TypedDict): n_elements: int -def arr_metadata(arr) -> ArrayMetadata: +def arr_metadata(arr: Any) -> ArrayMetadata: # pyright: ignore[reportAny] """get metadata for a numpy array""" return { - "shape": list(arr.shape), + "shape": list(arr.shape), # pyright: ignore[reportAny] "dtype": ( - arr.dtype.__name__ if hasattr(arr.dtype, "__name__") else str(arr.dtype) + arr.dtype.__name__ if hasattr(arr.dtype, "__name__") else str(arr.dtype) # pyright: ignore[reportAny] ), "n_elements": array_n_elements(arr), } def serialize_array( - jser: "JsonSerializer", # type: ignore[name-defined] # noqa: F821 # pyright: ignore[reportUndefinedVariable] - arr: np.ndarray, - path: str | Sequence[str | int], + jser: "JsonSerializer", # type: ignore[name-defined] # noqa: F821 + arr: "Union[np.ndarray, torch.Tensor]", + path: str | Sequence[str | int], # pyright: ignore[reportUnusedParameter] array_mode: ArrayMode | None = None, ) -> SerializedArrayWithMeta | NumericList: """serialize a numpy or pytorch array in one of several modes @@ -138,11 +142,11 @@ def serialize_array( array_mode = jser.array_mode arr_type: str = f"{type(arr).__module__}.{type(arr).__name__}" - arr_np: np.ndarray = arr if isinstance(arr, np.ndarray) else np.array(arr) + arr_np: np.ndarray = arr if isinstance(arr, np.ndarray) else np.array(arr) # pyright: ignore[reportUnnecessaryIsInstance] # Handle list mode first (no metadata needed) if array_mode == "list": - return arr_np.tolist() + return arr_np.tolist() # pyright: ignore[reportAny] # For all other modes, compute metadata once metadata: ArrayMetadata = arr_metadata(arr if len(arr.shape) == 0 else arr_np) @@ -153,7 +157,7 @@ def serialize_array( if len(arr.shape) == 0: return SerializedArrayWithMeta( __muutils_format__=f"{arr_type}:zero_dim", - data=arr.item(), + data=arr.item(), # pyright: ignore[reportAny] shape=metadata["shape"], dtype=metadata["dtype"], n_elements=metadata["n_elements"], @@ -163,7 +167,7 @@ def serialize_array( if array_mode == "array_list_meta": return SerializedArrayWithMeta( __muutils_format__=f"{arr_type}:array_list_meta", - data=arr_np.tolist(), + data=arr_np.tolist(), # pyright: ignore[reportAny] shape=metadata["shape"], dtype=metadata["dtype"], n_elements=metadata["n_elements"], @@ -210,7 +214,7 @@ def infer_array_mode( return_mode: ArrayMode if isinstance(arr, typing.Mapping): # _FORMAT_KEY always maps to a string - fmt: str = arr.get(_FORMAT_KEY, "") # type: ignore # pyright: ignore[reportAssignmentType] + fmt: str = arr.get(_FORMAT_KEY, "") # type: ignore if fmt.endswith(":array_list_meta"): if not isinstance(arr["data"], Iterable): raise ValueError(f"invalid list format: {type(arr['data']) = }\t{arr}") diff --git a/muutils/json_serialize/json_serialize.py b/muutils/json_serialize/json_serialize.py index dd2f6512..2bb9e789 100644 --- a/muutils/json_serialize/json_serialize.py +++ b/muutils/json_serialize/json_serialize.py @@ -27,7 +27,7 @@ except ImportError as e: # TYPING: obviously, these types are all wrong if we can't import array.py ArrayMode = str # type: ignore[misc] - serialize_array = lambda *args, **kwargs: None # type: ignore[assignment, invalid-assignment] # noqa: E731 # pyright: ignore[reportUnknownVariableType, reportUnknownLambdaType] + serialize_array = lambda *args, **kwargs: None # type: ignore[assignment, invalid-assignment] # noqa: E731 warnings.warn( f"muutils.json_serialize.array could not be imported probably because missing numpy, array serialization will not work: \n{e}", ImportWarning, @@ -316,8 +316,8 @@ def json_serialize( if self.write_only_format: if isinstance(output, dict) and _FORMAT_KEY in output: # TYPING: JSONitem has no idea that _FORMAT_KEY is str - new_fmt: str = output.pop(_FORMAT_KEY) # type: ignore - output["__write_format__"] = new_fmt # type: ignore + new_fmt: str = output.pop(_FORMAT_KEY) # type: ignore # pyright: ignore[reportAssignmentType] + output["__write_format__"] = new_fmt # type: ignore # pyright: ignore[reportGeneralTypeIssues] return output raise ValueError(f"no handler found for object with {type(obj) = }") diff --git a/muutils/json_serialize/util.py b/muutils/json_serialize/util.py index 19e25b95..c229097e 100644 --- a/muutils/json_serialize/util.py +++ b/muutils/json_serialize/util.py @@ -24,22 +24,23 @@ # At type-checking time, include array serialization types to avoid nominal type errors # This avoids superfluous imports at runtime -if TYPE_CHECKING: - from muutils.json_serialize.array import NumericList, SerializedArrayWithMeta - - JSONitem = Union[ - BaseType, - typing.Sequence["JSONitem"], - typing.Dict[str, "JSONitem"], - SerializedArrayWithMeta, - NumericList, - ] -else: - JSONitem = Union[ - BaseType, - typing.Sequence["JSONitem"], - typing.Dict[str, "JSONitem"], - ] +# if TYPE_CHECKING: +# from muutils.json_serialize.array import NumericList, SerializedArrayWithMeta + +# JSONitem = Union[ +# BaseType, +# typing.Sequence["JSONitem"], +# typing.Dict[str, "JSONitem"], +# SerializedArrayWithMeta, +# NumericList, +# ] +# else: + +JSONitem = Union[ + BaseType, + typing.Sequence["JSONitem"], + typing.Dict[str, "JSONitem"], +] JSONdict = typing.Dict[str, JSONitem] diff --git a/tests/unit/json_serialize/test_array_torch.py b/tests/unit/json_serialize/test_array_torch.py index f8786812..e037acdb 100644 --- a/tests/unit/json_serialize/test_array_torch.py +++ b/tests/unit/json_serialize/test_array_torch.py @@ -4,16 +4,20 @@ from muutils.json_serialize import JsonSerializer from muutils.json_serialize.array import ( + ArrayMode, arr_metadata, array_n_elements, load_array, serialize_array, ) -from muutils.json_serialize.types import _FORMAT_KEY +from muutils.json_serialize.types import _FORMAT_KEY # pyright: ignore[reportPrivateUsage] # pylint: disable=missing-class-docstring +_WITH_META_ARRAY_MODES: list[ArrayMode] = ["array_list_meta", "array_hex_meta", "array_b64_meta"] + + def test_arr_metadata_torch(): """Test arr_metadata() with torch tensors.""" # 1D tensor @@ -65,7 +69,7 @@ def test_serialize_load_torch_tensors(): ] for tensor in tensors: - for mode in ["array_list_meta", "array_hex_meta", "array_b64_meta"]: + for mode in _WITH_META_ARRAY_MODES: serialized = serialize_array(jser, tensor, "test", array_mode=mode) # type: ignore[arg-type] loaded = load_array(serialized) @@ -91,8 +95,8 @@ def test_torch_shape_dtype_preservation(): (torch.tensor([True, False, True], dtype=torch.bool), torch.bool), ] - for tensor, expected_dtype in dtype_tests: - for mode in ["array_list_meta", "array_hex_meta", "array_b64_meta"]: + for tensor, _expected_dtype in dtype_tests: + for mode in _WITH_META_ARRAY_MODES: serialized = serialize_array(jser, tensor, "test", array_mode=mode) # type: ignore[arg-type] loaded = load_array(serialized) @@ -108,7 +112,7 @@ def test_torch_zero_dim_tensor(): tensor_0d = torch.tensor(42) - for mode in ["array_list_meta", "array_hex_meta", "array_b64_meta"]: + for mode in _WITH_META_ARRAY_MODES: serialized = serialize_array(jser, tensor_0d, "test", array_mode=mode) # type: ignore[arg-type] loaded = load_array(serialized) @@ -116,7 +120,6 @@ def test_torch_zero_dim_tensor(): assert loaded.shape == tensor_0d.shape assert np.array_equal(loaded, tensor_0d.cpu().numpy()) - def test_torch_edge_cases(): """Test edge cases with torch tensors.""" jser = JsonSerializer(array_mode="array_list_meta") @@ -131,7 +134,7 @@ def test_torch_edge_cases(): special_tensor = torch.tensor( [float("inf"), float("-inf"), float("nan"), 0.0, -0.0] ) - for mode in ["array_list_meta", "array_hex_meta", "array_b64_meta"]: + for mode in _WITH_META_ARRAY_MODES: serialized = serialize_array(jser, special_tensor, "test", array_mode=mode) # type: ignore[arg-type] loaded = load_array(serialized) @@ -159,7 +162,7 @@ def test_torch_gpu_tensors(): # Create GPU tensor tensor_gpu = torch.tensor([1, 2, 3, 4], dtype=torch.float32, device="cuda") - for mode in ["array_list_meta", "array_hex_meta", "array_b64_meta"]: + for mode in _WITH_META_ARRAY_MODES: # Need to move to CPU first for numpy conversion tensor_cpu_torch = tensor_gpu.cpu() serialized = serialize_array(jser, tensor_cpu_torch, "test", array_mode=mode) # type: ignore[arg-type] From 0fdb121bec99563cb3fbc7ac70c8bdb2c05fe8cf Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Mon, 3 Nov 2025 15:55:49 +0000 Subject: [PATCH 61/72] re-run typing summary --- .meta/typing-summary.txt | 34 +++++++++++++++++----------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/.meta/typing-summary.txt b/.meta/typing-summary.txt index d550d0dc..b099a1bd 100644 --- a/.meta/typing-summary.txt +++ b/.meta/typing-summary.txt @@ -1,28 +1,28 @@ # Showing all errors -# mypy: Found 94 errors in 6 files (checked 116 source files) -# basedpyright: 637 errors, 3416 warnings, 0 notes +# mypy: Found 59 errors in 7 files (checked 116 source files) +# basedpyright: 598 errors, 3416 warnings, 0 notes # ty: Found 209 diagnostics [type_errors.mypy] -total_errors = 94 +total_errors = 59 [type_errors.mypy.by_type] -"typeddict-item" = 26 -"index" = 19 -"attr-defined" = 13 -"call-overload" = 12 -"literal-required" = 10 -"arg-type" = 6 +"index" = 20 +"attr-defined" = 11 +"call-overload" = 10 +"arg-type" = 8 "operator" = 6 +"return-value" = 2 "var-annotated" = 2 [type_errors.mypy.by_file] -"tests/unit/json_serialize/test_array_torch.py" = 35 -"tests/unit/json_serialize/test_array.py" = 24 -"tests/unit/test_jsonlines.py" = 12 -"tests/unit/json_serialize/test_json_serialize.py" = 11 +"tests/unit/json_serialize/test_array.py" = 19 +"tests/unit/json_serialize/test_array_torch.py" = 13 "tests/unit/json_serialize/test_serializable_field.py" = 11 +"tests/unit/test_jsonlines.py" = 7 +"muutils/json_serialize/json_serialize.py" = 4 +"tests/unit/json_serialize/test_json_serialize.py" = 4 "tests/unit/benchmark_parallel/benchmark_parallel.py" = 1 [type_errors.basedpyright] @@ -49,8 +49,8 @@ total_errors = 2705 "reportUnnecessaryIsInstance" = 19 "reportUntypedClassDecorator" = 17 "reportMissingSuperCall" = 14 -"reportUnnecessaryTypeIgnoreComment" = 13 "reportOptionalSubscript" = 13 +"reportUnnecessaryTypeIgnoreComment" = 12 "reportUntypedFunctionDecorator" = 12 "reportInvalidTypeArguments" = 9 "reportUnnecessaryComparison" = 8 @@ -64,9 +64,9 @@ total_errors = 2705 "reportMissingImports" = 3 "reportUnusedExpression" = 3 "reportImplicitStringConcatenation" = 2 +"reportUnusedImport" = 2 "reportOperatorIssue" = 2 "reportUntypedNamedTuple" = 2 -"reportUnusedImport" = 1 "reportUnusedFunction" = 1 "reportGeneralTypeIssues" = 1 @@ -84,7 +84,7 @@ total_errors = 2705 "muutils/misc/func.py" = 70 "tests/unit/cli/test_arg_bool.py" = 66 "muutils/misc/freezing.py" = 65 -"muutils/json_serialize/json_serialize.py" = 64 +"muutils/json_serialize/json_serialize.py" = 63 "muutils/dictmagic.py" = 62 "tests/unit/web/test_bundle_html.py" = 55 "muutils/spinner.py" = 51 @@ -152,6 +152,7 @@ total_errors = 2705 "muutils/logger/simplelogger.py" = 3 "muutils/misc/hashing.py" = 3 "muutils/json_serialize/array.py" = 2 +"muutils/json_serialize/util.py" = 2 "muutils/misc/numerical.py" = 2 "muutils/timeit_fancy.py" = 2 "tests/unit/logger/test_log_util.py" = 2 @@ -159,7 +160,6 @@ total_errors = 2705 "tests/unit/test_tensor_info.py" = 2 "tests/unit/test_timeit_fancy.py" = 2 "muutils/console_unicode.py" = 1 -"muutils/json_serialize/util.py" = 1 "muutils/logger/loggingstream.py" = 1 "muutils/math/bins.py" = 1 "muutils/misc/__init__.py" = 1 From c6f63e19d72e8281d18a3d8b6e036000929b5755 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Mon, 3 Nov 2025 16:16:15 +0000 Subject: [PATCH 62/72] wip --- .meta/typing-summary.txt | 64 +++++++++---------- muutils/json_serialize/array.py | 6 +- muutils/json_serialize/json_serialize.py | 22 +++---- muutils/json_serialize/types.py | 2 +- muutils/json_serialize/util.py | 6 +- muutils/logger/loggingstream.py | 3 +- muutils/math/bins.py | 18 ++++-- muutils/misc/__init__.py | 2 + muutils/misc/b64_decode.py | 2 +- muutils/misc/string.py | 42 +++++++++--- tests/unit/json_serialize/test_array_torch.py | 23 ++++--- 11 files changed, 113 insertions(+), 77 deletions(-) diff --git a/.meta/typing-summary.txt b/.meta/typing-summary.txt index b099a1bd..784295c6 100644 --- a/.meta/typing-summary.txt +++ b/.meta/typing-summary.txt @@ -1,17 +1,17 @@ # Showing all errors -# mypy: Found 59 errors in 7 files (checked 116 source files) -# basedpyright: 598 errors, 3416 warnings, 0 notes -# ty: Found 209 diagnostics +# mypy: Found 61 errors in 7 files (checked 116 source files) +# basedpyright: 565 errors, 3341 warnings, 0 notes +# ty: Found 212 diagnostics [type_errors.mypy] -total_errors = 59 +total_errors = 61 [type_errors.mypy.by_type] "index" = 20 "attr-defined" = 11 +"arg-type" = 10 "call-overload" = 10 -"arg-type" = 8 "operator" = 6 "return-value" = 2 "var-annotated" = 2 @@ -21,52 +21,52 @@ total_errors = 59 "tests/unit/json_serialize/test_array_torch.py" = 13 "tests/unit/json_serialize/test_serializable_field.py" = 11 "tests/unit/test_jsonlines.py" = 7 +"tests/unit/json_serialize/test_json_serialize.py" = 6 "muutils/json_serialize/json_serialize.py" = 4 -"tests/unit/json_serialize/test_json_serialize.py" = 4 "tests/unit/benchmark_parallel/benchmark_parallel.py" = 1 [type_errors.basedpyright] -total_errors = 2705 +total_errors = 2649 [type_errors.basedpyright.by_type] -"reportUnknownParameterType" = 434 -"reportMissingParameterType" = 384 -"reportAny" = 357 -"reportUnusedCallResult" = 292 -"reportUnknownVariableType" = 201 -"reportMissingTypeArgument" = 181 -"reportUnknownMemberType" = 136 +"reportUnknownParameterType" = 426 +"reportMissingParameterType" = 379 +"reportAny" = 354 +"reportUnusedCallResult" = 291 +"reportUnknownVariableType" = 195 +"reportMissingTypeArgument" = 180 +"reportUnknownMemberType" = 133 "reportUnknownLambdaType" = 127 "reportUnusedParameter" = 99 -"reportImplicitOverride" = 53 +"reportImplicitOverride" = 52 "reportInvalidTypeForm" = 49 -"reportIndexIssue" = 46 -"reportCallIssue" = 45 +"reportCallIssue" = 42 "reportUnannotatedClassAttribute" = 41 "reportPossiblyUnboundVariable" = 34 -"reportPrivateUsage" = 33 +"reportPrivateUsage" = 32 "reportUnreachable" = 31 +"reportIndexIssue" = 31 "reportUnnecessaryIsInstance" = 19 "reportUntypedClassDecorator" = 17 "reportMissingSuperCall" = 14 -"reportOptionalSubscript" = 13 -"reportUnnecessaryTypeIgnoreComment" = 12 +"reportUnnecessaryTypeIgnoreComment" = 13 "reportUntypedFunctionDecorator" = 12 "reportInvalidTypeArguments" = 9 "reportUnnecessaryComparison" = 8 -"reportCallInDefaultInitializer" = 8 "reportUnusedVariable" = 8 +"reportOptionalSubscript" = 8 "reportUndefinedVariable" = 7 "reportOptionalMemberAccess" = 7 "reportUninitializedInstanceVariable" = 5 +"reportCallInDefaultInitializer" = 4 "reportUnusedClass" = 4 "reportMissingTypeStubs" = 3 "reportMissingImports" = 3 "reportUnusedExpression" = 3 "reportImplicitStringConcatenation" = 2 -"reportUnusedImport" = 2 "reportOperatorIssue" = 2 "reportUntypedNamedTuple" = 2 +"reportUnusedImport" = 1 "reportUnusedFunction" = 1 "reportGeneralTypeIssues" = 1 @@ -84,9 +84,9 @@ total_errors = 2705 "muutils/misc/func.py" = 70 "tests/unit/cli/test_arg_bool.py" = 66 "muutils/misc/freezing.py" = 65 -"muutils/json_serialize/json_serialize.py" = 63 "muutils/dictmagic.py" = 62 "tests/unit/web/test_bundle_html.py" = 55 +"muutils/json_serialize/json_serialize.py" = 52 "muutils/spinner.py" = 51 "muutils/json_serialize/serializable_field.py" = 49 "tests/unit/benchmark_parallel/benchmark_parallel.py" = 49 @@ -105,7 +105,6 @@ total_errors = 2705 "tests/unit/test_spinner.py" = 32 "tests/unit/misc/test_freeze.py" = 31 "tests/unit/misc/test_numerical_conversions.py" = 27 -"tests/unit/json_serialize/test_array_torch.py" = 23 "muutils/misc/sequence.py" = 22 "muutils/mlutils.py" = 22 "muutils/validate_type.py" = 20 @@ -113,7 +112,6 @@ total_errors = 2705 "tests/unit/test_jsonlines.py" = 19 "muutils/logger/logger.py" = 18 "muutils/math/matrix_powers.py" = 18 -"muutils/misc/string.py" = 18 "tests/unit/json_serialize/test_util.py" = 18 "muutils/interval.py" = 17 "tests/unit/test_kappa.py" = 17 @@ -149,10 +147,9 @@ total_errors = 2705 "tests/unit/misc/test_sequence.py" = 4 "tests/unit/test_mlutils.py" = 4 "muutils/cli/command.py" = 3 +"muutils/json_serialize/array.py" = 3 "muutils/logger/simplelogger.py" = 3 "muutils/misc/hashing.py" = 3 -"muutils/json_serialize/array.py" = 2 -"muutils/json_serialize/util.py" = 2 "muutils/misc/numerical.py" = 2 "muutils/timeit_fancy.py" = 2 "tests/unit/logger/test_log_util.py" = 2 @@ -160,22 +157,19 @@ total_errors = 2705 "tests/unit/test_tensor_info.py" = 2 "tests/unit/test_timeit_fancy.py" = 2 "muutils/console_unicode.py" = 1 -"muutils/logger/loggingstream.py" = 1 -"muutils/math/bins.py" = 1 -"muutils/misc/__init__.py" = 1 -"muutils/misc/b64_decode.py" = 1 +"muutils/json_serialize/util.py" = 1 "tests/unit/benchmark_parallel/test_benchmark_demo.py" = 1 [type_errors.ty] -total_errors = 209 +total_errors = 212 [type_errors.ty.by_type] "unknown-argument" = 164 "unresolved-attribute" = 33 +"invalid-argument-type" = 5 "invalid-assignment" = 5 -"invalid-argument-type" = 3 "too-many-positional-arguments" = 3 -"unresolved-import" = 1 +"unresolved-import" = 2 [type_errors.ty.by_file] "tests/unit/json_serialize/serializable_dataclass/test_serializable_dataclass.py" = 134 @@ -184,4 +178,6 @@ total_errors = 209 "tests/unit/json_serialize/serializable_dataclass/test_sdc_properties_nested.py" = 9 "tests/unit/test_dictmagic.py" = 8 "tests/unit/benchmark_parallel/benchmark_parallel.py" = 2 +"tests/unit/json_serialize/test_json_serialize.py" = 2 +"muutils/logger/loggingstream.py" = 1 "tests/unit/benchmark_parallel/test_benchmark_demo.py" = 1 diff --git a/muutils/json_serialize/array.py b/muutils/json_serialize/array.py index 4e219bf3..045ec80a 100644 --- a/muutils/json_serialize/array.py +++ b/muutils/json_serialize/array.py @@ -62,8 +62,10 @@ def array_n_elements(arr: Any) -> int: # type: ignore[name-defined] # pyright: if isinstance(arr, np.ndarray): return arr.size elif str(type(arr)) == "": # pyright: ignore[reportUnknownArgumentType, reportAny] - assert hasattr(arr, "nelement"), "torch Tensor does not have nelement() method? this should not happen" # pyright: ignore[reportAny] - return arr.nelement() # pyright: ignore[reportAny] + assert hasattr(arr, "nelement"), ( + "torch Tensor does not have nelement() method? this should not happen" + ) # pyright: ignore[reportAny] + return arr.nelement() # pyright: ignore[reportAny] else: raise TypeError(f"invalid type: {type(arr)}") # pyright: ignore[reportAny] diff --git a/muutils/json_serialize/json_serialize.py b/muutils/json_serialize/json_serialize.py index 2bb9e789..382b7f61 100644 --- a/muutils/json_serialize/json_serialize.py +++ b/muutils/json_serialize/json_serialize.py @@ -283,10 +283,10 @@ class JsonSerializer: def __init__( self, - *args, + *args: None, array_mode: "ArrayMode" = "array_list_meta", error_mode: ErrorMode = ErrorMode.EXCEPT, - handlers_pre: MonoTuple[SerializerHandler] = tuple(), + handlers_pre: MonoTuple[SerializerHandler] = (), handlers_default: MonoTuple[SerializerHandler] = DEFAULT_HANDLERS, write_only_format: bool = False, ): @@ -305,8 +305,8 @@ def __init__( def json_serialize( self, - obj: Any, - path: ObjectPath = tuple(), + obj: Any, # pyright: ignore[reportAny] + path: ObjectPath = (), ) -> JSONitem: handler = None try: @@ -317,14 +317,14 @@ def json_serialize( if isinstance(output, dict) and _FORMAT_KEY in output: # TYPING: JSONitem has no idea that _FORMAT_KEY is str new_fmt: str = output.pop(_FORMAT_KEY) # type: ignore # pyright: ignore[reportAssignmentType] - output["__write_format__"] = new_fmt # type: ignore # pyright: ignore[reportGeneralTypeIssues] + output["__write_format__"] = new_fmt # type: ignore return output - raise ValueError(f"no handler found for object with {type(obj) = }") + raise ValueError(f"no handler found for object with {type(obj) = }") # pyright: ignore[reportAny] except Exception as e: if self.error_mode == ErrorMode.EXCEPT: - obj_str: str = repr(obj) + obj_str: str = repr(obj) # pyright: ignore[reportAny] if len(obj_str) > 1000: obj_str = obj_str[:1000] + "..." handler_uid = handler.uid if handler else "no handler matched" @@ -336,12 +336,12 @@ def json_serialize( f"error serializing at {path = }, will return as string\n{obj = }\nexception = {e}" ) - return repr(obj) + return repr(obj) # pyright: ignore[reportAny] def hashify( self, - obj: Any, - path: ObjectPath = tuple(), + obj: Any, # pyright: ignore[reportAny] + path: ObjectPath = (), force: bool = True, ) -> Hashableitem: """try to turn any object into something hashable""" @@ -354,6 +354,6 @@ def hashify( GLOBAL_JSON_SERIALIZER: JsonSerializer = JsonSerializer() -def json_serialize(obj: Any, path: ObjectPath = tuple()) -> JSONitem: +def json_serialize(obj: Any, path: ObjectPath = ()) -> JSONitem: # pyright: ignore[reportAny] """serialize object to json-serializable object with default config""" return GLOBAL_JSON_SERIALIZER.json_serialize(obj, path=path) diff --git a/muutils/json_serialize/types.py b/muutils/json_serialize/types.py index 34284c92..7c2d3ab0 100644 --- a/muutils/json_serialize/types.py +++ b/muutils/json_serialize/types.py @@ -13,7 +13,7 @@ None, ] -Hashableitem = Union[bool, int, float, str, tuple] # pyright: ignore[reportMissingTypeArgument] +Hashableitem = Union[BaseType, tuple["Hashableitem", ...]] _FORMAT_KEY: Final[str] = "__muutils_format__" diff --git a/muutils/json_serialize/util.py b/muutils/json_serialize/util.py index c229097e..b8231d88 100644 --- a/muutils/json_serialize/util.py +++ b/muutils/json_serialize/util.py @@ -8,7 +8,7 @@ import sys import typing import warnings -from typing import TYPE_CHECKING, Any, Callable, Iterable, TypeVar, Union +from typing import Any, Callable, Iterable, TypeVar, Union from muutils.json_serialize.types import BaseType, Hashableitem @@ -125,13 +125,13 @@ def newfunc(*args: Any, **kwargs: Any) -> Union[T_FuncTryCatchReturn, str]: # p # TYPING: can we get rid of any of these? -def _recursive_hashify(obj: Any, force: bool = True) -> Hashableitem: # pyright: ignore[reportUnknownParameterType, reportAny] +def _recursive_hashify(obj: Any, force: bool = True) -> Hashableitem: # pyright: ignore[reportAny] if isinstance(obj, typing.Mapping): return tuple((k, _recursive_hashify(v)) for k, v in obj.items()) # pyright: ignore[reportUnknownVariableType, reportUnknownArgumentType] elif isinstance(obj, (bool, int, float, str)): return obj elif isinstance(obj, (tuple, list, Iterable)): - return tuple(_recursive_hashify(v) for v in obj) # pyright: ignore[reportUnknownVariableType, reportUnknownArgumentType] + return tuple(_recursive_hashify(v) for v in obj) # pyright: ignore[reportUnknownVariableType] else: if force: return str(obj) # pyright: ignore[reportAny] diff --git a/muutils/logger/loggingstream.py b/muutils/logger/loggingstream.py index 77cad982..ca18222c 100644 --- a/muutils/logger/loggingstream.py +++ b/muutils/logger/loggingstream.py @@ -2,7 +2,7 @@ import time from dataclasses import dataclass, field -from typing import Any, Callable +from typing import Any, Callable, override from muutils.logger.simplelogger import AnyIO, NullIO from muutils.misc import sanitize_fname @@ -91,5 +91,6 @@ def __del__(self): self.handler.flush() self.handler.close() + @override def __str__(self): return f"LoggingStream(name={self.name}, aliases={self.aliases}, file={self.file}, default_level={self.default_level}, default_contents={self.default_contents})" diff --git a/muutils/math/bins.py b/muutils/math/bins.py index b1ac5047..a43c3343 100644 --- a/muutils/math/bins.py +++ b/muutils/math/bins.py @@ -29,25 +29,31 @@ def edges(self) -> Float[np.ndarray, "n_bins+1"]: ) if self.start == 0: return np.concatenate( - [ + [ # pyright: ignore[reportUnknownArgumentType] np.array([0]), np.logspace( - np.log10(self._log_min), np.log10(self.stop), self.n_bins + np.log10(self._log_min), # pyright: ignore[reportAny] + np.log10(self.stop), # pyright: ignore[reportAny] + self.n_bins, ), ] ) elif self.start < self._log_min and self._zero_in_small_start_log: return np.concatenate( - [ + [ # pyright: ignore[reportUnknownArgumentType] np.array([0]), np.logspace( - np.log10(self.start), np.log10(self.stop), self.n_bins + np.log10(self.start), # pyright: ignore[reportAny] + np.log10(self.stop), # pyright: ignore[reportAny] + self.n_bins, ), ] ) else: - return np.logspace( - np.log10(self.start), np.log10(self.stop), self.n_bins + 1 + return np.logspace( # pyright: ignore[reportUnknownVariableType] + np.log10(self.start), # pyright: ignore[reportAny] + np.log10(self.stop), # pyright: ignore[reportAny] + self.n_bins + 1, ) else: raise ValueError(f"Invalid scale {self.scale}, expected lin or log") diff --git a/muutils/misc/__init__.py b/muutils/misc/__init__.py index eff71720..2d194002 100644 --- a/muutils/misc/__init__.py +++ b/muutils/misc/__init__.py @@ -8,6 +8,8 @@ - `muutils.misc.classes` for some weird class utilities """ +# pyright: reportPrivateUsage=false + from muutils.misc.hashing import stable_hash from muutils.misc.sequence import ( WhenMissing, diff --git a/muutils/misc/b64_decode.py b/muutils/misc/b64_decode.py index 0001202b..c6cf41f6 100644 --- a/muutils/misc/b64_decode.py +++ b/muutils/misc/b64_decode.py @@ -6,4 +6,4 @@ input_file: Path = Path(argv[1]) out: Path = Path(argv[2]) input_text: str = input_file.read_text().replace("\n", "") - out.write_bytes(b64decode(input_text)) + out.write_bytes(b64decode(input_text)) # pyright: ignore[reportUnusedCallResult] diff --git a/muutils/misc/string.py b/muutils/misc/string.py index 4ff8269a..d0d65932 100644 --- a/muutils/misc/string.py +++ b/muutils/misc/string.py @@ -1,4 +1,5 @@ from __future__ import annotations +from typing import Any, Callable, TypeVar from muutils.misc.hashing import stable_hash @@ -55,34 +56,54 @@ def sanitize_name( return sanitized -def sanitize_fname(fname: str | None, **kwargs) -> str: +def sanitize_fname( + fname: str | None, + replace_invalid: str = "", + when_none: str | None = "_None_", + leading_digit_prefix: str = "", +) -> str: """sanitize a filename to posix standards - leave only alphanumerics, `_` (underscore), '-' (dash) and `.` (period) """ - return sanitize_name(fname, additional_allowed_chars="._-", **kwargs) + return sanitize_name( + name=fname, + additional_allowed_chars="._-", + replace_invalid=replace_invalid, + when_none=when_none, + leading_digit_prefix=leading_digit_prefix, + ) -def sanitize_identifier(fname: str | None, **kwargs) -> str: +def sanitize_identifier( + fname: str | None, + replace_invalid: str = "", + when_none: str | None = "_None_", +) -> str: """sanitize an identifier (variable or function name) - leave only alphanumerics and `_` (underscore) - prefix with `_` if it starts with a digit """ return sanitize_name( - fname, additional_allowed_chars="_", leading_digit_prefix="_", **kwargs + name=fname, + additional_allowed_chars="_", + replace_invalid=replace_invalid, + when_none=when_none, + leading_digit_prefix="_", ) def dict_to_filename( - data: dict, + data: dict[str, Any], format_str: str = "{key}_{val}", separator: str = ".", max_length: int = 255, ): # Convert the dictionary items to a list of strings using the format string formatted_items: list[str] = [ - format_str.format(key=k, val=v) for k, v in data.items() + format_str.format(key=k, val=v) + for k, v in data.items() # pyright: ignore[reportAny] ] # Join the formatted items using the separator @@ -99,10 +120,13 @@ def dict_to_filename( return f"h_{stable_hash(sanitized_str)}" -def dynamic_docstring(**doc_params): - def decorator(func): +T_Callable = TypeVar("T_Callable", bound=Callable[..., Any]) + + +def dynamic_docstring(**doc_params: str) -> Callable[[T_Callable], T_Callable]: + def decorator(func: T_Callable) -> T_Callable: if func.__doc__: - func.__doc__ = func.__doc__.format(**doc_params) + func.__doc__ = getattr(func, "__doc__", "").format(**doc_params) return func return decorator diff --git a/tests/unit/json_serialize/test_array_torch.py b/tests/unit/json_serialize/test_array_torch.py index e037acdb..3af93556 100644 --- a/tests/unit/json_serialize/test_array_torch.py +++ b/tests/unit/json_serialize/test_array_torch.py @@ -15,7 +15,11 @@ # pylint: disable=missing-class-docstring -_WITH_META_ARRAY_MODES: list[ArrayMode] = ["array_list_meta", "array_hex_meta", "array_b64_meta"] +_WITH_META_ARRAY_MODES: list[ArrayMode] = [ + "array_list_meta", + "array_hex_meta", + "array_b64_meta", +] def test_arr_metadata_torch(): @@ -120,6 +124,7 @@ def test_torch_zero_dim_tensor(): assert loaded.shape == tensor_0d.shape assert np.array_equal(loaded, tensor_0d.cpu().numpy()) + def test_torch_edge_cases(): """Test edge cases with torch tensors.""" jser = JsonSerializer(array_mode="array_list_meta") @@ -139,9 +144,9 @@ def test_torch_edge_cases(): loaded = load_array(serialized) # Check special values - assert np.isinf(loaded[0]) and loaded[0] > 0 - assert np.isinf(loaded[1]) and loaded[1] < 0 - assert np.isnan(loaded[2]) + assert np.isinf(loaded[0]) and loaded[0] > 0 # pyright: ignore[reportAny] + assert np.isinf(loaded[1]) and loaded[1] < 0 # pyright: ignore[reportAny] + assert np.isnan(loaded[2]) # pyright: ignore[reportAny] # Large tensor large_tensor = torch.randn(100, 100) @@ -199,11 +204,11 @@ def test_torch_serialization_integration(): assert isinstance(serialized["biases"], dict) assert serialized["biases"]["shape"] == [5] - assert serialized["metadata"]["epochs"] == 10 + assert serialized["metadata"]["epochs"] == 10 # pyright: ignore[reportArgumentType, reportCallIssue, reportIndexIssue, reportOptionalSubscript] # Check nested tensors - assert isinstance(serialized["history"][0]["loss"], dict) - assert _FORMAT_KEY in serialized["history"][0]["loss"] + assert isinstance(serialized["history"][0]["loss"], dict) # pyright: ignore[reportArgumentType, reportCallIssue, reportIndexIssue, reportOptionalSubscript] + assert _FORMAT_KEY in serialized["history"][0]["loss"] # pyright: ignore[reportArgumentType, reportCallIssue, reportIndexIssue, reportOptionalSubscript, reportOperatorIssue] def test_mixed_numpy_torch(): @@ -229,5 +234,5 @@ def test_mixed_numpy_torch(): assert _FORMAT_KEY in serialized["torch_tensor"] # Check format strings identify the type - assert "numpy" in serialized["numpy_array"][_FORMAT_KEY] - assert "torch" in serialized["torch_tensor"][_FORMAT_KEY] + assert "numpy" in serialized["numpy_array"][_FORMAT_KEY] # pyright: ignore[reportOperatorIssue] + assert "torch" in serialized["torch_tensor"][_FORMAT_KEY] # pyright: ignore[reportOperatorIssue] From 5b2c87179532e6702893cd563ce1bc66ff331976 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Tue, 4 Nov 2025 13:25:25 +0000 Subject: [PATCH 63/72] wip typing --- .meta/typing-summary.txt | 22 +++++++++------------- muutils/console_unicode.py | 2 +- muutils/logger/loggingstream.py | 4 +++- muutils/misc/hashing.py | 5 +++-- muutils/misc/numerical.py | 6 +++--- muutils/timeit_fancy.py | 4 ++-- 6 files changed, 21 insertions(+), 22 deletions(-) diff --git a/.meta/typing-summary.txt b/.meta/typing-summary.txt index 784295c6..d40abe7a 100644 --- a/.meta/typing-summary.txt +++ b/.meta/typing-summary.txt @@ -1,7 +1,7 @@ # Showing all errors # mypy: Found 61 errors in 7 files (checked 116 source files) -# basedpyright: 565 errors, 3341 warnings, 0 notes +# basedpyright: 564 errors, 3333 warnings, 0 notes # ty: Found 212 diagnostics [type_errors.mypy] @@ -26,13 +26,13 @@ total_errors = 61 "tests/unit/benchmark_parallel/benchmark_parallel.py" = 1 [type_errors.basedpyright] -total_errors = 2649 +total_errors = 2641 [type_errors.basedpyright.by_type] -"reportUnknownParameterType" = 426 -"reportMissingParameterType" = 379 +"reportUnknownParameterType" = 425 +"reportMissingParameterType" = 378 "reportAny" = 354 -"reportUnusedCallResult" = 291 +"reportUnusedCallResult" = 290 "reportUnknownVariableType" = 195 "reportMissingTypeArgument" = 180 "reportUnknownMemberType" = 133 @@ -43,14 +43,14 @@ total_errors = 2649 "reportCallIssue" = 42 "reportUnannotatedClassAttribute" = 41 "reportPossiblyUnboundVariable" = 34 -"reportPrivateUsage" = 32 -"reportUnreachable" = 31 +"reportPrivateUsage" = 31 "reportIndexIssue" = 31 -"reportUnnecessaryIsInstance" = 19 +"reportUnreachable" = 30 +"reportUnnecessaryIsInstance" = 18 "reportUntypedClassDecorator" = 17 "reportMissingSuperCall" = 14 -"reportUnnecessaryTypeIgnoreComment" = 13 "reportUntypedFunctionDecorator" = 12 +"reportUnnecessaryTypeIgnoreComment" = 11 "reportInvalidTypeArguments" = 9 "reportUnnecessaryComparison" = 8 "reportUnusedVariable" = 8 @@ -149,14 +149,10 @@ total_errors = 2649 "muutils/cli/command.py" = 3 "muutils/json_serialize/array.py" = 3 "muutils/logger/simplelogger.py" = 3 -"muutils/misc/hashing.py" = 3 -"muutils/misc/numerical.py" = 2 -"muutils/timeit_fancy.py" = 2 "tests/unit/logger/test_log_util.py" = 2 "tests/unit/test_chunks.py" = 2 "tests/unit/test_tensor_info.py" = 2 "tests/unit/test_timeit_fancy.py" = 2 -"muutils/console_unicode.py" = 1 "muutils/json_serialize/util.py" = 1 "tests/unit/benchmark_parallel/test_benchmark_demo.py" = 1 diff --git a/muutils/console_unicode.py b/muutils/console_unicode.py index 5ea3b993..f9e0b85b 100644 --- a/muutils/console_unicode.py +++ b/muutils/console_unicode.py @@ -28,7 +28,7 @@ def get_console_safe_str( ``` """ try: - default.encode(locale.getpreferredencoding()) + _ = default.encode(locale.getpreferredencoding()) return default except UnicodeEncodeError: return fallback diff --git a/muutils/logger/loggingstream.py b/muutils/logger/loggingstream.py index ca18222c..5d8f565d 100644 --- a/muutils/logger/loggingstream.py +++ b/muutils/logger/loggingstream.py @@ -2,7 +2,9 @@ import time from dataclasses import dataclass, field -from typing import Any, Callable, override + +# TYPING: ty fails to resolve this?? +from typing import Any, Callable, override # type: ignore[unresolved-import] from muutils.logger.simplelogger import AnyIO, NullIO from muutils.misc import sanitize_fname diff --git a/muutils/misc/hashing.py b/muutils/misc/hashing.py index 0e384cdc..e6ca77d4 100644 --- a/muutils/misc/hashing.py +++ b/muutils/misc/hashing.py @@ -3,6 +3,7 @@ import base64 import hashlib import json +from typing import Any def stable_hash(s: str | bytes) -> int: @@ -13,12 +14,12 @@ def stable_hash(s: str | bytes) -> int: s_bytes = s.encode("utf-8") else: s_bytes = s - hash_obj: hashlib._Hash = hashlib.md5(s_bytes) + hash_obj: hashlib._Hash = hashlib.md5(s_bytes) # pyright: ignore[reportPrivateUsage] # get digest and convert to int return int.from_bytes(hash_obj.digest(), "big") -def stable_json_dumps(d) -> str: +def stable_json_dumps(d: Any) -> str: # pyright: ignore[reportAny] return json.dumps( d, sort_keys=True, diff --git a/muutils/misc/numerical.py b/muutils/misc/numerical.py index 2cd74c16..441f046e 100644 --- a/muutils/misc/numerical.py +++ b/muutils/misc/numerical.py @@ -73,8 +73,8 @@ def str_to_numeric( """ # check is string - if not isinstance(quantity, str): - raise TypeError( + if not isinstance(quantity, str): # pyright: ignore[reportUnnecessaryIsInstance] + raise TypeError( # pyright: ignore[reportUnreachable] f"quantity must be a string, got '{type(quantity) = }' '{quantity = }'" ) @@ -97,7 +97,7 @@ def str_to_numeric( if mapping is True or mapping is None: _mapping = _REVERSE_SHORTEN_MAP else: - _mapping = mapping # type: ignore[assignment] + _mapping = mapping # type: ignore[assignment] # pyright: ignore[reportAssignmentType] quantity_original: str = quantity diff --git a/muutils/timeit_fancy.py b/muutils/timeit_fancy.py index 30dafc3e..df57573a 100644 --- a/muutils/timeit_fancy.py +++ b/muutils/timeit_fancy.py @@ -23,9 +23,9 @@ class FancyTimeitResult(NamedTuple): def timeit_fancy( cmd: Union[Callable[[], T_return], str], - setup: Union[str, Callable[[], Any]] = lambda: None, # pyright: ignore[reportExplicitAny] + setup: Union[str, Callable[[], Any]] = lambda: None, repeats: int = 5, - namespace: Union[dict[str, Any], None] = None, # pyright: ignore[reportExplicitAny] + namespace: Union[dict[str, Any], None] = None, get_return: bool = True, do_profiling: bool = False, ) -> FancyTimeitResult: From dc1fb451a636114399f648342f028846904ad676 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Sun, 18 Jan 2026 20:43:16 -0700 Subject: [PATCH 64/72] some typing fixes, wip --- .meta/typing-summary.txt | 38 +++++++++---------- muutils/json_serialize/serializable_field.py | 2 +- .../test_serializable_dataclass.py | 12 +++--- .../json_serialize/test_serializable_field.py | 10 ++--- 4 files changed, 30 insertions(+), 32 deletions(-) diff --git a/.meta/typing-summary.txt b/.meta/typing-summary.txt index d40abe7a..3871266c 100644 --- a/.meta/typing-summary.txt +++ b/.meta/typing-summary.txt @@ -1,32 +1,31 @@ # Showing all errors -# mypy: Found 61 errors in 7 files (checked 116 source files) -# basedpyright: 564 errors, 3333 warnings, 0 notes -# ty: Found 212 diagnostics +# mypy: Found 52 errors in 7 files (checked 116 source files) +# basedpyright: 550 errors, 3319 warnings, 0 notes +# ty: Found 204 diagnostics [type_errors.mypy] -total_errors = 61 +total_errors = 52 [type_errors.mypy.by_type] "index" = 20 -"attr-defined" = 11 "arg-type" = 10 "call-overload" = 10 "operator" = 6 +"attr-defined" = 4 "return-value" = 2 -"var-annotated" = 2 [type_errors.mypy.by_file] "tests/unit/json_serialize/test_array.py" = 19 "tests/unit/json_serialize/test_array_torch.py" = 13 -"tests/unit/json_serialize/test_serializable_field.py" = 11 "tests/unit/test_jsonlines.py" = 7 "tests/unit/json_serialize/test_json_serialize.py" = 6 "muutils/json_serialize/json_serialize.py" = 4 +"tests/unit/json_serialize/test_serializable_field.py" = 2 "tests/unit/benchmark_parallel/benchmark_parallel.py" = 1 [type_errors.basedpyright] -total_errors = 2641 +total_errors = 2631 [type_errors.basedpyright.by_type] "reportUnknownParameterType" = 425 @@ -34,9 +33,9 @@ total_errors = 2641 "reportAny" = 354 "reportUnusedCallResult" = 290 "reportUnknownVariableType" = 195 -"reportMissingTypeArgument" = 180 -"reportUnknownMemberType" = 133 +"reportMissingTypeArgument" = 178 "reportUnknownLambdaType" = 127 +"reportUnknownMemberType" = 126 "reportUnusedParameter" = 99 "reportImplicitOverride" = 52 "reportInvalidTypeForm" = 49 @@ -56,7 +55,7 @@ total_errors = 2641 "reportUnusedVariable" = 8 "reportOptionalSubscript" = 8 "reportUndefinedVariable" = 7 -"reportOptionalMemberAccess" = 7 +"reportOptionalMemberAccess" = 6 "reportUninitializedInstanceVariable" = 5 "reportCallInDefaultInitializer" = 4 "reportUnusedClass" = 4 @@ -74,7 +73,7 @@ total_errors = 2641 "tests/unit/json_serialize/test_json_serialize.py" = 164 "tests/unit/test_dbg.py" = 150 "tests/unit/test_parallel.py" = 131 -"tests/unit/json_serialize/serializable_dataclass/test_serializable_dataclass.py" = 121 +"tests/unit/json_serialize/serializable_dataclass/test_serializable_dataclass.py" = 119 "tests/unit/misc/test_func.py" = 98 "tests/unit/validate_type/test_get_kwargs.py" = 87 "tests/unit/validate_type/test_validate_type.py" = 86 @@ -88,15 +87,15 @@ total_errors = 2641 "tests/unit/web/test_bundle_html.py" = 55 "muutils/json_serialize/json_serialize.py" = 52 "muutils/spinner.py" = 51 -"muutils/json_serialize/serializable_field.py" = 49 "tests/unit/benchmark_parallel/benchmark_parallel.py" = 49 +"muutils/json_serialize/serializable_field.py" = 48 "tests/unit/errormode/test_errormode_init.py" = 46 "muutils/tensor_info.py" = 45 -"tests/unit/json_serialize/test_serializable_field.py" = 45 "muutils/parallel.py" = 42 "muutils/statcounter.py" = 41 "muutils/sysinfo.py" = 39 "tests/unit/json_serialize/test_array.py" = 38 +"tests/unit/json_serialize/test_serializable_field.py" = 38 "muutils/web/bundle_html.py" = 37 "tests/unit/test_dictmagic.py" = 37 "tests/unit/errormode/test_errormode_functionality.py" = 36 @@ -157,23 +156,22 @@ total_errors = 2641 "tests/unit/benchmark_parallel/test_benchmark_demo.py" = 1 [type_errors.ty] -total_errors = 212 +total_errors = 204 [type_errors.ty.by_type] "unknown-argument" = 164 -"unresolved-attribute" = 33 -"invalid-argument-type" = 5 +"unresolved-attribute" = 27 "invalid-assignment" = 5 +"invalid-argument-type" = 4 "too-many-positional-arguments" = 3 -"unresolved-import" = 2 +"unresolved-import" = 1 [type_errors.ty.by_file] "tests/unit/json_serialize/serializable_dataclass/test_serializable_dataclass.py" = 134 -"tests/unit/json_serialize/test_serializable_field.py" = 29 "tests/unit/json_serialize/serializable_dataclass/test_methods_no_override.py" = 26 +"tests/unit/json_serialize/test_serializable_field.py" = 22 "tests/unit/json_serialize/serializable_dataclass/test_sdc_properties_nested.py" = 9 "tests/unit/test_dictmagic.py" = 8 "tests/unit/benchmark_parallel/benchmark_parallel.py" = 2 "tests/unit/json_serialize/test_json_serialize.py" = 2 -"muutils/logger/loggingstream.py" = 1 "tests/unit/benchmark_parallel/test_benchmark_demo.py" = 1 diff --git a/muutils/json_serialize/serializable_field.py b/muutils/json_serialize/serializable_field.py index ad34992a..7313e4da 100644 --- a/muutils/json_serialize/serializable_field.py +++ b/muutils/json_serialize/serializable_field.py @@ -119,7 +119,7 @@ def __init__( self.custom_typecheck_fn: Optional[Callable[[type], bool]] = custom_typecheck_fn @classmethod - def from_Field(cls, field: dataclasses.Field) -> "SerializableField": + def from_Field(cls, field: "dataclasses.Field[Any]") -> "SerializableField": """copy all values from a `dataclasses.Field` to new `SerializableField`""" return cls( default=field.default, diff --git a/tests/unit/json_serialize/serializable_dataclass/test_serializable_dataclass.py b/tests/unit/json_serialize/serializable_dataclass/test_serializable_dataclass.py index bc20ab86..65848302 100644 --- a/tests/unit/json_serialize/serializable_dataclass/test_serializable_dataclass.py +++ b/tests/unit/json_serialize/serializable_dataclass/test_serializable_dataclass.py @@ -31,7 +31,7 @@ class BasicAutofields(SerializableDataclass): def test_basic_auto_fields(): data = dict(a="hello", b=42, c=[1, 2, 3]) - instance = BasicAutofields(**data) # type: ignore[arg-type] + instance = BasicAutofields(**data) # type: ignore[arg-type] # pyright: ignore[reportArgumentType] data_with_format = data.copy() data_with_format[_FORMAT_KEY] = "BasicAutofields(SerializableDataclass)" assert instance.serialize() == data_with_format @@ -212,7 +212,7 @@ class MyClass(SerializableDataclass): age: int = serializable_field( serialization_fn=lambda x: x + 1, loading_fn=lambda x: x["age"] - 1 ) - items: list = serializable_field(default_factory=list) + items: List[Any] = serializable_field(default_factory=list) @property def full_name(self) -> str: @@ -590,7 +590,7 @@ def test_dict_type_validation(): # Invalid int_dict with pytest.raises(FieldTypeMismatchError): StrictDictContainer( - int_dict={"a": "not an int"}, # type: ignore[dict-item] + int_dict={"a": "not an int"}, # type: ignore[dict-item] # pyright: ignore[reportArgumentType] str_dict={"x": "hello"}, float_dict={"m": 1.0}, ) @@ -599,7 +599,7 @@ def test_dict_type_validation(): with pytest.raises(FieldTypeMismatchError): StrictDictContainer( int_dict={"a": 1}, - str_dict={"x": 123}, # type: ignore[dict-item] + str_dict={"x": 123}, # type: ignore[dict-item] # pyright: ignore[reportArgumentType] float_dict={"m": 1.0}, ) @@ -1017,7 +1017,7 @@ def test_error_handling(): with pytest.raises(TypeError): BaseClass.load({}) - x = BaseClass(base_field=42, shared_field="invalid") # type: ignore[arg-type] + x = BaseClass(base_field=42, shared_field="invalid") # type: ignore[arg-type] # pyright: ignore[reportArgumentType] assert not x.validate_fields_types() with pytest.raises(FieldTypeMismatchError): @@ -1059,4 +1059,4 @@ class Node(SerializableDataclass): loaded = Node.load(serialized) assert loaded.value == "one" # TODO: idk why we type ignore here - assert loaded.next.value == "two" # type: ignore[union-attr] + assert loaded.next.value == "two" # type: ignore[union-attr] # pyright: ignore[reportOptionalMemberAccess] diff --git a/tests/unit/json_serialize/test_serializable_field.py b/tests/unit/json_serialize/test_serializable_field.py index 99515f33..36ef6787 100644 --- a/tests/unit/json_serialize/test_serializable_field.py +++ b/tests/unit/json_serialize/test_serializable_field.py @@ -115,7 +115,7 @@ def test_SerializableField_doc(): def test_from_Field(): """Test converting a dataclasses.Field to SerializableField.""" # Create a standard dataclasses.Field - dc_field = field( + dc_field: dataclasses.Field[int] = field( # type: ignore[assignment] default=42, init=True, repr=True, @@ -140,7 +140,7 @@ def test_from_Field(): assert sf.deserialize_fn is None # Test with default_factory and init=False to avoid init=True, serialize=False error - dc_field2 = field(default_factory=list, repr=True, init=True) + dc_field2: dataclasses.Field[list[Any]] = field(default_factory=list, repr=True, init=True) # type: ignore[assignment] sf2 = SerializableField.from_Field(dc_field2) assert sf2.default_factory == list # noqa: E721 assert sf2.default is dataclasses.MISSING @@ -308,15 +308,15 @@ def test_serializable_field_function(): assert f1.serialize is True # Test with default - f2 = serializable_field(default=100) + f2: SerializableField = serializable_field(default=100) # type: ignore[assignment] assert f2.default == 100 # Test with default_factory - f3 = serializable_field(default_factory=list) + f3: SerializableField = serializable_field(default_factory=list) # type: ignore[assignment] assert f3.default_factory == list # noqa: E721 # Test with all parameters - f4 = serializable_field( + f4: SerializableField = serializable_field( # type: ignore[assignment] default=42, init=True, repr=False, From 47e7fe886f5c61d5e76fc2b8fc94c16235fa4f26 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Sun, 18 Jan 2026 21:03:44 -0700 Subject: [PATCH 65/72] more typing wip --- .meta/typing-summary.txt | 45 ++++++++----------- TODO.md | 2 +- muutils/json_serialize/json_serialize.py | 15 ++++--- .../benchmark_parallel/benchmark_parallel.py | 4 +- .../benchmark_parallel/test_benchmark_demo.py | 2 +- .../test_sdc_properties_nested.py | 2 +- tests/unit/json_serialize/test_array.py | 14 +++--- tests/unit/json_serialize/test_array_torch.py | 18 +++++--- .../json_serialize/test_json_serialize.py | 2 +- tests/unit/test_jsonlines.py | 10 ++++- 10 files changed, 63 insertions(+), 51 deletions(-) diff --git a/.meta/typing-summary.txt b/.meta/typing-summary.txt index 3871266c..eaef4f92 100644 --- a/.meta/typing-summary.txt +++ b/.meta/typing-summary.txt @@ -1,31 +1,28 @@ # Showing all errors -# mypy: Found 52 errors in 7 files (checked 116 source files) -# basedpyright: 550 errors, 3319 warnings, 0 notes -# ty: Found 204 diagnostics +# mypy: Found 20 errors in 5 files (checked 116 source files) +# basedpyright: 493 errors, 3319 warnings, 0 notes +# ty: Found 198 diagnostics [type_errors.mypy] -total_errors = 52 +total_errors = 20 [type_errors.mypy.by_type] -"index" = 20 -"arg-type" = 10 -"call-overload" = 10 +"arg-type" = 7 "operator" = 6 "attr-defined" = 4 -"return-value" = 2 +"call-overload" = 2 +"import-not-found" = 1 [type_errors.mypy.by_file] -"tests/unit/json_serialize/test_array.py" = 19 -"tests/unit/json_serialize/test_array_torch.py" = 13 -"tests/unit/test_jsonlines.py" = 7 +"tests/unit/json_serialize/test_array.py" = 9 "tests/unit/json_serialize/test_json_serialize.py" = 6 -"muutils/json_serialize/json_serialize.py" = 4 "tests/unit/json_serialize/test_serializable_field.py" = 2 -"tests/unit/benchmark_parallel/benchmark_parallel.py" = 1 +"tests/unit/json_serialize/test_array_torch.py" = 2 +"tests/unit/benchmark_parallel/test_benchmark_demo.py" = 1 [type_errors.basedpyright] -total_errors = 2631 +total_errors = 2593 [type_errors.basedpyright.by_type] "reportUnknownParameterType" = 425 @@ -39,11 +36,10 @@ total_errors = 2631 "reportUnusedParameter" = 99 "reportImplicitOverride" = 52 "reportInvalidTypeForm" = 49 -"reportCallIssue" = 42 "reportUnannotatedClassAttribute" = 41 +"reportCallIssue" = 36 "reportPossiblyUnboundVariable" = 34 "reportPrivateUsage" = 31 -"reportIndexIssue" = 31 "reportUnreachable" = 30 "reportUnnecessaryIsInstance" = 18 "reportUntypedClassDecorator" = 17 @@ -53,8 +49,8 @@ total_errors = 2631 "reportInvalidTypeArguments" = 9 "reportUnnecessaryComparison" = 8 "reportUnusedVariable" = 8 -"reportOptionalSubscript" = 8 "reportUndefinedVariable" = 7 +"reportIndexIssue" = 7 "reportOptionalMemberAccess" = 6 "reportUninitializedInstanceVariable" = 5 "reportCallInDefaultInitializer" = 4 @@ -94,7 +90,6 @@ total_errors = 2631 "muutils/parallel.py" = 42 "muutils/statcounter.py" = 41 "muutils/sysinfo.py" = 39 -"tests/unit/json_serialize/test_array.py" = 38 "tests/unit/json_serialize/test_serializable_field.py" = 38 "muutils/web/bundle_html.py" = 37 "tests/unit/test_dictmagic.py" = 37 @@ -108,7 +103,6 @@ total_errors = 2631 "muutils/mlutils.py" = 22 "muutils/validate_type.py" = 20 "tests/unit/validate_type/test_validate_type_special.py" = 20 -"tests/unit/test_jsonlines.py" = 19 "muutils/logger/logger.py" = 18 "muutils/math/matrix_powers.py" = 18 "tests/unit/json_serialize/test_util.py" = 18 @@ -117,6 +111,7 @@ total_errors = 2631 "muutils/logger/log_util.py" = 16 "muutils/tensor_utils.py" = 16 "tests/unit/json_serialize/serializable_dataclass/test_methods_no_override.py" = 15 +"tests/unit/json_serialize/test_array.py" = 15 "tests/unit/nbutils/test_configure_notebook.py" = 14 "muutils/jsonlines.py" = 13 "muutils/nbutils/run_notebook_tests.py" = 12 @@ -144,6 +139,7 @@ total_errors = 2631 "muutils/web/html_to_pdf.py" = 4 "tests/unit/math/test_matrix_powers_torch.py" = 4 "tests/unit/misc/test_sequence.py" = 4 +"tests/unit/test_jsonlines.py" = 4 "tests/unit/test_mlutils.py" = 4 "muutils/cli/command.py" = 3 "muutils/json_serialize/array.py" = 3 @@ -156,22 +152,17 @@ total_errors = 2631 "tests/unit/benchmark_parallel/test_benchmark_demo.py" = 1 [type_errors.ty] -total_errors = 204 +total_errors = 198 [type_errors.ty.by_type] "unknown-argument" = 164 "unresolved-attribute" = 27 "invalid-assignment" = 5 -"invalid-argument-type" = 4 -"too-many-positional-arguments" = 3 -"unresolved-import" = 1 +"too-many-positional-arguments" = 2 [type_errors.ty.by_file] "tests/unit/json_serialize/serializable_dataclass/test_serializable_dataclass.py" = 134 "tests/unit/json_serialize/serializable_dataclass/test_methods_no_override.py" = 26 "tests/unit/json_serialize/test_serializable_field.py" = 22 -"tests/unit/json_serialize/serializable_dataclass/test_sdc_properties_nested.py" = 9 +"tests/unit/json_serialize/serializable_dataclass/test_sdc_properties_nested.py" = 8 "tests/unit/test_dictmagic.py" = 8 -"tests/unit/benchmark_parallel/benchmark_parallel.py" = 2 -"tests/unit/json_serialize/test_json_serialize.py" = 2 -"tests/unit/benchmark_parallel/test_benchmark_demo.py" = 1 diff --git a/TODO.md b/TODO.md index 46890f31..67eb4c4d 100644 --- a/TODO.md +++ b/TODO.md @@ -18,7 +18,7 @@ 4. Implement that fix -run type checking only on the specific file you are changing to verify that the errors are fixed. +5. run type checking only on the specific file you are changing to verify that the errors are fixed. use `uv run `, not `python -m` # Guidelines: diff --git a/muutils/json_serialize/json_serialize.py b/muutils/json_serialize/json_serialize.py index 382b7f61..ae459b3e 100644 --- a/muutils/json_serialize/json_serialize.py +++ b/muutils/json_serialize/json_serialize.py @@ -14,7 +14,7 @@ import warnings from dataclasses import dataclass, is_dataclass from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable, Iterable, Mapping, Set, Union +from typing import TYPE_CHECKING, Any, Callable, Iterable, Mapping, Set, Union, cast from muutils.errormode import ErrorMode @@ -195,16 +195,19 @@ def _serialize_override_serialize_func( ), SerializerHandler( check=lambda self, obj, path: str(type(obj)) == "", - serialize_func=lambda self, obj, path: serialize_array(self, obj, path=path), + serialize_func=lambda self, obj, path: cast(JSONitem, serialize_array(self, obj, path=path)), uid="numpy.ndarray", desc="numpy arrays", ), SerializerHandler( check=lambda self, obj, path: str(type(obj)) == "", - serialize_func=lambda self, obj, path: serialize_array( - self, - obj.detach().cpu(), - path=path, # pyright: ignore[reportAny] + serialize_func=lambda self, obj, path: cast( + JSONitem, + serialize_array( + self, + obj.detach().cpu(), + path=path, # pyright: ignore[reportAny] + ), ), uid="torch.Tensor", desc="pytorch tensors", diff --git a/tests/unit/benchmark_parallel/benchmark_parallel.py b/tests/unit/benchmark_parallel/benchmark_parallel.py index c1288813..8c4fd98c 100644 --- a/tests/unit/benchmark_parallel/benchmark_parallel.py +++ b/tests/unit/benchmark_parallel/benchmark_parallel.py @@ -248,7 +248,7 @@ def benchmark_run_maybe_parallel( def plot_speedup_by_data_size( - df: pd.DataFrame, task_type: str | None = None, save_path: str | None = None + df: pd.DataFrame, task_type: str | None = None, save_path: str | Path | None = None ): """Plot speedup vs data size for different methods.""" import matplotlib.pyplot as plt # type: ignore[import-untyped] @@ -328,7 +328,7 @@ def plot_efficiency_heatmap(df: pd.DataFrame, save_path: str | Path | None = Non plt.imshow(pivot_df, aspect="auto", cmap="YlOrRd", vmin=0) plt.colorbar(label="Speedup") plt.yticks(range(len(pivot_df.index)), [f"{t[0]}-{t[1]}" for t in pivot_df.index]) - plt.xticks(range(len(pivot_df.columns)), pivot_df.columns, rotation=45) + plt.xticks(range(len(pivot_df.columns)), list(pivot_df.columns), rotation=45) plt.title("Parallelization Efficiency Heatmap") plt.tight_layout() diff --git a/tests/unit/benchmark_parallel/test_benchmark_demo.py b/tests/unit/benchmark_parallel/test_benchmark_demo.py index b7af7f19..2174f266 100644 --- a/tests/unit/benchmark_parallel/test_benchmark_demo.py +++ b/tests/unit/benchmark_parallel/test_benchmark_demo.py @@ -2,7 +2,7 @@ from pathlib import Path -from benchmark_parallel import io_bound_task, light_cpu_task, main +from .benchmark_parallel import io_bound_task, light_cpu_task, main def test_main(): diff --git a/tests/unit/json_serialize/serializable_dataclass/test_sdc_properties_nested.py b/tests/unit/json_serialize/serializable_dataclass/test_sdc_properties_nested.py index 5167970b..ffbfc60e 100644 --- a/tests/unit/json_serialize/serializable_dataclass/test_sdc_properties_nested.py +++ b/tests/unit/json_serialize/serializable_dataclass/test_sdc_properties_nested.py @@ -76,7 +76,7 @@ def test_serialize_titled_person(): if SUPPORTS_KW_ONLY: with pytest.raises(TypeError): - TitledPerson("Jane", "Smith", "Dr.") + TitledPerson("Jane", "Smith", "Dr.") # type: ignore[too-many-positional-arguments] serialized = instance.serialize() diff --git a/tests/unit/json_serialize/test_array.py b/tests/unit/json_serialize/test_array.py index bfa72678..6c419d2b 100644 --- a/tests/unit/json_serialize/test_array.py +++ b/tests/unit/json_serialize/test_array.py @@ -154,9 +154,11 @@ def test_array_serialization_handlers(): } serialized = jser.json_serialize(data_dict) - assert isinstance(serialized["array"], dict) - assert _FORMAT_KEY in serialized["array"] - assert serialized["array"]["shape"] == [4] + assert isinstance(serialized, dict) + serialized_array = serialized["array"] + assert isinstance(serialized_array, dict) + assert _FORMAT_KEY in serialized_array + assert serialized_array["shape"] == [4] # Array in a list data_list = [ @@ -166,8 +168,10 @@ def test_array_serialization_handlers(): ] serialized_list = jser.json_serialize(data_list) - assert isinstance(serialized_list[1], dict) - assert _FORMAT_KEY in serialized_list[1] + assert isinstance(serialized_list, list) + serialized_list_item = serialized_list[1] + assert isinstance(serialized_list_item, dict) + assert _FORMAT_KEY in serialized_list_item # Test different array modes for mode in ["list", "array_list_meta", "array_hex_meta", "array_b64_meta"]: diff --git a/tests/unit/json_serialize/test_array_torch.py b/tests/unit/json_serialize/test_array_torch.py index 3af93556..d22e22ac 100644 --- a/tests/unit/json_serialize/test_array_torch.py +++ b/tests/unit/json_serialize/test_array_torch.py @@ -201,14 +201,22 @@ def test_torch_serialization_integration(): assert _FORMAT_KEY in serialized["model_weights"] assert serialized["model_weights"]["shape"] == [10, 5] - assert isinstance(serialized["biases"], dict) - assert serialized["biases"]["shape"] == [5] + serialized_biases = serialized["biases"] + assert isinstance(serialized_biases, dict) + assert serialized_biases["shape"] == [5] - assert serialized["metadata"]["epochs"] == 10 # pyright: ignore[reportArgumentType, reportCallIssue, reportIndexIssue, reportOptionalSubscript] + serialized_metadata = serialized["metadata"] + assert isinstance(serialized_metadata, dict) + assert serialized_metadata["epochs"] == 10 # Check nested tensors - assert isinstance(serialized["history"][0]["loss"], dict) # pyright: ignore[reportArgumentType, reportCallIssue, reportIndexIssue, reportOptionalSubscript] - assert _FORMAT_KEY in serialized["history"][0]["loss"] # pyright: ignore[reportArgumentType, reportCallIssue, reportIndexIssue, reportOptionalSubscript, reportOperatorIssue] + serialized_history = serialized["history"] + assert isinstance(serialized_history, list) + history_item_0 = serialized_history[0] + assert isinstance(history_item_0, dict) + history_item_0_loss = history_item_0["loss"] + assert isinstance(history_item_0_loss, dict) + assert _FORMAT_KEY in history_item_0_loss def test_mixed_numpy_torch(): diff --git a/tests/unit/json_serialize/test_json_serialize.py b/tests/unit/json_serialize/test_json_serialize.py index 7569a7ed..4fbce5ec 100644 --- a/tests/unit/json_serialize/test_json_serialize.py +++ b/tests/unit/json_serialize/test_json_serialize.py @@ -631,7 +631,7 @@ def tracking_check(self, obj, path): def test_JsonSerializer_init_no_positional_args(): """Test that JsonSerializer raises ValueError on positional arguments.""" with pytest.raises(ValueError, match="no positional arguments"): - JsonSerializer("invalid", "args") + JsonSerializer("invalid", "args") # type: ignore[invalid-argument-type] # Should work with keyword args serializer = JsonSerializer(error_mode=ErrorMode.WARN) diff --git a/tests/unit/test_jsonlines.py b/tests/unit/test_jsonlines.py index 56b19f91..96fd4f37 100644 --- a/tests/unit/test_jsonlines.py +++ b/tests/unit/test_jsonlines.py @@ -38,8 +38,14 @@ def test_jsonl_load(): # Verify the data matches assert loaded_data == test_data assert len(loaded_data) == 4 - assert loaded_data[0]["name"] == "Alice" - assert loaded_data[3]["nested"]["b"] == 2 + loaded_item_0 = loaded_data[0] + assert isinstance(loaded_item_0, dict) + assert loaded_item_0["name"] == "Alice" + loaded_item_3 = loaded_data[3] + assert isinstance(loaded_item_3, dict) + loaded_item_3_nested = loaded_item_3["nested"] + assert isinstance(loaded_item_3_nested, dict) + assert loaded_item_3_nested["b"] == 2 def test_jsonl_write(): From b593fe261ab39486fee93033975af9340b6e0245 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Sun, 18 Jan 2026 21:40:03 -0700 Subject: [PATCH 66/72] more wip typing fixes --- .meta/requirements/requirements-all.txt | 2 +- .meta/requirements/requirements-dev.txt | 2 +- .meta/requirements/requirements.txt | 2 +- .meta/typing-summary.txt | 71 +++++----- muutils/json_serialize/array.py | 23 ++++ .../json_serialize/serializable_dataclass.py | 34 +++-- muutils/json_serialize/util.py | 8 +- muutils/logger/exception_context.py | 17 ++- muutils/logger/timing.py | 8 +- muutils/misc/freezing.py | 28 ++-- pyproject.toml | 7 +- tests/unit/benchmark_parallel/__init__.py | 0 .../benchmark_parallel/test_benchmark_demo.py | 2 +- .../test_serializable_dataclass.py | 4 +- tests/unit/json_serialize/test_array.py | 127 +++++++++--------- uv.lock | 44 +++--- 16 files changed, 216 insertions(+), 163 deletions(-) create mode 100644 tests/unit/benchmark_parallel/__init__.py diff --git a/.meta/requirements/requirements-all.txt b/.meta/requirements/requirements-all.txt index 5daae08c..d1fa273f 100644 --- a/.meta/requirements/requirements-all.txt +++ b/.meta/requirements/requirements-all.txt @@ -896,7 +896,7 @@ triton==3.5.0 ; python_full_version >= '3.10' and python_full_version < '3.14' a # via torch twine==6.1.0 ; python_full_version < '3.9' twine==6.2.0 ; python_full_version >= '3.9' -ty==0.0.1a24 +ty==0.0.12 typeguard==4.4.0 ; python_full_version < '3.9' # via jaxtyping typer==0.20.0 diff --git a/.meta/requirements/requirements-dev.txt b/.meta/requirements/requirements-dev.txt index 45c5df2c..2e879649 100644 --- a/.meta/requirements/requirements-dev.txt +++ b/.meta/requirements/requirements-dev.txt @@ -724,7 +724,7 @@ traitlets==5.14.3 # nbformat twine==6.1.0 ; python_full_version < '3.9' twine==6.2.0 ; python_full_version >= '3.9' -ty==0.0.1a24 +ty==0.0.12 typer==0.20.0 # via pycln typing-extensions==4.13.2 ; python_full_version < '3.9' diff --git a/.meta/requirements/requirements.txt b/.meta/requirements/requirements.txt index 0f737086..4c3108dd 100644 --- a/.meta/requirements/requirements.txt +++ b/.meta/requirements/requirements.txt @@ -896,7 +896,7 @@ triton==3.5.0 ; python_full_version >= '3.10' and python_full_version < '3.14' a # via torch twine==6.1.0 ; python_full_version < '3.9' twine==6.2.0 ; python_full_version >= '3.9' -ty==0.0.1a24 +ty==0.0.12 typeguard==4.4.0 ; python_full_version < '3.9' # via jaxtyping typer==0.20.0 diff --git a/.meta/typing-summary.txt b/.meta/typing-summary.txt index eaef4f92..12f4d213 100644 --- a/.meta/typing-summary.txt +++ b/.meta/typing-summary.txt @@ -1,42 +1,42 @@ # Showing all errors -# mypy: Found 20 errors in 5 files (checked 116 source files) -# basedpyright: 493 errors, 3319 warnings, 0 notes -# ty: Found 198 diagnostics +# mypy: Found 22 errors in 5 files (checked 117 source files) +# basedpyright: 485 errors, 3280 warnings, 0 notes +# ty: Found 23 diagnostics [type_errors.mypy] -total_errors = 20 +total_errors = 22 [type_errors.mypy.by_type] -"arg-type" = 7 +"arg-type" = 8 "operator" = 6 "attr-defined" = 4 -"call-overload" = 2 -"import-not-found" = 1 +"call-overload" = 3 +"assignment" = 1 [type_errors.mypy.by_file] -"tests/unit/json_serialize/test_array.py" = 9 +"tests/unit/json_serialize/test_array.py" = 10 "tests/unit/json_serialize/test_json_serialize.py" = 6 +"tests/unit/json_serialize/test_array_torch.py" = 3 "tests/unit/json_serialize/test_serializable_field.py" = 2 -"tests/unit/json_serialize/test_array_torch.py" = 2 -"tests/unit/benchmark_parallel/test_benchmark_demo.py" = 1 +"muutils/logger/logger.py" = 1 [type_errors.basedpyright] -total_errors = 2593 +total_errors = 2557 [type_errors.basedpyright.by_type] -"reportUnknownParameterType" = 425 -"reportMissingParameterType" = 378 -"reportAny" = 354 +"reportUnknownParameterType" = 405 +"reportAny" = 362 +"reportMissingParameterType" = 357 "reportUnusedCallResult" = 290 "reportUnknownVariableType" = 195 "reportMissingTypeArgument" = 178 "reportUnknownLambdaType" = 127 -"reportUnknownMemberType" = 126 -"reportUnusedParameter" = 99 +"reportUnknownMemberType" = 125 +"reportUnusedParameter" = 98 "reportImplicitOverride" = 52 "reportInvalidTypeForm" = 49 -"reportUnannotatedClassAttribute" = 41 +"reportUnannotatedClassAttribute" = 40 "reportCallIssue" = 36 "reportPossiblyUnboundVariable" = 34 "reportPrivateUsage" = 31 @@ -78,13 +78,13 @@ total_errors = 2593 "tests/unit/test_interval.py" = 75 "muutils/misc/func.py" = 70 "tests/unit/cli/test_arg_bool.py" = 66 -"muutils/misc/freezing.py" = 65 "muutils/dictmagic.py" = 62 "tests/unit/web/test_bundle_html.py" = 55 "muutils/json_serialize/json_serialize.py" = 52 "muutils/spinner.py" = 51 "tests/unit/benchmark_parallel/benchmark_parallel.py" = 49 "muutils/json_serialize/serializable_field.py" = 48 +"muutils/misc/freezing.py" = 46 "tests/unit/errormode/test_errormode_init.py" = 46 "muutils/tensor_info.py" = 45 "muutils/parallel.py" = 42 @@ -111,15 +111,14 @@ total_errors = 2593 "muutils/logger/log_util.py" = 16 "muutils/tensor_utils.py" = 16 "tests/unit/json_serialize/serializable_dataclass/test_methods_no_override.py" = 15 -"tests/unit/json_serialize/test_array.py" = 15 "tests/unit/nbutils/test_configure_notebook.py" = 14 "muutils/jsonlines.py" = 13 +"tests/unit/json_serialize/test_array.py" = 13 "muutils/nbutils/run_notebook_tests.py" = 12 "muutils/errormode.py" = 11 "tests/unit/misc/test_misc.py" = 11 "tests/unit/test_tensor_info_torch.py" = 11 "muutils/dbg.py" = 10 -"muutils/logger/exception_context.py" = 10 "tests/unit/test_collect_warnings.py" = 10 "muutils/kappa.py" = 9 "muutils/misc/classes.py" = 9 @@ -131,10 +130,9 @@ total_errors = 2593 "muutils/collect_warnings.py" = 7 "muutils/nbutils/mermaid.py" = 7 "tests/unit/cli/test_command.py" = 7 -"muutils/logger/timing.py" = 6 +"tests/unit/json_serialize/serializable_dataclass/test_sdc_defaults.py" = 7 "tests/unit/nbutils/test_conversion.py" = 6 "tests/conftest.py" = 5 -"tests/unit/json_serialize/serializable_dataclass/test_sdc_defaults.py" = 5 "tests/unit/json_serialize/serializable_dataclass/test_sdc_properties_nested.py" = 5 "muutils/web/html_to_pdf.py" = 4 "tests/unit/math/test_matrix_powers_torch.py" = 4 @@ -148,21 +146,26 @@ total_errors = 2593 "tests/unit/test_chunks.py" = 2 "tests/unit/test_tensor_info.py" = 2 "tests/unit/test_timeit_fancy.py" = 2 -"muutils/json_serialize/util.py" = 1 -"tests/unit/benchmark_parallel/test_benchmark_demo.py" = 1 +"muutils/logger/exception_context.py" = 1 [type_errors.ty] -total_errors = 198 +total_errors = 23 [type_errors.ty.by_type] -"unknown-argument" = 164 -"unresolved-attribute" = 27 -"invalid-assignment" = 5 -"too-many-positional-arguments" = 2 +"invalid-argument-type" = 10 +"invalid-assignment" = 4 +"unsupported-operator" = 4 +"no-matching-overload" = 2 +"unresolved-attribute" = 2 +"invalid-type-form" = 1 [type_errors.ty.by_file] -"tests/unit/json_serialize/serializable_dataclass/test_serializable_dataclass.py" = 134 -"tests/unit/json_serialize/serializable_dataclass/test_methods_no_override.py" = 26 -"tests/unit/json_serialize/test_serializable_field.py" = 22 -"tests/unit/json_serialize/serializable_dataclass/test_sdc_properties_nested.py" = 8 -"tests/unit/test_dictmagic.py" = 8 +"muutils/spinner.py" = 5 +"tests/unit/json_serialize/test_json_serialize.py" = 4 +"muutils/interval.py" = 3 +"muutils/misc/func.py" = 3 +"muutils/logger/logger.py" = 2 +"muutils/sysinfo.py" = 2 +"tests/unit/web/test_bundle_html.py" = 2 +"muutils/logger/simplelogger.py" = 1 +"tests/unit/validate_type/test_validate_type.py" = 1 diff --git a/muutils/json_serialize/array.py b/muutils/json_serialize/array.py index 045ec80a..b19394d3 100644 --- a/muutils/json_serialize/array.py +++ b/muutils/json_serialize/array.py @@ -101,6 +101,29 @@ def arr_metadata(arr: Any) -> ArrayMetadata: # pyright: ignore[reportAny] } +@overload +def serialize_array( + jser: "JsonSerializer", + arr: "Union[np.ndarray, torch.Tensor]", + path: str | Sequence[str | int], + array_mode: Literal["list"], +) -> NumericList: ... +@overload +def serialize_array( + jser: "JsonSerializer", + arr: "Union[np.ndarray, torch.Tensor]", + path: str | Sequence[str | int], + array_mode: Literal[ + "array_list_meta", "array_hex_meta", "array_b64_meta", "zero_dim", "external" + ], +) -> SerializedArrayWithMeta: ... +@overload +def serialize_array( + jser: "JsonSerializer", + arr: "Union[np.ndarray, torch.Tensor]", + path: str | Sequence[str | int], + array_mode: None = None, +) -> SerializedArrayWithMeta | NumericList: ... def serialize_array( jser: "JsonSerializer", # type: ignore[name-defined] # noqa: F821 arr: "Union[np.ndarray, torch.Tensor]", diff --git a/muutils/json_serialize/serializable_dataclass.py b/muutils/json_serialize/serializable_dataclass.py index 81af3bda..2991abaf 100644 --- a/muutils/json_serialize/serializable_dataclass.py +++ b/muutils/json_serialize/serializable_dataclass.py @@ -58,7 +58,7 @@ class NestedClass(SerializableDataclass): import sys import typing import warnings -from typing import Any, Optional, Type, TypeVar +from typing import Any, Optional, Type, TypeVar, overload, TYPE_CHECKING from muutils.errormode import ErrorMode from muutils.validate_type import validate_type @@ -75,15 +75,19 @@ class NestedClass(SerializableDataclass): # pylint: disable=bad-mcs-classmethod-argument, too-many-arguments, protected-access -# this is quite horrible, but unfortunately mypy fails if we try to assign to `dataclass_transform` directly -# and every time we try to init a serializable dataclass it says the argument doesnt exist -if sys.version_info >= (3, 11): - from typing import dataclass_transform +# For type checkers: always use typing_extensions which they can resolve +# At runtime: use stdlib if available (3.11+), else typing_extensions, else mock +if TYPE_CHECKING: + from typing_extensions import dataclass_transform, Self else: - try: # pyright: ignore[reportUnreachable] - from typing_extensions import dataclass_transform - except Exception: - from muutils.json_serialize.dataclass_transform_mock import dataclass_transform + if sys.version_info >= (3, 11): + from typing import dataclass_transform, Self + else: + try: + from typing_extensions import dataclass_transform, Self + except Exception: + from muutils.json_serialize.dataclass_transform_mock import dataclass_transform + Self = TypeVar("Self") T = TypeVar("T") @@ -354,8 +358,16 @@ def serialize(self) -> dict[str, Any]: f"decorate {self.__class__ = } with `@serializable_dataclass`" ) + @overload + @classmethod + def load(cls, data: dict[str, Any]) -> Self: ... + + @overload + @classmethod + def load(cls, data: Self) -> Self: ... + @classmethod - def load(cls: Type[T], data: dict[str, Any] | T) -> T: + def load(cls, data: dict[str, Any] | Self) -> Self: "takes in an appropriately structured dict and returns an instance of the class, implemented by using `@serializable_dataclass` decorator" raise NotImplementedError(f"decorate {cls = } with `@serializable_dataclass`") @@ -782,7 +794,7 @@ def serialize(self) -> dict[str, Any]: # ====================================================================== # mypy thinks this isnt a classmethod @classmethod # type: ignore[misc] - def load(cls, data: dict[str, Any] | T) -> Type[T]: + def load(cls, data: dict[str, Any] | T) -> T: # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ if isinstance(data, cls): return data diff --git a/muutils/json_serialize/util.py b/muutils/json_serialize/util.py index b8231d88..255e8741 100644 --- a/muutils/json_serialize/util.py +++ b/muutils/json_serialize/util.py @@ -208,7 +208,6 @@ def dc_eq( dc1: Any, # pyright: ignore[reportAny] dc2: Any, # pyright: ignore[reportAny] except_when_class_mismatch: bool = False, - # TODO: why is this unused? false_when_class_mismatch: bool = True, except_when_field_mismatch: bool = False, ) -> bool: @@ -279,16 +278,19 @@ def dc_eq( raise TypeError( f"Cannot compare dataclasses of different classes: `{dc1.__class__}` and `{dc2.__class__}`" # pyright: ignore[reportAny] ) + if false_when_class_mismatch: + # return False immediately without attempting field comparison + return False + # classes don't match but we'll try to compare fields anyway if except_when_field_mismatch: dc1_fields: set[str] = set([fld.name for fld in dataclasses.fields(dc1)]) # pyright: ignore[reportAny] dc2_fields: set[str] = set([fld.name for fld in dataclasses.fields(dc2)]) # pyright: ignore[reportAny] fields_match: bool = set(dc1_fields) == set(dc2_fields) if not fields_match: - # if the fields match, keep going + # if the fields don't match, raise an error raise AttributeError( f"dataclasses {dc1} and {dc2} have different fields: `{dc1_fields}` and `{dc2_fields}`" ) - return False return all( array_safe_eq(getattr(dc1, fld.name), getattr(dc2, fld.name)) # pyright: ignore[reportAny] diff --git a/muutils/logger/exception_context.py b/muutils/logger/exception_context.py index 47605cbb..1604c44a 100644 --- a/muutils/logger/exception_context.py +++ b/muutils/logger/exception_context.py @@ -1,4 +1,8 @@ +from __future__ import annotations + import json +from types import TracebackType +from typing import IO from muutils.json_serialize import json_serialize @@ -20,13 +24,18 @@ class ExceptionContext: """ - def __init__(self, stream): - self.stream = stream + def __init__(self, stream: IO[str]) -> None: + self.stream: IO[str] = stream - def __enter__(self): + def __enter__(self) -> ExceptionContext: return self - def __exit__(self, exc_type, exc_value, exc_traceback): + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + exc_traceback: TracebackType | None, + ) -> bool: if exc_type is not None: self.stream.write( json.dumps( diff --git a/muutils/logger/timing.py b/muutils/logger/timing.py index 1b7b2e99..c9e7aa2f 100644 --- a/muutils/logger/timing.py +++ b/muutils/logger/timing.py @@ -1,6 +1,7 @@ from __future__ import annotations import time +from types import TracebackType from typing import Literal @@ -16,7 +17,12 @@ def __enter__(self) -> "TimerContext": self.start_time = time.time() return self - def __exit__(self, exc_type, exc_val, exc_tb) -> Literal[False]: + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> Literal[False]: self.end_time = time.time() self.elapsed_time = self.end_time - self.start_time return False diff --git a/muutils/misc/freezing.py b/muutils/misc/freezing.py index 6fc44bb0..38764c95 100644 --- a/muutils/misc/freezing.py +++ b/muutils/misc/freezing.py @@ -1,38 +1,38 @@ from __future__ import annotations -from typing import Any, TypeVar, overload +from typing import Any, Iterable, NoReturn, SupportsIndex, TypeVar, overload -class FrozenDict(dict): - def __setitem__(self, key, value): +class FrozenDict(dict): # type: ignore[type-arg] + def __setitem__(self, key: Any, value: Any) -> NoReturn: raise AttributeError("dict is frozen") - def __delitem__(self, key): + def __delitem__(self, key: Any) -> NoReturn: raise AttributeError("dict is frozen") -class FrozenList(list): - def __setitem__(self, index, value): +class FrozenList(list): # type: ignore[type-arg] + def __setitem__(self, index: SupportsIndex | slice, value: Any) -> NoReturn: raise AttributeError("list is frozen") - def __delitem__(self, index): + def __delitem__(self, index: SupportsIndex | slice) -> NoReturn: raise AttributeError("list is frozen") - def append(self, value): + def append(self, value: Any) -> NoReturn: raise AttributeError("list is frozen") - def extend(self, iterable): + def extend(self, iterable: Iterable[Any]) -> NoReturn: raise AttributeError("list is frozen") - def insert(self, index, value): + def insert(self, index: SupportsIndex, value: Any) -> NoReturn: raise AttributeError("list is frozen") - def remove(self, value): + def remove(self, value: Any) -> NoReturn: raise AttributeError("list is frozen") - def pop(self, index=-1): + def pop(self, index: SupportsIndex = -1) -> NoReturn: raise AttributeError("list is frozen") - def clear(self): + def clear(self) -> NoReturn: raise AttributeError("list is frozen") @@ -103,7 +103,7 @@ def freeze(instance: Any) -> Any: # create a new class which inherits from the original class class FrozenClass(instance.__class__): # type: ignore[name-defined] - def __setattr__(self, name, value): + def __setattr__(self, name: str, value: Any) -> NoReturn: raise AttributeError("class is frozen") FrozenClass.__name__ = f"FrozenClass__{instance.__class__.__name__}" diff --git a/pyproject.toml b/pyproject.toml index 831f8cf5..f68c6342 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -72,7 +72,7 @@ "mypy>=1.15; python_version >= '3.9'", "typing-extensions; python_version < '3.11'", "beartype>=0.14.1", - "ty", + "ty>=0.0.12", "basedpyright", # tests & coverage "pytest>=8.2.2", @@ -168,6 +168,11 @@ "docs/resources/make_docs.py", ] +# TODO: remove this once we clean up the `# type: ignore` comments that are needed +# for mypy/pyright but not for ty. See https://docs.astral.sh/ty/configuration/ +[tool.ty.rules] + unused-ignore-comment = "ignore" + [tool.mypy] exclude = [ # tests diff --git a/tests/unit/benchmark_parallel/__init__.py b/tests/unit/benchmark_parallel/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit/benchmark_parallel/test_benchmark_demo.py b/tests/unit/benchmark_parallel/test_benchmark_demo.py index 2174f266..b2b17737 100644 --- a/tests/unit/benchmark_parallel/test_benchmark_demo.py +++ b/tests/unit/benchmark_parallel/test_benchmark_demo.py @@ -7,7 +7,7 @@ def test_main(): """Test the main function of the benchmark script.""" - main( + _ = main( data_sizes=(1, 2), base_path=Path("tests/_temp/benchmark_demo"), plot=True, diff --git a/tests/unit/json_serialize/serializable_dataclass/test_serializable_dataclass.py b/tests/unit/json_serialize/serializable_dataclass/test_serializable_dataclass.py index 65848302..a0f7af15 100644 --- a/tests/unit/json_serialize/serializable_dataclass/test_serializable_dataclass.py +++ b/tests/unit/json_serialize/serializable_dataclass/test_serializable_dataclass.py @@ -1049,8 +1049,8 @@ class Node(SerializableDataclass): next: Optional["Node"] = serializable_field(default=None) # Create a cycle - node1 = Node("one") - node2 = Node("two") + node1 = Node(value="one") + node2 = Node(value="two") node1.next = node2 node2.next = node1 diff --git a/tests/unit/json_serialize/test_array.py b/tests/unit/json_serialize/test_array.py index 6c419d2b..172507b5 100644 --- a/tests/unit/json_serialize/test_array.py +++ b/tests/unit/json_serialize/test_array.py @@ -56,44 +56,41 @@ def test_load_array(self): loaded_array = load_array(serialized_array, array_mode="array_list_meta") assert np.array_equal(loaded_array, self.array_3d) - def test_serialize_load_integration(self): - for array_mode in [ - "list", - "array_list_meta", - "array_hex_meta", - "array_b64_meta", - ]: - for array in [self.array_1d, self.array_2d, self.array_3d]: - serialized_array = serialize_array( - self.jser, - array, - "test_path", - array_mode=array_mode, # type: ignore[arg-type] - ) - loaded_array = load_array(serialized_array, array_mode=array_mode) # type: ignore[arg-type] - assert np.array_equal(loaded_array, array) - - def test_serialize_load_zero_dim(self): - for array_mode in [ - # TODO: do we even want to support "list" mode for zero-dim arrays? - # "list", - "array_list_meta", - "array_hex_meta", - "array_b64_meta", - ]: - print(array_mode) + @pytest.mark.parametrize( + "array_mode", + ["list", "array_list_meta", "array_hex_meta", "array_b64_meta"], + ) + def test_serialize_load_integration(self, array_mode: ArrayMode): + for array in [self.array_1d, self.array_2d, self.array_3d]: serialized_array = serialize_array( self.jser, - self.array_zero_dim, + array, "test_path", - array_mode=array_mode, # type: ignore[arg-type] + array_mode=array_mode, ) - print(serialized_array) - loaded_array = load_array(serialized_array) - assert np.array_equal(loaded_array, self.array_zero_dim) + # The overload combinations for serialize_array -> load_array are complex + # since array_mode determines both the serialized type and load method + loaded_array = load_array(serialized_array, array_mode=array_mode) # type: ignore[call-overload, arg-type] + assert np.array_equal(loaded_array, array) + + # TODO: do we even want to support "list" mode for zero-dim arrays? + @pytest.mark.parametrize( + "array_mode", + ["array_list_meta", "array_hex_meta", "array_b64_meta"], + ) + def test_serialize_load_zero_dim(self, array_mode: ArrayMode): + serialized_array = serialize_array( + self.jser, + self.array_zero_dim, + "test_path", + array_mode=array_mode, + ) + loaded_array = load_array(serialized_array) + assert np.array_equal(loaded_array, self.array_zero_dim) -def test_array_shape_dtype_preservation(): +@pytest.mark.parametrize("mode", ["array_list_meta", "array_hex_meta", "array_b64_meta"]) +def test_array_shape_dtype_preservation(mode: ArrayMode): """Test that various shapes and dtypes are preserved through serialization.""" # Test different shapes shapes_and_arrays = [ @@ -119,26 +116,24 @@ def test_array_shape_dtype_preservation(): # Test shapes preservation for arr, description in shapes_and_arrays: - for mode in ["array_list_meta", "array_hex_meta", "array_b64_meta"]: - serialized = serialize_array(jser, arr, "test", array_mode=mode) # type: ignore[arg-type] - loaded = load_array(serialized) - assert loaded.shape == arr.shape, ( - f"Shape mismatch for {description} in {mode}" - ) - assert loaded.dtype == arr.dtype, ( - f"Dtype mismatch for {description} in {mode}" - ) - assert np.array_equal(loaded, arr), ( - f"Data mismatch for {description} in {mode}" - ) + serialized = serialize_array(jser, arr, "test", array_mode=mode) + loaded = load_array(serialized) + assert loaded.shape == arr.shape, ( + f"Shape mismatch for {description} in {mode}" + ) + assert loaded.dtype == arr.dtype, ( + f"Dtype mismatch for {description} in {mode}" + ) + assert np.array_equal(loaded, arr), ( + f"Data mismatch for {description} in {mode}" + ) # Test dtypes preservation for arr, expected_dtype in dtype_tests: - for mode in ["array_list_meta", "array_hex_meta", "array_b64_meta"]: - serialized = serialize_array(jser, arr, "test", array_mode=mode) # type: ignore[arg-type] - loaded = load_array(serialized) - assert loaded.dtype == expected_dtype, f"Dtype not preserved: {mode}" - assert np.array_equal(loaded, arr), f"Data not preserved: {mode}" + serialized = serialize_array(jser, arr, "test", array_mode=mode) + loaded = load_array(serialized) + assert loaded.dtype == expected_dtype, f"Dtype not preserved: {mode}" + assert np.array_equal(loaded, arr), f"Data not preserved: {mode}" def test_array_serialization_handlers(): @@ -186,7 +181,8 @@ def test_array_serialization_handlers(): assert _FORMAT_KEY in result -def test_array_edge_cases(): +@pytest.mark.parametrize("mode", ["array_list_meta", "array_hex_meta", "array_b64_meta"]) +def test_array_edge_cases(mode: ArrayMode): """Test edge cases: empty arrays, unusual dtypes, and boundary conditions.""" jser = JsonSerializer(array_mode="array_list_meta") @@ -196,12 +192,11 @@ def test_array_edge_cases(): empty_3d = np.array([[]], dtype=np.int64).reshape(1, 1, 0) for empty_arr in [empty_1d, empty_2d, empty_3d]: - for mode in ["array_list_meta", "array_hex_meta", "array_b64_meta"]: - serialized = serialize_array(jser, empty_arr, "test", array_mode=mode) # type: ignore[arg-type] - loaded = load_array(serialized) - assert loaded.shape == empty_arr.shape - assert loaded.dtype == empty_arr.dtype - assert np.array_equal(loaded, empty_arr) + serialized = serialize_array(jser, empty_arr, "test", array_mode=mode) + loaded = load_array(serialized) + assert loaded.shape == empty_arr.shape + assert loaded.dtype == empty_arr.dtype + assert np.array_equal(loaded, empty_arr) # Complex dtypes complex_arr = np.array([1 + 2j, 3 + 4j, 5 + 6j], dtype=np.complex64) @@ -214,17 +209,15 @@ def test_array_edge_cases(): # Large arrays (test that serialization doesn't break) large_arr = np.random.rand(100, 100) - for mode in ["array_list_meta", "array_hex_meta", "array_b64_meta"]: - serialized = serialize_array(jser, large_arr, "test", array_mode=mode) # type: ignore[arg-type] - loaded = load_array(serialized) - assert np.allclose(loaded, large_arr) + serialized = serialize_array(jser, large_arr, "test", array_mode=mode) + loaded = load_array(serialized) + assert np.allclose(loaded, large_arr) # Arrays with special values special_arr = np.array([np.inf, -np.inf, np.nan, 0.0, -0.0], dtype=np.float64) - for mode in ["array_list_meta", "array_hex_meta", "array_b64_meta"]: - serialized = serialize_array(jser, special_arr, "test", array_mode=mode) # type: ignore[arg-type] - loaded = load_array(serialized) - # Use special comparison for NaN - assert np.isnan(loaded[2]) and np.isnan(special_arr[2]) - assert np.array_equal(loaded[:2], special_arr[:2]) # inf values - assert np.array_equal(loaded[3:], special_arr[3:]) # zeros + serialized = serialize_array(jser, special_arr, "test", array_mode=mode) + loaded = load_array(serialized) + # Use special comparison for NaN + assert np.isnan(loaded[2]) and np.isnan(special_arr[2]) + assert np.array_equal(loaded[:2], special_arr[:2]) # inf values + assert np.array_equal(loaded[3:], special_arr[3:]) # zeros diff --git a/uv.lock b/uv.lock index 150f1ce6..e26f3826 100644 --- a/uv.lock +++ b/uv.lock @@ -4183,7 +4183,7 @@ dev = [ { name = "tomli", marker = "python_full_version < '3.11'", specifier = ">=2.1.0" }, { name = "tornado", marker = "python_full_version >= '3.9'", specifier = ">=6.5" }, { name = "twine" }, - { name = "ty" }, + { name = "ty", specifier = ">=0.0.12" }, { name = "typing-extensions", marker = "python_full_version < '3.11'" }, ] lint = [ @@ -7730,27 +7730,27 @@ wheels = [ [[package]] name = "ty" -version = "0.0.1a24" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/fc/71/a1db0d604be8d0067342e7aad74ab0c7fec6bea20eb33b6a6324baabf45f/ty-0.0.1a24.tar.gz", hash = "sha256:3273c514df5b9954c9928ee93b6a0872d12310ea8de42249a6c197720853e096", size = 4386721, upload-time = "2025-10-23T13:33:29.729Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/ab/89/21fb275cb676d3480b67fbbf6eb162aec200b4dcb10c7885bffc754dc73f/ty-0.0.1a24-py3-none-linux_armv6l.whl", hash = "sha256:d478cd02278b988d5767df5821a0f03b99ef848f6fc29e8c77f30e859b89c779", size = 8833903, upload-time = "2025-10-23T13:32:53.552Z" }, - { url = "https://files.pythonhosted.org/packages/a2/22/beb127bce67fc2a1f3704b6b39505d77a7078a61becfbe10c5ee7ed9f5d8/ty-0.0.1a24-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:de758790f05f0a3bb396da4c75f770c85ab3a46095ec188b830c916bd5a5bc10", size = 8691210, upload-time = "2025-10-23T13:32:55.706Z" }, - { url = "https://files.pythonhosted.org/packages/39/bd/190f5e934339669191179fa01c60f5a140822dc465f0d4d312985903d109/ty-0.0.1a24-py3-none-macosx_11_0_arm64.whl", hash = "sha256:68f325ddc8cfb7a7883501e5e22f01284c5d5912aaa901d21e477f38edf4e625", size = 8138421, upload-time = "2025-10-23T13:32:58.718Z" }, - { url = "https://files.pythonhosted.org/packages/40/84/f08020dabad1e660957bb641b2ba42fe1e1e87192c234b1fc1fd6fb42cf2/ty-0.0.1a24-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:49a52bbb1f8b0b29ad717d3fd70bd2afe752e991072fd13ff2fc14f03945c849", size = 8419861, upload-time = "2025-10-23T13:33:00.068Z" }, - { url = "https://files.pythonhosted.org/packages/e5/cc/e3812f7c1c2a0dcfb1bf8a5d6a7e5aa807a483a632c0d5734ea50a60a9ae/ty-0.0.1a24-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:12945fe358fb0f73acf0b72a29efcc80da73f8d95cfe7f11a81e4d8d730e7b18", size = 8641443, upload-time = "2025-10-23T13:33:01.887Z" }, - { url = "https://files.pythonhosted.org/packages/e3/8b/3fc047d04afbba4780aba031dc80e06f6e95d888bbddb8fd6da502975cfb/ty-0.0.1a24-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6231e190989798b0860d15a8f225e3a06a6ce442a7083d743eb84f5b4b83b980", size = 8997853, upload-time = "2025-10-23T13:33:03.951Z" }, - { url = "https://files.pythonhosted.org/packages/e0/d9/ae1475d9200ecf6b196a59357ea3e4f4aa00e1d38c9237ca3f267a4a3ef7/ty-0.0.1a24-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:7c6401f4a7532eab63dd7fe015c875792a701ca4b1a44fc0c490df32594e071f", size = 9676864, upload-time = "2025-10-23T13:33:05.744Z" }, - { url = "https://files.pythonhosted.org/packages/cc/d9/abd6849f0601b24d5d5098e47b00dfbdfe44a4f6776f2e54a21005739bdf/ty-0.0.1a24-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:83c69759bfa2a00278aa94210eded35aea599215d16460445cbbf5b36f77c454", size = 9351386, upload-time = "2025-10-23T13:33:07.807Z" }, - { url = "https://files.pythonhosted.org/packages/63/5c/639e0fe3b489c65b12b38385fe5032024756bc07f96cd994d7df3ab579ef/ty-0.0.1a24-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:71146713cb8f804aad2b2e87a8efa7e7df0a5a25aed551af34498bcc2721ae03", size = 9517674, upload-time = "2025-10-23T13:33:09.641Z" }, - { url = "https://files.pythonhosted.org/packages/78/ae/323f373fcf54a883e39ea3fb6f83ed6d1eda6dfd8246462d0cfd81dac781/ty-0.0.1a24-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d4836854411059de592f0ecc62193f2b24fc3acbfe6ce6ce0bf2c6d1a5ea9de7", size = 9000468, upload-time = "2025-10-23T13:33:11.51Z" }, - { url = "https://files.pythonhosted.org/packages/14/26/1a4be005aa4326264f0e7ce554844d5ef8afc4c5600b9a38b05671e9ed18/ty-0.0.1a24-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:a7f0b8546d27605e09cd0fe08dc28c1d177bf7498316dd11c3bb8ef9440bf2e1", size = 8377164, upload-time = "2025-10-23T13:33:13.504Z" }, - { url = "https://files.pythonhosted.org/packages/73/2f/dcd6b449084e53a2beb536d8721a2517143a2353413b5b323d6eb9a31705/ty-0.0.1a24-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:4e2fbf7dce2311127748824e03d9de2279e96ab5713029c3fa58acbaf19b2f51", size = 8672709, upload-time = "2025-10-23T13:33:15.213Z" }, - { url = "https://files.pythonhosted.org/packages/dc/2e/8b3b45d46085a79547e6db5295f42c6b798a0240d34454181e2ca947183c/ty-0.0.1a24-py3-none-musllinux_1_2_i686.whl", hash = "sha256:f35b7f0a65f7e34e59f34173164946c89a4c4b1d1c18cabe662356a35f33efcd", size = 8788732, upload-time = "2025-10-23T13:33:17.347Z" }, - { url = "https://files.pythonhosted.org/packages/cf/c5/7675ff8693ad13044d86d8d4c824caf6bbb00340df05ad93d0e9d1e0338b/ty-0.0.1a24-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:120fe95eaf2a200f531f949e3dd0a9d95ab38915ce388412873eae28c499c0b9", size = 9095693, upload-time = "2025-10-23T13:33:19.836Z" }, - { url = "https://files.pythonhosted.org/packages/62/0b/bdba5d31aa3f0298900675fd355eec63a9c682aa46ef743dbac8f28b4608/ty-0.0.1a24-py3-none-win32.whl", hash = "sha256:d8d8379264a8c14e1f4ca9e117e72df3bf0a0b0ca64c5fd18affbb6142d8662a", size = 8361302, upload-time = "2025-10-23T13:33:21.572Z" }, - { url = "https://files.pythonhosted.org/packages/b4/48/127a45e16c49563df82829542ca64b0bc387591a777df450972bc85957e6/ty-0.0.1a24-py3-none-win_amd64.whl", hash = "sha256:2e826d75bddd958643128c309f6c47673ed6cef2ea5f2b3cd1a1159a1392971a", size = 9039221, upload-time = "2025-10-23T13:33:23.055Z" }, - { url = "https://files.pythonhosted.org/packages/31/67/9161fbb8c1a2005938bdb5ccd4e4c98ee4bea2d262afb777a4b69aa15eb5/ty-0.0.1a24-py3-none-win_arm64.whl", hash = "sha256:2efbfcdc94d306f0d25f3efe2a90c0f953132ca41a1a47d0bae679d11cdb15aa", size = 8514044, upload-time = "2025-10-23T13:33:27.816Z" }, +version = "0.0.12" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/b5/78/ba1a4ad403c748fbba8be63b7e774a90e80b67192f6443d624c64fe4aaab/ty-0.0.12.tar.gz", hash = "sha256:cd01810e106c3b652a01b8f784dd21741de9fdc47bd595d02c122a7d5cefeee7", size = 4981303, upload-time = "2026-01-14T22:30:48.537Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7d/8f/c21314d074dda5fb13d3300fa6733fd0d8ff23ea83a721818740665b6314/ty-0.0.12-py3-none-linux_armv6l.whl", hash = "sha256:eb9da1e2c68bd754e090eab39ed65edf95168d36cbeb43ff2bd9f86b4edd56d1", size = 9614164, upload-time = "2026-01-14T22:30:44.016Z" }, + { url = "https://files.pythonhosted.org/packages/09/28/f8a4d944d13519d70c486e8f96d6fa95647ac2aa94432e97d5cfec1f42f6/ty-0.0.12-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:c181f42aa19b0ed7f1b0c2d559980b1f1d77cc09419f51c8321c7ddf67758853", size = 9542337, upload-time = "2026-01-14T22:30:05.687Z" }, + { url = "https://files.pythonhosted.org/packages/e1/9c/f576e360441de7a8201daa6dc4ebc362853bc5305e059cceeb02ebdd9a48/ty-0.0.12-py3-none-macosx_11_0_arm64.whl", hash = "sha256:1f829e1eecd39c3e1b032149db7ae6a3284f72fc36b42436e65243a9ed1173db", size = 8909582, upload-time = "2026-01-14T22:30:46.089Z" }, + { url = "https://files.pythonhosted.org/packages/d6/13/0898e494032a5d8af3060733d12929e3e7716db6c75eac63fa125730a3e7/ty-0.0.12-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f45162e7826e1789cf3374627883cdeb0d56b82473a0771923e4572928e90be3", size = 9384932, upload-time = "2026-01-14T22:30:13.769Z" }, + { url = "https://files.pythonhosted.org/packages/e4/1a/b35b6c697008a11d4cedfd34d9672db2f0a0621ec80ece109e13fca4dfef/ty-0.0.12-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:d11fec40b269bec01e751b2337d1c7ffa959a2c2090a950d7e21c2792442cccd", size = 9453140, upload-time = "2026-01-14T22:30:11.131Z" }, + { url = "https://files.pythonhosted.org/packages/dd/1e/71c9edbc79a3c88a0711324458f29c7dbf6c23452c6e760dc25725483064/ty-0.0.12-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:09d99e37e761a4d2651ad9d5a610d11235fbcbf35dc6d4bc04abf54e7cf894f1", size = 9960680, upload-time = "2026-01-14T22:30:33.621Z" }, + { url = "https://files.pythonhosted.org/packages/0e/75/39375129f62dd22f6ad5a99cd2a42fd27d8b91b235ce2db86875cdad397d/ty-0.0.12-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:d9ca0cdb17bd37397da7b16a7cd23423fc65c3f9691e453ad46c723d121225a1", size = 10904518, upload-time = "2026-01-14T22:30:08.464Z" }, + { url = "https://files.pythonhosted.org/packages/32/5e/26c6d88fafa11a9d31ca9f4d12989f57782ec61e7291d4802d685b5be118/ty-0.0.12-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:fcf2757b905e7eddb7e456140066335b18eb68b634a9f72d6f54a427ab042c64", size = 10525001, upload-time = "2026-01-14T22:30:16.454Z" }, + { url = "https://files.pythonhosted.org/packages/c2/a5/2f0b91894af13187110f9ad7ee926d86e4e6efa755c9c88a820ed7f84c85/ty-0.0.12-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:00cf34c1ebe1147efeda3021a1064baa222c18cdac114b7b050bbe42deb4ca80", size = 10307103, upload-time = "2026-01-14T22:30:41.221Z" }, + { url = "https://files.pythonhosted.org/packages/4b/77/13d0410827e4bc713ebb7fdaf6b3590b37dcb1b82e0a81717b65548f2442/ty-0.0.12-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2bb3a655bd869352e9a22938d707631ac9fbca1016242b1f6d132d78f347c851", size = 10072737, upload-time = "2026-01-14T22:30:51.783Z" }, + { url = "https://files.pythonhosted.org/packages/e1/dd/fc36d8bac806c74cf04b4ca735bca14d19967ca84d88f31e121767880df1/ty-0.0.12-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:4658e282c7cb82be304052f8f64f9925f23c3c4f90eeeb32663c74c4b095d7ba", size = 9368726, upload-time = "2026-01-14T22:30:18.683Z" }, + { url = "https://files.pythonhosted.org/packages/54/70/9e8e461647550f83e2fe54bc632ccbdc17a4909644783cdbdd17f7296059/ty-0.0.12-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:c167d838eaaa06e03bb66a517f75296b643d950fbd93c1d1686a187e5a8dbd1f", size = 9454704, upload-time = "2026-01-14T22:30:22.759Z" }, + { url = "https://files.pythonhosted.org/packages/04/9b/6292cf7c14a0efeca0539cf7d78f453beff0475cb039fbea0eb5d07d343d/ty-0.0.12-py3-none-musllinux_1_2_i686.whl", hash = "sha256:2956e0c9ab7023533b461d8a0e6b2ea7b78e01a8dde0688e8234d0fce10c4c1c", size = 9649829, upload-time = "2026-01-14T22:30:31.234Z" }, + { url = "https://files.pythonhosted.org/packages/49/bd/472a5d2013371e4870886cff791c94abdf0b92d43d305dd0f8e06b6ff719/ty-0.0.12-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:5c6a3fd7479580009f21002f3828320621d8a82d53b7ba36993234e3ccad58c8", size = 10162814, upload-time = "2026-01-14T22:30:36.174Z" }, + { url = "https://files.pythonhosted.org/packages/31/e9/2ecbe56826759845a7c21d80aa28187865ea62bc9757b056f6cbc06f78ed/ty-0.0.12-py3-none-win32.whl", hash = "sha256:a91c24fd75c0f1796d8ede9083e2c0ec96f106dbda73a09fe3135e075d31f742", size = 9140115, upload-time = "2026-01-14T22:30:38.903Z" }, + { url = "https://files.pythonhosted.org/packages/5d/6d/d9531eff35a5c0ec9dbc10231fac21f9dd6504814048e81d6ce1c84dc566/ty-0.0.12-py3-none-win_amd64.whl", hash = "sha256:df151894be55c22d47068b0f3b484aff9e638761e2267e115d515fcc9c5b4a4b", size = 9884532, upload-time = "2026-01-14T22:30:25.112Z" }, + { url = "https://files.pythonhosted.org/packages/e9/f3/20b49e75967023b123a221134548ad7000f9429f13fdcdda115b4c26305f/ty-0.0.12-py3-none-win_arm64.whl", hash = "sha256:cea99d334b05629de937ce52f43278acf155d3a316ad6a35356635f886be20ea", size = 9313974, upload-time = "2026-01-14T22:30:27.44Z" }, ] [[package]] From bdb12fa6230e42d6345f183832ccee090da384a6 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Mon, 19 Jan 2026 14:04:39 -0700 Subject: [PATCH 67/72] typing wip --- .meta/typing-summary.txt | 22 +++++----- muutils/json_serialize/array.py | 31 ++++++------- muutils/logger/exception_context.py | 12 ++++-- muutils/logger/headerfuncs.py | 4 +- muutils/logger/log_util.py | 20 ++++----- muutils/logger/logger.py | 30 +++++++------ muutils/logger/simplelogger.py | 2 +- tests/unit/json_serialize/test_array.py | 43 +++++++++++++++++-- tests/unit/json_serialize/test_array_torch.py | 3 +- 9 files changed, 104 insertions(+), 63 deletions(-) diff --git a/.meta/typing-summary.txt b/.meta/typing-summary.txt index 12f4d213..615178a4 100644 --- a/.meta/typing-summary.txt +++ b/.meta/typing-summary.txt @@ -1,7 +1,7 @@ # Showing all errors # mypy: Found 22 errors in 5 files (checked 117 source files) -# basedpyright: 485 errors, 3280 warnings, 0 notes +# basedpyright: 474 errors, 3219 warnings, 0 notes # ty: Found 23 diagnostics [type_errors.mypy] @@ -22,15 +22,15 @@ total_errors = 22 "muutils/logger/logger.py" = 1 [type_errors.basedpyright] -total_errors = 2557 +total_errors = 2532 [type_errors.basedpyright.by_type] -"reportUnknownParameterType" = 405 -"reportAny" = 362 -"reportMissingParameterType" = 357 +"reportUnknownParameterType" = 393 +"reportAny" = 371 +"reportMissingParameterType" = 349 "reportUnusedCallResult" = 290 -"reportUnknownVariableType" = 195 -"reportMissingTypeArgument" = 178 +"reportUnknownVariableType" = 191 +"reportMissingTypeArgument" = 168 "reportUnknownLambdaType" = 127 "reportUnknownMemberType" = 125 "reportUnusedParameter" = 98 @@ -103,12 +103,10 @@ total_errors = 2557 "muutils/mlutils.py" = 22 "muutils/validate_type.py" = 20 "tests/unit/validate_type/test_validate_type_special.py" = 20 -"muutils/logger/logger.py" = 18 "muutils/math/matrix_powers.py" = 18 "tests/unit/json_serialize/test_util.py" = 18 "muutils/interval.py" = 17 "tests/unit/test_kappa.py" = 17 -"muutils/logger/log_util.py" = 16 "muutils/tensor_utils.py" = 16 "tests/unit/json_serialize/serializable_dataclass/test_methods_no_override.py" = 15 "tests/unit/nbutils/test_configure_notebook.py" = 14 @@ -116,6 +114,7 @@ total_errors = 2557 "tests/unit/json_serialize/test_array.py" = 13 "muutils/nbutils/run_notebook_tests.py" = 12 "muutils/errormode.py" = 11 +"muutils/logger/logger.py" = 11 "tests/unit/misc/test_misc.py" = 11 "tests/unit/test_tensor_info_torch.py" = 11 "muutils/dbg.py" = 10 @@ -124,13 +123,13 @@ total_errors = 2557 "muutils/misc/classes.py" = 9 "muutils/cli/arg_bool.py" = 8 "muutils/json_serialize/dataclass_transform_mock.py" = 8 -"muutils/logger/headerfuncs.py" = 8 "tests/unit/test_console_unicode.py" = 8 "tests/util/test_fire.py" = 8 "muutils/collect_warnings.py" = 7 "muutils/nbutils/mermaid.py" = 7 "tests/unit/cli/test_command.py" = 7 "tests/unit/json_serialize/serializable_dataclass/test_sdc_defaults.py" = 7 +"muutils/logger/headerfuncs.py" = 6 "tests/unit/nbutils/test_conversion.py" = 6 "tests/conftest.py" = 5 "tests/unit/json_serialize/serializable_dataclass/test_sdc_properties_nested.py" = 5 @@ -141,12 +140,13 @@ total_errors = 2557 "tests/unit/test_mlutils.py" = 4 "muutils/cli/command.py" = 3 "muutils/json_serialize/array.py" = 3 -"muutils/logger/simplelogger.py" = 3 +"muutils/logger/simplelogger.py" = 2 "tests/unit/logger/test_log_util.py" = 2 "tests/unit/test_chunks.py" = 2 "tests/unit/test_tensor_info.py" = 2 "tests/unit/test_timeit_fancy.py" = 2 "muutils/logger/exception_context.py" = 1 +"muutils/logger/log_util.py" = 1 [type_errors.ty] total_errors = 23 diff --git a/muutils/json_serialize/array.py b/muutils/json_serialize/array.py index b19394d3..552c815a 100644 --- a/muutils/json_serialize/array.py +++ b/muutils/json_serialize/array.py @@ -56,6 +56,15 @@ "zero_dim", ] +# Modes that produce SerializedArrayWithMeta (dict with metadata) +ArrayModeWithMeta = Literal[ + "array_list_meta", + "array_hex_meta", + "array_b64_meta", + "zero_dim", + "external", +] + def array_n_elements(arr: Any) -> int: # type: ignore[name-defined] # pyright: ignore[reportAny] """get the number of elements in an array""" @@ -113,9 +122,7 @@ def serialize_array( jser: "JsonSerializer", arr: "Union[np.ndarray, torch.Tensor]", path: str | Sequence[str | int], - array_mode: Literal[ - "array_list_meta", "array_hex_meta", "array_b64_meta", "zero_dim", "external" - ], + array_mode: ArrayModeWithMeta, ) -> SerializedArrayWithMeta: ... @overload def serialize_array( @@ -220,13 +227,7 @@ def serialize_array( @overload def infer_array_mode( arr: SerializedArrayWithMeta, -) -> Literal[ - "array_list_meta", - "array_hex_meta", - "array_b64_meta", - "external", - "zero_dim", -]: ... +) -> ArrayModeWithMeta: ... @overload def infer_array_mode(arr: NumericList) -> Literal["list"]: ... def infer_array_mode( @@ -269,15 +270,7 @@ def infer_array_mode( @overload def load_array( arr: SerializedArrayWithMeta, - array_mode: Optional[ - Literal[ - "array_list_meta", - "array_hex_meta", - "array_b64_meta", - "external", - "zero_dim", - ] - ] = None, + array_mode: Optional[ArrayModeWithMeta] = None, ) -> np.ndarray: ... @overload def load_array( diff --git a/muutils/logger/exception_context.py b/muutils/logger/exception_context.py index 1604c44a..fd77d95c 100644 --- a/muutils/logger/exception_context.py +++ b/muutils/logger/exception_context.py @@ -2,11 +2,17 @@ import json from types import TracebackType -from typing import IO +from typing import Protocol from muutils.json_serialize import json_serialize +class WritableStream(Protocol): + """Protocol for objects that support write operations.""" + + def write(self, msg: str) -> int: ... + + class ExceptionContext: """context manager which catches all exceptions happening while the context is open, `.write()` the exception trace to the given stream, and then raises the exception @@ -24,8 +30,8 @@ class ExceptionContext: """ - def __init__(self, stream: IO[str]) -> None: - self.stream: IO[str] = stream + def __init__(self, stream: WritableStream) -> None: + self.stream: WritableStream = stream def __enter__(self) -> ExceptionContext: return self diff --git a/muutils/logger/headerfuncs.py b/muutils/logger/headerfuncs.py index 8f327268..49241954 100644 --- a/muutils/logger/headerfuncs.py +++ b/muutils/logger/headerfuncs.py @@ -10,7 +10,7 @@ class HeaderFunction(Protocol): - def __call__(self, msg: Any, lvl: int, **kwargs) -> str: ... + def __call__(self, msg: Any, lvl: int, **kwargs: Any) -> str: ... def md_header_function( @@ -19,7 +19,7 @@ def md_header_function( stream: str | None = None, indent_lvl: str = " ", extra_indent: str = "", - **kwargs, + **kwargs: Any, ) -> str: """standard header function. will output diff --git a/muutils/logger/log_util.py b/muutils/logger/log_util.py index 80ded213..08de7455 100644 --- a/muutils/logger/log_util.py +++ b/muutils/logger/log_util.py @@ -1,5 +1,5 @@ from __future__ import annotations -from typing import TypeVar +from typing import Any, TypeVar from muutils.jsonlines import jsonl_load_log T_StreamValue = TypeVar("T_StreamValue") @@ -16,10 +16,10 @@ def get_any_from_stream( raise KeyError(f"key '{key}' not found in stream") -def gather_log(file: str) -> dict[str, list[dict]]: +def gather_log(file: str) -> dict[str, list[dict[str, Any]]]: """gathers and sorts all streams from a log""" - data: list[dict] = jsonl_load_log(file) - output: dict[str, list[dict]] = dict() + data: list[dict[str, Any]] = jsonl_load_log(file) + output: dict[str, list[dict[str, Any]]] = dict() for item in data: stream: str = item.get("_stream", "default") @@ -33,11 +33,11 @@ def gather_log(file: str) -> dict[str, list[dict]]: def gather_stream( file: str, stream: str, -) -> list[dict]: +) -> list[dict[str, Any]]: """gets all entries from a specific stream in a log file""" - data: list[dict] = jsonl_load_log(file) + data: list[dict[str, Any]] = jsonl_load_log(file) - output: list[dict] = list() + output: list[dict[str, Any]] = list() for item in data: # select for the stream @@ -51,7 +51,7 @@ def gather_val( stream: str, keys: tuple[str, ...], allow_skip: bool = True, -) -> list[list]: +) -> list[list[Any]]: """gather specific keys from a specific stream in a log file example: @@ -70,9 +70,9 @@ def gather_val( ``` """ - data: list[dict] = jsonl_load_log(file) + data: list[dict[str, Any]] = jsonl_load_log(file) - output: list[list] = list() + output: list[list[Any]] = list() for item in data: # select for the stream diff --git a/muutils/logger/logger.py b/muutils/logger/logger.py index 24501a0f..7ef6054d 100644 --- a/muutils/logger/logger.py +++ b/muutils/logger/logger.py @@ -79,8 +79,8 @@ def __init__( keep_last_msg_time: bool = True, # junk args timestamp: bool = True, - **kwargs, - ): + **kwargs: Any, + ) -> None: # junk arg checking # ================================================== if len(kwargs) > 0: @@ -152,18 +152,22 @@ def _exception_context( # level: int = -256, # **kwargs, ) -> ExceptionContext: + import sys + s: LoggingStream = self._streams[stream] - return ExceptionContext(stream=s) + handler = s.handler if s.handler is not None else sys.stderr + return ExceptionContext(stream=handler) - def log( # type: ignore # yes, the signatures are different here. + def log( self, msg: JSONitem = None, + *, lvl: int | None = None, stream: str | None = None, console_print: bool = False, extra_indent: str = "", - **kwargs, - ): + **kwargs: Any, + ) -> None: """logging function ### Parameters: @@ -271,13 +275,13 @@ def log_elapsed_last( lvl: int | None = None, stream: str | None = None, console_print: bool = True, - **kwargs, - ) -> float: + **kwargs: Any, + ) -> None: """logs the time elapsed since the last message was printed to the console (in any stream)""" if self._last_msg_time is None: raise ValueError("no last message time!") else: - return self.log( + self.log( {"elapsed_time": round(time.time() - self._last_msg_time, 6)}, lvl=(lvl if lvl is not None else self._console_print_threshold), stream=stream, @@ -294,13 +298,13 @@ def flush_all(self): if stream.handler is not None: stream.handler.flush() - def __getattr__(self, stream: str) -> Callable: + def __getattr__(self, stream: str) -> Callable[..., Any]: if stream.startswith("_"): raise AttributeError(f"invalid stream name {stream} (no underscores)") return partial(self.log, stream=stream) - def __getitem__(self, stream: str): + def __getitem__(self, stream: str) -> Callable[..., Any]: return partial(self.log, stream=stream) - def __call__(self, *args, **kwargs): - return self.log(*args, **kwargs) + def __call__(self, *args: Any, **kwargs: Any) -> None: + self.log(*args, **kwargs) diff --git a/muutils/logger/simplelogger.py b/muutils/logger/simplelogger.py index b39a683a..38e19a77 100644 --- a/muutils/logger/simplelogger.py +++ b/muutils/logger/simplelogger.py @@ -64,7 +64,7 @@ def __init__( assert log_path is not None self._log_file_handle = open(log_path, "w", encoding="utf-8") - def log(self, msg: JSONitem, console_print: bool = False, **kwargs): + def log(self, msg: JSONitem, *, console_print: bool = False, **kwargs: Any) -> None: """log a message to the log file, and optionally to the console""" if console_print: print(msg) diff --git a/tests/unit/json_serialize/test_array.py b/tests/unit/json_serialize/test_array.py index 172507b5..a5eed85b 100644 --- a/tests/unit/json_serialize/test_array.py +++ b/tests/unit/json_serialize/test_array.py @@ -4,6 +4,7 @@ from muutils.json_serialize import JsonSerializer from muutils.json_serialize.array import ( ArrayMode, + ArrayModeWithMeta, arr_metadata, array_n_elements, load_array, @@ -73,12 +74,48 @@ def test_serialize_load_integration(self, array_mode: ArrayMode): loaded_array = load_array(serialized_array, array_mode=array_mode) # type: ignore[call-overload, arg-type] assert np.array_equal(loaded_array, array) + def test_serialize_load_list(self): + """Test serialize/load with 'list' mode - separate function for type safety.""" + for array in [self.array_1d, self.array_2d, self.array_3d]: + serialized_array = serialize_array( + self.jser, array, "test_path", array_mode="list" + ) + loaded_array = load_array(serialized_array, array_mode="list") + assert np.array_equal(loaded_array, array) + + def test_serialize_load_array_list_meta(self): + """Test serialize/load with 'array_list_meta' mode - separate function for type safety.""" + for array in [self.array_1d, self.array_2d, self.array_3d]: + serialized_array = serialize_array( + self.jser, array, "test_path", array_mode="array_list_meta" + ) + loaded_array = load_array(serialized_array, array_mode="array_list_meta") + assert np.array_equal(loaded_array, array) + + def test_serialize_load_array_hex_meta(self): + """Test serialize/load with 'array_hex_meta' mode - separate function for type safety.""" + for array in [self.array_1d, self.array_2d, self.array_3d]: + serialized_array = serialize_array( + self.jser, array, "test_path", array_mode="array_hex_meta" + ) + loaded_array = load_array(serialized_array, array_mode="array_hex_meta") + assert np.array_equal(loaded_array, array) + + def test_serialize_load_array_b64_meta(self): + """Test serialize/load with 'array_b64_meta' mode - separate function for type safety.""" + for array in [self.array_1d, self.array_2d, self.array_3d]: + serialized_array = serialize_array( + self.jser, array, "test_path", array_mode="array_b64_meta" + ) + loaded_array = load_array(serialized_array, array_mode="array_b64_meta") + assert np.array_equal(loaded_array, array) + # TODO: do we even want to support "list" mode for zero-dim arrays? @pytest.mark.parametrize( "array_mode", ["array_list_meta", "array_hex_meta", "array_b64_meta"], ) - def test_serialize_load_zero_dim(self, array_mode: ArrayMode): + def test_serialize_load_zero_dim(self, array_mode: ArrayModeWithMeta): serialized_array = serialize_array( self.jser, self.array_zero_dim, @@ -90,7 +127,7 @@ def test_serialize_load_zero_dim(self, array_mode: ArrayMode): @pytest.mark.parametrize("mode", ["array_list_meta", "array_hex_meta", "array_b64_meta"]) -def test_array_shape_dtype_preservation(mode: ArrayMode): +def test_array_shape_dtype_preservation(mode: ArrayModeWithMeta): """Test that various shapes and dtypes are preserved through serialization.""" # Test different shapes shapes_and_arrays = [ @@ -182,7 +219,7 @@ def test_array_serialization_handlers(): @pytest.mark.parametrize("mode", ["array_list_meta", "array_hex_meta", "array_b64_meta"]) -def test_array_edge_cases(mode: ArrayMode): +def test_array_edge_cases(mode: ArrayModeWithMeta): """Test edge cases: empty arrays, unusual dtypes, and boundary conditions.""" jser = JsonSerializer(array_mode="array_list_meta") diff --git a/tests/unit/json_serialize/test_array_torch.py b/tests/unit/json_serialize/test_array_torch.py index d22e22ac..5b003974 100644 --- a/tests/unit/json_serialize/test_array_torch.py +++ b/tests/unit/json_serialize/test_array_torch.py @@ -5,6 +5,7 @@ from muutils.json_serialize import JsonSerializer from muutils.json_serialize.array import ( ArrayMode, + ArrayModeWithMeta, arr_metadata, array_n_elements, load_array, @@ -15,7 +16,7 @@ # pylint: disable=missing-class-docstring -_WITH_META_ARRAY_MODES: list[ArrayMode] = [ +_WITH_META_ARRAY_MODES: list[ArrayModeWithMeta] = [ "array_list_meta", "array_hex_meta", "array_b64_meta", From cbcce7cea3b915dbef49d3b55e48f698b5d966f0 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Mon, 19 Jan 2026 14:04:48 -0700 Subject: [PATCH 68/72] make format --- muutils/json_serialize/json_serialize.py | 4 +++- .../json_serialize/serializable_dataclass.py | 5 ++++- tests/unit/json_serialize/test_array.py | 20 +++++++++---------- tests/unit/json_serialize/test_array_torch.py | 1 - .../json_serialize/test_serializable_field.py | 4 +++- 5 files changed, 19 insertions(+), 15 deletions(-) diff --git a/muutils/json_serialize/json_serialize.py b/muutils/json_serialize/json_serialize.py index ae459b3e..7d6454e4 100644 --- a/muutils/json_serialize/json_serialize.py +++ b/muutils/json_serialize/json_serialize.py @@ -195,7 +195,9 @@ def _serialize_override_serialize_func( ), SerializerHandler( check=lambda self, obj, path: str(type(obj)) == "", - serialize_func=lambda self, obj, path: cast(JSONitem, serialize_array(self, obj, path=path)), + serialize_func=lambda self, obj, path: cast( + JSONitem, serialize_array(self, obj, path=path) + ), uid="numpy.ndarray", desc="numpy arrays", ), diff --git a/muutils/json_serialize/serializable_dataclass.py b/muutils/json_serialize/serializable_dataclass.py index 2991abaf..44181132 100644 --- a/muutils/json_serialize/serializable_dataclass.py +++ b/muutils/json_serialize/serializable_dataclass.py @@ -86,7 +86,10 @@ class NestedClass(SerializableDataclass): try: from typing_extensions import dataclass_transform, Self except Exception: - from muutils.json_serialize.dataclass_transform_mock import dataclass_transform + from muutils.json_serialize.dataclass_transform_mock import ( + dataclass_transform, + ) + Self = TypeVar("Self") T = TypeVar("T") diff --git a/tests/unit/json_serialize/test_array.py b/tests/unit/json_serialize/test_array.py index a5eed85b..5a3e3b85 100644 --- a/tests/unit/json_serialize/test_array.py +++ b/tests/unit/json_serialize/test_array.py @@ -126,7 +126,9 @@ def test_serialize_load_zero_dim(self, array_mode: ArrayModeWithMeta): assert np.array_equal(loaded_array, self.array_zero_dim) -@pytest.mark.parametrize("mode", ["array_list_meta", "array_hex_meta", "array_b64_meta"]) +@pytest.mark.parametrize( + "mode", ["array_list_meta", "array_hex_meta", "array_b64_meta"] +) def test_array_shape_dtype_preservation(mode: ArrayModeWithMeta): """Test that various shapes and dtypes are preserved through serialization.""" # Test different shapes @@ -155,15 +157,9 @@ def test_array_shape_dtype_preservation(mode: ArrayModeWithMeta): for arr, description in shapes_and_arrays: serialized = serialize_array(jser, arr, "test", array_mode=mode) loaded = load_array(serialized) - assert loaded.shape == arr.shape, ( - f"Shape mismatch for {description} in {mode}" - ) - assert loaded.dtype == arr.dtype, ( - f"Dtype mismatch for {description} in {mode}" - ) - assert np.array_equal(loaded, arr), ( - f"Data mismatch for {description} in {mode}" - ) + assert loaded.shape == arr.shape, f"Shape mismatch for {description} in {mode}" + assert loaded.dtype == arr.dtype, f"Dtype mismatch for {description} in {mode}" + assert np.array_equal(loaded, arr), f"Data mismatch for {description} in {mode}" # Test dtypes preservation for arr, expected_dtype in dtype_tests: @@ -218,7 +214,9 @@ def test_array_serialization_handlers(): assert _FORMAT_KEY in result -@pytest.mark.parametrize("mode", ["array_list_meta", "array_hex_meta", "array_b64_meta"]) +@pytest.mark.parametrize( + "mode", ["array_list_meta", "array_hex_meta", "array_b64_meta"] +) def test_array_edge_cases(mode: ArrayModeWithMeta): """Test edge cases: empty arrays, unusual dtypes, and boundary conditions.""" jser = JsonSerializer(array_mode="array_list_meta") diff --git a/tests/unit/json_serialize/test_array_torch.py b/tests/unit/json_serialize/test_array_torch.py index 5b003974..a4ae803c 100644 --- a/tests/unit/json_serialize/test_array_torch.py +++ b/tests/unit/json_serialize/test_array_torch.py @@ -4,7 +4,6 @@ from muutils.json_serialize import JsonSerializer from muutils.json_serialize.array import ( - ArrayMode, ArrayModeWithMeta, arr_metadata, array_n_elements, diff --git a/tests/unit/json_serialize/test_serializable_field.py b/tests/unit/json_serialize/test_serializable_field.py index 36ef6787..2b6bb088 100644 --- a/tests/unit/json_serialize/test_serializable_field.py +++ b/tests/unit/json_serialize/test_serializable_field.py @@ -140,7 +140,9 @@ def test_from_Field(): assert sf.deserialize_fn is None # Test with default_factory and init=False to avoid init=True, serialize=False error - dc_field2: dataclasses.Field[list[Any]] = field(default_factory=list, repr=True, init=True) # type: ignore[assignment] + dc_field2: dataclasses.Field[list[Any]] = field( + default_factory=list, repr=True, init=True + ) # type: ignore[assignment] sf2 = SerializableField.from_Field(dc_field2) assert sf2.default_factory == list # noqa: E721 assert sf2.default is dataclasses.MISSING From cf22ca470fac31a229eab43d18189c4597ea2299 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Mon, 19 Jan 2026 14:17:25 -0700 Subject: [PATCH 69/72] more typing fixes! --- .meta/typing-summary.txt | 36 ++++++++------------ CHANGELOG.md | 17 +++++++++ muutils/json_serialize/serializable_field.py | 8 ++--- muutils/logger/exception_context.py | 2 +- muutils/spinner.py | 31 +++++++++++------ muutils/tensor_info.py | 28 ++++++++++++++- tests/unit/json_serialize/test_array.py | 14 ++++---- tests/unit/logger/test_logger.py | 6 ++-- 8 files changed, 95 insertions(+), 47 deletions(-) create mode 100644 CHANGELOG.md diff --git a/.meta/typing-summary.txt b/.meta/typing-summary.txt index 615178a4..a9ac3565 100644 --- a/.meta/typing-summary.txt +++ b/.meta/typing-summary.txt @@ -1,36 +1,31 @@ # Showing all errors -# mypy: Found 22 errors in 5 files (checked 117 source files) -# basedpyright: 474 errors, 3219 warnings, 0 notes -# ty: Found 23 diagnostics +# mypy: Found 10 errors in 3 files (checked 117 source files) +# basedpyright: 471 errors, 3193 warnings, 0 notes +# ty: Found 17 diagnostics [type_errors.mypy] -total_errors = 22 +total_errors = 10 [type_errors.mypy.by_type] -"arg-type" = 8 "operator" = 6 -"attr-defined" = 4 -"call-overload" = 3 -"assignment" = 1 +"arg-type" = 4 [type_errors.mypy.by_file] -"tests/unit/json_serialize/test_array.py" = 10 "tests/unit/json_serialize/test_json_serialize.py" = 6 -"tests/unit/json_serialize/test_array_torch.py" = 3 "tests/unit/json_serialize/test_serializable_field.py" = 2 -"muutils/logger/logger.py" = 1 +"tests/unit/json_serialize/test_array_torch.py" = 2 [type_errors.basedpyright] -total_errors = 2532 +total_errors = 2520 [type_errors.basedpyright.by_type] -"reportUnknownParameterType" = 393 -"reportAny" = 371 -"reportMissingParameterType" = 349 +"reportUnknownParameterType" = 385 +"reportAny" = 376 +"reportMissingParameterType" = 341 "reportUnusedCallResult" = 290 "reportUnknownVariableType" = 191 -"reportMissingTypeArgument" = 168 +"reportMissingTypeArgument" = 167 "reportUnknownLambdaType" = 127 "reportUnknownMemberType" = 125 "reportUnusedParameter" = 98 @@ -81,7 +76,6 @@ total_errors = 2532 "muutils/dictmagic.py" = 62 "tests/unit/web/test_bundle_html.py" = 55 "muutils/json_serialize/json_serialize.py" = 52 -"muutils/spinner.py" = 51 "tests/unit/benchmark_parallel/benchmark_parallel.py" = 49 "muutils/json_serialize/serializable_field.py" = 48 "muutils/misc/freezing.py" = 46 @@ -89,6 +83,7 @@ total_errors = 2532 "muutils/tensor_info.py" = 45 "muutils/parallel.py" = 42 "muutils/statcounter.py" = 41 +"muutils/spinner.py" = 39 "muutils/sysinfo.py" = 39 "tests/unit/json_serialize/test_serializable_field.py" = 38 "muutils/web/bundle_html.py" = 37 @@ -149,10 +144,10 @@ total_errors = 2532 "muutils/logger/log_util.py" = 1 [type_errors.ty] -total_errors = 23 +total_errors = 17 [type_errors.ty.by_type] -"invalid-argument-type" = 10 +"invalid-argument-type" = 4 "invalid-assignment" = 4 "unsupported-operator" = 4 "no-matching-overload" = 2 @@ -160,12 +155,11 @@ total_errors = 23 "invalid-type-form" = 1 [type_errors.ty.by_file] -"muutils/spinner.py" = 5 "tests/unit/json_serialize/test_json_serialize.py" = 4 "muutils/interval.py" = 3 "muutils/misc/func.py" = 3 -"muutils/logger/logger.py" = 2 "muutils/sysinfo.py" = 2 "tests/unit/web/test_bundle_html.py" = 2 +"muutils/logger/logger.py" = 1 "muutils/logger/simplelogger.py" = 1 "tests/unit/validate_type/test_validate_type.py" = 1 diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 00000000..50fc9e4f --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,17 @@ +# Changelog + +## [Unreleased] + +### Breaking Changes + +- **`muutils.logger`**: `Logger.log()` and `SimpleLogger.log()` now require keyword arguments for all parameters after `msg`. This change was made to fix type checker compatibility between the two classes. + + **Before:** + ```python + logger.log("message", -10) # lvl as positional arg + ``` + + **After:** + ```python + logger.log("message", lvl=-10) # lvl as keyword arg + ``` diff --git a/muutils/json_serialize/serializable_field.py b/muutils/json_serialize/serializable_field.py index 7313e4da..bbd5da75 100644 --- a/muutils/json_serialize/serializable_field.py +++ b/muutils/json_serialize/serializable_field.py @@ -143,7 +143,7 @@ def from_Field(cls, field: "dataclasses.Field[Any]") -> "SerializableField": @overload def serializable_field( # only `default_factory` is provided - *_args, + *_args: Any, default_factory: Callable[[], Sfield_T], default: dataclasses._MISSING_TYPE = dataclasses.MISSING, init: bool = True, @@ -162,7 +162,7 @@ def serializable_field( # only `default_factory` is provided ) -> Sfield_T: ... @overload def serializable_field( # only `default` is provided - *_args, + *_args: Any, default: Sfield_T, default_factory: dataclasses._MISSING_TYPE = dataclasses.MISSING, init: bool = True, @@ -181,7 +181,7 @@ def serializable_field( # only `default` is provided ) -> Sfield_T: ... @overload def serializable_field( # both `default` and `default_factory` are MISSING - *_args, + *_args: Any, default: dataclasses._MISSING_TYPE = dataclasses.MISSING, default_factory: dataclasses._MISSING_TYPE = dataclasses.MISSING, init: bool = True, @@ -199,7 +199,7 @@ def serializable_field( # both `default` and `default_factory` are MISSING **kwargs: Any, ) -> Any: ... def serializable_field( # general implementation - *_args, + *_args: Any, default: Union[Any, dataclasses._MISSING_TYPE] = dataclasses.MISSING, default_factory: Union[Any, dataclasses._MISSING_TYPE] = dataclasses.MISSING, init: bool = True, diff --git a/muutils/logger/exception_context.py b/muutils/logger/exception_context.py index fd77d95c..11acd7ca 100644 --- a/muutils/logger/exception_context.py +++ b/muutils/logger/exception_context.py @@ -10,7 +10,7 @@ class WritableStream(Protocol): """Protocol for objects that support write operations.""" - def write(self, msg: str) -> int: ... + def write(self, __s: str) -> int: ... class ExceptionContext: diff --git a/muutils/spinner.py b/muutils/spinner.py index 0abf1a1d..13507c4b 100644 --- a/muutils/spinner.py +++ b/muutils/spinner.py @@ -3,12 +3,15 @@ using the base `Spinner` class while some code is running. """ +from __future__ import annotations + import os import time from dataclasses import dataclass, field import threading import sys from functools import wraps +from types import TracebackType from typing import ( List, Dict, @@ -66,21 +69,22 @@ def __post_init__(self): @classmethod def from_any(cls, arg: "SpinnerConfigArg") -> "SpinnerConfig": - if isinstance(arg, str): + # check SpinnerConfig first to help type narrowing + if isinstance(arg, SpinnerConfig): + return arg + elif isinstance(arg, str): return SPINNERS[arg] elif isinstance(arg, list): return SpinnerConfig(working=arg) elif isinstance(arg, dict): return SpinnerConfig(**arg) - elif isinstance(arg, SpinnerConfig): - return arg else: raise TypeError( f"to create a SpinnerConfig, you must pass a string (key), list (working seq), dict (kwargs to SpinnerConfig), or SpinnerConfig, but got {type(arg) = }, {arg = }" ) -SpinnerConfigArg = Union[str, List[str], SpinnerConfig, dict] +SpinnerConfigArg = Union[str, List[str], SpinnerConfig, Dict[str, Any]] SPINNERS: Dict[str, SpinnerConfig] = dict( default=SpinnerConfig(working=["|", "/", "-", "\\"], success="#", fail="X"), @@ -262,7 +266,7 @@ def long_running_function(): def __init__( self, # no positional args - *args, + *args: Any, config: SpinnerConfigArg = "default", update_interval: float = 0.1, initial_value: str = "", @@ -410,16 +414,21 @@ def stop(self, failed: bool = False) -> None: self.state = "fail" if failed else "success" -class NoOpContextManager(ContextManager): +class NoOpContextManager(ContextManager): # type: ignore[type-arg] """A context manager that does nothing.""" - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: pass - def __enter__(self): + def __enter__(self) -> NoOpContextManager: return self - def __exit__(self, exc_type, exc_value, traceback): + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: pass @@ -439,7 +448,7 @@ def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: # TODO: type hint that the `update_status` kwarg is not needed when calling the function we just decorated def spinner_decorator( - *args, + *args: Any, # passed to `Spinner.__init__` config: SpinnerConfigArg = "default", update_interval: float = 0.1, @@ -452,7 +461,7 @@ def spinner_decorator( # deprecated spinner_chars: Union[str, Sequence[str], None] = None, spinner_complete: Optional[str] = None, - **kwargs, + **kwargs: Any, ) -> Callable[[DecoratedFunction], DecoratedFunction]: """see `Spinner` for parameters. Also takes `mutable_kwarg_key` diff --git a/muutils/tensor_info.py b/muutils/tensor_info.py index edd7ad77..023e5203 100644 --- a/muutils/tensor_info.py +++ b/muutils/tensor_info.py @@ -3,7 +3,33 @@ from __future__ import annotations import numpy as np -from typing import Union, Any, Literal, List, Dict, overload, Optional +from typing import Union, Any, Literal, List, Dict, overload, Optional, TYPE_CHECKING + +if TYPE_CHECKING: + from typing import TypedDict +else: + try: + from typing import TypedDict + except ImportError: + from typing_extensions import TypedDict + + +class ArraySummarySettings(TypedDict): + """Type definition for array_summary default settings.""" + + fmt: OutputFormat + precision: int + stats: bool + shape: bool + dtype: bool + device: bool + requires_grad: bool + sparkline: bool + sparkline_bins: int + sparkline_logy: Optional[bool] + colored: bool + as_list: bool + eq_char: str # Global color definitions COLORS: Dict[str, Dict[str, str]] = { diff --git a/tests/unit/json_serialize/test_array.py b/tests/unit/json_serialize/test_array.py index 5a3e3b85..2dee92cc 100644 --- a/tests/unit/json_serialize/test_array.py +++ b/tests/unit/json_serialize/test_array.py @@ -132,7 +132,7 @@ def test_serialize_load_zero_dim(self, array_mode: ArrayModeWithMeta): def test_array_shape_dtype_preservation(mode: ArrayModeWithMeta): """Test that various shapes and dtypes are preserved through serialization.""" # Test different shapes - shapes_and_arrays = [ + shapes_and_arrays: list[tuple[np.ndarray, str]] = [ (np.array([1, 2, 3], dtype=np.int32), "1D int32"), (np.array([[1.5, 2.5], [3.5, 4.5]], dtype=np.float32), "2D float32"), (np.array([[[1]], [[2]]], dtype=np.int8), "3D int8"), @@ -140,7 +140,7 @@ def test_array_shape_dtype_preservation(mode: ArrayModeWithMeta): ] # Test different dtypes - dtype_tests = [ + dtype_tests: list[tuple[np.ndarray, type[np.generic]]] = [ (np.array([1, 2, 3], dtype=np.int8), np.int8), (np.array([1, 2, 3], dtype=np.int16), np.int16), (np.array([1, 2, 3], dtype=np.int32), np.int32), @@ -222,11 +222,13 @@ def test_array_edge_cases(mode: ArrayModeWithMeta): jser = JsonSerializer(array_mode="array_list_meta") # Empty arrays with different shapes - empty_1d = np.array([], dtype=np.int32) - empty_2d = np.array([[], []], dtype=np.float32).reshape(2, 0) - empty_3d = np.array([[]], dtype=np.int64).reshape(1, 1, 0) + empty_arrays: list[np.ndarray] = [ + np.array([], dtype=np.int32), + np.array([[], []], dtype=np.float32).reshape(2, 0), + np.array([[]], dtype=np.int64).reshape(1, 1, 0), + ] - for empty_arr in [empty_1d, empty_2d, empty_3d]: + for empty_arr in empty_arrays: serialized = serialize_array(jser, empty_arr, "test", array_mode=mode) loaded = load_array(serialized) assert loaded.shape == empty_arr.shape diff --git a/tests/unit/logger/test_logger.py b/tests/unit/logger/test_logger.py index 15be01dc..2aaa05ac 100644 --- a/tests/unit/logger/test_logger.py +++ b/tests/unit/logger/test_logger.py @@ -13,7 +13,7 @@ def test_logger(): logger.mystream("hello mystream") logger.mystream("hello mystream, again") - logger.log("something is wrong!", -10) - logger.log("something is very wrong!", -30) + logger.log("something is wrong!", lvl=-10) + logger.log("something is very wrong!", lvl=-30) - logger.log("not very important", 50) + logger.log("not very important", lvl=50) From e208c3b55bfedf451e935cab1ba31bea00c3d651 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Mon, 19 Jan 2026 15:58:56 -0700 Subject: [PATCH 70/72] more wip typing fixes --- .meta/typing-summary.txt | 300 +++++++++++------- CHANGELOG.md | 5 + TODO.md | 3 +- muutils/dbg.py | 5 +- muutils/dictmagic.py | 6 +- .../dataclass_transform_mock.py | 4 +- .../json_serialize/serializable_dataclass.py | 30 +- muutils/json_serialize/util.py | 64 ++-- muutils/kappa.py | 6 +- muutils/logger/logger.py | 2 +- muutils/logger/simplelogger.py | 2 +- muutils/math/matrix_powers.py | 6 +- muutils/misc/func.py | 8 +- muutils/misc/typing_breakdown.py | 95 +++++- muutils/mlutils.py | 7 +- muutils/nbutils/configure_notebook.py | 2 +- muutils/nbutils/mermaid.py | 2 +- muutils/parallel.py | 12 +- muutils/spinner.py | 7 +- muutils/statcounter.py | 6 +- muutils/sysinfo.py | 2 +- muutils/tensor_info.py | 171 ++++++---- tests/unit/json_serialize/test_array_torch.py | 9 +- .../json_serialize/test_json_serialize.py | 15 +- .../json_serialize/test_serializable_field.py | 8 +- tests/unit/json_serialize/test_util.py | 67 +++- .../unit/validate_type/test_validate_type.py | 2 +- tests/unit/web/test_bundle_html.py | 2 + 28 files changed, 565 insertions(+), 283 deletions(-) diff --git a/.meta/typing-summary.txt b/.meta/typing-summary.txt index a9ac3565..ba02b48e 100644 --- a/.meta/typing-summary.txt +++ b/.meta/typing-summary.txt @@ -1,165 +1,233 @@ # Showing all errors -# mypy: Found 10 errors in 3 files (checked 117 source files) -# basedpyright: 471 errors, 3193 warnings, 0 notes -# ty: Found 17 diagnostics +# mypy: Found 5 errors in 2 files (checked 117 source files) +# basedpyright: 437 errors, 3173 warnings, 0 notes +# ty: Found 9 diagnostics [type_errors.mypy] -total_errors = 10 +total_errors = 5 [type_errors.mypy.by_type] -"operator" = 6 -"arg-type" = 4 +"method-assign" = 3 +"assignment" = 1 +"call-overload" = 1 [type_errors.mypy.by_file] -"tests/unit/json_serialize/test_json_serialize.py" = 6 -"tests/unit/json_serialize/test_serializable_field.py" = 2 -"tests/unit/json_serialize/test_array_torch.py" = 2 +"muutils/json_serialize/serializable_dataclass.py" = 4 +"muutils/dbg.py" = 1 [type_errors.basedpyright] -total_errors = 2520 +total_errors = 436 [type_errors.basedpyright.by_type] -"reportUnknownParameterType" = 385 -"reportAny" = 376 -"reportMissingParameterType" = 341 -"reportUnusedCallResult" = 290 -"reportUnknownVariableType" = 191 -"reportMissingTypeArgument" = 167 -"reportUnknownLambdaType" = 127 -"reportUnknownMemberType" = 125 -"reportUnusedParameter" = 98 -"reportImplicitOverride" = 52 -"reportInvalidTypeForm" = 49 -"reportUnannotatedClassAttribute" = 40 -"reportCallIssue" = 36 +"reportMissingTypeArgument" = 154 +"reportArgumentType" = 63 +"reportInvalidTypeForm" = 48 +"reportCallIssue" = 39 "reportPossiblyUnboundVariable" = 34 -"reportPrivateUsage" = 31 -"reportUnreachable" = 30 -"reportUnnecessaryIsInstance" = 18 -"reportUntypedClassDecorator" = 17 +"reportAttributeAccessIssue" = 26 "reportMissingSuperCall" = 14 -"reportUntypedFunctionDecorator" = 12 -"reportUnnecessaryTypeIgnoreComment" = 11 +"reportAssignmentType" = 11 "reportInvalidTypeArguments" = 9 -"reportUnnecessaryComparison" = 8 -"reportUnusedVariable" = 8 "reportUndefinedVariable" = 7 "reportIndexIssue" = 7 "reportOptionalMemberAccess" = 6 "reportUninitializedInstanceVariable" = 5 +"reportOperatorIssue" = 4 +"reportReturnType" = 4 +"reportMissingImports" = 3 +"reportIncompatibleMethodOverride" = 1 +"reportGeneralTypeIssues" = 1 + +[type_errors.basedpyright.by_file] +"muutils/misc/func.py" = 49 +"tests/unit/misc/test_func.py" = 36 +"tests/unit/test_dbg.py" = 32 +"muutils/nbutils/configure_notebook.py" = 22 +"tests/unit/misc/test_freeze.py" = 20 +"muutils/dictmagic.py" = 19 +"tests/unit/test_interval.py" = 17 +"muutils/json_serialize/serializable_dataclass.py" = 16 +"muutils/mlutils.py" = 15 +"muutils/statcounter.py" = 14 +"tests/unit/benchmark_parallel/benchmark_parallel.py" = 12 +"tests/unit/test_parallel.py" = 12 +"muutils/sysinfo.py" = 10 +"muutils/parallel.py" = 9 +"tests/unit/json_serialize/test_array.py" = 9 +"tests/unit/validate_type/test_validate_type.py" = 9 +"tests/unit/validate_type/test_validate_type_GENERATED.py" = 9 +"tests/unit/validate_type/test_validate_type_special.py" = 9 +"muutils/misc/freezing.py" = 8 +"muutils/misc/sequence.py" = 8 +"muutils/jsonlines.py" = 7 +"tests/unit/test_dictmagic.py" = 7 +"tests/unit/test_spinner.py" = 7 +"muutils/json_serialize/serializable_field.py" = 6 +"muutils/spinner.py" = 6 +"tests/unit/json_serialize/test_serializable_field.py" = 6 +"tests/unit/web/test_bundle_html.py" = 6 +"muutils/interval.py" = 5 +"muutils/nbutils/convert_ipynb_to_script.py" = 5 +"tests/unit/json_serialize/test_json_serialize.py" = 5 +"muutils/tensor_info.py" = 4 +"tests/unit/json_serialize/test_util.py" = 4 +"tests/unit/errormode/test_errormode_functionality.py" = 3 +"muutils/dbg.py" = 2 +"muutils/nbutils/mermaid.py" = 2 +"muutils/nbutils/run_notebook_tests.py" = 2 +"tests/unit/json_serialize/serializable_dataclass/test_sdc_defaults.py" = 2 +"tests/unit/math/test_bins.py" = 2 +"tests/unit/misc/test_sequence.py" = 2 +"tests/unit/test_chunks.py" = 2 +"tests/unit/test_mlutils.py" = 2 +"tests/unit/test_tensor_info_torch.py" = 2 +"tests/unit/test_tensor_utils_torch.py" = 2 +"tests/unit/validate_type/test_get_kwargs.py" = 2 +"muutils/collect_warnings.py" = 1 +"muutils/errormode.py" = 1 +"muutils/kappa.py" = 1 +"muutils/web/bundle_html.py" = 1 +"tests/unit/nbutils/test_conversion.py" = 1 +"tests/unit/test_collect_warnings.py" = 1 +"tests/unit/test_timeit_fancy.py" = 1 +"tests/util/test_fire.py" = 1 + +[type_warnings.basedpyright] +total_warnings = 3173 + +[type_warnings.basedpyright.by_type] +"reportAny" = 758 +"reportUnknownParameterType" = 419 +"reportUnknownArgumentType" = 391 +"reportMissingParameterType" = 316 +"reportUnknownVariableType" = 304 +"reportUnusedCallResult" = 288 +"reportUnknownMemberType" = 235 +"reportUnknownLambdaType" = 124 +"reportUnusedParameter" = 98 +"reportImplicitOverride" = 52 +"reportUnannotatedClassAttribute" = 41 +"reportPrivateUsage" = 31 +"reportUnreachable" = 30 +"reportUnnecessaryIsInstance" = 20 +"reportUnnecessaryTypeIgnoreComment" = 14 +"reportUntypedFunctionDecorator" = 12 +"reportUnnecessaryComparison" = 8 +"reportUnusedVariable" = 8 "reportCallInDefaultInitializer" = 4 -"reportUnusedClass" = 4 +"reportPrivateLocalImportUsage" = 3 "reportMissingTypeStubs" = 3 -"reportMissingImports" = 3 +"reportUnusedClass" = 3 "reportUnusedExpression" = 3 "reportImplicitStringConcatenation" = 2 -"reportOperatorIssue" = 2 "reportUntypedNamedTuple" = 2 +"reportInvalidTypeVarUse" = 1 +"reportPrivateImportUsage" = 1 "reportUnusedImport" = 1 "reportUnusedFunction" = 1 -"reportGeneralTypeIssues" = 1 -[type_errors.basedpyright.by_file] -"tests/unit/json_serialize/test_json_serialize.py" = 164 -"tests/unit/test_dbg.py" = 150 -"tests/unit/test_parallel.py" = 131 -"tests/unit/json_serialize/serializable_dataclass/test_serializable_dataclass.py" = 119 -"tests/unit/misc/test_func.py" = 98 -"tests/unit/validate_type/test_get_kwargs.py" = 87 -"tests/unit/validate_type/test_validate_type.py" = 86 -"tests/unit/validate_type/test_validate_type_GENERATED.py" = 86 -"muutils/json_serialize/serializable_dataclass.py" = 80 -"tests/unit/test_interval.py" = 75 -"muutils/misc/func.py" = 70 +[type_warnings.basedpyright.by_file] +"tests/unit/json_serialize/test_json_serialize.py" = 184 +"tests/unit/json_serialize/serializable_dataclass/test_serializable_dataclass.py" = 179 +"tests/unit/test_parallel.py" = 173 +"tests/unit/benchmark_parallel/benchmark_parallel.py" = 159 +"tests/unit/test_dbg.py" = 154 +"muutils/json_serialize/serializable_dataclass.py" = 100 +"tests/unit/validate_type/test_get_kwargs.py" = 98 +"tests/unit/test_interval.py" = 87 +"tests/unit/misc/test_func.py" = 84 +"tests/unit/validate_type/test_validate_type.py" = 79 +"tests/unit/validate_type/test_validate_type_GENERATED.py" = 79 +"muutils/json_serialize/json_serialize.py" = 74 +"muutils/dictmagic.py" = 68 +"tests/unit/errormode/test_errormode_functionality.py" = 67 "tests/unit/cli/test_arg_bool.py" = 66 -"muutils/dictmagic.py" = 62 -"tests/unit/web/test_bundle_html.py" = 55 -"muutils/json_serialize/json_serialize.py" = 52 -"tests/unit/benchmark_parallel/benchmark_parallel.py" = 49 -"muutils/json_serialize/serializable_field.py" = 48 -"muutils/misc/freezing.py" = 46 -"tests/unit/errormode/test_errormode_init.py" = 46 +"muutils/statcounter.py" = 64 +"tests/unit/json_serialize/test_serializable_field.py" = 63 +"tests/unit/web/test_bundle_html.py" = 62 +"tests/unit/errormode/test_errormode_init.py" = 61 +"tests/unit/test_dictmagic.py" = 58 +"tests/unit/json_serialize/serializable_dataclass/test_methods_no_override.py" = 56 +"tests/unit/test_spinner.py" = 54 +"muutils/json_serialize/serializable_field.py" = 53 +"muutils/misc/freezing.py" = 53 +"muutils/cli/command.py" = 52 +"muutils/nbutils/convert_ipynb_to_script.py" = 49 +"muutils/web/bundle_html.py" = 46 "muutils/tensor_info.py" = 45 -"muutils/parallel.py" = 42 -"muutils/statcounter.py" = 41 +"tests/unit/test_kappa.py" = 45 "muutils/spinner.py" = 39 -"muutils/sysinfo.py" = 39 -"tests/unit/json_serialize/test_serializable_field.py" = 38 -"muutils/web/bundle_html.py" = 37 -"tests/unit/test_dictmagic.py" = 37 -"tests/unit/errormode/test_errormode_functionality.py" = 36 -"muutils/nbutils/convert_ipynb_to_script.py" = 34 -"muutils/nbutils/configure_notebook.py" = 33 -"tests/unit/test_spinner.py" = 32 -"tests/unit/misc/test_freeze.py" = 31 -"tests/unit/misc/test_numerical_conversions.py" = 27 -"muutils/misc/sequence.py" = 22 -"muutils/mlutils.py" = 22 -"muutils/validate_type.py" = 20 -"tests/unit/validate_type/test_validate_type_special.py" = 20 -"muutils/math/matrix_powers.py" = 18 -"tests/unit/json_serialize/test_util.py" = 18 -"muutils/interval.py" = 17 -"tests/unit/test_kappa.py" = 17 -"muutils/tensor_utils.py" = 16 -"tests/unit/json_serialize/serializable_dataclass/test_methods_no_override.py" = 15 -"tests/unit/nbutils/test_configure_notebook.py" = 14 -"muutils/jsonlines.py" = 13 -"tests/unit/json_serialize/test_array.py" = 13 -"muutils/nbutils/run_notebook_tests.py" = 12 -"muutils/errormode.py" = 11 -"muutils/logger/logger.py" = 11 -"tests/unit/misc/test_misc.py" = 11 -"tests/unit/test_tensor_info_torch.py" = 11 -"muutils/dbg.py" = 10 -"tests/unit/test_collect_warnings.py" = 10 -"muutils/kappa.py" = 9 -"muutils/misc/classes.py" = 9 -"muutils/cli/arg_bool.py" = 8 -"muutils/json_serialize/dataclass_transform_mock.py" = 8 -"tests/unit/test_console_unicode.py" = 8 -"tests/util/test_fire.py" = 8 -"muutils/collect_warnings.py" = 7 -"muutils/nbutils/mermaid.py" = 7 +"tests/unit/misc/test_numerical_conversions.py" = 37 +"muutils/parallel.py" = 33 +"muutils/misc/sequence.py" = 31 +"muutils/sysinfo.py" = 31 +"tests/unit/nbutils/test_configure_notebook.py" = 28 +"muutils/dbg.py" = 27 +"tests/unit/misc/test_freeze.py" = 27 +"tests/unit/math/test_matrix_powers_torch.py" = 26 +"tests/unit/test_tensor_info_torch.py" = 26 +"muutils/misc/func.py" = 25 +"muutils/nbutils/configure_notebook.py" = 25 +"muutils/validate_type.py" = 25 +"muutils/math/matrix_powers.py" = 24 +"muutils/tensor_utils.py" = 23 +"tests/unit/json_serialize/test_util.py" = 19 +"muutils/interval.py" = 18 +"muutils/cli/arg_bool.py" = 16 +"tests/unit/misc/test_misc.py" = 16 +"tests/unit/test_mlutils.py" = 16 +"muutils/jsonlines.py" = 14 +"muutils/nbutils/run_notebook_tests.py" = 14 +"tests/unit/misc/test_sequence.py" = 14 +"muutils/errormode.py" = 13 +"muutils/logger/logger.py" = 13 +"muutils/misc/classes.py" = 12 +"tests/unit/json_serialize/serializable_dataclass/test_sdc_properties_nested.py" = 11 +"tests/unit/nbutils/test_conversion.py" = 11 +"tests/unit/validate_type/test_validate_type_special.py" = 11 +"tests/unit/json_serialize/test_array.py" = 10 +"tests/unit/test_console_unicode.py" = 10 +"tests/unit/test_collect_warnings.py" = 9 +"tests/util/test_fire.py" = 9 +"tests/unit/json_serialize/serializable_dataclass/test_sdc_defaults.py" = 8 "tests/unit/cli/test_command.py" = 7 -"tests/unit/json_serialize/serializable_dataclass/test_sdc_defaults.py" = 7 +"tests/unit/test_jsonlines.py" = 7 +"muutils/collect_warnings.py" = 6 "muutils/logger/headerfuncs.py" = 6 -"tests/unit/nbutils/test_conversion.py" = 6 +"muutils/kappa.py" = 5 +"muutils/logger/log_util.py" = 5 +"muutils/web/html_to_pdf.py" = 5 "tests/conftest.py" = 5 -"tests/unit/json_serialize/serializable_dataclass/test_sdc_properties_nested.py" = 5 -"muutils/web/html_to_pdf.py" = 4 -"tests/unit/math/test_matrix_powers_torch.py" = 4 -"tests/unit/misc/test_sequence.py" = 4 -"tests/unit/test_jsonlines.py" = 4 -"tests/unit/test_mlutils.py" = 4 -"muutils/cli/command.py" = 3 -"muutils/json_serialize/array.py" = 3 +"muutils/json_serialize/array.py" = 4 +"muutils/json_serialize/dataclass_transform_mock.py" = 4 +"muutils/misc/__init__.py" = 4 +"tests/unit/json_serialize/serializable_dataclass/test_helpers_torch.py" = 4 +"tests/unit/test_chunks.py" = 4 +"muutils/json_serialize/__init__.py" = 3 +"muutils/mlutils.py" = 3 +"tests/unit/test_sysinfo.py" = 3 "muutils/logger/simplelogger.py" = 2 +"muutils/nbutils/mermaid.py" = 2 "tests/unit/logger/test_log_util.py" = 2 -"tests/unit/test_chunks.py" = 2 +"tests/unit/nbutils/test_configure_notebook_torch.py" = 2 "tests/unit/test_tensor_info.py" = 2 "tests/unit/test_timeit_fancy.py" = 2 "muutils/logger/exception_context.py" = 1 -"muutils/logger/log_util.py" = 1 +"muutils/misc/string.py" = 1 +"tests/unit/test_tensor_utils_torch.py" = 1 [type_errors.ty] -total_errors = 17 +total_errors = 9 [type_errors.ty.by_type] -"invalid-argument-type" = 4 "invalid-assignment" = 4 -"unsupported-operator" = 4 -"no-matching-overload" = 2 -"unresolved-attribute" = 2 -"invalid-type-form" = 1 +"no-matching-overload" = 3 +"invalid-argument-type" = 2 [type_errors.ty.by_file] -"tests/unit/json_serialize/test_json_serialize.py" = 4 "muutils/interval.py" = 3 "muutils/misc/func.py" = 3 -"muutils/sysinfo.py" = 2 -"tests/unit/web/test_bundle_html.py" = 2 +"muutils/dbg.py" = 1 "muutils/logger/logger.py" = 1 "muutils/logger/simplelogger.py" = 1 -"tests/unit/validate_type/test_validate_type.py" = 1 diff --git a/CHANGELOG.md b/CHANGELOG.md index 50fc9e4f..d82349e7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,11 @@ ## [Unreleased] +### Fixed + +- **`muutils.json_serialize.util.dc_eq`**: Fixed docstring that incorrectly stated `except_when_field_mismatch` defaults to `True` (actual default is `False`), and that it raises `TypeError` (it actually raises `AttributeError`) +- **`muutils.json_serialize.util.dc_eq`**: Updated flowchart in docstring to accurately reflect the control flow, including the missing `false_when_class_mismatch` decision branch + ### Breaking Changes - **`muutils.logger`**: `Logger.log()` and `SimpleLogger.log()` now require keyword arguments for all parameters after `msg`. This change was made to fix type checker compatibility between the two classes. diff --git a/TODO.md b/TODO.md index 67eb4c4d..2ded0655 100644 --- a/TODO.md +++ b/TODO.md @@ -24,4 +24,5 @@ # Guidelines: - make sure all type hints are python>=3.8 compatible -- always err on the side of STRICTER type hints! \ No newline at end of file +- always err on the side of STRICTER type hints! +- try to avoid breaking changes. check with the user before making breaking changes. if breaking changes are necessary, and the user agrees, make sure to document them properly and add them to CHANGELOG.md \ No newline at end of file diff --git a/muutils/dbg.py b/muutils/dbg.py index 89363398..a63b05da 100644 --- a/muutils/dbg.py +++ b/muutils/dbg.py @@ -59,7 +59,7 @@ class DBGListDefaultsType(typing.TypedDict): class DBGTensorArraySummaryDefaultsType(typing.TypedDict): - fmt: str + fmt: typing.Literal["unicode", "latex", "ascii"] precision: int stats: bool shape: bool @@ -239,7 +239,8 @@ def square(x: int) -> int: def tensor_info(tensor: typing.Any) -> str: from muutils.tensor_info import array_summary - return array_summary(tensor, as_list=False, **DBG_TENSOR_ARRAY_SUMMARY_DEFAULTS) + # mypy can't match overloads with **TypedDict spread + return array_summary(tensor, as_list=False, **DBG_TENSOR_ARRAY_SUMMARY_DEFAULTS) # type: ignore[call-overload] DBG_DICT_DEFAULTS: DBGDictDefaultsType = { diff --git a/muutils/dictmagic.py b/muutils/dictmagic.py index 595d8707..29e85a04 100644 --- a/muutils/dictmagic.py +++ b/muutils/dictmagic.py @@ -33,7 +33,7 @@ class DefaulterDict(typing.Dict[_KT, _VT], Generic[_KT, _VT]): """like a defaultdict, but default_factory is passed the key as an argument""" - def __init__(self, default_factory: Callable[[_KT], _VT], *args, **kwargs): + def __init__(self, default_factory: Callable[[_KT], _VT], *args: Any, **kwargs: Any) -> None: if args: raise TypeError( f"DefaulterDict does not support positional arguments: *args = {args}" @@ -380,8 +380,8 @@ def _default_shapes_convert(x: tuple) -> str: def condense_tensor_dict( data: TensorDict | TensorIterable, fmt: TensorDictFormats = "dict", - *args, - shapes_convert: Callable[[tuple], Any] = _default_shapes_convert, + *args: Any, + shapes_convert: Callable[[tuple[Union[int, str], ...]], Any] = _default_shapes_convert, drop_batch_dims: int = 0, sep: str = ".", dims_names_map: Optional[dict[int, str]] = None, diff --git a/muutils/json_serialize/dataclass_transform_mock.py b/muutils/json_serialize/dataclass_transform_mock.py index 299c678f..b9020f38 100644 --- a/muutils/json_serialize/dataclass_transform_mock.py +++ b/muutils/json_serialize/dataclass_transform_mock.py @@ -12,10 +12,10 @@ def dataclass_transform( frozen_default: bool = False, field_specifiers: tuple[Union[type[Any], typing.Callable[..., Any]], ...] = (), **kwargs: Any, -) -> typing.Callable: +) -> typing.Callable[[Any], Any]: "mock `typing.dataclass_transform` for python <3.11" - def decorator(cls_or_fn): + def decorator(cls_or_fn: Any) -> Any: cls_or_fn.__dataclass_transform__ = { "eq_default": eq_default, "order_default": order_default, diff --git a/muutils/json_serialize/serializable_dataclass.py b/muutils/json_serialize/serializable_dataclass.py index 44181132..3ad61692 100644 --- a/muutils/json_serialize/serializable_dataclass.py +++ b/muutils/json_serialize/serializable_dataclass.py @@ -92,7 +92,7 @@ class NestedClass(SerializableDataclass): Self = TypeVar("Self") -T = TypeVar("T") +T_SerializeableDataclass = TypeVar("T_SerializeableDataclass", bound="SerializableDataclass") class CantGetTypeHintsWarning(UserWarning): @@ -111,7 +111,7 @@ class ZanjMissingWarning(UserWarning): "flag to keep track of if we have successfully imported ZANJ" -def zanj_register_loader_serializable_dataclass(cls: typing.Type[T]): +def zanj_register_loader_serializable_dataclass(cls: typing.Type[T_SerializeableDataclass]): """Register a serializable dataclass with the ZANJ import this allows `ZANJ().read()` to load the class and not just return plain dicts @@ -516,12 +516,12 @@ def __deepcopy__(self, memo: dict) -> "SerializableDataclass": # cache this so we don't have to keep getting it # TODO: are the types hashable? does this even make sense? @functools.lru_cache(typed=True) -def get_cls_type_hints_cached(cls: Type[T]) -> dict[str, Any]: +def get_cls_type_hints_cached(cls: Type[T_SerializeableDataclass]) -> dict[str, Any]: "cached typing.get_type_hints for a class" return typing.get_type_hints(cls) -def get_cls_type_hints(cls: Type[T]) -> dict[str, Any]: +def get_cls_type_hints(cls: Type[T_SerializeableDataclass]) -> dict[str, Any]: "helper function to get type hints for a class" cls_type_hints: dict[str, Any] try: @@ -596,8 +596,8 @@ def serializable_dataclass( on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, on_typecheck_mismatch: ErrorMode = _DEFAULT_ON_TYPECHECK_MISMATCH, methods_no_override: list[str] | None = None, - **kwargs, -): + **kwargs: Any, +) -> Any: """decorator to make a dataclass serializable. **must also make it inherit from `SerializableDataclass`!!** types will be validated (like pydantic) unless `on_typecheck_mismatch` is set to `ErrorMode.IGNORE` @@ -691,7 +691,7 @@ class to decorate. don't pass this arg, just use this as a decorator else: _properties_to_serialize = properties_to_serialize - def wrap(cls: Type[T]) -> Type[T]: + def wrap(cls: Type[T_SerializeableDataclass]) -> Type[T_SerializeableDataclass]: # Modify the __annotations__ dictionary to replace regular fields with SerializableField for field_name, field_type in cls.__annotations__.items(): field_value = getattr(cls, field_name, None) @@ -733,7 +733,7 @@ def wrap(cls: Type[T]) -> Type[T]: # define `serialize` func # done locally since it depends on args to the decorator # ====================================================================== - def serialize(self) -> dict[str, Any]: + def serialize(self: Any) -> dict[str, Any]: result: dict[str, Any] = { _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)" } @@ -797,7 +797,7 @@ def serialize(self) -> dict[str, Any]: # ====================================================================== # mypy thinks this isnt a classmethod @classmethod # type: ignore[misc] - def load(cls, data: dict[str, Any] | T) -> T: + def load(cls: type[T_SerializeableDataclass], data: dict[str, Any] | T_SerializeableDataclass) -> T_SerializeableDataclass: # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ if isinstance(data, cls): return data @@ -812,7 +812,9 @@ def load(cls, data: dict[str, Any] | T) -> T: ctor_kwargs: dict[str, Any] = dict() # iterate over the fields of the class - for field in dataclasses.fields(cls): + # mypy doesn't recognize @dataclass_transform for dataclasses.fields() + # https://github.com/python/mypy/issues/16241 + for field in dataclasses.fields(cls): # type: ignore[arg-type] # check if the field is a SerializableField assert isinstance(field, SerializableField), ( f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new" @@ -853,7 +855,7 @@ def load(cls, data: dict[str, Any] | T) -> T: ctor_kwargs[field.name] = value # create a new instance of the class with the constructor kwargs - output: cls = cls(**ctor_kwargs) + output: T_SerializeableDataclass = cls(**ctor_kwargs) # validate the types of the fields if needed if on_typecheck_mismatch != ErrorMode.IGNORE: @@ -903,14 +905,14 @@ def load(cls, data: dict[str, Any] | T) -> T: # mypy says "Type cannot be declared in assignment to non-self attribute" so thats why I've left the hints in the comments if "serialize" not in _methods_no_override: # type is `Callable[[T], dict]` - cls.serialize = serialize # type: ignore[attr-defined] + cls.serialize = serialize # type: ignore[attr-defined, method-assign] if "load" not in _methods_no_override: # type is `Callable[[dict], T]` - cls.load = load # type: ignore[attr-defined] + cls.load = load # type: ignore[attr-defined, method-assign, assignment] if "validate_field_type" not in _methods_no_override: # type is `Callable[[T, ErrorMode], bool]` - cls.validate_fields_types = SerializableDataclass__validate_fields_types # type: ignore[attr-defined] + cls.validate_fields_types = SerializableDataclass__validate_fields_types # type: ignore[attr-defined, method-assign] if "__eq__" not in _methods_no_override: # type is `Callable[[T, T], bool]` diff --git a/muutils/json_serialize/util.py b/muutils/json_serialize/util.py index 255e8741..b8706993 100644 --- a/muutils/json_serialize/util.py +++ b/muutils/json_serialize/util.py @@ -228,8 +228,8 @@ def dc_eq( if `False`, will attempt to compare the fields. - `except_when_field_mismatch: bool` only relevant if `except_when_class_mismatch` is `False` and `false_when_class_mismatch` is `False`. - if `True`, will throw `TypeError` if the fields are different. - (default: `True`) + if `True`, will throw `AttributeError` if the fields are different. + (default: `False`) # Returns: - `bool`: True if the dataclasses are equal, False otherwise @@ -238,34 +238,40 @@ def dc_eq( - `TypeError`: if the dataclasses are of different classes - `AttributeError`: if the dataclasses have different fields - # TODO: after "except when class mismatch" is False, shouldn't we then go to "field keys match"? ``` - [START] - ▼ - ┌───────────┐ ┌─────────┐ - │dc1 is dc2?├─►│ classes │ - └──┬────────┘No│ match? │ - ──── │ ├─────────┤ - (True)◄──┘Yes │No │Yes - ──── ▼ ▼ - ┌────────────────┐ ┌────────────┐ - │ except when │ │ fields keys│ - │ class mismatch?│ │ match? │ - ├───────────┬────┘ ├───────┬────┘ - │Yes │No │No │Yes - ▼ ▼ ▼ ▼ - ─────────── ┌──────────┐ ┌────────┐ - { raise } │ except │ │ field │ - { TypeError } │ when │ │ values │ - ─────────── │ field │ │ match? │ - │ mismatch?│ ├────┬───┘ - ├───────┬──┘ │ │Yes - │Yes │No │No ▼ - ▼ ▼ │ ──── - ─────────────── ───── │ (True) - { raise } (False)◄┘ ──── - { AttributeError} ───── - ─────────────── + [START] + ▼ + ┌─────────────┐ + │ dc1 is dc2? │───Yes───► (True) + └──────┬──────┘ + │No + ▼ + ┌───────────────┐ + │ classes match?│───Yes───► [compare field values] ───► (True/False) + └──────┬────────┘ + │No + ▼ + ┌────────────────────────────┐ + │ except_when_class_mismatch?│───Yes───► { raise TypeError } + └─────────────┬──────────────┘ + │No + ▼ + ┌────────────────────────────┐ + │ false_when_class_mismatch? │───Yes───► (False) + └─────────────┬──────────────┘ + │No + ▼ + ┌────────────────────────────┐ + │ except_when_field_mismatch?│───No────► [compare field values] + └─────────────┬──────────────┘ + │Yes + ▼ + ┌───────────────┐ + │ fields match? │───Yes───► [compare field values] + └──────┬────────┘ + │No + ▼ + { raise AttributeError } ``` """ diff --git a/muutils/kappa.py b/muutils/kappa.py index e964b4ff..32970669 100644 --- a/muutils/kappa.py +++ b/muutils/kappa.py @@ -34,15 +34,15 @@ def __init__(self, func_getitem: Callable[[_kappa_K], _kappa_V]) -> None: ) ) - def __getitem__(self, x) -> _kappa_V: + def __getitem__(self, x: _kappa_K) -> _kappa_V: return self.func_getitem(x) - def __iter__(self): + def __iter__(self) -> None: # type: ignore[override] raise NotImplementedError( "This method is not implemented for Kappa, we don't know the valid inputs" ) - def __len__(self): + def __len__(self) -> int: raise NotImplementedError( "This method is not implemented for Kappa, no idea how many valid inputs there are" ) diff --git a/muutils/logger/logger.py b/muutils/logger/logger.py index 7ef6054d..dbbc74a3 100644 --- a/muutils/logger/logger.py +++ b/muutils/logger/logger.py @@ -228,7 +228,7 @@ def log( if not isinstance(msg, typing.Mapping): msg_dict = {"_msg": msg} else: - msg_dict = dict(msg) + msg_dict = dict(typing.cast(typing.Mapping[str, Any], msg)) # level+stream metadata if lvl is not None: diff --git a/muutils/logger/simplelogger.py b/muutils/logger/simplelogger.py index 38e19a77..b9d6f77f 100644 --- a/muutils/logger/simplelogger.py +++ b/muutils/logger/simplelogger.py @@ -73,7 +73,7 @@ def log(self, msg: JSONitem, *, console_print: bool = False, **kwargs: Any) -> N if not isinstance(msg, typing.Mapping): msg_dict = {"_msg": msg} else: - msg_dict = dict(msg) + msg_dict = dict(typing.cast(typing.Mapping[str, Any], msg)) if self._timestamp: msg_dict["_timestamp"] = time.time() diff --git a/muutils/math/matrix_powers.py b/muutils/math/matrix_powers.py index c5e6bf48..c3db19ac 100644 --- a/muutils/math/matrix_powers.py +++ b/muutils/math/matrix_powers.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import List, Sequence, TYPE_CHECKING +from typing import Any, List, Sequence, TYPE_CHECKING import numpy as np from jaxtyping import Float, Int @@ -83,9 +83,9 @@ def matrix_powers( # BUG: breaks with integer matrices??? # TYPING: jaxtyping hints not working here, separate file for torch implementation? def matrix_powers_torch( - A, # : Float["torch.Tensor", "n n"], + A: Any, # : Float["torch.Tensor", "n n"], powers: Sequence[int], -): # Float["torch.Tensor", "n_powers n n"]: +) -> Any: # Float["torch.Tensor", "n_powers n n"]: """Compute multiple powers of a matrix efficiently. Uses binary exponentiation to compute powers in O(log max(powers)) diff --git a/muutils/misc/func.py b/muutils/misc/func.py index 5db91583..2bff7bb8 100644 --- a/muutils/misc/func.py +++ b/muutils/misc/func.py @@ -119,7 +119,7 @@ def decorator( func: Callable[FuncParams, ReturnType], ) -> Callable[FuncParams, ReturnType]: @functools.wraps(func) - def wrapper(*args: FuncParams.args, **kwargs: FuncParams.kwargs) -> ReturnType: + def wrapper(*args: FuncParams.args, **kwargs: FuncParams.kwargs) -> ReturnType: # pyright: ignore[reportUnknownParameterType] if kwarg_name in kwargs: value: Any = kwargs[kwarg_name] if not validator(value): @@ -179,7 +179,7 @@ def decorator( func: Callable[FuncParams, ReturnType], ) -> Callable[FuncParams, ReturnType]: @functools.wraps(func) - def wrapper(*args: FuncParams.args, **kwargs: FuncParams.kwargs) -> ReturnType: + def wrapper(*args: FuncParams.args, **kwargs: FuncParams.kwargs) -> ReturnType: # pyright: ignore[reportUnknownParameterType] if kwarg_name in kwargs: # TODO: no way to type hint this, I think if check(kwargs[kwarg_name]): # type: ignore[arg-type] @@ -225,7 +225,7 @@ def decorator( LambdaArgsTypes = TypeVar("LambdaArgsTypes", bound=Tuple[type, ...]) -def typed_lambda( +def typed_lambda( # pyright: ignore[reportUnknownParameterType] fn: Callable[[Unpack[LambdaArgs]], ReturnType], in_types: LambdaArgsTypes, out_type: type[ReturnType], @@ -271,7 +271,7 @@ def typed_lambda( annotations["return"] = out_type @functools.wraps(fn) - def wrapped(*args: Unpack[LambdaArgs]) -> ReturnType: + def wrapped(*args: Unpack[LambdaArgs]) -> ReturnType: # pyright: ignore[reportUnknownParameterType] return fn(*args) wrapped.__annotations__ = annotations diff --git a/muutils/misc/typing_breakdown.py b/muutils/misc/typing_breakdown.py index cf363e0d..a5cee463 100644 --- a/muutils/misc/typing_breakdown.py +++ b/muutils/misc/typing_breakdown.py @@ -52,6 +52,9 @@ class TypeCheckResult: type_checker: Literal["mypy", "basedpyright", "ty"] by_type: Dict[str, int] = field(default_factory=lambda: defaultdict(int)) by_file: Dict[str, int] = field(default_factory=lambda: defaultdict(int)) + # Separate tracking for warnings (used by basedpyright) + warnings_by_type: Dict[str, int] = field(default_factory=lambda: defaultdict(int)) + warnings_by_file: Dict[str, int] = field(default_factory=lambda: defaultdict(int)) @property def total_errors(self) -> int: @@ -78,19 +81,45 @@ def filter_by(self, top_n: int | None) -> TypeCheckResult: key=lambda x: x[1], reverse=True, ) + sorted_warnings_by_type: List[Tuple[str, int]] = sorted( + self.warnings_by_type.items(), + key=lambda x: x[1], + reverse=True, + ) + sorted_warnings_by_file: List[Tuple[str, int]] = sorted( + self.warnings_by_file.items(), + key=lambda x: x[1], + reverse=True, + ) # Apply top_n limit if specified if top_n is not None: sorted_by_type = sorted_by_type[:top_n] sorted_by_file = sorted_by_file[:top_n] + sorted_warnings_by_type = sorted_warnings_by_type[:top_n] + sorted_warnings_by_file = sorted_warnings_by_file[:top_n] # Create new instance with filtered data (dicts maintain insertion order in Python 3.7+) result: TypeCheckResult = TypeCheckResult(type_checker=self.type_checker) result.by_type = dict(sorted_by_type) result.by_file = dict(sorted_by_file) + result.warnings_by_type = dict(sorted_warnings_by_type) + result.warnings_by_file = dict(sorted_warnings_by_file) return result + @property + def total_warnings(self) -> int: + "total number of warnings across all types" + total_by_type: int = sum(self.warnings_by_type.values()) + total_by_file: int = sum(self.warnings_by_file.values()) + + if total_by_type != total_by_file: + err_msg: str = f"Warning count mismatch for {self.type_checker}: by_type={total_by_type}, by_file={total_by_file}" + raise ValueError(err_msg) + + return total_by_type + def to_toml(self) -> str: "format as TOML-like output" lines: List[str] = [] @@ -121,6 +150,30 @@ def to_toml(self) -> str: # Always quote file paths lines.append(f'"{file_path}" = {count}') + # Add warnings sections if there are any warnings + if self.warnings_by_type or self.warnings_by_file: + lines.append("") + lines.append(f"[type_warnings.{self.type_checker}]") + try: + lines.append(f"total_warnings = {self.total_warnings}") + except ValueError: + lines.append(f"total_warnings_by_type = {sum(self.warnings_by_type.values())}") + lines.append(f"total_warnings_by_file = {sum(self.warnings_by_file.values())}") + lines.append("") + + # warnings by_type section + lines.append(f"[type_warnings.{self.type_checker}.by_type]") + warning_type: str + for warning_type, count in self.warnings_by_type.items(): + lines.append(f'"{warning_type}" = {count}') + + lines.append("") + + # warnings by_file section + lines.append(f"[type_warnings.{self.type_checker}.by_file]") + for file_path, count in self.warnings_by_file.items(): + lines.append(f'"{file_path}" = {count}') + return "\n".join(lines) @@ -147,25 +200,53 @@ def parse_basedpyright(content: str) -> TypeCheckResult: # Pattern for file paths (lines that start with /) # Pattern for errors: indented line with - error/warning: message (code) + # Some diagnostics span multiple lines with (reportCode) on a continuation line current_file: str = "" + pending_diagnostic_type: str | None = None # "error" or "warning" waiting for code line: str for line in content.splitlines(): - # Check if this is a file path line + # Check if this is a file path line (starts with / and no leading space) if line and not line.startswith(" ") and line.startswith("/"): current_file = strip_cwd(line.strip()) - # Check if this is an error/warning line + pending_diagnostic_type = None + elif line.strip() and current_file: - # Match pattern like: " path:line:col - warning: message (reportCode)" + # Try to match single-line format: " path:line:col - warning: message (reportCode)" match: re.Match[str] | None = re.search( r"\s+.+:\d+:\d+ - (error|warning): .+ \((\w+)\)", line ) if match: - # TODO: handle warnings vs errors - _error_type: str = match.group(1) + diagnostic_type: str = match.group(1) error_code: str = match.group(2) - result.by_type[error_code] += 1 - result.by_file[current_file] += 1 + if diagnostic_type == "warning": + result.warnings_by_type[error_code] += 1 + result.warnings_by_file[current_file] += 1 + else: + result.by_type[error_code] += 1 + result.by_file[current_file] += 1 + pending_diagnostic_type = None + else: + # Check if this is a diagnostic line without code (multi-line format start) + diag_match: re.Match[str] | None = re.search( + r"\s+.+:\d+:\d+ - (error|warning): ", line + ) + if diag_match: + pending_diagnostic_type = diag_match.group(1) + # Check if this is a continuation line with the code + elif pending_diagnostic_type: + code_match: re.Match[str] | None = re.search( + r"\((\w+)\)\s*$", line + ) + if code_match: + error_code = code_match.group(1) + if pending_diagnostic_type == "warning": + result.warnings_by_type[error_code] += 1 + result.warnings_by_file[current_file] += 1 + else: + result.by_type[error_code] += 1 + result.by_file[current_file] += 1 + pending_diagnostic_type = None return result diff --git a/muutils/mlutils.py b/muutils/mlutils.py index 628d0e94..7fd60988 100644 --- a/muutils/mlutils.py +++ b/muutils/mlutils.py @@ -9,7 +9,7 @@ import warnings from itertools import islice from pathlib import Path -from typing import Any, Callable, Optional, TypeVar, Union +from typing import Any, Callable, Generator, Iterable, Optional, TypeVar, Union ARRAY_IMPORTS: bool try: @@ -97,7 +97,10 @@ def set_reproducibility(seed: int = DEFAULT_SEED): os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" -def chunks(it, chunk_size): +T = TypeVar("T") + + +def chunks(it: Iterable[T], chunk_size: int) -> Generator[list[T], Any, None]: """Yield successive chunks from an iterator.""" # https://stackoverflow.com/a/61435714 iterator = iter(it) diff --git a/muutils/nbutils/configure_notebook.py b/muutils/nbutils/configure_notebook.py index 63f080be..4f207bf6 100644 --- a/muutils/nbutils/configure_notebook.py +++ b/muutils/nbutils/configure_notebook.py @@ -190,7 +190,7 @@ def setup_plots( def configure_notebook( - *args, + *args: typing.Any, seed: int = 42, device: typing.Any = None, # this can be a string, torch.device, or None dark_mode: bool = True, diff --git a/muutils/nbutils/mermaid.py b/muutils/nbutils/mermaid.py index 98c5ced6..392dfa57 100644 --- a/muutils/nbutils/mermaid.py +++ b/muutils/nbutils/mermaid.py @@ -12,7 +12,7 @@ ) -def mm(graph): +def mm(graph: str) -> None: """for plotting mermaid.js diagrams""" graphbytes = graph.encode("ascii") base64_bytes = base64.b64encode(graphbytes) diff --git a/muutils/parallel.py b/muutils/parallel.py index f5189a15..faddc330 100644 --- a/muutils/parallel.py +++ b/muutils/parallel.py @@ -31,7 +31,7 @@ class ProgressBarFunction(Protocol): "a protocol for a progress bar function" - def __call__(self, iterable: Iterable, **kwargs: Any) -> Iterable: ... + def __call__(self, iterable: Iterable[Any], **kwargs: Any) -> Iterable[Any]: ... ProgressBarOption = Literal["tqdm", "spinner", "none", None] @@ -52,7 +52,7 @@ def __call__(self, iterable: Iterable, **kwargs: Any) -> Iterable: ... DEFAULT_PBAR_FN = "spinner" -def spinner_fn_wrap(x: Iterable, **kwargs) -> List: +def spinner_fn_wrap(x: Iterable[Any], **kwargs: Any) -> List[Any]: "spinner wrapper" spinnercontext_allowed_kwargs: set[str] = get_fn_allowed_kwargs( SpinnerContext.__init__ @@ -72,7 +72,7 @@ def spinner_fn_wrap(x: Iterable, **kwargs) -> List: return output -def map_kwargs_for_tqdm(kwargs: dict) -> dict: +def map_kwargs_for_tqdm(kwargs: Dict[str, Any]) -> Dict[str, Any]: "map kwargs for tqdm, cant wrap because the pbar dissapears?" tqdm_allowed_kwargs: set[str] = get_fn_allowed_kwargs(tqdm.tqdm.__init__) mapped_kwargs: dict = {k: v for k, v in kwargs.items() if k in tqdm_allowed_kwargs} @@ -86,7 +86,7 @@ def map_kwargs_for_tqdm(kwargs: dict) -> dict: return mapped_kwargs -def no_progress_fn_wrap(x: Iterable, **kwargs) -> Iterable: +def no_progress_fn_wrap(x: Iterable[Any], **kwargs: Any) -> Iterable[Any]: "fallback to no progress bar" return x @@ -94,8 +94,8 @@ def no_progress_fn_wrap(x: Iterable, **kwargs) -> Iterable: def set_up_progress_bar_fn( pbar: Union[ProgressBarFunction, ProgressBarOption], pbar_kwargs: Optional[Dict[str, Any]] = None, - **extra_kwargs, -) -> Tuple[ProgressBarFunction, dict]: + **extra_kwargs: Any, +) -> Tuple[ProgressBarFunction, Dict[str, Any]]: """set up the progress bar function and its kwargs # Parameters: diff --git a/muutils/spinner.py b/muutils/spinner.py index 13507c4b..81c96a31 100644 --- a/muutils/spinner.py +++ b/muutils/spinner.py @@ -439,7 +439,12 @@ def __enter__(self) -> "SpinnerContext": self.start() return self - def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: self.stop(failed=exc_type is not None) diff --git a/muutils/statcounter.py b/muutils/statcounter.py index 60853a7f..4054e3c7 100644 --- a/muutils/statcounter.py +++ b/muutils/statcounter.py @@ -9,7 +9,7 @@ from collections import Counter from functools import cached_property from itertools import chain -from typing import Callable, Optional, Sequence, Union +from typing import Any, Callable, Optional, Sequence, Union # _GeneralArray = Union[np.ndarray, "torch.Tensor"] @@ -224,8 +224,8 @@ def load(cls, data: dict) -> "StatCounter": @classmethod def from_list_arrays( cls, - arr, - map_func: Callable = float, + arr: Any, + map_func: Callable[[Any], float] = float, ) -> "StatCounter": """calls `map_func` on each element of `universal_flatten(arr)`""" return cls([map_func(x) for x in universal_flatten(arr)]) diff --git a/muutils/sysinfo.py b/muutils/sysinfo.py index 9aa0c16c..c213af19 100644 --- a/muutils/sysinfo.py +++ b/muutils/sysinfo.py @@ -9,7 +9,7 @@ def _popen(cmd: list[str], split_out: bool = False) -> dict[str, typing.Any]: - p: subprocess.Popen = subprocess.Popen( + p: subprocess.Popen[bytes] = subprocess.Popen( cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE ) diff --git a/muutils/tensor_info.py b/muutils/tensor_info.py index 023e5203..0d25e7b7 100644 --- a/muutils/tensor_info.py +++ b/muutils/tensor_info.py @@ -13,24 +13,6 @@ except ImportError: from typing_extensions import TypedDict - -class ArraySummarySettings(TypedDict): - """Type definition for array_summary default settings.""" - - fmt: OutputFormat - precision: int - stats: bool - shape: bool - dtype: bool - device: bool - requires_grad: bool - sparkline: bool - sparkline_bins: int - sparkline_logy: Optional[bool] - colored: bool - as_list: bool - eq_char: str - # Global color definitions COLORS: Dict[str, Dict[str, str]] = { "latex": { @@ -93,6 +75,25 @@ class ArraySummarySettings(TypedDict): OutputFormat = Literal["unicode", "latex", "ascii"] + +class ArraySummarySettings(TypedDict): + """Type definition for array_summary default settings.""" + + fmt: OutputFormat + precision: int + stats: bool + shape: bool + dtype: bool + device: bool + requires_grad: bool + sparkline: bool + sparkline_bins: int + sparkline_logy: Optional[bool] + colored: bool + as_list: bool + eq_char: str + + SYMBOLS: Dict[OutputFormat, Dict[str, str]] = { "latex": { "range": r"\mathcal{R}", @@ -367,21 +368,21 @@ def generate_sparkline( return spark, log_y -DEFAULT_SETTINGS: Dict[str, Any] = dict( - fmt="unicode", - precision=2, - stats=True, - shape=True, - dtype=True, - device=True, - requires_grad=True, - sparkline=False, - sparkline_bins=5, - sparkline_logy=None, - colored=False, - as_list=False, - eq_char="=", -) +DEFAULT_SETTINGS: ArraySummarySettings = { + "fmt": "unicode", + "precision": 2, + "stats": True, + "shape": True, + "dtype": True, + "device": True, + "requires_grad": True, + "sparkline": False, + "sparkline_bins": 5, + "sparkline_logy": None, + "colored": False, + "as_list": False, + "eq_char": "=", +} def apply_color( @@ -424,7 +425,7 @@ def colorize_dtype(dtype_str: str, colors: Dict[str, str], using_tex: bool) -> s return type_colored -def format_shape_colored(shape_val, colors: Dict[str, str], using_tex: bool) -> str: +def format_shape_colored(shape_val: Any, colors: Dict[str, str], using_tex: bool) -> str: """Format shape with proper coloring for both 1D and multi-D arrays.""" def apply_color(text: str, color_key: str) -> str: @@ -476,30 +477,70 @@ class _UseDefaultType: @overload def array_summary( array: Any, + fmt: OutputFormat = ..., + precision: int = ..., + stats: bool = ..., + shape: bool = ..., + dtype: bool = ..., + device: bool = ..., + requires_grad: bool = ..., + sparkline: bool = ..., + sparkline_bins: int = ..., + sparkline_logy: Optional[bool] = ..., + colored: bool = ..., + eq_char: str = ..., + *, as_list: Literal[True], - **kwargs, ) -> List[str]: ... @overload def array_summary( array: Any, - as_list: Literal[False], - **kwargs, + fmt: OutputFormat = ..., + precision: int = ..., + stats: bool = ..., + shape: bool = ..., + dtype: bool = ..., + device: bool = ..., + requires_grad: bool = ..., + sparkline: bool = ..., + sparkline_bins: int = ..., + sparkline_logy: Optional[bool] = ..., + colored: bool = ..., + eq_char: str = ..., + as_list: Literal[False] = ..., ) -> str: ... -def array_summary( # type: ignore[misc] - array, - fmt: OutputFormat = _USE_DEFAULT, # type: ignore[assignment] - precision: int = _USE_DEFAULT, # type: ignore[assignment] - stats: bool = _USE_DEFAULT, # type: ignore[assignment] - shape: bool = _USE_DEFAULT, # type: ignore[assignment] - dtype: bool = _USE_DEFAULT, # type: ignore[assignment] - device: bool = _USE_DEFAULT, # type: ignore[assignment] - requires_grad: bool = _USE_DEFAULT, # type: ignore[assignment] - sparkline: bool = _USE_DEFAULT, # type: ignore[assignment] - sparkline_bins: int = _USE_DEFAULT, # type: ignore[assignment] - sparkline_logy: Optional[bool] = _USE_DEFAULT, # type: ignore[assignment] - colored: bool = _USE_DEFAULT, # type: ignore[assignment] - eq_char: str = _USE_DEFAULT, # type: ignore[assignment] - as_list: bool = _USE_DEFAULT, # type: ignore[assignment] +@overload +def array_summary( + array: Any, + fmt: OutputFormat = ..., + precision: int = ..., + stats: bool = ..., + shape: bool = ..., + dtype: bool = ..., + device: bool = ..., + requires_grad: bool = ..., + sparkline: bool = ..., + sparkline_bins: int = ..., + sparkline_logy: Optional[bool] = ..., + colored: bool = ..., + eq_char: str = ..., + as_list: bool = ..., +) -> Union[str, List[str]]: ... +def array_summary( + array: Any, + fmt: Union[OutputFormat, _UseDefaultType] = _USE_DEFAULT, + precision: Union[int, _UseDefaultType] = _USE_DEFAULT, + stats: Union[bool, _UseDefaultType] = _USE_DEFAULT, + shape: Union[bool, _UseDefaultType] = _USE_DEFAULT, + dtype: Union[bool, _UseDefaultType] = _USE_DEFAULT, + device: Union[bool, _UseDefaultType] = _USE_DEFAULT, + requires_grad: Union[bool, _UseDefaultType] = _USE_DEFAULT, + sparkline: Union[bool, _UseDefaultType] = _USE_DEFAULT, + sparkline_bins: Union[int, _UseDefaultType] = _USE_DEFAULT, + sparkline_logy: Union[Optional[bool], _UseDefaultType] = _USE_DEFAULT, + colored: Union[bool, _UseDefaultType] = _USE_DEFAULT, + eq_char: Union[str, _UseDefaultType] = _USE_DEFAULT, + as_list: Union[bool, _UseDefaultType] = _USE_DEFAULT, ) -> Union[str, List[str]]: """Format array information into a readable summary. @@ -535,31 +576,31 @@ def array_summary( # type: ignore[misc] - `Union[str, List[str]]` Formatted statistical summary, either as string or list of strings """ - if fmt is _USE_DEFAULT: + if isinstance(fmt, _UseDefaultType): fmt = DEFAULT_SETTINGS["fmt"] - if precision is _USE_DEFAULT: + if isinstance(precision, _UseDefaultType): precision = DEFAULT_SETTINGS["precision"] - if stats is _USE_DEFAULT: + if isinstance(stats, _UseDefaultType): stats = DEFAULT_SETTINGS["stats"] - if shape is _USE_DEFAULT: + if isinstance(shape, _UseDefaultType): shape = DEFAULT_SETTINGS["shape"] - if dtype is _USE_DEFAULT: + if isinstance(dtype, _UseDefaultType): dtype = DEFAULT_SETTINGS["dtype"] - if device is _USE_DEFAULT: + if isinstance(device, _UseDefaultType): device = DEFAULT_SETTINGS["device"] - if requires_grad is _USE_DEFAULT: + if isinstance(requires_grad, _UseDefaultType): requires_grad = DEFAULT_SETTINGS["requires_grad"] - if sparkline is _USE_DEFAULT: + if isinstance(sparkline, _UseDefaultType): sparkline = DEFAULT_SETTINGS["sparkline"] - if sparkline_bins is _USE_DEFAULT: + if isinstance(sparkline_bins, _UseDefaultType): sparkline_bins = DEFAULT_SETTINGS["sparkline_bins"] - if sparkline_logy is _USE_DEFAULT: + if isinstance(sparkline_logy, _UseDefaultType): sparkline_logy = DEFAULT_SETTINGS["sparkline_logy"] - if colored is _USE_DEFAULT: + if isinstance(colored, _UseDefaultType): colored = DEFAULT_SETTINGS["colored"] - if as_list is _USE_DEFAULT: + if isinstance(as_list, _UseDefaultType): as_list = DEFAULT_SETTINGS["as_list"] - if eq_char is _USE_DEFAULT: + if isinstance(eq_char, _UseDefaultType): eq_char = DEFAULT_SETTINGS["eq_char"] array_data: Dict[str, Any] = array_info(array, hist_bins=sparkline_bins) diff --git a/tests/unit/json_serialize/test_array_torch.py b/tests/unit/json_serialize/test_array_torch.py index a4ae803c..b79e18ab 100644 --- a/tests/unit/json_serialize/test_array_torch.py +++ b/tests/unit/json_serialize/test_array_torch.py @@ -242,5 +242,10 @@ def test_mixed_numpy_torch(): assert _FORMAT_KEY in serialized["torch_tensor"] # Check format strings identify the type - assert "numpy" in serialized["numpy_array"][_FORMAT_KEY] # pyright: ignore[reportOperatorIssue] - assert "torch" in serialized["torch_tensor"][_FORMAT_KEY] # pyright: ignore[reportOperatorIssue] + numpy_format = serialized["numpy_array"][_FORMAT_KEY] + assert isinstance(numpy_format, str) + assert "numpy" in numpy_format + + torch_format = serialized["torch_tensor"][_FORMAT_KEY] + assert isinstance(torch_format, str) + assert "torch" in torch_format diff --git a/tests/unit/json_serialize/test_json_serialize.py b/tests/unit/json_serialize/test_json_serialize.py index 4fbce5ec..ff68143f 100644 --- a/tests/unit/json_serialize/test_json_serialize.py +++ b/tests/unit/json_serialize/test_json_serialize.py @@ -546,10 +546,15 @@ def simple_serialize(self, obj, path): assert metadata["desc"] == "Test handler description" # Check that code and doc are included - assert "code" in metadata["check"] - assert "doc" in metadata["check"] - assert "code" in metadata["serialize_func"] - assert "doc" in metadata["serialize_func"] + check_data = metadata["check"] + assert isinstance(check_data, dict) + assert "code" in check_data + assert "doc" in check_data + + serialize_func_data = metadata["serialize_func"] + assert isinstance(serialize_func_data, dict) + assert "code" in serialize_func_data + assert "doc" in serialize_func_data # ============================================================================ @@ -631,7 +636,7 @@ def tracking_check(self, obj, path): def test_JsonSerializer_init_no_positional_args(): """Test that JsonSerializer raises ValueError on positional arguments.""" with pytest.raises(ValueError, match="no positional arguments"): - JsonSerializer("invalid", "args") # type: ignore[invalid-argument-type] + JsonSerializer("invalid", "args") # type: ignore[arg-type] # Should work with keyword args serializer = JsonSerializer(error_mode=ErrorMode.WARN) diff --git a/tests/unit/json_serialize/test_serializable_field.py b/tests/unit/json_serialize/test_serializable_field.py index 2b6bb088..11cc775f 100644 --- a/tests/unit/json_serialize/test_serializable_field.py +++ b/tests/unit/json_serialize/test_serializable_field.py @@ -116,7 +116,7 @@ def test_from_Field(): """Test converting a dataclasses.Field to SerializableField.""" # Create a standard dataclasses.Field dc_field: dataclasses.Field[int] = field( # type: ignore[assignment] - default=42, + default=42, # type: ignore[arg-type] init=True, repr=True, hash=None, @@ -140,9 +140,9 @@ def test_from_Field(): assert sf.deserialize_fn is None # Test with default_factory and init=False to avoid init=True, serialize=False error - dc_field2: dataclasses.Field[list[Any]] = field( - default_factory=list, repr=True, init=True - ) # type: ignore[assignment] + dc_field2: dataclasses.Field[list[Any]] = field( # type: ignore[assignment] + default_factory=list, repr=True, init=True # type: ignore[arg-type] + ) sf2 = SerializableField.from_Field(dc_field2) assert sf2.default_factory == list # noqa: E721 assert sf2.default is dataclasses.MISSING diff --git a/tests/unit/json_serialize/test_util.py b/tests/unit/json_serialize/test_util.py index 7ade3161..1dae89c7 100644 --- a/tests/unit/json_serialize/test_util.py +++ b/tests/unit/json_serialize/test_util.py @@ -1,5 +1,5 @@ from collections import namedtuple -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import NamedTuple import pytest @@ -248,15 +248,72 @@ class Container: assert dc_eq(c1, c2) is True assert dc_eq(c1, c3) is False - # Test except_when_field_mismatch with different classes but same fields + # Test except_when_field_mismatch with different classes and different fields + # Must set false_when_class_mismatch=False to reach the field check + with pytest.raises(AttributeError, match="different fields"): + dc_eq(p1, p3d, except_when_field_mismatch=True, false_when_class_mismatch=False) + + # Test except_when_field_mismatch with different classes but SAME fields - should NOT raise @dataclass class Point2D: x: int y: int - # Different classes but same fields - should raise with except_when_field_mismatch - with pytest.raises(AttributeError): - dc_eq(p1, p3d, except_when_field_mismatch=True) + p2d = Point2D(1, 2) + # Same fields, different classes, same values - should return True + result = dc_eq( + p1, p2d, except_when_field_mismatch=True, false_when_class_mismatch=False + ) + assert result is True + + # Different classes, same fields, different values - should return False + p2d_diff = Point2D(1, 99) + assert ( + dc_eq(p2d_diff, p1, false_when_class_mismatch=False, except_when_field_mismatch=True) + is False + ) + + # Test parameter precedence: except_when_class_mismatch takes precedence over false_when_class_mismatch + with pytest.raises(TypeError, match="Cannot compare dataclasses of different classes"): + dc_eq(p1, p3d, except_when_class_mismatch=True, false_when_class_mismatch=True) + + # Test parameter precedence: except_when_class_mismatch takes precedence over except_when_field_mismatch + with pytest.raises(TypeError, match="Cannot compare dataclasses of different classes"): + dc_eq(p1, p3d, except_when_class_mismatch=True, except_when_field_mismatch=True) + + # Test with empty dataclasses + @dataclass + class Empty: + pass + + @dataclass + class AlsoEmpty: + pass + + e1, e2 = Empty(), Empty() + assert dc_eq(e1, e2) is True + + # Different empty classes - same fields (none), should be equal when allowing cross-class comparison + ae = AlsoEmpty() + assert dc_eq(e1, ae, false_when_class_mismatch=False) is True + + # Test with compare=False fields - these should be ignored in comparison + @dataclass + class WithIgnored: + x: int + ignored: int = field(compare=False) + + w1 = WithIgnored(1, 100) + w2 = WithIgnored(1, 999) # ignored field differs + assert dc_eq(w1, w2) is True # Should still be equal since ignored field is not compared + + # Test with non-dataclass objects - should raise TypeError + class NotADataclass: + def __init__(self, x: int): + self.x = x + + with pytest.raises(TypeError): + dc_eq(NotADataclass(1), NotADataclass(1)) def test_FORMAT_KEY(): diff --git a/tests/unit/validate_type/test_validate_type.py b/tests/unit/validate_type/test_validate_type.py index 51e5e2ae..075069d4 100644 --- a/tests/unit/validate_type/test_validate_type.py +++ b/tests/unit/validate_type/test_validate_type.py @@ -82,7 +82,7 @@ def test_validate_type_any(value): ("hello", Union[int, str], True), (3.14, Union[int, float], True), (True, Union[int, str], True), - (None, Union[int, type(None)], True), + (None, Union[int, None], True), (None, Union[int, str], False), (5, Union[int, str], True), (5.0, Union[int, str], False), diff --git a/tests/unit/web/test_bundle_html.py b/tests/unit/web/test_bundle_html.py index 39438a76..820168d5 100644 --- a/tests/unit/web/test_bundle_html.py +++ b/tests/unit/web/test_bundle_html.py @@ -232,6 +232,7 @@ def test_tag_attr_override(site: dict[str, Path]) -> None: def test_cli_smoke(tmp_path: Path, site: dict[str, Path]) -> None: html_copy = tmp_path / "page.html" html_copy.write_text(site["html"].read_text()) + assert bundle_html.__file__ is not None exe = Path(bundle_html.__file__).resolve() subprocess.check_call( [sys.executable, str(exe), str(html_copy), "--output", str(html_copy)] @@ -470,6 +471,7 @@ def test_fragment_in_src_kept(tiny_site: dict[str, Path]) -> None: def test_cli_overwrite(tmp_path: Path, tiny_site: dict[str, Path]) -> None: copy = tmp_path / "page.html" copy.write_text(tiny_site["html"].read_text()) + assert bundle_html.__file__ is not None exe = Path(bundle_html.__file__).resolve() subprocess.check_call([sys.executable, str(exe), str(copy), "--output", str(copy)]) res = copy.read_text() From 3f19fe05c788229c9143062ac8a8dbb4a787e484 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Mon, 19 Jan 2026 15:59:07 -0700 Subject: [PATCH 71/72] make format --- muutils/dictmagic.py | 8 ++++++-- .../json_serialize/serializable_dataclass.py | 13 ++++++++++--- muutils/misc/typing_breakdown.py | 12 +++++++----- muutils/tensor_info.py | 4 +++- .../json_serialize/test_serializable_field.py | 4 +++- tests/unit/json_serialize/test_util.py | 19 +++++++++++++++---- 6 files changed, 44 insertions(+), 16 deletions(-) diff --git a/muutils/dictmagic.py b/muutils/dictmagic.py index 29e85a04..17e907e0 100644 --- a/muutils/dictmagic.py +++ b/muutils/dictmagic.py @@ -33,7 +33,9 @@ class DefaulterDict(typing.Dict[_KT, _VT], Generic[_KT, _VT]): """like a defaultdict, but default_factory is passed the key as an argument""" - def __init__(self, default_factory: Callable[[_KT], _VT], *args: Any, **kwargs: Any) -> None: + def __init__( + self, default_factory: Callable[[_KT], _VT], *args: Any, **kwargs: Any + ) -> None: if args: raise TypeError( f"DefaulterDict does not support positional arguments: *args = {args}" @@ -381,7 +383,9 @@ def condense_tensor_dict( data: TensorDict | TensorIterable, fmt: TensorDictFormats = "dict", *args: Any, - shapes_convert: Callable[[tuple[Union[int, str], ...]], Any] = _default_shapes_convert, + shapes_convert: Callable[ + [tuple[Union[int, str], ...]], Any + ] = _default_shapes_convert, drop_batch_dims: int = 0, sep: str = ".", dims_names_map: Optional[dict[int, str]] = None, diff --git a/muutils/json_serialize/serializable_dataclass.py b/muutils/json_serialize/serializable_dataclass.py index 3ad61692..d252467c 100644 --- a/muutils/json_serialize/serializable_dataclass.py +++ b/muutils/json_serialize/serializable_dataclass.py @@ -92,7 +92,9 @@ class NestedClass(SerializableDataclass): Self = TypeVar("Self") -T_SerializeableDataclass = TypeVar("T_SerializeableDataclass", bound="SerializableDataclass") +T_SerializeableDataclass = TypeVar( + "T_SerializeableDataclass", bound="SerializableDataclass" +) class CantGetTypeHintsWarning(UserWarning): @@ -111,7 +113,9 @@ class ZanjMissingWarning(UserWarning): "flag to keep track of if we have successfully imported ZANJ" -def zanj_register_loader_serializable_dataclass(cls: typing.Type[T_SerializeableDataclass]): +def zanj_register_loader_serializable_dataclass( + cls: typing.Type[T_SerializeableDataclass], +): """Register a serializable dataclass with the ZANJ import this allows `ZANJ().read()` to load the class and not just return plain dicts @@ -797,7 +801,10 @@ def serialize(self: Any) -> dict[str, Any]: # ====================================================================== # mypy thinks this isnt a classmethod @classmethod # type: ignore[misc] - def load(cls: type[T_SerializeableDataclass], data: dict[str, Any] | T_SerializeableDataclass) -> T_SerializeableDataclass: + def load( + cls: type[T_SerializeableDataclass], + data: dict[str, Any] | T_SerializeableDataclass, + ) -> T_SerializeableDataclass: # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ if isinstance(data, cls): return data diff --git a/muutils/misc/typing_breakdown.py b/muutils/misc/typing_breakdown.py index a5cee463..9776e82c 100644 --- a/muutils/misc/typing_breakdown.py +++ b/muutils/misc/typing_breakdown.py @@ -157,8 +157,12 @@ def to_toml(self) -> str: try: lines.append(f"total_warnings = {self.total_warnings}") except ValueError: - lines.append(f"total_warnings_by_type = {sum(self.warnings_by_type.values())}") - lines.append(f"total_warnings_by_file = {sum(self.warnings_by_file.values())}") + lines.append( + f"total_warnings_by_type = {sum(self.warnings_by_type.values())}" + ) + lines.append( + f"total_warnings_by_file = {sum(self.warnings_by_file.values())}" + ) lines.append("") # warnings by_type section @@ -235,9 +239,7 @@ def parse_basedpyright(content: str) -> TypeCheckResult: pending_diagnostic_type = diag_match.group(1) # Check if this is a continuation line with the code elif pending_diagnostic_type: - code_match: re.Match[str] | None = re.search( - r"\((\w+)\)\s*$", line - ) + code_match: re.Match[str] | None = re.search(r"\((\w+)\)\s*$", line) if code_match: error_code = code_match.group(1) if pending_diagnostic_type == "warning": diff --git a/muutils/tensor_info.py b/muutils/tensor_info.py index 0d25e7b7..945ead39 100644 --- a/muutils/tensor_info.py +++ b/muutils/tensor_info.py @@ -425,7 +425,9 @@ def colorize_dtype(dtype_str: str, colors: Dict[str, str], using_tex: bool) -> s return type_colored -def format_shape_colored(shape_val: Any, colors: Dict[str, str], using_tex: bool) -> str: +def format_shape_colored( + shape_val: Any, colors: Dict[str, str], using_tex: bool +) -> str: """Format shape with proper coloring for both 1D and multi-D arrays.""" def apply_color(text: str, color_key: str) -> str: diff --git a/tests/unit/json_serialize/test_serializable_field.py b/tests/unit/json_serialize/test_serializable_field.py index 11cc775f..05eed8b9 100644 --- a/tests/unit/json_serialize/test_serializable_field.py +++ b/tests/unit/json_serialize/test_serializable_field.py @@ -141,7 +141,9 @@ def test_from_Field(): # Test with default_factory and init=False to avoid init=True, serialize=False error dc_field2: dataclasses.Field[list[Any]] = field( # type: ignore[assignment] - default_factory=list, repr=True, init=True # type: ignore[arg-type] + default_factory=list, + repr=True, + init=True, # type: ignore[arg-type] ) sf2 = SerializableField.from_Field(dc_field2) assert sf2.default_factory == list # noqa: E721 diff --git a/tests/unit/json_serialize/test_util.py b/tests/unit/json_serialize/test_util.py index 1dae89c7..6c7fbf31 100644 --- a/tests/unit/json_serialize/test_util.py +++ b/tests/unit/json_serialize/test_util.py @@ -269,16 +269,25 @@ class Point2D: # Different classes, same fields, different values - should return False p2d_diff = Point2D(1, 99) assert ( - dc_eq(p2d_diff, p1, false_when_class_mismatch=False, except_when_field_mismatch=True) + dc_eq( + p2d_diff, + p1, + false_when_class_mismatch=False, + except_when_field_mismatch=True, + ) is False ) # Test parameter precedence: except_when_class_mismatch takes precedence over false_when_class_mismatch - with pytest.raises(TypeError, match="Cannot compare dataclasses of different classes"): + with pytest.raises( + TypeError, match="Cannot compare dataclasses of different classes" + ): dc_eq(p1, p3d, except_when_class_mismatch=True, false_when_class_mismatch=True) # Test parameter precedence: except_when_class_mismatch takes precedence over except_when_field_mismatch - with pytest.raises(TypeError, match="Cannot compare dataclasses of different classes"): + with pytest.raises( + TypeError, match="Cannot compare dataclasses of different classes" + ): dc_eq(p1, p3d, except_when_class_mismatch=True, except_when_field_mismatch=True) # Test with empty dataclasses @@ -305,7 +314,9 @@ class WithIgnored: w1 = WithIgnored(1, 100) w2 = WithIgnored(1, 999) # ignored field differs - assert dc_eq(w1, w2) is True # Should still be equal since ignored field is not compared + assert ( + dc_eq(w1, w2) is True + ) # Should still be equal since ignored field is not compared # Test with non-dataclass objects - should raise TypeError class NotADataclass: From db2f5ecf090f3889d70fe93ad99edf03a09666c2 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Mon, 19 Jan 2026 16:00:19 -0700 Subject: [PATCH 72/72] re-run typing summary --- .meta/typing-summary.txt | 34 +++++++++++++--------------------- 1 file changed, 13 insertions(+), 21 deletions(-) diff --git a/.meta/typing-summary.txt b/.meta/typing-summary.txt index ba02b48e..668975e2 100644 --- a/.meta/typing-summary.txt +++ b/.meta/typing-summary.txt @@ -1,29 +1,26 @@ # Showing all errors -# mypy: Found 5 errors in 2 files (checked 117 source files) -# basedpyright: 437 errors, 3173 warnings, 0 notes -# ty: Found 9 diagnostics +# mypy: Found 1 error in 1 file (checked 117 source files) +# basedpyright: 435 errors, 3172 warnings, 0 notes +# ty: Found 6 diagnostics [type_errors.mypy] -total_errors = 5 +total_errors = 1 [type_errors.mypy.by_type] -"method-assign" = 3 -"assignment" = 1 -"call-overload" = 1 +"arg-type" = 1 [type_errors.mypy.by_file] -"muutils/json_serialize/serializable_dataclass.py" = 4 -"muutils/dbg.py" = 1 +"tests/unit/json_serialize/test_serializable_field.py" = 1 [type_errors.basedpyright] -total_errors = 436 +total_errors = 434 [type_errors.basedpyright.by_type] "reportMissingTypeArgument" = 154 -"reportArgumentType" = 63 +"reportArgumentType" = 62 "reportInvalidTypeForm" = 48 -"reportCallIssue" = 39 +"reportCallIssue" = 38 "reportPossiblyUnboundVariable" = 34 "reportAttributeAccessIssue" = 26 "reportMissingSuperCall" = 14 @@ -73,7 +70,6 @@ total_errors = 436 "muutils/tensor_info.py" = 4 "tests/unit/json_serialize/test_util.py" = 4 "tests/unit/errormode/test_errormode_functionality.py" = 3 -"muutils/dbg.py" = 2 "muutils/nbutils/mermaid.py" = 2 "muutils/nbutils/run_notebook_tests.py" = 2 "tests/unit/json_serialize/serializable_dataclass/test_sdc_defaults.py" = 2 @@ -94,14 +90,14 @@ total_errors = 436 "tests/util/test_fire.py" = 1 [type_warnings.basedpyright] -total_warnings = 3173 +total_warnings = 3172 [type_warnings.basedpyright.by_type] "reportAny" = 758 "reportUnknownParameterType" = 419 "reportUnknownArgumentType" = 391 "reportMissingParameterType" = 316 -"reportUnknownVariableType" = 304 +"reportUnknownVariableType" = 303 "reportUnusedCallResult" = 288 "reportUnknownMemberType" = 235 "reportUnknownLambdaType" = 124 @@ -163,8 +159,8 @@ total_warnings = 3173 "muutils/misc/sequence.py" = 31 "muutils/sysinfo.py" = 31 "tests/unit/nbutils/test_configure_notebook.py" = 28 -"muutils/dbg.py" = 27 "tests/unit/misc/test_freeze.py" = 27 +"muutils/dbg.py" = 26 "tests/unit/math/test_matrix_powers_torch.py" = 26 "tests/unit/test_tensor_info_torch.py" = 26 "muutils/misc/func.py" = 25 @@ -218,16 +214,12 @@ total_warnings = 3173 "tests/unit/test_tensor_utils_torch.py" = 1 [type_errors.ty] -total_errors = 9 +total_errors = 6 [type_errors.ty.by_type] "invalid-assignment" = 4 -"no-matching-overload" = 3 "invalid-argument-type" = 2 [type_errors.ty.by_file] "muutils/interval.py" = 3 "muutils/misc/func.py" = 3 -"muutils/dbg.py" = 1 -"muutils/logger/logger.py" = 1 -"muutils/logger/simplelogger.py" = 1