From a8a75f23b8b700eafd7dda96afb60e16bddb6665 Mon Sep 17 00:00:00 2001 From: FishWoWater Date: Tue, 22 Apr 2025 17:51:38 +0800 Subject: [PATCH 1/2] add replicate support --- .env.example | 1 + .gitignore | 13 +- README.md | 25 +- blend_gaussians.py | 6 +- inpainting/__init__.py | 22 + .../flux_inpainter_server}/.gitattributes | 0 .../flux_inpainter_server}/README.md | 0 .../flux_inpainter_server}/app.py | 0 .../flux_inpainter_server}/controlnet_flux.py | 0 .../flux_inpainter_server}/inpaint.py | 0 .../flux_inpainter_server}/main.py | 18 + .../pipeline_flux_controlnet_inpaint.py | 0 .../flux_inpainter_server}/readme.md | 0 .../flux_inpainter_server}/requirements.txt | 0 .../transformer_flux.py | 0 inpainting/replicate_inpainter/__init__.py | 4 + inpainting/replicate_inpainter/base.py | 52 ++ inpainting/replicate_inpainter/flux.py | 21 + inpainting/replicate_inpainter/sdxl.py | 19 + inpainting_server.sh | 2 +- run_pipeline.py | 658 +++++++++++++----- 21 files changed, 648 insertions(+), 193 deletions(-) create mode 100644 .env.example create mode 100644 inpainting/__init__.py rename {FLUX_inpainting_server => inpainting/flux_inpainter_server}/.gitattributes (100%) rename {FLUX_inpainting_server => inpainting/flux_inpainter_server}/README.md (100%) rename {FLUX_inpainting_server => inpainting/flux_inpainter_server}/app.py (100%) rename {FLUX_inpainting_server => inpainting/flux_inpainter_server}/controlnet_flux.py (100%) rename {FLUX_inpainting_server => inpainting/flux_inpainter_server}/inpaint.py (100%) rename {FLUX_inpainting_server => inpainting/flux_inpainter_server}/main.py (69%) rename {FLUX_inpainting_server => inpainting/flux_inpainter_server}/pipeline_flux_controlnet_inpaint.py (100%) rename {FLUX_inpainting_server => inpainting/flux_inpainter_server}/readme.md (100%) rename {FLUX_inpainting_server => inpainting/flux_inpainter_server}/requirements.txt (100%) rename {FLUX_inpainting_server => inpainting/flux_inpainter_server}/transformer_flux.py (100%) create mode 100644 inpainting/replicate_inpainter/__init__.py create mode 100644 inpainting/replicate_inpainter/base.py create mode 100644 inpainting/replicate_inpainter/flux.py create mode 100644 inpainting/replicate_inpainter/sdxl.py diff --git a/.env.example b/.env.example new file mode 100644 index 0000000..2500f2e --- /dev/null +++ b/.env.example @@ -0,0 +1 @@ +REPLICATE_API_TOKEN= \ No newline at end of file diff --git a/.gitignore b/.gitignore index d993b3d..3d2f9fb 100644 --- a/.gitignore +++ b/.gitignore @@ -398,7 +398,7 @@ FodyWeavers.xsd *.sln.iml # env files from the flux server -FLUX_inpainting_server/env/lib +inpainting/flux_inpainting_server/env/lib tmp* # SLURM output @@ -411,4 +411,13 @@ wandb/* *.pt *.pth -.vscode/ \ No newline at end of file +.vscode/ + +# Environment variables +.env + +# Output scenes +scenes + +# Blender +blender-3.6*/ \ No newline at end of file diff --git a/README.md b/README.md index 34423bd..451bf5c 100644 --- a/README.md +++ b/README.md @@ -13,7 +13,7 @@ SynCity generates complex and immersive 3D worlds from text prompts and does not ### Prerequisites * **System**: The code was tested on Ubuntu 22.04. We expect it to run on other Linux-based distributions too. -* **Hardware**: An NVIDIA GPU with at least **48GB** of memory is required. We have used A40 and A6000 GPUs. +* **Hardware**: If the inpainting server is deployed locally, an NVIDIA GPU with at least **48GB** of memory is required. We have used A40 and A6000 GPUs. If you use the inpainting service from replicate, you will need a GPU with at least **16GB** of memory for trellis generation. * **Software**: - The [CUDA Toolkit](https://developer.nvidia.com/cuda-toolkit-archive) is needed to compile certain submodules. We have tested CUDA versions 11.8 and 12.4. - [Conda](https://docs.anaconda.com/miniconda/install/#quick-command-line-install) is used to create the environment to run the code. This environment uses Python version 3.10. @@ -33,10 +33,23 @@ source ./setup.sh --new-env --basic --xformers --diffoctreerast --spconv --mipga ``` Make sure to have set the environment variable `CUDA_HOME`, which should point to your CUDA Toolkit installation. If you run into issues while running this setup script, please refer to [the README in the TRELLIS](https://github.com/microsoft/TRELLIS/blob/main/README.md#installation-steps) repository, which provides additional guidance. -3. Set up the FLUX inpainting server: + +3. Set up the FLUX inpainting backend: + +**Option1**: Set up the FLUX inpainting server locally (the server will require around **30GB+ VRAM**): ``` ./inpainting_server.sh --install ``` +**Option2**: Use replicate web deployment (pay as you go, about 0.03$ for every image and 0.4$ every 3x3 scene). +``` +cp .env.example .env + +# fill in your replicate id, which can be obtained here +# https://replicate.com/account/api-tokens + +# install requirements for replicate +pip install dotenv replicate +``` 4. Download [Blender 3.6.19](https://www.blender.org/download/release/Blender3.6/blender-3.6.19-linux-x64.tar.xz), extract it into the root directory of this project, and make sure `blender-3.6.19-linux-x64/blender` can be executed on your system. @@ -51,7 +64,11 @@ The process to generate a world is split into two straightforward steps. ### Step 1: Generating Tiles The tiles are generated using an instruction file, which contain the prompts to generate each tile (see some examples in the `instructions` folder). To generate a set of tiles that will be saved to `scenes/solarpunk`, run: ``` +# option1: locally deploy the inpainting server python run_pipeline.py --instructions instructions/3x3/solarpunk.json --prefix scenes/solarpunk --gradio_url=http://127.0.0.1:7860 + +# option2: use the replicate inpainting service +python run_pipeline.py --instructions instructions/3x3/solarpunk.json --prefix scenes/solarpunk --parallel=False --inpainter_type=flux_replicate ``` This script will parallelize tile generation where possible if multiple GPUs are available. If the script is stalling for longer than a minute, consider running the tile generation synchronously (`--parallel=False`). Furthermore, if a single tile keeps being regenerated, consider interrupting the script and replacing the offending tile's prompt. Then, restart the script with `--skip_existing=True` to ensure it will not overwrite existing tiles. Alternatively, see the ["Advanced Usage"](#advanced_usage) section on how to adjust the tile rejection criteria. @@ -59,7 +76,11 @@ This script will parallelize tile generation where possible if multiple GPUs are ### Step 2: Blending Tiles To create smooth transitions between tiles and refine their boundary regions, run the blending script: ``` +# option1: locally deploy the inpainting server python blend_gaussians.py --compute_rescaled --stitch_images --stitch_slats --gradio_url=http://127.0.0.1:7860 --prefix scenes/solarpunk + +# option2: use the replicate inpainting service +python blend_gaussians.py --compute_rescaled --stitch_images --stitch_slats --inpainter_type=flux_replicate --prefix scenes/solarpunk ``` This script will create a `.ply` file with the Gaussians of the entire grid as well as a video rendering. diff --git a/blend_gaussians.py b/blend_gaussians.py index 680b94d..5fbe1f8 100644 --- a/blend_gaussians.py +++ b/blend_gaussians.py @@ -12,6 +12,7 @@ import torch.nn.functional as F from PIL import Image from tqdm.auto import tqdm +from typing import Literal import trellis.models as models from tile_cutting import z_preserving_crop @@ -477,6 +478,7 @@ def merge_gaussians( stitch_images: bool = False, stitch_slats: bool = False, use_cached: bool = False, + inpainter_type: Literal["flux_local", "flux_replicate", "sdxl_replicate"] = "flux_local", gradio_url='http://127.0.0.1:7860', blender_path: str = 'blender-3.6.19-linux-x64/blender', seed: int = 429 @@ -577,9 +579,9 @@ def merge_gaussians( rescaled_tiles = dill.load(open(os.path.join(grid_path, 'rescaled_tiles.pkl'), 'rb')) if stitch_images: - from FLUX_inpainting_server.inpaint import Inpainter + from inpainting import Inpainter import time - inpainter = Inpainter(gradio_url) + inpainter = Inpainter(inpainter_type, gradio_url) VIEW_TYPE = 'zoom_out' prompts = json.load(open(os.path.join(grid_path, 'instructions.json'))) diff --git a/inpainting/__init__.py b/inpainting/__init__.py new file mode 100644 index 0000000..f70bfdd --- /dev/null +++ b/inpainting/__init__.py @@ -0,0 +1,22 @@ +from dotenv import load_dotenv +from typing import Literal +from PIL import Image +from .flux_inpainter_server.inpaint import Inpainter as FluxInpainter +from .replicate_inpainter import ReplicateFluxInpainter, ReplicateSDXLInpainter + +# load api keys of replicate +load_dotenv() + +class Inpainter: + def __init__(self, inpainter_type: Literal["flux_local", "flux_replicate", "sdxl_replicate"] = "flux_local", gradio_url: str = ""): + if inpainter_type == "flux_local": + self.inpainter = FluxInpainter(gradio_url) + elif inpainter_type == "flux_replicate": + self.inpainter = ReplicateFluxInpainter() + elif inpainter_type == "sdxl_replicate": + self.inpainter = ReplicateSDXLInpainter() + else: + raise ValueError(f"Invalid inpainter_type: {inpainter_type}") + + def __call__(self, image:Image.Image, mask:Image.Image, seed:int, prompt:str): + return self.inpainter(image, mask, seed, prompt) diff --git a/FLUX_inpainting_server/.gitattributes b/inpainting/flux_inpainter_server/.gitattributes similarity index 100% rename from FLUX_inpainting_server/.gitattributes rename to inpainting/flux_inpainter_server/.gitattributes diff --git a/FLUX_inpainting_server/README.md b/inpainting/flux_inpainter_server/README.md similarity index 100% rename from FLUX_inpainting_server/README.md rename to inpainting/flux_inpainter_server/README.md diff --git a/FLUX_inpainting_server/app.py b/inpainting/flux_inpainter_server/app.py similarity index 100% rename from FLUX_inpainting_server/app.py rename to inpainting/flux_inpainter_server/app.py diff --git a/FLUX_inpainting_server/controlnet_flux.py b/inpainting/flux_inpainter_server/controlnet_flux.py similarity index 100% rename from FLUX_inpainting_server/controlnet_flux.py rename to inpainting/flux_inpainter_server/controlnet_flux.py diff --git a/FLUX_inpainting_server/inpaint.py b/inpainting/flux_inpainter_server/inpaint.py similarity index 100% rename from FLUX_inpainting_server/inpaint.py rename to inpainting/flux_inpainter_server/inpaint.py diff --git a/FLUX_inpainting_server/main.py b/inpainting/flux_inpainter_server/main.py similarity index 69% rename from FLUX_inpainting_server/main.py rename to inpainting/flux_inpainter_server/main.py index 7419d95..280356c 100644 --- a/FLUX_inpainting_server/main.py +++ b/inpainting/flux_inpainter_server/main.py @@ -10,6 +10,7 @@ image_path='https://huggingface.co/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha/resolve/main/images/bucket.png', mask_path='https://huggingface.co/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha/resolve/main/images/bucket_mask.jpeg', prompt='a person wearing a white shoe, carrying a white bucket with text "FLUX" on it' +prompt_detailed = 'an ivy-covered red brick building with classical columns and arched windows, on top of a base, east coast university, ivy-clad red brick buildings, cobblestone paths, gentle autumn light, soft warm lighting, realistic textures, subtle gradients, isometric perspective, academic charm, and meticulous detailing' # Build pipeline controlnet = FluxControlNetModel.from_pretrained("alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha", torch_dtype=torch.bfloat16) @@ -48,3 +49,20 @@ result.save('flux_inpaint.png') print("Successfully inpaint image") + +result = pipe( + prompt=prompt_detailed, + height=size[1], + width=size[0], + control_image=image, + control_mask=mask, + num_inference_steps=28, + generator=generator, + controlnet_conditioning_scale=0.9, + guidance_scale=3.5, + negative_prompt="", + true_guidance_scale=3.5 +).images[0] + +result.save('flux_inpaint_detailed.png') +print("Successfully inpaint image") diff --git a/FLUX_inpainting_server/pipeline_flux_controlnet_inpaint.py b/inpainting/flux_inpainter_server/pipeline_flux_controlnet_inpaint.py similarity index 100% rename from FLUX_inpainting_server/pipeline_flux_controlnet_inpaint.py rename to inpainting/flux_inpainter_server/pipeline_flux_controlnet_inpaint.py diff --git a/FLUX_inpainting_server/readme.md b/inpainting/flux_inpainter_server/readme.md similarity index 100% rename from FLUX_inpainting_server/readme.md rename to inpainting/flux_inpainter_server/readme.md diff --git a/FLUX_inpainting_server/requirements.txt b/inpainting/flux_inpainter_server/requirements.txt similarity index 100% rename from FLUX_inpainting_server/requirements.txt rename to inpainting/flux_inpainter_server/requirements.txt diff --git a/FLUX_inpainting_server/transformer_flux.py b/inpainting/flux_inpainter_server/transformer_flux.py similarity index 100% rename from FLUX_inpainting_server/transformer_flux.py rename to inpainting/flux_inpainter_server/transformer_flux.py diff --git a/inpainting/replicate_inpainter/__init__.py b/inpainting/replicate_inpainter/__init__.py new file mode 100644 index 0000000..5a112c7 --- /dev/null +++ b/inpainting/replicate_inpainter/__init__.py @@ -0,0 +1,4 @@ +from .sdxl import ReplicateSDXLInpainter +from .flux import ReplicateFluxInpainter + +__all__ = ["ReplicateSDXLInpainter", "ReplicateFluxInpainter"] diff --git a/inpainting/replicate_inpainter/base.py b/inpainting/replicate_inpainter/base.py new file mode 100644 index 0000000..d7951a9 --- /dev/null +++ b/inpainting/replicate_inpainter/base.py @@ -0,0 +1,52 @@ +import os, os.path as osp +import replicate +import numpy as np +from abc import ABC, abstractmethod +from PIL import Image + +TMP_DIR = "./tmp_flux" +TMP_PATH = osp.join(TMP_DIR, "result.png") + +class BaseReplicateInpainter(ABC): + REPLICATE_ID = "" + def __init__(self) -> None: + pass + + @abstractmethod + def _build_extra_inputs(self): + pass + + def run(self, image: Image, mask: Image, seed: int, prompt: str): + inputs = { + "image": image, + "mask": mask, + "seed": seed, + "prompt": prompt, + } + inputs.update(self._build_extra_inputs()) + print(inputs.keys()) + return replicate.run( + self.REPLICATE_ID, + input=inputs + ) + + def __call__(self, image: Image, mask: Image, seed: int, prompt: str): + os.makedirs(TMP_DIR, exist_ok=True) + image = image.convert("RGB") + mask_rgb = Image.new('RGB', mask.size) + mask_rgb.paste(mask) + image_tmp_path = osp.join(TMP_DIR, "image.png") + mask_tmp_path = osp.join(TMP_DIR, "mask.png") + image.save(image_tmp_path) + mask_rgb.save(mask_tmp_path) + + print(np.array(mask_rgb).shape, np.array(image).shape) + # exit(0) + + output = self.run(open(image_tmp_path, "rb"), open(mask_tmp_path, "rb"), seed, prompt) + # output = self.run(image, mask, seed, prompt) + assert len(output) == 1 + + with open(TMP_PATH, "wb") as file: + file.write(output[0].read()) + return Image.open(TMP_PATH) \ No newline at end of file diff --git a/inpainting/replicate_inpainter/flux.py b/inpainting/replicate_inpainter/flux.py new file mode 100644 index 0000000..a1503d5 --- /dev/null +++ b/inpainting/replicate_inpainter/flux.py @@ -0,0 +1,21 @@ +from .base import BaseReplicateInpainter + + +class ReplicateFluxInpainter(BaseReplicateInpainter): + """ + Replicate playground: https://replicate.com/black-forest-labs/flux-fill-dev + Speed and Cost: 0.04$ / image, 9.6s / image + """ + + REPLICATE_ID = "fishwowater/flux-dev-controlnet-inpainting-beta:27d3ff35f58b4409775de5a0b36e99b4c6d2d7fc7fe772b35170951db678ec63" + + def _build_extra_inputs(self): + return { + # default guidance scale + "guidance_scale": 3.5, + "true_guidance_scale": 3.5, + "controlnet_conditioning_scale": 0.9, + # default values for inference steps + "num_inference_steps": 24, + "output_quality": 100, + } diff --git a/inpainting/replicate_inpainter/sdxl.py b/inpainting/replicate_inpainter/sdxl.py new file mode 100644 index 0000000..2b117dc --- /dev/null +++ b/inpainting/replicate_inpainter/sdxl.py @@ -0,0 +1,19 @@ +from .base import BaseReplicateInpainter + + + +class ReplicateSDXLInpainter(BaseReplicateInpainter): + """ + Replicate playground: https://replicate.com/lucataco/sdxl-inpainting + Speed and Cost: 0.0023$ / image, 1.9s / image + """ + + REPLICATE_ID = "lucataco/sdxl-inpainting:a5b13068cc81a89a4fbeefeccc774869fcb34df4dbc92c1555e0f2771d49dde7" + + def _build_extra_inputs(self): + return { + # default values on the playground + "guidance_scale": 8.0, + "steps": 20, + "strength": 0.7, + } \ No newline at end of file diff --git a/inpainting_server.sh b/inpainting_server.sh index 0be5d24..ec76e7c 100755 --- a/inpainting_server.sh +++ b/inpainting_server.sh @@ -1,6 +1,6 @@ #!/bin/bash # Installation as shown here: https://huggingface.co/spaces/ameerazam08/FLUX.1-dev-Inpainting-Model-Beta-GPU?docker=true -cd FLUX_inpainting_server +cd inpainting_server/flux_inpainter python -m venv env source env/bin/activate diff --git a/run_pipeline.py b/run_pipeline.py index 63867b8..435548f 100644 --- a/run_pipeline.py +++ b/run_pipeline.py @@ -12,8 +12,8 @@ from queue import Queue from typing import Dict, List, Literal, Optional, Tuple -os.environ['ATTN_BACKEND'] = 'xformers' -os.environ['SPCONV_ALGO'] = 'native' # 'auto' is faster but benchmarks on start +os.environ["ATTN_BACKEND"] = "xformers" +os.environ["SPCONV_ALGO"] = "native" # 'auto' is faster but benchmarks on start import cv2 import imageio @@ -24,7 +24,7 @@ from lpips import LPIPS, im2tensor from PIL import Image -from FLUX_inpainting_server.inpaint import Inpainter +from inpainting import Inpainter # a brief explanation of the orthographic scale: # the ortho scale determines the size of the world in the image @@ -44,6 +44,7 @@ # this means we'll max out at 2.0 + 2 * 0.5 = 3.0 MAX_ORTHO_SCALE = INITIAL_ORTHO_SCALE + ORTHO_SCALE_Y_STEP * ORTHO_SCALE_MAX_NUM_Y_STEPS + class States(Enum): BLOCKED = "🔒" READY = "📦" @@ -60,25 +61,38 @@ class States(Enum): def process_instructions(instructions_path: str) -> List[str]: - with open(instructions_path, 'r') as f: + with open(instructions_path, "r") as f: instructions = json.load(f) # sort the instructions by the x and y values - instructions["tiles"] = sorted(instructions["tiles"], key=lambda x: (x["y"], x["x"])) + instructions["tiles"] = sorted( + instructions["tiles"], key=lambda x: (x["y"], x["x"]) + ) - assert instructions["tiles"][0]["x"] == 0 and instructions["tiles"][0]["y"] == 0, "The first tile must be at position (0, 0)" + assert ( + instructions["tiles"][0]["x"] == 0 and instructions["tiles"][0]["y"] == 0 + ), "The first tile must be at position (0, 0)" for tile_instruction in instructions["tiles"]: - tile_instruction["prompt"] = instructions["prompt"].format(**{"tile_prompt": tile_instruction["prompt"]}) + tile_instruction["prompt"] = instructions["prompt"].format( + **{"tile_prompt": tile_instruction["prompt"]} + ) return instructions -def label_obstructing_tiles(grid: List[Dict], tile_pos: Tuple[int], key: str, create_copy: bool = True, default_slice_height = 0.): + +def label_obstructing_tiles( + grid: List[Dict], + tile_pos: Tuple[int], + key: str, + create_copy: bool = True, + default_slice_height=0.0, +): if create_copy: grid = json.loads(json.dumps(grid)) for tile in grid: - if tile['x'] == tile_pos[0] and tile['y'] == tile_pos[1] - 1: + if tile["x"] == tile_pos[0] and tile["y"] == tile_pos[1] - 1: if "max_corner" in tile: tile[key] = tile["max_corner"] + tile["translation"][-1] + 5e-3 else: @@ -86,7 +100,16 @@ def label_obstructing_tiles(grid: List[Dict], tile_pos: Tuple[int], key: str, cr return grid -def make_non_overlapping_mask(mask, full_mask: bool = False, overlap_x: int = 8, overlap_y: int = 8, extend_at: Optional[Literal["x", "y", "xy"]] = "xy", add_corner: bool = True, erosion_radius: int = 5): + +def make_non_overlapping_mask( + mask, + full_mask: bool = False, + overlap_x: int = 8, + overlap_y: int = 8, + extend_at: Optional[Literal["x", "y", "xy"]] = "xy", + add_corner: bool = True, + erosion_radius: int = 5, +): mask_np = np.array(mask) # erode the mask by a tiny amount @@ -95,42 +118,47 @@ def make_non_overlapping_mask(mask, full_mask: bool = False, overlap_x: int = 8, # Convert the mask to a binary version (0 and 1) for easier processing. bin_mask = (mask_np > 0).astype(np.uint8) - + # Sum along the vertical direction (axis=0) to get a 1D array per column. col_sums = bin_mask.sum(axis=0) - + # Find the first column that contains any white pixels. nonzero_cols = np.where(col_sums > 0)[0] if len(nonzero_cols) == 0: raise ValueError("The mask does not contain any white pixels.") x1 = nonzero_cols[0] - + # The other edge is the column with the maximum white pixels. x2 = int(np.argmax(col_sums)) - + # For column x1, get the indices (rows) where pixels are white. rows_x1 = np.where(bin_mask[:, x1] > 0)[0] y1_top, y1_bot = int(rows_x1[0]), int(rows_x1[-1]) - + # For column x2, get the indices (rows) where pixels are white. rows_x2 = np.where(bin_mask[:, x2] > 0)[0] y2_top, y2_bot = int(rows_x2[0]), int(rows_x2[-1]) # Create a copy of the mask to modify. mask_out = mask_np.copy() - d = 5 # fix white boundaries - h = abs(y1_top - y1_bot) - (0 if not extend_at or "x" not in extend_at else overlap_x) + d = 5 # fix white boundaries + h = abs(y1_top - y1_bot) - ( + 0 if not extend_at or "x" not in extend_at else overlap_x + ) if not full_mask: # Define the four corners of the quadrilateral (in (x, y) order): # Top-left, top-right, bottom-right, bottom-left. - pts = np.array([ - [x1, y1_top - d], - [x2, y2_top - d], - [x2, y2_top + h - d], - [x1, y1_top + h - d] - ], dtype=np.int32) - + pts = np.array( + [ + [x1, y1_top - d], + [x2, y2_top - d], + [x2, y2_top + h - d], + [x1, y1_top + h - d], + ], + dtype=np.int32, + ) + cv2.fillPoly(mask_out, [pts], 0) bot_slope = (y2_bot - y1_bot) / (x2 - x1) @@ -148,62 +176,85 @@ def make_non_overlapping_mask(mask, full_mask: bool = False, overlap_x: int = 8, x_int, y_int = int(x_intersect), int(y_intersect) - pts = np.array([ - [x1, y1_bot], - [x_int, y_int], - [x1, y1_top + h - d] - ], dtype=np.int32) + pts = np.array( + [[x1, y1_bot], [x_int, y_int], [x1, y1_top + h - d]], dtype=np.int32 + ) cv2.fillPoly(mask_out, [pts], 1) if extend_at and "y" in extend_at: - pts = np.array([ - [x1, y1_bot], - [x2, y2_bot], - [x2 - overlap_y, y2_bot + overlap_y * -top_slope], - [x1 - overlap_y, y1_bot + overlap_y * -top_slope] - ], dtype=np.int32) + pts = np.array( + [ + [x1, y1_bot], + [x2, y2_bot], + [x2 - overlap_y, y2_bot + overlap_y * -top_slope], + [x1 - overlap_y, y1_bot + overlap_y * -top_slope], + ], + dtype=np.int32, + ) cv2.fillPoly(mask_out, [pts], 1) - + return mask_out - -def generate_tile_info(blender_path, grid: List[Dict], output_folder: str, resolution: int = 1024): + + +def generate_tile_info( + blender_path, grid: List[Dict], output_folder: str, resolution: int = 1024 +): grid = json.loads(json.dumps(grid)) # Construct command as a list of arguments cmd = [ blender_path, - '-b', - '-P', 'blender_script.py', - '--', - '--output_folder', output_folder, - '--resolution', str(resolution), - '--debase', - '--export_tile_info', - '--no_render', + "-b", + "-P", + "blender_script.py", + "--", + "--output_folder", + output_folder, + "--resolution", + str(resolution), + "--debase", + "--export_tile_info", + "--no_render", ] if len(grid) > 0: tile_json = json.dumps(grid) - cmd.extend(['--tiles', tile_json]) + cmd.extend(["--tiles", tile_json]) # Run command with redirected output - with open(os.devnull, 'wb') as devnull: + with open(os.devnull, "wb") as devnull: subprocess.check_call(cmd, stdout=devnull, stderr=devnull) -def render_next_tile(blender_path, grid: List[Dict], output_folder: str, resolution: int = 1024, pos: Tuple[int, int] = (0, 0), ortho_scale: float = 1.75): + +def render_next_tile( + blender_path, + grid: List[Dict], + output_folder: str, + resolution: int = 1024, + pos: Tuple[int, int] = (0, 0), + ortho_scale: float = 1.75, +): grid = json.loads(json.dumps(grid)) # figure out which tiles are at most 2 tiles away (Manhattan distance) from the current tile # and only render tiles that aren't above the current tile (which would mess with the mask) - grid = [tile for tile in grid if tile["x"] <= pos[0] and tile["y"] <= pos[1] and not (tile["x"] == pos[0] and tile["y"] == pos[1])] + grid = [ + tile + for tile in grid + if tile["x"] <= pos[0] + and tile["y"] <= pos[1] + and not (tile["x"] == pos[0] and tile["y"] == pos[1]) + ] # to provide additional context for tiles x=0, we put the y-1 tile (if it exists) # at position (-1, y) as this will not be cropped out and can provide additional context if pos[0] == 0 and pos[1] > 0: # find the tile at y-1 - y_minus_1 = [tile for tile in grid if tile["x"] == 0 and tile["y"] == pos[1] - 1] + y_minus_1 = [ + tile for tile in grid if tile["x"] == 0 and tile["y"] == pos[1] - 1 + ] if len(y_minus_1) > 0: tile_dict = y_minus_1[0] grid.append({**tile_dict, "x": -1, "y": pos[1], "has_slab": False}) @@ -213,109 +264,149 @@ def render_next_tile(blender_path, grid: List[Dict], output_folder: str, resolut # Construct command as a list of arguments cmd = [ blender_path, - '-b', - '-P', 'blender_script.py', - '--', - '--output_folder', output_folder, - '--resolution', str(resolution), - '--debase', - f'--next_tile_at={pos[0]},{pos[1]}' + "-b", + "-P", + "blender_script.py", + "--", + "--output_folder", + output_folder, + "--resolution", + str(resolution), + "--debase", + f"--next_tile_at={pos[0]},{pos[1]}", ] if len(grid) > 0: - cmd.extend(['--tiles', json.dumps(grid)]) - - views = [{"yaw": np.radians(-45), "pitch": np.arctan(1/np.sqrt(2)), "radius": 2, "fov": np.radians(47.1), "ortho_scale": ortho_scale}] + cmd.extend(["--tiles", json.dumps(grid)]) + + views = [ + { + "yaw": np.radians(-45), + "pitch": np.arctan(1 / np.sqrt(2)), + "radius": 2, + "fov": np.radians(47.1), + "ortho_scale": ortho_scale, + } + ] cmd.extend(["--views", json.dumps(views)]) # Run command with redirected output - with open(os.devnull, 'wb') as devnull: + with open(os.devnull, "wb") as devnull: subprocess.check_call(cmd, stdout=devnull, stderr=devnull) -def find_orientation_of_tile(blender_path, tile_dict: Dict, conditioning_image: str, output_folder: str, resolution: int=256, rotations: Tuple[int] = (0, 90, 180, 270)): + +def find_orientation_of_tile( + blender_path, + tile_dict: Dict, + conditioning_image: str, + output_folder: str, + resolution: int = 256, + rotations: Tuple[int] = (0, 90, 180, 270), +): # place this singular tile at the origin tile_dict = json.loads(json.dumps(tile_dict)) tile_dict["x"], tile_dict["y"] = 0, 0 for rotation in rotations: tile_dict["rotation"] = rotation - tile_json = json.dumps([tile_dict]) # No need to escape quotes when using list arguments + tile_json = json.dumps( + [tile_dict] + ) # No need to escape quotes when using list arguments cmd = [ blender_path, - '-b', - '-P', 'blender_script.py', - '--', - '--output_folder', os.path.join(output_folder, f"rot_{rotation}"), - '--resolution', str(resolution), - '--tiles', tile_json, - '--rgb_only' + "-b", + "-P", + "blender_script.py", + "--", + "--output_folder", + os.path.join(output_folder, f"rot_{rotation}"), + "--resolution", + str(resolution), + "--tiles", + tile_json, + "--rgb_only", ] # Run command with redirected output - with open(os.devnull, 'wb') as devnull: + with open(os.devnull, "wb") as devnull: subprocess.check_call(cmd, stdout=devnull, stderr=devnull) # find the orientation of the tile - available_rotations = glob.glob(f'{output_folder}/rot_*/*.png') + available_rotations = glob.glob(f"{output_folder}/rot_*/*.png") rotations_dict = { - int(fn.split("rot_")[-1].split("/")[0]): Image.open(fn).convert('RGBA') for fn in available_rotations + int(fn.split("rot_")[-1].split("/")[0]): Image.open(fn).convert("RGBA") + for fn in available_rotations } - conditioning = Image.open(conditioning_image).convert('RGBA').resize((resolution, resolution)) - lpips_inp_cond = im2tensor(np.array(conditioning.convert("RGB"))[:, :, ::-1]).to("cuda") + conditioning = ( + Image.open(conditioning_image).convert("RGBA").resize((resolution, resolution)) + ) + lpips_inp_cond = im2tensor(np.array(conditioning.convert("RGB"))[:, :, ::-1]).to( + "cuda" + ) - lpips_fn = LPIPS(net='vgg').cuda() + lpips_fn = LPIPS(net="vgg").cuda() lpips_loss = { rotation: lpips_fn( lpips_inp_cond, - im2tensor(np.array(rotations_dict[rotation].convert("RGB"))[:, :, ::-1]).to("cuda") - ).item() for rotation in rotations_dict + im2tensor(np.array(rotations_dict[rotation].convert("RGB"))[:, :, ::-1]).to( + "cuda" + ), + ).item() + for rotation in rotations_dict } # clean up for rotation in rotations_dict: - shutil.rmtree(f'{output_folder}/rot_{rotation}') + shutil.rmtree(f"{output_folder}/rot_{rotation}") return min(lpips_loss, key=lpips_loss.get) + def process_mask(mask_image: Image) -> Image: - return Image.fromarray((np.floor(np.array(mask_image.convert('L'))/255)).clip(0, 1).astype(np.uint8)*255) + return Image.fromarray( + (np.floor(np.array(mask_image.convert("L")) / 255)).clip(0, 1).astype(np.uint8) + * 255 + ) + def pil_mask_to_numpy(mask: Image) -> np.ndarray: return (np.asarray(mask) / 255).astype(np.float32) + def numpy_mask_to_pil(mask: np.ndarray) -> Image: - return Image.fromarray((mask*255).astype(np.uint8)).convert('L') + return Image.fromarray((mask * 255).astype(np.uint8)).convert("L") + def inpaint_tile( - server: Inpainter|str, - prompt: str, - input_folder: str, - input_image: str, - output_image: Optional[str] = None, - seed: int = 999, - mode: Literal['single', 'overlap-free'] = 'single', - extend_at: Optional[str] = None, - ortho_scale: float = 1.75, - base_ortho_scale: float = 1.75, -): + server: Inpainter | str, + prompt: str, + input_folder: str, + input_image: str, + output_image: Optional[str] = None, + seed: int = 999, + mode: Literal["single", "overlap-free"] = "single", + extend_at: Optional[str] = None, + ortho_scale: float = 1.75, + base_ortho_scale: float = 1.75, +): if isinstance(server, str): - inpainter = Inpainter(server) + inpainter = Inpainter(inpainter_type="flux_local", gradio_url=server) else: inpainter = server - + image_path = os.path.join(input_folder, input_image) - mask_path = image_path.replace('rgb.png', 'inpaint_mask.png') + mask_path = image_path.replace("rgb.png", "inpaint_mask.png") if output_image is None: - output_path = image_path.replace('rgb.png', 'inpainted.png') + output_path = image_path.replace("rgb.png", "inpainted.png") else: output_path = os.path.join(input_folder, output_image) - base = Image.open(image_path).convert('RGB') + base = Image.open(image_path).convert("RGB") mask = process_mask(Image.open(mask_path)) - MAX_OVERLAP = 4 # in pixels + MAX_OVERLAP = 4 # in pixels # depending on the ortho scale, we allow fewer pixels to overlap overlap = int((base_ortho_scale / ortho_scale) * MAX_OVERLAP) @@ -323,13 +414,39 @@ def inpaint_tile( # we also erode the original mask a bit erosion_radius = int((base_ortho_scale / ortho_scale) * 4) - overlap_mask = numpy_mask_to_pil(make_non_overlapping_mask(pil_mask_to_numpy(mask), extend_at=extend_at, full_mask=(mode == "single"), overlap_x=overlap, overlap_y=overlap//2, erosion_radius=erosion_radius)) + overlap_mask = numpy_mask_to_pil( + make_non_overlapping_mask( + pil_mask_to_numpy(mask), + extend_at=extend_at, + full_mask=(mode == "single"), + overlap_x=overlap, + overlap_y=overlap // 2, + erosion_radius=erosion_radius, + ) + ) overlap_mask.save(mask_path.replace("inpaint", "overlap-free")) - numpy_mask_to_pil(make_non_overlapping_mask(pil_mask_to_numpy(mask), extend_at=None, overlap_x=0, overlap_y=0, add_corner=False, full_mask=(mode == "single"), erosion_radius=erosion_radius)).save(mask_path) + numpy_mask_to_pil( + make_non_overlapping_mask( + pil_mask_to_numpy(mask), + extend_at=None, + overlap_x=0, + overlap_y=0, + add_corner=False, + full_mask=(mode == "single"), + erosion_radius=erosion_radius, + ) + ).save(mask_path) image_inpainted = inpainter(base, overlap_mask, seed, prompt) image_inpainted.save(output_path) -def run_trellis(pipe, image_path, seed=1, mesh_path='./assets/house-tile.glb', metric_thresholds={"squareness": 1, "slab_size": 4096, "completeness": 0.95}): + +def run_trellis( + pipe, + image_path, + seed=1, + mesh_path="./assets/house-tile.glb", + metric_thresholds={"squareness": 1, "slab_size": 4096, "completeness": 0.95}, +): from trellis.utils import render_utils, postprocessing_utils import torch @@ -345,29 +462,30 @@ def run_trellis(pipe, image_path, seed=1, mesh_path='./assets/house-tile.glb', m for metric_name in ("squareness", "slab_size", "completeness"): print(f"{metric_name}: {outputs[metric_name]}") - video = render_utils.render_video(outputs['scene']['gaussian'][0])['color'] - imageio.mimsave(image_path.replace('.png', '.mp4'), video, fps=30) + video = render_utils.render_video(outputs["scene"]["gaussian"][0])["color"] + imageio.mimsave(image_path.replace(".png", ".mp4"), video, fps=30) # GLB files can be extracted from the outputs glb = postprocessing_utils.to_glb( - outputs['scene']['gaussian'][0], - outputs['scene']['mesh'][0], + outputs["scene"]["gaussian"][0], + outputs["scene"]["mesh"][0], # Optional parameters - simplify=0.95, # Ratio of triangles to remove in the simplification process - texture_size=1024, # Size of the texture used for the GLB + simplify=0.95, # Ratio of triangles to remove in the simplification process + texture_size=1024, # Size of the texture used for the GLB ) glb.export(mesh_path) del glb outputs_to_save = {k: v for k, v in outputs.items() if k not in ["scene"]} - torch.save(outputs_to_save, mesh_path.replace('.glb', '.pt')) + torch.save(outputs_to_save, mesh_path.replace(".glb", ".pt")) - for k in [k for k in outputs['scene'].keys() if k != 'gaussian']: - del outputs['scene'][k] + for k in [k for k in outputs["scene"].keys() if k != "gaussian"]: + del outputs["scene"][k] return outputs + def get_widest_point_y(image, find_last=False): arr = np.array(image) @@ -408,6 +526,7 @@ def get_widest_point_y(image, find_last=False): return widest_y + def center_on_square(square_img, intricate_img): square_ground_y = get_widest_point_y(square_img) intricate_ground_y = get_widest_point_y(intricate_img) @@ -419,14 +538,27 @@ def center_on_square(square_img, intricate_img): return new_intricate -def rebased_inpainted_tile(inpainted_image_path, base_slab_path, is_left_tile: bool = True, scale=0.85, postfix="inpainted", erosion_radius=5, ortho_scale=1.75, base_ortho_scale=1.75, render_resolution=1024): + +def rebased_inpainted_tile( + inpainted_image_path, + base_slab_path, + is_left_tile: bool = True, + scale=0.85, + postfix="inpainted", + erosion_radius=5, + ortho_scale=1.75, + base_ortho_scale=1.75, + render_resolution=1024, +): if ortho_scale != base_ortho_scale: ortho_rescale = base_ortho_scale / ortho_scale crop_size = int(render_resolution * ortho_rescale) crop_size_sides = (render_resolution - crop_size) // 2 # crop the images to the same size - for fn in glob.glob(os.path.join(os.path.dirname(inpainted_image_path), "000_*.png")): + for fn in glob.glob( + os.path.join(os.path.dirname(inpainted_image_path), "000_*.png") + ): if "backup" in fn: continue @@ -434,9 +566,16 @@ def rebased_inpainted_tile(inpainted_image_path, base_slab_path, is_left_tile: b shutil.copy(fn, fn.replace(".png", "_backup.png")) img = Image.open(fn) - img = img.crop((crop_size_sides, crop_size_sides, crop_size_sides + crop_size, crop_size_sides + crop_size)) + img = img.crop( + ( + crop_size_sides, + crop_size_sides, + crop_size_sides + crop_size, + crop_size_sides + crop_size, + ) + ) img.save(fn) - + inpainted_image = Image.open(inpainted_image_path) width, height = inpainted_image.size @@ -446,18 +585,38 @@ def rebased_inpainted_tile(inpainted_image_path, base_slab_path, is_left_tile: b base_slab = base_slab.resize((width, height)) # mask out the slab - conditioning_mask = Image.open(inpainted_image_path.replace(postfix, "conditioning_mask")).convert("L") - inpaint_mask = Image.open(inpainted_image_path.replace(postfix, "inpaint_mask")).convert("L") + conditioning_mask = Image.open( + inpainted_image_path.replace(postfix, "conditioning_mask") + ).convert("L") + inpaint_mask = Image.open( + inpainted_image_path.replace(postfix, "inpaint_mask") + ).convert("L") discard_mask = PIL.ImageOps.invert(conditioning_mask) if is_left_tile: - slab_mask = Image.composite(conditioning_mask, Image.new('L', conditioning_mask.size, (0,)), PIL.ImageOps.invert(inpaint_mask)) - slabless = Image.composite(inpainted_image, Image.new('RGBA', inpainted_image.size, (0, 0, 0, 255)), PIL.ImageOps.invert(slab_mask)) + slab_mask = Image.composite( + conditioning_mask, + Image.new("L", conditioning_mask.size, (0,)), + PIL.ImageOps.invert(inpaint_mask), + ) + slabless = Image.composite( + inpainted_image, + Image.new("RGBA", inpainted_image.size, (0, 0, 0, 255)), + PIL.ImageOps.invert(slab_mask), + ) else: - slabless = Image.composite(inpainted_image, Image.new('RGBA', inpainted_image.size, (0, 0, 0, 255)), inpaint_mask) - - slabless = Image.composite(slabless, Image.new('RGBA', inpainted_image.size, (0, 0, 0, 255)), PIL.ImageOps.invert(discard_mask)) + slabless = Image.composite( + inpainted_image, + Image.new("RGBA", inpainted_image.size, (0, 0, 0, 255)), + inpaint_mask, + ) + + slabless = Image.composite( + slabless, + Image.new("RGBA", inpainted_image.size, (0, 0, 0, 255)), + PIL.ImageOps.invert(discard_mask), + ) # isolate the object isolated = rembg.remove(slabless, alpha_matting=True, post_process_mask=True) @@ -465,17 +624,23 @@ def rebased_inpainted_tile(inpainted_image_path, base_slab_path, is_left_tile: b # erode the rembg result a bit in case there are is a white gradient around the object isolated_np = (np.asarray(isolated) / 255).astype(np.float32) isolated_np[..., -1] = np.where(isolated_np[..., -1] > 0.5, 1, 0) - isolated_np[..., -1] = cv2.erode(isolated_np[..., -1], np.ones((erosion_radius, erosion_radius), np.uint8)) + isolated_np[..., -1] = cv2.erode( + isolated_np[..., -1], np.ones((erosion_radius, erosion_radius), np.uint8) + ) isolated = Image.fromarray((isolated_np * 255).astype(np.uint8)) # make sure everything on the slab is actually retained # sometimes, rembg will remove some pixels on the slab which will break # the rebasing process - slab_surface = Image.composite(base_slab.split()[-1], Image.new('L', base_slab.size, (0,)), inpaint_mask) + slab_surface = Image.composite( + base_slab.split()[-1], Image.new("L", base_slab.size, (0,)), inpaint_mask + ) # we'll erode this mask to make sure the object is not too close to the edge slab_surface_np = pil_mask_to_numpy(slab_surface) - slab_surface_np = cv2.erode(slab_surface_np, np.ones((erosion_radius, erosion_radius), np.uint8)) + slab_surface_np = cv2.erode( + slab_surface_np, np.ones((erosion_radius, erosion_radius), np.uint8) + ) slab_surface_eroded = numpy_mask_to_pil(slab_surface_np) isolated = Image.composite(inpainted_image, isolated, slab_surface_eroded) @@ -493,15 +658,15 @@ def rebased_inpainted_tile(inpainted_image_path, base_slab_path, is_left_tile: b isolated = Image.fromarray((isolated_np * 255).astype(np.uint8)) # reposition the object so it is centered again after scaling - repositioned = Image.new('RGBA', (width, height), (0, 0, 0, 0)) - + repositioned = Image.new("RGBA", (width, height), (0, 0, 0, 0)) + paste_x = (width - scaled_width) // 2 paste_y = (height - scaled_height) // 2 repositioned.paste(isolated, (paste_x, paste_y), mask=isolated) # rebase the object onto the original square, centering it - merged = Image.new('RGBA', (width, height), (0, 0, 0, 0)) + merged = Image.new("RGBA", (width, height), (0, 0, 0, 0)) merged.paste(base_slab, (0, 0)) centered_object = center_on_square(base_slab, repositioned) @@ -510,16 +675,33 @@ def rebased_inpainted_tile(inpainted_image_path, base_slab_path, is_left_tile: b return merged -def worker(prefix, tile_dict, gradio_url, blender_path, gpu_queue, generated_grid, first_tile_path, tile_mq, task_id, config, init_seed=429, verbose=True): + +def worker( + prefix, + tile_dict, + inpainter_type, + gradio_url, + blender_path, + gpu_queue, + generated_grid, + first_tile_path, + tile_mq, + task_id, + config, + init_seed=429, + verbose=True, +): if not verbose: - sys.stdout = open("/dev/null", 'w') - sys.stderr = open("/dev/null", 'w') + sys.stdout = open("/dev/null", "w") + sys.stderr = open("/dev/null", "w") pos = (tile_dict["x"], tile_dict["y"]) pos_str = f"{pos[0]},{pos[1]}" gpu_id = gpu_queue.get() - tile_mq.put({"pos": pos, "state": States.ASSIGNED, "task_id": task_id, "gpu_id": gpu_id}) + tile_mq.put( + {"pos": pos, "state": States.ASSIGNED, "task_id": task_id, "gpu_id": gpu_id} + ) os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id) @@ -531,8 +713,12 @@ def worker(prefix, tile_dict, gradio_url, blender_path, gpu_queue, generated_gri # create local copy of generated_grid generated_grid = json.loads(json.dumps(generated_grid)) - current_base_scale = BASE_ORTHO_SCALE if (pos[0] == 0 and pos[1] == 0) else INITIAL_ORTHO_SCALE - current_ortho_scale = current_base_scale + ORTHO_SCALE_Y_STEP * min(pos[1], ORTHO_SCALE_MAX_NUM_Y_STEPS) + current_base_scale = ( + BASE_ORTHO_SCALE if (pos[0] == 0 and pos[1] == 0) else INITIAL_ORTHO_SCALE + ) + current_ortho_scale = current_base_scale + ORTHO_SCALE_Y_STEP * min( + pos[1], ORTHO_SCALE_MAX_NUM_Y_STEPS + ) tile_path = os.path.join(prefix, pos_str) @@ -543,12 +729,16 @@ def worker(prefix, tile_dict, gradio_url, blender_path, gpu_queue, generated_gri def load_pipeline_in_thread(event, queue): # Load the pipeline in a separate thread to avoid blocking the main thread - thread_pipeline = TrellisImageTo3DPipeline.from_pretrained("JeffreyXiang/TRELLIS-image-large") + thread_pipeline = TrellisImageTo3DPipeline.from_pretrained( + "JeffreyXiang/TRELLIS-image-large" + ) thread_pipeline.to("cuda") queue.put(thread_pipeline) event.set() - - pipeline_thread = threading.Thread(target=load_pipeline_in_thread, args=(has_loaded_pipeline, pipeline_queue)) + + pipeline_thread = threading.Thread( + target=load_pipeline_in_thread, args=(has_loaded_pipeline, pipeline_queue) + ) pipeline_thread.start() pipeline = None @@ -558,7 +748,13 @@ def load_pipeline_in_thread(event, queue): # render again to restore the masks to their original state tile_mq.put({"pos": pos, "state": States.RENDERING, "task_id": task_id}) - render_next_tile(blender_path, grid=generated_grid, output_folder=tile_path, pos=pos, ortho_scale=current_ortho_scale) + render_next_tile( + blender_path, + grid=generated_grid, + output_folder=tile_path, + pos=pos, + ortho_scale=current_ortho_scale, + ) mode = "single" if (pos[0] == 0 and pos[1] == 0) else "overlap-free" @@ -575,12 +771,22 @@ def load_pipeline_in_thread(event, queue): seed = init_seed + task_id + attempts - inpainting_server = Inpainter(gradio_url) + inpainting_server = Inpainter(inpainter_type, gradio_url) tile_mq.put({"pos": pos, "state": States.INPAINTING, "task_id": task_id}) try: - inpaint_tile(inpainting_server, tile_dict["prompt"], tile_path, '000_rgb.png', seed=seed, mode=mode, extend_at=extend_at, ortho_scale=current_ortho_scale, base_ortho_scale=current_base_scale) + inpaint_tile( + inpainting_server, + tile_dict["prompt"], + tile_path, + "000_rgb.png", + seed=seed, + mode=mode, + extend_at=extend_at, + ortho_scale=current_ortho_scale, + base_ortho_scale=current_base_scale, + ) except Exception as e: print(f"Encountered an error while inpainting tile {pos_str}: {e}") @@ -589,7 +795,14 @@ def load_pipeline_in_thread(event, queue): tile_mq.put({"pos": pos, "state": States.REBASING, "task_id": task_id}) - rebased = rebased_inpainted_tile(os.path.join(tile_path, "000_inpainted.png"), os.path.join(first_tile_path, "000_rgb.png"), is_left_tile=(pos[0] == 0), base_ortho_scale=BASE_ORTHO_SCALE, ortho_scale=current_ortho_scale, scale=config.rebasing_scale if hasattr(config, "rebasing_scale") else 0.85) + rebased = rebased_inpainted_tile( + os.path.join(tile_path, "000_inpainted.png"), + os.path.join(first_tile_path, "000_rgb.png"), + is_left_tile=(pos[0] == 0), + base_ortho_scale=BASE_ORTHO_SCALE, + ortho_scale=current_ortho_scale, + scale=config.rebasing_scale if hasattr(config, "rebasing_scale") else 0.85, + ) rebased.save(os.path.join(tile_path, "000_rebased.png")) tile_mesh_path = os.path.join(tile_path, "tile.glb") @@ -600,15 +813,21 @@ def load_pipeline_in_thread(event, queue): if not has_loaded_pipeline.is_set(): tile_mq.put({"pos": pos, "state": States.STALLED, "task_id": task_id}) has_loaded_pipeline.wait() - tile_mq.put({"pos": pos, "state": States.GENERATING, "task_id": task_id}) + tile_mq.put( + {"pos": pos, "state": States.GENERATING, "task_id": task_id} + ) if pipeline is None: pipeline = pipeline_queue.get() pipeline_thread.join() - outputs = run_trellis(pipeline, os.path.join(tile_path, "000_rebased.png"), mesh_path=tile_mesh_path) - gs = outputs['scene']['gaussian'][0] + outputs = run_trellis( + pipeline, + os.path.join(tile_path, "000_rebased.png"), + mesh_path=tile_mesh_path, + ) + gs = outputs["scene"]["gaussian"][0] break except PoorTileQualityException as e: @@ -621,7 +840,7 @@ def load_pipeline_in_thread(event, queue): tile_mq.put({"pos": pos, "state": States.CRASH, "task_id": task_id}) return - + # Clear memory after each iteration to avoid memory leaks # release the model for k in list(pipeline.models.keys()): @@ -653,21 +872,23 @@ def load_pipeline_in_thread(event, queue): "x": pos[0], "y": pos[1], "seed": seed, - **find_cuts(gaussian_path=os.path.join(tile_path, "mesh.ply")) + **find_cuts(gaussian_path=os.path.join(tile_path, "mesh.ply")), } torch.cuda.empty_cache() tile_mq.put({"pos": pos, "state": States.ORIENTING, "task_id": task_id}) - tile_dict["rotation"] = find_orientation_of_tile(blender_path, tile_dict, os.path.join(tile_path, "000_rebased.png"), tile_path) + tile_dict["rotation"] = find_orientation_of_tile( + blender_path, tile_dict, os.path.join(tile_path, "000_rebased.png"), tile_path + ) - with open(os.path.join(tile_path, "grid.json"), 'w') as f: + with open(os.path.join(tile_path, "grid.json"), "w") as f: json.dump(generated_grid, f) generate_tile_info(blender_path, [tile_dict], tile_path) - with open(os.path.join(tile_path, "tile_info.json"), 'r') as f: + with open(os.path.join(tile_path, "tile_info.json"), "r") as f: tile_info = json.load(f) tile_dict = {**tile_info[0], **tile_dict} @@ -678,31 +899,37 @@ def load_pipeline_in_thread(event, queue): return tile_dict + def main( - instructions: str = "demo.json", - prefix: str = 'run_new_prompts/loop', - parallel: bool = True, - workers: int = -1, - gpu_ids: List[int] = None, - skip_existing: bool = False, - workers_per_gpu: int = 1, - seed: int = 1429, - gradio_url: str = 'http://127.0.0.1:7860', - blender_path: str = 'blender-3.6.19-linux-x64/blender', - resample: Tuple[int, int] = None, - resample_prompt: str = None, - **kwargs - ): - - assert "CUDA_HOME" in os.environ, "CUDA_HOME not set. Please restart the script prefixed with 'CUDA_HOME=/path/to/cuda'" - + instructions: str = "demo.json", + prefix: str = "run_new_prompts/loop", + parallel: bool = True, + workers: int = -1, + gpu_ids: List[int] = None, + skip_existing: bool = False, + workers_per_gpu: int = 1, + seed: int = 1429, + gradio_url: str = "http://127.0.0.1:7860", + inpainter_type: Literal[ + "flux_local", "flux_replicate", "sdxl_replicate" + ] = "flux_local", + blender_path: str = "blender-3.6.19-linux-x64/blender", + resample: Tuple[int, int] = None, + resample_prompt: str = None, + **kwargs, +): + + assert ( + "CUDA_HOME" in os.environ + ), "CUDA_HOME not set. Please restart the script prefixed with 'CUDA_HOME=/path/to/cuda'" + os.makedirs(prefix, exist_ok=True) config = edict(kwargs) instructions = process_instructions(instructions) - with open(os.path.join(prefix, "instructions.json"), 'w') as f: + with open(os.path.join(prefix, "instructions.json"), "w") as f: json.dump(instructions["tiles"], f, indent=4) if resample is not None and isinstance(resample, str): @@ -710,21 +937,25 @@ def main( def has_instructions(x, y): return any(tile["x"] == x and tile["y"] == y for tile in instructions["tiles"]) - + def unlock_adjacent_tiles(state_grid, pos): # unlock the next tile on the x axis (but only if we are not waiting for a tile below) - if has_instructions(pos[0] + 1, pos[1]) and ((pos[1] - 1) < 0 or state_grid[pos[0] + 1][pos[1] - 1] == States.DONE): + if has_instructions(pos[0] + 1, pos[1]) and ( + (pos[1] - 1) < 0 or state_grid[pos[0] + 1][pos[1] - 1] == States.DONE + ): if state_grid[pos[0] + 1][pos[1]] == States.BLOCKED: state_grid[pos[0] + 1][pos[1]] = States.READY # unlock the next tile on the y axis, but only if we are the first tile (or we were waiting) - if has_instructions(pos[0], pos[1] + 1) and (pos[0] == 0 or state_grid[pos[0] - 1][pos[1] + 1] == States.DONE): + if has_instructions(pos[0], pos[1] + 1) and ( + pos[0] == 0 or state_grid[pos[0] - 1][pos[1] + 1] == States.DONE + ): if state_grid[pos[0]][pos[1] + 1] == States.BLOCKED: state_grid[pos[0]][pos[1] + 1] = States.READY if skip_existing or resample is not None: # load the grid from disk - with open(os.path.join(prefix, "grid.json"), 'r') as f: + with open(os.path.join(prefix, "grid.json"), "r") as f: generated_grid = json.load(f) else: @@ -732,7 +963,7 @@ def unlock_adjacent_tiles(state_grid, pos): first_tile_path = os.path.join(prefix, "0,0") - multiprocessing.set_start_method('forkserver', force=True) + multiprocessing.set_start_method("forkserver", force=True) if gpu_ids is None: gpu_ids = os.environ.get("CUDA_VISIBLE_DEVICES", "0").split(",") @@ -756,10 +987,19 @@ def unlock_adjacent_tiles(state_grid, pos): parallel = False for tile_instruction in instructions["tiles"]: - if tile_instruction["x"] == resample[0] and tile_instruction["y"] == resample[1]: - tile_instruction["prompt"] = instructions["prompt"].format(**{"tile_prompt": resample_prompt}) - - generated_grid = [tile for tile in generated_grid if not (tile["x"] == resample[0] and tile["y"] == resample[1])] + if ( + tile_instruction["x"] == resample[0] + and tile_instruction["y"] == resample[1] + ): + tile_instruction["prompt"] = instructions["prompt"].format( + **{"tile_prompt": resample_prompt} + ) + + generated_grid = [ + tile + for tile in generated_grid + if not (tile["x"] == resample[0] and tile["y"] == resample[1]) + ] if not parallel: # emit all messages @@ -777,11 +1017,13 @@ def unlock_adjacent_tiles(state_grid, pos): # manually unlock the next tiles for tile in generated_grid: state_grid[tile["x"]][tile["y"]] = States.DONE - tile_mq.put({"pos": (tile["x"], tile["y"]), "state": States.DONE, "task_id": -1}) + tile_mq.put( + {"pos": (tile["x"], tile["y"]), "state": States.DONE, "task_id": -1} + ) unlock_adjacent_tiles(state_grid, (tile["x"], tile["y"])) if resample_prompt is not None: - with open(os.path.join(prefix, "instructions.json"), 'w') as f: + with open(os.path.join(prefix, "instructions.json"), "w") as f: json.dump(instructions["tiles"], f, indent=4) if parallel: @@ -790,8 +1032,15 @@ def unlock_adjacent_tiles(state_grid, pos): results = [] def announce_crash(tile, task_id, e): - tile_mq.put({"pos": (tile["x"], tile["y"]), "state": States.CRASH, "task_id": task_id, "error": str(e)}) - + tile_mq.put( + { + "pos": (tile["x"], tile["y"]), + "state": States.CRASH, + "task_id": task_id, + "error": str(e), + } + ) + def queue_available_jobs(): for tile in instructions["tiles"]: if state_grid[tile["x"]][tile["y"]] != States.READY: @@ -802,12 +1051,42 @@ def queue_available_jobs(): task_id = len(results) if parallel: - res = pool.apply_async(worker, (prefix, tile, gradio_url, blender_path, gpu_queue, generated_grid, first_tile_path, tile_mq, task_id, config, seed), error_callback=partial(announce_crash, tile, task_id)) + res = pool.apply_async( + worker, + ( + prefix, + tile, + gradio_url, + blender_path, + gpu_queue, + generated_grid, + first_tile_path, + tile_mq, + task_id, + config, + seed, + ), + error_callback=partial(announce_crash, tile, task_id), + ) else: # peek at the gpu queue gpu_id = gpu_queue.get() gpu_queue.put(gpu_id) - res = worker(prefix, tile, gradio_url, blender_path, gpu_queue, generated_grid, first_tile_path, tile_mq, task_id, config, seed, verbose=True) + res = worker( + prefix, + tile, + inpainter_type, + gradio_url, + blender_path, + gpu_queue, + generated_grid, + first_tile_path, + tile_mq, + task_id, + config, + seed, + verbose=True, + ) gpu_queue.put(gpu_id) results.append(res) @@ -816,7 +1095,10 @@ def queue_available_jobs(): while True: # we are done! - if all(all(state == States.DONE for state in row.values()) for row in state_grid.values()): + if all( + all(state == States.DONE for state in row.values()) + for row in state_grid.values() + ): break message = tile_mq.get() @@ -834,7 +1116,9 @@ def queue_available_jobs(): if state == States.DONE: if task_id != -1: - generated_grid.append(results[task_id].get() if parallel else results[task_id]) + generated_grid.append( + results[task_id].get() if parallel else results[task_id] + ) # release the gpu gpu_queue.put(gpu_logbook[task_id]) del gpu_logbook[task_id] @@ -842,7 +1126,7 @@ def queue_available_jobs(): unlock_adjacent_tiles(state_grid, pos) # write the grid to disk - with open(os.path.join(prefix, "grid.json"), 'w') as f: + with open(os.path.join(prefix, "grid.json"), "w") as f: json.dump(generated_grid, f, indent=4) elif state == States.CRASH: @@ -860,6 +1144,8 @@ def queue_available_jobs(): m.shutdown() -if __name__ == '__main__': + +if __name__ == "__main__": import fire + fire.Fire(main) From 53fe6628f0b6913618ead3a58c33062521e9d8c1 Mon Sep 17 00:00:00 2001 From: FishWoWater Date: Tue, 22 Apr 2025 18:40:46 +0800 Subject: [PATCH 2/2] add replicate support while undo formatting --- run_pipeline.py | 657 ++++++++++++++---------------------------------- 1 file changed, 186 insertions(+), 471 deletions(-) diff --git a/run_pipeline.py b/run_pipeline.py index 435548f..47bcc0a 100644 --- a/run_pipeline.py +++ b/run_pipeline.py @@ -12,8 +12,8 @@ from queue import Queue from typing import Dict, List, Literal, Optional, Tuple -os.environ["ATTN_BACKEND"] = "xformers" -os.environ["SPCONV_ALGO"] = "native" # 'auto' is faster but benchmarks on start +os.environ['ATTN_BACKEND'] = 'xformers' +os.environ['SPCONV_ALGO'] = 'native' # 'auto' is faster but benchmarks on start import cv2 import imageio @@ -44,7 +44,6 @@ # this means we'll max out at 2.0 + 2 * 0.5 = 3.0 MAX_ORTHO_SCALE = INITIAL_ORTHO_SCALE + ORTHO_SCALE_Y_STEP * ORTHO_SCALE_MAX_NUM_Y_STEPS - class States(Enum): BLOCKED = "🔒" READY = "📦" @@ -61,38 +60,25 @@ class States(Enum): def process_instructions(instructions_path: str) -> List[str]: - with open(instructions_path, "r") as f: + with open(instructions_path, 'r') as f: instructions = json.load(f) # sort the instructions by the x and y values - instructions["tiles"] = sorted( - instructions["tiles"], key=lambda x: (x["y"], x["x"]) - ) + instructions["tiles"] = sorted(instructions["tiles"], key=lambda x: (x["y"], x["x"])) - assert ( - instructions["tiles"][0]["x"] == 0 and instructions["tiles"][0]["y"] == 0 - ), "The first tile must be at position (0, 0)" + assert instructions["tiles"][0]["x"] == 0 and instructions["tiles"][0]["y"] == 0, "The first tile must be at position (0, 0)" for tile_instruction in instructions["tiles"]: - tile_instruction["prompt"] = instructions["prompt"].format( - **{"tile_prompt": tile_instruction["prompt"]} - ) + tile_instruction["prompt"] = instructions["prompt"].format(**{"tile_prompt": tile_instruction["prompt"]}) return instructions - -def label_obstructing_tiles( - grid: List[Dict], - tile_pos: Tuple[int], - key: str, - create_copy: bool = True, - default_slice_height=0.0, -): +def label_obstructing_tiles(grid: List[Dict], tile_pos: Tuple[int], key: str, create_copy: bool = True, default_slice_height = 0.): if create_copy: grid = json.loads(json.dumps(grid)) for tile in grid: - if tile["x"] == tile_pos[0] and tile["y"] == tile_pos[1] - 1: + if tile['x'] == tile_pos[0] and tile['y'] == tile_pos[1] - 1: if "max_corner" in tile: tile[key] = tile["max_corner"] + tile["translation"][-1] + 5e-3 else: @@ -100,16 +86,7 @@ def label_obstructing_tiles( return grid - -def make_non_overlapping_mask( - mask, - full_mask: bool = False, - overlap_x: int = 8, - overlap_y: int = 8, - extend_at: Optional[Literal["x", "y", "xy"]] = "xy", - add_corner: bool = True, - erosion_radius: int = 5, -): +def make_non_overlapping_mask(mask, full_mask: bool = False, overlap_x: int = 8, overlap_y: int = 8, extend_at: Optional[Literal["x", "y", "xy"]] = "xy", add_corner: bool = True, erosion_radius: int = 5): mask_np = np.array(mask) # erode the mask by a tiny amount @@ -118,47 +95,42 @@ def make_non_overlapping_mask( # Convert the mask to a binary version (0 and 1) for easier processing. bin_mask = (mask_np > 0).astype(np.uint8) - + # Sum along the vertical direction (axis=0) to get a 1D array per column. col_sums = bin_mask.sum(axis=0) - + # Find the first column that contains any white pixels. nonzero_cols = np.where(col_sums > 0)[0] if len(nonzero_cols) == 0: raise ValueError("The mask does not contain any white pixels.") x1 = nonzero_cols[0] - + # The other edge is the column with the maximum white pixels. x2 = int(np.argmax(col_sums)) - + # For column x1, get the indices (rows) where pixels are white. rows_x1 = np.where(bin_mask[:, x1] > 0)[0] y1_top, y1_bot = int(rows_x1[0]), int(rows_x1[-1]) - + # For column x2, get the indices (rows) where pixels are white. rows_x2 = np.where(bin_mask[:, x2] > 0)[0] y2_top, y2_bot = int(rows_x2[0]), int(rows_x2[-1]) # Create a copy of the mask to modify. mask_out = mask_np.copy() - d = 5 # fix white boundaries - h = abs(y1_top - y1_bot) - ( - 0 if not extend_at or "x" not in extend_at else overlap_x - ) + d = 5 # fix white boundaries + h = abs(y1_top - y1_bot) - (0 if not extend_at or "x" not in extend_at else overlap_x) if not full_mask: # Define the four corners of the quadrilateral (in (x, y) order): # Top-left, top-right, bottom-right, bottom-left. - pts = np.array( - [ - [x1, y1_top - d], - [x2, y2_top - d], - [x2, y2_top + h - d], - [x1, y1_top + h - d], - ], - dtype=np.int32, - ) - + pts = np.array([ + [x1, y1_top - d], + [x2, y2_top - d], + [x2, y2_top + h - d], + [x1, y1_top + h - d] + ], dtype=np.int32) + cv2.fillPoly(mask_out, [pts], 0) bot_slope = (y2_bot - y1_bot) / (x2 - x1) @@ -176,85 +148,62 @@ def make_non_overlapping_mask( x_int, y_int = int(x_intersect), int(y_intersect) - pts = np.array( - [[x1, y1_bot], [x_int, y_int], [x1, y1_top + h - d]], dtype=np.int32 - ) + pts = np.array([ + [x1, y1_bot], + [x_int, y_int], + [x1, y1_top + h - d] + ], dtype=np.int32) cv2.fillPoly(mask_out, [pts], 1) if extend_at and "y" in extend_at: - pts = np.array( - [ - [x1, y1_bot], - [x2, y2_bot], - [x2 - overlap_y, y2_bot + overlap_y * -top_slope], - [x1 - overlap_y, y1_bot + overlap_y * -top_slope], - ], - dtype=np.int32, - ) + pts = np.array([ + [x1, y1_bot], + [x2, y2_bot], + [x2 - overlap_y, y2_bot + overlap_y * -top_slope], + [x1 - overlap_y, y1_bot + overlap_y * -top_slope] + ], dtype=np.int32) cv2.fillPoly(mask_out, [pts], 1) - + return mask_out - - -def generate_tile_info( - blender_path, grid: List[Dict], output_folder: str, resolution: int = 1024 -): + +def generate_tile_info(blender_path, grid: List[Dict], output_folder: str, resolution: int = 1024): grid = json.loads(json.dumps(grid)) # Construct command as a list of arguments cmd = [ blender_path, - "-b", - "-P", - "blender_script.py", - "--", - "--output_folder", - output_folder, - "--resolution", - str(resolution), - "--debase", - "--export_tile_info", - "--no_render", + '-b', + '-P', 'blender_script.py', + '--', + '--output_folder', output_folder, + '--resolution', str(resolution), + '--debase', + '--export_tile_info', + '--no_render', ] if len(grid) > 0: tile_json = json.dumps(grid) - cmd.extend(["--tiles", tile_json]) + cmd.extend(['--tiles', tile_json]) # Run command with redirected output - with open(os.devnull, "wb") as devnull: + with open(os.devnull, 'wb') as devnull: subprocess.check_call(cmd, stdout=devnull, stderr=devnull) - -def render_next_tile( - blender_path, - grid: List[Dict], - output_folder: str, - resolution: int = 1024, - pos: Tuple[int, int] = (0, 0), - ortho_scale: float = 1.75, -): +def render_next_tile(blender_path, grid: List[Dict], output_folder: str, resolution: int = 1024, pos: Tuple[int, int] = (0, 0), ortho_scale: float = 1.75): grid = json.loads(json.dumps(grid)) # figure out which tiles are at most 2 tiles away (Manhattan distance) from the current tile # and only render tiles that aren't above the current tile (which would mess with the mask) - grid = [ - tile - for tile in grid - if tile["x"] <= pos[0] - and tile["y"] <= pos[1] - and not (tile["x"] == pos[0] and tile["y"] == pos[1]) - ] + grid = [tile for tile in grid if tile["x"] <= pos[0] and tile["y"] <= pos[1] and not (tile["x"] == pos[0] and tile["y"] == pos[1])] # to provide additional context for tiles x=0, we put the y-1 tile (if it exists) # at position (-1, y) as this will not be cropped out and can provide additional context if pos[0] == 0 and pos[1] > 0: # find the tile at y-1 - y_minus_1 = [ - tile for tile in grid if tile["x"] == 0 and tile["y"] == pos[1] - 1 - ] + y_minus_1 = [tile for tile in grid if tile["x"] == 0 and tile["y"] == pos[1] - 1] if len(y_minus_1) > 0: tile_dict = y_minus_1[0] grid.append({**tile_dict, "x": -1, "y": pos[1], "has_slab": False}) @@ -264,149 +213,109 @@ def render_next_tile( # Construct command as a list of arguments cmd = [ blender_path, - "-b", - "-P", - "blender_script.py", - "--", - "--output_folder", - output_folder, - "--resolution", - str(resolution), - "--debase", - f"--next_tile_at={pos[0]},{pos[1]}", + '-b', + '-P', 'blender_script.py', + '--', + '--output_folder', output_folder, + '--resolution', str(resolution), + '--debase', + f'--next_tile_at={pos[0]},{pos[1]}' ] if len(grid) > 0: - cmd.extend(["--tiles", json.dumps(grid)]) - - views = [ - { - "yaw": np.radians(-45), - "pitch": np.arctan(1 / np.sqrt(2)), - "radius": 2, - "fov": np.radians(47.1), - "ortho_scale": ortho_scale, - } - ] + cmd.extend(['--tiles', json.dumps(grid)]) + + views = [{"yaw": np.radians(-45), "pitch": np.arctan(1/np.sqrt(2)), "radius": 2, "fov": np.radians(47.1), "ortho_scale": ortho_scale}] cmd.extend(["--views", json.dumps(views)]) # Run command with redirected output - with open(os.devnull, "wb") as devnull: + with open(os.devnull, 'wb') as devnull: subprocess.check_call(cmd, stdout=devnull, stderr=devnull) - -def find_orientation_of_tile( - blender_path, - tile_dict: Dict, - conditioning_image: str, - output_folder: str, - resolution: int = 256, - rotations: Tuple[int] = (0, 90, 180, 270), -): +def find_orientation_of_tile(blender_path, tile_dict: Dict, conditioning_image: str, output_folder: str, resolution: int=256, rotations: Tuple[int] = (0, 90, 180, 270)): # place this singular tile at the origin tile_dict = json.loads(json.dumps(tile_dict)) tile_dict["x"], tile_dict["y"] = 0, 0 for rotation in rotations: tile_dict["rotation"] = rotation - tile_json = json.dumps( - [tile_dict] - ) # No need to escape quotes when using list arguments + tile_json = json.dumps([tile_dict]) # No need to escape quotes when using list arguments cmd = [ blender_path, - "-b", - "-P", - "blender_script.py", - "--", - "--output_folder", - os.path.join(output_folder, f"rot_{rotation}"), - "--resolution", - str(resolution), - "--tiles", - tile_json, - "--rgb_only", + '-b', + '-P', 'blender_script.py', + '--', + '--output_folder', os.path.join(output_folder, f"rot_{rotation}"), + '--resolution', str(resolution), + '--tiles', tile_json, + '--rgb_only' ] # Run command with redirected output - with open(os.devnull, "wb") as devnull: + with open(os.devnull, 'wb') as devnull: subprocess.check_call(cmd, stdout=devnull, stderr=devnull) # find the orientation of the tile - available_rotations = glob.glob(f"{output_folder}/rot_*/*.png") + available_rotations = glob.glob(f'{output_folder}/rot_*/*.png') rotations_dict = { - int(fn.split("rot_")[-1].split("/")[0]): Image.open(fn).convert("RGBA") - for fn in available_rotations + int(fn.split("rot_")[-1].split("/")[0]): Image.open(fn).convert('RGBA') for fn in available_rotations } - conditioning = ( - Image.open(conditioning_image).convert("RGBA").resize((resolution, resolution)) - ) - lpips_inp_cond = im2tensor(np.array(conditioning.convert("RGB"))[:, :, ::-1]).to( - "cuda" - ) + conditioning = Image.open(conditioning_image).convert('RGBA').resize((resolution, resolution)) + lpips_inp_cond = im2tensor(np.array(conditioning.convert("RGB"))[:, :, ::-1]).to("cuda") - lpips_fn = LPIPS(net="vgg").cuda() + lpips_fn = LPIPS(net='vgg').cuda() lpips_loss = { rotation: lpips_fn( lpips_inp_cond, - im2tensor(np.array(rotations_dict[rotation].convert("RGB"))[:, :, ::-1]).to( - "cuda" - ), - ).item() - for rotation in rotations_dict + im2tensor(np.array(rotations_dict[rotation].convert("RGB"))[:, :, ::-1]).to("cuda") + ).item() for rotation in rotations_dict } # clean up for rotation in rotations_dict: - shutil.rmtree(f"{output_folder}/rot_{rotation}") + shutil.rmtree(f'{output_folder}/rot_{rotation}') return min(lpips_loss, key=lpips_loss.get) - def process_mask(mask_image: Image) -> Image: - return Image.fromarray( - (np.floor(np.array(mask_image.convert("L")) / 255)).clip(0, 1).astype(np.uint8) - * 255 - ) - + return Image.fromarray((np.floor(np.array(mask_image.convert('L'))/255)).clip(0, 1).astype(np.uint8)*255) def pil_mask_to_numpy(mask: Image) -> np.ndarray: return (np.asarray(mask) / 255).astype(np.float32) - def numpy_mask_to_pil(mask: np.ndarray) -> Image: - return Image.fromarray((mask * 255).astype(np.uint8)).convert("L") - + return Image.fromarray((mask*255).astype(np.uint8)).convert('L') def inpaint_tile( - server: Inpainter | str, - prompt: str, - input_folder: str, - input_image: str, - output_image: Optional[str] = None, - seed: int = 999, - mode: Literal["single", "overlap-free"] = "single", - extend_at: Optional[str] = None, - ortho_scale: float = 1.75, - base_ortho_scale: float = 1.75, -): + server: Inpainter|str, + prompt: str, + input_folder: str, + input_image: str, + output_image: Optional[str] = None, + seed: int = 999, + mode: Literal['single', 'overlap-free'] = 'single', + extend_at: Optional[str] = None, + ortho_scale: float = 1.75, + base_ortho_scale: float = 1.75, +): if isinstance(server, str): - inpainter = Inpainter(inpainter_type="flux_local", gradio_url=server) + inpainter = Inpainter("flux_local", server) else: inpainter = server - + image_path = os.path.join(input_folder, input_image) - mask_path = image_path.replace("rgb.png", "inpaint_mask.png") + mask_path = image_path.replace('rgb.png', 'inpaint_mask.png') if output_image is None: - output_path = image_path.replace("rgb.png", "inpainted.png") + output_path = image_path.replace('rgb.png', 'inpainted.png') else: output_path = os.path.join(input_folder, output_image) - base = Image.open(image_path).convert("RGB") + base = Image.open(image_path).convert('RGB') mask = process_mask(Image.open(mask_path)) - MAX_OVERLAP = 4 # in pixels + MAX_OVERLAP = 4 # in pixels # depending on the ortho scale, we allow fewer pixels to overlap overlap = int((base_ortho_scale / ortho_scale) * MAX_OVERLAP) @@ -414,39 +323,13 @@ def inpaint_tile( # we also erode the original mask a bit erosion_radius = int((base_ortho_scale / ortho_scale) * 4) - overlap_mask = numpy_mask_to_pil( - make_non_overlapping_mask( - pil_mask_to_numpy(mask), - extend_at=extend_at, - full_mask=(mode == "single"), - overlap_x=overlap, - overlap_y=overlap // 2, - erosion_radius=erosion_radius, - ) - ) + overlap_mask = numpy_mask_to_pil(make_non_overlapping_mask(pil_mask_to_numpy(mask), extend_at=extend_at, full_mask=(mode == "single"), overlap_x=overlap, overlap_y=overlap//2, erosion_radius=erosion_radius)) overlap_mask.save(mask_path.replace("inpaint", "overlap-free")) - numpy_mask_to_pil( - make_non_overlapping_mask( - pil_mask_to_numpy(mask), - extend_at=None, - overlap_x=0, - overlap_y=0, - add_corner=False, - full_mask=(mode == "single"), - erosion_radius=erosion_radius, - ) - ).save(mask_path) + numpy_mask_to_pil(make_non_overlapping_mask(pil_mask_to_numpy(mask), extend_at=None, overlap_x=0, overlap_y=0, add_corner=False, full_mask=(mode == "single"), erosion_radius=erosion_radius)).save(mask_path) image_inpainted = inpainter(base, overlap_mask, seed, prompt) image_inpainted.save(output_path) - -def run_trellis( - pipe, - image_path, - seed=1, - mesh_path="./assets/house-tile.glb", - metric_thresholds={"squareness": 1, "slab_size": 4096, "completeness": 0.95}, -): +def run_trellis(pipe, image_path, seed=1, mesh_path='./assets/house-tile.glb', metric_thresholds={"squareness": 1, "slab_size": 4096, "completeness": 0.95}): from trellis.utils import render_utils, postprocessing_utils import torch @@ -462,30 +345,29 @@ def run_trellis( for metric_name in ("squareness", "slab_size", "completeness"): print(f"{metric_name}: {outputs[metric_name]}") - video = render_utils.render_video(outputs["scene"]["gaussian"][0])["color"] - imageio.mimsave(image_path.replace(".png", ".mp4"), video, fps=30) + video = render_utils.render_video(outputs['scene']['gaussian'][0])['color'] + imageio.mimsave(image_path.replace('.png', '.mp4'), video, fps=30) # GLB files can be extracted from the outputs glb = postprocessing_utils.to_glb( - outputs["scene"]["gaussian"][0], - outputs["scene"]["mesh"][0], + outputs['scene']['gaussian'][0], + outputs['scene']['mesh'][0], # Optional parameters - simplify=0.95, # Ratio of triangles to remove in the simplification process - texture_size=1024, # Size of the texture used for the GLB + simplify=0.95, # Ratio of triangles to remove in the simplification process + texture_size=1024, # Size of the texture used for the GLB ) glb.export(mesh_path) del glb outputs_to_save = {k: v for k, v in outputs.items() if k not in ["scene"]} - torch.save(outputs_to_save, mesh_path.replace(".glb", ".pt")) + torch.save(outputs_to_save, mesh_path.replace('.glb', '.pt')) - for k in [k for k in outputs["scene"].keys() if k != "gaussian"]: - del outputs["scene"][k] + for k in [k for k in outputs['scene'].keys() if k != 'gaussian']: + del outputs['scene'][k] return outputs - def get_widest_point_y(image, find_last=False): arr = np.array(image) @@ -526,7 +408,6 @@ def get_widest_point_y(image, find_last=False): return widest_y - def center_on_square(square_img, intricate_img): square_ground_y = get_widest_point_y(square_img) intricate_ground_y = get_widest_point_y(intricate_img) @@ -538,27 +419,14 @@ def center_on_square(square_img, intricate_img): return new_intricate - -def rebased_inpainted_tile( - inpainted_image_path, - base_slab_path, - is_left_tile: bool = True, - scale=0.85, - postfix="inpainted", - erosion_radius=5, - ortho_scale=1.75, - base_ortho_scale=1.75, - render_resolution=1024, -): +def rebased_inpainted_tile(inpainted_image_path, base_slab_path, is_left_tile: bool = True, scale=0.85, postfix="inpainted", erosion_radius=5, ortho_scale=1.75, base_ortho_scale=1.75, render_resolution=1024): if ortho_scale != base_ortho_scale: ortho_rescale = base_ortho_scale / ortho_scale crop_size = int(render_resolution * ortho_rescale) crop_size_sides = (render_resolution - crop_size) // 2 # crop the images to the same size - for fn in glob.glob( - os.path.join(os.path.dirname(inpainted_image_path), "000_*.png") - ): + for fn in glob.glob(os.path.join(os.path.dirname(inpainted_image_path), "000_*.png")): if "backup" in fn: continue @@ -566,16 +434,9 @@ def rebased_inpainted_tile( shutil.copy(fn, fn.replace(".png", "_backup.png")) img = Image.open(fn) - img = img.crop( - ( - crop_size_sides, - crop_size_sides, - crop_size_sides + crop_size, - crop_size_sides + crop_size, - ) - ) + img = img.crop((crop_size_sides, crop_size_sides, crop_size_sides + crop_size, crop_size_sides + crop_size)) img.save(fn) - + inpainted_image = Image.open(inpainted_image_path) width, height = inpainted_image.size @@ -585,38 +446,18 @@ def rebased_inpainted_tile( base_slab = base_slab.resize((width, height)) # mask out the slab - conditioning_mask = Image.open( - inpainted_image_path.replace(postfix, "conditioning_mask") - ).convert("L") - inpaint_mask = Image.open( - inpainted_image_path.replace(postfix, "inpaint_mask") - ).convert("L") + conditioning_mask = Image.open(inpainted_image_path.replace(postfix, "conditioning_mask")).convert("L") + inpaint_mask = Image.open(inpainted_image_path.replace(postfix, "inpaint_mask")).convert("L") discard_mask = PIL.ImageOps.invert(conditioning_mask) if is_left_tile: - slab_mask = Image.composite( - conditioning_mask, - Image.new("L", conditioning_mask.size, (0,)), - PIL.ImageOps.invert(inpaint_mask), - ) - slabless = Image.composite( - inpainted_image, - Image.new("RGBA", inpainted_image.size, (0, 0, 0, 255)), - PIL.ImageOps.invert(slab_mask), - ) + slab_mask = Image.composite(conditioning_mask, Image.new('L', conditioning_mask.size, (0,)), PIL.ImageOps.invert(inpaint_mask)) + slabless = Image.composite(inpainted_image, Image.new('RGBA', inpainted_image.size, (0, 0, 0, 255)), PIL.ImageOps.invert(slab_mask)) else: - slabless = Image.composite( - inpainted_image, - Image.new("RGBA", inpainted_image.size, (0, 0, 0, 255)), - inpaint_mask, - ) - - slabless = Image.composite( - slabless, - Image.new("RGBA", inpainted_image.size, (0, 0, 0, 255)), - PIL.ImageOps.invert(discard_mask), - ) + slabless = Image.composite(inpainted_image, Image.new('RGBA', inpainted_image.size, (0, 0, 0, 255)), inpaint_mask) + + slabless = Image.composite(slabless, Image.new('RGBA', inpainted_image.size, (0, 0, 0, 255)), PIL.ImageOps.invert(discard_mask)) # isolate the object isolated = rembg.remove(slabless, alpha_matting=True, post_process_mask=True) @@ -624,23 +465,17 @@ def rebased_inpainted_tile( # erode the rembg result a bit in case there are is a white gradient around the object isolated_np = (np.asarray(isolated) / 255).astype(np.float32) isolated_np[..., -1] = np.where(isolated_np[..., -1] > 0.5, 1, 0) - isolated_np[..., -1] = cv2.erode( - isolated_np[..., -1], np.ones((erosion_radius, erosion_radius), np.uint8) - ) + isolated_np[..., -1] = cv2.erode(isolated_np[..., -1], np.ones((erosion_radius, erosion_radius), np.uint8)) isolated = Image.fromarray((isolated_np * 255).astype(np.uint8)) # make sure everything on the slab is actually retained # sometimes, rembg will remove some pixels on the slab which will break # the rebasing process - slab_surface = Image.composite( - base_slab.split()[-1], Image.new("L", base_slab.size, (0,)), inpaint_mask - ) + slab_surface = Image.composite(base_slab.split()[-1], Image.new('L', base_slab.size, (0,)), inpaint_mask) # we'll erode this mask to make sure the object is not too close to the edge slab_surface_np = pil_mask_to_numpy(slab_surface) - slab_surface_np = cv2.erode( - slab_surface_np, np.ones((erosion_radius, erosion_radius), np.uint8) - ) + slab_surface_np = cv2.erode(slab_surface_np, np.ones((erosion_radius, erosion_radius), np.uint8)) slab_surface_eroded = numpy_mask_to_pil(slab_surface_np) isolated = Image.composite(inpainted_image, isolated, slab_surface_eroded) @@ -658,15 +493,15 @@ def rebased_inpainted_tile( isolated = Image.fromarray((isolated_np * 255).astype(np.uint8)) # reposition the object so it is centered again after scaling - repositioned = Image.new("RGBA", (width, height), (0, 0, 0, 0)) - + repositioned = Image.new('RGBA', (width, height), (0, 0, 0, 0)) + paste_x = (width - scaled_width) // 2 paste_y = (height - scaled_height) // 2 repositioned.paste(isolated, (paste_x, paste_y), mask=isolated) # rebase the object onto the original square, centering it - merged = Image.new("RGBA", (width, height), (0, 0, 0, 0)) + merged = Image.new('RGBA', (width, height), (0, 0, 0, 0)) merged.paste(base_slab, (0, 0)) centered_object = center_on_square(base_slab, repositioned) @@ -675,33 +510,16 @@ def rebased_inpainted_tile( return merged - -def worker( - prefix, - tile_dict, - inpainter_type, - gradio_url, - blender_path, - gpu_queue, - generated_grid, - first_tile_path, - tile_mq, - task_id, - config, - init_seed=429, - verbose=True, -): +def worker(prefix, tile_dict, gradio_url, inpainter_type, blender_path, gpu_queue, generated_grid, first_tile_path, tile_mq, task_id, config, init_seed=429, verbose=True): if not verbose: - sys.stdout = open("/dev/null", "w") - sys.stderr = open("/dev/null", "w") + sys.stdout = open("/dev/null", 'w') + sys.stderr = open("/dev/null", 'w') pos = (tile_dict["x"], tile_dict["y"]) pos_str = f"{pos[0]},{pos[1]}" gpu_id = gpu_queue.get() - tile_mq.put( - {"pos": pos, "state": States.ASSIGNED, "task_id": task_id, "gpu_id": gpu_id} - ) + tile_mq.put({"pos": pos, "state": States.ASSIGNED, "task_id": task_id, "gpu_id": gpu_id}) os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id) @@ -713,12 +531,8 @@ def worker( # create local copy of generated_grid generated_grid = json.loads(json.dumps(generated_grid)) - current_base_scale = ( - BASE_ORTHO_SCALE if (pos[0] == 0 and pos[1] == 0) else INITIAL_ORTHO_SCALE - ) - current_ortho_scale = current_base_scale + ORTHO_SCALE_Y_STEP * min( - pos[1], ORTHO_SCALE_MAX_NUM_Y_STEPS - ) + current_base_scale = BASE_ORTHO_SCALE if (pos[0] == 0 and pos[1] == 0) else INITIAL_ORTHO_SCALE + current_ortho_scale = current_base_scale + ORTHO_SCALE_Y_STEP * min(pos[1], ORTHO_SCALE_MAX_NUM_Y_STEPS) tile_path = os.path.join(prefix, pos_str) @@ -729,16 +543,12 @@ def worker( def load_pipeline_in_thread(event, queue): # Load the pipeline in a separate thread to avoid blocking the main thread - thread_pipeline = TrellisImageTo3DPipeline.from_pretrained( - "JeffreyXiang/TRELLIS-image-large" - ) + thread_pipeline = TrellisImageTo3DPipeline.from_pretrained("JeffreyXiang/TRELLIS-image-large") thread_pipeline.to("cuda") queue.put(thread_pipeline) event.set() - - pipeline_thread = threading.Thread( - target=load_pipeline_in_thread, args=(has_loaded_pipeline, pipeline_queue) - ) + + pipeline_thread = threading.Thread(target=load_pipeline_in_thread, args=(has_loaded_pipeline, pipeline_queue)) pipeline_thread.start() pipeline = None @@ -748,13 +558,7 @@ def load_pipeline_in_thread(event, queue): # render again to restore the masks to their original state tile_mq.put({"pos": pos, "state": States.RENDERING, "task_id": task_id}) - render_next_tile( - blender_path, - grid=generated_grid, - output_folder=tile_path, - pos=pos, - ortho_scale=current_ortho_scale, - ) + render_next_tile(blender_path, grid=generated_grid, output_folder=tile_path, pos=pos, ortho_scale=current_ortho_scale) mode = "single" if (pos[0] == 0 and pos[1] == 0) else "overlap-free" @@ -776,17 +580,7 @@ def load_pipeline_in_thread(event, queue): tile_mq.put({"pos": pos, "state": States.INPAINTING, "task_id": task_id}) try: - inpaint_tile( - inpainting_server, - tile_dict["prompt"], - tile_path, - "000_rgb.png", - seed=seed, - mode=mode, - extend_at=extend_at, - ortho_scale=current_ortho_scale, - base_ortho_scale=current_base_scale, - ) + inpaint_tile(inpainting_server, tile_dict["prompt"], tile_path, '000_rgb.png', seed=seed, mode=mode, extend_at=extend_at, ortho_scale=current_ortho_scale, base_ortho_scale=current_base_scale) except Exception as e: print(f"Encountered an error while inpainting tile {pos_str}: {e}") @@ -795,14 +589,7 @@ def load_pipeline_in_thread(event, queue): tile_mq.put({"pos": pos, "state": States.REBASING, "task_id": task_id}) - rebased = rebased_inpainted_tile( - os.path.join(tile_path, "000_inpainted.png"), - os.path.join(first_tile_path, "000_rgb.png"), - is_left_tile=(pos[0] == 0), - base_ortho_scale=BASE_ORTHO_SCALE, - ortho_scale=current_ortho_scale, - scale=config.rebasing_scale if hasattr(config, "rebasing_scale") else 0.85, - ) + rebased = rebased_inpainted_tile(os.path.join(tile_path, "000_inpainted.png"), os.path.join(first_tile_path, "000_rgb.png"), is_left_tile=(pos[0] == 0), base_ortho_scale=BASE_ORTHO_SCALE, ortho_scale=current_ortho_scale, scale=config.rebasing_scale if hasattr(config, "rebasing_scale") else 0.85) rebased.save(os.path.join(tile_path, "000_rebased.png")) tile_mesh_path = os.path.join(tile_path, "tile.glb") @@ -813,21 +600,15 @@ def load_pipeline_in_thread(event, queue): if not has_loaded_pipeline.is_set(): tile_mq.put({"pos": pos, "state": States.STALLED, "task_id": task_id}) has_loaded_pipeline.wait() - tile_mq.put( - {"pos": pos, "state": States.GENERATING, "task_id": task_id} - ) + tile_mq.put({"pos": pos, "state": States.GENERATING, "task_id": task_id}) if pipeline is None: pipeline = pipeline_queue.get() pipeline_thread.join() - outputs = run_trellis( - pipeline, - os.path.join(tile_path, "000_rebased.png"), - mesh_path=tile_mesh_path, - ) - gs = outputs["scene"]["gaussian"][0] + outputs = run_trellis(pipeline, os.path.join(tile_path, "000_rebased.png"), mesh_path=tile_mesh_path) + gs = outputs['scene']['gaussian'][0] break except PoorTileQualityException as e: @@ -840,7 +621,7 @@ def load_pipeline_in_thread(event, queue): tile_mq.put({"pos": pos, "state": States.CRASH, "task_id": task_id}) return - + # Clear memory after each iteration to avoid memory leaks # release the model for k in list(pipeline.models.keys()): @@ -872,23 +653,21 @@ def load_pipeline_in_thread(event, queue): "x": pos[0], "y": pos[1], "seed": seed, - **find_cuts(gaussian_path=os.path.join(tile_path, "mesh.ply")), + **find_cuts(gaussian_path=os.path.join(tile_path, "mesh.ply")) } torch.cuda.empty_cache() tile_mq.put({"pos": pos, "state": States.ORIENTING, "task_id": task_id}) - tile_dict["rotation"] = find_orientation_of_tile( - blender_path, tile_dict, os.path.join(tile_path, "000_rebased.png"), tile_path - ) + tile_dict["rotation"] = find_orientation_of_tile(blender_path, tile_dict, os.path.join(tile_path, "000_rebased.png"), tile_path) - with open(os.path.join(tile_path, "grid.json"), "w") as f: + with open(os.path.join(tile_path, "grid.json"), 'w') as f: json.dump(generated_grid, f) generate_tile_info(blender_path, [tile_dict], tile_path) - with open(os.path.join(tile_path, "tile_info.json"), "r") as f: + with open(os.path.join(tile_path, "tile_info.json"), 'r') as f: tile_info = json.load(f) tile_dict = {**tile_info[0], **tile_dict} @@ -899,37 +678,32 @@ def load_pipeline_in_thread(event, queue): return tile_dict - def main( - instructions: str = "demo.json", - prefix: str = "run_new_prompts/loop", - parallel: bool = True, - workers: int = -1, - gpu_ids: List[int] = None, - skip_existing: bool = False, - workers_per_gpu: int = 1, - seed: int = 1429, - gradio_url: str = "http://127.0.0.1:7860", - inpainter_type: Literal[ - "flux_local", "flux_replicate", "sdxl_replicate" - ] = "flux_local", - blender_path: str = "blender-3.6.19-linux-x64/blender", - resample: Tuple[int, int] = None, - resample_prompt: str = None, - **kwargs, -): - - assert ( - "CUDA_HOME" in os.environ - ), "CUDA_HOME not set. Please restart the script prefixed with 'CUDA_HOME=/path/to/cuda'" - + instructions: str = "demo.json", + prefix: str = 'run_new_prompts/loop', + parallel: bool = True, + workers: int = -1, + gpu_ids: List[int] = None, + skip_existing: bool = False, + workers_per_gpu: int = 1, + seed: int = 1429, + gradio_url: str = 'http://127.0.0.1:7860', + inpainter_type: Literal["flux_local", "flux_replicate", "sdxl_replicate"] = "flux_local", + blender_path: str = 'blender-3.6.19-linux-x64/blender', + resample: Tuple[int, int] = None, + resample_prompt: str = None, + **kwargs + ): + + assert "CUDA_HOME" in os.environ, "CUDA_HOME not set. Please restart the script prefixed with 'CUDA_HOME=/path/to/cuda'" + os.makedirs(prefix, exist_ok=True) config = edict(kwargs) instructions = process_instructions(instructions) - with open(os.path.join(prefix, "instructions.json"), "w") as f: + with open(os.path.join(prefix, "instructions.json"), 'w') as f: json.dump(instructions["tiles"], f, indent=4) if resample is not None and isinstance(resample, str): @@ -937,25 +711,21 @@ def main( def has_instructions(x, y): return any(tile["x"] == x and tile["y"] == y for tile in instructions["tiles"]) - + def unlock_adjacent_tiles(state_grid, pos): # unlock the next tile on the x axis (but only if we are not waiting for a tile below) - if has_instructions(pos[0] + 1, pos[1]) and ( - (pos[1] - 1) < 0 or state_grid[pos[0] + 1][pos[1] - 1] == States.DONE - ): + if has_instructions(pos[0] + 1, pos[1]) and ((pos[1] - 1) < 0 or state_grid[pos[0] + 1][pos[1] - 1] == States.DONE): if state_grid[pos[0] + 1][pos[1]] == States.BLOCKED: state_grid[pos[0] + 1][pos[1]] = States.READY # unlock the next tile on the y axis, but only if we are the first tile (or we were waiting) - if has_instructions(pos[0], pos[1] + 1) and ( - pos[0] == 0 or state_grid[pos[0] - 1][pos[1] + 1] == States.DONE - ): + if has_instructions(pos[0], pos[1] + 1) and (pos[0] == 0 or state_grid[pos[0] - 1][pos[1] + 1] == States.DONE): if state_grid[pos[0]][pos[1] + 1] == States.BLOCKED: state_grid[pos[0]][pos[1] + 1] = States.READY if skip_existing or resample is not None: # load the grid from disk - with open(os.path.join(prefix, "grid.json"), "r") as f: + with open(os.path.join(prefix, "grid.json"), 'r') as f: generated_grid = json.load(f) else: @@ -963,7 +733,7 @@ def unlock_adjacent_tiles(state_grid, pos): first_tile_path = os.path.join(prefix, "0,0") - multiprocessing.set_start_method("forkserver", force=True) + multiprocessing.set_start_method('forkserver', force=True) if gpu_ids is None: gpu_ids = os.environ.get("CUDA_VISIBLE_DEVICES", "0").split(",") @@ -987,19 +757,10 @@ def unlock_adjacent_tiles(state_grid, pos): parallel = False for tile_instruction in instructions["tiles"]: - if ( - tile_instruction["x"] == resample[0] - and tile_instruction["y"] == resample[1] - ): - tile_instruction["prompt"] = instructions["prompt"].format( - **{"tile_prompt": resample_prompt} - ) - - generated_grid = [ - tile - for tile in generated_grid - if not (tile["x"] == resample[0] and tile["y"] == resample[1]) - ] + if tile_instruction["x"] == resample[0] and tile_instruction["y"] == resample[1]: + tile_instruction["prompt"] = instructions["prompt"].format(**{"tile_prompt": resample_prompt}) + + generated_grid = [tile for tile in generated_grid if not (tile["x"] == resample[0] and tile["y"] == resample[1])] if not parallel: # emit all messages @@ -1017,13 +778,11 @@ def unlock_adjacent_tiles(state_grid, pos): # manually unlock the next tiles for tile in generated_grid: state_grid[tile["x"]][tile["y"]] = States.DONE - tile_mq.put( - {"pos": (tile["x"], tile["y"]), "state": States.DONE, "task_id": -1} - ) + tile_mq.put({"pos": (tile["x"], tile["y"]), "state": States.DONE, "task_id": -1}) unlock_adjacent_tiles(state_grid, (tile["x"], tile["y"])) if resample_prompt is not None: - with open(os.path.join(prefix, "instructions.json"), "w") as f: + with open(os.path.join(prefix, "instructions.json"), 'w') as f: json.dump(instructions["tiles"], f, indent=4) if parallel: @@ -1032,15 +791,8 @@ def unlock_adjacent_tiles(state_grid, pos): results = [] def announce_crash(tile, task_id, e): - tile_mq.put( - { - "pos": (tile["x"], tile["y"]), - "state": States.CRASH, - "task_id": task_id, - "error": str(e), - } - ) - + tile_mq.put({"pos": (tile["x"], tile["y"]), "state": States.CRASH, "task_id": task_id, "error": str(e)}) + def queue_available_jobs(): for tile in instructions["tiles"]: if state_grid[tile["x"]][tile["y"]] != States.READY: @@ -1051,42 +803,12 @@ def queue_available_jobs(): task_id = len(results) if parallel: - res = pool.apply_async( - worker, - ( - prefix, - tile, - gradio_url, - blender_path, - gpu_queue, - generated_grid, - first_tile_path, - tile_mq, - task_id, - config, - seed, - ), - error_callback=partial(announce_crash, tile, task_id), - ) + res = pool.apply_async(worker, (prefix, tile, gradio_url, inpainter_type, blender_path, gpu_queue, generated_grid, first_tile_path, tile_mq, task_id, config, seed), error_callback=partial(announce_crash, tile, task_id)) else: # peek at the gpu queue gpu_id = gpu_queue.get() gpu_queue.put(gpu_id) - res = worker( - prefix, - tile, - inpainter_type, - gradio_url, - blender_path, - gpu_queue, - generated_grid, - first_tile_path, - tile_mq, - task_id, - config, - seed, - verbose=True, - ) + res = worker(prefix, tile, gradio_url, inpainter_type, blender_path, gpu_queue, generated_grid, first_tile_path, tile_mq, task_id, config, seed, verbose=True) gpu_queue.put(gpu_id) results.append(res) @@ -1095,10 +817,7 @@ def queue_available_jobs(): while True: # we are done! - if all( - all(state == States.DONE for state in row.values()) - for row in state_grid.values() - ): + if all(all(state == States.DONE for state in row.values()) for row in state_grid.values()): break message = tile_mq.get() @@ -1116,9 +835,7 @@ def queue_available_jobs(): if state == States.DONE: if task_id != -1: - generated_grid.append( - results[task_id].get() if parallel else results[task_id] - ) + generated_grid.append(results[task_id].get() if parallel else results[task_id]) # release the gpu gpu_queue.put(gpu_logbook[task_id]) del gpu_logbook[task_id] @@ -1126,7 +843,7 @@ def queue_available_jobs(): unlock_adjacent_tiles(state_grid, pos) # write the grid to disk - with open(os.path.join(prefix, "grid.json"), "w") as f: + with open(os.path.join(prefix, "grid.json"), 'w') as f: json.dump(generated_grid, f, indent=4) elif state == States.CRASH: @@ -1144,8 +861,6 @@ def queue_available_jobs(): m.shutdown() - -if __name__ == "__main__": +if __name__ == '__main__': import fire - - fire.Fire(main) + fire.Fire(main) \ No newline at end of file