diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index fb65c3b..bf88196 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -47,6 +47,13 @@ jobs: - name: Install dependencies run: npm ci + - name: Package seeds directory + run: | + cd seeds + zip -r ../seeds.zip *.png + cd .. + echo "Created seeds.zip with $(unzip -l seeds.zip | tail -1 | awk '{print $2}') bytes" + - name: Build Tauri app uses: tauri-apps/tauri-action@v0 env: @@ -57,3 +64,10 @@ jobs: releaseBody: 'See the assets below to download and install Biome.' releaseDraft: false prerelease: false + + - name: Upload seeds.zip to release + if: matrix.platform == 'ubuntu-22.04' + uses: softprops/action-gh-release@v1 + with: + files: seeds.zip + token: ${{ secrets.GITHUB_TOKEN }} diff --git a/.gitignore b/.gitignore index 0d04c27..487c740 100644 --- a/.gitignore +++ b/.gitignore @@ -26,6 +26,12 @@ dist-ssr .env/ *.bak +.venv/ +uv.lock +__pycache__ +# seed directory created by standalone instance of server.py +src-tauri/world_engine + # env .env .vercel diff --git a/README.md b/README.md index 30faf38..35a353c 100644 --- a/README.md +++ b/README.md @@ -4,18 +4,18 @@ # Biome - **Explore AI-generated worlds in real-time, running locally on your GPU.** +**Explore AI-generated worlds in real-time, running locally on your GPU.** - [![Website](https://img.shields.io/badge/over.world-000000?logo=)](https://over.world) - [![Discord](https://img.shields.io/badge/Discord-5865F2?logo=discord&logoColor=white)](https://discord.gg/overworld) - [![X](https://img.shields.io/badge/X-000000?logo=x&logoColor=white)](https://x.com/overworld_ai) - [![Windows](https://img.shields.io/badge/Windows-0078D6?logo=windows&logoColor=white)](https://github.com/Overworldai/Biome/releases) - [![Linux](https://img.shields.io/badge/Linux-FCC624?logo=linux&logoColor=black)](https://github.com/Overworldai/Biome/releases) - [![License: GPL v3](https://img.shields.io/badge/License-GPLv3-blue.svg)](https://www.gnu.org/licenses/gpl-3.0) +[![Website](https://img.shields.io/badge/over.world-000000?logo=)](https://over.world) +[![Discord](https://img.shields.io/badge/Discord-5865F2?logo=discord&logoColor=white)](https://discord.gg/overworld) +[![X](https://img.shields.io/badge/X-000000?logo=x&logoColor=white)](https://x.com/overworld_ai) +[![Windows](https://img.shields.io/badge/Windows-0078D6?logo=windows&logoColor=white)](https://github.com/Overworldai/Biome/releases) +[![Linux](https://img.shields.io/badge/Linux-FCC624?logo=linux&logoColor=black)](https://github.com/Overworldai/Biome/releases) +[![License: GPL v3](https://img.shields.io/badge/License-GPLv3-blue.svg)](https://www.gnu.org/licenses/gpl-3.0) - **[Download the latest release](https://github.com/Overworldai/Biome/releases/latest)** +**[Download the latest release](https://github.com/Overworldai/Biome/releases/latest)** @@ -31,7 +31,6 @@ Biome installs just like a video game — download, run the installer, and start - Runs locally on your GPU via [World Engine](https://github.com/Overworldai/world_engine) - Lightweight native desktop application - ## Getting Started Grab the installer from the [Releases](https://github.com/Overworldai/Biome/releases/latest) page and you're good to go. diff --git a/package-lock.json b/package-lock.json index a36d09f..5201137 100644 --- a/package-lock.json +++ b/package-lock.json @@ -50,6 +50,7 @@ "integrity": "sha512-H3mcG6ZDLTlYfaSNi0iOKkigqMFvkTKlGUYlD8GW7nNOYRrevuA46iTypPyv+06V3fEmvvazfntkBU34L0azAw==", "dev": true, "license": "MIT", + "peer": true, "dependencies": { "@babel/code-frame": "^7.28.6", "@babel/generator": "^7.28.6", @@ -1012,6 +1013,7 @@ } ], "license": "MIT", + "peer": true, "dependencies": { "baseline-browser-mapping": "^2.9.0", "caniuse-lite": "^1.0.30001759", @@ -1227,6 +1229,7 @@ "resolved": "https://registry.npmjs.org/picomatch/-/picomatch-4.0.3.tgz", "integrity": "sha512-5gTmgEY/sqK6gFXLIsQNH19lWb4ebPDLA4SdLP7dsWkIXHWlG66oPuVvXSGFPppYZz8ZDZq0dYYrbHfBCVUb1Q==", "license": "MIT", + "peer": true, "engines": { "node": ">=12" }, @@ -1283,6 +1286,7 @@ "resolved": "https://registry.npmjs.org/react/-/react-18.3.1.tgz", "integrity": "sha512-wS+hAgJShR0KhEvPJArfuPVN1+Hz1t0Y6n5jLrGQbkb4urgPE/0Rve+1kMB1v/oWgHgm4WIcV+i7F2pTVj+2iQ==", "license": "MIT", + "peer": true, "dependencies": { "loose-envify": "^1.1.0" }, diff --git a/seeds/0_2.png b/seeds/0_2.png deleted file mode 100644 index 03620a6..0000000 Binary files a/seeds/0_2.png and /dev/null differ diff --git a/seeds/default.png b/seeds/default.png deleted file mode 100644 index f577521..0000000 Binary files a/seeds/default.png and /dev/null differ diff --git a/seeds/sample_00000.png b/seeds/sample_00000.png new file mode 100644 index 0000000..a372508 Binary files /dev/null and b/seeds/sample_00000.png differ diff --git a/seeds/sample_00001.png b/seeds/sample_00001.png new file mode 100644 index 0000000..b73e116 Binary files /dev/null and b/seeds/sample_00001.png differ diff --git a/seeds/sample_00002.png b/seeds/sample_00002.png new file mode 100644 index 0000000..19e495b Binary files /dev/null and b/seeds/sample_00002.png differ diff --git a/seeds/sample_00003.png b/seeds/sample_00003.png new file mode 100644 index 0000000..040434f Binary files /dev/null and b/seeds/sample_00003.png differ diff --git a/seeds/sample_00004.png b/seeds/sample_00004.png new file mode 100644 index 0000000..48c38db Binary files /dev/null and b/seeds/sample_00004.png differ diff --git a/seeds/sample_00005.png b/seeds/sample_00005.png new file mode 100644 index 0000000..28d99f5 Binary files /dev/null and b/seeds/sample_00005.png differ diff --git a/seeds/sample_00006.png b/seeds/sample_00006.png new file mode 100644 index 0000000..6723403 Binary files /dev/null and b/seeds/sample_00006.png differ diff --git a/seeds/sample_00007.png b/seeds/sample_00007.png new file mode 100644 index 0000000..097e6bd Binary files /dev/null and b/seeds/sample_00007.png differ diff --git a/seeds/sample_00008.png b/seeds/sample_00008.png new file mode 100644 index 0000000..8b956e6 Binary files /dev/null and b/seeds/sample_00008.png differ diff --git a/seeds/sample_00009.png b/seeds/sample_00009.png new file mode 100644 index 0000000..6c3f177 Binary files /dev/null and b/seeds/sample_00009.png differ diff --git a/seeds/sample_00010.png b/seeds/sample_00010.png new file mode 100644 index 0000000..c0e11e8 Binary files /dev/null and b/seeds/sample_00010.png differ diff --git a/seeds/sample_00011.png b/seeds/sample_00011.png new file mode 100644 index 0000000..cee5857 Binary files /dev/null and b/seeds/sample_00011.png differ diff --git a/seeds/sample_00012.png b/seeds/sample_00012.png new file mode 100644 index 0000000..8ea619f Binary files /dev/null and b/seeds/sample_00012.png differ diff --git a/seeds/sample_00013.png b/seeds/sample_00013.png new file mode 100644 index 0000000..624e466 Binary files /dev/null and b/seeds/sample_00013.png differ diff --git a/seeds/sample_00014.png b/seeds/sample_00014.png new file mode 100644 index 0000000..15be7aa Binary files /dev/null and b/seeds/sample_00014.png differ diff --git a/seeds/sample_00015.png b/seeds/sample_00015.png new file mode 100644 index 0000000..006b762 Binary files /dev/null and b/seeds/sample_00015.png differ diff --git a/seeds/sample_00016.png b/seeds/sample_00016.png new file mode 100644 index 0000000..1eaf70f Binary files /dev/null and b/seeds/sample_00016.png differ diff --git a/seeds/sample_00017.png b/seeds/sample_00017.png new file mode 100644 index 0000000..79be730 Binary files /dev/null and b/seeds/sample_00017.png differ diff --git a/seeds/sample_00018.png b/seeds/sample_00018.png new file mode 100644 index 0000000..442bf46 Binary files /dev/null and b/seeds/sample_00018.png differ diff --git a/seeds/sample_00019.png b/seeds/sample_00019.png new file mode 100644 index 0000000..686d83d Binary files /dev/null and b/seeds/sample_00019.png differ diff --git a/seeds/sample_00020.png b/seeds/sample_00020.png new file mode 100644 index 0000000..681d0c6 Binary files /dev/null and b/seeds/sample_00020.png differ diff --git a/seeds/sample_00021.png b/seeds/sample_00021.png new file mode 100644 index 0000000..f9aacb7 Binary files /dev/null and b/seeds/sample_00021.png differ diff --git a/seeds/sample_00022.png b/seeds/sample_00022.png new file mode 100644 index 0000000..54f30ac Binary files /dev/null and b/seeds/sample_00022.png differ diff --git a/seeds/sample_00023.png b/seeds/sample_00023.png new file mode 100644 index 0000000..d060d35 Binary files /dev/null and b/seeds/sample_00023.png differ diff --git a/seeds/sample_00024.png b/seeds/sample_00024.png new file mode 100644 index 0000000..a7d2f38 Binary files /dev/null and b/seeds/sample_00024.png differ diff --git a/seeds/sample_00025.png b/seeds/sample_00025.png new file mode 100644 index 0000000..b1241f6 Binary files /dev/null and b/seeds/sample_00025.png differ diff --git a/seeds/sample_00026.png b/seeds/sample_00026.png new file mode 100644 index 0000000..efc273f Binary files /dev/null and b/seeds/sample_00026.png differ diff --git a/seeds/sample_00027.png b/seeds/sample_00027.png new file mode 100644 index 0000000..4d2fc12 Binary files /dev/null and b/seeds/sample_00027.png differ diff --git a/seeds/sample_00028.png b/seeds/sample_00028.png new file mode 100644 index 0000000..36d9d4e Binary files /dev/null and b/seeds/sample_00028.png differ diff --git a/seeds/sample_00029.png b/seeds/sample_00029.png new file mode 100644 index 0000000..a17f0e0 Binary files /dev/null and b/seeds/sample_00029.png differ diff --git a/seeds/sample_00030.png b/seeds/sample_00030.png new file mode 100644 index 0000000..05a872b Binary files /dev/null and b/seeds/sample_00030.png differ diff --git a/seeds/sample_00031.png b/seeds/sample_00031.png new file mode 100644 index 0000000..3e7c46e Binary files /dev/null and b/seeds/sample_00031.png differ diff --git a/seeds/sample_00032.png b/seeds/sample_00032.png new file mode 100644 index 0000000..361b723 Binary files /dev/null and b/seeds/sample_00032.png differ diff --git a/seeds/sample_00033.png b/seeds/sample_00033.png new file mode 100644 index 0000000..2de6435 Binary files /dev/null and b/seeds/sample_00033.png differ diff --git a/seeds/sample_00034.png b/seeds/sample_00034.png new file mode 100644 index 0000000..cc4dacc Binary files /dev/null and b/seeds/sample_00034.png differ diff --git a/seeds/sample_00035.png b/seeds/sample_00035.png new file mode 100644 index 0000000..0ed86c4 Binary files /dev/null and b/seeds/sample_00035.png differ diff --git a/seeds/sample_00036.png b/seeds/sample_00036.png new file mode 100644 index 0000000..61b61c4 Binary files /dev/null and b/seeds/sample_00036.png differ diff --git a/seeds/sample_00037.png b/seeds/sample_00037.png new file mode 100644 index 0000000..e512707 Binary files /dev/null and b/seeds/sample_00037.png differ diff --git a/seeds/sample_00038.png b/seeds/sample_00038.png new file mode 100644 index 0000000..1e93074 Binary files /dev/null and b/seeds/sample_00038.png differ diff --git a/seeds/sample_00039.png b/seeds/sample_00039.png new file mode 100644 index 0000000..4ef44b4 Binary files /dev/null and b/seeds/sample_00039.png differ diff --git a/seeds/sample_00040.png b/seeds/sample_00040.png new file mode 100644 index 0000000..d33e9d4 Binary files /dev/null and b/seeds/sample_00040.png differ diff --git a/seeds/sample_00041.png b/seeds/sample_00041.png new file mode 100644 index 0000000..3195fc1 Binary files /dev/null and b/seeds/sample_00041.png differ diff --git a/seeds/sample_00042.png b/seeds/sample_00042.png new file mode 100644 index 0000000..28c1c83 Binary files /dev/null and b/seeds/sample_00042.png differ diff --git a/seeds/sample_00043.png b/seeds/sample_00043.png new file mode 100644 index 0000000..091bf79 Binary files /dev/null and b/seeds/sample_00043.png differ diff --git a/seeds/sample_00044.png b/seeds/sample_00044.png new file mode 100644 index 0000000..2e04bb1 Binary files /dev/null and b/seeds/sample_00044.png differ diff --git a/seeds/sample_00045.png b/seeds/sample_00045.png new file mode 100644 index 0000000..1489325 Binary files /dev/null and b/seeds/sample_00045.png differ diff --git a/seeds/sample_00046.png b/seeds/sample_00046.png new file mode 100644 index 0000000..30734f4 Binary files /dev/null and b/seeds/sample_00046.png differ diff --git a/seeds/sample_00047.png b/seeds/sample_00047.png new file mode 100644 index 0000000..c784774 Binary files /dev/null and b/seeds/sample_00047.png differ diff --git a/seeds/sample_00048.png b/seeds/sample_00048.png new file mode 100644 index 0000000..f92c06b Binary files /dev/null and b/seeds/sample_00048.png differ diff --git a/seeds/sample_00049.png b/seeds/sample_00049.png new file mode 100644 index 0000000..2a00f0e Binary files /dev/null and b/seeds/sample_00049.png differ diff --git a/seeds/sample_00050.png b/seeds/sample_00050.png new file mode 100644 index 0000000..9cf38d2 Binary files /dev/null and b/seeds/sample_00050.png differ diff --git a/seeds/starter (1).png b/seeds/starter (1).png deleted file mode 100644 index 5c5da00..0000000 Binary files a/seeds/starter (1).png and /dev/null differ diff --git a/seeds/starter (11).png b/seeds/starter (11).png deleted file mode 100644 index 220ed67..0000000 Binary files a/seeds/starter (11).png and /dev/null differ diff --git a/seeds/starter (12).png b/seeds/starter (12).png deleted file mode 100644 index b21ba2a..0000000 Binary files a/seeds/starter (12).png and /dev/null differ diff --git a/seeds/starter (14).png b/seeds/starter (14).png deleted file mode 100644 index 2463d72..0000000 Binary files a/seeds/starter (14).png and /dev/null differ diff --git a/seeds/starter (15).png b/seeds/starter (15).png deleted file mode 100644 index 71e354a..0000000 Binary files a/seeds/starter (15).png and /dev/null differ diff --git a/seeds/starter (16).png b/seeds/starter (16).png deleted file mode 100644 index bca251b..0000000 Binary files a/seeds/starter (16).png and /dev/null differ diff --git a/seeds/starter (17).png b/seeds/starter (17).png deleted file mode 100644 index 89f3001..0000000 Binary files a/seeds/starter (17).png and /dev/null differ diff --git a/seeds/starter (18).png b/seeds/starter (18).png deleted file mode 100644 index 307664a..0000000 Binary files a/seeds/starter (18).png and /dev/null differ diff --git a/seeds/starter (19).png b/seeds/starter (19).png deleted file mode 100644 index acfe829..0000000 Binary files a/seeds/starter (19).png and /dev/null differ diff --git a/seeds/starter (2).png b/seeds/starter (2).png deleted file mode 100644 index 25b50b9..0000000 Binary files a/seeds/starter (2).png and /dev/null differ diff --git a/seeds/starter (20).png b/seeds/starter (20).png deleted file mode 100644 index 6089f58..0000000 Binary files a/seeds/starter (20).png and /dev/null differ diff --git a/seeds/starter (21).png b/seeds/starter (21).png deleted file mode 100644 index 43bef47..0000000 Binary files a/seeds/starter (21).png and /dev/null differ diff --git a/seeds/starter (22).png b/seeds/starter (22).png deleted file mode 100644 index 56223af..0000000 Binary files a/seeds/starter (22).png and /dev/null differ diff --git a/seeds/starter (23).png b/seeds/starter (23).png deleted file mode 100644 index 77a4d6a..0000000 Binary files a/seeds/starter (23).png and /dev/null differ diff --git a/seeds/starter (24).png b/seeds/starter (24).png deleted file mode 100644 index 300f6e6..0000000 Binary files a/seeds/starter (24).png and /dev/null differ diff --git a/seeds/starter (26).png b/seeds/starter (26).png deleted file mode 100644 index 371c5f0..0000000 Binary files a/seeds/starter (26).png and /dev/null differ diff --git a/seeds/starter (27).png b/seeds/starter (27).png deleted file mode 100644 index 46c1bb0..0000000 Binary files a/seeds/starter (27).png and /dev/null differ diff --git a/seeds/starter (3).png b/seeds/starter (3).png deleted file mode 100644 index 478e3b9..0000000 Binary files a/seeds/starter (3).png and /dev/null differ diff --git a/seeds/starter (4).png b/seeds/starter (4).png deleted file mode 100644 index 638c058..0000000 Binary files a/seeds/starter (4).png and /dev/null differ diff --git a/seeds/starter (5).png b/seeds/starter (5).png deleted file mode 100644 index ad241de..0000000 Binary files a/seeds/starter (5).png and /dev/null differ diff --git a/seeds/starter (6).png b/seeds/starter (6).png deleted file mode 100644 index 6434850..0000000 Binary files a/seeds/starter (6).png and /dev/null differ diff --git a/seeds/starter (9).png b/seeds/starter (9).png deleted file mode 100644 index 7d8c0b1..0000000 Binary files a/seeds/starter (9).png and /dev/null differ diff --git a/src-tauri/Cargo.lock b/src-tauri/Cargo.lock index c3fb351..ded80c3 100644 --- a/src-tauri/Cargo.lock +++ b/src-tauri/Cargo.lock @@ -11,6 +11,7 @@ dependencies = [ "flate2", "image", "kill_tree", + "log", "reqwest", "serde", "serde_json", diff --git a/src-tauri/Cargo.toml b/src-tauri/Cargo.toml index 69dab57..de2daee 100644 --- a/src-tauri/Cargo.toml +++ b/src-tauri/Cargo.toml @@ -23,7 +23,7 @@ tauri-plugin-opener = "2" tauri-plugin-fs = "2" serde = { version = "1", features = ["derive"] } serde_json = "1" -reqwest = { version = "0.12" } +reqwest = { version = "0.12", features = ["json"] } zip = "2" flate2 = "1" tar = "0.4" @@ -31,4 +31,5 @@ kill_tree = "0.2" base64 = "0.22" image = { version = "0.25", default-features = false, features = ["png", "jpeg"] } ctrlc = "3" +log = "0.4" diff --git a/src-tauri/server-components/engine_manager.py b/src-tauri/server-components/engine_manager.py new file mode 100644 index 0000000..4af880c --- /dev/null +++ b/src-tauri/server-components/engine_manager.py @@ -0,0 +1,266 @@ +""" +WorldEngine module - Handles AI world generation and frame streaming. + +Extracted from monolithic server.py to provide clean separation of concerns. +""" + +import asyncio +import base64 +import io +import logging +import time +from dataclasses import dataclass + +import torch +import torch.nn.functional as F +from PIL import Image + +logger = logging.getLogger(__name__) + +# ============================================================================ +# Configuration +# ============================================================================ + +MODEL_URI = "Overworld/Waypoint-1-Small" +QUANT = None +N_FRAMES = 4096 +DEVICE = "cuda" +JPEG_QUALITY = 85 + +BUTTON_CODES = {} +# A-Z keys +for i in range(65, 91): + BUTTON_CODES[chr(i)] = i +# 0-9 keys +for i in range(10): + BUTTON_CODES[str(i)] = ord(str(i)) +# Special keys +BUTTON_CODES["UP"] = 0x26 +BUTTON_CODES["DOWN"] = 0x28 +BUTTON_CODES["LEFT"] = 0x25 +BUTTON_CODES["RIGHT"] = 0x27 +BUTTON_CODES["SHIFT"] = 0x10 +BUTTON_CODES["CTRL"] = 0x11 +BUTTON_CODES["SPACE"] = 0x20 +BUTTON_CODES["TAB"] = 0x09 +BUTTON_CODES["ENTER"] = 0x0D +BUTTON_CODES["MOUSE_LEFT"] = 0x01 +BUTTON_CODES["MOUSE_RIGHT"] = 0x02 +BUTTON_CODES["MOUSE_MIDDLE"] = 0x04 + +# Default prompt - describes the expected visual style +DEFAULT_PROMPT = ( + "First-person shooter gameplay footage from a true POV perspective, " + "the camera locked to the player's eyes as assault rifles, carbines, " + "machine guns, laser-sighted firearms, bullet-fed weapons, magazines, " + "barrels, muzzles, tracers, ammo, and launchers dominate the frame, " + "with constant gun handling, recoil, muzzle flash, shell ejection, " + "and ballistic impacts. Continuous real-time FPS motion with no cuts, " + "weapon-centric framing, realistic gun physics, authentic firearm " + "materials, high-caliber ammunition, laser optics, iron sights, and " + "relentless gun-driven action, rendered in ultra-realistic 4K at 60fps." +) + + +# ============================================================================ +# Session Management +# ============================================================================ + + +@dataclass +class Session: + """Tracks state for a single WebSocket connection.""" + + frame_count: int = 0 + max_frames: int = N_FRAMES - 2 + + +# ============================================================================ +# WorldEngine Manager +# ============================================================================ + + +class WorldEngineManager: + """Manages WorldEngine state and operations.""" + + def __init__(self): + self.engine = None + self.seed_frame = None + self.CtrlInput = None + self.current_prompt = DEFAULT_PROMPT + self.engine_warmed_up = False + + def load_seed_from_file( + self, file_path: str, target_size: tuple[int, int] = (360, 640) + ) -> torch.Tensor: + """Load a seed frame from a file path.""" + try: + img = Image.open(file_path).convert("RGB") + import numpy as np + + img_tensor = ( + torch.from_numpy(np.array(img)).permute(2, 0, 1).unsqueeze(0).float() + ) + frame = F.interpolate( + img_tensor, size=target_size, mode="bilinear", align_corners=False + )[0] + return ( + frame.to(dtype=torch.uint8, device=DEVICE) + .permute(1, 2, 0) + .contiguous() + ) + except Exception as e: + logger.error(f"Failed to load seed from file {file_path}: {e}") + return None + + def load_seed_from_base64( + self, base64_data: str, target_size: tuple[int, int] = (360, 640) + ) -> torch.Tensor: + """Load a seed frame from base64 encoded data.""" + try: + img_data = base64.b64decode(base64_data) + img = Image.open(io.BytesIO(img_data)).convert("RGB") + import numpy as np + + img_tensor = ( + torch.from_numpy(np.array(img)).permute(2, 0, 1).unsqueeze(0).float() + ) + frame = F.interpolate( + img_tensor, size=target_size, mode="bilinear", align_corners=False + )[0] + return ( + frame.to(dtype=torch.uint8, device=DEVICE) + .permute(1, 2, 0) + .contiguous() + ) + except Exception as e: + logger.error(f"Failed to load seed from base64: {e}") + return None + + + async def load_engine(self): + """Initialize the WorldEngine with configured model.""" + logger.info("=" * 60) + logger.info("BIOME ENGINE STARTUP") + logger.info("=" * 60) + + logger.info("[1/4] Importing WorldEngine...") + import_start = time.perf_counter() + from world_engine import CtrlInput as CI + from world_engine import WorldEngine + + self.CtrlInput = CI + logger.info( + f"[1/4] WorldEngine imported in {time.perf_counter() - import_start:.2f}s" + ) + + logger.info(f"[2/4] Loading model: {MODEL_URI}") + logger.info(f" Quantization: {QUANT}") + logger.info(f" Device: {DEVICE}") + logger.info(f" N_FRAMES: {N_FRAMES}") + logger.info(f" Prompt: {self.current_prompt[:60]}...") + + # Model config overrides + # scheduler_sigmas: diffusion denoising schedule (MUST end with 0.0) + # ae_uri: VAE model for encoding/decoding frames + model_start = time.perf_counter() + self.engine = WorldEngine( + MODEL_URI, + device=DEVICE, + model_config_overrides={ + "n_frames": N_FRAMES, + "ae_uri": "OpenWorldLabs/owl_vae_f16_c16_distill_v0_nogan", + "scheduler_sigmas": [1.0, 0.8, 0.2, 0.0], + }, + quant=QUANT, + dtype=torch.bfloat16, + ) + logger.info( + f"[2/4] Model loaded in {time.perf_counter() - model_start:.2f}s" + ) + + # Seed frame will be provided by frontend via set_initial_seed message + logger.info( + "[3/4] Seed frame: waiting for client to provide initial seed via base64" + ) + self.seed_frame = None + + logger.info("[4/4] Engine initialization complete") + logger.info("=" * 60) + logger.info("SERVER READY - Waiting for WebSocket connections on /ws") + logger.info(" (Client must send set_initial_seed with base64 data)") + logger.info("=" * 60) + + def frame_to_jpeg(self, frame: torch.Tensor, quality: int = JPEG_QUALITY) -> bytes: + """Convert frame tensor to JPEG bytes.""" + if frame.dtype != torch.uint8: + frame = frame.clamp(0, 255).to(torch.uint8) + img = Image.fromarray(frame.cpu().numpy(), mode="RGB") + buf = io.BytesIO() + img.save(buf, format="JPEG", quality=quality) + return buf.getvalue() + + async def generate_frame(self, ctrl_input) -> torch.Tensor: + """Generate next frame using WorldEngine.""" + frame = await asyncio.to_thread(self.engine.gen_frame, ctrl=ctrl_input) + return frame + + async def reset_state(self): + """Reset engine state.""" + await asyncio.to_thread(self.engine.reset) + await asyncio.to_thread(self.engine.append_frame, self.seed_frame) + await asyncio.to_thread(self.engine.set_prompt, self.current_prompt) + + async def warmup(self): + """Perform initial warmup to compile CUDA graphs.""" + + def do_warmup(): + warmup_start = time.perf_counter() + + logger.info("[5/5] Step 1: Resetting engine state...") + reset_start = time.perf_counter() + self.engine.reset() + logger.info( + f"[5/5] Step 1: Reset complete in {time.perf_counter() - reset_start:.2f}s" + ) + + logger.info("[5/5] Step 2: Appending seed frame...") + append_start = time.perf_counter() + self.engine.append_frame(self.seed_frame) + logger.info( + f"[5/5] Step 2: Seed frame appended in {time.perf_counter() - append_start:.2f}s" + ) + + logger.info("[5/5] Step 3: Setting prompt...") + prompt_start = time.perf_counter() + self.engine.set_prompt(self.current_prompt) + logger.info( + f"[5/5] Step 3: Prompt set in {time.perf_counter() - prompt_start:.2f}s" + ) + + logger.info( + "[5/5] Step 4: Generating first frame (compiling CUDA graphs)..." + ) + gen_start = time.perf_counter() + _ = self.engine.gen_frame( + ctrl=self.CtrlInput(button=set(), mouse=(0.0, 0.0)) + ) + logger.info( + f"[5/5] Step 4: First frame generated in {time.perf_counter() - gen_start:.2f}s" + ) + + return time.perf_counter() - warmup_start + + logger.info("=" * 60) + logger.info( + "[5/5] WARMUP - First client connected, initializing CUDA graphs..." + ) + logger.info("=" * 60) + + warmup_time = await asyncio.to_thread(do_warmup) + + logger.info("=" * 60) + logger.info(f"[5/5] WARMUP COMPLETE - Total time: {warmup_time:.2f}s") + logger.info("=" * 60) + + self.engine_warmed_up = True diff --git a/src-tauri/server-components/pyproject.toml b/src-tauri/server-components/pyproject.toml index d288f06..f81c79f 100644 --- a/src-tauri/server-components/pyproject.toml +++ b/src-tauri/server-components/pyproject.toml @@ -7,8 +7,11 @@ dependencies = [ "pillow", "fastapi>=0.128.0", "uvicorn>=0.40.0", + "httpx>=0.27.0", "websockets>=15.0.1", "hf-xet>=1.0.0", + "transformers>=4.30.0", + "timm>=0.9.0", ] [dependency-groups] diff --git a/src-tauri/server-components/safety.py b/src-tauri/server-components/safety.py new file mode 100644 index 0000000..57e4c24 --- /dev/null +++ b/src-tauri/server-components/safety.py @@ -0,0 +1,277 @@ +""" +Safety module - NSFW image detection for seed validation. + +Uses Freepik/nsfw_image_detector model to check images for inappropriate content. +""" + +import gc +import logging +import threading +from typing import List, Dict + +import torch +import torch.nn.functional as F +from PIL import Image +from transformers import AutoModelForImageClassification +from timm.data.transforms_factory import create_transform +from torchvision.transforms import Compose +from timm.data import resolve_data_config +from timm.models import get_pretrained_cfg + +logger = logging.getLogger(__name__) + + +class SafetyChecker: + """NSFW content detector for seed images.""" + + def __init__(self): + self.model = None + self.processor = None + self.device = "cuda" if torch.cuda.is_available() else "cpu" + self._lock = threading.Lock() # Prevent concurrent model access + logger.info(f"SafetyChecker initialized (device: {self.device})") + + def _load_model(self): + """Lazy load model on first check.""" + if self.model is None: + logger.info("Loading NSFW detection model...") + + if self.device == "cuda": + load_start = torch.cuda.Event(enable_timing=True) + load_end = torch.cuda.Event(enable_timing=True) + load_start.record() + + self.model = AutoModelForImageClassification.from_pretrained( + "Freepik/nsfw_image_detector", dtype=torch.bfloat16 + ).to(self.device) + + cfg = get_pretrained_cfg("eva02_base_patch14_448.mim_in22k_ft_in22k_in1k") + self.processor = create_transform(**resolve_data_config(cfg.__dict__)) + + if self.device == "cuda": + load_end.record() + torch.cuda.synchronize() + load_time = load_start.elapsed_time(load_end) / 1000 # Convert to seconds + logger.info(f"NSFW detection model loaded in {load_time:.2f}s") + else: + logger.info("NSFW detection model loaded") + + def unload_model(self): + """Unload model from memory to free resources.""" + if self.model is not None: + logger.info("Unloading NSFW detection model...") + + # Move model to CPU before deletion if it was on GPU + if self.device == "cuda": + self.model.cpu() + + # Delete model and processor + del self.model + del self.processor + self.model = None + self.processor = None + + # Clear CUDA cache if available + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + # Force garbage collection + gc.collect() + + logger.info("NSFW detection model unloaded") + + def check_image(self, image_path: str) -> Dict[str, any]: + """ + Check single image for NSFW content. + + Args: + image_path: Path to image file + + Returns: + { + 'is_safe': bool, + 'scores': { + 'neutral': float, + 'low': float, + 'medium': float, + 'high': float + } + } + """ + with self._lock: + self._load_model() + + try: + img = Image.open(image_path) + # Convert to RGB to handle RGBA/RGB mode differences + if img.mode != "RGB": + img = img.convert("RGB") + scores = self.predict_batch_values([img])[0] + is_safe = scores["low"] < 0.5 # Strict threshold + result = {"is_safe": is_safe, "scores": scores} + except Exception as e: + logger.error(f"Failed to check image {image_path}: {e}") + # Default to unsafe on error (conservative approach) + result = { + "is_safe": False, + "scores": {"neutral": 0.0, "low": 1.0, "medium": 0.0, "high": 0.0}, + } + finally: + # Unload model after check to free memory + self.unload_model() + + return result + + def check_batch( + self, image_paths: List[str], batch_size: int = 8 + ) -> List[Dict[str, any]]: + """ + Check multiple images efficiently with proper batching. + + Args: + image_paths: List of paths to image files + batch_size: Number of images to process at once (default 8 to avoid GPU OOM) + + Returns: + List of results matching check_image() format + """ + if not image_paths: + return [] + + with self._lock: + self._load_model() + + try: + # First pass: load all images and track which ones failed + images = [] + for path in image_paths: + try: + img = Image.open(path) + # Convert to RGB to handle RGBA/RGB mode differences + if img.mode != "RGB": + img = img.convert("RGB") + images.append(img) + except Exception as e: + logger.error(f"Failed to load image {path}: {e}") + images.append(None) + + # Process valid images in batches + valid_images = [img for img in images if img is not None] + all_scores = [] + + for i in range(0, len(valid_images), batch_size): + batch = valid_images[i : i + batch_size] + batch_scores = self.predict_batch_values(batch) + all_scores.extend(batch_scores) + + # Build results, matching order of input paths + results = [] + score_idx = 0 + for img in images: + if img is None: + # Failed to load - mark as unsafe + results.append( + { + "is_safe": False, + "scores": { + "neutral": 0.0, + "low": 1.0, + "medium": 0.0, + "high": 0.0, + }, + } + ) + else: + scores = all_scores[score_idx] + results.append({"is_safe": scores["low"] < 0.5, "scores": scores}) + score_idx += 1 + + return results + except Exception as e: + logger.error(f"Failed to check batch: {e}") + # Return all unsafe on batch failure + return [ + { + "is_safe": False, + "scores": {"neutral": 0.0, "low": 1.0, "medium": 0.0, "high": 0.0}, + } + for _ in image_paths + ] + finally: + # Unload model after batch check to free memory + self.unload_model() + + def predict_batch_values( + self, img_batch: List[Image.Image] + ) -> List[Dict[str, float]]: + """ + Process a batch of images and return prediction scores for each NSFW category. + + Args: + img_batch: List of PIL images + + Returns: + List of score dictionaries with cumulative probabilities: + [ + { + 'neutral': float, # Probability of being neutral (only this category) + 'low': float, # Probability of being low or higher (cumulative) + 'medium': float, # Probability of being medium or higher (cumulative) + 'high': float # Probability of being high (cumulative) + } + ] + """ + idx_to_label = {0: "neutral", 1: "low", 2: "medium", 3: "high"} + + # Prepare batch + inputs = torch.stack([self.processor(img) for img in img_batch]).to( + self.device + ) + output = [] + + with torch.inference_mode(): + logits = self.model(inputs).logits + batch_probs = F.log_softmax(logits, dim=-1) + batch_probs = torch.exp(batch_probs).cpu() + + for i in range(len(batch_probs)): + element_probs = batch_probs[i] + output_img = {} + danger_cum_sum = 0 + + # Cumulative sum from high to low (reverse order) + for j in range(len(element_probs) - 1, -1, -1): + danger_cum_sum += element_probs[j] + if j == 0: + danger_cum_sum = element_probs[j] # Neutral is not cumulative + output_img[idx_to_label[j]] = danger_cum_sum.item() + + output.append(output_img) + + return output + + def prediction( + self, + img_batch: List[Image.Image], + class_to_predict: str, + threshold: float = 0.5, + ) -> List[bool]: + """ + Predict if images meet or exceed a specific NSFW threshold. + + Args: + img_batch: List of PIL images + class_to_predict: One of "low", "medium", "high" + threshold: Probability threshold (0.0 to 1.0) + + Returns: + List of booleans indicating if each image meets the threshold + """ + if class_to_predict not in ["low", "medium", "high"]: + raise ValueError("class_to_predict must be one of: low, medium, high") + + if not 0 <= threshold <= 1: + raise ValueError("threshold must be between 0 and 1") + + output = self.predict_batch_values(img_batch) + return [output[i][class_to_predict] >= threshold for i in range(len(output))] diff --git a/src-tauri/server-components/server.py b/src-tauri/server-components/server.py index 5db469a..4c6abe6 100644 --- a/src-tauri/server-components/server.py +++ b/src-tauri/server-components/server.py @@ -1,10 +1,14 @@ """ -Low-latency WebSocket server for WorldEngine frame streaming. +Tauri <> Python Communication Bridge + +Low-latency WebSocket server that orchestrates WorldEngine and Safety modules. +This server acts as a unified interface for both world generation and safety checking. Usage: - python examples/websocket_server.py + python server.py --host 0.0.0.0 --port 7987 Client connects via WebSocket to ws://localhost:7987/ws +Safety checks via HTTP POST to http://localhost:7987/safety/check_batch """ # Immediate startup logging before any imports that could fail @@ -15,13 +19,17 @@ import asyncio import base64 -import io +import hashlib import json import logging +import os +import pickle +import shutil import time -import urllib.request +import zipfile from contextlib import asynccontextmanager -from dataclasses import dataclass +from pathlib import Path +from typing import Optional logging.basicConfig( level=logging.INFO, @@ -29,14 +37,13 @@ datefmt="%H:%M:%S", stream=sys.stdout, ) -logger = logging.getLogger("websocket_server") +logger = logging.getLogger("biome_server") print("[BIOME] Basic imports done", flush=True) try: print("[BIOME] Importing torch...", flush=True) import torch - import torch.nn.functional as F print(f"[BIOME] torch {torch.__version__} imported", flush=True) @@ -54,8 +61,20 @@ import uvicorn from fastapi import FastAPI, WebSocket, WebSocketDisconnect from fastapi.responses import JSONResponse + from pydantic import BaseModel print("[BIOME] FastAPI imported", flush=True) + + print("[BIOME] Importing Engine Manager module...", flush=True) + from engine_manager import WorldEngineManager, Session, BUTTON_CODES + + print("[BIOME] Engine Manager module imported", flush=True) + + print("[BIOME] Importing Safety module...", flush=True) + from safety import SafetyChecker + + print("[BIOME] Safety module imported", flush=True) + except Exception as e: print(f"[BIOME] FATAL: Import failed: {e}", flush=True) import traceback @@ -64,216 +83,546 @@ sys.exit(1) # ============================================================================ -# Configuration +# Global Module Instances # ============================================================================ -MODEL_URI = "Overworld/Waypoint-1-Small" -QUANT = "w8a8" -N_FRAMES = 4096 -DEVICE = "cuda" -JPEG_QUALITY = 85 - -BUTTON_CODES = {} -# A-Z keys -for i in range(65, 91): - BUTTON_CODES[chr(i)] = i -# 0-9 keys -for i in range(10): - BUTTON_CODES[str(i)] = ord(str(i)) -# Special keys -BUTTON_CODES["UP"] = 0x26 -BUTTON_CODES["DOWN"] = 0x28 -BUTTON_CODES["LEFT"] = 0x25 -BUTTON_CODES["RIGHT"] = 0x27 -BUTTON_CODES["SHIFT"] = 0x10 -BUTTON_CODES["CTRL"] = 0x11 -BUTTON_CODES["SPACE"] = 0x20 -BUTTON_CODES["TAB"] = 0x09 -BUTTON_CODES["ENTER"] = 0x0D -BUTTON_CODES["MOUSE_LEFT"] = 0x01 -BUTTON_CODES["MOUSE_RIGHT"] = 0x02 -BUTTON_CODES["MOUSE_MIDDLE"] = 0x04 - - - -# Default prompt - describes the expected visual style -DEFAULT_PROMPT = ("First-person shooter gameplay footage from a true POV perspective, " -"the camera locked to the player's eyes as assault rifles, carbines, " -"machine guns, laser-sighted firearms, bullet-fed weapons, magazines, " -"barrels, muzzles, tracers, ammo, and launchers dominate the frame, " -"with constant gun handling, recoil, muzzle flash, shell ejection, " -"and ballistic impacts. Continuous real-time FPS motion with no cuts, " -"weapon-centric framing, realistic gun physics, authentic firearm " -"materials, high-caliber ammunition, laser optics, iron sights, and " -"relentless gun-driven action, rendered in ultra-realistic 4K at 60fps.") +world_engine = None +safety_checker = None +safe_seeds_cache = {} # Maps filename -> {hash, is_safe, path} +rescan_lock = None # Prevent concurrent rescans (initialized in lifespan) + +# ============================================================================ +# Seed Management Configuration +# ============================================================================ + +# Server-side seed storage paths +SEEDS_BASE_DIR = Path(__file__).parent.parent / "world_engine" / "seeds" +DEFAULT_SEEDS_DIR = SEEDS_BASE_DIR / "default" +UPLOADS_DIR = SEEDS_BASE_DIR / "uploads" +CACHE_FILE = Path(__file__).parent.parent / "world_engine" / ".seeds_cache.bin" + +# Local seeds directory (relative to project root for standalone usage) +LOCAL_SEEDS_DIR = Path(__file__).parent.parent.parent / "seeds" + +DEFAULT_SEEDS_URL = "https://github.com/Wayfarer-Labs/Biome/releases/latest/download/seeds.zip" # ============================================================================ -# Engine Setup +# Seed Management Functions # ============================================================================ -engine = None -seed_frame = None -CtrlInput = None -current_prompt = DEFAULT_PROMPT -engine_warmed_up = False + +def ensure_seed_directories(): + """Create seed directory structure if it doesn't exist.""" + DEFAULT_SEEDS_DIR.mkdir(parents=True, exist_ok=True) + UPLOADS_DIR.mkdir(parents=True, exist_ok=True) + logger.info(f"Seed directories initialized: {SEEDS_BASE_DIR}") -def load_seed_from_base64(base64_data: str, target_size: tuple[int, int] = (360, 640)) -> torch.Tensor: - """Load a seed frame from base64 encoded data.""" +async def download_default_seeds(): + """Download and extract default seeds on first startup, or use local seeds if available.""" + if list(DEFAULT_SEEDS_DIR.glob("*.png")): + logger.info("Default seeds already exist, skipping setup") + return + + # Check if local seeds directory exists (for standalone usage) + if LOCAL_SEEDS_DIR.exists() and list(LOCAL_SEEDS_DIR.glob("*.png")): + logger.info(f"Found local seeds directory at {LOCAL_SEEDS_DIR}") + try: + # Copy seeds from local directory to default directory + seed_files = list(LOCAL_SEEDS_DIR.glob("*.png")) + logger.info(f"Copying {len(seed_files)} seed files to {DEFAULT_SEEDS_DIR}") + + for seed_file in seed_files: + dest = DEFAULT_SEEDS_DIR / seed_file.name + shutil.copy2(seed_file, dest) + logger.info(f" Copied {seed_file.name}") + + logger.info("Local seeds copied successfully") + return + except Exception as e: + logger.error(f"Failed to copy local seeds: {e}") + logger.info("Will attempt to download instead...") + + # No local seeds found, attempt download + logger.info("No local seeds found, downloading from remote...") try: - img_data = base64.b64decode(base64_data) - img = Image.open(io.BytesIO(img_data)).convert("RGB") - import numpy as np + import httpx + + async with httpx.AsyncClient(timeout=60.0) as client: + logger.info(f"Downloading default seeds from {DEFAULT_SEEDS_URL}") + response = await client.get(DEFAULT_SEEDS_URL) + response.raise_for_status() + + # Save zip temporarily + zip_path = SEEDS_BASE_DIR / "seeds.zip" + zip_path.write_bytes(response.content) + logger.info(f"Downloaded {len(response.content)} bytes") + + # Extract to default directory + with zipfile.ZipFile(zip_path, "r") as zip_ref: + zip_ref.extractall(DEFAULT_SEEDS_DIR) + logger.info(f"Extracted default seeds to {DEFAULT_SEEDS_DIR}") + + # Cleanup + zip_path.unlink() - img_tensor = ( - torch.from_numpy(np.array(img)).permute(2, 0, 1).unsqueeze(0).float() - ) - frame = F.interpolate( - img_tensor, size=target_size, mode="bilinear", align_corners=False - )[0] - return frame.to(dtype=torch.uint8, device=DEVICE).permute(1, 2, 0).contiguous() except Exception as e: - print(f"[ERROR] Failed to load seed from base64: {e}") - return None + logger.error(f"Failed to download default seeds: {e}") + logger.info(f"Please manually place seed images in {DEFAULT_SEEDS_DIR} or {LOCAL_SEEDS_DIR}") + +def load_seeds_cache() -> dict: + """Load seeds cache from binary file.""" + if not CACHE_FILE.exists(): + logger.info("No cache file found, will create new one") + return {"files": {}, "last_scan": None} -def load_seed_from_url(url, target_size=(360, 640)): - """Load a seed frame from URL (used for prompt_with_seed)""" try: - with urllib.request.urlopen(url, timeout=10) as response: - img_data = response.read() - img = Image.open(io.BytesIO(img_data)).convert("RGB") - import numpy as np + with open(CACHE_FILE, "rb") as f: + cache = pickle.load(f) + logger.info(f"Loaded cache with {len(cache.get('files', {}))} seeds") + return cache + except Exception as e: + logger.error(f"Failed to load cache: {e}") + return {"files": {}, "last_scan": None} - img_tensor = ( - torch.from_numpy(np.array(img)).permute(2, 0, 1).unsqueeze(0).float() - ) - frame = F.interpolate( - img_tensor, size=target_size, mode="bilinear", align_corners=False - )[0] - return frame.to(dtype=torch.uint8, device=DEVICE).permute(1, 2, 0).contiguous() + +def save_seeds_cache(cache: dict): + """Save seeds cache to binary file.""" + try: + with open(CACHE_FILE, "wb") as f: + pickle.dump(cache, f) + logger.info(f"Saved cache with {len(cache.get('files', {}))} seeds") except Exception as e: - print(f"[ERROR] Failed to load seed from URL: {e}") - return None + logger.error(f"Failed to save cache: {e}") + + +async def rescan_seeds() -> dict: + """Scan seed directories and run safety checks on all images.""" + logger.info("Starting seed directory scan...") + cache = {"files": {}, "last_scan": time.time()} + + # Scan both default and uploads directories + all_seeds = list(DEFAULT_SEEDS_DIR.glob("*.png")) + list(UPLOADS_DIR.glob("*.png")) + logger.info(f"Found {len(all_seeds)} seed images") + + if not all_seeds: + save_seeds_cache(cache) + logger.info("Scan complete: 0 seeds processed") + return cache + + # Compute hashes for all files + logger.info("Computing file hashes...") + hash_tasks = [asyncio.to_thread(compute_file_hash, str(p)) for p in all_seeds] + file_hashes = await asyncio.gather(*hash_tasks, return_exceptions=True) + + # Run batch safety check (model loads once, processes in batches, then unloads) + logger.info("Running batch safety check...") + image_paths = [str(p) for p in all_seeds] + safety_results = await asyncio.to_thread(safety_checker.check_batch, image_paths) + + # Build cache from results + checked_at = time.time() + for i, seed_path in enumerate(all_seeds): + filename = seed_path.name + file_hash = file_hashes[i] if not isinstance(file_hashes[i], Exception) else "" + safety_result = safety_results[i] + + if isinstance(file_hashes[i], Exception): + logger.error(f"Failed to hash {filename}: {file_hashes[i]}") + cache["files"][filename] = { + "hash": "", + "is_safe": False, + "path": str(seed_path), + "error": str(file_hashes[i]), + "checked_at": checked_at, + } + else: + cache["files"][filename] = { + "hash": file_hash, + "is_safe": safety_result.get("is_safe", False), + "path": str(seed_path), + "scores": safety_result.get("scores", {}), + "checked_at": checked_at, + } + + status = "✓ SAFE" if safety_result.get("is_safe") else "✗ UNSAFE" + logger.info(f" {filename}: {status}") + + save_seeds_cache(cache) + logger.info(f"Scan complete: {len(cache['files'])} seeds processed") + return cache -def load_engine(): - """Initialize the WorldEngine with configured model.""" - global engine, seed_frame, CtrlInput +# ============================================================================ +# Application Lifecycle +# ============================================================================ + + +@asynccontextmanager +async def lifespan(app: FastAPI): + """Startup/shutdown lifecycle handler.""" + global world_engine, safety_checker, safe_seeds_cache, rescan_lock logger.info("=" * 60) - logger.info("BIOME ENGINE STARTUP") + logger.info("BIOME SERVER STARTUP") logger.info("=" * 60) - logger.info("[1/4] Importing WorldEngine...") - import_start = time.perf_counter() - from world_engine import CtrlInput as CI - from world_engine import WorldEngine + # Initialize lock for rescan operations + rescan_lock = asyncio.Lock() - CtrlInput = CI - logger.info( - f"[1/4] WorldEngine imported in {time.perf_counter() - import_start:.2f}s" - ) + # Initialize modules + logger.info("Initializing WorldEngine...") + world_engine = WorldEngineManager() - logger.info(f"[2/4] Loading model: {MODEL_URI}") - logger.info(f" Quantization: {QUANT}") - logger.info(f" Device: {DEVICE}") - logger.info(f" N_FRAMES: {N_FRAMES}") - logger.info(f" Prompt: {DEFAULT_PROMPT[:60]}...") - - # Model config overrides - # scheduler_sigmas: diffusion denoising schedule (MUST end with 0.0) - # ae_uri: VAE model for encoding/decoding frames - model_start = time.perf_counter() - engine = WorldEngine( - MODEL_URI, - device=DEVICE, - model_config_overrides={ - "n_frames": N_FRAMES, - "ae_uri": "OpenWorldLabs/owl_vae_f16_c16_distill_v0_nogan", - "scheduler_sigmas": [1.0, 0.8, 0.2, 0.0], - }, - quant=QUANT, - dtype=torch.bfloat16, - ) - logger.info(f"[2/4] Model loaded in {time.perf_counter() - model_start:.2f}s") + logger.info("Initializing Safety Checker...") + safety_checker = SafetyChecker() + + # Load WorldEngine on startup + await world_engine.load_engine() - # Seed frame will be provided by frontend via set_initial_seed message - logger.info("[3/4] Seed frame: waiting for client to provide initial seed via base64") - seed_frame = None + # Initialize seed management system + logger.info("Initializing server-side seed storage...") + ensure_seed_directories() + await download_default_seeds() + + # Load or create seed cache + cache = load_seeds_cache() + if not cache.get("files"): + logger.info("Cache empty, scanning seed directories...") + cache = await rescan_seeds() + else: + logger.info(f"Using cached seed data ({len(cache.get('files', {}))} seeds)") + + # Update global cache (map filename -> metadata) + safe_seeds_cache = cache.get("files", {}) - logger.info("[4/4] Engine initialization complete") logger.info("=" * 60) - logger.info("SERVER READY - Waiting for WebSocket connections on /ws") - logger.info(" (Client must send set_initial_seed with base64 data)") + logger.info("[SERVER] Ready - WorldEngine and Safety modules loaded") + logger.info(f"[SERVER] {len(safe_seeds_cache)} seeds available") logger.info("=" * 60) + print("SERVER READY", flush=True) # Signal for Rust to detect + + yield + + # Cleanup + logger.info("[SERVER] Shutting down") + + +app = FastAPI(title="Biome Server", lifespan=lifespan) # ============================================================================ -# Frame Encoding +# Utilities # ============================================================================ -def frame_to_jpeg(frame: torch.Tensor, quality: int = JPEG_QUALITY) -> bytes: - """Convert frame tensor to JPEG bytes.""" - if frame.dtype != torch.uint8: - frame = frame.clamp(0, 255).to(torch.uint8) - img = Image.fromarray(frame.cpu().numpy(), mode="RGB") - buf = io.BytesIO() - img.save(buf, format="JPEG", quality=quality) - return buf.getvalue() +def compute_file_hash(file_path: str) -> str: + """Compute SHA256 hash of a file.""" + sha256 = hashlib.sha256() + with open(file_path, "rb") as f: + for chunk in iter(lambda: f.read(8192), b""): + sha256.update(chunk) + return sha256.hexdigest() # ============================================================================ -# Session Management +# Health Endpoints # ============================================================================ -@dataclass -class Session: - """Tracks state for a single WebSocket connection.""" +@app.get("/health") +async def health(): + """Health check for Tauri backend.""" + return JSONResponse( + { + "status": "ok", + "world_engine": { + "loaded": world_engine.engine is not None, + "warmed_up": world_engine.engine_warmed_up, + "has_seed": world_engine.seed_frame is not None, + }, + "safety": {"loaded": safety_checker.model is not None}, + } + ) + + +# ============================================================================ +# Safety Endpoints +# ============================================================================ + + +class CheckImageRequest(BaseModel): + path: str - frame_count: int = 0 - max_frames: int = N_FRAMES - 2 + +class CheckBatchRequest(BaseModel): + paths: list[str] + + +@app.post("/safety/check_image") +async def check_image(request: CheckImageRequest): + """Check single image for NSFW content.""" + try: + result = safety_checker.check_image(request.path) + return JSONResponse(result) + except Exception as e: + logger.error(f"Safety check failed: {e}") + return JSONResponse({"error": str(e)}, status_code=500) + + +@app.post("/safety/check_batch") +async def check_batch(request: CheckBatchRequest): + """Check multiple images for NSFW content.""" + try: + results = safety_checker.check_batch(request.paths) + return JSONResponse({"results": results}) + except Exception as e: + logger.error(f"Safety batch check failed: {e}") + return JSONResponse({"error": str(e)}, status_code=500) + + +class SetCacheRequest(BaseModel): + seeds: dict[str, dict] # filename -> {hash, is_safe, path} + + +@app.post("/safety/set_cache") +async def set_cache(request: SetCacheRequest): + """Receive safety cache from Rust on startup.""" + global safe_seeds_cache + safe_seeds_cache = request.seeds + logger.info(f"Safety cache updated: {len(safe_seeds_cache)} seeds loaded") + return JSONResponse({"status": "ok", "count": len(safe_seeds_cache)}) # ============================================================================ -# FastAPI Application +# Seed Management Endpoints # ============================================================================ -@asynccontextmanager -async def lifespan(app: FastAPI): - """Lifespan event handler for startup and shutdown.""" - # Startup - load_engine() - yield - # Shutdown (if needed in the future) +@app.get("/seeds/list") +async def list_seeds(): + """Return list of all available seeds with metadata (only safe ones).""" + safe_only = { + filename: { + "filename": filename, + "hash": data["hash"], + "is_safe": data["is_safe"], + "checked_at": data.get("checked_at", 0), + } + for filename, data in safe_seeds_cache.items() + if data.get("is_safe", False) + } + return JSONResponse({"seeds": safe_only, "count": len(safe_only)}) -app = FastAPI(title="WorldEngine WebSocket Server", lifespan=lifespan) +@app.get("/seeds/image/{filename}") +async def get_seed_image(filename: str): + """Serve full PNG seed image.""" + from fastapi.responses import FileResponse + # Validate filename is in cache and safe + if filename not in safe_seeds_cache: + return JSONResponse({"error": "Seed not found"}, status_code=404) -@app.get("/health") -async def health(): - return JSONResponse( - { - "status": "healthy", - "model": MODEL_URI, - "quant": QUANT, - "engine_loaded": engine is not None, + seed_data = safe_seeds_cache[filename] + if not seed_data.get("is_safe", False): + return JSONResponse({"error": "Seed marked unsafe"}, status_code=403) + + file_path = seed_data.get("path", "") + if not os.path.exists(file_path): + return JSONResponse({"error": "Seed file not found"}, status_code=404) + + return FileResponse(file_path, media_type="image/png") + + +@app.get("/seeds/thumbnail/{filename}") +async def get_seed_thumbnail(filename: str): + """Serve 80x80 JPEG thumbnail of seed image.""" + import io + + # Validate filename is in cache and safe + if filename not in safe_seeds_cache: + return JSONResponse({"error": "Seed not found"}, status_code=404) + + seed_data = safe_seeds_cache[filename] + if not seed_data.get("is_safe", False): + return JSONResponse({"error": "Seed marked unsafe"}, status_code=403) + + file_path = seed_data.get("path", "") + if not os.path.exists(file_path): + return JSONResponse({"error": "Seed file not found"}, status_code=404) + + try: + # Generate thumbnail + img = await asyncio.to_thread(Image.open, file_path) + img.thumbnail((80, 80)) + + # Convert to JPEG in memory + buffer = io.BytesIO() + await asyncio.to_thread(img.save, buffer, format="JPEG", quality=85) + buffer.seek(0) + + from fastapi.responses import StreamingResponse + + return StreamingResponse(buffer, media_type="image/jpeg") + except Exception as e: + logger.error(f"Failed to generate thumbnail for {filename}: {e}") + return JSONResponse({"error": "Thumbnail generation failed"}, status_code=500) + + +class UploadSeedRequest(BaseModel): + filename: str + data: str # base64 encoded PNG + + +@app.post("/seeds/upload") +async def upload_seed(request: UploadSeedRequest): + """Upload a custom seed image (will be safety checked).""" + global safe_seeds_cache + + filename = request.filename + if not filename.endswith(".png"): + return JSONResponse({"error": "Only PNG files supported"}, status_code=400) + + # Decode base64 + try: + image_data = base64.b64decode(request.data) + except Exception as e: + return JSONResponse({"error": f"Invalid base64 data: {e}"}, status_code=400) + + # Save to uploads directory + file_path = UPLOADS_DIR / filename + await asyncio.to_thread(file_path.write_bytes, image_data) + logger.info(f"Uploaded seed saved to {file_path}") + + # Compute hash + file_hash = await asyncio.to_thread(compute_file_hash, str(file_path)) + + # Run safety check + try: + safety_result = await asyncio.to_thread( + safety_checker.check_image, str(file_path) + ) + is_safe = safety_result.get("is_safe", False) + + # Update cache + safe_seeds_cache[filename] = { + "hash": file_hash, + "is_safe": is_safe, + "path": str(file_path), + "scores": safety_result.get("scores", {}), + "checked_at": time.time(), } - ) + + # Save to disk + cache = load_seeds_cache() + cache["files"] = safe_seeds_cache + save_seeds_cache(cache) + + status_msg = "SAFE" if is_safe else "UNSAFE" + logger.info(f"Uploaded seed {filename}: {status_msg}") + + return JSONResponse( + { + "status": "ok", + "filename": filename, + "hash": file_hash, + "is_safe": is_safe, + "scores": safety_result.get("scores", {}), + } + ) + + except Exception as e: + logger.error(f"Safety check failed for uploaded seed: {e}") + # Delete the file if safety check failed + if file_path.exists(): + file_path.unlink() + return JSONResponse( + {"error": f"Safety check failed: {e}"}, status_code=500 + ) + + +@app.post("/seeds/rescan") +async def rescan_seeds_endpoint(): + """Trigger a rescan of all seed directories.""" + global safe_seeds_cache + + # Prevent concurrent rescans + if rescan_lock.locked(): + logger.warning("Rescan already in progress, rejecting duplicate request") + return JSONResponse( + { + "status": "busy", + "message": "Rescan already in progress", + }, + status_code=409, + ) + + async with rescan_lock: + logger.info("Manual rescan triggered") + cache = await rescan_seeds() + safe_seeds_cache = cache.get("files", {}) + + safe_count = sum(1 for data in safe_seeds_cache.values() if data.get("is_safe")) + return JSONResponse( + { + "status": "ok", + "total_seeds": len(safe_seeds_cache), + "safe_seeds": safe_count, + } + ) + + +@app.delete("/seeds/{filename}") +async def delete_seed(filename: str): + """Delete a custom seed (only from uploads directory).""" + global safe_seeds_cache + + if filename not in safe_seeds_cache: + return JSONResponse({"error": "Seed not found"}, status_code=404) + + seed_data = safe_seeds_cache[filename] + file_path = Path(seed_data.get("path", "")) + + # Only allow deleting from uploads directory + if not str(file_path).startswith(str(UPLOADS_DIR)): + return JSONResponse( + {"error": "Cannot delete default seeds"}, status_code=403 + ) + + try: + if file_path.exists(): + await asyncio.to_thread(file_path.unlink) + del safe_seeds_cache[filename] + + # Update cache file + cache = load_seeds_cache() + cache["files"] = safe_seeds_cache + save_seeds_cache(cache) + + logger.info(f"Deleted seed: {filename}") + return JSONResponse({"status": "ok", "deleted": filename}) + + except Exception as e: + logger.error(f"Failed to delete seed {filename}: {e}") + return JSONResponse({"error": str(e)}, status_code=500) + + +# ============================================================================ +# WorldEngine WebSocket +# ============================================================================ # Status codes (client maps these to display text) class Status: - WAITING_FOR_SEED = "waiting_for_seed" # Waiting for initial seed from client - INIT = "init" # Engine resetting - LOADING = "loading" # Loading seed frame - READY = "ready" # Ready for game loop - RESET = "reset" # Session reset + WAITING_FOR_SEED = "waiting_for_seed" + INIT = "init" + LOADING = "loading" + READY = "ready" + RESET = "reset" + WARMUP = "warmup" @app.websocket("/ws") @@ -290,10 +639,14 @@ async def websocket_endpoint(websocket: WebSocket): Client -> Server: {"type": "control", "buttons": [str], "mouse_dx": float, "mouse_dy": float, "ts": float} {"type": "reset"} + {"type": "set_initial_seed", "filename": str} + {"type": "prompt", "prompt": str} + {"type": "prompt_with_seed", "filename": str} + {"type": "pause"} + {"type": "resume"} - Status codes: init, loading, ready, reset + Status codes: waiting_for_seed, init, loading, ready, reset, warmup """ - global seed_frame, current_prompt, engine_warmed_up client_host = websocket.client.host if websocket.client else "unknown" logger.info(f"Client connected: {client_host}") @@ -304,9 +657,7 @@ async def send_json(data: dict): await websocket.send_text(json.dumps(data)) async def reset_engine(): - await asyncio.to_thread(engine.reset) - await asyncio.to_thread(engine.append_frame, seed_frame) - await asyncio.to_thread(engine.set_prompt, current_prompt) + await world_engine.reset_state() session.frame_count = 0 await send_json({"type": "status", "code": Status.RESET}) logger.info(f"[{client_host}] Engine Reset") @@ -317,93 +668,104 @@ async def reset_engine(): logger.info(f"[{client_host}] Waiting for initial seed from client...") # Wait for set_initial_seed message - while seed_frame is None: + while world_engine.seed_frame is None: try: raw = await asyncio.wait_for(websocket.receive_text(), timeout=60.0) msg = json.loads(raw) msg_type = msg.get("type") if msg_type == "set_initial_seed": - seed_base64 = msg.get("seed_base64") - if seed_base64: - logger.info(f"[{client_host}] Received initial seed ({len(seed_base64)} chars)") - loaded_frame = load_seed_from_base64(seed_base64) - if loaded_frame is not None: - seed_frame = loaded_frame - logger.info(f"[{client_host}] Initial seed loaded successfully") - else: - await send_json({"type": "error", "message": "Failed to decode seed image"}) - else: - await send_json({"type": "error", "message": "No seed_base64 provided"}) - else: - logger.info(f"[{client_host}] Ignoring message type '{msg_type}' while waiting for seed") + filename = msg.get("filename") - except asyncio.TimeoutError: - await send_json({"type": "error", "message": "Timeout waiting for initial seed"}) - return + if not filename: + await send_json( + {"type": "error", "message": "Missing filename"} + ) + continue - # Warmup on first connection AFTER seed is loaded (CUDA graphs require same thread context) - if not engine_warmed_up: - logger.info("=" * 60) - logger.info( - "[5/5] WARMUP - First client connected, initializing CUDA graphs..." - ) - logger.info("=" * 60) - await send_json({"type": "status", "code": "warmup"}) - - def do_warmup(): - warmup_start = time.perf_counter() - - logger.info("[5/5] Step 1: Resetting engine state...") - reset_start = time.perf_counter() - engine.reset() - logger.info( - f"[5/5] Step 1: Reset complete in {time.perf_counter() - reset_start:.2f}s" - ) + # Verify seed is in safety cache and is safe + if filename not in safe_seeds_cache: + logger.warning(f"[{client_host}] Seed '{filename}' not in safety cache") + await send_json( + {"type": "error", "message": f"Seed '{filename}' not in safety cache"} + ) + continue - logger.info("[5/5] Step 2: Appending seed frame...") - append_start = time.perf_counter() - engine.append_frame(seed_frame) - logger.info( - f"[5/5] Step 2: Seed frame appended in {time.perf_counter() - append_start:.2f}s" - ) + cached_entry = safe_seeds_cache[filename] - logger.info("[5/5] Step 3: Setting prompt...") - prompt_start = time.perf_counter() - engine.set_prompt(current_prompt) - logger.info( - f"[5/5] Step 3: Prompt set in {time.perf_counter() - prompt_start:.2f}s" - ) + if not cached_entry.get("is_safe", False): + logger.warning(f"[{client_host}] Seed '{filename}' marked as unsafe") + await send_json( + {"type": "error", "message": f"Seed '{filename}' marked as unsafe"} + ) + continue - logger.info( - "[5/5] Step 4: Generating first frame (compiling CUDA graphs)..." - ) - gen_start = time.perf_counter() - _ = engine.gen_frame(ctrl=CtrlInput(button=set(), mouse=(0.0, 0.0))) - logger.info( - f"[5/5] Step 4: First frame generated in {time.perf_counter() - gen_start:.2f}s" - ) + # Get cached hash and file path + cached_hash = cached_entry.get("hash", "") + file_path = cached_entry.get("path", "") - return time.perf_counter() - warmup_start + # Verify file exists + if not os.path.exists(file_path): + logger.error(f"[{client_host}] Seed file not found: {file_path}") + await send_json( + {"type": "error", "message": f"Seed file not found: {filename}"} + ) + continue - warmup_time = await asyncio.to_thread(do_warmup) - logger.info("=" * 60) - logger.info(f"[5/5] WARMUP COMPLETE - Total time: {warmup_time:.2f}s") - logger.info("=" * 60) - engine_warmed_up = True + # Verify file integrity (check if file on disk matches cached hash) + actual_hash = await asyncio.to_thread(compute_file_hash, file_path) + if actual_hash != cached_hash: + logger.warning( + f"[{client_host}] File integrity check failed for '{filename}' - file may have been modified" + ) + await send_json( + {"type": "error", "message": "File integrity verification failed - please rescan seeds"} + ) + continue + + # All checks passed - load the seed + logger.info(f"[{client_host}] Loading initial seed '{filename}'") + loaded_frame = await asyncio.to_thread( + world_engine.load_seed_from_file, file_path + ) + + if loaded_frame is not None: + world_engine.seed_frame = loaded_frame + logger.info(f"[{client_host}] Initial seed loaded successfully") + else: + await send_json( + {"type": "error", "message": "Failed to load seed image"} + ) + else: + logger.info( + f"[{client_host}] Ignoring message type '{msg_type}' while waiting for seed" + ) + + except asyncio.TimeoutError: + await send_json( + {"type": "error", "message": "Timeout waiting for initial seed"} + ) + return + + # Warmup on first connection AFTER seed is loaded + if not world_engine.engine_warmed_up: + await send_json({"type": "status", "code": Status.WARMUP}) + await world_engine.warmup() await send_json({"type": "status", "code": Status.INIT}) logger.info(f"[{client_host}] Calling engine.reset()...") - await asyncio.to_thread(engine.reset) + await asyncio.to_thread(world_engine.engine.reset) await send_json({"type": "status", "code": Status.LOADING}) logger.info(f"[{client_host}] Calling append_frame...") - await asyncio.to_thread(engine.append_frame, seed_frame) + await asyncio.to_thread(world_engine.engine.append_frame, world_engine.seed_frame) # Send initial frame so client has something to display - jpeg = await asyncio.to_thread(frame_to_jpeg, seed_frame) + jpeg = await asyncio.to_thread( + world_engine.frame_to_jpeg, world_engine.seed_frame + ) await send_json( { "type": "frame", @@ -443,12 +805,11 @@ async def get_latest_control(): except asyncio.TimeoutError: # No more messages in queue - # if skipped_count > 0: - # logger.info(f"[{client_host}] Skipped {skipped_count} queued inputs, using latest") return latest_control_msg except WebSocketDisconnect: raise + # Main game loop while True: try: msg = await get_latest_control() @@ -465,61 +826,133 @@ async def get_latest_control(): logger.info(f"[{client_host}] Reset requested") await reset_engine() continue + case "pause": - # don't really have to do anything special for pausing paused = True logger.info("[RECV] Paused") + case "resume": - # don't really have to do anything special for resuming paused = False logger.info("[RECV] Resumed") + case "prompt": new_prompt = msg.get("prompt", "").strip() logger.info(f"[RECV] Prompt received: '{new_prompt[:50]}...'") try: - current_prompt = new_prompt if new_prompt else DEFAULT_PROMPT + from engine_manager import DEFAULT_PROMPT + + world_engine.current_prompt = ( + new_prompt if new_prompt else DEFAULT_PROMPT + ) await reset_engine() except Exception as e: - logger.info(f"[GEN] Failed to set prompt: {e}") + logger.error(f"[GEN] Failed to set prompt: {e}") + case "prompt_with_seed": - new_prompt = msg.get("prompt", "").strip() - seed_url = msg.get("seed_url") - logger.info( - f"[RECV] Prompt with seed: '{new_prompt}', URL: {seed_url}" - ) + # Load new seed mid-session (server verifies against cache) + filename = msg.get("filename") + logger.info(f"[RECV] prompt_with_seed: filename={filename}") + try: - if seed_url: - url_frame = load_seed_from_url(seed_url) - if url_frame is not None: - seed_frame = url_frame - logger.info("[RECV] Seed frame loaded from URL") - current_prompt = new_prompt if new_prompt else DEFAULT_PROMPT - logger.info( - "[RECV] Seed frame prompt loaded from URL, resetting engine" + if not filename: + await send_json( + { + "type": "error", + "message": "Missing filename", + } + ) + continue + + # Check if seed is in safety cache + if filename not in safe_seeds_cache: + logger.warning( + f"[RECV] Seed '{filename}' not in safety cache" + ) + await send_json( + { + "type": "error", + "message": f"Seed '{filename}' not in safety cache", + } + ) + continue + + cached_entry = safe_seeds_cache[filename] + + # Verify is_safe flag + if not cached_entry.get("is_safe", False): + logger.warning( + f"[RECV] Seed '{filename}' marked as unsafe in cache" + ) + await send_json( + { + "type": "error", + "message": f"Seed '{filename}' marked as unsafe", + } + ) + continue + + # Get cached hash and file path + cached_hash = cached_entry.get("hash", "") + file_path = cached_entry.get("path", "") + + # Verify file exists + if not os.path.exists(file_path): + logger.error(f"[RECV] Seed file not found: {file_path}") + await send_json( + { + "type": "error", + "message": f"Seed file not found: {filename}", + } + ) + continue + + # Verify file integrity (check if file on disk matches cached hash) + actual_hash = await asyncio.to_thread( + compute_file_hash, file_path ) - await reset_engine() - except Exception as e: - logger.info(f"[GEN] Failed to set prompt: {e}") - case "set_initial_seed": - # Allow updating the seed mid-session - seed_base64 = msg.get("seed_base64") - logger.info(f"[RECV] set_initial_seed received ({len(seed_base64) if seed_base64 else 0} chars)") - try: - if seed_base64: - loaded_frame = load_seed_from_base64(seed_base64) - if loaded_frame is not None: - seed_frame = loaded_frame - logger.info("[RECV] Seed frame updated from base64") - await reset_engine() - else: - await send_json({"type": "error", "message": "Failed to decode seed image"}) + if actual_hash != cached_hash: + logger.warning( + f"[RECV] File integrity check failed for '{filename}' - file may have been modified" + ) + await send_json( + { + "type": "error", + "message": "File integrity verification failed - please rescan seeds", + } + ) + continue + + # All checks passed - load the seed + logger.info(f"[RECV] Loading seed '{filename}' from {file_path}") + loaded_frame = await asyncio.to_thread( + world_engine.load_seed_from_file, file_path + ) + + if loaded_frame is not None: + world_engine.seed_frame = loaded_frame + logger.info(f"[RECV] Seed '{filename}' loaded successfully") + await reset_engine() else: - await send_json({"type": "error", "message": "No seed_base64 provided"}) + await send_json( + { + "type": "error", + "message": f"Failed to load seed image: {filename}", + } + ) + except Exception as e: - logger.info(f"[GEN] Failed to set seed: {e}") + logger.error(f"[GEN] Failed to set seed: {e}") + await send_json( + { + "type": "error", + "message": f"Failed to set seed: {str(e)}", + } + ) + case "control": if paused: continue + buttons = { BUTTON_CODES[b.upper()] for b in msg.get("buttons", []) @@ -533,16 +966,18 @@ async def get_latest_control(): logger.info(f"[{client_host}] Auto-reset at frame limit") await reset_engine() - ctrl = CtrlInput(button=buttons, mouse=(mouse_dx, mouse_dy)) + ctrl = world_engine.CtrlInput( + button=buttons, mouse=(mouse_dx, mouse_dy) + ) t0 = time.perf_counter() - frame = await asyncio.to_thread(engine.gen_frame, ctrl=ctrl) + frame = await world_engine.generate_frame(ctrl) gen_time = (time.perf_counter() - t0) * 1000 session.frame_count += 1 # Encode and send frame with timing info - jpeg = await asyncio.to_thread(frame_to_jpeg, frame) + jpeg = await asyncio.to_thread(world_engine.frame_to_jpeg, frame) await send_json( { "type": "frame", @@ -576,7 +1011,7 @@ async def get_latest_control(): if __name__ == "__main__": import argparse - parser = argparse.ArgumentParser(description="WorldEngine WebSocket Server") + parser = argparse.ArgumentParser(description="Biome Server") parser.add_argument("--host", default="0.0.0.0", help="Host to bind to") parser.add_argument("--port", type=int, default=7987, help="Port to bind to") args = parser.parse_args() diff --git a/src-tauri/src/lib.rs b/src-tauri/src/lib.rs index af49715..a748342 100644 --- a/src-tauri/src/lib.rs +++ b/src-tauri/src/lib.rs @@ -1,7 +1,7 @@ use base64::{Engine as _, engine::general_purpose::STANDARD as BASE64_STANDARD}; use serde::{Deserialize, Serialize}; use std::fs::{self, File, OpenOptions}; -use std::io::{self, BufRead, BufReader, Cursor, Read, Write}; +use std::io::{self, BufRead, BufReader, Cursor, Write}; use std::path::{Path, PathBuf}; use std::process::{Child, Command, Stdio}; use std::sync::Mutex; @@ -14,8 +14,6 @@ use tar::Archive; const CONFIG_FILENAME: &str = "config.json"; const WORLD_ENGINE_DIR: &str = "world_engine"; -const DEFAULT_SEEDS_DIR: &str = "default_seeds"; -const CUSTOM_SEEDS_DIR: &str = "custom_seeds"; const UV_VERSION: &str = "0.9.26"; // Port 7987 = 'O' (79) + 'W' (87) in ASCII const STANDALONE_PORT: u16 = 7987; @@ -62,6 +60,8 @@ fn get_app_handle() -> Option<&'static tauri::AppHandle> { APP_HANDLE.get() } +// Safety cache structures removed - now handled server-side + /// Get the executable's directory (for portable data storage) fn get_exe_dir() -> Result { let exe_path = @@ -671,157 +671,131 @@ async fn open_engine_dir(app: tauri::AppHandle) -> Result<(), String> { } // ============================================================================ -// Seeds Management +// Seeds Management (Server-Authoritative) // ============================================================================ -// Get the default seeds directory path (bundled seeds, next to executable) -fn get_default_seeds_dir(_app: &tauri::AppHandle) -> Result { - let exe_dir = get_exe_dir()?; - Ok(exe_dir.join(DEFAULT_SEEDS_DIR)) -} - -// Get the custom seeds directory path (user seeds, next to executable) -fn get_custom_seeds_dir(_app: &tauri::AppHandle) -> Result { - let exe_dir = get_exe_dir()?; - Ok(exe_dir.join(CUSTOM_SEEDS_DIR)) -} +const SERVER_BASE_URL: &str = "http://localhost:7987"; -/// Find which directory contains a seed file (checks custom first, then default) -fn find_seed_path(app: &tauri::AppHandle, filename: &str) -> Result { - // Validate filename doesn't contain path traversal - if filename.contains("..") || filename.contains('/') || filename.contains('\\') { - return Err(format!("Invalid seed filename: {}", filename)); - } +/// Initialize seeds (triggers server-side rescan) +#[tauri::command] +async fn initialize_seeds(_app: tauri::AppHandle) -> Result { + log::info!("Triggering server-side seed scan..."); - let custom_dir = get_custom_seeds_dir(app)?; - let custom_path = custom_dir.join(filename); - if custom_path.exists() { - return Ok(custom_path); - } + let client = reqwest::Client::new(); + let response = client + .post(format!("{}/seeds/rescan", SERVER_BASE_URL)) + .send() + .await + .map_err(|e| format!("Failed to contact server: {}", e))?; - let default_dir = get_default_seeds_dir(app)?; - let default_path = default_dir.join(filename); - if default_path.exists() { - return Ok(default_path); + if !response.status().is_success() { + return Err(format!("Server returned error: {}", response.status())); } - Err(format!("Seed file not found: {}", filename)) -} - -/// Initialize seeds directories (creates custom_seeds folder for user seeds) -#[tauri::command] -async fn initialize_seeds(app: tauri::AppHandle) -> Result { - let custom_seeds_dir = get_custom_seeds_dir(&app)?; - - // Create custom seeds directory if it doesn't exist - if !custom_seeds_dir.exists() { - fs::create_dir_all(&custom_seeds_dir) - .map_err(|e| format!("Failed to create custom seeds dir: {}", e))?; - } + let result: serde_json::Value = response + .json() + .await + .map_err(|e| format!("Failed to parse response: {}", e))?; Ok(format!( - "Seeds initialized: custom_seeds directory at {}", - custom_seeds_dir.display() + "Seeds initialized: {} total, {} safe", + result["total_seeds"].as_u64().unwrap_or(0), + result["safe_seeds"].as_u64().unwrap_or(0) )) } -/// List available seed filenames (png/jpg/jpeg) from both default and custom directories +/// List available seeds from server #[tauri::command] -async fn list_seeds(app: tauri::AppHandle) -> Result, String> { - use std::collections::HashSet; - - let mut seeds: HashSet = HashSet::new(); - - // Helper to collect seeds from a directory - let collect_seeds = |dir: &Path, seeds: &mut HashSet| { - let Ok(entries) = fs::read_dir(dir) else { - return; - }; - for entry in entries.flatten() { - let path = entry.path(); - let Some(ext) = path.extension() else { - continue; - }; - let ext_lower = ext.to_string_lossy().to_lowercase(); - if (ext_lower == "png" || ext_lower == "jpg" || ext_lower == "jpeg") - && let Some(filename) = path.file_name() - { - seeds.insert(filename.to_string_lossy().to_string()); - } - } - }; +async fn list_seeds(_app: tauri::AppHandle) -> Result, String> { + let client = reqwest::Client::new(); + let response = client + .get(format!("{}/seeds/list", SERVER_BASE_URL)) + .send() + .await + .map_err(|e| format!("Failed to contact server: {}", e))?; - // Collect from default seeds - if let Ok(default_dir) = get_default_seeds_dir(&app) { - collect_seeds(&default_dir, &mut seeds); + if !response.status().is_success() { + return Err(format!("Server returned error: {}", response.status())); } - // Collect from custom seeds - if let Ok(custom_dir) = get_custom_seeds_dir(&app) { - collect_seeds(&custom_dir, &mut seeds); - } + let result: serde_json::Value = response + .json() + .await + .map_err(|e| format!("Failed to parse response: {}", e))?; - let mut seeds_vec: Vec = seeds.into_iter().collect(); - seeds_vec.sort(); - Ok(seeds_vec) + let seeds_obj = result["seeds"] + .as_object() + .ok_or_else(|| "Invalid response format".to_string())?; + + let mut filenames: Vec = seeds_obj.keys().cloned().collect(); + filenames.sort(); + Ok(filenames) } -/// Read a seed file and return base64 encoded data +/// Read a seed image as base64 from server #[tauri::command] -async fn read_seed_as_base64(app: tauri::AppHandle, filename: String) -> Result { - let seed_path = find_seed_path(&app, &filename)?; +async fn read_seed_as_base64(_app: tauri::AppHandle, filename: String) -> Result { + let client = reqwest::Client::new(); + let response = client + .get(format!("{}/seeds/image/{}", SERVER_BASE_URL, filename)) + .send() + .await + .map_err(|e| format!("Failed to contact server: {}", e))?; - let mut file = - File::open(&seed_path).map_err(|e| format!("Failed to open seed file: {}", e))?; + if !response.status().is_success() { + return Err(format!("Server returned error: {}", response.status())); + } - let mut buffer = Vec::new(); - file.read_to_end(&mut buffer) - .map_err(|e| format!("Failed to read seed file: {}", e))?; + let bytes = response + .bytes() + .await + .map_err(|e| format!("Failed to read image data: {}", e))?; - Ok(BASE64_STANDARD.encode(&buffer)) + Ok(BASE64_STANDARD.encode(&bytes)) } -/// Read a seed file and return a small thumbnail as base64 encoded JPEG +/// Read a seed thumbnail as base64 from server #[tauri::command] async fn read_seed_thumbnail( - app: tauri::AppHandle, + _app: tauri::AppHandle, filename: String, - max_size: Option, + _max_size: Option, ) -> Result { - let seed_path = find_seed_path(&app, &filename)?; - - // Load and resize image - let img = image::open(&seed_path).map_err(|e| format!("Failed to open image: {}", e))?; + let client = reqwest::Client::new(); + let response = client + .get(format!("{}/seeds/thumbnail/{}", SERVER_BASE_URL, filename)) + .send() + .await + .map_err(|e| format!("Failed to contact server: {}", e))?; - let max_dim = max_size.unwrap_or(80); - let thumbnail = img.thumbnail(max_dim, max_dim); + if !response.status().is_success() { + return Err(format!("Server returned error: {}", response.status())); + } - // Encode as JPEG - let mut buffer = Vec::new(); - let mut cursor = Cursor::new(&mut buffer); - thumbnail - .write_to(&mut cursor, image::ImageFormat::Jpeg) - .map_err(|e| format!("Failed to encode thumbnail: {}", e))?; + let bytes = response + .bytes() + .await + .map_err(|e| format!("Failed to read thumbnail data: {}", e))?; - Ok(BASE64_STANDARD.encode(&buffer)) + Ok(BASE64_STANDARD.encode(&bytes)) } -/// Get the custom seeds directory path (where users add their seeds) +/// Get the seeds directory path (server-side) #[tauri::command] -fn get_seeds_dir_path(app: tauri::AppHandle) -> Result { - let seeds_dir = get_custom_seeds_dir(&app)?; - Ok(seeds_dir.to_string_lossy().to_string()) +fn get_seeds_dir_path(_app: tauri::AppHandle) -> Result { + // Return server-side path for information purposes + Ok("world_engine/seeds/uploads".to_string()) } -/// Open the custom seeds directory in file explorer +/// Open the seeds directory in file explorer (server-side) #[tauri::command] -async fn open_seeds_dir(app: tauri::AppHandle) -> Result<(), String> { - let seeds_dir = get_custom_seeds_dir(&app)?; +async fn open_seeds_dir(_app: tauri::AppHandle) -> Result<(), String> { + let exe_dir = get_exe_dir()?; + let seeds_dir = exe_dir.join("world_engine").join("seeds").join("uploads"); // Create directory if it doesn't exist if !seeds_dir.exists() { - fs::create_dir_all(&seeds_dir) - .map_err(|e| format!("Failed to create custom seeds dir: {}", e))?; + fs::create_dir_all(&seeds_dir).map_err(|e| format!("Failed to create seeds dir: {}", e))?; } // Open File Explorer with seeds directory @@ -829,6 +803,8 @@ async fn open_seeds_dir(app: tauri::AppHandle) -> Result<(), String> { .map_err(|e| format!("Failed to open seeds directory: {}", e)) } +// Safety cache functions removed - now handled server-side + #[tauri::command] async fn start_engine_server(app: tauri::AppHandle, port: u16) -> Result { let engine_dir = get_engine_dir(&app)?; diff --git a/src/components/BottomPanel.jsx b/src/components/BottomPanel.jsx index a38c7a2..e2e27c0 100644 --- a/src/components/BottomPanel.jsx +++ b/src/components/BottomPanel.jsx @@ -204,13 +204,12 @@ const BottomPanel = ({ isOpen, isHidden, onToggleHidden }) => { loadBatch() }, [activeTab, seeds]) - // Handle seed selection - reset server, load full-size image, recapture cursor + // Handle seed selection - reset server and send filename (server loads from its storage) const handleSeedClick = async (filename) => { setSelectedSeed(filename) try { - const base64 = await invoke('read_seed_as_base64', { filename }) reset() - sendInitialSeed(base64) + sendInitialSeed(filename) requestPointerLock() } catch (err) { console.error('Failed to apply seed:', err) diff --git a/src/context/StreamingContext.jsx b/src/context/StreamingContext.jsx index a4ff396..739d8a1 100644 --- a/src/context/StreamingContext.jsx +++ b/src/context/StreamingContext.jsx @@ -132,23 +132,12 @@ export const StreamingProvider = ({ children }) => { // Send initial seed when server is waiting for it useEffect(() => { if (statusCode === 'waiting_for_seed' && isConnected) { - log.info('Server waiting for seed, sending default seed...') - getDefaultSeedBase64() - .then((seedBase64) => { - sendInitialSeed(seedBase64) - log.info('Initial seed sent to server') - }) - .catch((err) => { - log.error('Failed to get default seed:', err) - const errorMessage = err.message || String(err) - if (errorMessage.includes('default.png')) { - setEngineError('Required file "default.png" not found in seeds folder. Please add a default.png image.') - } else { - setEngineError('Failed to load seed image: ' + errorMessage) - } - }) + log.info('Server waiting for seed, sending default seed filename...') + // Just send the filename - server has the file + sendInitialSeed('default.png') + log.info('Initial seed filename sent to server') } - }, [statusCode, isConnected, getDefaultSeedBase64, sendInitialSeed]) + }, [statusCode, isConnected, sendInitialSeed]) // Pointer lock controls const requestPointerLock = useCallback(() => { diff --git a/src/hooks/useSeeds.js b/src/hooks/useSeeds.js index 8c51b70..6958875 100644 --- a/src/hooks/useSeeds.js +++ b/src/hooks/useSeeds.js @@ -15,8 +15,22 @@ export const useSeeds = () => { setIsLoading(true) setError(null) try { + // Show loading message during safety scan + log.info('Initializing seeds and running safety checks...') const result = await invoke('initialize_seeds') log.info('Seeds initialized:', result) + + // Parse result for unsafe count (new format: "Seeds initialized: X total, Y safe") + const match = result.match(/(\d+) total, (\d+) safe/) + if (match) { + const total = parseInt(match[1]) + const safe = parseInt(match[2]) + const unsafe = total - safe + if (unsafe > 0) { + log.warn(`${unsafe} seed images hidden due to safety check`) + } + } + // Refresh the list after initialization const seedList = await invoke('list_seeds') setSeeds(seedList) diff --git a/src/hooks/useWebSocket.js b/src/hooks/useWebSocket.js index bff0d8d..7a1ccbc 100644 --- a/src/hooks/useWebSocket.js +++ b/src/hooks/useWebSocket.js @@ -165,17 +165,17 @@ export const useWebSocket = () => { } }, []) - const sendPromptWithSeed = useCallback((prompt, seedUrl) => { + const sendPromptWithSeed = useCallback((filename) => { if (wsRef.current && wsRef.current.readyState === WebSocket.OPEN) { - wsRef.current.send(JSON.stringify({ type: 'prompt_with_seed', prompt, seed_url: seedUrl })) - log.info('Prompt with seed sent:', prompt, seedUrl) + wsRef.current.send(JSON.stringify({ type: 'prompt_with_seed', filename })) + log.info('Prompt with seed sent:', filename) } }, []) - const sendInitialSeed = useCallback((seedBase64) => { + const sendInitialSeed = useCallback((filename) => { if (wsRef.current && wsRef.current.readyState === WebSocket.OPEN) { - wsRef.current.send(JSON.stringify({ type: 'set_initial_seed', seed_base64: seedBase64 })) - log.info('Initial seed sent:', seedBase64.length, 'chars') + wsRef.current.send(JSON.stringify({ type: 'set_initial_seed', filename })) + log.info('Initial seed sent:', filename) } }, [])