diff --git a/.gitmodules b/.gitmodules
index c6b0a7b..d215578 100644
--- a/.gitmodules
+++ b/.gitmodules
@@ -8,3 +8,8 @@
url = https://github.com/TrickyGo/Pano2Room.git
branch = main
shallow = true
+[submodule "thirdparty/sam3d"]
+ path = thirdparty/sam3d
+ url = https://github.com/HochCC/sam-3d-objects.git
+ branch = main
+ shallow = true
diff --git a/README.md b/README.md
index 80902a6..9a14480 100644
--- a/README.md
+++ b/README.md
@@ -37,11 +37,12 @@
```sh
git clone https://github.com/HorizonRobotics/EmbodiedGen.git
cd EmbodiedGen
-git checkout v0.1.6
+git checkout v0.1.7
git submodule update --init --recursive --progress
conda create -n embodiedgen python=3.10.13 -y # recommended to use a new env.
conda activate embodiedgen
-bash install.sh basic
+bash install.sh basic # around 20 mins
+# Optional: `bash install.sh extra` for scene3d-cli
```
### ✅ Starting from Docker
@@ -94,12 +95,14 @@ CUDA_VISIBLE_DEVICES=0 nohup python apps/image_to_3d.py > /dev/null 2>&1 &
### ⚡ API
Generate physically plausible 3D assets from image input via the command-line API.
```sh
-img3d-cli --image_path apps/assets/example_image/sample_00.jpg apps/assets/example_image/sample_01.jpg apps/assets/example_image/sample_19.jpg \
+img3d-cli --image_path apps/assets/example_image/sample_00.jpg apps/assets/example_image/sample_01.jpg \
--n_retry 1 --output_root outputs/imageto3d
# See result(.urdf/mesh.obj/mesh.glb/gs.ply) in ${output_root}/sample_xx/result
```
+Support the use of [SAM3D](https://github.com/facebookresearch/sam-3d-objects) or [TRELLIS](https://github.com/microsoft/TRELLIS) as 3D generation model, modify `IMAGE3D_MODEL` in `embodied_gen/scripts/imageto3d.py` to switch model.
+
---
@@ -133,7 +136,7 @@ text3d-cli --prompts "small bronze figurine of a lion" "A globe with wooden base
Text-to-image model based on the Kolors model.
```sh
bash embodied_gen/scripts/textto3d.sh \
- --prompts "small bronze figurine of a lion" "A globe with wooden base and latitude and longitude lines" "橙色电动手钻,有磨损细节" \
+ --prompts "A globe with wooden base and latitude and longitude lines" "橙色电动手钻,有磨损细节" \
--output_root outputs/textto3d_k
```
ps: models with more permissive licenses found in `embodied_gen/models/image_comm_model.py`
@@ -191,7 +194,11 @@ CUDA_VISIBLE_DEVICES=0 scene3d-cli \
⚙️ Articulated Object Generation
-🚧 *Coming Soon*
+See our paper published in NeurIPS 2025.
+[[Arxiv Paper]](https://arxiv.org/abs/2505.20460) |
+[[Gradio Demo]](https://huggingface.co/spaces/HorizonRobotics/DIPO) |
+[[Code]](https://github.com/RQ-Wu/DIPO)
+
@@ -239,6 +246,7 @@ Remove `--insert_robot` if you don't consider the robot pose in layout generatio
CUDA_VISIBLE_DEVICES=0 nohup layout-cli \
--task_descs "apps/assets/example_layout/task_list.txt" \
--bg_list "outputs/bg_scenes/scene_list.txt" \
+--n_image_retry 4 --n_asset_retry 3 --n_pipe_retry 2 \
--output_root "outputs/layouts_gens" --insert_robot > layouts_gens.log &
```
@@ -325,7 +333,7 @@ If you use EmbodiedGen in your research or projects, please cite:
## 🙌 Acknowledgement
EmbodiedGen builds upon the following amazing projects and models:
-🌟 [Trellis](https://github.com/microsoft/TRELLIS) | 🌟 [Hunyuan-Delight](https://huggingface.co/tencent/Hunyuan3D-2/tree/main/hunyuan3d-delight-v2-0) | 🌟 [Segment Anything](https://github.com/facebookresearch/segment-anything) | 🌟 [Rembg](https://github.com/danielgatis/rembg) | 🌟 [RMBG-1.4](https://huggingface.co/briaai/RMBG-1.4) | 🌟 [Stable Diffusion x4](https://huggingface.co/stabilityai/stable-diffusion-x4-upscaler) | 🌟 [Real-ESRGAN](https://github.com/xinntao/Real-ESRGAN) | 🌟 [Kolors](https://github.com/Kwai-Kolors/Kolors) | 🌟 [ChatGLM3](https://github.com/THUDM/ChatGLM3) | 🌟 [Aesthetic Score](http://captions.christoph-schuhmann.de/aesthetic_viz_laion_sac+logos+ava1-l14-linearMSE-en-2.37B.html) | 🌟 [Pano2Room](https://github.com/TrickyGo/Pano2Room) | 🌟 [Diffusion360](https://github.com/ArcherFMY/SD-T2I-360PanoImage) | 🌟 [Kaolin](https://github.com/NVIDIAGameWorks/kaolin) | 🌟 [diffusers](https://github.com/huggingface/diffusers) | 🌟 [gsplat](https://github.com/nerfstudio-project/gsplat) | 🌟 [QWEN-2.5VL](https://github.com/QwenLM/Qwen2.5-VL) | 🌟 [GPT4o](https://platform.openai.com/docs/models/gpt-4o) | 🌟 [SD3.5](https://huggingface.co/stabilityai/stable-diffusion-3.5-medium) | 🌟 [ManiSkill](https://github.com/haosulab/ManiSkill)
+🌟 [Trellis](https://github.com/microsoft/TRELLIS) | 🌟 [Hunyuan-Delight](https://huggingface.co/tencent/Hunyuan3D-2/tree/main/hunyuan3d-delight-v2-0) | 🌟 [Segment Anything](https://github.com/facebookresearch/segment-anything) | 🌟 [Rembg](https://github.com/danielgatis/rembg) | 🌟 [RMBG-1.4](https://huggingface.co/briaai/RMBG-1.4) | 🌟 [Stable Diffusion x4](https://huggingface.co/stabilityai/stable-diffusion-x4-upscaler) | 🌟 [Real-ESRGAN](https://github.com/xinntao/Real-ESRGAN) | 🌟 [Kolors](https://github.com/Kwai-Kolors/Kolors) | 🌟 [ChatGLM3](https://github.com/THUDM/ChatGLM3) | 🌟 [Aesthetic Score](http://captions.christoph-schuhmann.de/aesthetic_viz_laion_sac+logos+ava1-l14-linearMSE-en-2.37B.html) | 🌟 [Pano2Room](https://github.com/TrickyGo/Pano2Room) | 🌟 [Diffusion360](https://github.com/ArcherFMY/SD-T2I-360PanoImage) | 🌟 [Kaolin](https://github.com/NVIDIAGameWorks/kaolin) | 🌟 [diffusers](https://github.com/huggingface/diffusers) | 🌟 [gsplat](https://github.com/nerfstudio-project/gsplat) | 🌟 [QWEN-2.5VL](https://github.com/QwenLM/Qwen2.5-VL) | 🌟 [GPT4o](https://platform.openai.com/docs/models/gpt-4o) | 🌟 [SD3.5](https://huggingface.co/stabilityai/stable-diffusion-3.5-medium) | 🌟 [ManiSkill](https://github.com/haosulab/ManiSkill) | 🌟 [SAM3D](https://github.com/facebookresearch/sam-3d-objects)
---
diff --git a/apps/app_style.py b/apps/app_style.py
index a552f9f..313ccd1 100644
--- a/apps/app_style.py
+++ b/apps/app_style.py
@@ -1,10 +1,26 @@
+# Project EmbodiedGen
+#
+# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+# implied. See the License for the specific language governing
+# permissions and limitations under the License.
+
from gradio.themes import Soft
from gradio.themes.utils.colors import gray, neutral, slate, stone, teal, zinc
lighting_css = """
"""
diff --git a/apps/common.py b/apps/common.py
index 3fcac18..55e30da 100644
--- a/apps/common.py
+++ b/apps/common.py
@@ -14,6 +14,11 @@
# implied. See the License for the specific language governing
# permissions and limitations under the License.
+import spaces
+from embodied_gen.utils.monkey_patches import monkey_path_trellis
+
+monkey_path_trellis()
+
import gc
import logging
import os
@@ -25,18 +30,16 @@
import cv2
import gradio as gr
import numpy as np
-import spaces
import torch
-import torch.nn.functional as F
import trimesh
-from easydict import EasyDict as edict
from PIL import Image
from embodied_gen.data.backproject_v2 import entrypoint as backproject_api
from embodied_gen.data.backproject_v3 import entrypoint as backproject_api_v3
from embodied_gen.data.differentiable_render import entrypoint as render_api
-from embodied_gen.data.utils import resize_pil, trellis_preprocess, zip_files
+from embodied_gen.data.utils import trellis_preprocess, zip_files
from embodied_gen.models.delight_model import DelightingModel
from embodied_gen.models.gs_model import GaussianOperator
+from embodied_gen.models.sam3d import Sam3dInference
from embodied_gen.models.segment_model import (
BMGG14Remover,
RembgRemover,
@@ -53,10 +56,11 @@
from embodied_gen.utils.gpt_clients import GPT_CLIENT
from embodied_gen.utils.process_media import (
filter_image_small_connected_components,
+ keep_largest_connected_component,
merge_images_video,
)
from embodied_gen.utils.tags import VERSION
-from embodied_gen.utils.trender import render_video
+from embodied_gen.utils.trender import pack_state, render_video, unpack_state
from embodied_gen.validators.quality_checkers import (
BaseChecker,
ImageAestheticChecker,
@@ -69,15 +73,6 @@
current_dir = os.path.dirname(current_file_path)
sys.path.append(os.path.join(current_dir, ".."))
from thirdparty.TRELLIS.trellis.pipelines import TrellisImageTo3DPipeline
-from thirdparty.TRELLIS.trellis.representations import (
- Gaussian,
- MeshExtractResult,
-)
-from thirdparty.TRELLIS.trellis.representations.gaussian.general_utils import (
- build_scaling_rotation,
- inverse_sigmoid,
- strip_symmetric,
-)
from thirdparty.TRELLIS.trellis.utils import postprocessing_utils
logging.basicConfig(
@@ -85,64 +80,24 @@
)
logger = logging.getLogger(__name__)
-
-os.environ["TORCH_EXTENSIONS_DIR"] = os.path.expanduser(
- "~/.cache/torch_extensions"
-)
os.environ["GRADIO_ANALYTICS_ENABLED"] = "false"
-os.environ["SPCONV_ALGO"] = "native"
MAX_SEED = 100000
-
-def patched_setup_functions(self):
- def inverse_softplus(x):
- return x + torch.log(-torch.expm1(-x))
-
- def build_covariance_from_scaling_rotation(
- scaling, scaling_modifier, rotation
- ):
- L = build_scaling_rotation(scaling_modifier * scaling, rotation)
- actual_covariance = L @ L.transpose(1, 2)
- symm = strip_symmetric(actual_covariance)
- return symm
-
- if self.scaling_activation_type == "exp":
- self.scaling_activation = torch.exp
- self.inverse_scaling_activation = torch.log
- elif self.scaling_activation_type == "softplus":
- self.scaling_activation = F.softplus
- self.inverse_scaling_activation = inverse_softplus
-
- self.covariance_activation = build_covariance_from_scaling_rotation
- self.opacity_activation = torch.sigmoid
- self.inverse_opacity_activation = inverse_sigmoid
- self.rotation_activation = F.normalize
-
- self.scale_bias = self.inverse_scaling_activation(
- torch.tensor(self.scaling_bias)
- ).to(self.device)
- self.rots_bias = torch.zeros((4)).to(self.device)
- self.rots_bias[0] = 1
- self.opacity_bias = self.inverse_opacity_activation(
- torch.tensor(self.opacity_bias)
- ).to(self.device)
-
-
-Gaussian.setup_functions = patched_setup_functions
-
-
# DELIGHT = DelightingModel()
# IMAGESR_MODEL = ImageRealESRGAN(outscale=4)
# IMAGESR_MODEL = ImageStableSR()
-if os.getenv("GRADIO_APP") == "imageto3d":
+if os.getenv("GRADIO_APP").startswith("imageto3d"):
RBG_REMOVER = RembgRemover()
RBG14_REMOVER = BMGG14Remover()
SAM_PREDICTOR = SAMPredictor(model_type="vit_h", device="cpu")
- PIPELINE = TrellisImageTo3DPipeline.from_pretrained(
- "microsoft/TRELLIS-image-large"
- )
- # PIPELINE.cuda()
+ if "sam3d" in os.getenv("GRADIO_APP"):
+ PIPELINE = Sam3dInference()
+ else:
+ PIPELINE = TrellisImageTo3DPipeline.from_pretrained(
+ "microsoft/TRELLIS-image-large"
+ )
+ # PIPELINE.cuda()
SEG_CHECKER = ImageSegChecker(GPT_CLIENT)
GEO_CHECKER = MeshGeoChecker(GPT_CLIENT)
AESTHETIC_CHECKER = ImageAestheticChecker()
@@ -151,13 +106,16 @@ def build_covariance_from_scaling_rotation(
os.path.dirname(os.path.abspath(__file__)), "sessions/imageto3d"
)
os.makedirs(TMP_DIR, exist_ok=True)
-elif os.getenv("GRADIO_APP") == "textto3d":
+elif os.getenv("GRADIO_APP").startswith("textto3d"):
RBG_REMOVER = RembgRemover()
RBG14_REMOVER = BMGG14Remover()
- PIPELINE = TrellisImageTo3DPipeline.from_pretrained(
- "microsoft/TRELLIS-image-large"
- )
- # PIPELINE.cuda()
+ if "sam3d" in os.getenv("GRADIO_APP"):
+ PIPELINE = Sam3dInference()
+ else:
+ PIPELINE = TrellisImageTo3DPipeline.from_pretrained(
+ "microsoft/TRELLIS-image-large"
+ )
+ # PIPELINE.cuda()
text_model_dir = "weights/Kolors"
PIPELINE_IMG_IP = build_text2img_ip_pipeline(text_model_dir, ref_scale=0.3)
PIPELINE_IMG = build_text2img_pipeline(text_model_dir)
@@ -201,18 +159,23 @@ def end_session(req: gr.Request) -> None:
@spaces.GPU
def preprocess_image_fn(
- image: str | np.ndarray | Image.Image, rmbg_tag: str = "rembg"
+ image: str | np.ndarray | Image.Image,
+ rmbg_tag: str = "rembg",
+ preprocess: bool = True,
) -> tuple[Image.Image, Image.Image]:
if isinstance(image, str):
image = Image.open(image)
elif isinstance(image, np.ndarray):
image = Image.fromarray(image)
- image_cache = resize_pil(image.copy(), 1024)
+ image_cache = image.copy() # resize_pil(image.copy(), 1024)
bg_remover = RBG_REMOVER if rmbg_tag == "rembg" else RBG14_REMOVER
image = bg_remover(image)
- image = trellis_preprocess(image)
+ image = keep_largest_connected_component(image)
+
+ if preprocess:
+ image = trellis_preprocess(image)
return image, image_cache
@@ -264,50 +227,6 @@ def get_cached_image(image_path: str) -> Image.Image:
return Image.open(image_path).resize((512, 512))
-@spaces.GPU
-def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
- return {
- "gaussian": {
- **gs.init_params,
- "_xyz": gs._xyz.cpu().numpy(),
- "_features_dc": gs._features_dc.cpu().numpy(),
- "_scaling": gs._scaling.cpu().numpy(),
- "_rotation": gs._rotation.cpu().numpy(),
- "_opacity": gs._opacity.cpu().numpy(),
- },
- "mesh": {
- "vertices": mesh.vertices.cpu().numpy(),
- "faces": mesh.faces.cpu().numpy(),
- },
- }
-
-
-def unpack_state(state: dict, device: str = "cpu") -> tuple[Gaussian, dict]:
- gs = Gaussian(
- aabb=state["gaussian"]["aabb"],
- sh_degree=state["gaussian"]["sh_degree"],
- mininum_kernel_size=state["gaussian"]["mininum_kernel_size"],
- scaling_bias=state["gaussian"]["scaling_bias"],
- opacity_bias=state["gaussian"]["opacity_bias"],
- scaling_activation=state["gaussian"]["scaling_activation"],
- device=device,
- )
- gs._xyz = torch.tensor(state["gaussian"]["_xyz"], device=device)
- gs._features_dc = torch.tensor(
- state["gaussian"]["_features_dc"], device=device
- )
- gs._scaling = torch.tensor(state["gaussian"]["_scaling"], device=device)
- gs._rotation = torch.tensor(state["gaussian"]["_rotation"], device=device)
- gs._opacity = torch.tensor(state["gaussian"]["_opacity"], device=device)
-
- mesh = edict(
- vertices=torch.tensor(state["mesh"]["vertices"], device=device),
- faces=torch.tensor(state["mesh"]["faces"], device=device),
- )
-
- return gs, mesh
-
-
def get_seed(randomize_seed: bool, seed: int, max_seed: int = MAX_SEED) -> int:
return np.random.randint(0, max_seed) if randomize_seed else seed
@@ -349,11 +268,11 @@ def select_point(
def image_to_3d(
image: Image.Image,
seed: int,
- ss_guidance_strength: float,
ss_sampling_steps: int,
- slat_guidance_strength: float,
slat_sampling_steps: int,
raw_image_cache: Image.Image,
+ ss_guidance_strength: float,
+ slat_guidance_strength: float,
sam_image: Image.Image = None,
is_sam_image: bool = False,
req: gr.Request = None,
@@ -361,39 +280,48 @@ def image_to_3d(
if is_sam_image:
seg_image = filter_image_small_connected_components(sam_image)
seg_image = Image.fromarray(seg_image, mode="RGBA")
- seg_image = trellis_preprocess(seg_image)
else:
seg_image = image
if isinstance(seg_image, np.ndarray):
seg_image = Image.fromarray(seg_image)
+ if isinstance(PIPELINE, Sam3dInference):
+ outputs = PIPELINE.run(
+ seg_image,
+ seed=seed,
+ stage1_inference_steps=ss_sampling_steps,
+ stage2_inference_steps=slat_sampling_steps,
+ )
+ else:
+ PIPELINE.cuda()
+ seg_image = trellis_preprocess(seg_image)
+ outputs = PIPELINE.run(
+ seg_image,
+ seed=seed,
+ formats=["gaussian", "mesh"],
+ preprocess_image=False,
+ sparse_structure_sampler_params={
+ "steps": ss_sampling_steps,
+ "cfg_strength": ss_guidance_strength,
+ },
+ slat_sampler_params={
+ "steps": slat_sampling_steps,
+ "cfg_strength": slat_guidance_strength,
+ },
+ )
+ # Set back to cpu for memory saving.
+ PIPELINE.cpu()
+
+ gs_model = outputs["gaussian"][0]
+ mesh_model = outputs["mesh"][0]
+ color_images = render_video(gs_model, r=1.85)["color"]
+ normal_images = render_video(mesh_model, r=1.85)["normal"]
+
output_root = os.path.join(TMP_DIR, str(req.session_hash))
os.makedirs(output_root, exist_ok=True)
seg_image.save(f"{output_root}/seg_image.png")
raw_image_cache.save(f"{output_root}/raw_image.png")
- PIPELINE.cuda()
- outputs = PIPELINE.run(
- seg_image,
- seed=seed,
- formats=["gaussian", "mesh"],
- preprocess_image=False,
- sparse_structure_sampler_params={
- "steps": ss_sampling_steps,
- "cfg_strength": ss_guidance_strength,
- },
- slat_sampler_params={
- "steps": slat_sampling_steps,
- "cfg_strength": slat_guidance_strength,
- },
- )
- # Set to cpu for memory saving.
- PIPELINE.cpu()
-
- gs_model = outputs["gaussian"][0]
- mesh_model = outputs["mesh"][0]
- color_images = render_video(gs_model)["color"]
- normal_images = render_video(mesh_model)["normal"]
video_path = os.path.join(output_root, "gs_mesh.mp4")
merge_images_video(color_images, normal_images, video_path)
@@ -405,56 +333,13 @@ def image_to_3d(
return state, video_path
-@spaces.GPU
-def extract_3d_representations(
- state: dict, enable_delight: bool, texture_size: int, req: gr.Request
-):
- output_root = TMP_DIR
- output_root = os.path.join(output_root, str(req.session_hash))
- gs_model, mesh_model = unpack_state(state, device="cuda")
-
- mesh = postprocessing_utils.to_glb(
- gs_model,
- mesh_model,
- simplify=0.9,
- texture_size=1024,
- verbose=True,
- )
- filename = "sample"
- gs_path = os.path.join(output_root, f"{filename}_gs.ply")
- gs_model.save_ply(gs_path)
-
- # Rotate mesh and GS by 90 degrees around Z-axis.
- rot_matrix = [[0, 0, -1], [0, 1, 0], [1, 0, 0]]
- # Addtional rotation for GS to align mesh.
- gs_rot = np.array([[1, 0, 0], [0, -1, 0], [0, 0, -1]]) @ np.array(
- rot_matrix
- )
- pose = GaussianOperator.trans_to_quatpose(gs_rot)
- aligned_gs_path = gs_path.replace(".ply", "_aligned.ply")
- GaussianOperator.resave_ply(
- in_ply=gs_path,
- out_ply=aligned_gs_path,
- instance_pose=pose,
- )
-
- mesh.vertices = mesh.vertices @ np.array(rot_matrix)
- mesh_obj_path = os.path.join(output_root, f"{filename}.obj")
- mesh.export(mesh_obj_path)
- mesh_glb_path = os.path.join(output_root, f"{filename}.glb")
- mesh.export(mesh_glb_path)
-
- torch.cuda.empty_cache()
-
- return mesh_glb_path, gs_path, mesh_obj_path, aligned_gs_path
-
-
def extract_3d_representations_v2(
state: dict,
enable_delight: bool,
texture_size: int,
req: gr.Request,
):
+ """Back-Projection Version of Texture Super-Resolution."""
output_root = TMP_DIR
user_dir = os.path.join(output_root, str(req.session_hash))
gs_model, mesh_model = unpack_state(state, device="cpu")
@@ -521,6 +406,7 @@ def extract_3d_representations_v3(
texture_size: int,
req: gr.Request,
):
+ """Back-Projection Version with Optimization-Based."""
output_root = TMP_DIR
user_dir = os.path.join(output_root, str(req.session_hash))
gs_model, mesh_model = unpack_state(state, device="cpu")
@@ -688,6 +574,7 @@ def text2image_fn(
image_wh: int | tuple[int, int] = [1024, 1024],
rmbg_tag: str = "rembg",
seed: int = None,
+ enable_pre_resize: bool = True,
n_sample: int = 3,
req: gr.Request = None,
):
@@ -715,7 +602,9 @@ def text2image_fn(
for idx in range(len(images)):
image = images[idx]
- images[idx], _ = preprocess_image_fn(image, rmbg_tag)
+ images[idx], _ = preprocess_image_fn(
+ image, rmbg_tag, enable_pre_resize
+ )
save_paths = []
for idx, image in enumerate(images):
@@ -841,6 +730,7 @@ def backproject_texture_v2(
texture_size: int,
enable_delight: bool = True,
fix_mesh: bool = False,
+ no_mesh_post_process: bool = False,
uuid: str = "sample",
req: gr.Request = None,
) -> str:
@@ -857,6 +747,7 @@ def backproject_texture_v2(
skip_fix_mesh=not fix_mesh,
delight=enable_delight,
texture_wh=[texture_size, texture_size],
+ no_mesh_post_process=no_mesh_post_process,
)
output_obj_mesh = os.path.join(output_dir, f"{uuid}.obj")
diff --git a/apps/image_to_3d.py b/apps/image_to_3d.py
index d8c1681..5bf5fc9 100644
--- a/apps/image_to_3d.py
+++ b/apps/image_to_3d.py
@@ -17,7 +17,9 @@
import os
-os.environ["GRADIO_APP"] = "imageto3d"
+# GRADIO_APP == "imageto3d_sam3d", sam3d object model, by default.
+# GRADIO_APP == "imageto3d", TRELLIS model.
+os.environ["GRADIO_APP"] = "imageto3d_sam3d"
from glob import glob
import gradio as gr
@@ -37,6 +39,16 @@
start_session,
)
+app_name = os.getenv("GRADIO_APP")
+if app_name == "imageto3d_sam3d":
+ enable_pre_resize = False
+ sample_step = 25
+ bg_rm_model_name = "rembg" # "rembg", "rmbg14"
+elif app_name == "imageto3d":
+ enable_pre_resize = True
+ sample_step = 12
+ bg_rm_model_name = "rembg" # "rembg", "rmbg14"
+
with gr.Blocks(delete_cache=(43200, 43200), theme=custom_theme) as demo:
gr.HTML(image_css, visible=False)
gr.HTML(lighting_css, visible=False)
@@ -67,7 +79,7 @@
)
with gr.Row():
- with gr.Column(scale=2):
+ with gr.Column(scale=3):
with gr.Tabs() as input_tabs:
with gr.Tab(
label="Image(auto seg)", id=0
@@ -142,7 +154,7 @@
)
rmbg_tag = gr.Radio(
choices=["rembg", "rmbg14"],
- value="rembg",
+ value=bg_rm_model_name,
label="Background Removal Model",
)
with gr.Row():
@@ -163,7 +175,11 @@
step=0.1,
)
ss_sampling_steps = gr.Slider(
- 1, 50, label="Sampling Steps", value=12, step=1
+ 1,
+ 50,
+ label="Sampling Steps",
+ value=sample_step,
+ step=1,
)
gr.Markdown("Visual Appearance Generation")
with gr.Row():
@@ -175,7 +191,11 @@
step=0.1,
)
slat_sampling_steps = gr.Slider(
- 1, 50, label="Sampling Steps", value=12, step=1
+ 1,
+ 50,
+ label="Sampling Steps",
+ value=sample_step,
+ step=1,
)
generate_btn = gr.Button(
@@ -242,7 +262,7 @@
has quality inspection, open with an editor to view details.
"""
)
-
+ enable_pre_resize = gr.State(enable_pre_resize)
with gr.Row() as single_image_example:
examples = gr.Examples(
label="Image Gallery",
@@ -252,7 +272,7 @@
glob("apps/assets/example_image/*")
)
],
- inputs=[image_prompt, rmbg_tag],
+ inputs=[image_prompt, rmbg_tag, enable_pre_resize],
fn=preprocess_image_fn,
outputs=[image_prompt, raw_image_cache],
run_on_click=True,
@@ -274,16 +294,16 @@
run_on_click=True,
examples_per_page=10,
)
- with gr.Column(scale=1):
+ with gr.Column(scale=2):
gr.Markdown("
")
video_output = gr.Video(
label="Generated 3D Asset",
autoplay=True,
loop=True,
- height=300,
+ height=400,
)
model_output_gs = gr.Model3D(
- label="Gaussian Representation", height=300, interactive=False
+ label="Gaussian Representation", height=350, interactive=False
)
aligned_gs = gr.Textbox(visible=False)
gr.Markdown(
@@ -292,9 +312,9 @@
with gr.Row():
model_output_mesh = gr.Model3D(
label="Mesh Representation",
- height=300,
+ height=350,
interactive=False,
- clear_color=[0.8, 0.8, 0.8, 1],
+ clear_color=[0, 0, 0, 1],
elem_id="lighter_mesh",
)
@@ -320,7 +340,7 @@
image_prompt.upload(
preprocess_image_fn,
- inputs=[image_prompt, rmbg_tag],
+ inputs=[image_prompt, rmbg_tag, enable_pre_resize],
outputs=[image_prompt, raw_image_cache],
)
image_prompt.change(
@@ -437,11 +457,11 @@
inputs=[
image_prompt,
seed,
- ss_guidance_strength,
ss_sampling_steps,
- slat_guidance_strength,
slat_sampling_steps,
raw_image_cache,
+ ss_guidance_strength,
+ slat_guidance_strength,
image_seg_sam,
is_samimage,
],
diff --git a/apps/text_to_3d.py b/apps/text_to_3d.py
index 8c9012c..e5a176d 100644
--- a/apps/text_to_3d.py
+++ b/apps/text_to_3d.py
@@ -17,8 +17,9 @@
import os
-os.environ["GRADIO_APP"] = "textto3d"
-
+# GRADIO_APP == "textto3d_sam3d", sam3d object model, by default.
+# GRADIO_APP == "textto3d", TRELLIS model.
+os.environ["GRADIO_APP"] = "textto3d_sam3d"
import gradio as gr
from app_style import custom_theme, image_css, lighting_css
@@ -37,6 +38,14 @@
text2image_fn,
)
+app_name = os.getenv("GRADIO_APP")
+if app_name == "textto3d_sam3d":
+ enable_pre_resize = False
+ sample_step = 25
+elif app_name == "textto3d":
+ enable_pre_resize = True
+ sample_step = 12
+
with gr.Blocks(delete_cache=(43200, 43200), theme=custom_theme) as demo:
gr.HTML(image_css, visible=False)
gr.HTML(lighting_css, visible=False)
@@ -101,11 +110,11 @@
)
rmbg_tag = gr.Radio(
choices=["rembg", "rmbg14"],
- value="rembg",
+ value="rmbg14",
label="Background Removal Model",
)
ip_adapt_scale = gr.Slider(
- 0, 1, label="IP-adapter Scale", value=0.3, step=0.05
+ 0, 1, label="IP-adapter Scale", value=0.7, step=0.05
)
img_guidance_scale = gr.Slider(
1, 30, label="Text Guidance Scale", value=12, step=0.2
@@ -162,7 +171,11 @@
step=0.1,
)
ss_sampling_steps = gr.Slider(
- 1, 50, label="Sampling Steps", value=12, step=1
+ 1,
+ 50,
+ label="Sampling Steps",
+ value=sample_step,
+ step=1,
)
gr.Markdown("Visual Appearance Generation")
with gr.Row():
@@ -174,7 +187,11 @@
step=0.1,
)
slat_sampling_steps = gr.Slider(
- 1, 50, label="Sampling Steps", value=12, step=1
+ 1,
+ 50,
+ label="Sampling Steps",
+ value=sample_step,
+ step=1,
)
generate_btn = gr.Button(
@@ -265,7 +282,7 @@
visible=False,
)
gr.Markdown(
- "Generated image may be poor quality due to auto seg."
+ "Generated image may be poor quality due to auto seg. "
"Retry by adjusting text prompt, seed or switch seg model in `Image Gen Settings`."
)
with gr.Row():
@@ -285,7 +302,7 @@
model_output_mesh = gr.Model3D(
label="Mesh Representation",
- clear_color=[0.8, 0.8, 0.8, 1],
+ clear_color=[0, 0, 0, 1],
height=300,
interactive=False,
elem_id="lighter_mesh",
@@ -323,6 +340,7 @@
)
output_buf = gr.State()
+ enable_pre_resize = gr.State(enable_pre_resize)
demo.load(start_session)
demo.unload(end_session)
@@ -389,6 +407,7 @@
img_resolution,
rmbg_tag,
seed,
+ enable_pre_resize,
],
outputs=[
image_sample1,
@@ -420,11 +439,11 @@
inputs=[
select_img,
seed,
- ss_guidance_strength,
ss_sampling_steps,
- slat_guidance_strength,
slat_sampling_steps,
raw_image_cache,
+ ss_guidance_strength,
+ slat_guidance_strength,
],
outputs=[output_buf, video_output],
).success(
diff --git a/apps/texture_edit.py b/apps/texture_edit.py
index 722ce6c..01afe13 100644
--- a/apps/texture_edit.py
+++ b/apps/texture_edit.py
@@ -267,7 +267,7 @@ def active_btn_by_content(mesh_content: gr.Model3D, text_content: gr.Textbox):
demo.load(start_session)
demo.unload(end_session)
-
+ no_mesh_post_process = gr.State(True)
mesh_input.change(
lambda: tuple(
[
@@ -368,6 +368,7 @@ def active_btn_by_content(mesh_content: gr.Model3D, text_content: gr.Textbox):
texture_size,
project_delight,
fix_mesh,
+ no_mesh_post_process,
],
outputs=[mesh_output, mesh_outpath, download_btn],
).success(
diff --git a/apps/visualize_asset.py b/apps/visualize_asset.py
index 5e9b94b..233df4a 100644
--- a/apps/visualize_asset.py
+++ b/apps/visualize_asset.py
@@ -27,7 +27,6 @@
import uuid
import xml.etree.ElementTree as ET
from pathlib import Path
-from typing import Any, Dict, Tuple
import gradio as gr
import pandas as pd
@@ -255,8 +254,7 @@ def search_assets(query: str, top_k: int):
return items, gr.update(interactive=True), top_assets
-# --- Mesh extraction ---
-def _extract_mesh_paths(row) -> Tuple[str | None, str | None, str]:
+def _extract_mesh_paths(row) -> tuple[str | None, str | None, str]:
desc = row["description"]
urdf_path = os.path.join(DATA_ROOT, row["urdf_path"])
asset_dir = os.path.join(DATA_ROOT, row["asset_dir"])
diff --git a/docs/acknowledgement.md b/docs/acknowledgement.md
index d588194..b69fdae 100644
--- a/docs/acknowledgement.md
+++ b/docs/acknowledgement.md
@@ -1,7 +1,7 @@
# 🙌 Acknowledgement
EmbodiedGen builds upon the following amazing projects and models:
-🌟 [Trellis](https://github.com/microsoft/TRELLIS) | 🌟 [Hunyuan-Delight](https://huggingface.co/tencent/Hunyuan3D-2/tree/main/hunyuan3d-delight-v2-0) | 🌟 [Segment Anything](https://github.com/facebookresearch/segment-anything) | 🌟 [Rembg](https://github.com/danielgatis/rembg) | 🌟 [RMBG-1.4](https://huggingface.co/briaai/RMBG-1.4) | 🌟 [Stable Diffusion x4](https://huggingface.co/stabilityai/stable-diffusion-x4-upscaler) | 🌟 [Real-ESRGAN](https://github.com/xinntao/Real-ESRGAN) | 🌟 [Kolors](https://github.com/Kwai-Kolors/Kolors) | 🌟 [ChatGLM3](https://github.com/THUDM/ChatGLM3) | 🌟 [Aesthetic Score](http://captions.christoph-schuhmann.de/aesthetic_viz_laion_sac+logos+ava1-l14-linearMSE-en-2.37B.html) | 🌟 [Pano2Room](https://github.com/TrickyGo/Pano2Room) | 🌟 [Diffusion360](https://github.com/ArcherFMY/SD-T2I-360PanoImage) | 🌟 [Kaolin](https://github.com/NVIDIAGameWorks/kaolin) | 🌟 [diffusers](https://github.com/huggingface/diffusers) | 🌟 [gsplat](https://github.com/nerfstudio-project/gsplat) | 🌟 [QWEN-2.5VL](https://github.com/QwenLM/Qwen2.5-VL) | 🌟 [GPT4o](https://platform.openai.com/docs/models/gpt-4o) | 🌟 [SD3.5](https://huggingface.co/stabilityai/stable-diffusion-3.5-medium) | 🌟 [ManiSkill](https://github.com/haosulab/ManiSkill)
+🌟 [Trellis](https://github.com/microsoft/TRELLIS) | 🌟 [Hunyuan-Delight](https://huggingface.co/tencent/Hunyuan3D-2/tree/main/hunyuan3d-delight-v2-0) | 🌟 [Segment Anything](https://github.com/facebookresearch/segment-anything) | 🌟 [Rembg](https://github.com/danielgatis/rembg) | 🌟 [RMBG-1.4](https://huggingface.co/briaai/RMBG-1.4) | 🌟 [Stable Diffusion x4](https://huggingface.co/stabilityai/stable-diffusion-x4-upscaler) | 🌟 [Real-ESRGAN](https://github.com/xinntao/Real-ESRGAN) | 🌟 [Kolors](https://github.com/Kwai-Kolors/Kolors) | 🌟 [ChatGLM3](https://github.com/THUDM/ChatGLM3) | 🌟 [Aesthetic Score](http://captions.christoph-schuhmann.de/aesthetic_viz_laion_sac+logos+ava1-l14-linearMSE-en-2.37B.html) | 🌟 [Pano2Room](https://github.com/TrickyGo/Pano2Room) | 🌟 [Diffusion360](https://github.com/ArcherFMY/SD-T2I-360PanoImage) | 🌟 [Kaolin](https://github.com/NVIDIAGameWorks/kaolin) | 🌟 [diffusers](https://github.com/huggingface/diffusers) | 🌟 [gsplat](https://github.com/nerfstudio-project/gsplat) | 🌟 [QWEN-2.5VL](https://github.com/QwenLM/Qwen2.5-VL) | 🌟 [GPT4o](https://platform.openai.com/docs/models/gpt-4o) | 🌟 [SD3.5](https://huggingface.co/stabilityai/stable-diffusion-3.5-medium) | 🌟 [ManiSkill](https://github.com/haosulab/ManiSkill) | 🌟 [SAM3D](https://github.com/facebookresearch/sam-3d-objects)
---
diff --git a/docs/install.md b/docs/install.md
index cf01f06..e3c3534 100644
--- a/docs/install.md
+++ b/docs/install.md
@@ -7,11 +7,12 @@ hide:
```sh
git clone https://github.com/HorizonRobotics/EmbodiedGen.git
cd EmbodiedGen
-git checkout v0.1.6
+git checkout v0.1.7
git submodule update --init --recursive --progress
conda create -n embodiedgen python=3.10.13 -y # recommended to use a new env.
conda activate embodiedgen
-bash install.sh basic
+bash install.sh basic # around 20 mins
+# Optional: `bash install.sh extra` for scene3d-cli
```
Please `huggingface-cli login` to ensure that the ckpts can be downloaded automatically afterwards.
diff --git a/docs/services/image_to_3d.md b/docs/services/image_to_3d.md
index 4e1a1c8..674d243 100644
--- a/docs/services/image_to_3d.md
+++ b/docs/services/image_to_3d.md
@@ -67,6 +67,8 @@ python apps/image_to_3d.py
CUDA_VISIBLE_DEVICES=0 nohup python apps/image_to_3d.py > /dev/null 2>&1 &
```
+Support the use of [SAM3D](https://github.com/facebookresearch/sam-3d-objects) or [TRELLIS](https://github.com/microsoft/TRELLIS) as 3D generation model, modify `GRADIO_APP` in `apps/image_to_3d.py` to switch model.
+
---
!!! tip "Getting Started"
diff --git a/docs/tutorials/image_to_3d.md b/docs/tutorials/image_to_3d.md
index 2f7178a..705ed35 100644
--- a/docs/tutorials/image_to_3d.md
+++ b/docs/tutorials/image_to_3d.md
@@ -5,10 +5,11 @@ Generate **physically plausible 3D assets** from a single input image, supportin
---
## ⚡ Command-Line Usage
+Support the use of [SAM3D](https://github.com/facebookresearch/sam-3d-objects) or [TRELLIS](https://github.com/microsoft/TRELLIS) as 3D generation model, modify `IMAGE3D_MODEL` in `embodied_gen/scripts/imageto3d.py` to switch model.
```bash
img3d-cli --image_path apps/assets/example_image/sample_00.jpg \
-apps/assets/example_image/sample_01.jpg apps/assets/example_image/sample_19.jpg \
+apps/assets/example_image/sample_01.jpg \
--n_retry 1 --output_root outputs/imageto3d
```
diff --git a/docs/tutorials/layout_gen.md b/docs/tutorials/layout_gen.md
index 3109ff1..6eeb642 100644
--- a/docs/tutorials/layout_gen.md
+++ b/docs/tutorials/layout_gen.md
@@ -60,6 +60,7 @@ You can also run multiple tasks via a task list file in the backend.
CUDA_VISIBLE_DEVICES=0 nohup layout-cli \
--task_descs "apps/assets/example_layout/task_list.txt" \
--bg_list "outputs/bg_scenes/scene_list.txt" \
+ --n_image_retry 4 --n_asset_retry 3 --n_pipe_retry 2 \
--output_root "outputs/layouts_gens" \
--insert_robot > layouts_gens.log &
```
diff --git a/docs/tutorials/text_to_3d.md b/docs/tutorials/text_to_3d.md
index 3a81366..0c4b0dc 100644
--- a/docs/tutorials/text_to_3d.md
+++ b/docs/tutorials/text_to_3d.md
@@ -74,8 +74,8 @@ You will get the following results:
Kolors Model CLI (Supports Chinese & English Prompts):
```bash
bash embodied_gen/scripts/textto3d.sh \
- --prompts "small bronze figurine of a lion" "A globe with wooden base and latitude and longitude lines" "橙色电动手钻,有磨损细节" \
- --output_root outputs/textto3d_k
+ --prompts "A globe with wooden base and latitude and longitude lines" "橙色电动手钻,有磨损细节" \
+ --output_root outputs/textto3d_k
```
> Models with more permissive licenses can be found in `embodied_gen/models/image_comm_model.py`.
diff --git a/embodied_gen/data/asset_converter.py b/embodied_gen/data/asset_converter.py
index 71ef27e..7b09b70 100644
--- a/embodied_gen/data/asset_converter.py
+++ b/embodied_gen/data/asset_converter.py
@@ -1,3 +1,20 @@
+# Project EmbodiedGen
+#
+# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+# implied. See the License for the specific language governing
+# permissions and limitations under the License.
+
+
from __future__ import annotations
import logging
diff --git a/embodied_gen/data/backproject_v2.py b/embodied_gen/data/backproject_v2.py
index 5908013..92c0a76 100644
--- a/embodied_gen/data/backproject_v2.py
+++ b/embodied_gen/data/backproject_v2.py
@@ -274,6 +274,7 @@ class TextureBacker:
mask_thresh (float, optional): Threshold for visibility masks.
smooth_texture (bool, optional): Apply post-processing to texture.
inpaint_smooth (bool, optional): Apply inpainting smoothing.
+ mesh_post_process (bool, optional): False for preventing modification of vertices.
Example:
```py
@@ -308,6 +309,7 @@ def __init__(
mask_thresh: float = 0.5,
smooth_texture: bool = True,
inpaint_smooth: bool = False,
+ mesh_post_process: bool = True,
) -> None:
self.camera_params = camera_params
self.renderer = None
@@ -318,6 +320,7 @@ def __init__(
self.mask_thresh = mask_thresh
self.smooth_texture = smooth_texture
self.inpaint_smooth = inpaint_smooth
+ self.mesh_post_process = mesh_post_process
self.bake_angle_thresh = bake_angle_thresh
self.bake_unreliable_kernel_size = int(
@@ -668,7 +671,12 @@ def __call__(
mesh, self.scale, self.center
)
textured_mesh = save_mesh_with_mtl(
- vertices, faces, uv_map, texture_np, output_path
+ vertices,
+ faces,
+ uv_map,
+ texture_np,
+ output_path,
+ mesh_process=self.mesh_post_process,
)
return textured_mesh
@@ -766,6 +774,7 @@ def parse_args():
help="Disable saving delight image",
)
parser.add_argument("--n_max_faces", type=int, default=30000)
+ parser.add_argument("--no_mesh_post_process", action="store_true")
args, unknown = parser.parse_known_args()
return args
@@ -856,6 +865,7 @@ def entrypoint(
render_wh=args.resolution_hw,
texture_wh=args.texture_wh,
smooth_texture=not args.no_smooth_texture,
+ mesh_post_process=not args.no_mesh_post_process,
)
textured_mesh = texture_backer(multiviews, mesh, args.output_path)
diff --git a/embodied_gen/data/backproject_v3.py b/embodied_gen/data/backproject_v3.py
index b22b497..81cea59 100644
--- a/embodied_gen/data/backproject_v3.py
+++ b/embodied_gen/data/backproject_v3.py
@@ -353,8 +353,8 @@ def parse_args():
parser.add_argument(
"--distance",
type=float,
- default=5,
- help="Camera distance (default: 5)",
+ default=4.5,
+ help="Camera distance (default: 4.5)",
)
parser.add_argument(
"--resolution_hw",
@@ -400,8 +400,8 @@ def parse_args():
parser.add_argument(
"--mesh_sipmlify_ratio",
type=float,
- default=0.9,
- help="Mesh simplification ratio (default: 0.9)",
+ default=0.85,
+ help="Mesh simplification ratio (default: 0.85)",
)
parser.add_argument(
"--delight", action="store_true", help="Use delighting model."
@@ -500,7 +500,7 @@ def entrypoint(
faces = mesh.faces.astype(np.int32)
vertices = vertices.astype(np.float32)
- if not args.skip_fix_mesh and len(faces) > 10 * args.n_max_faces:
+ if not args.skip_fix_mesh:
mesh_fixer = MeshFixer(vertices, faces, args.device)
vertices, faces = mesh_fixer(
filter_ratio=args.mesh_sipmlify_ratio,
@@ -512,7 +512,7 @@ def entrypoint(
if len(faces) > args.n_max_faces:
mesh_fixer = MeshFixer(vertices, faces, args.device)
vertices, faces = mesh_fixer(
- filter_ratio=max(0.05, args.mesh_sipmlify_ratio - 0.2),
+ filter_ratio=max(0.1, args.mesh_sipmlify_ratio - 0.1),
max_hole_size=0.04,
resolution=1024,
num_views=1000,
diff --git a/embodied_gen/data/utils.py b/embodied_gen/data/utils.py
index fa2f7d5..74f96c6 100644
--- a/embodied_gen/data/utils.py
+++ b/embodied_gen/data/utils.py
@@ -15,10 +15,13 @@
# permissions and limitations under the License.
+import logging
import math
import os
-import random
+import time
import zipfile
+from contextlib import contextmanager
+from dataclasses import dataclass, field
from shutil import rmtree
from typing import List, Tuple, Union
@@ -28,20 +31,9 @@
import nvdiffrast.torch as dr
import torch
import torch.nn.functional as F
-from PIL import Image, ImageEnhance
-
-try:
- from kolors.models.modeling_chatglm import ChatGLMModel
- from kolors.models.tokenization_chatglm import ChatGLMTokenizer
-except ImportError:
- ChatGLMTokenizer = None
- ChatGLMModel = None
-import logging
-from dataclasses import dataclass, field
-
import trimesh
from kaolin.render.camera import Camera
-from torch import nn
+from PIL import Image, ImageEnhance
logger = logging.getLogger(__name__)
@@ -50,10 +42,8 @@
"DiffrastRender",
"save_images",
"render_pbr",
- "prelabel_text_feature",
"calc_vertex_normals",
"normalize_vertices_array",
- "load_mesh_to_unit_cube",
"as_list",
"CameraSetting",
"import_kaolin_mesh",
@@ -67,6 +57,7 @@
"trellis_preprocess",
"delete_dir",
"kaolin_to_opencv_view",
+ "model_device_ctx",
]
@@ -520,114 +511,6 @@ def render_pbr(
return image, albedo, diffuse, normal
-def _move_to_target_device(data, device: str):
- if isinstance(data, dict):
- for key, value in data.items():
- data[key] = _move_to_target_device(value, device)
- elif isinstance(data, torch.Tensor):
- return data.to(device)
-
- return data
-
-
-def _encode_prompt(
- prompt_batch,
- text_encoders,
- tokenizers,
- proportion_empty_prompts=0,
- is_train=True,
-):
- prompt_embeds_list = []
-
- captions = []
- for caption in prompt_batch:
- if random.random() < proportion_empty_prompts:
- captions.append("")
- elif isinstance(caption, str):
- captions.append(caption)
- elif isinstance(caption, (list, np.ndarray)):
- captions.append(random.choice(caption) if is_train else caption[0])
-
- with torch.no_grad():
- for tokenizer, text_encoder in zip(tokenizers, text_encoders):
- text_inputs = tokenizer(
- captions,
- padding="max_length",
- max_length=256,
- truncation=True,
- return_tensors="pt",
- ).to(text_encoder.device)
-
- output = text_encoder(
- input_ids=text_inputs.input_ids,
- attention_mask=text_inputs.attention_mask,
- position_ids=text_inputs.position_ids,
- output_hidden_states=True,
- )
-
- # We are only interested in the pooled output of the text encoder.
- prompt_embeds = output.hidden_states[-2].permute(1, 0, 2).clone()
- pooled_prompt_embeds = output.hidden_states[-1][-1, :, :].clone()
- bs_embed, seq_len, _ = prompt_embeds.shape
- prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
- prompt_embeds_list.append(prompt_embeds)
-
- prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
- pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)
-
- return prompt_embeds, pooled_prompt_embeds
-
-
-def load_llm_models(pretrained_model_name_or_path: str, device: str):
- tokenizer = ChatGLMTokenizer.from_pretrained(
- pretrained_model_name_or_path,
- subfolder="text_encoder",
- )
- text_encoder = ChatGLMModel.from_pretrained(
- pretrained_model_name_or_path,
- subfolder="text_encoder",
- ).to(device)
-
- text_encoders = [
- text_encoder,
- ]
- tokenizers = [
- tokenizer,
- ]
-
- logger.info(f"Load model from {pretrained_model_name_or_path} done.")
-
- return tokenizers, text_encoders
-
-
-def prelabel_text_feature(
- prompt_batch: List[str],
- output_dir: str,
- tokenizers: nn.Module,
- text_encoders: nn.Module,
-) -> List[str]:
- os.makedirs(output_dir, exist_ok=True)
-
- # prompt_batch ["text..."]
- prompt_embeds, pooled_prompt_embeds = _encode_prompt(
- prompt_batch, text_encoders, tokenizers
- )
-
- prompt_embeds = _move_to_target_device(prompt_embeds, device="cpu")
- pooled_prompt_embeds = _move_to_target_device(
- pooled_prompt_embeds, device="cpu"
- )
-
- data_dict = dict(
- prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds
- )
-
- save_path = os.path.join(output_dir, "text_feat.pth")
- torch.save(data_dict, save_path)
-
- return save_path
-
-
def _calc_face_normals(
vertices: torch.Tensor, # V,3 first vertex may be unreferenced
faces: torch.Tensor, # F,3 long, first face may be all zero
@@ -683,25 +566,6 @@ def normalize_vertices_array(
return vertices, scale, center
-def load_mesh_to_unit_cube(
- mesh_file: str,
- mesh_scale: float = 1.0,
-) -> tuple[trimesh.Trimesh, float, list[float]]:
- if not os.path.exists(mesh_file):
- raise FileNotFoundError(f"mesh_file path {mesh_file} not exists.")
-
- mesh = trimesh.load(mesh_file)
- if isinstance(mesh, trimesh.Scene):
- mesh = trimesh.utils.concatenate(mesh)
-
- vertices, scale, center = normalize_vertices_array(
- mesh.vertices, mesh_scale
- )
- mesh.vertices = vertices
-
- return mesh, scale, center
-
-
def as_list(obj):
if isinstance(obj, (list, tuple)):
return obj
@@ -862,6 +726,7 @@ def save_mesh_with_mtl(
texture: Union[Image.Image, np.ndarray],
output_path: str,
material_base=(250, 250, 250, 255),
+ mesh_process: bool = True,
) -> trimesh.Trimesh:
if isinstance(texture, np.ndarray):
texture = Image.fromarray(texture)
@@ -870,6 +735,7 @@ def save_mesh_with_mtl(
vertices,
faces,
visual=trimesh.visual.TextureVisuals(uv=uvs, image=texture),
+ process=mesh_process, # True for preventing modification of vertices
)
mesh.visual.material = trimesh.visual.material.SimpleMaterial(
image=texture,
@@ -998,8 +864,9 @@ def gamma_shs(shs: torch.Tensor, gamma: float) -> torch.Tensor:
def resize_pil(image: Image.Image, max_size: int = 1024) -> Image.Image:
- max_size = max(image.size)
- scale = min(1, 1024 / max_size)
+ current_max_dim = max(image.size)
+ scale = min(1, max_size / current_max_dim)
+
if scale < 1:
new_size = (int(image.width * scale), int(image.height * scale))
image = image.resize(new_size, Image.Resampling.LANCZOS)
@@ -1068,3 +935,34 @@ def delete_dir(folder_path: str, keep_subs: list[str] = None) -> None:
rmtree(item_path)
else:
os.remove(item_path)
+
+
+@contextmanager
+def model_device_ctx(
+ *models,
+ src_device: str = "cpu",
+ dst_device: str = "cuda",
+ verbose: bool = False,
+):
+ start = time.perf_counter()
+ for m in models:
+ if m is None:
+ continue
+ m.to(dst_device)
+ to_cuda_time = time.perf_counter() - start
+
+ try:
+ yield
+ finally:
+ start = time.perf_counter()
+ for m in models:
+ if m is None:
+ continue
+ m.to(src_device)
+ to_cpu_time = time.perf_counter() - start
+
+ if verbose:
+ model_names = [m.__class__.__name__ for m in models]
+ logger.debug(
+ f"[model_device_ctx] {model_names} to cuda: {to_cuda_time:.1f}s, to cpu: {to_cpu_time:.1f}s"
+ )
diff --git a/embodied_gen/models/sam3d.py b/embodied_gen/models/sam3d.py
new file mode 100644
index 0000000..4b28e40
--- /dev/null
+++ b/embodied_gen/models/sam3d.py
@@ -0,0 +1,152 @@
+# Project EmbodiedGen
+#
+# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+# implied. See the License for the specific language governing
+# permissions and limitations under the License.
+
+from embodied_gen.utils.monkey_patches import monkey_patch_sam3d
+
+monkey_patch_sam3d()
+import os
+import sys
+
+import numpy as np
+from hydra.utils import instantiate
+from modelscope import snapshot_download
+from omegaconf import OmegaConf
+from PIL import Image
+
+current_file_path = os.path.abspath(__file__)
+current_dir = os.path.dirname(current_file_path)
+sys.path.append(os.path.join(current_dir, "../.."))
+from loguru import logger
+from thirdparty.sam3d.sam3d_objects.pipeline.inference_pipeline_pointmap import (
+ InferencePipelinePointMap,
+)
+
+logger.remove()
+logger.add(lambda _: None, level="ERROR")
+
+
+__all__ = ["Sam3dInference"]
+
+
+class Sam3dInference:
+ """Wrapper for the SAM-3D-Objects inference pipeline.
+
+ This class handles loading the SAM-3D-Objects model, configuring it for inference,
+ and running the pipeline on input images (optionally with masks and pointmaps).
+ It supports distillation options and inference step customization.
+
+ Args:
+ local_dir (str): Directory to store or load model weights and configs.
+ compile (bool): Whether to compile the model for faster inference.
+
+ Methods:
+ merge_mask_to_rgba(image, mask):
+ Merges a binary mask into the alpha channel of an RGB image.
+
+ run(image, mask=None, seed=None, pointmap=None, use_stage1_distillation=False,
+ use_stage2_distillation=False, stage1_inference_steps=25, stage2_inference_steps=25):
+ Runs the inference pipeline and returns the output dictionary.
+ """
+
+ def __init__(
+ self, local_dir: str = "weights/sam-3d-objects", compile: bool = False
+ ) -> None:
+ if not os.path.exists(local_dir):
+ snapshot_download("facebook/sam-3d-objects", local_dir=local_dir)
+ config_file = os.path.join(local_dir, "checkpoints/pipeline.yaml")
+ config = OmegaConf.load(config_file)
+ config.rendering_engine = "nvdiffrast"
+ config.compile_model = compile
+ config.workspace_dir = os.path.dirname(config_file)
+ # Generate 4 instead of 32 gs in each pixel for efficient storage.
+ config["slat_decoder_gs_config_path"] = config.pop(
+ "slat_decoder_gs_4_config_path", "slat_decoder_gs_4.yaml"
+ )
+ config["slat_decoder_gs_ckpt_path"] = config.pop(
+ "slat_decoder_gs_4_ckpt_path", "slat_decoder_gs_4.ckpt"
+ )
+ self.pipeline: InferencePipelinePointMap = instantiate(config)
+
+ def merge_mask_to_rgba(
+ self, image: np.ndarray, mask: np.ndarray
+ ) -> np.ndarray:
+ mask = mask.astype(np.uint8) * 255
+ mask = mask[..., None]
+ rgba_image = np.concatenate([image[..., :3], mask], axis=-1)
+
+ return rgba_image
+
+ def run(
+ self,
+ image: np.ndarray | Image.Image,
+ mask: np.ndarray = None,
+ seed: int = None,
+ pointmap: np.ndarray = None,
+ use_stage1_distillation: bool = False,
+ use_stage2_distillation: bool = False,
+ stage1_inference_steps: int = 25,
+ stage2_inference_steps: int = 25,
+ ) -> dict:
+ if isinstance(image, Image.Image):
+ image = np.array(image)
+ if mask is not None:
+ image = self.merge_mask_to_rgba(image, mask)
+ return self.pipeline.run(
+ image,
+ None,
+ seed,
+ stage1_only=False,
+ with_mesh_postprocess=False,
+ with_texture_baking=False,
+ with_layout_postprocess=False,
+ use_vertex_color=True,
+ use_stage1_distillation=use_stage1_distillation,
+ use_stage2_distillation=use_stage2_distillation,
+ stage1_inference_steps=stage1_inference_steps,
+ stage2_inference_steps=stage2_inference_steps,
+ pointmap=pointmap,
+ )
+
+
+if __name__ == "__main__":
+ pipeline = Sam3dInference()
+
+ from time import time
+
+ import torch
+ from embodied_gen.models.segment_model import RembgRemover
+
+ input_image = "apps/assets/example_image/sample_00.jpg"
+ output_gs = "outputs/splat.ply"
+ remover = RembgRemover()
+ clean_image = remover(input_image)
+
+ if torch.cuda.is_available():
+ torch.cuda.reset_peak_memory_stats()
+ torch.cuda.empty_cache()
+
+ start = time()
+ output = pipeline.run(clean_image, seed=42)
+ print(f"Running cost: {round(time()-start, 1)}")
+
+ if torch.cuda.is_available():
+ max_memory = torch.cuda.max_memory_allocated() / (1024**3)
+ print(f"(Max VRAM): {max_memory:.2f} GB")
+
+ print(f"End: {torch.cuda.memory_allocated() / 1024**2:.2f} MB")
+
+ output["gs"].save_ply(output_gs)
+ print(f"Saved to {output_gs}")
diff --git a/embodied_gen/models/segment_model.py b/embodied_gen/models/segment_model.py
index 6f54cfa..a82afa6 100644
--- a/embodied_gen/models/segment_model.py
+++ b/embodied_gen/models/segment_model.py
@@ -43,6 +43,7 @@
"SAMRemover",
"SAMPredictor",
"RembgRemover",
+ "BMGG14Remover",
"get_segmented_image_by_agent",
]
@@ -376,7 +377,7 @@ def __init__(self) -> None:
def __call__(
self, image: Union[str, Image.Image, np.ndarray], save_path: str = None
- ):
+ ) -> Image.Image:
"""Removes background from an image.
Args:
@@ -496,13 +497,18 @@ def _is_valid_seg(raw_img: Image.Image, seg_img: Image.Image) -> bool:
# input_image = "outputs/text2image/tmp/bucket.jpeg"
# output_image = "outputs/text2image/tmp/bucket_seg.png"
- remover = SAMRemover(model_type="vit_h")
- remover = RembgRemover()
- clean_image = remover(input_image)
- clean_image.save(output_image)
- get_segmented_image_by_agent(
- Image.open(input_image), remover, remover, None, "./test_seg.png"
- )
+ # remover = SAMRemover(model_type="vit_h")
+ # remover = RembgRemover()
+ # clean_image = remover(input_image)
+ # clean_image.save(output_image)
+ # get_segmented_image_by_agent(
+ # Image.open(input_image), remover, remover, None, "./test_seg.png"
+ # )
remover = BMGG14Remover()
- remover("embodied_gen/models/test_seg.jpg", "./seg.png")
+ clean_image = remover("./camera.jpeg", "./seg.png")
+ from embodied_gen.utils.process_media import (
+ keep_largest_connected_component,
+ )
+
+ keep_largest_connected_component(clean_image).save("./seg_post.png")
diff --git a/embodied_gen/scripts/gen_scene3d.py b/embodied_gen/scripts/gen_scene3d.py
index 42d2527..c1ab5ff 100644
--- a/embodied_gen/scripts/gen_scene3d.py
+++ b/embodied_gen/scripts/gen_scene3d.py
@@ -1,3 +1,20 @@
+# Project EmbodiedGen
+#
+# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+# implied. See the License for the specific language governing
+# permissions and limitations under the License.
+
+
import logging
import os
import random
diff --git a/embodied_gen/scripts/gen_texture.py b/embodied_gen/scripts/gen_texture.py
index a0023a8..d28336f 100644
--- a/embodied_gen/scripts/gen_texture.py
+++ b/embodied_gen/scripts/gen_texture.py
@@ -1,3 +1,20 @@
+# Project EmbodiedGen
+#
+# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+# implied. See the License for the specific language governing
+# permissions and limitations under the License.
+
+
import os
import shutil
from dataclasses import dataclass
@@ -94,6 +111,7 @@ def entrypoint() -> None:
delight=cfg.delight,
no_save_delight_img=True,
texture_wh=[cfg.texture_size, cfg.texture_size],
+ no_mesh_post_process=True,
)
drender_api(
mesh_path=f"{output_root}/texture_mesh/{uuid}.obj",
diff --git a/embodied_gen/scripts/imageto3d.py b/embodied_gen/scripts/imageto3d.py
index 61a14b6..dfbc8e4 100644
--- a/embodied_gen/scripts/imageto3d.py
+++ b/embodied_gen/scripts/imageto3d.py
@@ -14,30 +14,30 @@
# implied. See the License for the specific language governing
# permissions and limitations under the License.
-
import argparse
import os
import random
-import sys
from glob import glob
from shutil import copy, copytree, rmtree
import numpy as np
-import torch
import trimesh
from PIL import Image
from embodied_gen.data.backproject_v3 import entrypoint as backproject_api
-from embodied_gen.data.utils import delete_dir, trellis_preprocess
+from embodied_gen.data.utils import delete_dir
+# from embodied_gen.models.sr_model import ImageRealESRGAN
# from embodied_gen.models.delight_model import DelightingModel
from embodied_gen.models.gs_model import GaussianOperator
from embodied_gen.models.segment_model import RembgRemover
-
-# from embodied_gen.models.sr_model import ImageRealESRGAN
from embodied_gen.scripts.render_gs import entrypoint as render_gs_api
from embodied_gen.utils.gpt_clients import GPT_CLIENT
+from embodied_gen.utils.inference import image3d_model_infer
from embodied_gen.utils.log import logger
-from embodied_gen.utils.process_media import merge_images_video
+from embodied_gen.utils.process_media import (
+ combine_images_to_grid,
+ merge_images_video,
+)
from embodied_gen.utils.tags import VERSION
from embodied_gen.utils.trender import render_video
from embodied_gen.validators.quality_checkers import (
@@ -48,26 +48,24 @@
)
from embodied_gen.validators.urdf_convertor import URDFGenerator
-current_file_path = os.path.abspath(__file__)
-current_dir = os.path.dirname(current_file_path)
-sys.path.append(os.path.join(current_dir, "../.."))
-from thirdparty.TRELLIS.trellis.pipelines import TrellisImageTo3DPipeline
+# random.seed(0)
+IMAGE3D_MODEL = "SAM3D" # TRELLIS or SAM3D
+logger.info(f"Loading {IMAGE3D_MODEL} as Image3D Models...")
+if IMAGE3D_MODEL == "TRELLIS":
+ from thirdparty.TRELLIS.trellis.pipelines import TrellisImageTo3DPipeline
-os.environ["TORCH_EXTENSIONS_DIR"] = os.path.expanduser(
- "~/.cache/torch_extensions"
-)
-os.environ["GRADIO_ANALYTICS_ENABLED"] = "false"
-os.environ["SPCONV_ALGO"] = "native"
-random.seed(0)
+ PIPELINE = TrellisImageTo3DPipeline.from_pretrained(
+ "microsoft/TRELLIS-image-large"
+ )
+ # PIPELINE.cuda()
+elif IMAGE3D_MODEL == "SAM3D":
+ from embodied_gen.models.sam3d import Sam3dInference
+
+ PIPELINE = Sam3dInference()
-logger.info("Loading Image3D Models...")
# DELIGHT = DelightingModel()
# IMAGESR_MODEL = ImageRealESRGAN(outscale=4)
RBG_REMOVER = RembgRemover()
-PIPELINE = TrellisImageTo3DPipeline.from_pretrained(
- "microsoft/TRELLIS-image-large"
-)
-# PIPELINE.cuda()
SEG_CHECKER = ImageSegChecker(GPT_CLIENT)
GEO_CHECKER = MeshGeoChecker(GPT_CLIENT)
AESTHETIC_CHECKER = ImageAestheticChecker()
@@ -151,7 +149,6 @@ def entrypoint(**kwargs):
# Segmentation: Get segmented image using Rembg.
seg_path = f"{output_root}/{filename}_cond.png"
seg_image = RBG_REMOVER(image) if image.mode != "RGBA" else image
- seg_image = trellis_preprocess(seg_image)
seg_image.save(seg_path)
seed = args.seed
@@ -162,27 +159,8 @@ def entrypoint(**kwargs):
logger.info(
f"Try: {try_idx + 1}/{args.n_retry}, Seed: {seed}, Prompt: {seg_path}"
)
- # Run the pipeline
try:
- PIPELINE.cuda()
- outputs = PIPELINE.run(
- seg_image,
- preprocess_image=False,
- seed=(
- random.randint(0, 100000) if seed is None else seed
- ),
- # Optional parameters
- # sparse_structure_sampler_params={
- # "steps": 12,
- # "cfg_strength": 7.5,
- # },
- # slat_sampler_params={
- # "steps": 12,
- # "cfg_strength": 3,
- # },
- )
- PIPELINE.cpu()
- torch.cuda.empty_cache()
+ outputs = image3d_model_infer(PIPELINE, seg_image, seed)
except Exception as e:
logger.error(
f"[Pipeline Failed] process {image_path}: {e}, skip."
@@ -215,14 +193,13 @@ def entrypoint(**kwargs):
render_gs_api(
input_gs=aligned_gs_path,
output_path=color_path,
- elevation=[20, -10, 60, -50],
- num_images=12,
+ elevation=[30, -30],
+ num_images=4,
)
-
color_img = Image.open(color_path)
- keep_height = int(color_img.height * 2 / 3)
- crop_img = color_img.crop((0, 0, color_img.width, keep_height))
- geo_flag, geo_result = GEO_CHECKER([crop_img], text=asset_node)
+ geo_flag, geo_result = GEO_CHECKER(
+ [color_img], text=asset_node
+ )
logger.warning(
f"{GEO_CHECKER.__class__.__name__}: {geo_result} for {seg_path}"
)
@@ -232,8 +209,8 @@ def entrypoint(**kwargs):
seed = random.randint(0, 100000) if seed is not None else None
# Render the video for generated 3D asset.
- color_images = render_video(gs_model)["color"]
- normal_images = render_video(mesh_model)["normal"]
+ color_images = render_video(gs_model, r=1.85)["color"]
+ normal_images = render_video(mesh_model, r=1.85)["normal"]
video_path = os.path.join(output_root, "gs_mesh.mp4")
merge_images_video(color_images, normal_images, video_path)
@@ -312,7 +289,7 @@ def entrypoint(**kwargs):
image_paths = glob(f"{image_dir}/*.png")
images_list = []
for checker in CHECKERS:
- images = image_paths
+ images = combine_images_to_grid(image_paths)
if isinstance(checker, ImageSegChecker):
images = [
f"{output_root}/{filename}_raw.png",
@@ -334,9 +311,12 @@ def entrypoint(**kwargs):
f"{result_dir}/{urdf_convertor.output_mesh_dir}",
)
copy(video_path, f"{result_dir}/video.mp4")
+
if not args.keep_intermediate:
delete_dir(output_root, keep_subs=["result"])
+ logger.info(f"Saved results for {image_path} in {result_dir}")
+
except Exception as e:
logger.error(f"Failed to process {image_path}: {e}, skip.")
continue
diff --git a/embodied_gen/scripts/render_gs.py b/embodied_gen/scripts/render_gs.py
index 3a3d7a2..e00c548 100644
--- a/embodied_gen/scripts/render_gs.py
+++ b/embodied_gen/scripts/render_gs.py
@@ -27,7 +27,6 @@
from embodied_gen.data.utils import (
CameraSetting,
init_kal_camera,
- normalize_vertices_array,
)
from embodied_gen.models.gs_model import load_gs_model
from embodied_gen.utils.process_media import combine_images_to_grid
diff --git a/embodied_gen/scripts/textto3d.py b/embodied_gen/scripts/textto3d.py
index 4e96063..c5fe5f3 100644
--- a/embodied_gen/scripts/textto3d.py
+++ b/embodied_gen/scripts/textto3d.py
@@ -30,6 +30,7 @@
from embodied_gen.utils.log import logger
from embodied_gen.utils.process_media import (
check_object_edge_truncated,
+ combine_images_to_grid,
render_asset3d,
)
from embodied_gen.validators.quality_checkers import (
@@ -51,7 +52,6 @@
__all__ = [
- "text_to_image",
"text_to_3d",
]
@@ -176,12 +176,12 @@ def text_to_3d(**kwargs) -> dict:
image_path = render_asset3d(
mesh_path,
output_root=f"{node_save_dir}/result",
- num_images=6,
+ num_images=4,
elevation=(30, -30),
output_subdir="renders",
no_index_file=True,
)
-
+ image_path = combine_images_to_grid(image_path)
check_text = asset_type if asset_type is not None else prompt
qa_flag, qa_result = TXTGEN_CHECKER(check_text, image_path)
logger.warning(
diff --git a/embodied_gen/utils/gpt_clients.py b/embodied_gen/utils/gpt_clients.py
index 47f5ce2..32a9ea9 100644
--- a/embodied_gen/utils/gpt_clients.py
+++ b/embodied_gen/utils/gpt_clients.py
@@ -21,13 +21,14 @@
from io import BytesIO
from typing import Optional
+import openai
import yaml
from openai import AzureOpenAI, OpenAI # pip install openai
from PIL import Image
from tenacity import (
retry,
+ retry_if_not_exception_type,
stop_after_attempt,
- stop_after_delay,
wait_random_exponential,
)
from embodied_gen.utils.process_media import combine_images_to_grid
@@ -106,8 +107,9 @@ def __init__(
logger.info(f"Using GPT model: {self.model_name}.")
@retry(
- wait=wait_random_exponential(min=1, max=20),
- stop=(stop_after_attempt(10) | stop_after_delay(30)),
+ retry=retry_if_not_exception_type(openai.BadRequestError),
+ wait=wait_random_exponential(min=1, max=10),
+ stop=stop_after_attempt(5),
)
def completion_with_backoff(self, **kwargs):
"""Performs a chat completion request with retry/backoff."""
@@ -246,3 +248,8 @@ def check_connection(self) -> None:
model_name=model_name,
check_connection=False,
)
+
+
+if __name__ == "__main__":
+ response = GPT_CLIENT.query("What is the capital of China?")
+ print(f"Response: {response}")
diff --git a/embodied_gen/utils/inference.py b/embodied_gen/utils/inference.py
new file mode 100644
index 0000000..5e19e93
--- /dev/null
+++ b/embodied_gen/utils/inference.py
@@ -0,0 +1,59 @@
+from embodied_gen.utils.monkey_patches import monkey_path_trellis
+
+monkey_path_trellis()
+import random
+
+import torch
+from PIL import Image
+from embodied_gen.data.utils import trellis_preprocess
+from embodied_gen.models.sam3d import Sam3dInference
+from embodied_gen.utils.trender import pack_state, unpack_state
+from thirdparty.TRELLIS.trellis.pipelines import TrellisImageTo3DPipeline
+
+__all__ = [
+ "image3d_model_infer",
+]
+
+
+def image3d_model_infer(
+ pipe: TrellisImageTo3DPipeline | Sam3dInference,
+ seg_image: Image.Image,
+ seed: int = None,
+ **kwargs: dict,
+) -> dict[str, any]:
+ if isinstance(pipe, TrellisImageTo3DPipeline):
+ pipe.cuda()
+ seg_image = trellis_preprocess(seg_image)
+ outputs = pipe.run(
+ seg_image,
+ preprocess_image=False,
+ seed=(random.randint(0, 100000) if seed is None else seed),
+ # Optional parameters
+ # sparse_structure_sampler_params={
+ # "steps": 12,
+ # "cfg_strength": 7.5,
+ # },
+ # slat_sampler_params={
+ # "steps": 12,
+ # "cfg_strength": 3,
+ # },
+ **kwargs,
+ )
+ pipe.cpu()
+ elif isinstance(pipe, Sam3dInference):
+ outputs = pipe.run(
+ seg_image,
+ seed=(random.randint(0, 100000) if seed is None else seed),
+ # stage1_inference_steps=25,
+ # stage2_inference_steps=25,
+ **kwargs,
+ )
+ state = pack_state(outputs["gaussian"][0], outputs["mesh"][0])
+ # Align GS3D from SAM3D with TRELLIS format.
+ outputs["gaussian"][0], _ = unpack_state(state, device="cuda")
+ else:
+ raise ValueError(f"Unsupported pipeline type: {type(pipe)}")
+
+ torch.cuda.empty_cache()
+
+ return outputs
diff --git a/embodied_gen/utils/monkey_patches.py b/embodied_gen/utils/monkey_patches.py
index b5d35cf..6e3b033 100644
--- a/embodied_gen/utils/monkey_patches.py
+++ b/embodied_gen/utils/monkey_patches.py
@@ -25,6 +25,73 @@
from PIL import Image
from torchvision import transforms
+__all__ = [
+ "monkey_patch_pano2room",
+ "monkey_patch_maniskill",
+ "monkey_patch_sam3d",
+]
+
+
+def monkey_path_trellis():
+ import torch.nn.functional as F
+
+ current_file_path = os.path.abspath(__file__)
+ current_dir = os.path.dirname(current_file_path)
+ sys.path.append(os.path.join(current_dir, "../.."))
+
+ from thirdparty.TRELLIS.trellis.representations import Gaussian
+ from thirdparty.TRELLIS.trellis.representations.gaussian.general_utils import (
+ build_scaling_rotation,
+ inverse_sigmoid,
+ strip_symmetric,
+ )
+
+ os.environ["TORCH_EXTENSIONS_DIR"] = os.path.expanduser(
+ "~/.cache/torch_extensions"
+ )
+ os.environ["SPCONV_ALGO"] = "auto" # Can be 'native' or 'auto'
+ os.environ['ATTN_BACKEND'] = (
+ "xformers" # Can be 'flash-attn' or 'xformers'
+ )
+ from thirdparty.TRELLIS.trellis.modules.sparse import set_attn
+
+ set_attn("xformers")
+
+ def patched_setup_functions(self):
+ def inverse_softplus(x):
+ return x + torch.log(-torch.expm1(-x))
+
+ def build_covariance_from_scaling_rotation(
+ scaling, scaling_modifier, rotation
+ ):
+ L = build_scaling_rotation(scaling_modifier * scaling, rotation)
+ actual_covariance = L @ L.transpose(1, 2)
+ symm = strip_symmetric(actual_covariance)
+ return symm
+
+ if self.scaling_activation_type == "exp":
+ self.scaling_activation = torch.exp
+ self.inverse_scaling_activation = torch.log
+ elif self.scaling_activation_type == "softplus":
+ self.scaling_activation = F.softplus
+ self.inverse_scaling_activation = inverse_softplus
+
+ self.covariance_activation = build_covariance_from_scaling_rotation
+ self.opacity_activation = torch.sigmoid
+ self.inverse_opacity_activation = inverse_sigmoid
+ self.rotation_activation = F.normalize
+
+ self.scale_bias = self.inverse_scaling_activation(
+ torch.tensor(self.scaling_bias)
+ ).to(self.device)
+ self.rots_bias = torch.zeros((4)).to(self.device)
+ self.rots_bias[0] = 1
+ self.opacity_bias = self.inverse_opacity_activation(
+ torch.tensor(self.opacity_bias)
+ ).to(self.device)
+
+ Gaussian.setup_functions = patched_setup_functions
+
def monkey_patch_pano2room():
current_file_path = os.path.abspath(__file__)
@@ -216,3 +283,363 @@ def get_rgba_tensor(camera, return_alpha):
ManiSkillScene.get_human_render_camera_images = (
get_human_render_camera_images
)
+
+
+def monkey_patch_sam3d():
+ from typing import Optional, Union
+
+ from embodied_gen.data.utils import model_device_ctx
+ from embodied_gen.utils.log import logger
+
+ os.environ["LIDRA_SKIP_INIT"] = "true"
+
+ current_file_path = os.path.abspath(__file__)
+ current_dir = os.path.dirname(current_file_path)
+ sam3d_root = os.path.abspath(
+ os.path.join(current_dir, "../../thirdparty/sam3d")
+ )
+ if sam3d_root not in sys.path:
+ sys.path.insert(0, sam3d_root)
+
+ def patch_pointmap_infer_pipeline():
+ from copy import deepcopy
+
+ try:
+ from sam3d_objects.pipeline.inference_pipeline_pointmap import (
+ InferencePipelinePointMap,
+ )
+ except ImportError:
+ logger.error(
+ "[MonkeyPatch]: Could not import sam3d_objects directly. Check paths."
+ )
+ return
+
+ def patch_run(
+ self,
+ image: Union[None, Image.Image, np.ndarray],
+ mask: Union[None, Image.Image, np.ndarray] = None,
+ seed: Optional[int] = None,
+ stage1_only=False,
+ with_mesh_postprocess=True,
+ with_texture_baking=True,
+ with_layout_postprocess=True,
+ use_vertex_color=False,
+ stage1_inference_steps=None,
+ stage2_inference_steps=None,
+ use_stage1_distillation=False,
+ use_stage2_distillation=False,
+ pointmap=None,
+ decode_formats=None,
+ estimate_plane=False,
+ ) -> dict:
+ image = self.merge_image_and_mask(image, mask)
+ with self.device:
+ pointmap_dict = self.compute_pointmap(image, pointmap)
+ pointmap = pointmap_dict["pointmap"]
+ pts = type(self)._down_sample_img(pointmap)
+ pts_colors = type(self)._down_sample_img(
+ pointmap_dict["pts_color"]
+ )
+
+ if estimate_plane:
+ return self.estimate_plane(pointmap_dict, image)
+
+ ss_input_dict = self.preprocess_image(
+ image, self.ss_preprocessor, pointmap=pointmap
+ )
+
+ slat_input_dict = self.preprocess_image(
+ image, self.slat_preprocessor
+ )
+ if seed is not None:
+ torch.manual_seed(seed)
+
+ with model_device_ctx(
+ self.models["ss_generator"],
+ self.models["ss_decoder"],
+ self.condition_embedders["ss_condition_embedder"],
+ ):
+ ss_return_dict = self.sample_sparse_structure(
+ ss_input_dict,
+ inference_steps=stage1_inference_steps,
+ use_distillation=use_stage1_distillation,
+ )
+
+ # We could probably use the decoder from the models themselves
+ pointmap_scale = ss_input_dict.get("pointmap_scale", None)
+ pointmap_shift = ss_input_dict.get("pointmap_shift", None)
+ ss_return_dict.update(
+ self.pose_decoder(
+ ss_return_dict,
+ scene_scale=pointmap_scale,
+ scene_shift=pointmap_shift,
+ )
+ )
+
+ ss_return_dict["scale"] = (
+ ss_return_dict["scale"]
+ * ss_return_dict["downsample_factor"]
+ )
+
+ if stage1_only:
+ logger.info("Finished!")
+ ss_return_dict["voxel"] = (
+ ss_return_dict["coords"][:, 1:] / 64 - 0.5
+ )
+ return {
+ **ss_return_dict,
+ "pointmap": pts.cpu().permute((1, 2, 0)), # HxWx3
+ "pointmap_colors": pts_colors.cpu().permute(
+ (1, 2, 0)
+ ), # HxWx3
+ }
+ # return ss_return_dict
+
+ coords = ss_return_dict["coords"]
+ with model_device_ctx(
+ self.models["slat_generator"],
+ self.condition_embedders["slat_condition_embedder"],
+ ):
+ slat = self.sample_slat(
+ slat_input_dict,
+ coords,
+ inference_steps=stage2_inference_steps,
+ use_distillation=use_stage2_distillation,
+ )
+
+ with model_device_ctx(
+ self.models["slat_decoder_mesh"],
+ self.models["slat_decoder_gs"],
+ self.models["slat_decoder_gs_4"],
+ ):
+ outputs = self.decode_slat(
+ slat,
+ (
+ self.decode_formats
+ if decode_formats is None
+ else decode_formats
+ ),
+ )
+
+ outputs = self.postprocess_slat_output(
+ outputs,
+ with_mesh_postprocess,
+ with_texture_baking,
+ use_vertex_color,
+ )
+ glb = outputs.get("glb", None)
+
+ try:
+ if (
+ with_layout_postprocess
+ and self.layout_post_optimization_method is not None
+ ):
+ assert (
+ glb is not None
+ ), "require mesh to run postprocessing"
+ logger.info(
+ "Running layout post optimization method..."
+ )
+ postprocessed_pose = self.run_post_optimization(
+ deepcopy(glb),
+ pointmap_dict["intrinsics"],
+ ss_return_dict,
+ ss_input_dict,
+ )
+ ss_return_dict.update(postprocessed_pose)
+ except Exception as e:
+ logger.error(
+ f"Error during layout post optimization: {e}",
+ exc_info=True,
+ )
+
+ result = {
+ **ss_return_dict,
+ **outputs,
+ "pointmap": pts.cpu().permute((1, 2, 0)),
+ "pointmap_colors": pts_colors.cpu().permute((1, 2, 0)),
+ }
+ return result
+
+ InferencePipelinePointMap.run = patch_run
+
+ def patch_infer_init():
+ import torch
+
+ try:
+ from sam3d_objects.pipeline import preprocess_utils
+ from sam3d_objects.pipeline.inference_pipeline_pointmap import (
+ InferencePipeline,
+ )
+ from sam3d_objects.pipeline.inference_utils import (
+ SLAT_MEAN,
+ SLAT_STD,
+ )
+ except ImportError:
+ print(
+ "[MonkeyPatch] Error: Could not import sam3d_objects directly for infer pipeline."
+ )
+ return
+
+ def patch_init(
+ self,
+ ss_generator_config_path,
+ ss_generator_ckpt_path,
+ slat_generator_config_path,
+ slat_generator_ckpt_path,
+ ss_decoder_config_path,
+ ss_decoder_ckpt_path,
+ slat_decoder_gs_config_path,
+ slat_decoder_gs_ckpt_path,
+ slat_decoder_mesh_config_path,
+ slat_decoder_mesh_ckpt_path,
+ slat_decoder_gs_4_config_path=None,
+ slat_decoder_gs_4_ckpt_path=None,
+ ss_encoder_config_path=None,
+ ss_encoder_ckpt_path=None,
+ decode_formats=["gaussian", "mesh"],
+ dtype="bfloat16",
+ pad_size=1.0,
+ version="v0",
+ device="cuda",
+ ss_preprocessor=preprocess_utils.get_default_preprocessor(),
+ slat_preprocessor=preprocess_utils.get_default_preprocessor(),
+ ss_condition_input_mapping=["image"],
+ slat_condition_input_mapping=["image"],
+ pose_decoder_name="default",
+ workspace_dir="",
+ downsample_ss_dist=0, # the distance we use to downsample
+ ss_inference_steps=25,
+ ss_rescale_t=3,
+ ss_cfg_strength=7,
+ ss_cfg_interval=[0, 500],
+ ss_cfg_strength_pm=0.0,
+ slat_inference_steps=25,
+ slat_rescale_t=3,
+ slat_cfg_strength=5,
+ slat_cfg_interval=[0, 500],
+ rendering_engine: str = "nvdiffrast", # nvdiffrast OR pytorch3d,
+ shape_model_dtype=None,
+ compile_model=False,
+ slat_mean=SLAT_MEAN,
+ slat_std=SLAT_STD,
+ ):
+ self.rendering_engine = rendering_engine
+ self.device = torch.device(device)
+ self.compile_model = compile_model
+ with self.device:
+ self.decode_formats = decode_formats
+ self.pad_size = pad_size
+ self.version = version
+ self.ss_condition_input_mapping = ss_condition_input_mapping
+ self.slat_condition_input_mapping = (
+ slat_condition_input_mapping
+ )
+ self.workspace_dir = workspace_dir
+ self.downsample_ss_dist = downsample_ss_dist
+ self.ss_inference_steps = ss_inference_steps
+ self.ss_rescale_t = ss_rescale_t
+ self.ss_cfg_strength = ss_cfg_strength
+ self.ss_cfg_interval = ss_cfg_interval
+ self.ss_cfg_strength_pm = ss_cfg_strength_pm
+ self.slat_inference_steps = slat_inference_steps
+ self.slat_rescale_t = slat_rescale_t
+ self.slat_cfg_strength = slat_cfg_strength
+ self.slat_cfg_interval = slat_cfg_interval
+
+ self.dtype = self._get_dtype(dtype)
+ if shape_model_dtype is None:
+ self.shape_model_dtype = self.dtype
+ else:
+ self.shape_model_dtype = self._get_dtype(shape_model_dtype)
+
+ # Setup preprocessors
+ self.pose_decoder = self.init_pose_decoder(
+ ss_generator_config_path, pose_decoder_name
+ )
+ self.ss_preprocessor = self.init_ss_preprocessor(
+ ss_preprocessor, ss_generator_config_path
+ )
+ self.slat_preprocessor = slat_preprocessor
+
+ raw_device = self.device
+ self.device = torch.device("cpu")
+ ss_generator = self.init_ss_generator(
+ ss_generator_config_path, ss_generator_ckpt_path
+ )
+ slat_generator = self.init_slat_generator(
+ slat_generator_config_path, slat_generator_ckpt_path
+ )
+ ss_decoder = self.init_ss_decoder(
+ ss_decoder_config_path, ss_decoder_ckpt_path
+ )
+ ss_encoder = self.init_ss_encoder(
+ ss_encoder_config_path, ss_encoder_ckpt_path
+ )
+ slat_decoder_gs = self.init_slat_decoder_gs(
+ slat_decoder_gs_config_path, slat_decoder_gs_ckpt_path
+ )
+ slat_decoder_gs_4 = self.init_slat_decoder_gs(
+ slat_decoder_gs_4_config_path, slat_decoder_gs_4_ckpt_path
+ )
+ slat_decoder_mesh = self.init_slat_decoder_mesh(
+ slat_decoder_mesh_config_path, slat_decoder_mesh_ckpt_path
+ )
+
+ # Load conditioner embedder so that we only load it once
+ ss_condition_embedder = self.init_ss_condition_embedder(
+ ss_generator_config_path, ss_generator_ckpt_path
+ )
+ slat_condition_embedder = self.init_slat_condition_embedder(
+ slat_generator_config_path, slat_generator_ckpt_path
+ )
+ self.device = raw_device
+
+ self.condition_embedders = {
+ "ss_condition_embedder": ss_condition_embedder,
+ "slat_condition_embedder": slat_condition_embedder,
+ }
+
+ # override generator and condition embedder setting
+ self.override_ss_generator_cfg_config(
+ ss_generator,
+ cfg_strength=ss_cfg_strength,
+ inference_steps=ss_inference_steps,
+ rescale_t=ss_rescale_t,
+ cfg_interval=ss_cfg_interval,
+ cfg_strength_pm=ss_cfg_strength_pm,
+ )
+ self.override_slat_generator_cfg_config(
+ slat_generator,
+ cfg_strength=slat_cfg_strength,
+ inference_steps=slat_inference_steps,
+ rescale_t=slat_rescale_t,
+ cfg_interval=slat_cfg_interval,
+ )
+
+ self.models = torch.nn.ModuleDict(
+ {
+ "ss_generator": ss_generator,
+ "slat_generator": slat_generator,
+ "ss_encoder": ss_encoder,
+ "ss_decoder": ss_decoder,
+ "slat_decoder_gs": slat_decoder_gs,
+ "slat_decoder_gs_4": slat_decoder_gs_4,
+ "slat_decoder_mesh": slat_decoder_mesh,
+ }
+ )
+ logger.info("Loading SAM3D model weights completed.")
+
+ if self.compile_model:
+ logger.info("Compiling model...")
+ self._compile()
+ logger.info("Model compilation completed!")
+ self.slat_mean = torch.tensor(slat_mean)
+ self.slat_std = torch.tensor(slat_std)
+
+ InferencePipeline.__init__ = patch_init
+
+ patch_pointmap_infer_pipeline()
+ patch_infer_init()
+
+ return
diff --git a/embodied_gen/utils/process_media.py b/embodied_gen/utils/process_media.py
index 8feb7ec..3a68ca1 100644
--- a/embodied_gen/utils/process_media.py
+++ b/embodied_gen/utils/process_media.py
@@ -96,7 +96,7 @@ def render_asset3d(
image_paths = render_asset3d(
mesh_path="path_to_mesh.obj",
output_root="path_to_save_dir",
- num_images=6,
+ num_images=4,
elevation=(30, -30),
output_subdir="renders",
no_index_file=True,
@@ -230,6 +230,29 @@ def filter_image_small_connected_components(
return image
+def keep_largest_connected_component(pil_img: Image.Image) -> Image.Image:
+ if pil_img.mode != "RGBA":
+ pil_img = pil_img.convert("RGBA")
+
+ img_arr = np.array(pil_img)
+ alpha_channel = img_arr[:, :, 3]
+
+ _, binary_mask = cv2.threshold(alpha_channel, 0, 255, cv2.THRESH_BINARY)
+ num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(
+ binary_mask, connectivity=8
+ )
+ if num_labels < 2:
+ return pil_img
+
+ largest_label = 1 + np.argmax(stats[1:, cv2.CC_STAT_AREA])
+ new_alpha = np.where(labels == largest_label, alpha_channel, 0).astype(
+ np.uint8
+ )
+ img_arr[:, :, 3] = new_alpha
+
+ return Image.fromarray(img_arr)
+
+
def combine_images_to_grid(
images: list[str | Image.Image],
cat_row_col: tuple[int, int] = None,
@@ -439,7 +462,7 @@ def render(
plt.axis("off")
legend_handles = [
- Patch(facecolor=color, edgecolor='black', label=role)
+ Patch(facecolor=color, edgecolor="black", label=role)
for role, color in self.role_colors.items()
]
plt.legend(
@@ -465,7 +488,7 @@ def load_scene_dict(file_path: str) -> dict:
dict: Mapping from scene ID to description.
"""
scene_dict = {}
- with open(file_path, "r", encoding='utf-8') as f:
+ with open(file_path, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line or ":" not in line:
@@ -487,7 +510,7 @@ def is_image_file(filename: str) -> bool:
"""
mime_type, _ = mimetypes.guess_type(filename)
- return mime_type is not None and mime_type.startswith('image')
+ return mime_type is not None and mime_type.startswith("image")
def parse_text_prompts(prompts: list[str]) -> list[str]:
diff --git a/embodied_gen/utils/tags.py b/embodied_gen/utils/tags.py
index 9302331..b03d010 100644
--- a/embodied_gen/utils/tags.py
+++ b/embodied_gen/utils/tags.py
@@ -1 +1 @@
-VERSION = "v0.1.6"
+VERSION = "v0.1.7"
diff --git a/embodied_gen/utils/trender.py b/embodied_gen/utils/trender.py
index 53acc50..f2a845f 100644
--- a/embodied_gen/utils/trender.py
+++ b/embodied_gen/utils/trender.py
@@ -16,29 +16,35 @@
import os
import sys
+from collections import defaultdict
import numpy as np
import spaces
import torch
+from easydict import EasyDict as edict
from tqdm import tqdm
current_file_path = os.path.abspath(__file__)
current_dir = os.path.dirname(current_file_path)
sys.path.append(os.path.join(current_dir, "../.."))
-from thirdparty.TRELLIS.trellis.renderers.mesh_renderer import MeshRenderer
-from thirdparty.TRELLIS.trellis.representations import MeshExtractResult
+from thirdparty.TRELLIS.trellis.renderers import GaussianRenderer, MeshRenderer
+from thirdparty.TRELLIS.trellis.representations import (
+ Gaussian,
+ MeshExtractResult,
+)
from thirdparty.TRELLIS.trellis.utils.render_utils import (
- render_frames,
yaw_pitch_r_fov_to_extrinsics_intrinsics,
)
__all__ = [
"render_video",
+ "pack_state",
+ "unpack_state",
]
@spaces.GPU
-def render_mesh(sample, extrinsics, intrinsics, options={}, **kwargs):
+def render_mesh_frames(sample, extrinsics, intrinsics, options={}, **kwargs):
renderer = MeshRenderer()
renderer.rendering_options.resolution = options.get("resolution", 512)
renderer.rendering_options.near = options.get("near", 1)
@@ -60,6 +66,57 @@ def render_mesh(sample, extrinsics, intrinsics, options={}, **kwargs):
return rets
+@spaces.GPU
+def render_gs_frames(
+ sample,
+ extrinsics,
+ intrinsics,
+ options=None,
+ colors_overwrite=None,
+ verbose=True,
+ **kwargs,
+):
+ def to_img(tensor):
+ return np.clip(
+ tensor.detach().cpu().numpy().transpose(1, 2, 0) * 255, 0, 255
+ ).astype(np.uint8)
+
+ def to_numpy(tensor):
+ return tensor.detach().cpu().numpy()
+
+ renderer = GaussianRenderer()
+ renderer.pipe.kernel_size = kwargs.get("kernel_size", 0.1)
+ renderer.pipe.use_mip_gaussian = True
+
+ defaults = {
+ "resolution": 512,
+ "near": 0.8,
+ "far": 1.6,
+ "bg_color": (0, 0, 0),
+ "ssaa": 1,
+ }
+ final_options = {**defaults, **(options or {})}
+
+ for k, v in final_options.items():
+ if hasattr(renderer.rendering_options, k):
+ setattr(renderer.rendering_options, k, v)
+
+ outputs = defaultdict(list)
+ iterator = zip(extrinsics, intrinsics)
+ if verbose:
+ iterator = tqdm(iterator, total=len(extrinsics), desc="Rendering")
+
+ for extr, intr in iterator:
+ res = renderer.render(
+ sample, extr, intr, colors_overwrite=colors_overwrite
+ )
+ outputs["color"].append(to_img(res["color"]))
+ depth = res.get("percent_depth") or res.get("depth")
+ outputs["depth"].append(to_numpy(depth) if depth is not None else None)
+
+ return dict(outputs)
+
+
@spaces.GPU
def render_video(
sample,
@@ -77,7 +134,9 @@ def render_video(
yaws, pitch, r, fov
)
render_fn = (
- render_mesh if isinstance(sample, MeshExtractResult) else render_frames
+ render_mesh_frames
+ if sample.__class__.__name__ == "MeshExtractResult"
+ else render_gs_frames
)
result = render_fn(
sample,
@@ -88,3 +147,47 @@ def render_video(
)
return result
+
+
+@spaces.GPU
+def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
+ return {
+ "gaussian": {
+ **gs.init_params,
+ "_xyz": gs._xyz.cpu().numpy(),
+ "_features_dc": gs._features_dc.cpu().numpy(),
+ "_scaling": gs._scaling.cpu().numpy(),
+ "_rotation": gs._rotation.cpu().numpy(),
+ "_opacity": gs._opacity.cpu().numpy(),
+ },
+ "mesh": {
+ "vertices": mesh.vertices.cpu().numpy(),
+ "faces": mesh.faces.cpu().numpy(),
+ },
+ }
+
+
+def unpack_state(state: dict, device: str = "cpu") -> tuple[Gaussian, dict]:
+ gs = Gaussian(
+ aabb=state["gaussian"]["aabb"],
+ sh_degree=state["gaussian"]["sh_degree"],
+ mininum_kernel_size=state["gaussian"]["mininum_kernel_size"],
+ scaling_bias=state["gaussian"]["scaling_bias"],
+ opacity_bias=state["gaussian"]["opacity_bias"],
+ scaling_activation=state["gaussian"]["scaling_activation"],
+ device=device,
+ )
+ gs._xyz = torch.tensor(state["gaussian"]["_xyz"], device=device)
+ gs._features_dc = torch.tensor(
+ state["gaussian"]["_features_dc"], device=device
+ )
+ gs._scaling = torch.tensor(state["gaussian"]["_scaling"], device=device)
+ gs._rotation = torch.tensor(state["gaussian"]["_rotation"], device=device)
+ gs._opacity = torch.tensor(state["gaussian"]["_opacity"], device=device)
+
+ mesh = edict(
+ vertices=torch.tensor(state["mesh"]["vertices"], device=device),
+ faces=torch.tensor(state["mesh"]["faces"], device=device),
+ )
+
+ return gs, mesh
diff --git a/embodied_gen/validators/aesthetic_predictor.py b/embodied_gen/validators/aesthetic_predictor.py
index 6e77449..9feedde 100644
--- a/embodied_gen/validators/aesthetic_predictor.py
+++ b/embodied_gen/validators/aesthetic_predictor.py
@@ -125,7 +125,11 @@ def predict(self, image_path):
Returns:
float: Predicted aesthetic score.
"""
- pil_image = Image.open(image_path)
+ if isinstance(image_path, str):
+ pil_image = Image.open(image_path)
+ else:
+ pil_image = image_path
+
image = self.preprocess(pil_image).unsqueeze(0).to(self.device)
with torch.no_grad():
diff --git a/embodied_gen/validators/quality_checkers.py b/embodied_gen/validators/quality_checkers.py
index 0e5ff7e..3cec795 100644
--- a/embodied_gen/validators/quality_checkers.py
+++ b/embodied_gen/validators/quality_checkers.py
@@ -126,6 +126,30 @@ def __init__(
super().__init__(prompt, verbose)
self.gpt_client = gpt_client
if self.prompt is None:
+ # Old version for TRELLIS.
+ # self.prompt = """
+ # You are an expert in evaluating the geometry quality of generated 3D asset.
+ # You will be given rendered views of a generated 3D asset, type {}, with black background.
+ # Your task is to evaluate the quality of the 3D asset generation,
+ # including geometry, structure, and appearance, based on the rendered views.
+ # Criteria:
+ # - Is the object in the image a single, complete, and well-formed instance,
+ # without truncation, missing parts, overlapping duplicates, or redundant geometry?
+ # - Minor flaws, asymmetries, or simplifications (e.g., less detail on sides or back,
+ # soft edges) are acceptable if the object is structurally sound and recognizable.
+ # - Only evaluate geometry. Do not assess texture quality.
+ # - The asset should not contain any unrelated elements, such as
+ # ground planes, platforms, or background props (e.g., paper, flooring).
+
+ # If all the above criteria are met, return "YES". Otherwise, return
+ # "NO" followed by a brief explanation (no more than 20 words).
+
+ # Example:
+ # Images show a yellow cup standing on a flat white plane -> NO
+ # -> Response: NO: extra white surface under the object.
+ # Image shows a chair with simplified back legs and soft edges -> YES
+ # """
+
self.prompt = """
You are an expert in evaluating the geometry quality of generated 3D asset.
You will be given rendered views of a generated 3D asset, type {}, with black background.
@@ -137,16 +161,13 @@ def __init__(
- Minor flaws, asymmetries, or simplifications (e.g., less detail on sides or back,
soft edges) are acceptable if the object is structurally sound and recognizable.
- Only evaluate geometry. Do not assess texture quality.
- - The asset should not contain any unrelated elements, such as
- ground planes, platforms, or background props (e.g., paper, flooring).
- If all the above criteria are met, return "YES". Otherwise, return
+ If all the above criteria are met, return "YES" only. Otherwise, return
"NO" followed by a brief explanation (no more than 20 words).
Example:
- Images show a yellow cup standing on a flat white plane -> NO
- -> Response: NO: extra white surface under the object.
- Image shows a chair with simplified back legs and soft edges → YES
+ Image shows a chair with one leg missing -> NO: the chair missing leg.
+ Image shows a geometrically complete cup -> YES
"""
def query(
diff --git a/embodied_gen/validators/urdf_convertor.py b/embodied_gen/validators/urdf_convertor.py
index 3f070be..8a48e94 100644
--- a/embodied_gen/validators/urdf_convertor.py
+++ b/embodied_gen/validators/urdf_convertor.py
@@ -27,7 +27,10 @@
from scipy.spatial.transform import Rotation
from embodied_gen.data.convex_decomposer import decompose_convex_mesh
from embodied_gen.utils.gpt_clients import GPT_CLIENT, GPTclient
-from embodied_gen.utils.process_media import render_asset3d
+from embodied_gen.utils.process_media import (
+ combine_images_to_grid,
+ render_asset3d,
+)
from embodied_gen.utils.tags import VERSION
logging.basicConfig(level=logging.INFO)
@@ -482,7 +485,7 @@ def __call__(
output_subdir=self.output_render_dir,
no_index_file=True,
)
-
+ # image_path = combine_images_to_grid(image_path)
response = self.gpt_client.query(text_prompt, image_path)
# logger.info(response)
if response is None:
diff --git a/install/install_basic.sh b/install/install_basic.sh
index 63d4af4..ccbf861 100644
--- a/install/install_basic.sh
+++ b/install/install_basic.sh
@@ -8,7 +8,7 @@ PIP_INSTALL_PACKAGES=(
"torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu118"
"xformers==0.0.27.post2 --index-url https://download.pytorch.org/whl/cu118"
"-r requirements.txt --use-deprecated=legacy-resolver"
- "flash-attn==2.7.0.post2"
+ # "flash-attn==2.7.0.post2"
"utils3d@git+https://github.com/EasternJournalist/utils3d.git@9a4eb15"
"clip@git+https://github.com/openai/CLIP.git"
"segment-anything@git+https://github.com/facebookresearch/segment-anything.git@dca509f"
@@ -16,6 +16,8 @@ PIP_INSTALL_PACKAGES=(
"kolors@git+https://github.com/HochCC/Kolors.git"
"kaolin@git+https://github.com/NVIDIAGameWorks/kaolin.git@v0.16.0"
"git+https://github.com/nerfstudio-project/gsplat.git@v1.5.3"
+ "git+https://github.com/facebookresearch/pytorch3d.git@stable"
+ "MoGe@git+https://github.com/microsoft/MoGe.git@a8c3734"
)
for pkg in "${PIP_INSTALL_PACKAGES[@]}"; do
diff --git a/install/install_extra.sh b/install/install_extra.sh
index af63f7b..302e0ad 100644
--- a/install/install_extra.sh
+++ b/install/install_extra.sh
@@ -4,21 +4,17 @@ SCRIPT_DIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)
source "$SCRIPT_DIR/_utils.sh"
PYTHON_PACKAGES_NODEPS=(
- "timm"
"txt2panoimg@git+https://github.com/HochCC/SD-T2I-360PanoImage"
)
PYTHON_PACKAGES=(
- "ninja"
- "fused-ssim@git+https://github.com/rahul-goel/fused-ssim#egg=328dc98"
+ "fused-ssim@git+https://github.com/rahul-goel/fused-ssim#egg=328dc98 --no-build-isolation"
"git+https://github.com/NVlabs/tiny-cuda-nn/#subdirectory=bindings/torch"
- "git+https://github.com/facebookresearch/pytorch3d.git@stable"
"kornia"
"h5py"
"albumentations==0.5.2"
"webdataset"
"icecream"
- "open3d"
"pyequilib"
)
diff --git a/pyproject.toml b/pyproject.toml
index be3cc60..8ed7bc7 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -7,7 +7,7 @@ packages = ["embodied_gen"]
[project]
name = "embodied_gen"
-version = "v0.1.6"
+version = "v0.1.7"
readme = "README.md"
license = "Apache-2.0"
license-files = ["LICENSE", "NOTICE"]
diff --git a/requirements.txt b/requirements.txt
index 05fbfc3..c72fa6b 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -42,4 +42,13 @@ coacd
mani_skill==3.0.0b21
typing_extensions==4.14.1
ninja
-packaging
\ No newline at end of file
+packaging
+lightning
+astor
+optree
+loguru
+seaborn
+hydra-core
+modelscope
+timm
+open3d
\ No newline at end of file
diff --git a/tests/test_examples/test_quality_checkers.py b/tests/test_examples/test_quality_checkers.py
index 8b604d6..207c415 100644
--- a/tests/test_examples/test_quality_checkers.py
+++ b/tests/test_examples/test_quality_checkers.py
@@ -21,7 +21,10 @@
import pytest
from embodied_gen.utils.gpt_clients import GPT_CLIENT
-from embodied_gen.utils.process_media import render_asset3d
+from embodied_gen.utils.process_media import (
+ combine_images_to_grid,
+ render_asset3d,
+)
from embodied_gen.validators.quality_checkers import (
ImageAestheticChecker,
ImageSegChecker,
@@ -166,12 +169,13 @@ def test_textgen_checker(textalign_checker, mesh_path, text_desc):
image_list = render_asset3d(
mesh_path,
output_root=output_root,
- num_images=6,
+ num_images=4,
elevation=(30, -30),
output_subdir="renders",
no_index_file=True,
with_mtl=False,
)
+ image_list = combine_images_to_grid(image_list)
flag, result = textalign_checker(text_desc, image_list)
logger.info(f"textalign_checker: {flag}, {result}")