From f3da504531946436a1f0a4ed97ea4378740de92b Mon Sep 17 00:00:00 2001 From: mattjhawken Date: Tue, 20 Jan 2026 23:41:14 -0500 Subject: [PATCH 01/25] Enhanced debug logs and added new version tag. --- .github/workflows/release.yaml | 5 + README.md | 62 ++++------ SMALL_README.md | 163 --------------------------- docs/{LIGHTPAPER.md => LITEPAPER.md} | 0 pyproject.toml | 10 +- tensorlink/ml/formatter.py | 5 - tensorlink/nodes/validator_thread.py | 7 +- tensorlink/nodes/worker_thread.py | 11 +- tensorlink/p2p/smart_node.py | 1 + tensorlink/p2p/torch_node.py | 123 ++++++++++++++++++-- 10 files changed, 150 insertions(+), 237 deletions(-) delete mode 100644 SMALL_README.md rename docs/{LIGHTPAPER.md => LITEPAPER.md} (100%) diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml index da44335..505d8a8 100644 --- a/.github/workflows/release.yaml +++ b/.github/workflows/release.yaml @@ -30,6 +30,11 @@ jobs: run: | poetry install --only main --no-interaction --no-ansi + - name: Sync version from git tag + run: | + VERSION=${GITHUB_REF_NAME#v} + poetry version "$VERSION" + # Build Python Package - name: Build wheel + sdist run: | diff --git a/README.md b/README.md index 1bb4ecb..4416f7e 100644 --- a/README.md +++ b/README.md @@ -28,22 +28,24 @@ ## What is Tensorlink? Tensorlink is a Python library and decentralized compute platform for running PyTorch and Hugging Face models across -peer-to-peer networks. It enables you to run, train, and serve large models securely across distributed hardware without relying on -centralized cloud inference providers. +peer-to-peer networks. It lets you run, train, and serve large models securely on distributed hardware without relying +on centralized cloud inference providers. + +With Tensorlink, models can be automatically sharded across multiple GPUs, enabling execution beyond local VRAM limits. +You can host models on your own devices, expose them through a REST API, stream tokens in real time, and optionally +route requests only to your own hardware for private usage. Tensorlink supports both distributed training with +optimizers and low-latency inference across the network. ### Key Features -- **Native PyTorch & REST API Access** – Use models directly in Python or via HTTP endpoints -- **Run Large Models Without Local VRAM** – Execute models that exceed your GPU capacity -- **Remote Access to Your Own Hardware** – Securely host and access models on your devices via API -- **Plug-and-Play Distributed Execution** – Automatic model sharding across multiple GPUs -- **Training & Inference Support** – Train models with distributed optimizers or run inference across the network -- **Streaming Generation** – Token-by-token streaming for real-time responses -- **Privacy Controls** – Route queries exclusively to your own hardware for private usage -- **Earn Rewards for Idle Compute** – Contribute GPU resources to the network and get compensated + +- **Native PyTorch & REST API Access** — Use models directly in Python or via HTTP endpoints. +- **Run Large Models** — Automatic offloading and model sharding across peers. +- **Plug-and-Play Distributed Execution** — No manual cluster setup required. +- **Streaming Generation** — Token-by-token responses for real-time apps. +- **Privacy Controls** — Route traffic exclusively to your own machines, or leverage hybrid models privacy enhanced model workflows. > **Early Access:** Tensorlink is under active development. APIs and internals may evolve. > [Join our Discord](https://discord.gg/aCW2kTNzJ2) for updates, support, and roadmap discussions. -> Learn more in the [**Litepaper**](docs/LITEPAPER.md) ## Quick Start @@ -73,7 +75,7 @@ model = DistributedModel( ) optimizer = model.create_optimizer(lr=0.001) ``` -> See [Examples](docs/examples) for streaming generation, distributed training, custom models, +> See [Examples](https://github.com/mattjhawken/tensorlink/blob/main/docs/examples) for streaming generation, distributed training, custom models, > and network configurations. ### Option 2: Accessing Models via HTTP @@ -95,7 +97,7 @@ response = requests.post( print(response.json()) ``` ->Access the public network or configure your own hardware for private API access. See [Examples](docs/examples) for +>Access the public network or configure your own hardware for private API access. See [Examples](https://github.com/mattjhawken/tensorlink/blob/main/docs/examples) for >streaming, chat completions, and API reference. ### Option 3: Run a Node @@ -104,13 +106,13 @@ Run Tensorlink nodes to host models, shard workloads across GPUs, and expose the Nodes can act as workers (run models), validators (route requests + expose API), or both. This allows you to build private clusters, public compute providers, or local development environments. -1. Download the latest `tensorlink-node` from [Releases](releases) +1. Download the latest `tensorlink-node` from [Releases](https://github.com/mattjhawken/tensorlink/releases) 2. Edit `config.json` to configure your nodes. 3. Run: `./run-node.sh` > By default, the config is set for running a public worker node. Your GPU will process network workloads and earn -> rewards via the networking layer ([Smartnodes](https://smartnodes.ca)). See [Examples](docs/examples) for different -> device and network configurations. +> rewards via the networking layer ([Smartnodes](https://smartnodes.ca)). See [Examples](https://github.com/mattjhawken/tensorlink/blob/main/docs/examples) +> for different device and network configurations. --- @@ -147,7 +149,7 @@ running a public worker node. | `mining_script` | `str` | Path to mining / GPU workload executable | | `seed_validators` | `List[List[str, int, str]]` | Path to mining / GPU workload executable | -> For common configuration recipes and examples, see [**Examples: Node Configuration**](docs/examples/EXAMPLES.md#node-configuration-examples) +> For common configuration recipes and examples, see [**Examples: Node Configuration**](https://github.com/mattjhawken/tensorlink/blob/main/docs/examples/EXAMPLES.md#node-configuration-examples) --- @@ -186,15 +188,6 @@ Simple generation endpoint with flexible output formats. | `history` | array | `null` | Chat history for multi-turn conversations | | `is_chat_completion` | bool | `false` | Determines whether to format chat output | -# In _generate_streaming: -should_filter = request.is_chat_completion - -# Or if you want finer control: -should_filter = ( - request.is_chat_completion or - (request.input_format == "chat" and request.output_format == "openai") -) - #### Example: Basic Generation ```python @@ -379,7 +372,7 @@ import requests r = requests.post( "http://localhost:64747/request-model", - json={"hf_name": "Qwen/Qwen2.5-7B-Instruct"} + json={"hf_name": "Qwen/Qwen3-8B"} ) print(r.json()) @@ -396,25 +389,20 @@ models may appear. Please report any bugs via [Issues](https://github.com/mattjh - **Token IDs**: Automatically handles missing pad/eos tokens with safe fallbacks - **Format Control**: Use `input_format="chat"` and `output_format="openai"` for seamless integration -> For complete examples, error handling, and advanced usage, see [**Examples: HTTP API**](docs/examples/EXAMPLES.md#http-api-examples) +> For complete examples, error handling, and advanced usage, see [**Examples: HTTP API**](https://github.com/mattjhawken/tensorlink/blob/main/docs/examples/EXAMPLES.md#http-api-examples) --- ## Learn More - 📚 **[Documentation](https://smartnodes.ca/tensorlink/docs)** – Full API reference and guides -- 🎯 **[Examples](docs/examples/EXAMPLES.md)** – Comprehensive usage patterns and recipes +- 🎯 **[Examples](https://github.com/mattjhawken/tensorlink/blob/main/docs/examples/EXAMPLES.md)** – Comprehensive usage patterns and recipes - 💬 **[Discord Community](https://discord.gg/aCW2kTNzJ2)** – Get help and connect with developers -- 🎮 **[Live Demo](https://smartnodes.ca/localhostGPT)** – Try localhostGPT powered by Tensorlink -- 📘 **[Litepaper](docs/LITEPAPER.md)** – Technical overview and architecture +- 🎮 **[Live Demo](https://smartnodes.ca/tensorlink)** – Try the chatbot demo powered by a model on Tensorlink +- 📘 **[Litepaper](https://github.com/mattjhawken/tensorlink/blob/main/docs/LITEPAPER.md)** – Technical overview and architecture ## Contributing -We welcome contributions! Here's how to get involved: - -- **Report bugs** via [GitHub Issues](https://github.com/mattjhawken/tensorlink/issues) -- **Suggest features** on our [Discord](https://discord.gg/aCW2kTNzJ2) -- **Submit PRs** to improve code or documentation -- **Support the project** via [Buy Me a Coffee](https://www.buymeacoffee.com/smartnodes) +Read our [contirbution guide.](https://github.com/mattjhawken/tensorlink/blob/main/.github/CONTRIBUTING.md) Tensorlink is released under the [MIT License](LICENSE). diff --git a/SMALL_README.md b/SMALL_README.md deleted file mode 100644 index be6ec97..0000000 --- a/SMALL_README.md +++ /dev/null @@ -1,163 +0,0 @@ -# Tensorlink - -**Peer-to-peer AI Inference & Distributed Execution with PyTorch** - -## What is Tensorlink? - -Tensorlink is a Python library and decentralized compute platform for running PyTorch and Hugging Face -models across peer-to-peer networks. It provides a compelling alternative to centralized cloud providers, -allowing you to run, train, and serve large models across distributed hardware. - -## Key Features - -- **Native PyTorch & REST API Access** – Use models directly in Python or via HTTP endpoints -- **Run Large Models Without Local VRAM** – Execute models that exceed your GPU capacity -- **Remote Access to Your Own Hardware** – Securely host and access models on your devices -- **Plug-and-Play Distributed Execution** – Automatic model sharding across multiple GPUs -- **Training & Inference Support** – Train models with distributed optimizers or run inference -- **Streaming Generation** – Token-by-token streaming for real-time responses -- **Privacy Controls** – Route queries exclusively to your own hardware -- **Earn Rewards for Idle Compute** – Contribute GPU resources and get compensated - -> **Note:** Tensorlink is under active development. APIs may evolve. - -## Installation - -```bash -pip install tensorlink -``` - -**Requirements:** Python 3.10+, PyTorch 2.3+, UNIX/MacOS (Windows: use WSL) - -## Quick Start - -### Option 1: Distributed Models in Python - -Execute Hugging Face models distributed across the network: - -```python -from tensorlink.ml import DistributedModel - -# Connect to a model on the network -model = DistributedModel( - model="Qwen/Qwen2.5-7B-Instruct", - training=True, - device="cuda" -) - -# Optimizer instantiation method -optimizer = model.create_optimizer(optimizer_type="Adam", lr=0.001) - -# Use like any PyTorch model and optimizer -``` - -### Option 2: HTTP API Access - -Access models via OpenAI-compatible HTTP endpoints: - -```python -import requests - -# Simple generation -response = requests.post( - "http://smartnodes.ddns.net/tensorlink-api/v1/generate", - json={ - "hf_name": "Qwen/Qwen2.5-7B-Instruct", - "message": "Explain quantum computing in one sentence.", - "max_new_tokens": 50, - } -) - -print(response.json()["generated_text"]) -``` - -### Option 3: Chat Completions (OpenAI-Compatible) - -```python -import requests - -response = requests.post( - "http://localhost:64747/v1/chat/completions", - json={ - "model": "Qwen/Qwen2.5-7B-Instruct", - "messages": [ - {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": "Explain quantum computing."} - ], - "max_tokens": 128, - "temperature": 0.7, - "stream": False - } -) - -print(response.json()["choices"][0]["message"]["content"]) -``` - -### Streaming Responses - -```python -response = requests.post( - "http://localhost:64747/v1/chat/completions", - json={ - "model": "Qwen/Qwen2.5-7B-Instruct", - "messages": [{"role": "user", "content": "Write a story about AI"}], - "stream": True - }, - stream=True -) - -for line in response.iter_lines(): - if line and line.startswith(b"data: "): - chunk = line.decode()[6:] - if chunk != "[DONE]": - import json - data = json.loads(chunk) - content = data["choices"][0]["delta"].get("content", "") - print(content, end="", flush=True) -``` - -## API Endpoints - -- `POST /v1/generate` – Simple text generation -- `POST /v1/chat/completions` – OpenAI-compatible chat interface -- `POST /request-model` – Preload models across the network - -## Running Your Own Node - -Host models on your own hardware and expose them via API: - -1. Download the latest release from [GitHub](https://github.com/mattjhawken/tensorlink/releases) -2. Configure `config.json` for your setup -3. Run: `./run-node.sh` - -By default, nodes contribute to the public network and earn rewards. Configure private mode to use only your own devices. - -## Documentation & Resources - -- **Full Documentation:** [smartnodes.ca/tensorlink/docs](https://smartnodes.ca/tensorlink/docs) -- **Examples & Guides:** [docs/examples](https://github.com/mattjhawken/tensorlink/tree/main/docs/examples) -- **GitHub Repository:** [github.com/mattjhawken/tensorlink](https://github.com/mattjhawken/tensorlink) -- **Discord Community:** [discord.gg/aCW2kTNzJ2](https://discord.gg/aCW2kTNzJ2) -- **Live Demo:** [smartnodes.ca/localhostGPT](https://smartnodes.ca/localhostGPT) -- **Litepaper:** [Technical Overview](https://github.com/mattjhawken/tensorlink/blob/main/docs/LITEPAPER.md) - -## Use Cases - -- **Researchers:** Run large models without expensive cloud compute -- **Developers:** Build AI applications with distributed inference -- **Organizations:** Deploy private AI infrastructure across your devices -- **GPU Owners:** Monetize idle compute resources -- **Startups:** Scale AI services without infrastructure costs - -## Contributing - -We welcome contributions! - -- Report bugs via [GitHub Issues](https://github.com/mattjhawken/tensorlink/issues) -- Suggest features on [Discord](https://discord.gg/aCW2kTNzJ2) -- Submit pull requests to improve code or documentation -- Support the project via [Buy Me a Coffee](https://www.buymeacoffee.com/smartnodes) - -## License - -Tensorlink is released under the [MIT License](https://github.com/mattjhawken/tensorlink/blob/main/LICENSE). diff --git a/docs/LIGHTPAPER.md b/docs/LITEPAPER.md similarity index 100% rename from docs/LIGHTPAPER.md rename to docs/LITEPAPER.md diff --git a/pyproject.toml b/pyproject.toml index 24cea15..dd14bbe 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,17 +1,17 @@ [tool.poetry] name = "tensorlink" -version = "0.2.0.post1" +version = "0.3.0-rc.1" description = "Tensorlink is a decentralized Python library for distributed PyTorch models, providing easy Hugging Face API access and enabling users to share compute resources across a global network." authors = [ "Smartnodes Lab ", "Matthew Hawken", "Nikita Vassilyev" ] -readme = "SMALL_README.md" +readme = "README.md" packages = [{include = "tensorlink"}] -homepage = "https://smartnodes.ca/tensorlink" -documentation = "https://smartnodes.ca/tensorlink/docs" -repository = "https://github.com/smartnodes-lab/tensorlink" +homepage = "https://tensorlink.io" +documentation = "https://tensorlink.io/docs" +repository = "https://github.com/mattjhawken/tensorlink" license = "MIT" classifiers = [ "Development Status :: 3 - Alpha", diff --git a/tensorlink/ml/formatter.py b/tensorlink/ml/formatter.py index 5c635f0..6835d80 100644 --- a/tensorlink/ml/formatter.py +++ b/tensorlink/ml/formatter.py @@ -121,12 +121,7 @@ def normalize_generate_args( if top_p is not None: top_p = max(0.0, min(top_p, 1.0)) - # OPTIONAL EXTENSIONS - reasoning = getattr(request, "reasoning", None) - enable_thinking = getattr(request, "enable_thinking", None) - # BUILD ARGS DICT and FILTER BY GENERATE SIGNATURE - # Build args dict as before args = { "pad_token_id": pad_token_id, "eos_token_id": eos_token_id, diff --git a/tensorlink/nodes/validator_thread.py b/tensorlink/nodes/validator_thread.py index 520585d..7521ff0 100644 --- a/tensorlink/nodes/validator_thread.py +++ b/tensorlink/nodes/validator_thread.py @@ -970,7 +970,7 @@ def run(self): self.clean_port_mappings() self.get_workers() if counter % 180 == 0: - self.print_status() + self.print_ui_status() time.sleep(1) counter += 1 @@ -979,11 +979,6 @@ def stop(self): self.keeper.write_state() super().stop() - def print_status(self): - self.print_base_status() - print(f" Current Proposal: {self.current_proposal}") - print("=============================================\n") - def get_tensorlink_status(self): # Path to package root (where this file lives) base_dir = os.path.dirname(os.path.abspath(__file__)) diff --git a/tensorlink/nodes/worker_thread.py b/tensorlink/nodes/worker_thread.py index b9c8584..a2ed716 100644 --- a/tensorlink/nodes/worker_thread.py +++ b/tensorlink/nodes/worker_thread.py @@ -5,7 +5,6 @@ import torch.nn as nn from dotenv import get_key -import psutil import hashlib import json import logging @@ -60,9 +59,7 @@ def __init__( level=logging.INFO, tag="Worker", ) - self.available_gpu_memory = get_gpu_memory() - self.total_gpu_memory = self.available_gpu_memory - self.available_ram = psutil.virtual_memory().available + self.mining_active = mining_active self.reserved_memory = reserved_memory @@ -210,7 +207,7 @@ def run(self): if counter % 180 == 0: self.keeper.clean_node() self.clean_port_mappings() - self.print_status() + self.print_ui_status() time.sleep(1) counter += 1 @@ -267,7 +264,3 @@ def handle_statistics_request(self, callee, additional_context: dict = None): def activate(self): self.training = True - - def print_status(self): - self.print_base_status() - print("=============================================\n") diff --git a/tensorlink/p2p/smart_node.py b/tensorlink/p2p/smart_node.py index e306aeb..b33d480 100644 --- a/tensorlink/p2p/smart_node.py +++ b/tensorlink/p2p/smart_node.py @@ -273,6 +273,7 @@ def __init__( self.upnp = False self.public_key = None + self._start_time = time.time() # DHT Storage self.dht = DHT(self) diff --git a/tensorlink/p2p/torch_node.py b/tensorlink/p2p/torch_node.py index 02cf09f..cf6710d 100644 --- a/tensorlink/p2p/torch_node.py +++ b/tensorlink/p2p/torch_node.py @@ -1,5 +1,3 @@ -import os - from tensorlink.ml.utils import get_gpu_memory from tensorlink.nodes.shared_memory import get_from_shared_memory from tensorlink.p2p.connection import Connection @@ -9,8 +7,9 @@ import logging import queue import threading -import time import json +import os +import time import psutil @@ -18,6 +17,39 @@ MSG_STREAM_END = b"END__" +def _bar(current, total, width=20): + if total <= 0: + return "?" * width + ratio = min(max(current / total, 0), 1) + filled = int(width * ratio) + return "█" * filled + "░" * (width - filled) + + +def _fmt_gb(x): + return f"{x / 1e9:.2f}" + + +def _uptime(start_time): + s = int(time.time() - start_time) + h, s = divmod(s, 3600) + m, s = divmod(s, 60) + return f"{h:02}:{m:02}:{s:02}" + + +class ANSI: + RESET = "\033[0m" + BOLD = "\033[1m" + DIM = "\033[2m" + + CYAN = "\033[36m" + GREEN = "\033[32m" + YELLOW = "\033[33m" + RED = "\033[31m" + BLUE = "\033[34m" + MAGENTA = "\033[35m" + GRAY = "\033[90m" + + def format_size(size_bytes): """ Format the size to display in GB, MB, or KB with one decimal place. @@ -59,6 +91,8 @@ def __init__( # Available GPU mpc estimation self.available_gpu_memory = get_gpu_memory() + self.total_gpu_memory = self.available_gpu_memory + self.available_ram = psutil.virtual_memory().available self._mpc_comms = None self.memory_manager = {} @@ -853,14 +887,79 @@ def _stop_mpc_comms(self): self.debug_print("Shutting down distributed ML processes...", tag="Torchnode") self._mpc_comms.join() - def print_base_status(self): + def print_ui_status(self): + used_vram = self.total_gpu_memory - self.available_gpu_memory + total_vram = self.total_gpu_memory + actual_vram = get_gpu_memory() + + ram = psutil.virtual_memory() + used_ram = ram.total - ram.available + + streams = len(getattr(self, "stream_buffers", {})) + modules = len(self.modules) + + in_q = len(getattr(self, "endpoint_requests", {}).get("incoming", [])) + out_q = len(getattr(self, "endpoint_requests", {}).get("outgoing", [])) + + def c(label, colour): + return f"{colour}{label}{ANSI.RESET}" + + def line(label, value, colour=ANSI.CYAN): + return f"{c(label + ':', ANSI.DIM):<16} {colour}{value}{ANSI.RESET}" + + width = 80 + sep = f"{ANSI.DIM}{'─' * width}{ANSI.RESET}" + + # --- Header --- + role_name = "Validator" if self.role.startswith("V") else "Worker" + title = f" Tensorlink {role_name} Node " + + print() + print(sep) + print(f"{ANSI.BOLD}{ANSI.MAGENTA}{title.center(width)}{ANSI.RESET}") + print(sep) + + # --- Identity --- + print(line("Node ID", self.rsa_key_hash, ANSI.YELLOW)) + print(line("Address", f"{self.host}:{self.port}", ANSI.GREEN)) + print(line("Uptime", _uptime(self._start_time), ANSI.BLUE)) + + # --- Network --- + print(sep) + print(line("Connections", len(self.nodes), ANSI.CYAN)) + print(line(" Workers", len(self.workers), ANSI.CYAN)) + print(line(" Validators", len(self.validators), ANSI.CYAN)) + print(line(" Users", len(self.users), ANSI.CYAN)) + + # --- Resources --- + print(sep) + + vram_bar_estimate = _bar(total_vram - used_vram, total_vram) + vram_bar = _bar(total_vram - actual_vram, total_vram) + ram_bar = _bar(used_ram, ram.total) + + print( + f"{ANSI.DIM}{'VRAM':<14}:{ANSI.RESET} " + f"{ANSI.MAGENTA}[{vram_bar_estimate}]{ANSI.RESET} " + f"{ANSI.MAGENTA}[{vram_bar}]{ANSI.RESET} " + f"{ANSI.YELLOW}{_fmt_gb(used_vram)} / {_fmt_gb(total_vram)} GB{ANSI.RESET}" + ) + print( - f"\n=========== Node Status Report ({'Worker' if self.role == 'W' else 'Validator'}) ===========" + f"{ANSI.DIM}{'RAM':<14}:{ANSI.RESET} " + f"{ANSI.GREEN}[{ram_bar}]{ANSI.RESET} " + f"{ANSI.YELLOW}{_fmt_gb(used_ram)} / {_fmt_gb(ram.total)} GB{ANSI.RESET}" ) - print(f" Node ID: {self.rsa_key_hash} ({self.host}:{self.port})") - print(f" Connections: {len(self.nodes)}") - print(f" Workers: {self.workers}") - print(f" Validators: {self.validators}") - print(f" Users: {self.users}") - print(f" VRAM Available: {self.available_gpu_memory / 1e9:.2f} GB") - print(f" RAM Available: {psutil.virtual_memory().available / 1e9:.2f} GB") + + print(line("Modules", modules, ANSI.MAGENTA)) + + # --- Validator --- + if self.role.startswith("V"): + print(sep) + print(line("Proposal ID", self.current_proposal, ANSI.YELLOW)) + print(line("Streams", streams, ANSI.BLUE)) + print(line("API Jobs", f"in={in_q} out={out_q}", ANSI.CYAN)) + print(line("Queues", f"in={in_q} out={out_q}", ANSI.CYAN)) + + print(sep) + print() From e0400a5fe7352eac619613811c1bf4180acbb12c Mon Sep 17 00:00:00 2001 From: mattjhawken Date: Wed, 21 Jan 2026 10:15:13 -0500 Subject: [PATCH 02/25] Enhanced debug logging and some more reasoning logic segmentation for API requests (attempting reasoning toggles for API generate output #99). --- tensorlink/ml/formatter.py | 131 ++++++++++++++++++++++------------- tensorlink/ml/validator.py | 63 +++++++++++------ tensorlink/p2p/torch_node.py | 18 +++-- 3 files changed, 133 insertions(+), 79 deletions(-) diff --git a/tensorlink/ml/formatter.py b/tensorlink/ml/formatter.py index 6835d80..fd49f4f 100644 --- a/tensorlink/ml/formatter.py +++ b/tensorlink/ml/formatter.py @@ -132,10 +132,6 @@ def normalize_generate_args( } if top_p is not None and do_sample: args["top_p"] = top_p - if getattr(request, "reasoning", None) is not None: - args["reasoning"] = request.reasoning - if getattr(request, "enable_thinking", None) is not None: - args["enable_thinking"] = request.enable_thinking # Filter based on allowed kwargs if allowed_generate_args is not None: @@ -146,108 +142,121 @@ def normalize_generate_args( return args -def extract_assistant_response(text: str, model_name: str = None) -> str: +def extract_reasoning_and_answer(text: str): """ - Universal extractor that removes system/user/thought tags and returns - the final human-readable assistant response. + Extract reasoning blocks and clean answer. + Returns (reasoning, answer) """ - # Remove reasoning or hidden thought blocks (e.g. ...) - text = re.sub( - r"<\s*(think|reflection|thought|internal|analysis)\s*>.*?<\s*/\1\s*>", - "", + reasoning_blocks = [] + + def _collect(match): + reasoning_blocks.append(match.group(0)) + return "" + + # Capture , , etc + cleaned = re.sub( + r"<\s*(think|reflection|thought|internal|analysis)\s*>(.*?)<\s*/\1\s*>", + lambda m: _collect(m), text, flags=re.DOTALL | re.IGNORECASE, ) - # Remove common chat tags used by newer models - text = re.sub(r"<\|im_start\|>\s*\w+\s*", "", text) - text = re.sub(r"<\|im_end\|>", "", text) - text = re.sub(r"<\|assistant\|>", "", text) - text = re.sub(r"<\|user\|>", "", text) - text = re.sub(r"<\|system\|>", "", text) + reasoning = "\n\n".join( + re.sub(r"<[^>]+>", "", b).strip() for b in reasoning_blocks + ).strip() - # Strip out any prefixes like "assistant:" or "Assistant:" - text = re.sub(r"(?i)\bassistant\s*[::]\s*", "", text) + # Clean scaffolding + cleaned = re.sub(r"<\|im_start\|>\s*\w+\s*", "", cleaned) + cleaned = re.sub(r"<\|im_end\|>", "", cleaned) + cleaned = re.sub(r"<\|assistant\|>", "", cleaned) + cleaned = re.sub(r"<\|user\|>", "", cleaned) + cleaned = re.sub(r"<\|system\|>", "", cleaned) + cleaned = re.sub(r"(?i)\bassistant\s*[::]\s*", "", cleaned) + cleaned = re.sub(r"(?i)\b(system|user)\s*[::]\s*", "", cleaned) - # Remove lingering system/user scaffolding - text = re.sub(r"(?i)\b(system|user)\s*[::]\s*", "", text) - text = text.strip().replace("\r", "") + cleaned = cleaned.strip().replace("\r", "") - # If multiple paragraphs, prefer the last coherent chunk - # (models sometimes prepend hidden reasoning) - if "\n\n" in text: - parts = [p.strip() for p in text.split("\n\n") if len(p.strip()) > 10] + if "\n\n" in cleaned: + parts = [p.strip() for p in cleaned.split("\n\n") if len(p.strip()) > 10] if parts: - text = parts[-1] + cleaned = parts[-1] - # Fallback: if text still empty, just return as-is (safe default) - return text.strip() or "[No output produced]" + return reasoning, cleaned or "[No output produced]" -def format_chat_prompt(model_name, current_message, history): - """Format the chat history and current message into a prompt suitable for the specified model.""" +def format_chat_prompt(model_name, current_message, history, enable_thinking=True): + """ + Format the chat history and current message into a prompt suitable for + the specified model. + Args: + model_name: Name of the model + current_message: Current user message + history: Conversation history + enable_thinking: Whether to allow reasoning/thinking tokens + """ # Different models require different formatting if "Qwen" in model_name: - # Qwen-specific formatting system_prompt = ( "You are a helpful assistant. Respond directly to the user's questions." ) + # Modify system prompt to discourage thinking if disabled + if not enable_thinking: + system_prompt += " Provide concise, direct answers without showing your reasoning process." + formatted_prompt = f"<|im_start|>system\n{system_prompt}<|im_end|>\n" - # Add conversation history if history and len(history) > 0: for msg in history: role = msg["role"] content = msg["content"] formatted_prompt += f"<|im_start|>{role}\n{content}<|im_end|>\n" - # Add the current message formatted_prompt += f"<|im_start|>user\n{current_message}<|im_end|>\n" formatted_prompt += "<|im_start|>assistant\n" return formatted_prompt elif "llama" in model_name.lower(): - # Llama-style formatting system_prompt = ( "You are a helpful assistant. Respond directly to the user's questions." ) + + if not enable_thinking: + system_prompt += " Provide concise, direct answers without showing your reasoning process." + formatted_prompt = f"[INST] <>\n{system_prompt}\n<>\n\n" - # Add conversation history if history and len(history) > 0: for i, msg in enumerate(history): if msg["role"] == "user": if i > 0: formatted_prompt += "[/INST]\n\n[INST] " formatted_prompt += f"{msg['content']}" - else: # assistant + else: formatted_prompt += f" [/INST]\n\n{msg['content']}\n\n[INST] " - # Add the current message and prepare for response formatted_prompt += f"{current_message} [/INST]\n\n" - return formatted_prompt else: - # Generic formatting for other models system_prompt = ( "You are a helpful assistant. Respond directly to the user's questions." ) + + if not enable_thinking: + system_prompt += " Provide concise, direct answers without showing your reasoning process." + formatted_prompt = f"System: {system_prompt}\n\n" - # Add conversation history if history and len(history) > 0: for msg in history: role_prefix = "User: " if msg["role"] == "user" else "Assistant: " formatted_prompt += f"{role_prefix}{msg['content']}\n\n" - # Add the current message formatted_prompt += f"User: {current_message}\n\nAssistant: " - return formatted_prompt @@ -304,6 +313,7 @@ def format_non_streaming_response( prompt_tokens: int, completion_tokens: int, start_time: float, + reasoning_text: Optional[str] = None, ) -> Dict[str, Any]: """ Format a complete non-streaming response. @@ -314,6 +324,7 @@ def format_non_streaming_response( prompt_tokens: Number of tokens in the prompt completion_tokens: Number of tokens generated start_time: Generation start timestamp + reasoning_text: Extracted reasoning/thinking text (if any) Returns: Formatted response dict based on output_format @@ -321,7 +332,7 @@ def format_non_streaming_response( processing_time = time.time() - start_time if request.output_format == "openai": - return { + response = { "id": str(request.id), "object": "chat.completion", "created": int(start_time), @@ -340,9 +351,15 @@ def format_non_streaming_response( }, "processing_time": processing_time, } + + # Add reasoning to message if present + if reasoning_text: + response["choices"][0]["message"]["reasoning"] = reasoning_text + + return response + elif request.output_format == "simple": - # Simple format with metadata - return { + response = { "id": str(request.id), "model": request.hf_name, "text": output_text, @@ -354,9 +371,18 @@ def format_non_streaming_response( "processing_time": processing_time, "finish_reason": "stop", } + + # Add reasoning as separate field if present + if reasoning_text: + response["reasoning"] = reasoning_text + + return response else: - # Raw format - just the text (legacy compatibility) - return {"text": output_text} + # Raw format + response = {"text": output_text} + if reasoning_text: + response["reasoning"] = reasoning_text + return response @staticmethod def format_stream_chunk( @@ -391,7 +417,6 @@ def format_stream_chunk( ], } else: - # Simple streaming format chunk_data = { "id": str(request.id), "model": request.hf_name, @@ -409,6 +434,7 @@ def format_stream_final( prompt_tokens: int, completion_tokens: int, full_text: Optional[str] = None, + reasoning_text: Optional[str] = None, ) -> str: """ Format the final streaming chunk with usage stats. @@ -436,6 +462,10 @@ def format_stream_final( "total_tokens": prompt_tokens + completion_tokens, }, } + + if reasoning_text: + final_data["reasoning"] = reasoning_text + return f"data: {json.dumps(final_data)}\n\ndata: [DONE]\n\n" else: # Simple format final chunk @@ -452,6 +482,9 @@ def format_stream_final( if full_text is not None: final_data["full_text"] = full_text + if reasoning_text: + final_data["reasoning"] = reasoning_text + return f"data: {json.dumps(final_data)}\n\ndata: [DONE]\n\n" @staticmethod diff --git a/tensorlink/ml/validator.py b/tensorlink/ml/validator.py index 3f77476..c790b66 100644 --- a/tensorlink/ml/validator.py +++ b/tensorlink/ml/validator.py @@ -1,4 +1,5 @@ import inspect +import re from tensorlink.ml.graphing import ModelParser from tensorlink.ml.worker import DistributedWorker @@ -9,7 +10,7 @@ format_chat_prompt, format_stream_chunk, format_stream_final, - extract_assistant_response, + extract_reasoning_and_answer, ) from tensorlink.ml.utils import load_models_cache, save_models_cache from tensorlink.api.models import GenerationRequest @@ -577,7 +578,13 @@ def _generate(self, request, job_id, start_time): # FORMAT PROMPT if request.input_format == "chat": formatted_prompt = format_chat_prompt( - request.hf_name, request.message, request.history + request.hf_name, + request.message, + request.history, + enable_thinking=( + getattr(request, "enable_thinking", False) + or getattr(request, "reasoning", False) + ), ) else: formatted_prompt = request.message @@ -585,13 +592,13 @@ def _generate(self, request, job_id, start_time): # TOKENIZE # Get model's max length model_max_length = getattr(tokenizer, 'model_max_length', 2048) - if model_max_length > 1000000: + if model_max_length > 100000: model_max_length = 2048 # Tokenize with appropriate max_length max_length = min( getattr(request, 'max_length', 512), - model_max_length - 10, # Leave room for generation + model_max_length - 10, ) inputs = tokenizer( @@ -628,7 +635,6 @@ def _generate(self, request, job_id, start_time): # GENERATE with torch.no_grad(): try: - print(f"ARGS: {args}") outputs = distributed_model.generate( input_ids, # max_new_tokens=args["max_new_tokens"], @@ -660,9 +666,13 @@ def _generate(self, request, job_id, start_time): else: text = generated_text.strip() - # Clean for chat models + reasoning_text = None if request.input_format == "chat": - text = extract_assistant_response(text, request.hf_name) + reasoning_text, text = extract_reasoning_and_answer(text) + + # Respect enable_thinking flag + if not getattr(request, "enable_thinking", True): + reasoning_text = None request.output = text @@ -672,6 +682,7 @@ def _generate(self, request, job_id, start_time): request.formatted_response = ResponseFormatter.format_non_streaming_response( request=request, output_text=text, + reasoning_text=reasoning_text, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, start_time=start_time, @@ -693,14 +704,20 @@ def _generate_streaming(self, request: GenerationRequest, job_id: str): # Format input if request.input_format == "chat": formatted_prompt = format_chat_prompt( - request.hf_name, request.message, request.history + request.hf_name, + request.message, + request.history, + enable_thinking=( + getattr(request, "enable_thinking", False) + or getattr(request, "reasoning", False) + ), ) else: formatted_prompt = request.message # Tokenize model_max_length = getattr(tokenizer, 'model_max_length', 2048) - if model_max_length > 1000000: + if model_max_length > 100000: model_max_length = 2048 max_length = min(getattr(request, 'max_length', 512), model_max_length - 10) @@ -737,19 +754,17 @@ def _generate_streaming(self, request: GenerationRequest, job_id: str): return # Build generation kwargs - generation_kwargs = { - "input_ids": input_ids, - "stream": True, - "max_new_tokens": args["max_new_tokens"], - "temperature": args["temperature"], - # "pad_token_id": args["pad_token_id"], - # "eos_token_id": args["eos_token_id"], - "do_sample": args["do_sample"], - "num_beams": args["num_beams"], - } - - if "top_p" in args: - generation_kwargs["top_p"] = args["top_p"] + generation_kwargs = {"input_ids": input_ids, "stream": True, **args} + # generation_kwargs = { + # "input_ids": input_ids, + # "stream": True, + # "max_new_tokens": args["max_new_tokens"], + # "temperature": args["temperature"], + # # "pad_token_id": args["pad_token_id"], + # # "eos_token_id": args["eos_token_id"], + # "do_sample": args["do_sample"], + # "num_beams": args["num_beams"], + # } # Setup streamer if isinstance(distributed_model.model, OffloadedModule): @@ -792,9 +807,11 @@ def _generate_streaming(self, request: GenerationRequest, job_id: str): (request.id, {"chunk": formatted_chunk, "done": False}), ) + reasoning_text = None + # Clean output if request.input_format == "chat": - cleaned_text = extract_assistant_response(full_text, request.hf_name) + cleaned_text = extract_reasoning_and_answer(full_text) else: cleaned_text = full_text diff --git a/tensorlink/p2p/torch_node.py b/tensorlink/p2p/torch_node.py index cf6710d..e35921f 100644 --- a/tensorlink/p2p/torch_node.py +++ b/tensorlink/p2p/torch_node.py @@ -888,9 +888,9 @@ def _stop_mpc_comms(self): self._mpc_comms.join() def print_ui_status(self): - used_vram = self.total_gpu_memory - self.available_gpu_memory total_vram = self.total_gpu_memory - actual_vram = get_gpu_memory() + used_vram_est = total_vram - self.available_gpu_memory + used_vram_actual = total_vram - get_gpu_memory() ram = psutil.virtual_memory() used_ram = ram.total - ram.available @@ -934,15 +934,19 @@ def line(label, value, colour=ANSI.CYAN): # --- Resources --- print(sep) - vram_bar_estimate = _bar(total_vram - used_vram, total_vram) - vram_bar = _bar(total_vram - actual_vram, total_vram) + vram_bar_estimate = _bar(used_vram_est, total_vram) + vram_bar_actual = _bar(used_vram_actual, total_vram) ram_bar = _bar(used_ram, ram.total) print( - f"{ANSI.DIM}{'VRAM':<14}:{ANSI.RESET} " + f"{ANSI.DIM}{'VRAM EST.':<14}:{ANSI.RESET} " f"{ANSI.MAGENTA}[{vram_bar_estimate}]{ANSI.RESET} " - f"{ANSI.MAGENTA}[{vram_bar}]{ANSI.RESET} " - f"{ANSI.YELLOW}{_fmt_gb(used_vram)} / {_fmt_gb(total_vram)} GB{ANSI.RESET}" + f"{ANSI.YELLOW}{_fmt_gb(used_vram_est)} / {_fmt_gb(total_vram)} GB{ANSI.RESET}" + ) + print( + f"{ANSI.DIM}{'VRAM ACT.':<14}:{ANSI.RESET} " + f"{ANSI.MAGENTA}[{vram_bar_actual}]{ANSI.RESET} " + f"{ANSI.YELLOW}{_fmt_gb(used_vram_actual)} / {_fmt_gb(total_vram)} GB{ANSI.RESET}" ) print( From 697d1ab10c7ff06b8e168505d797abda491c026b Mon Sep 17 00:00:00 2001 From: mattjhawken Date: Wed, 21 Jan 2026 10:43:33 -0500 Subject: [PATCH 03/25] Attempting reasoning toggle for API generate output (pt2) #99. --- tensorlink/ml/validator.py | 78 ++++++++++++++++++++++++-------------- 1 file changed, 50 insertions(+), 28 deletions(-) diff --git a/tensorlink/ml/validator.py b/tensorlink/ml/validator.py index c790b66..2aae32e 100644 --- a/tensorlink/ml/validator.py +++ b/tensorlink/ml/validator.py @@ -581,26 +581,19 @@ def _generate(self, request, job_id, start_time): request.hf_name, request.message, request.history, - enable_thinking=( - getattr(request, "enable_thinking", False) - or getattr(request, "reasoning", False) - ), + enable_thinking=request.reasoning, ) else: formatted_prompt = request.message # TOKENIZE - # Get model's max length model_max_length = getattr(tokenizer, 'model_max_length', 2048) if model_max_length > 100000: model_max_length = 2048 - - # Tokenize with appropriate max_length max_length = min( getattr(request, 'max_length', 512), model_max_length - 10, ) - inputs = tokenizer( formatted_prompt, return_tensors="pt", @@ -707,10 +700,7 @@ def _generate_streaming(self, request: GenerationRequest, job_id: str): request.hf_name, request.message, request.history, - enable_thinking=( - getattr(request, "enable_thinking", False) - or getattr(request, "reasoning", False) - ), + enable_thinking=request.reasoning, ) else: formatted_prompt = request.message @@ -790,30 +780,61 @@ def _generate_streaming(self, request: GenerationRequest, job_id: str): # Stream tokens full_text = "" token_count = 0 + in_reasoning_block = False + reasoning_buffer = "" for token_text in streamer: full_text += token_text - token_count += 1 - formatted_chunk = ResponseFormatter.format_stream_chunk( - request=request, - token_text=token_text, - index=token_count, - start_time=start_time, - ) + # Track if we're inside a reasoning block (simple detection) + if request.input_format == "chat" and not request.reasoning: + # Check for start of reasoning tags + if re.search( + r'<\s*(think|reflection|thought|internal|analysis)\s*>', + reasoning_buffer + token_text, + re.IGNORECASE, + ): + in_reasoning_block = True + reasoning_buffer += token_text + continue + + # Check for end of reasoning tags + if in_reasoning_block: + reasoning_buffer += token_text + if re.search( + r'<\s*/\s*(think|reflection|thought|internal|analysis)\s*>', + reasoning_buffer, + re.IGNORECASE, + ): + in_reasoning_block = False + reasoning_buffer = "" + continue + + # Only send non-reasoning tokens when reasoning is disabled + if not in_reasoning_block: + token_count += 1 + formatted_chunk = ResponseFormatter.format_stream_chunk( + request=request, + token_text=token_text, + index=token_count, + start_time=start_time, + ) - self.send_request( - "update_stream", - (request.id, {"chunk": formatted_chunk, "done": False}), - ) + self.send_request( + "update_stream", + (request.id, {"chunk": formatted_chunk, "done": False}), + ) reasoning_text = None + cleaned_text = full_text - # Clean output + # Extract reasoning and clean output if request.input_format == "chat": - cleaned_text = extract_reasoning_and_answer(full_text) - else: - cleaned_text = full_text + reasoning_text, cleaned_text = extract_reasoning_and_answer(full_text) + + # Only include reasoning if explicitly enabled + if not request.reasoning: + reasoning_text = None request.output = cleaned_text @@ -823,7 +844,8 @@ def _generate_streaming(self, request: GenerationRequest, job_id: str): start_time=start_time, prompt_tokens=prompt_tokens, completion_tokens=token_count, - full_text=cleaned_text if request.output_format != "openai" else None, + full_text=cleaned_text, + reasoning_text=reasoning_text, ) self.send_request( From 908142776a848a91614aefa9b32b683ac3db0eac Mon Sep 17 00:00:00 2001 From: mattjhawken Date: Wed, 21 Jan 2026 12:41:00 -0500 Subject: [PATCH 04/25] Attempting reasoning toggle for API generate output (pt3) #99. --- tensorlink/ml/validator.py | 39 ++++++++------------------------------ 1 file changed, 8 insertions(+), 31 deletions(-) diff --git a/tensorlink/ml/validator.py b/tensorlink/ml/validator.py index 2aae32e..630e507 100644 --- a/tensorlink/ml/validator.py +++ b/tensorlink/ml/validator.py @@ -581,7 +581,7 @@ def _generate(self, request, job_id, start_time): request.hf_name, request.message, request.history, - enable_thinking=request.reasoning, + enable_thinking=request.reasoning, # Use consistent field name ) else: formatted_prompt = request.message @@ -590,10 +590,12 @@ def _generate(self, request, job_id, start_time): model_max_length = getattr(tokenizer, 'model_max_length', 2048) if model_max_length > 100000: model_max_length = 2048 + max_length = min( getattr(request, 'max_length', 512), model_max_length - 10, ) + inputs = tokenizer( formatted_prompt, return_tensors="pt", @@ -615,7 +617,6 @@ def _generate(self, request, job_id, start_time): ) except ValueError as e: - # Prompt is too long request.output = f"Error: {str(e)}" request.formatted_response = ResponseFormatter.format_error_response( error_message=str(e), @@ -628,18 +629,8 @@ def _generate(self, request, job_id, start_time): # GENERATE with torch.no_grad(): try: - outputs = distributed_model.generate( - input_ids, - # max_new_tokens=args["max_new_tokens"], - # temperature=args["temperature"], - # pad_token_id=args["pad_token_id"], - # eos_token_id=args["eos_token_id"], - # do_sample=args["do_sample"], - # num_beams=args["num_beams"], - # **({} if "top_p" not in args else {"top_p": args["top_p"]}), - ) + outputs = distributed_model.generate(input_ids, **args) except RuntimeError as e: - # Handle CUDA OOM or other runtime errors error_msg = f"Generation failed: {str(e)}" request.output = error_msg request.formatted_response = ResponseFormatter.format_error_response( @@ -663,8 +654,8 @@ def _generate(self, request, job_id, start_time): if request.input_format == "chat": reasoning_text, text = extract_reasoning_and_answer(text) - # Respect enable_thinking flag - if not getattr(request, "enable_thinking", True): + # Respect reasoning flag - only include reasoning if explicitly enabled + if not request.reasoning: reasoning_text = None request.output = text @@ -684,10 +675,7 @@ def _generate(self, request, job_id, start_time): def _generate_streaming(self, request: GenerationRequest, job_id: str): """ Fetches tokenizer, ensures generate arguments are not problematic with - normalize_generate_args, and calls DistributedModel.generate with stream. If model is - fully loaded on a single worker, the worker will envoke model.generate with stream and send - tokens back to us via the RemoteStreamer. If the model is distributed among multiple workers, - the streaming is done directly on this end. + normalize_generate_args, and calls DistributedModel.generate with stream. """ try: start_time = getattr(request, 'start_time', time.time()) @@ -700,7 +688,7 @@ def _generate_streaming(self, request: GenerationRequest, job_id: str): request.hf_name, request.message, request.history, - enable_thinking=request.reasoning, + enable_thinking=request.reasoning, # Use consistent field name ) else: formatted_prompt = request.message @@ -732,7 +720,6 @@ def _generate_streaming(self, request: GenerationRequest, job_id: str): allowed_generate_args=distributed_model._generate_args, ) except ValueError as e: - # Send error to stream error_chunk = ResponseFormatter.format_stream_error( error_message=str(e), error_type="prompt_too_long" ) @@ -745,16 +732,6 @@ def _generate_streaming(self, request: GenerationRequest, job_id: str): # Build generation kwargs generation_kwargs = {"input_ids": input_ids, "stream": True, **args} - # generation_kwargs = { - # "input_ids": input_ids, - # "stream": True, - # "max_new_tokens": args["max_new_tokens"], - # "temperature": args["temperature"], - # # "pad_token_id": args["pad_token_id"], - # # "eos_token_id": args["eos_token_id"], - # "do_sample": args["do_sample"], - # "num_beams": args["num_beams"], - # } # Setup streamer if isinstance(distributed_model.model, OffloadedModule): From 64a8089ddd88b5084a7cbf2608c171b3da73f323 Mon Sep 17 00:00:00 2001 From: mattjhawken Date: Wed, 21 Jan 2026 13:06:17 -0500 Subject: [PATCH 05/25] Attempting reasoning toggle for API generate output (pt4) #99. --- tensorlink/api/node.py | 36 +++++++++++++++--------------------- tensorlink/ml/validator.py | 5 +++-- 2 files changed, 18 insertions(+), 23 deletions(-) diff --git a/tensorlink/api/node.py b/tensorlink/api/node.py index d9f359a..2aeb81f 100644 --- a/tensorlink/api/node.py +++ b/tensorlink/api/node.py @@ -5,6 +5,7 @@ ModelStatusResponse, ChatCompletionRequest, ) +from tensorlink.ml.formatter import ResponseFormatter from fastapi.responses import StreamingResponse from fastapi import FastAPI, HTTPException, APIRouter, Request, Query from collections import defaultdict @@ -378,7 +379,7 @@ async def _generate_stream(self, request, request_id, start_time): # Mark request as streaming request.stream = True - request.start_time = start_time # Make sure start_time is set + request.start_time = start_time # Add to processing queue self.smart_node.endpoint_requests["incoming"].append(request) @@ -389,20 +390,14 @@ async def _generate_stream(self, request, request_id, start_time): # Wait for next token with timeout token_data = await asyncio.wait_for(token_queue.get(), timeout=30.0) + # Check if generation is complete if token_data.get("done"): - # Send final chunk if provided - final_chunk = token_data.get("final_chunk") - if final_chunk: - yield final_chunk - # Also check if there's a token in the done message - elif token_data.get("token"): - yield token_data.get("token") - else: - # Fallback - yield "data: [DONE]\n\n" + # Get the SSE-formatted string (could be final chunk or error) + sse_chunk = token_data.get("token", "data: [DONE]\n\n") + yield sse_chunk break - # Pull fully-formatted SSE string from 'token' + # Get the SSE-formatted chunk string sse_chunk = token_data.get("token") if sse_chunk: yield sse_chunk @@ -411,18 +406,17 @@ async def _generate_stream(self, request, request_id, start_time): continue except asyncio.TimeoutError: - error_chunk = { - "error": { - "message": "Generation timed out", - "type": "timeout_error", - } - } - yield f"data: {json.dumps(error_chunk)}\n\n" + error_chunk = ResponseFormatter.format_stream_error( + error_message="Generation timed out", error_type="timeout_error" + ) + yield error_chunk break except Exception as e: - error_chunk = {"error": {"message": str(e), "type": "internal_error"}} - yield f"data: {json.dumps(error_chunk)}\n\n" + error_chunk = ResponseFormatter.format_stream_error( + error_message=str(e), error_type="internal_error" + ) + yield error_chunk finally: # Clean up if request.id in self.streaming_responses: diff --git a/tensorlink/ml/validator.py b/tensorlink/ml/validator.py index 630e507..5955ee5 100644 --- a/tensorlink/ml/validator.py +++ b/tensorlink/ml/validator.py @@ -629,7 +629,7 @@ def _generate(self, request, job_id, start_time): # GENERATE with torch.no_grad(): try: - outputs = distributed_model.generate(input_ids, **args) + outputs = distributed_model.generate(input_ids) except RuntimeError as e: error_msg = f"Generation failed: {str(e)}" request.output = error_msg @@ -731,10 +731,11 @@ def _generate_streaming(self, request: GenerationRequest, job_id: str): return # Build generation kwargs - generation_kwargs = {"input_ids": input_ids, "stream": True, **args} + generation_kwargs = {"input_ids": input_ids, "stream": True} # Setup streamer if isinstance(distributed_model.model, OffloadedModule): + generation_kwargs.update(**args) module_id = distributed_model.model.module_id streamer = RemoteStreamer( poll_fn=lambda: self._poll_remote_token(module_id, tokenizer) From 1d307652a67e920ed23acdbe342a6e034564f635 Mon Sep 17 00:00:00 2001 From: mattjhawken Date: Wed, 21 Jan 2026 14:47:53 -0500 Subject: [PATCH 06/25] Attempting reasoning toggle for API generate output (pt5) #99. --- tensorlink/api/models.py | 19 ++++++++------- tensorlink/api/node.py | 36 ++++++++++++---------------- tensorlink/config/models.json | 2 +- tensorlink/ml/validator.py | 45 +++++++++++++++++++---------------- 4 files changed, 51 insertions(+), 51 deletions(-) diff --git a/tensorlink/api/models.py b/tensorlink/api/models.py index 85217ae..a545718 100644 --- a/tensorlink/api/models.py +++ b/tensorlink/api/models.py @@ -17,20 +17,21 @@ class JobRequest(BaseModel): class GenerationRequest(BaseModel): model_config = ConfigDict(protected_namespaces=()) - # Input fields hf_name: str message: str + + # Generation params (all optional) + max_new_tokens: Optional[int] = None + max_length: Optional[int] = None + temperature: Optional[float] = None + top_p: Optional[float] = None + do_sample: Optional[bool] = None + num_beams: Optional[int] = None + reasoning: Optional[bool] = None + prompt: str = None model_type: Optional[str] = "auto" - # Generation parameters - max_length: int = 2048 - max_new_tokens: int = 2048 - temperature: float = 0.7 - do_sample: bool = True - num_beams: int = 1 - reasoning: bool = False - # Chat/history history: Optional[List[dict]] = None diff --git a/tensorlink/api/node.py b/tensorlink/api/node.py index d9f359a..2aeb81f 100644 --- a/tensorlink/api/node.py +++ b/tensorlink/api/node.py @@ -5,6 +5,7 @@ ModelStatusResponse, ChatCompletionRequest, ) +from tensorlink.ml.formatter import ResponseFormatter from fastapi.responses import StreamingResponse from fastapi import FastAPI, HTTPException, APIRouter, Request, Query from collections import defaultdict @@ -378,7 +379,7 @@ async def _generate_stream(self, request, request_id, start_time): # Mark request as streaming request.stream = True - request.start_time = start_time # Make sure start_time is set + request.start_time = start_time # Add to processing queue self.smart_node.endpoint_requests["incoming"].append(request) @@ -389,20 +390,14 @@ async def _generate_stream(self, request, request_id, start_time): # Wait for next token with timeout token_data = await asyncio.wait_for(token_queue.get(), timeout=30.0) + # Check if generation is complete if token_data.get("done"): - # Send final chunk if provided - final_chunk = token_data.get("final_chunk") - if final_chunk: - yield final_chunk - # Also check if there's a token in the done message - elif token_data.get("token"): - yield token_data.get("token") - else: - # Fallback - yield "data: [DONE]\n\n" + # Get the SSE-formatted string (could be final chunk or error) + sse_chunk = token_data.get("token", "data: [DONE]\n\n") + yield sse_chunk break - # Pull fully-formatted SSE string from 'token' + # Get the SSE-formatted chunk string sse_chunk = token_data.get("token") if sse_chunk: yield sse_chunk @@ -411,18 +406,17 @@ async def _generate_stream(self, request, request_id, start_time): continue except asyncio.TimeoutError: - error_chunk = { - "error": { - "message": "Generation timed out", - "type": "timeout_error", - } - } - yield f"data: {json.dumps(error_chunk)}\n\n" + error_chunk = ResponseFormatter.format_stream_error( + error_message="Generation timed out", error_type="timeout_error" + ) + yield error_chunk break except Exception as e: - error_chunk = {"error": {"message": str(e), "type": "internal_error"}} - yield f"data: {json.dumps(error_chunk)}\n\n" + error_chunk = ResponseFormatter.format_stream_error( + error_message=str(e), error_type="internal_error" + ) + yield error_chunk finally: # Clean up if request.id in self.streaming_responses: diff --git a/tensorlink/config/models.json b/tensorlink/config/models.json index 83ddb66..6779e49 100644 --- a/tensorlink/config/models.json +++ b/tensorlink/config/models.json @@ -1,5 +1,5 @@ { "DEFAULT_MODELS": [ - "HuggingFaceTB/SmolLM-135M" + "Qwen/Qwen3-8B" ] } diff --git a/tensorlink/ml/validator.py b/tensorlink/ml/validator.py index 630e507..2ef14fd 100644 --- a/tensorlink/ml/validator.py +++ b/tensorlink/ml/validator.py @@ -629,7 +629,7 @@ def _generate(self, request, job_id, start_time): # GENERATE with torch.no_grad(): try: - outputs = distributed_model.generate(input_ids, **args) + outputs = distributed_model.generate(input_ids) except RuntimeError as e: error_msg = f"Generation failed: {str(e)}" request.output = error_msg @@ -653,6 +653,7 @@ def _generate(self, request, job_id, start_time): reasoning_text = None if request.input_format == "chat": reasoning_text, text = extract_reasoning_and_answer(text) + print(reasoning_text) # Respect reasoning flag - only include reasoning if explicitly enabled if not request.reasoning: @@ -731,10 +732,11 @@ def _generate_streaming(self, request: GenerationRequest, job_id: str): return # Build generation kwargs - generation_kwargs = {"input_ids": input_ids, "stream": True, **args} + generation_kwargs = {"input_ids": input_ids, "stream": True} # Setup streamer if isinstance(distributed_model.model, OffloadedModule): + generation_kwargs.update(**args) module_id = distributed_model.model.module_id streamer = RemoteStreamer( poll_fn=lambda: self._poll_remote_token(module_id, tokenizer) @@ -759,35 +761,38 @@ def _generate_streaming(self, request: GenerationRequest, job_id: str): token_count = 0 in_reasoning_block = False reasoning_buffer = "" + start_re = re.compile( + r'<\s*(think|reflection|thought|internal|analysis)\s*>', + re.IGNORECASE, + ) + end_re = re.compile( + r'<\s*/\s*(think|reflection|thought|internal|analysis)\s*>', + re.IGNORECASE, + ) for token_text in streamer: full_text += token_text - # Track if we're inside a reasoning block (simple detection) if request.input_format == "chat" and not request.reasoning: - # Check for start of reasoning tags - if re.search( - r'<\s*(think|reflection|thought|internal|analysis)\s*>', - reasoning_buffer + token_text, - re.IGNORECASE, - ): - in_reasoning_block = True - reasoning_buffer += token_text - continue - # Check for end of reasoning tags - if in_reasoning_block: + # ENTER only if we're NOT already inside + if not in_reasoning_block: + if start_re.search(token_text): + print(f"ENTERING REASON: {token_text}") + in_reasoning_block = True + reasoning_buffer = token_text + continue + + # EXIT only if we ARE inside + else: reasoning_buffer += token_text - if re.search( - r'<\s*/\s*(think|reflection|thought|internal|analysis)\s*>', - reasoning_buffer, - re.IGNORECASE, - ): + if end_re.search(reasoning_buffer): + print(f"EXITING REASON: {token_text}") in_reasoning_block = False reasoning_buffer = "" continue - # Only send non-reasoning tokens when reasoning is disabled + # Only emit visible tokens when not in reasoning if not in_reasoning_block: token_count += 1 formatted_chunk = ResponseFormatter.format_stream_chunk( From 0d737dc9f0bf76fce403db55be2e85a268b5c726 Mon Sep 17 00:00:00 2001 From: mattjhawken Date: Wed, 21 Jan 2026 21:05:49 -0500 Subject: [PATCH 07/25] Attempting reasoning toggle for API generate output (pt6) #99. Generate input arg normalization and filtering. --- tensorlink/api/models.py | 4 +- tensorlink/ml/formatter.py | 155 ++++++++++------------ tensorlink/ml/module.py | 9 -- tensorlink/ml/validator.py | 254 +++++++++++++++++-------------------- tensorlink/nodes/nodes.py | 2 +- tests/test_model_api.py | 4 +- 6 files changed, 185 insertions(+), 243 deletions(-) diff --git a/tensorlink/api/models.py b/tensorlink/api/models.py index a545718..c3cd021 100644 --- a/tensorlink/api/models.py +++ b/tensorlink/api/models.py @@ -27,7 +27,8 @@ class GenerationRequest(BaseModel): top_p: Optional[float] = None do_sample: Optional[bool] = None num_beams: Optional[int] = None - reasoning: Optional[bool] = None + reasoning: Optional[bool] = False + stream: bool = False prompt: str = None model_type: Optional[str] = "auto" @@ -42,7 +43,6 @@ class GenerationRequest(BaseModel): # Processing metadata processing: bool = False id: int = None - stream: bool = False start_time: float = 0 # Format control diff --git a/tensorlink/ml/formatter.py b/tensorlink/ml/formatter.py index fd49f4f..e6f17a3 100644 --- a/tensorlink/ml/formatter.py +++ b/tensorlink/ml/formatter.py @@ -12,16 +12,8 @@ def normalize_generate_args( allowed_generate_args: Optional[set] = None, ) -> Dict[str, Any]: """ - Normalize and validate generation arguments to prevent errors. - - Args: - request: GenerationRequest object - tokenizer: The tokenizer for the model - prompt_tokens: Number of tokens in the prompt (if already computed) - model_max_length: Maximum sequence length the model supports - allowed_generate_args: Generate function to get input args - Returns: - Dictionary of validated generation arguments + Normalize and validate generation arguments without injecting defaults. + Only user-provided, non-None values are included. """ # TOKEN IDs @@ -29,115 +21,96 @@ def normalize_generate_args( eos_token_id = tokenizer.eos_token_id vocab_size = len(tokenizer) - # Fallback if None if pad_token_id is None: pad_token_id = eos_token_id if eos_token_id is not None else 0 if eos_token_id is None: eos_token_id = pad_token_id - # Prevent identical pad and eos (causes generation to stop immediately) if pad_token_id == eos_token_id: if pad_token_id == 0 and vocab_size > 1: eos_token_id = 1 elif pad_token_id > 0: eos_token_id = 0 - # Ensure within vocab bounds eos_token_id = min(eos_token_id, vocab_size - 1) - # MODEL CONSTRAINTS - # Get model's maximum sequence length if model_max_length is None: - model_max_length = getattr(tokenizer, 'model_max_length', 2048) - # Some tokenizers have unrealistic defaults - if model_max_length > 1000000: + model_max_length = getattr(tokenizer, "model_max_length", 2048) + if model_max_length > 1_000_000: model_max_length = 2048 - # MAX_NEW_TOKENS - max_new_tokens = getattr(request, "max_new_tokens", None) - - # Default if not specified - if not max_new_tokens or max_new_tokens < 1: - max_new_tokens = 256 - - # Ensure we have room for generation - if prompt_tokens is not None: - # Calculate available space for new tokens - available_space = model_max_length - prompt_tokens - - if available_space < 10: - # Prompt is too long, we need at least some room to generate - raise ValueError( - f"Prompt is too long ({prompt_tokens} tokens). " - f"Model max length is {model_max_length}, leaving only " - f"{available_space} tokens for generation. " - f"Please use a shorter prompt." - ) - - # Cap max_new_tokens to available space - if max_new_tokens > available_space: - original = max_new_tokens - max_new_tokens = available_space - print( - f"Reduced max_new_tokens from {original} to {max_new_tokens} " - f"to fit within model's {model_max_length} token limit " - f"(prompt uses {prompt_tokens} tokens)" - ) - - # Ensure minimum generation length - max_new_tokens = max(max_new_tokens, 1) - - # TEMPERATURE - temperature = getattr(request, "temperature", None) - if temperature is None or temperature <= 0: - temperature = 0.7 + args: Dict[str, Any] = { + "pad_token_id": pad_token_id, + "eos_token_id": eos_token_id, + } - # Clamp to reasonable range - temperature = max(0.01, min(temperature, 2.0)) + # ---------- MAX_NEW_TOKENS ---------- + max_new_tokens = getattr(request, "max_new_tokens", None) - # SAMPLING - do_sample = bool(getattr(request, "do_sample", True)) + if max_new_tokens is not None: + if max_new_tokens < 1: + raise ValueError("max_new_tokens must be >= 1") + + if prompt_tokens is not None: + available_space = model_max_length - prompt_tokens + + if available_space < 10: + raise ValueError( + f"Prompt is too long ({prompt_tokens} tokens). " + f"Model max length is {model_max_length}, leaving only " + f"{available_space} tokens for generation." + ) + + if max_new_tokens > available_space: + original = max_new_tokens + max_new_tokens = available_space + print( + f"Reduced max_new_tokens from {original} to {max_new_tokens} " + f"to fit model limit ({model_max_length})" + ) + + args["max_new_tokens"] = max_new_tokens + + # ---------- DO_SAMPLE ---------- + do_sample = getattr(request, "do_sample", None) + if do_sample is not None: + do_sample = bool(do_sample) + args["do_sample"] = do_sample + else: + do_sample = None - # Force do_sample=False if temperature is very low (greedy decoding) - if temperature < 0.01: - do_sample = False - temperature = 1.0 # Temperature ignored when do_sample=False + # ---------- TEMPERATURE ---------- + temperature = getattr(request, "temperature", None) + if temperature is not None and do_sample: + temperature = max(0.01, min(float(temperature), 2.0)) + args["temperature"] = temperature - # BEAMS + # ---------- NUM_BEAMS ---------- num_beams = getattr(request, "num_beams", None) - if not num_beams or num_beams < 1: - num_beams = 1 + if num_beams is not None: + if num_beams < 1: + raise ValueError("num_beams must be >= 1") + args["num_beams"] = num_beams - # INCOMPATIBLE COMBINATIONS - # Can't use sampling with beam search (in most implementations) - if do_sample and num_beams > 1: + if do_sample and num_beams and num_beams > 1: print( - f"do_sample=True is incompatible with num_beams={num_beams}. " - f"Setting num_beams=1" + f"do_sample=True incompatible with num_beams={num_beams}, " + f"forcing num_beams=1" ) - num_beams = 1 + args["num_beams"] = 1 - # TOP_P (if provided) + # ---------- TOP_P ---------- top_p = getattr(request, "top_p", None) - if top_p is not None: - top_p = max(0.0, min(top_p, 1.0)) - - # BUILD ARGS DICT and FILTER BY GENERATE SIGNATURE - args = { - "pad_token_id": pad_token_id, - "eos_token_id": eos_token_id, - "max_new_tokens": max_new_tokens, - "temperature": temperature, - "do_sample": do_sample, - "num_beams": num_beams, - } if top_p is not None and do_sample: - args["top_p"] = top_p + top_p = max(0.0, min(float(top_p), 1.0)) + if top_p < 1.0: + args["top_p"] = top_p - # Filter based on allowed kwargs + # ---------- FILTER ALLOWED ---------- if allowed_generate_args is not None: - if allowed_generate_args is not None: - if allowed_generate_args != None: - args = {k: v for k, v in args.items() if k in allowed_generate_args} + args = {k: v for k, v in args.items() if k in allowed_generate_args} + + # ---------- DROP NONE ---------- + args = {k: v for k, v in args.items() if v is not None} return args diff --git a/tensorlink/ml/module.py b/tensorlink/ml/module.py index 5c6a860..f1379ee 100644 --- a/tensorlink/ml/module.py +++ b/tensorlink/ml/module.py @@ -610,15 +610,6 @@ def distribute_model(self, config=None, model_type: str = "chat"): if self.model_name: self._load_model_skeleton(model_type) - sig = inspect.signature(self.generate) - if any( - p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values() - ): - # Accepts **kwargs, just mark as None or empty set - self._generate_args = None - else: - self._generate_args = set(sig.parameters.keys()) - grouped_layers = {} host_modules = {} diff --git a/tensorlink/ml/validator.py b/tensorlink/ml/validator.py index 2ef14fd..31cb8e6 100644 --- a/tensorlink/ml/validator.py +++ b/tensorlink/ml/validator.py @@ -142,6 +142,22 @@ def __next__(self): return token +def _post_process_output(request, tokenizer, formatted_prompt, text): + # Remove prompt echo + if text.startswith(formatted_prompt): + text = text[len(formatted_prompt) :].strip() + else: + text = text.strip() + + reasoning_text = None + if request.input_format == "chat": + reasoning_text, text = extract_reasoning_and_answer(text) + if not request.reasoning: + reasoning_text = None + + return reasoning_text, text + + class DistributedValidator(DistributedWorker): def __init__(self, node, trusted=False, endpoint=True): super().__init__(node, trusted) @@ -540,6 +556,56 @@ def check_node(self): # "message": f"Model {model_name} is not loaded", # } + def _prepare_generation(self, request, job_id): + distributed_model = self.models[job_id] + tokenizer = self.tokenizers[request.hf_name] + + # FORMAT PROMPT + if request.input_format == "chat": + formatted_prompt = format_chat_prompt( + request.hf_name, + request.message, + request.history, + enable_thinking=request.reasoning, + ) + else: + formatted_prompt = request.message + + # TOKENIZE + model_max_length = getattr(tokenizer, "model_max_length", 2048) + if model_max_length > 100000: + model_max_length = 2048 + + max_length = getattr(request, "max_length", 512) or 512 + max_length = min(max_length, model_max_length - 10) + + inputs = tokenizer( + formatted_prompt, + return_tensors="pt", + truncation=True, + max_length=max_length, + ) + + input_ids = inputs.input_ids.to(self.device) + prompt_tokens = input_ids.shape[1] + + # NORMALIZE ARGS + args = normalize_generate_args( + request, + tokenizer, + prompt_tokens=prompt_tokens, + model_max_length=model_max_length, + ) + + return { + "distributed_model": distributed_model, + "tokenizer": tokenizer, + "formatted_prompt": formatted_prompt, + "input_ids": input_ids, + "prompt_tokens": prompt_tokens, + "args": args, + } + def _handle_generate_request(self, request: GenerationRequest, job_id: str): """Main entry point for generate requests""" self._record_request(request.hf_name) @@ -572,50 +638,8 @@ def _generate(self, request, job_id, start_time): Fetches tokenizer, ensures generate arguments are not problematic with normalize_generate_args, and calls DistributedModel.generate. """ - distributed_model = self.models[job_id] - tokenizer = self.tokenizers[request.hf_name] - - # FORMAT PROMPT - if request.input_format == "chat": - formatted_prompt = format_chat_prompt( - request.hf_name, - request.message, - request.history, - enable_thinking=request.reasoning, # Use consistent field name - ) - else: - formatted_prompt = request.message - - # TOKENIZE - model_max_length = getattr(tokenizer, 'model_max_length', 2048) - if model_max_length > 100000: - model_max_length = 2048 - - max_length = min( - getattr(request, 'max_length', 512), - model_max_length - 10, - ) - - inputs = tokenizer( - formatted_prompt, - return_tensors="pt", - truncation=True, - max_length=max_length, - ) - - prompt_tokens = inputs.input_ids.shape[1] - input_ids = inputs.input_ids.to(self.device) - - # NORMALIZE ARGS WITH PROMPT TOKEN COUNT try: - args = normalize_generate_args( - request, - tokenizer, - prompt_tokens=prompt_tokens, - model_max_length=model_max_length, - allowed_generate_args=distributed_model._generate_args, - ) - + ctx = self._prepare_generation(request, job_id) except ValueError as e: request.output = f"Error: {str(e)}" request.formatted_response = ResponseFormatter.format_error_response( @@ -626,10 +650,16 @@ def _generate(self, request, job_id, start_time): ) return - # GENERATE + distributed_model = ctx["distributed_model"] + tokenizer = ctx["tokenizer"] + formatted_prompt = ctx["formatted_prompt"] + input_ids = ctx["input_ids"] + prompt_tokens = ctx["prompt_tokens"] + args = ctx["args"] + with torch.no_grad(): try: - outputs = distributed_model.generate(input_ids) + outputs = distributed_model.generate(input_ids, **args) except RuntimeError as e: error_msg = f"Generation failed: {str(e)}" request.output = error_msg @@ -641,27 +671,12 @@ def _generate(self, request, job_id, start_time): ) return - # DECODE - generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) - - # Remove prompt echo - if generated_text.startswith(formatted_prompt): - text = generated_text[len(formatted_prompt) :].strip() - else: - text = generated_text.strip() - - reasoning_text = None - if request.input_format == "chat": - reasoning_text, text = extract_reasoning_and_answer(text) - print(reasoning_text) - - # Respect reasoning flag - only include reasoning if explicitly enabled - if not request.reasoning: - reasoning_text = None + generated_text = tokenizer.decode(outputs[0]) + reasoning_text, text = _post_process_output( + request, tokenizer, formatted_prompt, generated_text + ) request.output = text - - # COUNT TOKENS & FORMAT RESPONSE completion_tokens = len(tokenizer.encode(text, add_special_tokens=False)) request.formatted_response = ResponseFormatter.format_non_streaming_response( @@ -679,94 +694,64 @@ def _generate_streaming(self, request: GenerationRequest, job_id: str): normalize_generate_args, and calls DistributedModel.generate with stream. """ try: - start_time = getattr(request, 'start_time', time.time()) - distributed_model = self.models[job_id] - tokenizer = self.tokenizers[request.hf_name] - - # Format input - if request.input_format == "chat": - formatted_prompt = format_chat_prompt( - request.hf_name, - request.message, - request.history, - enable_thinking=request.reasoning, # Use consistent field name - ) - else: - formatted_prompt = request.message - - # Tokenize - model_max_length = getattr(tokenizer, 'model_max_length', 2048) - if model_max_length > 100000: - model_max_length = 2048 - - max_length = min(getattr(request, 'max_length', 512), model_max_length - 10) - - inputs = tokenizer( - formatted_prompt, - return_tensors="pt", - truncation=True, - max_length=max_length, - ) - - input_ids = inputs.input_ids.to(self.device) - prompt_tokens = input_ids.shape[1] - - # Normalize args - try: - args = normalize_generate_args( - request, - tokenizer, - prompt_tokens=prompt_tokens, - model_max_length=model_max_length, - allowed_generate_args=distributed_model._generate_args, - ) - except ValueError as e: - error_chunk = ResponseFormatter.format_stream_error( - error_message=str(e), error_type="prompt_too_long" - ) - self.send_request( - "update_stream", - (request.id, {"done": True, "final_chunk": error_chunk}), - ) - request.output = f"Error: {str(e)}" - return - - # Build generation kwargs - generation_kwargs = {"input_ids": input_ids, "stream": True} + start_time = getattr(request, "start_time", time.time()) + + ctx = self._prepare_generation(request, job_id) + distributed_model = ctx["distributed_model"] + tokenizer = ctx["tokenizer"] + formatted_prompt = ctx["formatted_prompt"] + input_ids = ctx["input_ids"] + prompt_tokens = ctx["prompt_tokens"] + args = ctx["args"] + + # ---- Build kwargs ---- + generation_kwargs = { + "input_ids": input_ids, + **args, + } - # Setup streamer + # ---- Setup streamer + thread ---- if isinstance(distributed_model.model, OffloadedModule): - generation_kwargs.update(**args) + generation_kwargs["stream"] = True + module_id = distributed_model.model.module_id streamer = RemoteStreamer( poll_fn=lambda: self._poll_remote_token(module_id, tokenizer) ) + generation_thread = Thread( - target=distributed_model.generate, kwargs=generation_kwargs + target=distributed_model.generate, + kwargs=generation_kwargs, + daemon=True, ) generation_thread.start() + else: streamer = TextIteratorStreamer( tokenizer, skip_prompt=True, skip_special_tokens=True ) - generation_kwargs.pop("stream") + generation_kwargs["streamer"] = streamer + generation_thread = Thread( - target=distributed_model.generate, kwargs=generation_kwargs + target=distributed_model.generate, + kwargs=generation_kwargs, + daemon=True, ) generation_thread.start() - # Stream tokens + # ---- Stream tokens ---- full_text = "" token_count = 0 in_reasoning_block = False reasoning_buffer = "" + start_re = re.compile( - r'<\s*(think|reflection|thought|internal|analysis)\s*>', + r"<\s*(think|reflection|thought|internal|analysis)\s*>", re.IGNORECASE, ) end_re = re.compile( - r'<\s*/\s*(think|reflection|thought|internal|analysis)\s*>', + r"<\s*/\s*(think|reflection|thought|internal|analysis)\s*>", re.IGNORECASE, ) @@ -775,26 +760,21 @@ def _generate_streaming(self, request: GenerationRequest, job_id: str): if request.input_format == "chat" and not request.reasoning: - # ENTER only if we're NOT already inside if not in_reasoning_block: if start_re.search(token_text): - print(f"ENTERING REASON: {token_text}") in_reasoning_block = True reasoning_buffer = token_text continue - - # EXIT only if we ARE inside else: reasoning_buffer += token_text if end_re.search(reasoning_buffer): - print(f"EXITING REASON: {token_text}") in_reasoning_block = False reasoning_buffer = "" continue - # Only emit visible tokens when not in reasoning if not in_reasoning_block: token_count += 1 + formatted_chunk = ResponseFormatter.format_stream_chunk( request=request, token_text=token_text, @@ -807,20 +787,17 @@ def _generate_streaming(self, request: GenerationRequest, job_id: str): (request.id, {"chunk": formatted_chunk, "done": False}), ) + # ---- Finalize ---- reasoning_text = None cleaned_text = full_text - # Extract reasoning and clean output if request.input_format == "chat": reasoning_text, cleaned_text = extract_reasoning_and_answer(full_text) - - # Only include reasoning if explicitly enabled if not request.reasoning: reasoning_text = None request.output = cleaned_text - # Send final chunk final_chunk = ResponseFormatter.format_stream_final( request=request, start_time=start_time, @@ -837,12 +814,15 @@ def _generate_streaming(self, request: GenerationRequest, job_id: str): except Exception as e: error_chunk = ResponseFormatter.format_stream_error( - error_message=str(e), error_type="generation_error" + error_message=str(e), + error_type="generation_error", ) + self.send_request( "update_stream", (request.id, {"done": True, "final_chunk": error_chunk}), ) + request.output = f"Error during generation: {str(e)}" def _poll_remote_token(self, module_id: str, tokenizer): diff --git a/tensorlink/nodes/nodes.py b/tensorlink/nodes/nodes.py index 948948b..9e344ee 100644 --- a/tensorlink/nodes/nodes.py +++ b/tensorlink/nodes/nodes.py @@ -17,7 +17,7 @@ class BaseNodeConfig: upnp: bool = True max_connections: int = 0 - on_chain: bool = False + on_chain: bool = True local_test: bool = False print_level: int = logging.INFO priority_nodes: Optional[List[List[str]]] = None diff --git a/tests/test_model_api.py b/tests/test_model_api.py index 10340d1..53f2a96 100644 --- a/tests/test_model_api.py +++ b/tests/test_model_api.py @@ -140,9 +140,7 @@ def test_generate_openai(model_env): generate_payload = { "hf_name": cfg["name"], "message": "Hi there, tell me something interesting.", - "max_new_tokens": 10, - "do_sample": True, - "num_beams": 2, + # "max_new_tokens": 50, "output_format": "openai", } From f039ddea132e53043dd47914084616cde36f43ae Mon Sep 17 00:00:00 2001 From: mattjhawken Date: Thu, 22 Jan 2026 11:20:12 -0500 Subject: [PATCH 08/25] Added tokenizer-level reasoning toggle for certain HF models #99 --- tensorlink/ml/formatter.py | 56 +++++++++++++-- tensorlink/ml/validator.py | 137 +++++++++++++++---------------------- tests/test_model_api.py | 16 +++-- 3 files changed, 119 insertions(+), 90 deletions(-) diff --git a/tensorlink/ml/formatter.py b/tensorlink/ml/formatter.py index e6f17a3..60c2dfe 100644 --- a/tensorlink/ml/formatter.py +++ b/tensorlink/ml/formatter.py @@ -158,10 +158,12 @@ def _collect(match): return reasoning, cleaned or "[No output produced]" -def format_chat_prompt(model_name, current_message, history, enable_thinking=True): +def format_chat_prompt_manual( + model_name, current_message, history, enable_thinking=True +): """ - Format the chat history and current message into a prompt suitable for - the specified model. + Manually format the chat history and current message into a prompt. + This is the fallback for models without native reasoning support. Args: model_name: Name of the model @@ -177,7 +179,7 @@ def format_chat_prompt(model_name, current_message, history, enable_thinking=Tru # Modify system prompt to discourage thinking if disabled if not enable_thinking: - system_prompt += " Provide concise, direct answers without showing your reasoning process." + system_prompt += " Provide concise, direct answers without showing your reasoning/thinking process." formatted_prompt = f"<|im_start|>system\n{system_prompt}<|im_end|>\n" @@ -233,6 +235,52 @@ def format_chat_prompt(model_name, current_message, history, enable_thinking=Tru return formatted_prompt +def format_chat_prompt( + model_name, current_message, history, enable_thinking=True, tokenizer=None +): + """ + Format the chat history and current message into a prompt. + Uses tokenizer's apply_chat_template if it supports enable_thinking, + otherwise falls back to manual formatting. + + Args: + model_name: Name of the model + current_message: Current user message + history: Conversation history + enable_thinking: Whether to allow reasoning/thinking tokens + tokenizer: Optional tokenizer instance (if None, uses manual formatting) + + Returns: + tuple: (formatted_prompt, reasoning_supported) + """ + supports_reasoning = getattr(tokenizer, "supports_reasoning", False) + # Check if tokenizer supports native reasoning + if tokenizer and supports_reasoning: + # Build messages list + messages = [] + if history and len(history) > 0: + messages.extend(history) + messages.append({"role": "user", "content": current_message}) + + # Use tokenizer's native reasoning support + formatted_prompt = tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + enable_thinking=enable_thinking, + ) + + return formatted_prompt, True + + else: + # Fall back to manual formatting + formatted_prompt = format_chat_prompt_manual( + model_name, current_message, history, enable_thinking=enable_thinking + ) + + return formatted_prompt, False + + def format_stream_final(request, start_time, prompt_tokens, token_count): if request.output_format == "openai": return { diff --git a/tensorlink/ml/validator.py b/tensorlink/ml/validator.py index 31cb8e6..8fcc511 100644 --- a/tensorlink/ml/validator.py +++ b/tensorlink/ml/validator.py @@ -8,8 +8,6 @@ ResponseFormatter, normalize_generate_args, format_chat_prompt, - format_stream_chunk, - format_stream_final, extract_reasoning_and_answer, ) from tensorlink.ml.utils import load_models_cache, save_models_cache @@ -38,79 +36,6 @@ DEFAULT_MODELS = MODELS["DEFAULT_MODELS"] -def _format_response( - request, - clean_output: str, - raw_output: str, - processing_time: float, -): - """ - Format the response based on the requested format type. - This runs in the validator process after generation completes. - - Args: - request: The original generation request - clean_output: Cleaned/extracted output text - raw_output: Raw model output - processing_time: Time taken to process the request - - Returns: - Dictionary formatted according to output_format - """ - timestamp = int(time.time()) - request_id = getattr(request, 'id') - - if request.output_format == "simple": - # Minimal response - just the text (no cleaning for simple) - return {"response": raw_output} - - elif request.output_format == "openai": - # OpenAI-compatible format (always cleaned) - return { - "id": request_id, - "object": "chat.completion", - "created": timestamp, - "model": request.hf_name, - "choices": [ - { - "index": 0, - "message": {"role": "assistant", "content": clean_output}, - "finish_reason": "stop", - } - ], - "usage": { - "prompt_tokens": getattr(request, 'prompt_tokens', -1), - "completion_tokens": getattr(request, 'completion_tokens', -1), - "total_tokens": getattr(request, 'total_tokens', -1), - }, - } - - else: # "full" format (default, comprehensive response with all metadata) - # For full format, don't clean unless it's openai-style request - output_text = raw_output - return { - "id": request_id, - "model": request.hf_name, - "response": output_text, - "raw_output": raw_output, - "created": timestamp, - "processing_time": round(processing_time, 3), - "generation_params": { - "max_length": request.max_length, - "max_new_tokens": request.max_new_tokens, - "temperature": request.temperature, - "do_sample": request.do_sample, - "num_beams": request.num_beams, - }, - "metadata": { - "has_history": bool(request.history), - "history_length": len(request.history) if request.history else 0, - "prompt_used": request.prompt is not None, - "formatted_as_chat": request.output_format == "openai", - }, - } - - class RemoteStreamer: def __init__(self, poll_fn, sleep=0.01): self.poll_fn = poll_fn @@ -142,6 +67,29 @@ def __next__(self): return token +def _supports_reasoning(tokenizer): + """ + Check if a tokenizer supports reasoning mode (enable_thinking parameter). + + Args: + tokenizer: HuggingFace tokenizer instance + + Returns: + bool: True if tokenizer supports enable_thinking parameter + """ + if not hasattr(tokenizer, 'apply_chat_template'): + return False + + try: + # Get the signature of apply_chat_template + sig = inspect.signature(tokenizer.apply_chat_template) + + # Check if 'enable_thinking' is a parameter + return 'enable_thinking' in sig.parameters + except Exception: + return False + + def _post_process_output(request, tokenizer, formatted_prompt, text): # Remove prompt echo if text.startswith(formatted_prompt): @@ -150,10 +98,18 @@ def _post_process_output(request, tokenizer, formatted_prompt, text): text = text.strip() reasoning_text = None + + # Only extract reasoning if chat format AND reasoning is supported/enabled if request.input_format == "chat": - reasoning_text, text = extract_reasoning_and_answer(text) - if not request.reasoning: - reasoning_text = None + reasoning_supported = getattr(request, '_reasoning_supported', False) + + # Extract reasoning blocks if the model/tokenizer supports it + if reasoning_supported: + reasoning_text, text = extract_reasoning_and_answer(text) + + # Only include reasoning in response if explicitly requested + if not request.reasoning: + reasoning_text = None return reasoning_text, text @@ -562,14 +518,26 @@ def _prepare_generation(self, request, job_id): # FORMAT PROMPT if request.input_format == "chat": - formatted_prompt = format_chat_prompt( + formatted_prompt, reasoning_supported = format_chat_prompt( request.hf_name, request.message, request.history, enable_thinking=request.reasoning, + tokenizer=tokenizer, ) + + # Track whether reasoning is actually supported + request._reasoning_supported = reasoning_supported + + # Log if reasoning was requested but not supported + if request.reasoning and not reasoning_supported: + print( + f"Note: Reasoning requested for {request.hf_name} but tokenizer " + f"doesn't support enable_thinking. Using manual prompt formatting." + ) else: formatted_prompt = request.message + request._reasoning_supported = False # TOKENIZE model_max_length = getattr(tokenizer, "model_max_length", 2048) @@ -603,6 +571,7 @@ def _prepare_generation(self, request, job_id): "formatted_prompt": formatted_prompt, "input_ids": input_ids, "prompt_tokens": prompt_tokens, + "reasoning_supported": getattr(request, '_reasoning_supported', False), "args": args, } @@ -762,12 +731,14 @@ def _generate_streaming(self, request: GenerationRequest, job_id: str): if not in_reasoning_block: if start_re.search(token_text): + print(f"ENTERING REASONING: {token_text}") in_reasoning_block = True reasoning_buffer = token_text continue else: reasoning_buffer += token_text if end_re.search(reasoning_buffer): + print(f"EXITING REASONING: {token_text}") in_reasoning_block = False reasoning_buffer = "" continue @@ -790,9 +761,11 @@ def _generate_streaming(self, request: GenerationRequest, job_id: str): # ---- Finalize ---- reasoning_text = None cleaned_text = full_text - + print(f"FINAL_TEXT: {cleaned_text}") if request.input_format == "chat": reasoning_text, cleaned_text = extract_reasoning_and_answer(full_text) + print(f"REASON_TEXT: {cleaned_text}") + print(f"FINAL_TEXT: {cleaned_text}") if not request.reasoning: reasoning_text = None @@ -967,7 +940,9 @@ def _finalize_hosted_job(self, job_id: str): # Load tokenizer if model_name not in self.tokenizers: - self.tokenizers[model_name] = AutoTokenizer.from_pretrained(model_name) + tokenizer = AutoTokenizer.from_pretrained(model_name) + setattr(tokenizer, "supports_reasoning", _supports_reasoning(tokenizer)) + self.tokenizers[model_name] = tokenizer setattr(distributed_model, 'tokenizer', self.tokenizers[model_name]) diff --git a/tests/test_model_api.py b/tests/test_model_api.py index 53f2a96..4299a20 100644 --- a/tests/test_model_api.py +++ b/tests/test_model_api.py @@ -32,13 +32,22 @@ ), pytest.param( { - "name": "HuggingFaceTB/SmolLM-135M", + "name": "HuggingFaceTB/SmolLM2-135M", "timeout": 120, "sleep": 10, "parsed": True, }, - id="smollm-135m", + id="smollm2-135m", ), + # pytest.param( + # { + # "name": "BabyLM-community/babylm-baseline-100m-gpt2", + # "timeout": 120, + # "sleep": 10, + # "parsed": True, + # }, + # id="babylm-baseline-100m-gpt2", + # ), ] @@ -81,7 +90,6 @@ def test_generate_simple(model_env): "hf_name": cfg["name"], "message": "Hi.", "max_new_tokens": 10, - "do_sample": True, "num_beams": 2, "output_format": "simple", # Explicitly set to simple } @@ -207,7 +215,6 @@ def test_streaming_generation_openai(model_env): "message": "Hi there, tell me something interesting.", "max_new_tokens": 10, "stream": True, - "do_sample": False, "num_beams": 1, "output_format": "openai", } @@ -273,7 +280,6 @@ def test_streaming_generation_simple(model_env): "message": "Hi there, tell me something interesting.", "max_new_tokens": 10, "stream": True, - "do_sample": False, "num_beams": 1, "output_format": "simple", } From 0169a92638864f0bb4992b2a6ef69c551a3a549e Mon Sep 17 00:00:00 2001 From: mattjhawken Date: Sat, 24 Jan 2026 23:48:55 -0500 Subject: [PATCH 09/25] Improved validator module hosting configurations, still need to update node config. Refining tests for model parsing capabilities. --- docs/examples/{public_api.py => api.py} | 0 tensorlink/api/node.py | 17 +- tensorlink/ml/graphing.py | 252 ++++++++++++++---------- tensorlink/ml/module.py | 9 +- tensorlink/ml/validator.py | 244 ++++++++++++++++------- tensorlink/ml/worker.py | 6 +- tensorlink/nodes/job_monitor.py | 2 +- tensorlink/p2p/torch_node.py | 128 ++++++------ tests/test_model_api.py | 2 +- tests/test_model_parser.py | 235 ++++++++++++++++++++++ 10 files changed, 641 insertions(+), 254 deletions(-) rename docs/examples/{public_api.py => api.py} (100%) create mode 100644 tests/test_model_parser.py diff --git a/docs/examples/public_api.py b/docs/examples/api.py similarity index 100% rename from docs/examples/public_api.py rename to docs/examples/api.py diff --git a/tensorlink/api/node.py b/tensorlink/api/node.py index 2aeb81f..892c616 100644 --- a/tensorlink/api/node.py +++ b/tensorlink/api/node.py @@ -448,12 +448,17 @@ def _check_model_status(self, model_name: str) -> dict: try: # Check if there is a public job with this module - for module_id, module in self.smart_node.modules.items(): - if module.get("model_name", "") == model_name: - if module.get("public", False): - status = "loaded" - message = f"Model {model_name} is loaded and ready" - break + for job_id in self.smart_node.jobs: + job_data = self.smart_node.dht.query(job_id) + if ( + job_data.get("model_name", "") == model_name + and job_data.get("hosted") + and job_data.get("api") + and job_data.get("public") + ): + status = "loaded" + message = f"Model {model_name} is loaded and ready" + break except Exception as e: logging.error(f"Error checking model status: {e}") diff --git a/tensorlink/ml/graphing.py b/tensorlink/ml/graphing.py index 8148cc2..3c6970f 100644 --- a/tensorlink/ml/graphing.py +++ b/tensorlink/ml/graphing.py @@ -1,6 +1,5 @@ from tensorlink.ml.utils import estimate_memory from tensorlink.ml.injector import find_loop_in_module_hierarchy - from transformers import AutoModel, AutoConfig from accelerate import init_empty_weights from collections import defaultdict @@ -55,9 +54,8 @@ def _create_grouped_entry(parent_path: str, group: list) -> dict: "num_layers": len(group), } - # Preserve parent_forward_code if present - if "parent_forward_code" in configs[0]: - grouped_config["parent_forward_code"] = configs[0]["parent_forward_code"] + # Preserve parent_module_path if present (but not parent_forward_code) + if "parent_module_path" in configs[0]: grouped_config["parent_module_path"] = configs[0]["parent_module_path"] return {grouped_path: grouped_config} @@ -68,12 +66,12 @@ def _group_sequential_layers(config: dict) -> dict: Group consecutive layers assigned to the same worker into single entries. For example: - model.layers.0 -> worker1 - model.layers.1 -> worker1 - model.layers.2 -> worker1 + model.layers.0 -> worker1 + model.layers.1 -> worker1 + model.layers.2 -> worker1 Becomes: - model.layers.0-2 -> worker1 + model.layers.0-2 -> worker1 """ # Group paths by their parent and extract layer patterns layer_groups = defaultdict(list) @@ -133,7 +131,6 @@ def _group_sequential_layers(config: dict) -> dict: def _is_loop_iterable_module(module: nn.Module, module_path: str) -> bool: """ Detect if this module is iterated over in a loop during the forward pass. - Uses the same loop detection logic as the injector. Returns: @@ -144,15 +141,12 @@ def _is_loop_iterable_module(module: nn.Module, module_path: str) -> bool: module_with_loop, loop_node, path = find_loop_in_module_hierarchy( module, max_depth=1 # Only check this level ) - # If we found a loop at this level, it's loop-iterable return True - except ValueError: # No loop found, check if modulelist if isinstance(module, nn.ModuleList): return True - return False @@ -161,7 +155,7 @@ def __init__(self, user_memory: int = 0, verbose=False): self.user_memory = user_memory self.model_name = "" self.assigned_workers = defaultdict(list) - self.forward_code_cache = {} + self.assigned_memory = 0 self.verbose = verbose self.module_paths = {} # Track all module paths @@ -171,12 +165,10 @@ def create_distributed_config( workers: dict, training: bool, trusted: bool, - handle_layers: bool = True, input_obfuscation: bool = False, optimizer_type: str = "adam", optimizer_spec: dict = {}, - host_load_small: bool = False, - host_threshold_mb: int = 50, + host_max_memory_bytes: int = 0, host_max_depth: int = 2, max_offload_depth: int = 3, max_seq_len: int = 4096, @@ -184,8 +176,119 @@ def create_distributed_config( model_type: str = "chat", ): """ - Creates a distributed configuration for a model, determining how it should be allocated across nodes. + Build a distributed execution configuration for a model by assigning its + submodules across available workers and optionally the local host. + + This method recursively walks the model graph, estimates memory usage for + each submodule (parameters, optimizer state, activations, and KV cache), + and determines whether the module should be: + - kept on the local host, + - fully offloaded to a remote worker, + - split into children and recursively assigned, or + - marked as unassignable. + + The result is a config dictionary describing how the model should be + partitioned for distributed inference or training. + + Parameters + ---------- + model : Union[nn.Module, str] + Either a PyTorch model instance or a HuggingFace model name. If a string + is provided, the model is instantiated with empty weights using + `AutoConfig` to avoid loading parameters into memory. + + workers : dict + Mapping of worker_id -> worker metadata. Each worker entry must contain + at least: + { + "gpu_memory": + } + This memory is decremented as modules are assigned. + + training : bool + Whether the configuration is for training or inference. Training mode + increases memory estimates to include gradients, optimizer state, and + activation storage. + + trusted : bool + Indicates whether workers are trusted. Used for downstream logic such as + security policies, encryption, or obfuscation decisions. + + input_obfuscation : bool, optional (default=False) + Whether inputs should be obfuscated when sent to workers. This flag is + propagated into the distributed config for runtime enforcement. + + optimizer_type : str, optional (default="adam") + Optimizer type used for memory estimation (e.g. "adam", "sgd"). This + affects optimizer state size during training. + + optimizer_spec : dict, optional + Extra optimizer configuration to attach to each assigned module + (e.g. learning rate, betas, weight decay). Stored in the config and + passed to workers. + + host_max_memory_bytes : int, optional (default=0) + Maximum number of bytes the local host is allowed to consume for loading + small submodules. If 0, the host will not keep modules locally. + + host_max_depth : int, optional (default=2) + Maximum recursion depth at which the host is allowed to keep modules. + Prevents deep layers from being pinned locally. + + max_offload_depth : int, optional (default=3) + Maximum recursion depth for offloading. If exceeded, the module is marked + as unassignable and an AssignmentError may be raised in verbose mode. + + max_seq_len : int, optional (default=4096) + Maximum sequence length used for estimating activation and KV cache + memory during inference or training. + + batch_size : int, optional (default=1) + Batch size used for memory estimation of activations and optimizer state. + + model_type : str, optional (default="chat") + Logical model type (e.g. "chat", "vision", "embedding"). Stored in the + config and used by downstream execution logic. + + Returns + ------- + dict + A dictionary with the following keys: + + - success : bool + Whether assignment completed successfully. + + - config : dict + Mapping of module_path -> assignment spec, where each entry may be: + { + "type": "loaded" | "offloaded" | "unassigned", + "device": "host" (if loaded), + "assigned_workers": [worker_id] (if offloaded), + "module_id": list, + "memory": bytes, + "module": str, + "module_path": str, + "training": bool, + "optimizer_spec": dict, + "batch_size": int, + "model_type": str, + "parent_module_path": str (optional, for pipelining) + } + + - model_memory : int + Total estimated memory footprint of the model under the provided + parameters (including activations and KV cache). + + Notes + ----- + - Modules that are too large or loop-iterable are recursively split into + children until they can be assigned. + - Sequential layers may later be grouped for pipeline parallelism via + `_group_sequential_layers`. + - Worker memory is decremented as assignments occur to prevent overcommit. + - If assignment fails, `success=False` is returned and config may be partial. """ + if isinstance(model, str): self.model_name = model model_config = AutoConfig.from_pretrained(model) @@ -199,6 +302,7 @@ def create_distributed_config( config = {} success = True + model_memory, breakdown = estimate_memory( model, training=training, @@ -210,21 +314,6 @@ def create_distributed_config( include_kv_cache=True, ) - # # Log the model structure first - # if self.verbose: - # print("\n" + "=" * 80) - # print("MODEL STRUCTURE:") - # print("=" * 80) - # self._log_model_structure( - # model, - # prefix="model", - # training=training, - # optimizer_type=optimizer_type, - # max_seq_len=max_seq_len, - # batch_size=batch_size, - # ) - # print("=" * 80 + "\n") - try: config, _ = self._recurse_module( module=model, @@ -232,13 +321,11 @@ def create_distributed_config( workers_state=workers_state, training=training, trusted=trusted, - handle_layers=handle_layers, input_obfuscation=input_obfuscation, last_worker=None, optimizer_type=optimizer_type, optimizer_spec=optimizer_spec, - host_load_small=host_load_small, - host_threshold_mb=host_threshold_mb, + host_max_memory_bytes=host_max_memory_bytes, host_max_depth=host_max_depth, max_offload_depth=max_offload_depth, max_seq_len=max_seq_len, @@ -257,54 +344,6 @@ def create_distributed_config( return {"success": success, "config": config, "model_memory": model_memory} - def _log_model_structure( - self, - module: nn.Module, - prefix: str = "model", - depth: int = 0, - training=False, - optimizer_type=None, - max_seq_len: int = 2048, - batch_size: int = 1, - ): - """ - Recursively log the entire model structure with module paths. - - Args: - module: The module to log - prefix: Current path prefix - depth: Current depth in the hierarchy - """ - indent = " " * depth - module_type = type(module).__name__ - - memory, breakdown = estimate_memory( - module, - training=training, - seq_length=max_seq_len, - optimizer_type=optimizer_type, - batch_size=batch_size, - recursive=True, - count_activations=True, - include_kv_cache=(depth == 0), - ) - - print(f"{indent}{prefix} [{module_type}] (~{memory/1e6:.1f}MB)") - - # Store in module_paths dict - self.module_paths[prefix] = {'type': module_type, 'memory_mb': memory / 1e6} - - # Recurse into children - for child_name, child_module in module.named_children(): - child_path = f"{prefix}.{child_name}" - self._log_model_structure( - child_module, - child_path, - depth + 1, - max_seq_len=max_seq_len, - batch_size=batch_size, - ) - def _recurse_module( self, module: nn.Module, @@ -312,15 +351,13 @@ def _recurse_module( workers_state: dict, training: bool, trusted: bool, - handle_layers: bool, input_obfuscation: bool, last_worker: Optional[str] = None, depth: int = 0, ids: list = None, optimizer_type="adam", optimizer_spec=None, - host_load_small: bool = False, - host_threshold_mb: int = 50, + host_max_memory_bytes: int = 0, host_max_depth: int = 1, max_offload_depth: int = 3, max_seq_len: int = 2048, @@ -341,7 +378,7 @@ def _recurse_module( if self.verbose: print(f"{indent}Processing: {module_path}") - # sum children memory + # Get memory of current module memory, breakdown = estimate_memory( module, training=training, @@ -357,13 +394,13 @@ def _recurse_module( memory -= breakdown.get("activations", 0) if self.verbose: - print(f"{indent} Memory required: {memory / 1e6:.2f}MB") + print(f"{indent} Memory required: {memory / 1e6:.2f}MB") # Local host small module logic if ( not is_root - and host_load_small - and (memory / 1e6) <= host_threshold_mb + and host_max_memory_bytes + and memory <= host_max_memory_bytes - self.assigned_memory and depth <= host_max_depth ): config[module_path] = { @@ -383,9 +420,8 @@ def _recurse_module( "batch_size": batch_size, "model_type": model_type, } - if self.verbose: - print(f"{indent} Kept on host (local) — {memory / 1e6:.2f}MB") + print(f"{indent} Kept on host (local) — {memory / 1e6:.2f}MB") return config, None # Check if module is loop-iterable BEFORE trying to assign @@ -393,7 +429,7 @@ def _recurse_module( if is_loop_iterable and depth > 0: if self.verbose: - print(f"{indent} Module is loop-iterable, will recurse into children") + print(f"{indent} Module is loop-iterable, will recurse into children") # Don't try to assign, just skip to recursion assigned_worker = None else: @@ -431,7 +467,7 @@ def _recurse_module( ) if self.verbose: - print(f"{indent} Assigned to {assigned_worker}") + print(f"{indent} Assigned to {assigned_worker}") return config, assigned_worker @@ -452,10 +488,11 @@ def _recurse_module( if self.verbose: reason = "is loop-iterable" if is_loop_iterable else "too large" print( - f"{indent} Module {module_path} ({memory / 1e6:.2f}MB) {reason}, recursing into children..." + f"{indent} Module {module_path} ({memory / 1e6:.2f}MB) {reason}, recursing into children..." ) children = list(module.named_children()) + if not children: config[module_path] = { "type": "unassigned", @@ -463,13 +500,11 @@ def _recurse_module( "module_path": module_path, } if self.verbose: - print(f"{indent} No children to recurse into - FAILED") - + print(f"{indent} No children to recurse into - FAILED") raise AssignmentError( f"Unable to assign {module_path}: no children to distribute" ) - parent_forward_code = self._extract_forward_code(module) child_workers = set() prev_child_worker = last_worker last_successful_worker = last_worker @@ -485,13 +520,11 @@ def _recurse_module( training=training, trusted=trusted, last_worker=prev_child_worker, - handle_layers=handle_layers, input_obfuscation=input_obfuscation, depth=depth + 1, optimizer_type=optimizer_type, optimizer_spec=optimizer_spec, - host_load_small=host_load_small, - host_threshold_mb=host_threshold_mb, + host_max_memory_bytes=host_max_memory_bytes, host_max_depth=host_max_depth, max_offload_depth=max_offload_depth, max_seq_len=max_seq_len, @@ -500,6 +533,7 @@ def _recurse_module( ) config.update(child_config) + if child_last_worker: prev_child_worker = child_last_worker last_successful_worker = child_last_worker @@ -507,14 +541,18 @@ def _recurse_module( except AssignmentError as e: if self.verbose: - print(f"{indent} Child {child_path} failed: {e}") + print(f"{indent} Child {child_path} failed: {e}") raise - if len(child_workers) > 1 and parent_forward_code: - for child_path, child_cfg in config.items(): - if child_cfg.get("assigned_workers", [None])[0] in child_workers: - child_cfg["parent_forward_code"] = parent_forward_code - child_cfg["parent_module_path"] = module_path + # Add parent_module_path when children span multiple workers + if len(child_workers) > 1: + # Get the children that were just processed (belong to this parent) + for child_name, _ in children: + child_path = f"{module_path}.{child_name}" + if child_path in config: + child_cfg = config[child_path] + if child_cfg.get("type") == "offloaded": + child_cfg["parent_module_path"] = module_path return config, last_successful_worker diff --git a/tensorlink/ml/module.py b/tensorlink/ml/module.py index f1379ee..e5fa6cc 100644 --- a/tensorlink/ml/module.py +++ b/tensorlink/ml/module.py @@ -250,9 +250,10 @@ def __init__( def forward(self, *args, **kwargs): """ - Performs the forward pass through the model. - - Splits input into micro-batches and runs them in parallel. - - Creates multiple parallel streams of workers for model parallel acceleration + Performs the forward pass through the distributed model, sending intermediate + forward tensors to downstream workers. + - optional: splitting inputs into micro-batches and running them in parallel. + - optional: multiple parallel streams of workers for model parallel acceleration """ if not args and "input_ids" in kwargs: args = kwargs.pop("input_ids") @@ -1208,7 +1209,7 @@ def forward(self, *args, **kwargs): # Relay forward pass to next roles self.parent_model.send_request( - "send_forward", (self.worker_id, size, shm_name, tag) + "send_forward", (self.worker_id, self.module_id, size, shm_name, tag) ) # Wait for response, change to appending waiting thread to list in master diff --git a/tensorlink/ml/validator.py b/tensorlink/ml/validator.py index 8fcc511..a46981a 100644 --- a/tensorlink/ml/validator.py +++ b/tensorlink/ml/validator.py @@ -1,6 +1,3 @@ -import inspect -import re - from tensorlink.ml.graphing import ModelParser from tensorlink.ml.worker import DistributedWorker from tensorlink.ml.module import DistributedModel, OffloadedModule @@ -10,16 +7,18 @@ format_chat_prompt, extract_reasoning_and_answer, ) -from tensorlink.ml.utils import load_models_cache, save_models_cache +from tensorlink.ml.utils import load_models_cache, save_models_cache, get_gpu_memory from tensorlink.api.models import GenerationRequest from transformers import AutoTokenizer, TextIteratorStreamer from collections import defaultdict -from threading import Thread +from threading import Thread, Lock import torch import logging +import inspect import json import time +import re import gc import os @@ -127,14 +126,21 @@ def __init__(self, node, trusted=False, endpoint=True): self.tokenizers = {} - # Track models that are in the process of being initialized (job_id) - self.models_initializing = set() + # Track models that are in the process of being initialized + self.models_initializing = set() # job_id # Configuration self.TRACKING_DAYS = 7 # Track requests for past 7 days self.MIN_REQUESTS_THRESHOLD = 10 # Minimum requests to consider auto-loading self.MAX_AUTO_MODELS = 10 # Maximum models to auto-load + # Track reserved host memory during initialization + self.host_memory_reserved = 0 + self.initializing_reservations = {} # job_id -> reserved_memory + + # Lock for thread-safe memory operations + self.memory_lock = Lock() + def _ensure_model_entry(self, model_name: str): """Ensure a model has an entry in the cache with proper structure""" if model_name not in self.model_cache: @@ -330,87 +336,107 @@ def _manage_auto_loaded_models(self): if self.models_initializing: self._try_finalize_initializing_models() - def inspect_model(self, model_name: str, job_data: dict, hosted=False) -> dict: + def inspect_model( + self, model_name: str, job_data: dict, hosted: bool = False + ) -> dict: """Inspect a model to determine network requirements and store distribution in JSON cache""" - parser = ModelParser() - model_name: str = job_data.get("model_name", model_name) + try: + parser = ModelParser() + model_name: str = job_data.get("model_name", model_name) - # Get network worker information to assign modules - workers = self.send_request("get_workers", None) + # Get network worker information to assign modules + workers = self.send_request("get_workers", None) - batch_size = job_data.get("batch_size", None) + batch_size = job_data.get("batch_size", None) - if batch_size is None: - if job_data.get("training", False): - batch_size = 256 + if batch_size is None: + if job_data.get("training", False): + batch_size = 256 + else: + batch_size = 1 + + if job_data.get("optimizer") is None: + optimizer_type = "adam" + optimizer_spec = {} else: - batch_size = 1 + optimizer_type = job_data["optimizer"]["type"] + optimizer_spec = job_data.get("optimizer") - if job_data.get("optimizer") is None: - optimizer_type = "adam" - optimizer_spec = {} - else: - optimizer_type = job_data["optimizer"]["type"] - optimizer_spec = job_data.get("optimizer") - - # Load HF model, create and save distribution - distribution = parser.create_distributed_config( - model_name, - workers=workers, - training=job_data.get("training", False), - trusted=False, - handle_layers=False, - input_obfuscation=False, - optimizer_type=optimizer_type, - optimizer_spec=optimizer_spec, - host_load_small=hosted, - host_max_depth=1, - host_threshold_mb=75, - max_offload_depth=3, - batch_size=job_data.get("batch_size", batch_size), - max_seq_len=job_data.get("max_seq_len", 4096), - model_type=job_data.get("model_type", "chat"), - ) + # Get available host memory accounting for concurrent initializations + if hosted: + available_host_memory = self._get_available_host_memory() + host_memory_budget = available_host_memory + else: + host_memory_budget = 0 - job_data["distribution"] = distribution + # Load HF model, create and save distribution + distribution = parser.create_distributed_config( + model_name, + workers=workers, + training=job_data.get("training", False), + trusted=False, + input_obfuscation=False, + optimizer_type=optimizer_type, + optimizer_spec=optimizer_spec, + host_max_memory_bytes=host_memory_budget, + host_max_depth=1, + max_offload_depth=3, + batch_size=job_data.get("batch_size", batch_size), + max_seq_len=job_data.get("max_seq_len", 4096), + model_type=job_data.get("model_type", "chat"), + ) - offloaded_count = sum( - 1 - for v in distribution["config"].values() - if "offloaded" in v.get("type", "") - ) + job_data["distribution"] = distribution - if ( - len(distribution["config"]) == 0 - or offloaded_count - > 4 # TODO This limit on number of distributions is not ideal - or not distribution["success"] - ): - return {} + offloaded_count = sum( + 1 + for v in distribution["config"].values() + if "offloaded" in v.get("type", "") + ) - # Store distribution in JSON cache - self._ensure_model_entry(model_name) - self.model_cache[model_name]["distribution"] = distribution - save_models_cache(self.model_cache) + if ( + len(distribution["config"]) == 0 + or offloaded_count + > 4 # TODO This limit on number of distributions is not ideal + or not distribution["success"] + ): + return {} - self.send_request( - "debug_print", - ( - f"DistributedValidator -> Retrieved HF model: {job_data}", - "bright_blue", - logging.DEBUG, - ), - ) + # Reserve the host memory this model will use + host_memory_used = distribution.get("host_memory_used", 0) + if host_memory_used > 0: + self._reserve_host_memory(job_data["id"], host_memory_used) + + # Store distribution in JSON cache + self._ensure_model_entry(model_name) + self.model_cache[model_name]["distribution"] = distribution + save_models_cache(self.model_cache) - gc.collect() # Force garbage collection + self.send_request( + "debug_print", + ( + f"DistributedValidator -> Retrieved HF model: {job_data}, Reserved: {host_memory_used / 1e9:.2f}GB", + "bright_blue", + logging.DEBUG, + ), + ) - # Send out job request - try: - new_job_data = self.send_request("send_job_request", job_data) - return new_job_data + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + # Send out job request + try: + new_job_data = self.send_request("send_job_request", job_data) + return new_job_data + + except Exception as e: + self._release_host_memory(job_data["id"]) + print(str(e)) except Exception as e: - print(str(e)) + self._release_host_memory(job_data["id"]) + raise def check_node(self): """Check for node requests/updates""" @@ -422,6 +448,9 @@ def check_node(self): # Clean up old request data self._cleanup_old_requests() + # Manage any ghost memory caches + self._audit_memory_reservations() + # Manage autoloaded models based on popularity (or DEFAULT_MODELS fallback) self._manage_auto_loaded_models() @@ -896,6 +925,7 @@ def _initialize_hosted_job( logging.error(f"Error initializing hosted job for {model_name}: {str(e)}") job_id = job_data.get("id") self.models_initializing.discard(job_id) + self._release_host_memory(job_id) del self.models[job_id] if job_id in self.model_state: del self.model_state[job_id] @@ -935,9 +965,11 @@ def _finalize_hosted_job(self, job_id: str): # Distribute the model across workers distributed_model.distribute_model(distribution) distributed_model.job_id = job_id - model_name = distributed_model.model_name + # Update available GPU memory + self._release_host_memory(job_id) + # Load tokenizer if model_name not in self.tokenizers: tokenizer = AutoTokenizer.from_pretrained(model_name) @@ -964,6 +996,7 @@ def _finalize_hosted_job(self, job_id: str): except Exception as e: logging.error(f"Error finalizing hosted job for {model_name}: {str(e)}") self.models_initializing.discard(job_id) + self._release_host_memory(job_id) if job_id in self.models: del self.models[job_id] return False @@ -971,6 +1004,8 @@ def _finalize_hosted_job(self, job_id: str): def _remove_hosted_job(self, job_id: str): """Remove a hosted job and clean up all associated resources""" try: + self._release_host_memory(job_id) + # Remove from initializing set if present self.models_initializing.discard(job_id) @@ -1048,6 +1083,8 @@ def _remove_hosted_job(self, job_id: str): # Force garbage collection to free memory gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() self.send_request( "debug_print", @@ -1069,5 +1106,64 @@ def _remove_hosted_job(self, job_id: str): ), ) + def _reserve_host_memory(self, job_id: str, amount: int): + """Reserve host memory for a model being initialized""" + with self.memory_lock: + self.host_memory_reserved += amount + self.initializing_reservations[job_id] = amount + + def _release_host_memory(self, job_id: str): + """Release reserved host memory when initialization completes or fails""" + with self.memory_lock: + if job_id in self.initializing_reservations: + reserved = self.initializing_reservations[job_id] + self.host_memory_reserved -= reserved + del self.initializing_reservations[job_id] + + def _get_available_host_memory(self) -> int: + """Get currently available host memory accounting for reservations""" + with self.memory_lock: + total_memory = get_gpu_memory() + return total_memory - self.host_memory_reserved + + def _audit_memory_reservations(self): + """ + Audit memory reservations and clean up any orphaned reservations. + Called periodically to prevent memory leaks from edge cases. + """ + with self.memory_lock: + # Find job_ids that have reservations but aren't in models or models_initializing + orphaned_reservations = [] + + for job_id in list(self.initializing_reservations.keys()): + if job_id not in self.models and job_id not in self.models_initializing: + orphaned_reservations.append(job_id) + + # Release orphaned reservations + for job_id in orphaned_reservations: + reserved = self.initializing_reservations[job_id] + self.host_memory_reserved -= reserved + del self.initializing_reservations[job_id] + + self.send_request( + "debug_print", + ( + f"Released orphaned reservation: {job_id} ({reserved / 1e9:.2f}GB)", + "yellow", + logging.WARNING, + ), + ) + + if orphaned_reservations: + self.send_request( + "debug_print", + ( + f"Memory audit: Released {len(orphaned_reservations)} orphaned reservations. " + f"Total reserved: {self.host_memory_reserved / 1e9:.2f}GB", + "cyan", + logging.INFO, + ), + ) + def main_loop(self): self.check_node() diff --git a/tensorlink/ml/worker.py b/tensorlink/ml/worker.py index f4daaf1..b654955 100644 --- a/tensorlink/ml/worker.py +++ b/tensorlink/ml/worker.py @@ -187,7 +187,7 @@ def __init__(self, node, trusted=False): self.scaler = None self.use_amp = False - self.GC_CHECK_INTERVAL = 1_000 + self.GC_CHECK_INTERVAL = 2_000 self.CHECK_COUNTER = 1 # Initialize CUDA streams for overlapping operations @@ -362,7 +362,7 @@ def _handle_forward(self, module_id, key, size, name): output_bytes = tensor_to_bytes(detached_out) size, name = store_in_shared_memory(output_bytes) - self.send_request("send_forward", (module.host, size, name, key)) + self.send_request("send_forward", (module.host, module_id, size, name, key)) # Incremental training counter if module.training: @@ -443,7 +443,7 @@ def _run_generate(): self._send_stream_end(module_id, host_id) size, name = store_in_shared_memory(output_bytes) - self.send_request("send_forward", (host_id, size, name, "generate")) + self.send_request("send_forward", (host_id, module_id, size, name, "generate")) if self.device.type == "cuda": torch.cuda.empty_cache() diff --git a/tensorlink/nodes/job_monitor.py b/tensorlink/nodes/job_monitor.py index 082dca1..dacd865 100644 --- a/tensorlink/nodes/job_monitor.py +++ b/tensorlink/nodes/job_monitor.py @@ -432,7 +432,7 @@ def _cleanup_job(self, job_data: Dict, final_status: JobStatus): def _cleanup_workers(self, job_data: Dict): """Clean up worker resources and send shutdown signals.""" for module_id, module_info in job_data["distribution"].items(): - if module_info["type"] == "offloaded": + if "offloaded" in module_info["type"]: for worker_id in module_info["assigned_workers"]: try: node = self.node.nodes[worker_id] diff --git a/tensorlink/p2p/torch_node.py b/tensorlink/p2p/torch_node.py index e35921f..c17c605 100644 --- a/tensorlink/p2p/torch_node.py +++ b/tensorlink/p2p/torch_node.py @@ -249,43 +249,52 @@ def _handle_backward(self, data: bytes, node: Connection): return True def _handle_forward(self, data: bytes, node: Connection): + """Handle a received forward pass from a node""" # Basic check, must be upgraded to check if we are expecting the request if node.node_id not in self.nodes: node.ghosts += 1 return False + + # Received a forward pass + eos = data.find(b"::") + size = int(data[7:eos]) + formatted_size = format_size(size) + self.debug_print(f"RECEIVED FORWARD: {formatted_size}", tag="Torchnode") + + # TODO we must check that the forward received corresponds to a sent pass/specific module + # must also do with backwards + tensor = data[eos + 2 : eos + 2 + size] + payload = json.loads(data[eos + 2 + size :]) + + if isinstance(payload, dict): + module_id = payload.get("module_id") + key = payload.get("key") else: - # Received a forward pass - eos = data.find(b"::") - size = int(data[7:eos]) - formatted_size = format_size(size) - self.debug_print(f"RECEIVED FORWARD: {formatted_size}", tag="Torchnode") + module_id = None + key = payload - # TODO we must check that the forward received corresponds to a sent pass/specific module - # must also do with backwards - tensor = data[eos + 2 : eos + 2 + size] - key = json.loads(data[eos + 2 + size :]) + if not isinstance(key, str): + key = tuple(key) + # Create shared mpc block and store tensor + self._store_tensor_in_shared_memory(key, tensor) + return True - if not isinstance(key, str): - key = tuple(key) + if module_id not in self.modules: + self.debug_print( + f"Unknown module_id in forward: {module_id}", tag="Torchnode" + ) + return False - # Create shared mpc block and store tensor - self._store_tensor_in_shared_memory(key, tensor) - else: - module_id = None - for module in self.modules: - if node.node_id in self.modules[module]["assigned_workers"]: - module_id = module - break - - shm = shared_memory.SharedMemory(create=True, size=size) - buffer = shm.buf[:size] - buffer[:] = tensor - - self.modules[module_id]["forward_queue"][key] = (size, shm.name) - self.memory_manager[key] = shm.name - del buffer - shm.close() - return True + shm = shared_memory.SharedMemory(create=True, size=size) + buffer = shm.buf[:size] + buffer[:] = tensor + + self.modules[module_id]["forward_queue"][key] = (size, shm.name) + self.memory_manager[key] = shm.name + + del buffer + shm.close() + return True def _handle_generate(self, data: bytes, node: Connection): # Received a forward pass @@ -319,6 +328,9 @@ def _handle_generate(self, data: bytes, node: Connection): return True def _handle_module(self, data: bytes, node: Connection): + """ + Load a module sent by a validator node + """ module_id = data[6:70].decode() file_name = module_id + self.rsa_key_hash if os.path.exists(file_name): @@ -408,7 +420,6 @@ def handle_requests(self, request=None): "send_forward": self._handle_send_forward, "send_backward": self._handle_send_backward, "send_parameters": self._handle_send_parameters, - "is_loaded": self._handle_is_loaded, "check_module": self._handle_check_module, "check_module_request": self._handle_check_module_request, "check_forward": self._handle_check_forward, @@ -469,7 +480,9 @@ def _handle_check_module_loaded(self, request): self.response_queue.put({"status": "SUCCESS", "return": return_val}) def _handle_module_loaded_request(self, request): - # Send module loaded message to node + """ + Send module loaded message from worker back to a validator + """ module_id = request["args"] module = self.modules[module_id] node_id = module["host"] @@ -479,6 +492,9 @@ def _handle_module_loaded_request(self, request): self.response_queue.put({"status": "SUCCESS", "return": None}) def _handle_optimizer_response_request(self, request): + """ + Send response after an update to the distributed optimizer was called + """ module_id, response_type = request["args"] node_id = self.modules[module_id]["host"] node = self.nodes[node_id] @@ -492,10 +508,10 @@ def _handle_optimizer_response_request(self, request): def _handle_send_forward(self, request): # Send forward pass tensor from shared mpc to a node - worker_id, size, shm_name, tag = request["args"] + worker_id, module_id, size, shm_name, tag = request["args"] node = self.nodes[worker_id] forward_bytes = get_from_shared_memory(size, shm_name, encoded=True) - self.send_forward(node, forward_bytes, tag) + self.send_forward(node, forward_bytes, tag, module_id) self.response_queue.put({"status": "SUCCESS", "return": None}) def _handle_send_generate(self, request): @@ -544,17 +560,11 @@ def _handle_send_parameters(self, request): ) self.response_queue.put({"status": "SUCCESS", "return": None}) - def _handle_is_loaded(self, request): - return_val = False - for module_id, module in self.modules.items(): - if module.get("terminated"): - pass - else: - return_val = True - - self.response_queue.put({"status": "SUCCESS", "return": return_val}) - def _handle_check_module(self, request): + """ + Invoked by a worker or validator ML process to see if there are any significant state + changes to any modules (ie loading or termination). + """ if self.role == "V": return_val = { "job_id": request["args"], @@ -567,7 +577,9 @@ def _handle_check_module(self, request): return_val = None for module_id, module in self.modules.items(): + # "mem_info" is added to module info upon initially receiving it if "mem_info" in module: + # Return the module info to the ML process if self.role == "V": if return_val.get("job_id") == module.get("job_id"): return_val["distribution"][module_id] = module["distribution"] @@ -580,6 +592,7 @@ def _handle_check_module(self, request): del module["mem_info"] + # "termination" is added to module info when the job is closing elif "termination" in module: return_val = module_id del self.modules[module_id] @@ -794,10 +807,17 @@ def _handle_debug_print(self, request): self.debug_print(message, colour=colour, level=level, tag=tag) self.response_queue.put({"status": "SUCCESS", "return": False}) - def send_forward(self, node: Connection, forward_bytes, context): + def send_forward(self, node: Connection, forward_bytes, context, module_id): """Send forward pass to node, must contain args (module args) and context (module + epoch id)""" + + # Inject module_id into context + payload = { + "module_id": module_id, + "key": context, + } + size = str(len(forward_bytes)).encode() + b"::" - json_data = b"FORWARD" + size + forward_bytes + json.dumps(context).encode() + json_data = b"FORWARD" + size + forward_bytes + json.dumps(payload).encode() self.send_to_node(node, json_data) def _store_tensor_in_shared_memory(self, key, tensor: bytes, backward=False): @@ -889,8 +909,7 @@ def _stop_mpc_comms(self): def print_ui_status(self): total_vram = self.total_gpu_memory - used_vram_est = total_vram - self.available_gpu_memory - used_vram_actual = total_vram - get_gpu_memory() + used_vram = total_vram - get_gpu_memory() ram = psutil.virtual_memory() used_ram = ram.total - ram.available @@ -934,21 +953,14 @@ def line(label, value, colour=ANSI.CYAN): # --- Resources --- print(sep) - vram_bar_estimate = _bar(used_vram_est, total_vram) - vram_bar_actual = _bar(used_vram_actual, total_vram) + vram_bar = _bar(used_vram, total_vram) ram_bar = _bar(used_ram, ram.total) print( - f"{ANSI.DIM}{'VRAM EST.':<14}:{ANSI.RESET} " - f"{ANSI.MAGENTA}[{vram_bar_estimate}]{ANSI.RESET} " - f"{ANSI.YELLOW}{_fmt_gb(used_vram_est)} / {_fmt_gb(total_vram)} GB{ANSI.RESET}" - ) - print( - f"{ANSI.DIM}{'VRAM ACT.':<14}:{ANSI.RESET} " - f"{ANSI.MAGENTA}[{vram_bar_actual}]{ANSI.RESET} " - f"{ANSI.YELLOW}{_fmt_gb(used_vram_actual)} / {_fmt_gb(total_vram)} GB{ANSI.RESET}" + f"{ANSI.DIM}{'VRAM':<14}:{ANSI.RESET} " + f"{ANSI.MAGENTA}[{vram_bar}]{ANSI.RESET} " + f"{ANSI.YELLOW}{_fmt_gb(used_vram)} / {_fmt_gb(total_vram)} GB{ANSI.RESET}" ) - print( f"{ANSI.DIM}{'RAM':<14}:{ANSI.RESET} " f"{ANSI.GREEN}[{ram_bar}]{ANSI.RESET} " diff --git a/tests/test_model_api.py b/tests/test_model_api.py index 4299a20..e9c97f5 100644 --- a/tests/test_model_api.py +++ b/tests/test_model_api.py @@ -33,7 +33,7 @@ pytest.param( { "name": "HuggingFaceTB/SmolLM2-135M", - "timeout": 120, + "timeout": 1200, "sleep": 10, "parsed": True, }, diff --git a/tests/test_model_parser.py b/tests/test_model_parser.py new file mode 100644 index 0000000..eb588e0 --- /dev/null +++ b/tests/test_model_parser.py @@ -0,0 +1,235 @@ +""" +Test the distribution of models across workers to visually analyze distribution methods +and ensure different distribution techniques are working accordingly +""" + +from tensorlink.ml.graphing import ModelParser +from tensorlink.ml.utils import estimate_memory, format_memory_size +import pandas as pd + +MODELS = [ + "Qwen/Qwen2.5-7B-Instruct", + "Qwen/Qwen2.5-14B-Instruct", + "Qwen/Qwen3-8B", + "Qwen/Qwen3-14B", +] + +# Two workers with a 24GB and 16GB GPU capacity +WORKERS = { + '509d89bf56704c67873c328e4f706a705b2fdc1671ebacab1083c9c6d2df650f': { + 'id': '509d89bf56704c67873c328e4f706a705b2fdc1671ebacab1083c9c6d2df650f', + 'gpu_memory': 24e9, + 'total_gpu_memory': 24e9, + 'role': 'W', + 'training': False, + }, + '209d89bf56704c67873c328e4f706a705b2fdc1671ebacab1083c9c6d2df650f': { + 'id': '209d89bf56704c67873c328e4f706a705b2fdc1671ebacab1083c9c6d2df650f', + 'gpu_memory': 16e9, + 'total_gpu_memory': 16e9, + 'role': 'W', + 'training': False, + }, +} + + +def test_model_distributions(): + """ + Print out model distributions for a variety of different models to inspect their memory + estimates and how they are distributed among the default workers. + + No actual assertion statements are made in this function. + """ + batch_sizes = [1] + seq_lengths = [1024, 4096, 8196] + + # Collect rows for the final DataFrame + rows = [] + + for model_name in MODELS: + for bs in batch_sizes: + for seqlen in seq_lengths: + parser = ModelParser(verbose=False) + + try: + config = parser.create_distributed_config( + model_name, + WORKERS, + training=False, + trusted=False, + input_obfuscation=False, + host_max_memory_bytes=50e6, + max_seq_len=seqlen, + batch_size=bs, + optimizer_type="adam", + ) + + success = config.get("success", False) + model_memory = ( + format_memory_size(config["model_memory"]) if success else 0 + ) + components_memory = { + k: {n['type']: n['memory']} for k, n in config["config"].items() + } + total_components_memory = sum( + [list(v.values())[-1] for v in components_memory.values()] + ) + memory_breakdown = { + k: {k2: format_memory_size(v2) for k2, v2 in v.items()} + for k, v in components_memory.items() + } + components_memory = { + k: {n['type']: format_memory_size(n['memory'])} + for k, n in config["config"].items() + } + + # Append to rows + rows.append( + { + "model": model_name, + "batch_size": bs, + "seq_length": seqlen, + "model_memory": model_memory, + "components_sum": format_memory_size( + total_components_memory + ), + "success": success, + "error": None if success else config.get("error", None), + "memory_breakdown": memory_breakdown, + } + ) + + except Exception as e: + rows.append( + { + "model": model_name, + "batch_size": bs, + "seq_length": seqlen, + "success": False, + "error": str(e), + } + ) + + # Create a pandas DataFrame + df = pd.DataFrame(rows) + + # Configure pandas to display multi-line cells + pd.set_option('display.max_colwidth', None) # No truncation + pd.set_option('display.width', None) # Auto-detect terminal width + pd.set_option('display.max_rows', None) # Show all rows + + print("\n\n========== FINAL RESULTS TABLE ==========\n") + print(df.to_string(index=False)) + print("\nTable stored in variable: df") + + +def test_config_combinations(): + """ + Test various combinations of configuration parameters to see how they affect + model distribution and memory allocation. Results shown as a simple DataFrame. + """ + # Base test parameters + test_model = "Qwen/Qwen2.5-14B-Instruct" + batch_size = 1 + seq_length = 4096 + + # Define test configurations to explore + test_configs = [ + { + "input_obfuscation": False, + "host_max_memory_bytes": 5e7, + }, + { + "input_obfuscation": False, + "host_max_memory_bytes": 5e7, + }, + { + "input_obfuscation": False, + "host_max_memory_bytes": 5e8, + }, + { + "input_obfuscation": False, + "host_max_memory_bytes": 5e8, + }, + ] + + results = [] + + for test_config in test_configs: + parser = ModelParser(verbose=False) + + try: + config = parser.create_distributed_config( + test_model, + WORKERS, + training=False, + trusted=False, + optimizer_type="adam", + max_seq_len=seq_length, + batch_size=batch_size, + **test_config, + ) + + success = config.get("success", False) + + if success: + components_memory = { + k: {n['type']: n['memory']} for k, n in config["config"].items() + } + total_components_memory = sum( + [list(v.values())[-1] for v in components_memory.values()] + ) + components_memory = { + k: {n['type']: format_memory_size(n['memory'])} + for k, n in config["config"].items() + } + else: + components_memory = {} + total_components_memory = 0 + + results.append( + { + "input_obfuscation": test_config["input_obfuscation"], + "success": success, + "model_memory": ( + format_memory_size(config["model_memory"]) if success else "N/A" + ), + "components_sum": ( + format_memory_size(total_components_memory) + if success + else "N/A" + ), + "component_breakdown": components_memory, + "error": None if success else config.get("error", "Unknown error"), + } + ) + + except Exception as e: + results.append( + { + "input_obfuscation": test_config["input_obfuscation"], + "host_threshold_mb": test_config["host_threshold_mb"], + "success": False, + "model_memory": "N/A", + "components_sum": "N/A", + "num_components": 0, + "component_breakdown": "N/A", + "error": str(e), + } + ) + + # Create DataFrame + df = pd.DataFrame(results) + + # Set display options + pd.set_option('display.max_columns', None) + pd.set_option('display.width', None) + pd.set_option('display.max_colwidth', None) + + print("\n\n========== CONFIGURATION COMPARISON ==========\n") + print(f"Test Model: {test_model}") + print(f"Batch Size: {batch_size}, Seq Length: {seq_length}\n") + print(df.to_string(index=False)) + print("\n\nResults stored in variable: df") + + return df From 038288addcf105c2f6beb3305966c48caca6a126 Mon Sep 17 00:00:00 2001 From: mattjhawken Date: Mon, 26 Jan 2026 12:46:54 -0500 Subject: [PATCH 10/25] Improved functionality for host loading. Validators can now act as workers and load larger modules. More configuration for max host module size, max gpu memory, etc. Enhanced docs for a number of key classes. --- bin/config.json | 3 +- bin/run_node.py | 13 +- tensorlink/ml/graphing.py | 404 +++++++++++++++------------ tensorlink/ml/injector.py | 56 ++-- tensorlink/ml/module.py | 74 ++++- tensorlink/ml/validator.py | 51 +++- tensorlink/ml/worker.py | 1 + tensorlink/nodes/nodes.py | 195 +++++++++++-- tensorlink/nodes/validator_thread.py | 13 + tests/conftest.py | 9 +- tests/test_model_api.py | 2 +- tests/test_model_parser.py | 13 +- 12 files changed, 591 insertions(+), 243 deletions(-) diff --git a/bin/config.json b/bin/config.json index 2279ed4..238ce89 100644 --- a/bin/config.json +++ b/bin/config.json @@ -18,7 +18,8 @@ ] }, "ml": { - "trusted": false + "trusted": false, + "max_vram_gb": 0 } } } \ No newline at end of file diff --git a/bin/run_node.py b/bin/run_node.py index fbd7b49..a3b0898 100644 --- a/bin/run_node.py +++ b/bin/run_node.py @@ -210,11 +210,17 @@ def main(): # Get node type from config node_type = config.get("node", {}).get("type", "worker").lower() - if node_type not in ["worker", "validator"]: + if node_type not in ["worker", "validator", "both"]: raise ValueError( - f"Invalid node type: {node_type}. Must be 'worker' or 'validator'" + f"Invalid node type: {node_type}. Must be 'worker', 'validator', or 'both'" ) + max_vram_gb = config.get("ml", {}).get("max_vram_gb", 0) + enable_hosting = False + if node_type == "both": + node_type = "validator" + enable_hosting = True + # Parse common config trusted = config.get("ml", {}).get("trusted", False) mode = config.get("node", {}).get("mode", "private") @@ -241,6 +247,7 @@ def main(): print_level=log_level, priority_nodes=config.get("node", {}).get("priority_nodes", []), seed_validators=config.get("crypto", {}).get("seed_validators", []), + # max_vram_gb=max_vram_gb ), trusted=trusted, utilization=True, @@ -260,6 +267,8 @@ def main(): print_level=log_level, priority_nodes=config.get("node", {}).get("priority_nodes", []), seed_validators=config.get("crypto", {}).get("seed_validators", []), + max_vram_gb=max_vram_gb, + enable_hosting=enable_hosting, ), trusted=trusted, ) diff --git a/tensorlink/ml/graphing.py b/tensorlink/ml/graphing.py index 3c6970f..14ec70a 100644 --- a/tensorlink/ml/graphing.py +++ b/tensorlink/ml/graphing.py @@ -54,7 +54,7 @@ def _create_grouped_entry(parent_path: str, group: list) -> dict: "num_layers": len(group), } - # Preserve parent_module_path if present (but not parent_forward_code) + # Preserve parent_module_path if present if "parent_module_path" in configs[0]: grouped_config["parent_module_path"] = configs[0]["parent_module_path"] @@ -151,13 +151,40 @@ def _is_loop_iterable_module(module: nn.Module, module_path: str) -> bool: class ModelParser: + """ + Parses a PyTorch model and constructs a distributed execution configuration + for Tensorlink by analyzing module structure, memory requirements, and + forward-pass behavior. It will assign individual models, modules, or groups + of sequential models in the model + + The ModelParser is responsible for: + - Walking the module hierarchy. + - Estimating memory usage per submodule. + - Assigning modules to workers or host. + - Detecting and rewriting forward loops for offloaded execution. + - Producing a configuration graph used by DistributedModel. + + This class does not execute the model itself, but prepares the metadata and + transformed forward methods required for distributed inference or training. + """ + def __init__(self, user_memory: int = 0, verbose=False): - self.user_memory = user_memory + """ + Initialize a ModelParser instance. + + Parameters + ---------- + verbose : bool, optional + If True, enables verbose logging during model parsing, memory estimation, + and assignment steps. Default is False. + """ + self.model_name = "" self.assigned_workers = defaultdict(list) self.assigned_memory = 0 self.verbose = verbose self.module_paths = {} # Track all module paths + self._host_max_module_bytes = 0 def create_distributed_config( self, @@ -167,8 +194,9 @@ def create_distributed_config( trusted: bool, input_obfuscation: bool = False, optimizer_type: str = "adam", - optimizer_spec: dict = {}, + optimizer_spec: Optional[dict] = None, host_max_memory_bytes: int = 0, + host_max_module_bytes: int = 0, host_max_depth: int = 2, max_offload_depth: int = 3, max_seq_len: int = 4096, @@ -185,109 +213,110 @@ def create_distributed_config( - kept on the local host, - fully offloaded to a remote worker, - split into children and recursively assigned, or - - marked as unassignable. + - marked as unassigned. The result is a config dictionary describing how the model should be partitioned for distributed inference or training. - Parameters - ---------- - model : Union[nn.Module, str] - Either a PyTorch model instance or a HuggingFace model name. If a string - is provided, the model is instantiated with empty weights using - `AutoConfig` to avoid loading parameters into memory. - - workers : dict - Mapping of worker_id -> worker metadata. Each worker entry must contain - at least: - { - "gpu_memory": - } - This memory is decremented as modules are assigned. - - training : bool - Whether the configuration is for training or inference. Training mode - increases memory estimates to include gradients, optimizer state, and - activation storage. - - trusted : bool - Indicates whether workers are trusted. Used for downstream logic such as - security policies, encryption, or obfuscation decisions. + Args: + model : Union[nn.Module, str] + Either a PyTorch model instance or a HuggingFace model name. If a string + is provided, the model is instantiated with empty weights using + `AutoConfig` to avoid loading parameters into memory. + + workers : dict + Mapping of worker_id -> worker metadata. Each worker entry must contain + at least: + { + "gpu_memory": + } + This memory is decremented as modules are assigned. - input_obfuscation : bool, optional (default=False) - Whether inputs should be obfuscated when sent to workers. This flag is - propagated into the distributed config for runtime enforcement. + training : bool + Whether the configuration is for training or inference. Training mode + increases memory estimates to include gradients, optimizer state, and + activation storage. - optimizer_type : str, optional (default="adam") - Optimizer type used for memory estimation (e.g. "adam", "sgd"). This - affects optimizer state size during training. + trusted : bool + Indicates whether workers are trusted. Used for downstream logic such as + security policies, encryption, or obfuscation decisions. - optimizer_spec : dict, optional - Extra optimizer configuration to attach to each assigned module - (e.g. learning rate, betas, weight decay). Stored in the config and - passed to workers. + input_obfuscation : bool, optional (default=False) + Whether inputs should be obfuscated when sent to workers. This flag is + propagated into the distributed config for runtime enforcement. - host_max_memory_bytes : int, optional (default=0) - Maximum number of bytes the local host is allowed to consume for loading - small submodules. If 0, the host will not keep modules locally. + optimizer_type : str, optional (default="adam") + Optimizer type used for memory estimation (e.g. "adam", "sgd"). This + affects optimizer state size during training. - host_max_depth : int, optional (default=2) - Maximum recursion depth at which the host is allowed to keep modules. - Prevents deep layers from being pinned locally. + optimizer_spec : dict, optional + Extra optimizer configuration to attach to each assigned module + (e.g. learning rate, betas, weight decay). Stored in the config and + passed to workers. - max_offload_depth : int, optional (default=3) - Maximum recursion depth for offloading. If exceeded, the module is marked - as unassignable and an AssignmentError may be raised in verbose mode. + host_max_memory_bytes : int, optional (default=0) + Maximum number of bytes the local host is allowed to consume for loading + small submodules. If 0, the host will not keep modules locally. - max_seq_len : int, optional (default=4096) - Maximum sequence length used for estimating activation and KV cache - memory during inference or training. + host_max_module_bytes : int, optional (default=0) + Maximum bytes size the local host is allowed to consume for an individual + submodule. If 0, the host will consider module size. - batch_size : int, optional (default=1) - Batch size used for memory estimation of activations and optimizer state. + host_max_depth : int, optional (default=2) + Maximum recursion depth at which the host is allowed to keep modules. + Prevents deep layers from being pinned locally. - model_type : str, optional (default="chat") - Logical model type (e.g. "chat", "vision", "embedding"). Stored in the - config and used by downstream execution logic. + max_offload_depth : int, optional (default=3) + Maximum recursion depth for offloading. If exceeded, the module is marked + as unassigned and an AssignmentError may be raised in verbose mode. - Returns - ------- - dict - A dictionary with the following keys: + max_seq_len : int, optional (default=4096) + Maximum sequence length used for estimating activation and KV cache + memory during inference or training. - - success : bool - Whether assignment completed successfully. + batch_size : int, optional (default=1) + Batch size used for memory estimation of activations and optimizer state. - - config : dict - Mapping of module_path -> assignment spec, where each entry may be: - { - "type": "loaded" | "offloaded" | "unassigned", - "device": "host" (if loaded), - "assigned_workers": [worker_id] (if offloaded), - "module_id": list, - "memory": bytes, - "module": str, - "module_path": str, - "training": bool, - "optimizer_spec": dict, - "batch_size": int, - "model_type": str, - "parent_module_path": str (optional, for pipelining) - } + model_type : str, optional (default="chat") + Logical model type (e.g. "chat", "vision", "embedding"). Stored in the + config and used by downstream execution logic. - - model_memory : int - Total estimated memory footprint of the model under the provided - parameters (including activations and KV cache). - - Notes - ----- - - Modules that are too large or loop-iterable are recursively split into - children until they can be assigned. - - Sequential layers may later be grouped for pipeline parallelism via - `_group_sequential_layers`. - - Worker memory is decremented as assignments occur to prevent overcommit. - - If assignment fails, `success=False` is returned and config may be partial. + Returns: + dict: A dictionary with the following keys: + - success : bool + Whether assignment completed successfully. + - config : dict + Mapping of module_path -> assignment spec, where each entry may be: + { + "type": "loaded" | "offloaded" | "unassigned", + "device": "host" (if loaded), + "assigned_workers": [worker_id] (if offloaded), + "memory": bytes, + "module": str, + "module_path": str, + "training": bool, + "optimizer_spec": dict, + "batch_size": int, + "model_type": str, + "parent_module_path": str (optional, for pipelining) + } + - model_memory : int + Total estimated memory footprint of the model under the provided + parameters (including activations and KV cache). + - host_memory_used: int + Assigned memory to validator + + Notes: + - Modules that are too large or loop-iterable are recursively split into + children until they can be assigned. + - Sequential layers may later be grouped for pipeline parallelism via + `_group_sequential_layers`. + - Worker memory is decremented as assignments occur to prevent overcommit. + - If assignment fails, `success=False` is returned and config may be partial. """ + self.assigned_memory = 0 + if optimizer_spec is None: + optimizer_spec = {} if isinstance(model, str): self.model_name = model @@ -303,6 +332,10 @@ def create_distributed_config( config = {} success = True + self._host_max_module_bytes = host_max_module_bytes + if host_max_module_bytes == 0: + self._host_max_module_bytes = 1e15 # Set to massive number if not specified + model_memory, breakdown = estimate_memory( model, training=training, @@ -315,7 +348,7 @@ def create_distributed_config( ) try: - config, _ = self._recurse_module( + config, _, _ = self._recurse_module( module=model, module_path="model", workers_state=workers_state, @@ -342,7 +375,12 @@ def create_distributed_config( except AssignmentError as e: success = False - return {"success": success, "config": config, "model_memory": model_memory} + return { + "success": success, + "config": config, + "model_memory": model_memory, + "host_memory_used": self.assigned_memory, + } def _recurse_module( self, @@ -354,7 +392,6 @@ def _recurse_module( input_obfuscation: bool, last_worker: Optional[str] = None, depth: int = 0, - ids: list = None, optimizer_type="adam", optimizer_spec=None, host_max_memory_bytes: int = 0, @@ -364,16 +401,12 @@ def _recurse_module( batch_size: int = 1, model_type: str = "chat", count_activations: bool = True, + obfuscation_layer_assigned: bool = False, ): config = {} - if ids is None: - ids = [] indent = " " * depth - # Root rule, host can never load entire model - is_root = module_path == "model" - # Log current module being processed if self.verbose: print(f"{indent}Processing: {module_path}") @@ -396,33 +429,80 @@ def _recurse_module( if self.verbose: print(f"{indent} Memory required: {memory / 1e6:.2f}MB") - # Local host small module logic + # --- Input obfuscation enforcement --- + force_host = False + if input_obfuscation and not obfuscation_layer_assigned: + # Keep the first substantial layer on host for input obfuscation + # This ensures raw inputs are transformed before being sent to workers + + # Check if this is a leaf module with parameters (actual layer, not container) + has_params = any(True for _ in module.parameters(recurse=False)) + has_no_children = len(list(module.children())) == 0 + + if has_params and has_no_children: + # This is a leaf layer with parameters - good candidate for obfuscation layer + force_host = True + obfuscation_layer_assigned = True + elif depth <= 1: + # At shallow depth, still enforce obfuscation even for containers + # to ensure we capture embedding layers or initial processing + force_host = True + + # If we need to force host but have no host memory budget, that's an error + if force_host and host_max_memory_bytes == 0: + raise ValueError( + f"input_obfuscation=True requires host_max_memory_bytes > 0 to keep " + f"the input transformation layer on the host." + ) + + # Local host module if we have the memory OR input obfuscation is enabled if ( - not is_root - and host_max_memory_bytes + host_max_memory_bytes and memory <= host_max_memory_bytes - self.assigned_memory and depth <= host_max_depth - ): - config[module_path] = { - "type": "loaded", - "device": "host", - "name": self.model_name, - "module_id": ids, - "memory": memory, - "module": ( - f"{type(module)}".split(".")[-1].split(">")[0][:-1] - if not isinstance(module, str) - else module - ), - "module_path": module_path, - "training": training, - "optimizer_spec": optimizer_spec, - "batch_size": batch_size, - "model_type": model_type, - } - if self.verbose: - print(f"{indent} Kept on host (local) — {memory / 1e6:.2f}MB") - return config, None + and memory <= self._host_max_module_bytes + ) or force_host: + # Double-check we can actually fit this on host if forced + if force_host and memory > host_max_memory_bytes - self.assigned_memory: + if self.verbose: + print( + f"{indent} WARNING: Obfuscation layer too large for host ({memory / 1e6:.2f}MB > {(host_max_memory_bytes - self.assigned_memory) / 1e6:.2f}MB available)" + ) + # Don't force it if it truly won't fit + force_host = False + else: + prev_assigned = self.assigned_memory + try: + self.assigned_memory += memory + config[module_path] = { + "type": "loaded", + "device": "host", + "name": self.model_name, + "memory": memory, + "module": ( + f"{type(module)}".split(".")[-1].split(">")[0][:-1] + if not isinstance(module, str) + else module + ), + "module_path": module_path, + "training": training, + "optimizer_spec": optimizer_spec, + "batch_size": batch_size, + "model_type": model_type, + "input_boundary": ( + True if input_obfuscation and depth == 0 else False + ), + } + + if self.verbose: + why = "obfuscation boundary" if force_host else "host budget" + print(f"{indent} Kept on host ({why}) — {memory / 1e6:.2f}MB") + + return config, None, obfuscation_layer_assigned + + except Exception: + self.assigned_memory = prev_assigned + raise # Check if module is loop-iterable BEFORE trying to assign is_loop_iterable = _is_loop_iterable_module(module, module_path) @@ -443,7 +523,6 @@ def _recurse_module( "type": "offloaded", "name": self.model_name, "assigned_workers": [assigned_worker], - "module_id": ids, "memory": memory, "module": ( f"{type(module)}".split(".")[-1].split(">")[0][:-1] @@ -459,7 +538,6 @@ def _recurse_module( self.assigned_workers[assigned_worker].append( { - "module_id": ids, "memory": memory, "module": module, "module_path": module_path, @@ -469,7 +547,7 @@ def _recurse_module( if self.verbose: print(f"{indent} Assigned to {assigned_worker}") - return config, assigned_worker + return config, assigned_worker, obfuscation_layer_assigned # Check if we've exceeded max recursion depth if depth >= max_offload_depth: @@ -513,23 +591,26 @@ def _recurse_module( child_path = f"{module_path}.{child_name}" try: - child_config, child_last_worker = self._recurse_module( - module=child_module, - module_path=child_path, - workers_state=workers_state, - training=training, - trusted=trusted, - last_worker=prev_child_worker, - input_obfuscation=input_obfuscation, - depth=depth + 1, - optimizer_type=optimizer_type, - optimizer_spec=optimizer_spec, - host_max_memory_bytes=host_max_memory_bytes, - host_max_depth=host_max_depth, - max_offload_depth=max_offload_depth, - max_seq_len=max_seq_len, - batch_size=batch_size, - count_activations=False, + child_config, child_last_worker, obfuscation_layer_assigned = ( + self._recurse_module( + module=child_module, + module_path=child_path, + workers_state=workers_state, + training=training, + trusted=trusted, + last_worker=prev_child_worker, + input_obfuscation=input_obfuscation, + depth=depth + 1, + optimizer_type=optimizer_type, + optimizer_spec=optimizer_spec, + host_max_memory_bytes=host_max_memory_bytes, + host_max_depth=host_max_depth, + max_offload_depth=max_offload_depth, + max_seq_len=max_seq_len, + batch_size=batch_size, + count_activations=False, + obfuscation_layer_assigned=obfuscation_layer_assigned, + ) ) config.update(child_config) @@ -544,17 +625,16 @@ def _recurse_module( print(f"{indent} Child {child_path} failed: {e}") raise - # Add parent_module_path when children span multiple workers - if len(child_workers) > 1: - # Get the children that were just processed (belong to this parent) - for child_name, _ in children: - child_path = f"{module_path}.{child_name}" - if child_path in config: - child_cfg = config[child_path] - if child_cfg.get("type") == "offloaded": - child_cfg["parent_module_path"] = module_path + # Add parent_module_path + # Get the children that were just processed (belong to this parent) + for child_name, _ in children: + child_path = f"{module_path}.{child_name}" + if child_path in config: + child_cfg = config[child_path] + if child_cfg.get("type") == "offloaded": + child_cfg["parent_module_path"] = module_path - return config, last_successful_worker + return config, last_successful_worker, obfuscation_layer_assigned def _try_assign_worker( self, @@ -589,30 +669,6 @@ def _try_assign_worker( return None - def _extract_forward_code(self, module: nn.Module): - """ - Extract the forward pass logic from the source code. Allows workers to - execute the parent's forward logic locally. - """ - # Check cache - module_class = type(module) - if module_class in self.forward_code_cache: - return self.forward_code_cache[module_class] - - try: - forward_method = module.forward - source = inspect.getsource(forward_method) - source = textwrap.dedent(source) - self.forward_code_cache[module_class] = source - return source - - except (OSError, TypeError) as e: - if self.verbose: - print( - f"Could not extract forward code for {module_class.__name__}: {e}" - ) - return None - def _log_assignment_summary(self, config: dict, workers_state: dict): """ Log a summary of the final assignment after configuration is complete. diff --git a/tensorlink/ml/injector.py b/tensorlink/ml/injector.py index e3d96e3..77c2949 100644 --- a/tensorlink/ml/injector.py +++ b/tensorlink/ml/injector.py @@ -564,35 +564,51 @@ def _generate_worker_calls( return calls +def _analyze_forward(fn): + """ + Extract source, parse AST, extract args, and locate loop node. + """ + source = inspect.getsource(fn) + source = textwrap.dedent(source) + tree = ast.parse(source) + + arg_extractor = FunctionArgExtractor() + arg_extractor.visit(tree) + + loop_finder = LoopFinder() + loop_finder.visit(tree) + + return source, tree, arg_extractor, loop_finder + + def generate_new_forward_method( - parent_module, offloaded_modules: List + parent_module, base_module, offloaded_modules: List ) -> types.FunctionType: """ Generate a new forward method with loop replaced by worker calls. Args: parent_module: The module whose forward pass contains the loop + base_module: Base model containing the module offloaded_modules: List of OffloadedModule instances to call sequentially - original_globals: Global namespace to use for the new function (optional) Returns: New forward function (unbound) """ - # Get original forward source original_forward = parent_module.forward - source = inspect.getsource(original_forward) - source = textwrap.dedent(source) - # Parse and analyze - tree = ast.parse(source) - - # Extract function arguments - arg_extractor = FunctionArgExtractor() - arg_extractor.visit(tree) + # First attempt: parent forward + source, tree, arg_extractor, loop_finder = _analyze_forward(original_forward) - # Find the loop - loop_finder = LoopFinder() - loop_finder.visit(tree) + # Fallback: base forward + if not loop_finder.loop_node: + try: + original_forward = base_module.forward + source, tree, arg_extractor, loop_finder = _analyze_forward( + original_forward + ) + except Exception: + raise ValueError("No suitable loop found in forward pass") if not loop_finder.loop_node: raise ValueError("No suitable loop found in forward pass") @@ -603,15 +619,13 @@ def generate_new_forward_method( loop_analyzer.visit(stmt) # Analyze variables created BEFORE the loop - func_node = tree.body[0] # The forward function + func_node = tree.body[0] pre_loop_analyzer = VariableUsageAnalyzer() for stmt in func_node.body: - # Stop when we reach the loop if stmt == loop_finder.loop_node: break pre_loop_analyzer.visit(stmt) - # Variables that exist before the loop are those written before it pre_loop_vars = pre_loop_analyzer.variables_written # Extract the layer call to preserve kwargs @@ -628,7 +642,11 @@ def generate_new_forward_method( # Generate new forward code new_forward_code = _generate_new_forward_source( - source, loop_finder.loop_node, layer_call_info, loop_vars, offloaded_modules + source, + loop_finder.loop_node, + layer_call_info, + loop_vars, + offloaded_modules, ) # Prepare namespace @@ -636,7 +654,7 @@ def generate_new_forward_method( try: exec(new_forward_code, namespace) - return namespace['forward'] + return namespace["forward"] except Exception as e: print("=" * 80) print("ERROR COMPILING NEW FORWARD") diff --git a/tensorlink/ml/module.py b/tensorlink/ml/module.py index e5fa6cc..8b1d90d 100644 --- a/tensorlink/ml/module.py +++ b/tensorlink/ml/module.py @@ -148,7 +148,8 @@ class DistributedModel(nn.Module): """ A modular distributed model that supports offloading submodules while handling local operations. This model can be instantiated - by either a Worker or a User, where the host is referred to as the 'master' node. + by either a Worker or a User, where the host is referred to as + the 'master' node. Features: - Handles distributed training across multiple nodes. @@ -172,7 +173,7 @@ def __init__( tokenizer=None, ): """ - Args: + Parameters: model (nn.Module): The base model to distribute. n_pipelines (int): Number of parallel pipelines for computation. optimizer (Type[optim.Optimizer]): Optimizer class to use. @@ -841,7 +842,9 @@ def _inject_grouped_layer_forward( parent_module.offloaded_modules = offloaded_modules - new_forward = generate_new_forward_method(parent_module, offloaded_modules) + new_forward = generate_new_forward_method( + parent_module, self.model, offloaded_modules + ) parent_module.forward = types.MethodType(new_forward, parent_module) @@ -955,17 +958,7 @@ def _load_single_host_module(self, module_id: str, module_info: Dict[str, Any]): "Module path refers to full model; loading directly into root model" ) - state_dict = self._load_module_weights(self.model_name, ["model"]) - - missing_keys, unexpected_keys = self.model.load_state_dict( - state_dict, strict=False - ) - - if missing_keys: - logging.warning(f"Model missing keys: {missing_keys}") - if unexpected_keys: - logging.warning(f"Model unexpected keys: {unexpected_keys}") - + self.model = self._load_full_model(self.model_name, module_info) return # Navigate to parent @@ -1003,6 +996,59 @@ def _load_single_host_module(self, module_id: str, module_info: Dict[str, Any]): logging.info(f"Successfully loaded host module {module_class}") + def _load_full_model(self, model_name: str, module_info: dict) -> torch.nn.Module: + """ + Load a complete model from HuggingFace with optimal memory usage. + Uses HF's native loading which is more memory-efficient than manual skeleton+weights. + """ + model_type = module_info.get('model_type', 'chat') + num_gpus = torch.cuda.device_count() + + # Force garbage collection before loading + load_kwargs = { + "low_cpu_mem_usage": True, + "torch_dtype": torch.float16, # TODO route quantization params through job requests, should also be done for module loading + } + + # Only use device_map for multi-GPU + if num_gpus > 1: + load_kwargs["device_map"] = "auto" + else: + # For single GPU, load to CPU first then move + load_kwargs["device_map"] = "cpu" + + logging.info(f"Loading full model {model_name} with type {model_type}") + + # Load model based on type + if model_type in ("causal", "chat"): + final_model = AutoModelForCausalLM.from_pretrained( + model_name, **load_kwargs + ) + elif model_type == "seq2seq": + final_model = AutoModelForSeq2SeqLM.from_pretrained( + model_name, **load_kwargs + ) + elif model_type == "vision2text": + final_model = AutoModelForVision2Seq.from_pretrained( + model_name, **load_kwargs + ) + elif model_type == "audio2text": + final_model = AutoModelForSpeechSeq2Seq.from_pretrained( + model_name, **load_kwargs + ) + else: + model_config = AutoConfig.from_pretrained(model_name) + final_model = AutoModel.from_pretrained( + model_name, config=model_config, **load_kwargs + ) + + # Move to GPU only after fully loaded (for single GPU) + if num_gpus == 1 and self.device.type == "cuda": + final_model = final_model.to(self.device) + + logging.info(f"Successfully loaded full model {model_name}") + return final_model + def _load_module_weights( self, model_name: str, module_paths: List[str] ) -> Dict[str, torch.Tensor]: diff --git a/tensorlink/ml/validator.py b/tensorlink/ml/validator.py index a46981a..3032a09 100644 --- a/tensorlink/ml/validator.py +++ b/tensorlink/ml/validator.py @@ -7,7 +7,12 @@ format_chat_prompt, extract_reasoning_and_answer, ) -from tensorlink.ml.utils import load_models_cache, save_models_cache, get_gpu_memory +from tensorlink.ml.utils import ( + load_models_cache, + save_models_cache, + get_gpu_memory, + attach_tensor, +) from tensorlink.api.models import GenerationRequest from transformers import AutoTokenizer, TextIteratorStreamer @@ -114,9 +119,28 @@ def _post_process_output(request, tokenizer, formatted_prompt, text): class DistributedValidator(DistributedWorker): - def __init__(self, node, trusted=False, endpoint=True): + """ + Backend logic for handling Distributed Models, assigning workers, and + handling job requests from users. To be run alongside the background + ValidatorThread that manages its networking, event loops, connections, etc. + Validators do not perform heavy computation by default but can be configured + to also host modules via enable_hosting. + """ + + def __init__( + self, + node, + trusted: bool = False, + endpoint: bool = True, + enable_hosting: bool = False, + max_vram_gb: float = 0, + max_module_bytes: int = 0, + ): super().__init__(node, trusted) self.endpoint = endpoint + self._hosting_enabled = enable_hosting + self._max_vram_bytes = max_vram_gb * 1e9 # Convert to bytes + self._max_module_bytes = max_module_bytes self.model_cache = load_models_cache() self.models = {} # job_id -> model instance self.model_state = ( @@ -379,6 +403,7 @@ def inspect_model( optimizer_type=optimizer_type, optimizer_spec=optimizer_spec, host_max_memory_bytes=host_memory_budget, + host_max_module_bytes=self._max_module_bytes, host_max_depth=1, max_offload_depth=3, batch_size=job_data.get("batch_size", batch_size), @@ -397,7 +422,7 @@ def inspect_model( if ( len(distribution["config"]) == 0 or offloaded_count - > 4 # TODO This limit on number of distributions is not ideal + > 5 # TODO This limit on number of distributions is not ideal or not distribution["success"] ): return {} @@ -651,7 +676,7 @@ def _generate(self, request, job_id, start_time): distributed_model = ctx["distributed_model"] tokenizer = ctx["tokenizer"] formatted_prompt = ctx["formatted_prompt"] - input_ids = ctx["input_ids"] + input_ids = attach_tensor(ctx["input_ids"], self.device) prompt_tokens = ctx["prompt_tokens"] args = ctx["args"] @@ -953,6 +978,9 @@ def _finalize_hosted_job(self, job_id: str): # Get the DistributedModel instance distributed_model = self.models[job_id] + if not distribution: + distribution = distributed_model.config + # Update state self.model_state[job_id] = "distributing" @@ -1122,9 +1150,18 @@ def _release_host_memory(self, job_id: str): def _get_available_host_memory(self) -> int: """Get currently available host memory accounting for reservations""" - with self.memory_lock: - total_memory = get_gpu_memory() - return total_memory - self.host_memory_reserved + available_memory = 0 + max_vram_bytes = self._max_vram_bytes + if max_vram_bytes <= 0: + max_vram_bytes = ( + 1e15 # Set to massive number (1PB) when max vram was not specified + ) + + if self._hosting_enabled: + with self.memory_lock: + total_memory = min(get_gpu_memory(), max_vram_bytes) + available_memory += total_memory - self.host_memory_reserved + return available_memory def _audit_memory_reservations(self): """ diff --git a/tensorlink/ml/worker.py b/tensorlink/ml/worker.py index b654955..4467404 100644 --- a/tensorlink/ml/worker.py +++ b/tensorlink/ml/worker.py @@ -842,6 +842,7 @@ def _load_single_module( # explicit path provided target_module = get_nested_module(base_model, parent_module_path) effective_layer_path = parent_module_path + else: # parent_module_path is 'model' or empty -> try to find by class name effective_layer_path = _find_module_path_by_class( diff --git a/tensorlink/nodes/nodes.py b/tensorlink/nodes/nodes.py index 9e344ee..6c160ea 100644 --- a/tensorlink/nodes/nodes.py +++ b/tensorlink/nodes/nodes.py @@ -15,6 +15,27 @@ @dataclass class BaseNodeConfig: + """ + Base configuration shared across all Tensorlink node roles. + + Attributes + ---------- + upnp : bool + Whether to attempt UPnP port forwarding. + max_connections : int + Maximum number of peer connections allowed. + on_chain : bool + Whether to interact with the on-chain Smartnodes layer. + local_test : bool + Enables local-only networking for testing. + print_level : int + Logging verbosity level. + priority_nodes : Optional[List[List[str]]] + Preferred peers to connect to first. + seed_validators : Optional[List[List[str]]] + Bootstrap validators for network discovery. + """ + upnp: bool = True max_connections: int = 0 on_chain: bool = True @@ -26,12 +47,20 @@ class BaseNodeConfig: @dataclass class WorkerConfig(BaseNodeConfig): + """ + Configuration specific to Worker nodes. + """ + duplicate: str = "" load_previous_state: bool = False @dataclass class ValidatorConfig(BaseNodeConfig): + """ + Configuration specific to Validator nodes. + """ + endpoint: bool = True endpoint_url: str = "0.0.0.0" endpoint_port: int = 64747 @@ -40,6 +69,10 @@ class ValidatorConfig(BaseNodeConfig): @dataclass class UserConfig(BaseNodeConfig): + """ + Configuration specific to User nodes. + """ + pass @@ -70,30 +103,56 @@ def show_spinner(stop_event, message="Processing"): class BaseNode: + """ + Base node runner that handles the startup of the P2P node thread + alongside a distributed ML process (i.e. DistributedWorker, + DistributedValidator, or DistributedModel). + """ + def __init__( self, config: BaseNodeConfig, trusted: bool = False, utilization: bool = True, ): + """ + Initialize a BaseNode instance. + + Parameters + ---------- + config : BaseNodeConfig + Configuration object for the node role. + + trusted : bool, optional + Whether this node is trusted within the network (bypasses some + verification or security checks). Default is False. + + utilization : bool, optional + If True, runs distributed ML logic in a background thread to allow + concurrent network operation. If False, runs synchronously. + """ self.config = config self.trusted = trusted self.utilization = utilization + # IPC primitives for communicating with the role process self.node_requests = mp.Queue() self.node_responses = mp.Queue() self.mpc_lock = mp.Lock() + # Multiprocessing lifecycle handles self.node_process = None self._stop_event = mp.Event() + # Install signal handlers and immediately start the node self._setup_signal_handlers() self.start() def _setup_signal_handlers(self): """ - Set up signal handlers for graceful shutdown. - Uses a multiprocessing Event to signal across processes. + Set up OS signal handlers for graceful shutdown. + + Uses a multiprocessing Event to propagate stop signals across processes. """ def handler(signum, frame): @@ -107,38 +166,56 @@ def handler(signum, frame): signal.signal(sig, handler) def start(self): + """ + Spawn the multiprocessing role process. + """ self.node_process = mp.Process(target=self.run_role, daemon=True) self.node_process.start() def cleanup(self): - # Process cleanup + """ + Gracefully shut down the role process and release resources. + """ if self.node_process is not None and self.node_process.exitcode is None: - # Send a stop request to the role instance + # Ask the role to stop cleanly first response = self.send_request("stop", (None,), timeout=15) if response: self.node_process.join(timeout=15) - # If the process is still alive, terminate it + # Force terminate if still alive if self.node_process.is_alive(): print("Forcing termination for node process.") self.node_process.terminate() - # Final join to ensure it's completely shut down self.node_process.join() - self.node_process = None # Reset to None after cleanup + self.node_process = None def send_request(self, request_type, args, timeout=5): """ - Sends a request to the roles and waits for the response. + Send a request to the role process and wait for a response. + + Parameters + ---------- + request_type : str + Type of request to send. + args : tuple + Arguments to forward to the role. + timeout : int, optional + Timeout in seconds for request/response. + + Returns + ------- + Any + Value returned by the role handler. """ request = {"type": request_type, "args": args} try: self.mpc_lock.acquire(timeout=timeout) self.node_requests.put(request) - response = self.node_responses.get( - timeout=timeout - ) # Blocking call, waits for response + + # Blocking wait for response + response = self.node_responses.get(timeout=timeout) except Exception as e: print(f"Error sending '{request_type}' request: {e}") @@ -150,9 +227,18 @@ def send_request(self, request_type, args, timeout=5): return response["return"] def run_role(self): + """ + Entry point for the multiprocessing role process. + + Subclasses must override this to construct and run the appropriate + node thread (WorkerThread, ValidatorThread, UserThread). + """ raise NotImplementedError("Subclasses must implement this method") def connect_node(self, host: str, port: int, node_id: str = None, timeout: int = 5): + """ + Request a connection to another node in the network. + """ if node_id is None: node_id = "" @@ -160,12 +246,24 @@ def connect_node(self, host: str, port: int, node_id: str = None, timeout: int = class Worker(BaseNode): + """ + Tensorlink Worker node runner. + + Workers perform distributed ML execution and communicate with validators + to run offloaded modules. + """ + def __init__(self, config: WorkerConfig, **kwargs): - self.mining_active = mp.Value('b', False) - self.reserved_memory = mp.Value('d', 0.0) + # Shared state for mining / memory tracking + self.mining_active = mp.Value("b", False) + self.reserved_memory = mp.Value("d", 0.0) + super().__init__(config, **kwargs) def run_role(self): + """ + Launch the WorkerThread inside the role process. + """ node = WorkerThread( self.node_requests, self.node_responses, @@ -177,12 +275,18 @@ def run_role(self): node.activate() node.run() + # Keep process alive while the node thread is running while node.is_alive(): time.sleep(1) def start(self): + """ + Start the worker role and the DistributedWorker controller. + """ super().start() + distributed_worker = DistributedWorker(self, trusted=self.trusted) + if self.utilization: t = threading.Thread(target=distributed_worker.run, daemon=True) t.start() @@ -192,10 +296,45 @@ def start(self): class Validator(BaseNode): - def __init__(self, config: ValidatorConfig, **kwargs): + """ + Tensorlink Validator node runner. + + Validators coordinate jobs, verify execution, and optionally host + distributed modules. + """ + + def __init__( + self, + config: ValidatorConfig, + enable_hosting: bool = False, + max_vram_gb: float = 0, + max_module_bytes: int = 0, + **kwargs, + ): + """ + Initialize a Validator node. + + Parameters + ---------- + enable_hosting : bool + Whether this validator may host modules locally. + max_vram_gb : float + Maximum VRAM budget for hosted execution. + max_module_bytes : int + Maximum module size allowed for hosting. + """ + self._enable_hosting = enable_hosting + self._max_vram_gb = max_vram_gb + self._max_module_bytes = max_module_bytes + super().__init__(config, **kwargs) + self.config = config + def run_role(self): + """ + Launch the ValidatorThread inside the role process. + """ node = ValidatorThread( self.node_requests, self.node_responses, @@ -208,10 +347,22 @@ def run_role(self): time.sleep(1) def start(self): + """ + Start the validator role and DistributedValidator controller. + """ from tensorlink.ml.validator import DistributedValidator super().start() - distributed_validator = DistributedValidator(self, trusted=self.trusted) + + distributed_validator = DistributedValidator( + self, + trusted=self.trusted, + endpoint=self.config.endpoint, + enable_hosting=self._enable_hosting, + max_vram_gb=self._max_vram_gb, + max_module_bytes=self._max_module_bytes, + ) + if self.utilization: t = threading.Thread(target=distributed_validator.run, daemon=True) t.start() @@ -221,10 +372,20 @@ def start(self): class User(BaseNode): + """ + Tensorlink User node runner. + + Users submit jobs and interact with distributed models but do not + perform validation or heavy execution themselves. + """ + def __init__(self, config: UserConfig, **kwargs): super().__init__(config, **kwargs) def run_role(self): + """ + Launch the UserThread inside the role process. + """ node = UserThread( self.node_requests, self.node_responses, @@ -237,7 +398,9 @@ def run_role(self): time.sleep(1) def cleanup(self): - """Downloads parameters from workers before shutting down""" + """ + Download parameters from workers before shutting down. + """ if hasattr(self, "distributed_model"): if self.distributed_model.training: self.distributed_model.parameters(distributed=True, load=False) diff --git a/tensorlink/nodes/validator_thread.py b/tensorlink/nodes/validator_thread.py index 7521ff0..ca1a9f4 100644 --- a/tensorlink/nodes/validator_thread.py +++ b/tensorlink/nodes/validator_thread.py @@ -20,6 +20,16 @@ class ValidatorThread(Torchnode): + """ + Coordinates connections, job requests, and smart contract updates across + the Tensorlink network. + + The ValidatorThread is responsible for: + - Discovering and maintaining connections to workers and peers. + - Validating job requests and proposals. + - Interacting with the underlying smart contract layer (Smartnodes). + """ + def __init__( self, request_queue, @@ -36,6 +46,9 @@ def __init__( priority_nodes: list = None, seed_validators: list = None, ): + """ + Initialize a Validator P2P Node. + """ super(ValidatorThread, self).__init__( request_queue, response_queue, diff --git a/tests/conftest.py b/tests/conftest.py index 5d8da8d..b0c2678 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -46,7 +46,9 @@ def uwv_nodes(): endpoint=False, endpoint_url="127.0.0.1", load_previous_state=False, - ) + ), + enable_hosting=True, + max_vram_gb=0.4, ) worker = Worker( @@ -83,7 +85,10 @@ def wwv_nodes(): endpoint=True, endpoint_url="127.0.0.1", load_previous_state=False, - ) + ), + enable_hosting=True, + max_vram_gb=0, + max_module_bytes=int(4e8), ) worker = Worker( diff --git a/tests/test_model_api.py b/tests/test_model_api.py index e9c97f5..4299a20 100644 --- a/tests/test_model_api.py +++ b/tests/test_model_api.py @@ -33,7 +33,7 @@ pytest.param( { "name": "HuggingFaceTB/SmolLM2-135M", - "timeout": 1200, + "timeout": 120, "sleep": 10, "parsed": True, }, diff --git a/tests/test_model_parser.py b/tests/test_model_parser.py index eb588e0..6d1a307 100644 --- a/tests/test_model_parser.py +++ b/tests/test_model_parser.py @@ -137,19 +137,19 @@ def test_config_combinations(): test_configs = [ { "input_obfuscation": False, - "host_max_memory_bytes": 5e7, + "host_max_memory_bytes": 0, }, { - "input_obfuscation": False, - "host_max_memory_bytes": 5e7, + "input_obfuscation": True, + "host_max_memory_bytes": 0, }, { "input_obfuscation": False, - "host_max_memory_bytes": 5e8, + "host_max_memory_bytes": 5e7, }, { - "input_obfuscation": False, - "host_max_memory_bytes": 5e8, + "input_obfuscation": True, + "host_max_memory_bytes": 5e7, }, ] @@ -208,7 +208,6 @@ def test_config_combinations(): results.append( { "input_obfuscation": test_config["input_obfuscation"], - "host_threshold_mb": test_config["host_threshold_mb"], "success": False, "model_memory": "N/A", "components_sum": "N/A", From 14ce6e347f529c5ba914072a90c3d724dca94bac Mon Sep 17 00:00:00 2001 From: mattjhawken Date: Mon, 26 Jan 2026 20:46:36 -0500 Subject: [PATCH 11/25] Addressing bug with loop finder for llama models. --- tensorlink/ml/graphing.py | 2 +- tensorlink/ml/injector.py | 19 +++++-------------- tensorlink/ml/module.py | 6 +++--- tests/conftest.py | 5 +++-- tests/test_model_api.py | 9 --------- tests/test_model_parser.py | 35 +++++++++++++++++++++++++++-------- 6 files changed, 39 insertions(+), 37 deletions(-) diff --git a/tensorlink/ml/graphing.py b/tensorlink/ml/graphing.py index 14ec70a..d9a1d60 100644 --- a/tensorlink/ml/graphing.py +++ b/tensorlink/ml/graphing.py @@ -197,7 +197,7 @@ def create_distributed_config( optimizer_spec: Optional[dict] = None, host_max_memory_bytes: int = 0, host_max_module_bytes: int = 0, - host_max_depth: int = 2, + host_max_depth: int = 1, max_offload_depth: int = 3, max_seq_len: int = 4096, batch_size: int = 1, diff --git a/tensorlink/ml/injector.py b/tensorlink/ml/injector.py index 77c2949..d083003 100644 --- a/tensorlink/ml/injector.py +++ b/tensorlink/ml/injector.py @@ -595,24 +595,15 @@ def generate_new_forward_method( Returns: New forward function (unbound) """ - original_forward = parent_module.forward + if hasattr(base_module, "model"): + parent_module = base_module.model + original_forward = parent_module.forward + else: + original_forward = parent_module.model.forward # First attempt: parent forward source, tree, arg_extractor, loop_finder = _analyze_forward(original_forward) - # Fallback: base forward - if not loop_finder.loop_node: - try: - original_forward = base_module.forward - source, tree, arg_extractor, loop_finder = _analyze_forward( - original_forward - ) - except Exception: - raise ValueError("No suitable loop found in forward pass") - - if not loop_finder.loop_node: - raise ValueError("No suitable loop found in forward pass") - # Analyze variable usage in loop loop_analyzer = VariableUsageAnalyzer() for stmt in loop_finder.loop_node.body: diff --git a/tensorlink/ml/module.py b/tensorlink/ml/module.py index 8b1d90d..f8319c5 100644 --- a/tensorlink/ml/module.py +++ b/tensorlink/ml/module.py @@ -835,10 +835,10 @@ def _inject_grouped_layer_forward( Modify the parent module's forward method to call the offloaded layer group instead of looping through individual layers. """ - parent_path = list(grouped_layers.values())[0].get("parent_module_path", "") - assert isinstance(self.model, nn.Module), "Invalid model type" - parent_module = get_nested_module(self.model, parent_path) + parent_module = self.model + if hasattr(parent_module, "model"): + parent_module = parent_module.model parent_module.offloaded_modules = offloaded_modules diff --git a/tests/conftest.py b/tests/conftest.py index b0c2678..f4cd6b4 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -49,6 +49,7 @@ def uwv_nodes(): ), enable_hosting=True, max_vram_gb=0.4, + max_module_bytes=int(1e8), ) worker = Worker( @@ -87,8 +88,8 @@ def wwv_nodes(): load_previous_state=False, ), enable_hosting=True, - max_vram_gb=0, - max_module_bytes=int(4e8), + max_vram_gb=0.4, + max_module_bytes=int(1e8), ) worker = Worker( diff --git a/tests/test_model_api.py b/tests/test_model_api.py index 4299a20..e6d2add 100644 --- a/tests/test_model_api.py +++ b/tests/test_model_api.py @@ -39,15 +39,6 @@ }, id="smollm2-135m", ), - # pytest.param( - # { - # "name": "BabyLM-community/babylm-baseline-100m-gpt2", - # "timeout": 120, - # "sleep": 10, - # "parsed": True, - # }, - # id="babylm-baseline-100m-gpt2", - # ), ] diff --git a/tests/test_model_parser.py b/tests/test_model_parser.py index 6d1a307..53b7212 100644 --- a/tests/test_model_parser.py +++ b/tests/test_model_parser.py @@ -129,27 +129,46 @@ def test_config_combinations(): model distribution and memory allocation. Results shown as a simple DataFrame. """ # Base test parameters - test_model = "Qwen/Qwen2.5-14B-Instruct" + test_model = "HuggingFaceTB/SmolLM2-135M" batch_size = 1 seq_length = 4096 # Define test configurations to explore + test_workers = { + '509d89bf56704c67873c328e4f706a705b2fdc1671ebacab1083c9c6d2df650f': { + 'id': '509d89bf56704c67873c328e4f706a705b2fdc1671ebacab1083c9c6d2df650f', + 'gpu_memory': 4e8, + 'total_gpu_memory': 4e8, + 'role': 'W', + 'training': False, + }, + '209d89bf56704c67873c328e4f706a705b2fdc1671ebacab1083c9c6d2df650f': { + 'id': '209d89bf56704c67873c328e4f706a705b2fdc1671ebacab1083c9c6d2df650f', + 'gpu_memory': 4e8, + 'total_gpu_memory': 4e8, + 'role': 'W', + 'training': False, + }, + } + test_configs = [ { "input_obfuscation": False, "host_max_memory_bytes": 0, + "host_max_module_bytes": 0, + "host_max_depth": 1, }, { - "input_obfuscation": True, + "input_obfuscation": False, "host_max_memory_bytes": 0, + "host_max_module_bytes": 1e8, + "host_max_depth": 1, }, { "input_obfuscation": False, - "host_max_memory_bytes": 5e7, - }, - { - "input_obfuscation": True, - "host_max_memory_bytes": 5e7, + "host_max_memory_bytes": 4e8, + "host_max_module_bytes": 1e8, + "host_max_depth": 1, }, ] @@ -161,7 +180,7 @@ def test_config_combinations(): try: config = parser.create_distributed_config( test_model, - WORKERS, + test_workers, training=False, trusted=False, optimizer_type="adam", From 141514578f9e735e1fadd17e0faf15e2db2d6d74 Mon Sep 17 00:00:00 2001 From: mattjhawken Date: Tue, 27 Jan 2026 08:42:27 -0500 Subject: [PATCH 12/25] Configured max vram limit for worker and validator nodes. --- bin/run_node.py | 15 ++- tensorlink/ml/utils.py | 6 +- tensorlink/nodes/nodes.py | 6 +- tensorlink/nodes/worker_thread.py | 6 +- tensorlink/p2p/torch_node.py | 4 +- tests/test_model_api.py | 166 ------------------------------ tests/test_model_parser.py | 28 ++--- 7 files changed, 35 insertions(+), 196 deletions(-) diff --git a/bin/run_node.py b/bin/run_node.py index a3b0898..01214c5 100644 --- a/bin/run_node.py +++ b/bin/run_node.py @@ -210,16 +210,14 @@ def main(): # Get node type from config node_type = config.get("node", {}).get("type", "worker").lower() - if node_type not in ["worker", "validator", "both"]: + if node_type not in ["worker", "validator"]: raise ValueError( f"Invalid node type: {node_type}. Must be 'worker', 'validator', or 'both'" ) max_vram_gb = config.get("ml", {}).get("max_vram_gb", 0) - enable_hosting = False - if node_type == "both": - node_type = "validator" - enable_hosting = True + max_module_bytes = config.get("ml", {}).get("max_module_bytes", 1e8) + enable_hosting = True # Parse common config trusted = config.get("ml", {}).get("trusted", False) @@ -247,10 +245,10 @@ def main(): print_level=log_level, priority_nodes=config.get("node", {}).get("priority_nodes", []), seed_validators=config.get("crypto", {}).get("seed_validators", []), - # max_vram_gb=max_vram_gb ), trusted=trusted, utilization=True, + max_vram_gb=max_vram_gb, ) run_worker_loop(worker, config) @@ -267,10 +265,11 @@ def main(): print_level=log_level, priority_nodes=config.get("node", {}).get("priority_nodes", []), seed_validators=config.get("crypto", {}).get("seed_validators", []), - max_vram_gb=max_vram_gb, - enable_hosting=enable_hosting, ), trusted=trusted, + max_vram_gb=max_vram_gb, + max_module_bytes=int(max_module_bytes), + enable_hosting=enable_hosting, ) run_validator_loop(validator) diff --git a/tensorlink/ml/utils.py b/tensorlink/ml/utils.py index d6f7f68..6d8a19b 100644 --- a/tensorlink/ml/utils.py +++ b/tensorlink/ml/utils.py @@ -127,9 +127,10 @@ def estimate_memory( return total, breakdown -def get_gpu_memory(): +def get_gpu_memory(max_vram_gb: float = 0): # Check how much available mpc we can allocate to the roles memory = 0 + max_memory_bytes = int(max_vram_gb * 1e9) if torch.cuda.is_available(): devices = list(range(torch.cuda.device_count())) @@ -142,6 +143,9 @@ def get_gpu_memory(): else: memory += 4e8 + if max_memory_bytes > 0: + memory = min(memory, max_memory_bytes) + return memory diff --git a/tensorlink/nodes/nodes.py b/tensorlink/nodes/nodes.py index 6c160ea..1dbc31c 100644 --- a/tensorlink/nodes/nodes.py +++ b/tensorlink/nodes/nodes.py @@ -253,10 +253,10 @@ class Worker(BaseNode): to run offloaded modules. """ - def __init__(self, config: WorkerConfig, **kwargs): + def __init__(self, config: WorkerConfig, max_vram_gb: float = 0, **kwargs): # Shared state for mining / memory tracking self.mining_active = mp.Value("b", False) - self.reserved_memory = mp.Value("d", 0.0) + self._max_vram_gb = max_vram_gb super().__init__(config, **kwargs) @@ -269,7 +269,7 @@ def run_role(self): self.node_responses, **vars(self.config), mining_active=self.mining_active, - reserved_memory=self.reserved_memory, + max_vram_gb=self._max_vram_gb, ) node.activate() diff --git a/tensorlink/nodes/worker_thread.py b/tensorlink/nodes/worker_thread.py index a2ed716..b9a2fdc 100644 --- a/tensorlink/nodes/worker_thread.py +++ b/tensorlink/nodes/worker_thread.py @@ -29,7 +29,7 @@ def __init__( on_chain=False, local_test=False, mining_active=None, - reserved_memory=None, + max_vram_gb=0, duplicate="", load_previous_state=False, priority_nodes: list = None, @@ -45,6 +45,7 @@ def __init__( local_test=local_test, priority_nodes=priority_nodes, seed_validators=seed_validators, + max_vram_gb=max_vram_gb, ) self.training = False @@ -61,7 +62,6 @@ def __init__( ) self.mining_active = mining_active - self.reserved_memory = reserved_memory if self.on_chain: self.public_key = get_key(".tensorlink.env", "PUBLIC_KEY") @@ -227,7 +227,7 @@ def load_distributed_module(self, module: nn.Module, graph: dict = None): # proof["output"] = handle_output(self.model(dummy_input)).sum() def get_available_gpu_memory(self): - available_gpu_memory = get_gpu_memory() + available_gpu_memory = get_gpu_memory(self._max_vram_gb) for module_id, module_info in self.modules.items(): # Account for modules that are not in CUDA and are still initializing diff --git a/tensorlink/p2p/torch_node.py b/tensorlink/p2p/torch_node.py index c17c605..bf11b8c 100644 --- a/tensorlink/p2p/torch_node.py +++ b/tensorlink/p2p/torch_node.py @@ -78,6 +78,7 @@ def __init__( local_test=False, priority_nodes: list = None, seed_validators: list = None, + max_vram_gb: float = 0, ): super(Torchnode, self).__init__( role=role, @@ -90,7 +91,8 @@ def __init__( ) # Available GPU mpc estimation - self.available_gpu_memory = get_gpu_memory() + self._max_vram_gb = max_vram_gb + self.available_gpu_memory = get_gpu_memory(self._max_vram_gb) self.total_gpu_memory = self.available_gpu_memory self.available_ram = psutil.virtual_memory().available diff --git a/tests/test_model_api.py b/tests/test_model_api.py index e6d2add..388952e 100644 --- a/tests/test_model_api.py +++ b/tests/test_model_api.py @@ -386,169 +386,3 @@ def test_chat_completions(model_env): assert result["usage"]["total_tokens"] == ( result["usage"]["prompt_tokens"] + result["usage"]["completion_tokens"] ) - - -# @pytest.fixture(params=MODELS, scope="module") -# def model_env(request, connected_wwv_nodes): -# cfg = request.param -# worker, worker2, validator, _ = connected_wwv_nodes -# -# payload = {"hf_name": cfg["name"], "model_type": "causal"} -# response = requests.post(f"{SERVER_URL}/request-model", json=payload, timeout=cfg["timeout"]) -# assert response.status_code == 200 -# -# time.sleep(cfg["sleep"]) -# yield cfg, (worker, worker2, validator) -# -# -# # ------------------------- -# # Non-streaming tests -# # ------------------------- -# @pytest.mark.parametrize("output_format", ["simple", "openai"]) -# def test_generate(model_env, output_format): -# cfg, _ = model_env -# payload = { -# "hf_name": cfg["name"], -# "message": "Hi.", -# "max_new_tokens": 10, -# "do_sample": True, -# "num_beams": 2, -# "output_format": output_format, -# } -# response = requests.post(f"{SERVER_URL}/v1/generate", json=payload, timeout=100) -# assert response.status_code == 200 -# result = response.json() -# if output_format == "simple": -# validate_simple_response(result, cfg["name"]) -# else: -# validate_openai_response(result, cfg["name"]) -# -# -# # ------------------------- -# # Streaming tests -# # ------------------------- -# @pytest.mark.parametrize("output_format", ["simple", "openai"]) -# def test_streaming_generation(model_env, output_format): -# cfg, _ = model_env -# payload = { -# "hf_name": cfg["name"], -# "message": "Hi.", -# "max_new_tokens": 10, -# "do_sample": False, -# "num_beams": 1, -# "stream": True, -# "output_format": output_format, -# } -# full_text, chunks, content_fields = stream_and_collect(payload) -# assert isinstance(full_text, str) -# assert content_fields > 0 -# print(f"[{output_format}] Streamed {chunks} chunks: {full_text}") -# -# -# # ------------------------- -# # Chat completions (non-streaming) -# # ------------------------- -# def test_chat_completions(model_env): -# cfg, _ = model_env -# chat_payload = { -# "model": cfg["name"], -# "messages": [ -# {"role": "system", "content": "You are a helpful assistant."}, -# {"role": "user", "content": "Say 'Hello world' and nothing else."}, -# ], -# "max_tokens": 10, -# "temperature": 0.7, -# "stream": False, -# } -# response = requests.post(f"{SERVER_URL}/v1/chat/completions", json=chat_payload, timeout=120) -# assert response.status_code == 200 -# result = response.json() -# validate_openai_response(result, cfg["name"]) -# -# -# # ------------------------- -# # Chat completions streaming -# # ------------------------- -# def test_chat_completions_streaming(model_env): -# cfg, _ = model_env -# chat_payload = { -# "model": cfg["name"], -# "messages": [ -# {"role": "system", "content": "You are a helpful assistant."}, -# {"role": "user", "content": "Say 'Hello world' and nothing else."}, -# ], -# "max_tokens": 10, -# "temperature": 0.7, -# "stream": True, -# } -# full_text, chunks, content_fields = stream_and_collect(chat_payload) -# assert isinstance(full_text, str) -# assert content_fields > 0 -# print(f"✅ Chat completions stream received {chunks} chunks: {full_text}") -# -# -# # ------------------------- -# # Helper validators -# # ------------------------- -# def validate_simple_response(result, expected_model): -# assert "id" in result -# assert "model" in result and result["model"] == expected_model -# assert "text" in result and len(result["text"]) > 0 -# assert "usage" in result -# usage = result["usage"] -# assert usage["total_tokens"] == usage["prompt_tokens"] + usage["completion_tokens"] -# assert "processing_time" in result -# assert result.get("finish_reason", "stop") == "stop" -# -# -# def validate_openai_response(result, expected_model): -# assert result["object"] == "chat.completion" -# assert result["model"] == expected_model -# assert len(result["choices"]) > 0 -# choice = result["choices"][0] -# assert choice["index"] == 0 -# msg = choice["message"] -# assert msg["role"] == "assistant" -# assert isinstance(msg["content"], str) -# usage = result["usage"] -# assert usage["total_tokens"] == usage["prompt_tokens"] + usage["completion_tokens"] -# -# -# # ------------------------- -# # Streaming helper -# # ------------------------- -# def stream_and_collect(payload): -# response = requests.post(f"{SERVER_URL}/v1/generate", json=payload, stream=True, timeout=120) -# assert response.status_code == 200 -# full_text = "" -# done_received = False -# chunks = 0 -# content_fields = 0 -# -# for line in response.iter_lines(): -# if not line: -# continue -# decoded = line.decode("utf-8") -# if not decoded.startswith("data: "): -# continue -# data = decoded[6:] -# if data == "[DONE]": -# done_received = True -# break -# chunk = json.loads(data) -# chunks += 1 -# delta = chunk.get("choices", [{}])[0].get("delta", {}) -# if "content" in delta: -# content_fields += 1 -# full_text += delta["content"] -# elif chunk.get("token"): -# content_fields += 1 -# full_text += chunk["token"] -# elif chunk.get("done") is True: -# full_text = chunk.get("full_text", full_text) -# done_received = True -# break -# -# assert done_received -# assert chunks > 0 -# return full_text, chunks, content_fields diff --git a/tests/test_model_parser.py b/tests/test_model_parser.py index 53b7212..03ab1f8 100644 --- a/tests/test_model_parser.py +++ b/tests/test_model_parser.py @@ -129,7 +129,7 @@ def test_config_combinations(): model distribution and memory allocation. Results shown as a simple DataFrame. """ # Base test parameters - test_model = "HuggingFaceTB/SmolLM2-135M" + test_model = "Qwen/Qwen3-14B" batch_size = 1 seq_length = 4096 @@ -155,21 +155,21 @@ def test_config_combinations(): { "input_obfuscation": False, "host_max_memory_bytes": 0, - "host_max_module_bytes": 0, - "host_max_depth": 1, - }, - { - "input_obfuscation": False, - "host_max_memory_bytes": 0, - "host_max_module_bytes": 1e8, - "host_max_depth": 1, - }, - { - "input_obfuscation": False, - "host_max_memory_bytes": 4e8, "host_max_module_bytes": 1e8, "host_max_depth": 1, }, + # { + # "input_obfuscation": False, + # "host_max_memory_bytes": 0, + # "host_max_module_bytes": 1e8, + # "host_max_depth": 1, + # }, + # { + # "input_obfuscation": False, + # "host_max_memory_bytes": 4e8, + # "host_max_module_bytes": 1e8, + # "host_max_depth": 1, + # }, ] results = [] @@ -180,7 +180,7 @@ def test_config_combinations(): try: config = parser.create_distributed_config( test_model, - test_workers, + WORKERS, training=False, trusted=False, optimizer_type="adam", From 057697223398ed95980d75a10edb8daffc0d9b92 Mon Sep 17 00:00:00 2001 From: mattjhawken Date: Tue, 27 Jan 2026 11:26:02 -0500 Subject: [PATCH 13/25] Minor node binary bug fix + debug print enhancements. --- bin/config.json | 4 +- bin/run-node.sh | 8 +- bin/run_node.py | 2 +- tensorlink/p2p/connection.py | 1 - tensorlink/p2p/smart_node.py | 281 +++++++++++------------------------ tensorlink/p2p/torch_node.py | 33 ++-- 6 files changed, 114 insertions(+), 215 deletions(-) diff --git a/bin/config.json b/bin/config.json index 238ce89..e553265 100644 --- a/bin/config.json +++ b/bin/config.json @@ -4,10 +4,10 @@ "type": "worker", "mode": "public", "endpoint": false, - "endpoint_url": "127.0.0.1", + "endpoint_url": "0.0.0.0", "endpoint_port": 64747, "priority_nodes": [], - "logging": "INFO" + "logging": "ERROR" }, "crypto": { "address": "0x1Bc3a15dfFa205AA24F6386D959334ac1BF27336", diff --git a/bin/run-node.sh b/bin/run-node.sh index b5668d8..75ce2c5 100755 --- a/bin/run-node.sh +++ b/bin/run-node.sh @@ -53,12 +53,12 @@ try: config = json.load(f) if 'config' in config: config = config['config'] - node_type = config.get('node', {}).get('type', 'worker').lower() - print(node_type) + node_type = config.get('node', {}).get('type', 'worker') + print(str(node_type).strip().lower()) except Exception as e: - print('worker', file=sys.stderr) + print('worker') print(f'Warning: Could not read node type from config.json, defaulting to worker. Error: {e}', file=sys.stderr) -" 2>&1 +" } # Trap any unexpected errors diff --git a/bin/run_node.py b/bin/run_node.py index 01214c5..4e8e926 100644 --- a/bin/run_node.py +++ b/bin/run_node.py @@ -1,5 +1,5 @@ """ -TensorLink node runner. +Tensorlink node runner. Supports both Worker and Validator node types based on config.json """ diff --git a/tensorlink/p2p/connection.py b/tensorlink/p2p/connection.py index ff8c5e3..d204574 100644 --- a/tensorlink/p2p/connection.py +++ b/tensorlink/p2p/connection.py @@ -152,7 +152,6 @@ def send(self, data: bytes): except Exception as e: self.main_node.debug_print( f"Connection send error: {e}", - colour="bright_red", level=logging.ERROR, tag="Connection", ) diff --git a/tensorlink/p2p/smart_node.py b/tensorlink/p2p/smart_node.py index b33d480..0bf90b4 100644 --- a/tensorlink/p2p/smart_node.py +++ b/tensorlink/p2p/smart_node.py @@ -44,10 +44,13 @@ "bright_white": "\033[97m", } +VERBOSE = 5 + # Map logging levels to colors LEVEL_COLOURS = { + VERBOSE: "gray", logging.DEBUG: "blue", - logging.INFO: "gray", + logging.INFO: "green", logging.WARNING: "yellow", logging.ERROR: "red", logging.CRITICAL: "bright_red", @@ -281,8 +284,9 @@ def __init__( if self.upnp: self._init_upnp() - self._init_sock() + self.VERBOSE = VERBOSE + self._init_sock() self._priority_nodes = priority_nodes or [] self._seed_validators = seed_validators or [] if self.on_chain: @@ -394,12 +398,7 @@ def _handle_pong_response(self, node: Connection) -> bool: node.ping = time.time() - node.pinged node.pinged = -1 else: - self.debug_print( - "Received pong with no corresponding ping", - colour="red", - level=logging.WARNING, - tag="Smartnode", - ) + self._log_warning("Received pong with no corresponding ping") node.ghosts += 1 return True @@ -417,22 +416,14 @@ def _handle_value_response(self, data: bytes, node: Connection) -> bool: """ # Validate response packet size if len(data) < 86: - self.debug_print( - "Received incomplete value response", - colour="red", - level=logging.WARNING, - tag="Smartnode", + self._log_warning( + f"Received incomplete value response from: {node.node_id}" ) return False # Check if we have an active request for this node if node.node_id not in self.requests: - self.debug_print( - "Received unsolicited data", - colour="red", - level=logging.WARNING, - tag="Smartnode", - ) + self._log_warning(f"Received unsolicited data from: {node.node_id}") return False value_id = data[22:86].decode() @@ -445,12 +436,7 @@ def _handle_value_response(self, data: bytes, node: Connection) -> bool: self._store_request(value_id, value) return True - self.debug_print( - f"Ghost data from node: {node.node_id}", - colour="red", - level=logging.WARNING, - tag="Smartnode", - ) + self._log_warning(f"Received ghost data from node: {node.node_id}") node.ghosts += 1 return False @@ -483,16 +469,32 @@ def _handle_value_request(self, data: bytes, node: Connection) -> bool: return True - def _log_error(self, message: str) -> None: + def _log_error(self, message: str, tag="Smartnode") -> None: """ Log error messages with appropriate severity. Args: message (str): Error message to log """ - self.debug_print( - f"{message}", colour="bright_red", level=logging.ERROR, tag="Smartnode" - ) + self.debug_print(f"{message}", level=logging.ERROR, tag=tag) + + def _log_warning(self, message: str, tag="Smartnode") -> None: + """ + Log warning messages with appropriate severity. + + Args: + message (str): Error message to log + """ + self.debug_print(f"{message}", level=logging.ERROR, tag=tag) + + def _log_debug(self, message: str, tag="Smartnode") -> None: + """ + Log debug messages with appropriate severity. + + Args: + message (str): Error message to log + """ + self.debug_print(f"{message}", level=logging.DEBUG, tag=tag) def debug_print(self, message, level=logging.DEBUG, colour=None, tag=None) -> None: """Print to console if debug is enabled""" @@ -627,7 +629,7 @@ def _listen(self): pass except Exception as e: - self.debug_print( + self._log_error( f"Listen connection error {e}", colour="bright_red", level=logging.CRITICAL, @@ -687,12 +689,7 @@ def _validate_node_credentials( ) if dht_info and dht_info.get("reputation", 0) < 40: - self.debug_print( - f"Poor reputation: {node_info['node_id_hash']}", - colour="red", - level=logging.WARNING, - tag="Smartnode", - ) + self._log_warning(f"Poor reputation: {node_info['node_id_hash']}") connection.close() return False @@ -788,12 +785,7 @@ def _handle_instigator_proof(self, node_info: dict) -> float: proof = decrypt(node_info['random'].encode(), self.role) return float(proof) except Exception as e: - self.debug_print( - f"Proof validation failed: {e}", - colour="bright_red", - level=logging.WARNING, - tag="Smartnode", - ) + self._log_warning(f"Proof validation failed: {e}") return 0 def _validate_response( @@ -824,10 +816,7 @@ def _validate_response( # Select a new port for the node to use (since we accepted connection from the listening/main socket) our_port = self._get_next_port() - self.debug_print( - f"Selected next port: {our_port} for new connection.", - tag="Smartnode", - ) + self._log_debug(f"Selected next port: {our_port} for new connection.") self.add_port_mapping(our_port, our_port) # Send the new port and proof of random number @@ -839,12 +828,7 @@ def _validate_response( return rand_n_proof == expected_rand_n, our_port, new_port, main_port except Exception as e: - self.debug_print( - f"Response validation failed: {e}", - colour="bright_red", - level=logging.WARNING, - tag="Smartnode", - ) + self._log_warning(f"Response validation failed: {e}") return False, 0, 0, 0 def _finalize_connection( @@ -892,29 +876,17 @@ def _finalize_connection( return self._store_node_connection(thread_client, node_info) except Exception as e: - self.debug_print( - f"Connection finalization failed: {e}", - colour="bright_red", - level=logging.ERROR, - tag="Smartnode", - ) + self._log_error(f"Connection finalization failed: {e}") if connection: try: connection.close() except Exception as close_error: - self.debug_print( - f"Error closing connection: {close_error}", - colour="bright_red", - level=logging.ERROR, - tag="Smartnode", - ) + self._log_error(f"Error closing connection: {close_error}") return False def _establish_instigator_connection(self, host: str, port: int) -> socket.socket: """Establish connection as the instigator""" - self.debug_print( - f"Switching connection to new port: {host}:{port}", tag="Smartnode" - ) + self._log_debug(f"Switching connection to new port: {host}:{port}") # Increase wait time to allow receiver to fully set up socket time.sleep(2.5) # Increased from 1 to 2.5 seconds @@ -930,40 +902,23 @@ def _establish_instigator_connection(self, host: str, port: int) -> socket.socke return new_sock except socket.timeout: - self.debug_print( - f"Port swap connection timeout: {host}:{port}", - colour="bright_red", - level=logging.WARNING, - tag="Smartnode", - ) + self._log_warning(f"Port swap connection timeout: {host}:{port}") new_sock.close() raise ConnectionError(f"Connection timeout to {host}:{port}") except ConnectionRefusedError: - self.debug_print( - f"Port swap connection refused: {host}:{port}", - colour="bright_red", - level=logging.WARNING, - tag="Smartnode", - ) + self._log_warning(f"Port swap connection refused: {host}:{port}") new_sock.close() raise ConnectionError(f"Connection refused to {host}:{port}") except Exception as e: - self.debug_print( - f"Port swap failed ({self.host}:{port}): {e}", - colour="bright_red", - level=logging.ERROR, - tag="Smartnode", - ) + self._log_error(f"Port swap failed ({self.host}:{port}): {e}") new_sock.close() raise def _establish_receiver_connection(self, port: int) -> Optional[socket.socket]: """Establish connection as the receiver""" - self.debug_print( - f"Listening for the instigator on the new port: {port}", tag="Smartnode" - ) + self._log_debug(f"Listening for the instigator on the new port: {port}") # Create socket earlier to reduce race condition new_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) @@ -982,22 +937,12 @@ def _establish_receiver_connection(self, port: int) -> Optional[socket.socket]: return connection except socket.timeout: - self.debug_print( - f"Timeout waiting for instigator on port {port}", - colour="bright_red", - level=logging.ERROR, - tag="Smartnode", - ) + self._log_error(f"Timeout waiting for instigator on port {port}") new_sock.close() return None except Exception as e: - self.debug_print( - f"Error accepting instigator connection: {e}", - colour="bright_red", - level=logging.ERROR, - tag="Smartnode", - ) + self._log_error(f"Error accepting instigator connection: {e}") new_sock.close() return None @@ -1009,9 +954,7 @@ def _is_duplicate_connection(self, node_address: tuple) -> bool: """Check if we already have a connection to this node""" for node in self.nodes.values(): if node.host == node_address[0] and node.port == node_address[1]: - self.debug_print( - f"Already connected to node: {node.node_id}", tag="Smartnode" - ) + self._log_debug(f"Already connected to node: {node.node_id}") return True return False @@ -1108,9 +1051,7 @@ def connect_node( # Avoid duplicate connections if id_hash is not None and id_hash in self.nodes and not reconnect: - self.debug_print( - f"connect_node: Already connected to {id_hash}", tag="Smartnode" - ) + self._log_debug(f"connect_node: Already connected to {id_hash}") return True if _can_connect: @@ -1125,9 +1066,8 @@ def connect_node( # Select a free local port for outbound connection our_port = self._get_next_port() self.add_port_mapping(our_port, our_port) - self.debug_print( - f"Selected next port: {our_port} for new connection", - tag="Smartnode", + self._log_debug( + f"Selected next port: {our_port} for new connection" ) # Bind locally and connect to target node @@ -1135,12 +1075,7 @@ def connect_node( sock.connect((host, port)) sock.settimeout(10) # Prevent hanging connections - self.debug_print( - f"connect_node: connecting to {host}:{port}", - colour="blue", - level=logging.INFO, - tag="Smartnode", - ) + self._log_debug(f"connect_node: connecting to {host}:{port}") # Send initial identity message (role + keys) message = json.dumps( @@ -1156,11 +1091,8 @@ def connect_node( except Exception as e: # Retry with exponential backoff wait_time = backoff * (2**attempt) - self.debug_print( - f"Attempt {attempt + 1}/{max_attempts} failed: {e}. Retrying in {wait_time}s", - level=logging.WARNING, - colour="bright_red", - tag="Smartnode", + self._log_warning( + f"Attempt {attempt + 1}/{max_attempts} failed: {e}. Retrying in {wait_time}s" ) time.sleep(wait_time) @@ -1177,7 +1109,12 @@ def bootstrap(self): """ # Connect with some seed nodes from config file if self.on_chain and not self.local_test: - self.debug_print("Bootstrapping to public network...", tag="Smartnode") + self.debug_print( + "Bootstrapping to public network...", + tag="Smartnode", + level=logging.INFO, + colour="cyan", + ) for seed_validator in self._seed_validators: host, port, id_hash = seed_validator connected = self.connect_node(host, port, id_hash) @@ -1216,7 +1153,12 @@ def bootstrap(self): # candidates.append(validator_id) if self._priority_nodes: - self.debug_print("Connecting priority nodes...", tag="Smartnode") + self.debug_print( + "Connecting priority nodes...", + tag="Smartnode", + level=logging.INFO, + colour="cyan", + ) for seed_node in self._priority_nodes: host, port = seed_node self.connect_node(host, port) @@ -1270,23 +1212,13 @@ def _init_upnp(self) -> None: # Clean up mappings previously created by this application. try: if devices_found == 0: - self.debug_print( - "No UPnP devices found.", - colour="bright_red", - level=logging.ERROR, - tag="Smartnode", - ) + self._log_error("No UPnP devices found.") return self.clean_port_mappings() except Exception as e: - self.debug_print( - f"Error during UPnP cleanup: {e}", - colour="bright_red", - level=logging.ERROR, - tag="Smartnode", - ) + self._log_error(f"Error during UPnP cleanup: {e}") self.add_port_mapping(self.port, self.port) @@ -1311,7 +1243,7 @@ def add_port_mapping(self, external_port, internal_port): ) if result: - self.debug_print( + self._log_debug( f"UPnP port forward successful on port {self.port}", tag="Smartnode", ) @@ -1321,18 +1253,13 @@ def add_port_mapping(self, external_port, internal_port): f"Failed to initialize UPnP. (internal port: {internal_port}," f" external port: {external_port})", level=logging.CRITICAL, - colour="bright_red", tag="Smartnode", ) return False except Exception as e: if "ConflictInMapping" in str(e): - self.debug_print( - f"Port {external_port} is already mapped.", - level=logging.DEBUG, - tag="Smartnode", - ) + self._log_debug(f"Port {external_port} is already mapped.") return False else: raise e @@ -1345,23 +1272,14 @@ def remove_port_mapping(self, external_port): result = self.upnp.deleteportmapping(external_port, "TCP") if result is True: - self.debug_print( - f"Successfully removed UPnP port mapping for external port {external_port}", - tag="Smartnode", + self._log_debug( + f"Successfully removed UPnP port mapping for external port {external_port}" ) else: - self.debug_print( - f"Could not remove port mapping: {result}", - level=logging.WARNING, - colour="yellow", - tag="Smartnode", - ) + self._log_warning(f"Could not remove port mapping: {result}") except Exception as e: - self.debug_print( - f"Error removing UPnP port mapping for port {external_port}: {e}", - level=logging.ERROR, - colour="bright_red", - tag="Smartnode", + self._log_error( + f"Error removing UPnP port mapping for port {external_port}: {e}" ) def clean_port_mappings(self): @@ -1372,9 +1290,7 @@ def clean_port_mappings(self): index = 38751 if not self.upnp: - self.debug_print( - "UPnP is not initialized.", level=logging.WARNING, tag="Smartnode" - ) + self._log_warning("UPnP is not initialized.") return mappings while True: @@ -1392,11 +1308,7 @@ def clean_port_mappings(self): if "SpecifiedArrayIndexInvalid" in str(e): break - self.debug_print( - f"Error retrieving port mapping at index {index}: {e}", - level=logging.ERROR, - tag="Smartnode", - ) + self._log_error(f"Error retrieving port mapping at index {index}: {e}") break if index > 39_000: @@ -1427,19 +1339,14 @@ def _can_connect(self, host: str, port: int): """Makes sure we are not trying to connect to ourselves or a connected nodes""" # Check if trying to connect to self if host == self.host and port == self.port: - self.debug_print( - "connect_with_node: cannot connect with yourself!", - level=logging.WARNING, - tag="Smartnode", - ) + self._log_warning("connect_with_node: cannot connect with yourself!") return False # Check if already connected for node in self.nodes.values(): if node.host == host and (node.port == port or node.main_port == port): - self.debug_print( - f"connect_with_node: already connected with node: {node.node_id}", - tag="Smartnode", + self._log_warning( + f"connect_with_node: already connected with node: {node.node_id}" ) return False @@ -1450,33 +1357,22 @@ def send_to_node(self, n: Connection, data: bytes) -> None: if n in self.nodes.values(): self.debug_print( f"send_to_node: Sending {len(data)} to node: {n.host}:{n.port}", - tag="Smartnode", + level=self.VERBOSE, ) n.send(data) else: - self.debug_print( - "send_to_node: node not found!", - colour="red", - level=logging.WARNING, - tag="Smartnode", - ) + self._log_warning("send_to_node: node not found!") def send_to_node_from_file(self, n: Connection, file, tag): if n in self.nodes.values(): n.send_from_file(file, tag) else: - self.debug_print( - "send_to_node: node not found!", - colour="red", - level=logging.WARNING, - tag="Smartnode", - ) + self._log_warning("send_to_node: node not found!") def handle_message(self, node: Connection, data) -> None: """Callback method to handles incoming data from connections""" - self.debug_print( - f"handle_message from {node.host}:{node.port} -> {data.__sizeof__() / 1e6}MB", - tag="Smartnode", + self._log_debug( + f"handle_message from {node.host}:{node.port} -> {data.__sizeof__() / 1e6}MB" ) # Update last seen value @@ -1515,7 +1411,7 @@ def close_connection(self, n: socket.socket, additional_info: str = None) -> Non if additional_info: message += f": {additional_info}" - self.debug_print(message, colour="red", level=logging.DEBUG, tag="Smartnode") + self.debug_print(message, colour="red", tag="Smartnode") self.remove_port_mapping(n.getsockname()[1]) n.close() @@ -1523,7 +1419,7 @@ def _stop_upnp(self) -> None: """Shuts down UPnP on port""" if self.upnp: self.clean_port_mappings() - self.debug_print("_stop_upnp: UPnP cleaned.", tag="Smartnode") + self._log_debug("_stop_upnp: UPnP cleaned.") def stop(self) -> None: """Shut down nodes and all associated connections/threads""" @@ -1538,12 +1434,7 @@ def stop(self) -> None: try: self.sock.close() except Exception as e: - self.debug_print( - f"Error closing socket: {e}", - colour="bright_red", - level=logging.ERROR, - tag="Smartnode", - ) + self._log_error(f"Error closing socket: {e}") for node in list(self.nodes.values()): node.stop() diff --git a/tensorlink/p2p/torch_node.py b/tensorlink/p2p/torch_node.py index bf11b8c..cdeea94 100644 --- a/tensorlink/p2p/torch_node.py +++ b/tensorlink/p2p/torch_node.py @@ -152,12 +152,7 @@ def handle_data(self, data: bytes, node: Connection): return False except Exception as e: - self.debug_print( - f"Error handling data: {e}", - colour="bright_red", - level=logging.ERROR, - tag="Torchnode", - ) + self._log_error(f"Error handling data: {e}", tag="Torchnode") def _train_updated(self, data: bytes): mode = False if data[13:14] == b"0" else True @@ -173,16 +168,16 @@ def _update_train(self, data: bytes, node: Connection): def _handle_parameters(self, data: bytes): module_id = data[10:74].decode() - self.debug_print( - f"Received Parameters for: {module_id}", colour="blue", tag="Torchnode" - ) + self._log_debug(f"Received Parameters for: {module_id}", tag="Torchnode") + file_name = f"tmp/{module_id}_parameters" key = "PREQPREQPREQ" + module_id self.memory_manager[key] = file_name + return True def _handle_parameters_request(self, data: bytes): - self.debug_print("RECEIVED PARAMS REQUEST", tag="Torchnode") + self._log_debug("RECEIVED PARAMS REQUEST", tag="Torchnode") # TODO Must ensure requesting node is indeed the master or an overseeing validator module_id = data[10:74].decode() @@ -208,7 +203,6 @@ def _handle_optimizer_response(self, data: bytes, node: Connection): if response_type == "loaded": self.debug_print( f"Optimizer for module: {module_id} loaded on worker {node.node_id}", - level=logging.INFO, colour="bright_cyan", tag="Torchnode", ) @@ -363,7 +357,7 @@ def _handle_module(self, data: bytes, node: Connection): self._remove_request(node.node_id, req) self.debug_print( - f"Loading distributed module: {module_id}", + f"Loading distributed module: {module_id}, mo", colour="bright_cyan", level=logging.INFO, tag="Torchnode", @@ -971,6 +965,21 @@ def line(label, value, colour=ANSI.CYAN): print(line("Modules", modules, ANSI.MAGENTA)) + if self.print_level == logging.DEBUG: + print(line("Module Info", len(self.modules), ANSI.MAGENTA)) + for k in list(self.modules)[:10]: + print(f"{ANSI.DIM} └─ {k}{ANSI.RESET}") + if len(self.modules) > 10: + print(f"{ANSI.DIM} ... {len(self.modules) - 10} more{ANSI.RESET}") + + # Jobs (if present) + jobs = getattr(self, "jobs", {}) + print(line("Jobs", len(jobs), ANSI.CYAN)) + + if self.print_level == logging.DEBUG: + for jid in list(jobs)[:10]: + print(f"{ANSI.DIM} └─ {jid}{ANSI.RESET}") + # --- Validator --- if self.role.startswith("V"): print(sep) From 64a5aad7e588bd475ff3e7af7f4c2cd6a36db2ff Mon Sep 17 00:00:00 2001 From: mattjhawken Date: Tue, 27 Jan 2026 21:47:36 -0500 Subject: [PATCH 14/25] Minor node binary bug fix + debug print enhancements. --- bin/config.json | 3 +- bin/run_node.py | 7 +++ tensorlink/nodes/nodes.py | 21 +++++--- tensorlink/nodes/user_thread.py | 36 ++++++-------- tensorlink/nodes/validator_thread.py | 2 + tensorlink/nodes/worker_thread.py | 74 +++++++++++++++------------- tensorlink/p2p/torch_node.py | 68 +++++++++++++------------ 7 files changed, 118 insertions(+), 93 deletions(-) diff --git a/bin/config.json b/bin/config.json index e553265..59191d7 100644 --- a/bin/config.json +++ b/bin/config.json @@ -19,7 +19,8 @@ }, "ml": { "trusted": false, - "max_vram_gb": 0 + "max_vram_gb": 0, + "max_module_bytes": 2e8 } } } \ No newline at end of file diff --git a/bin/run_node.py b/bin/run_node.py index 4e8e926..fc445e8 100644 --- a/bin/run_node.py +++ b/bin/run_node.py @@ -182,7 +182,10 @@ def run_worker_loop(worker, config): except KeyboardInterrupt: logging.info("Exiting...") + print("Exiting...") + finally: + worker.cleanup() if mining_process: stop_mining(mining_process) @@ -197,6 +200,10 @@ def run_validator_loop(validator): except KeyboardInterrupt: logging.info("Exiting...") + print("Exiting...") + + finally: + validator.cleanup() def main(): diff --git a/tensorlink/nodes/nodes.py b/tensorlink/nodes/nodes.py index 1dbc31c..ba165f1 100644 --- a/tensorlink/nodes/nodes.py +++ b/tensorlink/nodes/nodes.py @@ -176,19 +176,28 @@ def cleanup(self): """ Gracefully shut down the role process and release resources. """ - if self.node_process is not None and self.node_process.exitcode is None: - # Ask the role to stop cleanly first - response = self.send_request("stop", (None,), timeout=15) - if response: - self.node_process.join(timeout=15) + if self.node_process is not None: + # Signal process to stop + self._stop_event.set() + + # Ask the role to stop cleanly first via IPC + try: + response = self.send_request("stop", (None,), timeout=10) + print(f"Stop request response: {response}") + except Exception as e: + print(f"Error sending stop request: {e}") + + # Wait for graceful shutdown + self.node_process.join(timeout=10) # Force terminate if still alive if self.node_process.is_alive(): print("Forcing termination for node process.") self.node_process.terminate() + self.node_process.join() - self.node_process.join() self.node_process = None + print("Node cleanup complete") def send_request(self, request_type, args, timeout=5): """ diff --git a/tensorlink/nodes/user_thread.py b/tensorlink/nodes/user_thread.py index 2e56a86..1f0bd6c 100644 --- a/tensorlink/nodes/user_thread.py +++ b/tensorlink/nodes/user_thread.py @@ -1,7 +1,6 @@ from tensorlink.p2p.connection import Connection from tensorlink.p2p.torch_node import Torchnode -from dotenv import get_key import hashlib import json import logging @@ -68,25 +67,6 @@ def __init__( # ): # time.sleep(5) - if self.on_chain: - self.public_key = get_key(".tensorlink.env", "PUBLIC_KEY") - if not self.public_key: - self.debug_print( - "Public key not found in .env file, using donation wallet...", - tag="User", - ) - self.public_key = "0x1Bc3a15dfFa205AA24F6386D959334ac1BF27336" - - self.dht.store(hashlib.sha256(b"ADDRESS").hexdigest(), self.public_key) - - attempts = 0 - while attempts < 3 and len(self.validators) == 0: - self.bootstrap() - - if len(self.nodes) == 0: - time.sleep(3) - attempts += 1 - def handle_data(self, data: bytes, node: Connection) -> bool: """ Callback function to receive streamed data from worker roles. @@ -491,6 +471,22 @@ def run(self): # 2 or more of the identical info super().run() + should_bootstrap = bool(self._priority_nodes) or self.on_chain + if should_bootstrap: + attempts = 0 + while attempts < 3 and len(self.validators) == 0: + self.bootstrap() + + if len(self.nodes) == 0: + time.sleep(3) + attempts += 1 + else: + self.debug_print( + "Skipping bootstrap (no priority nodes and not on-chain).", + tag="Worker", + level=logging.INFO, + ) + while not self.terminate_flag.is_set(): # Handle job oversight, and inspect other jobs (includes job verification and reporting) time.sleep(3) diff --git a/tensorlink/nodes/validator_thread.py b/tensorlink/nodes/validator_thread.py index ca1a9f4..f7a4537 100644 --- a/tensorlink/nodes/validator_thread.py +++ b/tensorlink/nodes/validator_thread.py @@ -988,6 +988,8 @@ def run(self): time.sleep(1) counter += 1 + self.stop() + def stop(self): self.keeper.write_state() super().stop() diff --git a/tensorlink/nodes/worker_thread.py b/tensorlink/nodes/worker_thread.py index b9a2fdc..a942941 100644 --- a/tensorlink/nodes/worker_thread.py +++ b/tensorlink/nodes/worker_thread.py @@ -63,33 +63,6 @@ def __init__( self.mining_active = mining_active - if self.on_chain: - self.public_key = get_key(".tensorlink.env", "PUBLIC_KEY") - if not self.public_key: - self.debug_print( - "Public key not found in .env file, using donation wallet...", - tag="Worker", - ) - self.public_key = "0x1Bc3a15dfFa205AA24F6386D959334ac1BF27336" - - self.dht.store(hashlib.sha256(b"ADDRESS").hexdigest(), self.public_key) - - should_bootstrap = bool(self._priority_nodes) or self.on_chain - if should_bootstrap: - attempts = 0 - while attempts < 3 and len(self.validators) == 0: - self.bootstrap() - - if len(self.nodes) == 0: - time.sleep(3) - attempts += 1 - else: - self.debug_print( - "Skipping bootstrap (no priority nodes and not on-chain).", - tag="Worker", - level=logging.INFO, - ) - # Finally, load up previous saved state if any if on_chain or load_previous_state: self.keeper.load_previous_state() @@ -201,16 +174,47 @@ def run(self): # Accept users and back-check history # Get proposees from SC and send our state to them super().run() + try: + if self.on_chain: + self.public_key = get_key(".tensorlink.env", "PUBLIC_KEY") + if not self.public_key: + self.debug_print( + "Public key not found in .env file, using donation wallet...", + tag="Worker", + ) + self.public_key = "0x1Bc3a15dfFa205AA24F6386D959334ac1BF27336" + + self.dht.store(hashlib.sha256(b"ADDRESS").hexdigest(), self.public_key) + + should_bootstrap = bool(self._priority_nodes) or self.on_chain + if should_bootstrap: + attempts = 0 + while attempts < 3 and len(self.validators) == 0: + self.bootstrap() + + if len(self.nodes) == 0: + time.sleep(3) + attempts += 1 + else: + self.debug_print( + "Skipping bootstrap (no priority nodes and not on-chain).", + tag="Worker", + level=logging.INFO, + ) + + counter = 0 + while not self.terminate_flag.is_set(): + if counter % 180 == 0: + self.keeper.clean_node() + self.clean_port_mappings() + self.print_ui_status() - counter = 0 - while not self.terminate_flag.is_set(): - if counter % 180 == 0: - self.keeper.clean_node() - self.clean_port_mappings() - self.print_ui_status() + time.sleep(1) + counter += 1 + except KeyboardInterrupt: + pass - time.sleep(1) - counter += 1 + self.stop() def load_distributed_module(self, module: nn.Module, graph: dict = None): pass diff --git a/tensorlink/p2p/torch_node.py b/tensorlink/p2p/torch_node.py index cdeea94..993aeef 100644 --- a/tensorlink/p2p/torch_node.py +++ b/tensorlink/p2p/torch_node.py @@ -572,29 +572,36 @@ def _handle_check_module(self, request): else: return_val = None - for module_id, module in self.modules.items(): - # "mem_info" is added to module info upon initially receiving it - if "mem_info" in module: - # Return the module info to the ML process - if self.role == "V": - if return_val.get("job_id") == module.get("job_id"): - return_val["distribution"][module_id] = module["distribution"] - return_val["model_name"] = module.get("model_name", "") - return_val["optimizer"] = module["optimizer"] - return_val["training"] = module["training"] - else: - return_val = module - return_val["module_id"] = module_id - - del module["mem_info"] - - # "termination" is added to module info when the job is closing - elif "termination" in module: - return_val = module_id - del self.modules[module_id] - break + try: + for module_id, module in self.modules.items(): + # "mem_info" is added to module info upon initially receiving it + if "mem_info" in module: + # Return the module info to the ML process + if self.role == "V": + if return_val.get("job_id") == module.get("job_id"): + return_val["distribution"][module_id] = module[ + "distribution" + ] + return_val["model_name"] = module.get("model_name", "") + return_val["optimizer"] = module["optimizer"] + return_val["training"] = module["training"] + else: + return_val = module + return_val["module_id"] = module_id + + del module["mem_info"] + + # "termination" is added to module info when the job is closing + elif "termination" in module: + return_val = module_id + del self.modules[module_id] + break + + self.response_queue.put({"status": "SUCCESS", "return": return_val}) - self.response_queue.put({"status": "SUCCESS", "return": return_val}) + except Exception as e: + self._log_error(f"Error handling module: {e}") + self.response_queue.put({"status": "FAILURE", "return": None}) def _handle_check_module_request(self, request): request_type, worker_id, module_id = request["args"] @@ -776,16 +783,8 @@ def _handle_get_info(self, request): ) def _handle_stop(self, request): - self.response_queue.put({"status": "SUCCESS", "return": True}) self.terminate_flag.set() - - def _handle_check_shutdown(self, request): - if self.terminate_flag.is_set(): - self.response_queue.put({"status": "SUCCESS", "return": True}) - t = threading.Thread(target=self._stop_mpc_comms) - t.start() - else: - self.response_queue.put({"status": "SUCCESS", "return": False}) + self.response_queue.put({"status": "SUCCESS", "return": True}) def _handle_debug_print(self, request): if len(request["args"]) == 1: @@ -897,6 +896,13 @@ def run(self): def stop(self): super().stop() + self._stop_mpc_comms() + + def _handle_check_shutdown(self, request): + if self.terminate_flag.is_set(): + self.response_queue.put({"status": "SUCCESS", "return": True}) + else: + self.response_queue.put({"status": "SUCCESS", "return": False}) def _stop_mpc_comms(self): self.mpc_terminate_flag.set() From a67b4351826428e6e026ee15ffa2df73f69b3c84 Mon Sep 17 00:00:00 2001 From: mattjhawken Date: Wed, 28 Jan 2026 09:18:44 -0500 Subject: [PATCH 15/25] Reduce public model tracking to 1 day. --- tensorlink/ml/validator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorlink/ml/validator.py b/tensorlink/ml/validator.py index 3032a09..91c0ab6 100644 --- a/tensorlink/ml/validator.py +++ b/tensorlink/ml/validator.py @@ -154,7 +154,7 @@ def __init__( self.models_initializing = set() # job_id # Configuration - self.TRACKING_DAYS = 7 # Track requests for past 7 days + self.TRACKING_DAYS = 1 # Track requests for past 1 day self.MIN_REQUESTS_THRESHOLD = 10 # Minimum requests to consider auto-loading self.MAX_AUTO_MODELS = 10 # Maximum models to auto-load From b8c53ee704acd1e50c8b73367927796e78f22937 Mon Sep 17 00:00:00 2001 From: mattjhawken Date: Wed, 28 Jan 2026 09:30:57 -0500 Subject: [PATCH 16/25] Fixed debug print statement in smart_node. Reduced public model loading interval by 10x. --- tensorlink/ml/validator.py | 7 ++++--- tensorlink/p2p/smart_node.py | 7 +------ tensorlink/p2p/torch_node.py | 3 +-- 3 files changed, 6 insertions(+), 11 deletions(-) diff --git a/tensorlink/ml/validator.py b/tensorlink/ml/validator.py index 91c0ab6..3ac5d42 100644 --- a/tensorlink/ml/validator.py +++ b/tensorlink/ml/validator.py @@ -476,9 +476,6 @@ def check_node(self): # Manage any ghost memory caches self._audit_memory_reservations() - # Manage autoloaded models based on popularity (or DEFAULT_MODELS fallback) - self._manage_auto_loaded_models() - # Check if jobs are still active for job_id, model in self.models.items(): model_name = model.model_name @@ -491,6 +488,10 @@ def check_node(self): self.CHECK_COUNTER = 1 + if self.CHECK_COUNTER * 10 % self.GC_CHECK_INTERVAL == 0: + # Manage autoloaded models based on popularity (or DEFAULT_MODELS fallback) + self._manage_auto_loaded_models() + if self.models_initializing: # Only call model management if we have models actively initializing self._try_finalize_initializing_models() diff --git a/tensorlink/p2p/smart_node.py b/tensorlink/p2p/smart_node.py index 0bf90b4..1286634 100644 --- a/tensorlink/p2p/smart_node.py +++ b/tensorlink/p2p/smart_node.py @@ -629,12 +629,7 @@ def _listen(self): pass except Exception as e: - self._log_error( - f"Listen connection error {e}", - colour="bright_red", - level=logging.CRITICAL, - tag="Smartnode", - ) + self._log_error(f"Listen connection error {e}") # self.reconnect_nodes() diff --git a/tensorlink/p2p/torch_node.py b/tensorlink/p2p/torch_node.py index 993aeef..e9dab66 100644 --- a/tensorlink/p2p/torch_node.py +++ b/tensorlink/p2p/torch_node.py @@ -357,9 +357,8 @@ def _handle_module(self, data: bytes, node: Connection): self._remove_request(node.node_id, req) self.debug_print( - f"Loading distributed module: {module_id}, mo", + f"Loading distributed module: {module_info}", colour="bright_cyan", - level=logging.INFO, tag="Torchnode", ) From 0f2aa4f6deb471c2b55a8a53d21fe92794282138 Mon Sep 17 00:00:00 2001 From: mattjhawken Date: Wed, 28 Jan 2026 10:21:25 -0500 Subject: [PATCH 17/25] Fixed #24. Added more flake coverage to .pre-commit --- .pre-commit-config.yaml | 2 +- tensorlink/api/node.py | 128 +++++++++++++++++---------- tensorlink/nodes/nodes.py | 1 - tensorlink/nodes/user_thread.py | 43 +++++---- tensorlink/nodes/validator_thread.py | 60 +++++++------ tensorlink/nodes/worker_thread.py | 7 +- 6 files changed, 141 insertions(+), 100 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 207e25b..074d4c5 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -13,4 +13,4 @@ repos: rev: 7.1.1 hooks: - id: flake8 - files: ^(tensorlink/p2p|tensorlink/nodes)/ + files: ^(tensorlink/p2p|tensorlink/nodes|tensorlink/api)/ diff --git a/tensorlink/api/node.py b/tensorlink/api/node.py index 892c616..ab009e7 100644 --- a/tensorlink/api/node.py +++ b/tensorlink/api/node.py @@ -16,7 +16,6 @@ import random import queue import time -import json def build_hf_job_data( @@ -52,6 +51,47 @@ def build_hf_job_data( } +def _parse_chat_messages(messages): + """ + Parse chat messages into system messages, history, and last user message. + Returns: (system_messages, history, last_user_message) + """ + system_messages = [] + conversation = [] + + for msg in messages: + if msg.role not in ("system", "user", "assistant"): + continue + + if msg.role == "system": + system_messages.append(msg.content) + else: + conversation.append({"role": msg.role, "content": msg.content}) + + # Find last user message + last_user_message = None + last_user_idx = None + + for idx in range(len(conversation) - 1, -1, -1): + if conversation[idx]["role"] == "user": + last_user_message = conversation[idx]["content"] + last_user_idx = idx + break + + if last_user_message is None: + raise HTTPException(status_code=400, detail="No user message found") + + # Build history (everything before the last user message) + history = conversation[:last_user_idx] + + # Prepend system message to history if present + if system_messages: + combined_system = "\n".join(system_messages) + history.insert(0, {"role": "system", "content": combined_system}) + + return system_messages, history, last_user_message + + class TensorlinkAPI: def __init__(self, smart_node, host="0.0.0.0", port=64747): self.smart_node = smart_node @@ -73,6 +113,16 @@ def __init__(self, smart_node, host="0.0.0.0", port=64747): self._start_server() def _define_routes(self): + """Register all API routes by delegating to specialized methods""" + self._register_generate_routes() + self._register_model_routes() + self._register_stats_routes() + self._register_network_routes() + self.app.include_router(self.router) + + def _register_generate_routes(self): + """Register generation and chat completion endpoints""" + @self.router.post("/v1/generate") async def generate(request: GenerationRequest): """Updated /v1/generate endpoint""" @@ -82,18 +132,7 @@ async def generate(request: GenerationRequest): request.output_format = getattr(request, "output_format", "simple") # Log model request - current_time = time.time() - self.model_request_timestamps[request.hf_name].append(current_time) - cutoff = current_time - 300 - self.model_request_timestamps[request.hf_name] = [ - ts - for ts in self.model_request_timestamps[request.hf_name] - if ts > cutoff - ] - - if request.hf_name not in self.model_name_to_request: - self.model_name_to_request[request.hf_name] = 1 - self.model_name_to_request[request.hf_name] += 1 + self._log_model_request(request.hf_name) request.output = None request_id = f"req_{hash(random.random())}" @@ -150,39 +189,10 @@ async def chat_completions(request: ChatCompletionRequest): status_code=400, detail="messages cannot be empty" ) - # Separate system messages from conversation - system_messages = [] - conversation = [] - - for msg in request.messages: - if msg.role not in ("system", "user", "assistant"): - continue - - if msg.role == "system": - system_messages.append(msg.content) - else: - conversation.append({"role": msg.role, "content": msg.content}) - - # Find last user message - last_user_message = None - last_user_idx = None - - for idx in range(len(conversation) - 1, -1, -1): - if conversation[idx]["role"] == "user": - last_user_message = conversation[idx]["content"] - last_user_idx = idx - break - - if last_user_message is None: - raise HTTPException(status_code=400, detail="No user message found") - - # Build history (everything before the last user message) - history = conversation[:last_user_idx] - - # Prepend system message to history if present - if system_messages: - combined_system = "\n".join(system_messages) - history.insert(0, {"role": "system", "content": combined_system}) + # Parse messages into system messages, history, and last user message + system_messages, history, last_user_message = _parse_chat_messages( + request.messages + ) # Create GenerationRequest gen_request = GenerationRequest( @@ -207,6 +217,9 @@ async def chat_completions(request: ChatCompletionRequest): except Exception as e: raise HTTPException(status_code=500, detail=str(e)) + def _register_model_routes(self): + """Register model management endpoints""" + @self.router.post("/request-model", response_model=ModelStatusResponse) def request_model(job_request: JobRequest, request: Request): """ @@ -289,7 +302,7 @@ def list_available_models(): loading_models = [] # Query the node's worker for model status - response = self.smart_node.request_queue.put( + _ = self.smart_node.request_queue.put( {"type": "get_loaded_models", "args": None} ) @@ -311,6 +324,9 @@ def list_available_models(): except Exception as e: raise HTTPException(status_code=500, detail=str(e)) + def _register_stats_routes(self): + """Register statistics and monitoring endpoints""" + @self.app.get("/stats") async def get_network_stats(): return self.smart_node.get_tensorlink_status() @@ -334,6 +350,9 @@ async def get_proposals(limit: int = Query(30, ge=1, le=180)): """ return self.smart_node.keeper.get_proposals(limit=limit) + def _register_network_routes(self): + """Register network and node information endpoints""" + @self.app.get("/node-info") async def get_node_info(node_id: str): """ @@ -368,7 +387,20 @@ async def get_worker_claims(node_address: str): """Get claim information for a specific worker node""" return self.smart_node.contract_manager.get_worker_claim_data(node_address) - self.app.include_router(self.router) + def _log_model_request(self, model_name: str): + """Log and track model requests for prioritization""" + current_time = time.time() + self.model_request_timestamps[model_name].append(current_time) + + # Keep only requests from last 5 minutes + cutoff = current_time - 300 + self.model_request_timestamps[model_name] = [ + ts for ts in self.model_request_timestamps[model_name] if ts > cutoff + ] + + if model_name not in self.model_name_to_request: + self.model_name_to_request[model_name] = 1 + self.model_name_to_request[model_name] += 1 async def _generate_stream(self, request, request_id, start_time): """Generator function for streaming tokens""" diff --git a/tensorlink/nodes/nodes.py b/tensorlink/nodes/nodes.py index ba165f1..17686c5 100644 --- a/tensorlink/nodes/nodes.py +++ b/tensorlink/nodes/nodes.py @@ -281,7 +281,6 @@ def run_role(self): max_vram_gb=self._max_vram_gb, ) - node.activate() node.run() # Keep process alive while the node thread is running diff --git a/tensorlink/nodes/user_thread.py b/tensorlink/nodes/user_thread.py index 1f0bd6c..4ba5af5 100644 --- a/tensorlink/nodes/user_thread.py +++ b/tensorlink/nodes/user_thread.py @@ -469,26 +469,31 @@ def run(self): # Get proposees from SC and send our state to them # If we are the next proposee, accept info from validators and only add info to the final state if there are # 2 or more of the identical info - super().run() + try: + super().run() - should_bootstrap = bool(self._priority_nodes) or self.on_chain - if should_bootstrap: - attempts = 0 - while attempts < 3 and len(self.validators) == 0: - self.bootstrap() + should_bootstrap = bool(self._priority_nodes) or self.on_chain + if should_bootstrap: + attempts = 0 + while attempts < 3 and len(self.validators) == 0: + self.bootstrap() - if len(self.nodes) == 0: - time.sleep(3) - attempts += 1 - else: - self.debug_print( - "Skipping bootstrap (no priority nodes and not on-chain).", - tag="Worker", - level=logging.INFO, - ) + if len(self.nodes) == 0: + time.sleep(3) + attempts += 1 + else: + self.debug_print( + "Skipping bootstrap (no priority nodes and not on-chain).", + tag="Worker", + level=logging.INFO, + ) + + while not self.terminate_flag.is_set(): + # Handle job oversight, and inspect other jobs (includes job verification and reporting) + time.sleep(3) - while not self.terminate_flag.is_set(): - # Handle job oversight, and inspect other jobs (includes job verification and reporting) - time.sleep(3) + except KeyboardInterrupt: + self.terminate_flag.set() - self.stop() + finally: + self.stop() diff --git a/tensorlink/nodes/validator_thread.py b/tensorlink/nodes/validator_thread.py index f7a4537..a6edd4d 100644 --- a/tensorlink/nodes/validator_thread.py +++ b/tensorlink/nodes/validator_thread.py @@ -960,35 +960,39 @@ def distribute_job(self): # ] def run(self): - super().run() + try: + super().run() - if self.on_chain: - time.sleep(15) - self.execution_listener = threading.Thread( - target=self.contract_manager.proposal_creator, daemon=True - ) - self.execution_listener.start() - self.proposal_listener = threading.Thread( - target=self.contract_manager.proposal_validator, daemon=True - ) - self.proposal_listener.start() - - counter = 0 - # Loop for active job and network moderation - while not self.terminate_flag.is_set(): - if counter % 300 == 0: - self.keeper.write_state() - if counter % 120 == 0: - self.keeper.clean_node() - self.clean_port_mappings() - self.get_workers() - if counter % 180 == 0: - self.print_ui_status() - - time.sleep(1) - counter += 1 - - self.stop() + if self.on_chain: + time.sleep(15) + self.execution_listener = threading.Thread( + target=self.contract_manager.proposal_creator, daemon=True + ) + self.execution_listener.start() + self.proposal_listener = threading.Thread( + target=self.contract_manager.proposal_validator, daemon=True + ) + self.proposal_listener.start() + + counter = 0 + # Loop for active job and network moderation + while not self.terminate_flag.is_set(): + if counter % 300 == 0: + self.keeper.write_state() + if counter % 120 == 0: + self.keeper.clean_node() + self.clean_port_mappings() + self.get_workers() + if counter % 180 == 0: + self.print_ui_status() + + time.sleep(1) + counter += 1 + except KeyboardInterrupt: + self.terminate_flag.set() + + finally: + self.stop() def stop(self): self.keeper.write_state() diff --git a/tensorlink/nodes/worker_thread.py b/tensorlink/nodes/worker_thread.py index a942941..cb8b976 100644 --- a/tensorlink/nodes/worker_thread.py +++ b/tensorlink/nodes/worker_thread.py @@ -173,8 +173,8 @@ def _handle_job_req(self, data: bytes, node: Connection): def run(self): # Accept users and back-check history # Get proposees from SC and send our state to them - super().run() try: + super().run() if self.on_chain: self.public_key = get_key(".tensorlink.env", "PUBLIC_KEY") if not self.public_key: @@ -212,9 +212,10 @@ def run(self): time.sleep(1) counter += 1 except KeyboardInterrupt: - pass + self.terminate_flag.set() - self.stop() + finally: + self.stop() def load_distributed_module(self, module: nn.Module, graph: dict = None): pass From 8807fd96f6c1f51e89c18185b1faf91d4474a7cd Mon Sep 17 00:00:00 2001 From: mattjhawken Date: Wed, 28 Jan 2026 10:52:42 -0500 Subject: [PATCH 18/25] Fixed model autoload times in validator.py --- tensorlink/ml/validator.py | 13 +++++++++---- tensorlink/nodes/nodes.py | 4 +--- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/tensorlink/ml/validator.py b/tensorlink/ml/validator.py index 3ac5d42..722650b 100644 --- a/tensorlink/ml/validator.py +++ b/tensorlink/ml/validator.py @@ -303,6 +303,7 @@ def _manage_auto_loaded_models(self): desired_instances[model_name] = round(share * self.MAX_AUTO_MODELS) can_allocate = True + # Ensure each model has at least one instance for model_name, desired in desired_instances.items(): if not can_allocate: @@ -324,6 +325,8 @@ def _manage_auto_loaded_models(self): ), ) can_allocate = self._initialize_hosted_job(model_name) + if not can_allocate: + break # Finalize any first-load initializations if self.models_initializing: @@ -331,9 +334,6 @@ def _manage_auto_loaded_models(self): # Allocate duplicates based on proportional demand for model_name, target_count in desired_instances.items(): - if not can_allocate: - break - current_total = len(self.public_models.get(model_name, [])) current_total += sum( 1 if job_id in self.models_initializing else 0 @@ -356,6 +356,9 @@ def _manage_auto_loaded_models(self): if not can_allocate: break + if not can_allocate: + break + # Finalize any duplicate initializations if self.models_initializing: self._try_finalize_initializing_models() @@ -488,7 +491,9 @@ def check_node(self): self.CHECK_COUNTER = 1 - if self.CHECK_COUNTER * 10 % self.GC_CHECK_INTERVAL == 0: + if ( + self.CHECK_COUNTER % self.GC_CHECK_INTERVAL * 20 == 0 + ): # less frequent than garbage collection # Manage autoloaded models based on popularity (or DEFAULT_MODELS fallback) self._manage_auto_loaded_models() diff --git a/tensorlink/nodes/nodes.py b/tensorlink/nodes/nodes.py index 17686c5..f8edb2e 100644 --- a/tensorlink/nodes/nodes.py +++ b/tensorlink/nodes/nodes.py @@ -182,8 +182,7 @@ def cleanup(self): # Ask the role to stop cleanly first via IPC try: - response = self.send_request("stop", (None,), timeout=10) - print(f"Stop request response: {response}") + _ = self.send_request("stop", (None,), timeout=10) except Exception as e: print(f"Error sending stop request: {e}") @@ -197,7 +196,6 @@ def cleanup(self): self.node_process.join() self.node_process = None - print("Node cleanup complete") def send_request(self, request_type, args, timeout=5): """ From 4151cf2ae650a9fb8c2483c2cd34f4c8268ddae5 Mon Sep 17 00:00:00 2001 From: mattjhawken Date: Wed, 28 Jan 2026 11:26:18 -0500 Subject: [PATCH 19/25] Added 'api' key to job_data for hosted jobs in validator.py --- tensorlink/ml/validator.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorlink/ml/validator.py b/tensorlink/ml/validator.py index 722650b..6af82dd 100644 --- a/tensorlink/ml/validator.py +++ b/tensorlink/ml/validator.py @@ -489,13 +489,12 @@ def check_node(self): if not is_active: self._remove_hosted_job(job_id) - self.CHECK_COUNTER = 1 - if ( self.CHECK_COUNTER % self.GC_CHECK_INTERVAL * 20 == 0 ): # less frequent than garbage collection # Manage autoloaded models based on popularity (or DEFAULT_MODELS fallback) self._manage_auto_loaded_models() + self.CHECK_COUNTER = 1 if self.models_initializing: # Only call model management if we have models actively initializing @@ -911,6 +910,7 @@ def _initialize_hosted_job( "author": None, "active": True, "hosted": True, + "api": True, "training": False, "payment": payment, "time": time_limit, From 8c06431318a9de3f6b3e75b659aab9cd4669c4a4 Mon Sep 17 00:00:00 2001 From: mattjhawken Date: Wed, 28 Jan 2026 11:49:14 -0500 Subject: [PATCH 20/25] Moved query debug print to VERBOSE. Fixed minor autoload check interval bug. --- tensorlink/ml/validator.py | 2 +- tensorlink/p2p/dht.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorlink/ml/validator.py b/tensorlink/ml/validator.py index 6af82dd..a55ef06 100644 --- a/tensorlink/ml/validator.py +++ b/tensorlink/ml/validator.py @@ -490,7 +490,7 @@ def check_node(self): self._remove_hosted_job(job_id) if ( - self.CHECK_COUNTER % self.GC_CHECK_INTERVAL * 20 == 0 + self.CHECK_COUNTER % (self.GC_CHECK_INTERVAL * 20) == 0 ): # less frequent than garbage collection # Manage autoloaded models based on popularity (or DEFAULT_MODELS fallback) self._manage_auto_loaded_models() diff --git a/tensorlink/p2p/dht.py b/tensorlink/p2p/dht.py index 60abe87..ae2b4d9 100644 --- a/tensorlink/p2p/dht.py +++ b/tensorlink/p2p/dht.py @@ -77,7 +77,7 @@ def query( if self.node.rsa_key_hash not in keys_to_exclude: keys_to_exclude.append(self.node.rsa_key_hash) - self.node.debug_print(f"Querying for {key}", tag="DHT") + self.node.debug_print(f"Querying for {key}", tag="DHT", level=self.node.VERBOSE) closest_node = None closest_distance = float("inf") From 9da2721bb53d4e64eb6bf6c0e66c27d77fede472 Mon Sep 17 00:00:00 2001 From: mattjhawken Date: Wed, 28 Jan 2026 12:44:39 -0500 Subject: [PATCH 21/25] Changed seed validator port to 38752 --- bin/config.json | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/bin/config.json b/bin/config.json index 59191d7..3ac7c72 100644 --- a/bin/config.json +++ b/bin/config.json @@ -1,20 +1,20 @@ { "config": { "node": { - "type": "worker", + "type": "validator", "mode": "public", - "endpoint": false, + "endpoint": true, "endpoint_url": "0.0.0.0", "endpoint_port": 64747, "priority_nodes": [], - "logging": "ERROR" + "logging": "INFO" }, "crypto": { "address": "0x1Bc3a15dfFa205AA24F6386D959334ac1BF27336", "mining": false, "mining_script": "path/to/mining.executable", "seed_validators": [ - ["smartnodes.ddns.net", 38751, "58ef79797cd451e19df4a73fbd9871797f9c6a2995783c7f6fd2406978a2ba2e"] + ["smartnodes.ddns.net", 38752, "58ef79797cd451e19df4a73fbd9871797f9c6a2995783c7f6fd2406978a2ba2e"] ] }, "ml": { From b9c97c128dfd0ba06978e86e89ea8808053332cf Mon Sep 17 00:00:00 2001 From: mattjhawken Date: Wed, 28 Jan 2026 14:12:06 -0500 Subject: [PATCH 22/25] Removed memory lock from validator. Minor update to node binary venv naming. --- bin/run-node.sh | 2 +- tensorlink/ml/validator.py | 81 +++++++++++++++++--------------------- 2 files changed, 38 insertions(+), 45 deletions(-) diff --git a/bin/run-node.sh b/bin/run-node.sh index 75ce2c5..1edc3b0 100755 --- a/bin/run-node.sh +++ b/bin/run-node.sh @@ -1,6 +1,6 @@ #!/bin/bash -VENV_PATH="venv" +VENV_PATH=".venv" export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True diff --git a/tensorlink/ml/validator.py b/tensorlink/ml/validator.py index a55ef06..d8f0c69 100644 --- a/tensorlink/ml/validator.py +++ b/tensorlink/ml/validator.py @@ -17,7 +17,7 @@ from transformers import AutoTokenizer, TextIteratorStreamer from collections import defaultdict -from threading import Thread, Lock +from threading import Thread import torch import logging import inspect @@ -162,9 +162,6 @@ def __init__( self.host_memory_reserved = 0 self.initializing_reservations = {} # job_id -> reserved_memory - # Lock for thread-safe memory operations - self.memory_lock = Lock() - def _ensure_model_entry(self, model_name: str): """Ensure a model has an entry in the cache with proper structure""" if model_name not in self.model_cache: @@ -1142,17 +1139,15 @@ def _remove_hosted_job(self, job_id: str): def _reserve_host_memory(self, job_id: str, amount: int): """Reserve host memory for a model being initialized""" - with self.memory_lock: - self.host_memory_reserved += amount - self.initializing_reservations[job_id] = amount + self.host_memory_reserved += amount + self.initializing_reservations[job_id] = amount def _release_host_memory(self, job_id: str): """Release reserved host memory when initialization completes or fails""" - with self.memory_lock: - if job_id in self.initializing_reservations: - reserved = self.initializing_reservations[job_id] - self.host_memory_reserved -= reserved - del self.initializing_reservations[job_id] + if job_id in self.initializing_reservations: + reserved = self.initializing_reservations[job_id] + self.host_memory_reserved -= reserved + del self.initializing_reservations[job_id] def _get_available_host_memory(self) -> int: """Get currently available host memory accounting for reservations""" @@ -1164,9 +1159,8 @@ def _get_available_host_memory(self) -> int: ) if self._hosting_enabled: - with self.memory_lock: - total_memory = min(get_gpu_memory(), max_vram_bytes) - available_memory += total_memory - self.host_memory_reserved + total_memory = min(get_gpu_memory(), max_vram_bytes) + available_memory += total_memory - self.host_memory_reserved return available_memory def _audit_memory_reservations(self): @@ -1174,39 +1168,38 @@ def _audit_memory_reservations(self): Audit memory reservations and clean up any orphaned reservations. Called periodically to prevent memory leaks from edge cases. """ - with self.memory_lock: - # Find job_ids that have reservations but aren't in models or models_initializing - orphaned_reservations = [] + # Find job_ids that have reservations but aren't in models or models_initializing + orphaned_reservations = [] - for job_id in list(self.initializing_reservations.keys()): - if job_id not in self.models and job_id not in self.models_initializing: - orphaned_reservations.append(job_id) + for job_id in list(self.initializing_reservations.keys()): + if job_id not in self.models and job_id not in self.models_initializing: + orphaned_reservations.append(job_id) - # Release orphaned reservations - for job_id in orphaned_reservations: - reserved = self.initializing_reservations[job_id] - self.host_memory_reserved -= reserved - del self.initializing_reservations[job_id] + # Release orphaned reservations + for job_id in orphaned_reservations: + reserved = self.initializing_reservations[job_id] + self.host_memory_reserved -= reserved + del self.initializing_reservations[job_id] - self.send_request( - "debug_print", - ( - f"Released orphaned reservation: {job_id} ({reserved / 1e9:.2f}GB)", - "yellow", - logging.WARNING, - ), - ) + self.send_request( + "debug_print", + ( + f"Released orphaned reservation: {job_id} ({reserved / 1e9:.2f}GB)", + "yellow", + logging.WARNING, + ), + ) - if orphaned_reservations: - self.send_request( - "debug_print", - ( - f"Memory audit: Released {len(orphaned_reservations)} orphaned reservations. " - f"Total reserved: {self.host_memory_reserved / 1e9:.2f}GB", - "cyan", - logging.INFO, - ), - ) + if orphaned_reservations: + self.send_request( + "debug_print", + ( + f"Memory audit: Released {len(orphaned_reservations)} orphaned reservations. " + f"Total reserved: {self.host_memory_reserved / 1e9:.2f}GB", + "cyan", + logging.INFO, + ), + ) def main_loop(self): self.check_node() From 21b3383450c284799dc4ca5fa54421e971874c6a Mon Sep 17 00:00:00 2001 From: mattjhawken Date: Wed, 28 Jan 2026 15:22:04 -0500 Subject: [PATCH 23/25] Reimplemented memory lock for validator.py, fixed job_id bug during memory lock in inspect_model. --- tensorlink/ml/validator.py | 86 ++++++++++++++++------------ tensorlink/nodes/validator_thread.py | 2 +- 2 files changed, 50 insertions(+), 38 deletions(-) diff --git a/tensorlink/ml/validator.py b/tensorlink/ml/validator.py index d8f0c69..e9851fd 100644 --- a/tensorlink/ml/validator.py +++ b/tensorlink/ml/validator.py @@ -17,10 +17,11 @@ from transformers import AutoTokenizer, TextIteratorStreamer from collections import defaultdict -from threading import Thread +from threading import Thread, Lock import torch import logging import inspect +import hashlib import json import time import re @@ -162,6 +163,9 @@ def __init__( self.host_memory_reserved = 0 self.initializing_reservations = {} # job_id -> reserved_memory + # Lock for thread-safe memory operations + self.memory_lock = Lock() + def _ensure_model_entry(self, model_name: str): """Ensure a model has an entry in the cache with proper structure""" if model_name not in self.model_cache: @@ -427,6 +431,10 @@ def inspect_model( ): return {} + job_data["time"] = time.time() + job_id = hashlib.sha256(json.dumps(job_data).encode()).hexdigest() + job_data["id"] = job_id + # Reserve the host memory this model will use host_memory_used = distribution.get("host_memory_used", 0) if host_memory_used > 0: @@ -1139,15 +1147,17 @@ def _remove_hosted_job(self, job_id: str): def _reserve_host_memory(self, job_id: str, amount: int): """Reserve host memory for a model being initialized""" - self.host_memory_reserved += amount - self.initializing_reservations[job_id] = amount + with self.memory_lock: + self.host_memory_reserved += amount + self.initializing_reservations[job_id] = amount def _release_host_memory(self, job_id: str): """Release reserved host memory when initialization completes or fails""" - if job_id in self.initializing_reservations: - reserved = self.initializing_reservations[job_id] - self.host_memory_reserved -= reserved - del self.initializing_reservations[job_id] + with self.memory_lock: + if job_id in self.initializing_reservations: + reserved = self.initializing_reservations[job_id] + self.host_memory_reserved -= reserved + del self.initializing_reservations[job_id] def _get_available_host_memory(self) -> int: """Get currently available host memory accounting for reservations""" @@ -1159,8 +1169,9 @@ def _get_available_host_memory(self) -> int: ) if self._hosting_enabled: - total_memory = min(get_gpu_memory(), max_vram_bytes) - available_memory += total_memory - self.host_memory_reserved + with self.memory_lock: + total_memory = min(get_gpu_memory(), max_vram_bytes) + available_memory += total_memory - self.host_memory_reserved return available_memory def _audit_memory_reservations(self): @@ -1168,38 +1179,39 @@ def _audit_memory_reservations(self): Audit memory reservations and clean up any orphaned reservations. Called periodically to prevent memory leaks from edge cases. """ - # Find job_ids that have reservations but aren't in models or models_initializing - orphaned_reservations = [] + with self.memory_lock: + # Find job_ids that have reservations but aren't in models or models_initializing + orphaned_reservations = [] - for job_id in list(self.initializing_reservations.keys()): - if job_id not in self.models and job_id not in self.models_initializing: - orphaned_reservations.append(job_id) + for job_id in list(self.initializing_reservations.keys()): + if job_id not in self.models and job_id not in self.models_initializing: + orphaned_reservations.append(job_id) - # Release orphaned reservations - for job_id in orphaned_reservations: - reserved = self.initializing_reservations[job_id] - self.host_memory_reserved -= reserved - del self.initializing_reservations[job_id] + # Release orphaned reservations + for job_id in orphaned_reservations: + reserved = self.initializing_reservations[job_id] + self.host_memory_reserved -= reserved + del self.initializing_reservations[job_id] - self.send_request( - "debug_print", - ( - f"Released orphaned reservation: {job_id} ({reserved / 1e9:.2f}GB)", - "yellow", - logging.WARNING, - ), - ) + self.send_request( + "debug_print", + ( + f"Released orphaned reservation: {job_id} ({reserved / 1e9:.2f}GB)", + "yellow", + logging.WARNING, + ), + ) - if orphaned_reservations: - self.send_request( - "debug_print", - ( - f"Memory audit: Released {len(orphaned_reservations)} orphaned reservations. " - f"Total reserved: {self.host_memory_reserved / 1e9:.2f}GB", - "cyan", - logging.INFO, - ), - ) + if orphaned_reservations: + self.send_request( + "debug_print", + ( + f"Memory audit: Released {len(orphaned_reservations)} orphaned reservations. " + f"Total reserved: {self.host_memory_reserved / 1e9:.2f}GB", + "cyan", + logging.INFO, + ), + ) def main_loop(self): self.check_node() diff --git a/tensorlink/nodes/validator_thread.py b/tensorlink/nodes/validator_thread.py index a6edd4d..8b14894 100644 --- a/tensorlink/nodes/validator_thread.py +++ b/tensorlink/nodes/validator_thread.py @@ -521,9 +521,9 @@ def create_hf_job(self, job_info: dict, requesters_ip: str = None): _time = job_info.get("time") job_data = job_info - job_data["time"] = _time if not job_data.get("id"): + job_data["time"] = _time job_id = hashlib.sha256(json.dumps(job_data).encode()).hexdigest() job_data["id"] = job_id From 2c90c2b7db41459edf50d387b4624652cb441acc Mon Sep 17 00:00:00 2001 From: mattjhawken Date: Thu, 29 Jan 2026 09:57:18 -0500 Subject: [PATCH 24/25] Fixed job id bug for user-requested jobs (created from last push). Minor tweak to memory estimations. --- tensorlink/ml/utils.py | 41 +++++++++++---------------------- tensorlink/ml/validator.py | 7 +++--- tensorlink/nodes/user_thread.py | 1 + 3 files changed, 19 insertions(+), 30 deletions(-) diff --git a/tensorlink/ml/utils.py b/tensorlink/ml/utils.py index 6d8a19b..d64bd29 100644 --- a/tensorlink/ml/utils.py +++ b/tensorlink/ml/utils.py @@ -34,7 +34,7 @@ def estimate_memory( recursive: bool = True, count_activations: bool = True, ) -> tuple[float, dict]: - """Estimate memory with better control over what's counted.""" + """Estimate GPU memory required for a model.""" dtype_size = torch.tensor([], dtype=dtype).element_size() @@ -46,9 +46,10 @@ def estimate_memory( "kv_cache": 0, } - # Parameters, only count at this level if not recursive + # ---- parameters ---- if recursive: param_bytes = sum(p.numel() * p.element_size() for p in module.parameters()) + param_bytes += sum(b.numel() * b.element_size() for b in module.buffers()) else: param_bytes = sum( p.numel() * p.element_size() for p in module.parameters(recurse=False) @@ -59,22 +60,16 @@ def estimate_memory( breakdown["parameters"] = param_bytes + # ---- training extras ---- if training: breakdown["gradients"] = param_bytes - param_numel = sum(p.numel() for p in module.parameters()) - if optimizer_type.lower() in {"adam", "adamw"}: - # exp_avg + exp_avg_sq - opt_bytes = 2 * param_bytes * (4 / dtype_size) - + breakdown["optimizer"] = 2 * param_bytes * (4 / dtype_size) else: - opt_bytes = param_numel * dtype_size + breakdown["optimizer"] = param_bytes - breakdown["optimizer"] = opt_bytes - - # Only count activations if requested + # ---- activations ---- if count_activations: - # Try to infer hidden size from the module if hasattr(module, "config"): hidden_size = module.config.hidden_size elif hasattr(module, "hidden_size"): @@ -84,27 +79,16 @@ def estimate_memory( elif hasattr(module, "d_model"): hidden_size = module.d_model else: - # Estimate from parameter count - if recursive: - total_params = sum(p.numel() for p in module.parameters()) - else: - total_params = sum(p.numel() for p in module.parameters(recurse=False)) + total_params = sum(p.numel() for p in module.parameters()) + hidden_size = max(256, min(int((total_params / 12) ** 0.5), 8192)) - if total_params > 0: - # Rough heuristic: for transformer layers, params ≈ 12 * hidden_size^2 - hidden_size = max(128, min(int((total_params / 12) ** 0.5), 8192)) - else: - # Absolute last resort for modules with no parameters - hidden_size = 512 - - # More conservative activation multiplier activation_multiplier = 4 if not training else 7 breakdown["activations"] = ( batch_size * seq_length * hidden_size * dtype_size * activation_multiplier ) - if include_kv_cache and hasattr(module, 'config') and not training: + if include_kv_cache and hasattr(module, "config") and not training: num_layers = module.config.num_hidden_layers num_heads = getattr( module.config, @@ -123,7 +107,10 @@ def estimate_memory( * dtype_size ) - total = sum(breakdown.values()) * 1.30 # add 30% overhead + # ---- overhead ---- + OVERHEAD = 1.30 + total = sum(breakdown.values()) * OVERHEAD + return total, breakdown diff --git a/tensorlink/ml/validator.py b/tensorlink/ml/validator.py index e9851fd..b5c9fa9 100644 --- a/tensorlink/ml/validator.py +++ b/tensorlink/ml/validator.py @@ -431,9 +431,10 @@ def inspect_model( ): return {} - job_data["time"] = time.time() - job_id = hashlib.sha256(json.dumps(job_data).encode()).hexdigest() - job_data["id"] = job_id + if job_data.get("id") is None: + job_data["time"] = time.time() + job_id = hashlib.sha256(json.dumps(job_data).encode()).hexdigest() + job_data["id"] = job_id # Reserve the host memory this model will use host_memory_used = distribution.get("host_memory_used", 0) diff --git a/tensorlink/nodes/user_thread.py b/tensorlink/nodes/user_thread.py index 4ba5af5..c6f9520 100644 --- a/tensorlink/nodes/user_thread.py +++ b/tensorlink/nodes/user_thread.py @@ -397,6 +397,7 @@ def _send_job_req(self, validator: Connection, job_info): raise "Validator not a seed validator" message = b"JOB-REQ" + json.dumps(job_info).encode() + print(f'STORING: {validator.node_id}, {job_info["id"]}') self._store_request(validator.node_id, job_info["id"]) self.send_to_node(validator, message) start_time = time.time() From ef15a6cf25c791395d7541b6d44bbc4120014899 Mon Sep 17 00:00:00 2001 From: mattjhawken Date: Thu, 29 Jan 2026 13:00:53 -0500 Subject: [PATCH 25/25] Attempt at minor garbage collection improvement when loading modules on worker --- tensorlink/ml/worker.py | 38 ++++++++++++++++++++------------------ 1 file changed, 20 insertions(+), 18 deletions(-) diff --git a/tensorlink/ml/worker.py b/tensorlink/ml/worker.py index 4467404..2dc0fdc 100644 --- a/tensorlink/ml/worker.py +++ b/tensorlink/ml/worker.py @@ -139,6 +139,11 @@ def _load_model_skeleton(model_name: str, module_id: str, model_type: str = "cha skeleton_model = AutoModel.from_config(model_config) skeleton_model.eval() # Set to eval mode initially + + # Ensure no cached gradients or cached computations + for param in skeleton_model.parameters(): + param.requires_grad = False + return skeleton_model @@ -747,19 +752,8 @@ def _load_grouped_layers( f"Loading grouped layers {layer_range[0]}-{layer_range[1]} from {model_name}" ) - # Adjust layer paths if they include module_path prefix - adjusted_layer_paths = [] - for layer_path in layer_paths: - # If base_model is the submodule, layer paths should be relative to it - # eg 'model.layers.0' -> 'layers.0' - - # Check if layer_path starts with 'model.' - if layer_path.startswith('model.'): - adjusted_layer_paths.append(layer_path[6:]) - else: - adjusted_layer_paths.append(layer_path) - - # Create the layer group wrapper with empty weights + # Create the layer group wrapper with the skeleton's layers + # Extract references quickly before cleanup grouped_module = _create_layer_group_wrapper( base_model, layer_paths, @@ -769,12 +763,21 @@ def _load_grouped_layers( loop_iterator_name, ) - # Get name of model for loading weights - base_model_prefix = getattr(base_model, "base_model_prefix", None) + # CRITICAL: Aggressively cleanup skeleton immediately after extraction + # This is the key fix - cleanup happens as soon as we have the wrapper + del base_model + gc.collect() + if self.device.type == "cuda": + torch.cuda.synchronize() + torch.cuda.empty_cache() + + # Convert grouped module to empty tensors on CPU + # This removes any weight data that might have been copied from skeleton + grouped_module = grouped_module.to_empty(device="cpu") - # Aggressive cleanup before loading weights + # Another cleanup pass to ensure skeleton is gone with self.memory_efficient_context(): - del base_model + pass # Now load only the weights for the assigned layers logging.info(f"Loading weights for layers {layer_range[0]}-{layer_range[1]}") @@ -784,7 +787,6 @@ def _load_grouped_layers( ) # Load the state dict into the grouped module - grouped_module = grouped_module.to_empty(device="cpu") missing_keys, unexpected_keys = grouped_module.load_state_dict( state_dict, strict=False )