Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .env.example
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
REPLICATE_API_TOKEN=
13 changes: 11 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,7 @@ FodyWeavers.xsd
*.sln.iml

# env files from the flux server
FLUX_inpainting_server/env/lib
inpainting/flux_inpainting_server/env/lib
tmp*

# SLURM output
Expand All @@ -411,4 +411,13 @@ wandb/*
*.pt
*.pth

.vscode/
.vscode/

# Environment variables
.env

# Output scenes
scenes

# Blender
blender-3.6*/
25 changes: 23 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ SynCity generates complex and immersive 3D worlds from text prompts and does not

### Prerequisites
* **System**: The code was tested on Ubuntu 22.04. We expect it to run on other Linux-based distributions too.
* **Hardware**: An NVIDIA GPU with at least **48GB** of memory is required. We have used A40 and A6000 GPUs.
* **Hardware**: If the inpainting server is deployed locally, an NVIDIA GPU with at least **48GB** of memory is required. We have used A40 and A6000 GPUs. If you use the inpainting service from replicate, you will need a GPU with at least **16GB** of memory for trellis generation.
* **Software**:
- The [CUDA Toolkit](https://developer.nvidia.com/cuda-toolkit-archive) is needed to compile certain submodules. We have tested CUDA versions 11.8 and 12.4.
- [Conda](https://docs.anaconda.com/miniconda/install/#quick-command-line-install) is used to create the environment to run the code. This environment uses Python version 3.10.
Expand All @@ -33,10 +33,23 @@ source ./setup.sh --new-env --basic --xformers --diffoctreerast --spconv --mipga
```
Make sure to have set the environment variable `CUDA_HOME`, which should point to your CUDA Toolkit installation. If you run into issues while running this setup script, please refer to [the README in the TRELLIS](https://github.com/microsoft/TRELLIS/blob/main/README.md#installation-steps) repository, which provides additional guidance.

3. Set up the FLUX inpainting server:

3. Set up the FLUX inpainting backend:

**Option1**: Set up the FLUX inpainting server locally (the server will require around **30GB+ VRAM**):
```
./inpainting_server.sh --install
```
**Option2**: Use replicate web deployment (pay as you go, about 0.03$ for every image and 0.4$ every 3x3 scene).
```
cp .env.example .env

# fill in your replicate id, which can be obtained here
# https://replicate.com/account/api-tokens

# install requirements for replicate
pip install dotenv replicate
```

4. Download [Blender 3.6.19](https://www.blender.org/download/release/Blender3.6/blender-3.6.19-linux-x64.tar.xz), extract it into the root directory of this project, and make sure `blender-3.6.19-linux-x64/blender` can be executed on your system.

Expand All @@ -51,15 +64,23 @@ The process to generate a world is split into two straightforward steps.
### Step 1: Generating Tiles
The tiles are generated using an instruction file, which contain the prompts to generate each tile (see some examples in the `instructions` folder). To generate a set of tiles that will be saved to `scenes/solarpunk`, run:
```
# option1: locally deploy the inpainting server
python run_pipeline.py --instructions instructions/3x3/solarpunk.json --prefix scenes/solarpunk --gradio_url=http://127.0.0.1:7860

# option2: use the replicate inpainting service
python run_pipeline.py --instructions instructions/3x3/solarpunk.json --prefix scenes/solarpunk --parallel=False --inpainter_type=flux_replicate
```

This script will parallelize tile generation where possible if multiple GPUs are available. If the script is stalling for longer than a minute, consider running the tile generation synchronously (`--parallel=False`). Furthermore, if a single tile keeps being regenerated, consider interrupting the script and replacing the offending tile's prompt. Then, restart the script with `--skip_existing=True` to ensure it will not overwrite existing tiles. Alternatively, see the ["Advanced Usage"](#advanced_usage) section on how to adjust the tile rejection criteria.

### Step 2: Blending Tiles
To create smooth transitions between tiles and refine their boundary regions, run the blending script:
```
# option1: locally deploy the inpainting server
python blend_gaussians.py --compute_rescaled --stitch_images --stitch_slats --gradio_url=http://127.0.0.1:7860 --prefix scenes/solarpunk

# option2: use the replicate inpainting service
python blend_gaussians.py --compute_rescaled --stitch_images --stitch_slats --inpainter_type=flux_replicate --prefix scenes/solarpunk
```

This script will create a `.ply` file with the Gaussians of the entire grid as well as a video rendering.
Expand Down
6 changes: 4 additions & 2 deletions blend_gaussians.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import torch.nn.functional as F
from PIL import Image
from tqdm.auto import tqdm
from typing import Literal

import trellis.models as models
from tile_cutting import z_preserving_crop
Expand Down Expand Up @@ -477,6 +478,7 @@ def merge_gaussians(
stitch_images: bool = False,
stitch_slats: bool = False,
use_cached: bool = False,
inpainter_type: Literal["flux_local", "flux_replicate", "sdxl_replicate"] = "flux_local",
gradio_url='http://127.0.0.1:7860',
blender_path: str = 'blender-3.6.19-linux-x64/blender',
seed: int = 429
Expand Down Expand Up @@ -577,9 +579,9 @@ def merge_gaussians(
rescaled_tiles = dill.load(open(os.path.join(grid_path, 'rescaled_tiles.pkl'), 'rb'))

if stitch_images:
from FLUX_inpainting_server.inpaint import Inpainter
from inpainting import Inpainter
import time
inpainter = Inpainter(gradio_url)
inpainter = Inpainter(inpainter_type, gradio_url)

VIEW_TYPE = 'zoom_out'
prompts = json.load(open(os.path.join(grid_path, 'instructions.json')))
Expand Down
22 changes: 22 additions & 0 deletions inpainting/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from dotenv import load_dotenv
from typing import Literal
from PIL import Image
from .flux_inpainter_server.inpaint import Inpainter as FluxInpainter
from .replicate_inpainter import ReplicateFluxInpainter, ReplicateSDXLInpainter

# load api keys of replicate
load_dotenv()

class Inpainter:
def __init__(self, inpainter_type: Literal["flux_local", "flux_replicate", "sdxl_replicate"] = "flux_local", gradio_url: str = ""):
if inpainter_type == "flux_local":
self.inpainter = FluxInpainter(gradio_url)
elif inpainter_type == "flux_replicate":
self.inpainter = ReplicateFluxInpainter()
elif inpainter_type == "sdxl_replicate":
self.inpainter = ReplicateSDXLInpainter()
else:
raise ValueError(f"Invalid inpainter_type: {inpainter_type}")

def __call__(self, image:Image.Image, mask:Image.Image, seed:int, prompt:str):
return self.inpainter(image, mask, seed, prompt)
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
image_path='https://huggingface.co/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha/resolve/main/images/bucket.png',
mask_path='https://huggingface.co/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha/resolve/main/images/bucket_mask.jpeg',
prompt='a person wearing a white shoe, carrying a white bucket with text "FLUX" on it'
prompt_detailed = 'an ivy-covered red brick building with classical columns and arched windows, on top of a base, east coast university, ivy-clad red brick buildings, cobblestone paths, gentle autumn light, soft warm lighting, realistic textures, subtle gradients, isometric perspective, academic charm, and meticulous detailing'

# Build pipeline
controlnet = FluxControlNetModel.from_pretrained("alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha", torch_dtype=torch.bfloat16)
Expand Down Expand Up @@ -48,3 +49,20 @@

result.save('flux_inpaint.png')
print("Successfully inpaint image")

result = pipe(
prompt=prompt_detailed,
height=size[1],
width=size[0],
control_image=image,
control_mask=mask,
num_inference_steps=28,
generator=generator,
controlnet_conditioning_scale=0.9,
guidance_scale=3.5,
negative_prompt="",
true_guidance_scale=3.5
).images[0]

result.save('flux_inpaint_detailed.png')
print("Successfully inpaint image")
4 changes: 4 additions & 0 deletions inpainting/replicate_inpainter/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .sdxl import ReplicateSDXLInpainter
from .flux import ReplicateFluxInpainter

__all__ = ["ReplicateSDXLInpainter", "ReplicateFluxInpainter"]
52 changes: 52 additions & 0 deletions inpainting/replicate_inpainter/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import os, os.path as osp
import replicate
import numpy as np
from abc import ABC, abstractmethod
from PIL import Image

TMP_DIR = "./tmp_flux"
TMP_PATH = osp.join(TMP_DIR, "result.png")

class BaseReplicateInpainter(ABC):
REPLICATE_ID = ""
def __init__(self) -> None:
pass

@abstractmethod
def _build_extra_inputs(self):
pass

def run(self, image: Image, mask: Image, seed: int, prompt: str):
inputs = {
"image": image,
"mask": mask,
"seed": seed,
"prompt": prompt,
}
inputs.update(self._build_extra_inputs())
print(inputs.keys())
return replicate.run(
self.REPLICATE_ID,
input=inputs
)

def __call__(self, image: Image, mask: Image, seed: int, prompt: str):
os.makedirs(TMP_DIR, exist_ok=True)
image = image.convert("RGB")
mask_rgb = Image.new('RGB', mask.size)
mask_rgb.paste(mask)
image_tmp_path = osp.join(TMP_DIR, "image.png")
mask_tmp_path = osp.join(TMP_DIR, "mask.png")
image.save(image_tmp_path)
mask_rgb.save(mask_tmp_path)

print(np.array(mask_rgb).shape, np.array(image).shape)
# exit(0)

output = self.run(open(image_tmp_path, "rb"), open(mask_tmp_path, "rb"), seed, prompt)
# output = self.run(image, mask, seed, prompt)
assert len(output) == 1

with open(TMP_PATH, "wb") as file:
file.write(output[0].read())
return Image.open(TMP_PATH)
21 changes: 21 additions & 0 deletions inpainting/replicate_inpainter/flux.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from .base import BaseReplicateInpainter


class ReplicateFluxInpainter(BaseReplicateInpainter):
"""
Replicate playground: https://replicate.com/black-forest-labs/flux-fill-dev
Speed and Cost: 0.04$ / image, 9.6s / image
"""

REPLICATE_ID = "fishwowater/flux-dev-controlnet-inpainting-beta:27d3ff35f58b4409775de5a0b36e99b4c6d2d7fc7fe772b35170951db678ec63"

def _build_extra_inputs(self):
return {
# default guidance scale
"guidance_scale": 3.5,
"true_guidance_scale": 3.5,
"controlnet_conditioning_scale": 0.9,
# default values for inference steps
"num_inference_steps": 24,
"output_quality": 100,
}
19 changes: 19 additions & 0 deletions inpainting/replicate_inpainter/sdxl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from .base import BaseReplicateInpainter



class ReplicateSDXLInpainter(BaseReplicateInpainter):
"""
Replicate playground: https://replicate.com/lucataco/sdxl-inpainting
Speed and Cost: 0.0023$ / image, 1.9s / image
"""

REPLICATE_ID = "lucataco/sdxl-inpainting:a5b13068cc81a89a4fbeefeccc774869fcb34df4dbc92c1555e0f2771d49dde7"

def _build_extra_inputs(self):
return {
# default values on the playground
"guidance_scale": 8.0,
"steps": 20,
"strength": 0.7,
}
2 changes: 1 addition & 1 deletion inpainting_server.sh
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#!/bin/bash
# Installation as shown here: https://huggingface.co/spaces/ameerazam08/FLUX.1-dev-Inpainting-Model-Beta-GPU?docker=true
cd FLUX_inpainting_server
cd inpainting_server/flux_inpainter
python -m venv env
source env/bin/activate

Expand Down
15 changes: 8 additions & 7 deletions run_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from lpips import LPIPS, im2tensor
from PIL import Image

from FLUX_inpainting_server.inpaint import Inpainter
from inpainting import Inpainter

# a brief explanation of the orthographic scale:
# the ortho scale determines the size of the world in the image
Expand Down Expand Up @@ -301,7 +301,7 @@ def inpaint_tile(
base_ortho_scale: float = 1.75,
):
if isinstance(server, str):
inpainter = Inpainter(server)
inpainter = Inpainter("flux_local", server)
else:
inpainter = server

Expand Down Expand Up @@ -510,7 +510,7 @@ def rebased_inpainted_tile(inpainted_image_path, base_slab_path, is_left_tile: b

return merged

def worker(prefix, tile_dict, gradio_url, blender_path, gpu_queue, generated_grid, first_tile_path, tile_mq, task_id, config, init_seed=429, verbose=True):
def worker(prefix, tile_dict, gradio_url, inpainter_type, blender_path, gpu_queue, generated_grid, first_tile_path, tile_mq, task_id, config, init_seed=429, verbose=True):
if not verbose:
sys.stdout = open("/dev/null", 'w')
sys.stderr = open("/dev/null", 'w')
Expand Down Expand Up @@ -575,7 +575,7 @@ def load_pipeline_in_thread(event, queue):

seed = init_seed + task_id + attempts

inpainting_server = Inpainter(gradio_url)
inpainting_server = Inpainter(inpainter_type, gradio_url)

tile_mq.put({"pos": pos, "state": States.INPAINTING, "task_id": task_id})

Expand Down Expand Up @@ -688,6 +688,7 @@ def main(
workers_per_gpu: int = 1,
seed: int = 1429,
gradio_url: str = 'http://127.0.0.1:7860',
inpainter_type: Literal["flux_local", "flux_replicate", "sdxl_replicate"] = "flux_local",
blender_path: str = 'blender-3.6.19-linux-x64/blender',
resample: Tuple[int, int] = None,
resample_prompt: str = None,
Expand Down Expand Up @@ -802,12 +803,12 @@ def queue_available_jobs():
task_id = len(results)

if parallel:
res = pool.apply_async(worker, (prefix, tile, gradio_url, blender_path, gpu_queue, generated_grid, first_tile_path, tile_mq, task_id, config, seed), error_callback=partial(announce_crash, tile, task_id))
res = pool.apply_async(worker, (prefix, tile, gradio_url, inpainter_type, blender_path, gpu_queue, generated_grid, first_tile_path, tile_mq, task_id, config, seed), error_callback=partial(announce_crash, tile, task_id))
else:
# peek at the gpu queue
gpu_id = gpu_queue.get()
gpu_queue.put(gpu_id)
res = worker(prefix, tile, gradio_url, blender_path, gpu_queue, generated_grid, first_tile_path, tile_mq, task_id, config, seed, verbose=True)
res = worker(prefix, tile, gradio_url, inpainter_type, blender_path, gpu_queue, generated_grid, first_tile_path, tile_mq, task_id, config, seed, verbose=True)
gpu_queue.put(gpu_id)

results.append(res)
Expand Down Expand Up @@ -862,4 +863,4 @@ def queue_available_jobs():

if __name__ == '__main__':
import fire
fire.Fire(main)
fire.Fire(main)