|
import gradio as gr |
|
import spaces |
|
import os |
|
import numpy as np |
|
import trimesh |
|
import time |
|
import traceback |
|
import torch |
|
from PIL import Image |
|
import cv2 |
|
import shutil |
|
from segment_anything import SamAutomaticMaskGenerator, build_sam |
|
from omegaconf import OmegaConf |
|
|
|
from modules.bbox_gen.models.autogressive_bbox_gen import BboxGen |
|
from modules.part_synthesis.process_utils import save_parts_outputs |
|
from modules.inference_utils import load_img_mask, prepare_bbox_gen_input, prepare_part_synthesis_input, gen_mesh_from_bounds, vis_voxel_coords, merge_parts |
|
from modules.part_synthesis.pipelines import OmniPartImageTo3DPipeline |
|
from modules.label_2d_mask.visualizer import Visualizer |
|
from transformers import AutoModelForImageSegmentation |
|
|
|
from modules.label_2d_mask.label_parts import ( |
|
prepare_image, |
|
get_sam_mask, |
|
get_mask, |
|
clean_segment_edges, |
|
resize_and_pad_to_square, |
|
size_th as DEFAULT_SIZE_TH |
|
) |
|
|
|
|
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
DTYPE = torch.float16 |
|
MAX_SEED = np.iinfo(np.int32).max |
|
TMP_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "tmp") |
|
os.makedirs(TMP_ROOT, exist_ok=True) |
|
|
|
sam_mask_generator = None |
|
rmbg_model = None |
|
bbox_gen_model = None |
|
part_synthesis_pipeline = None |
|
|
|
size_th = DEFAULT_SIZE_TH |
|
|
|
|
|
def prepare_models(sam_ckpt_path, partfield_ckpt_path, bbox_gen_ckpt_path): |
|
global sam_mask_generator, rmbg_model, bbox_gen_model, part_synthesis_pipeline |
|
if sam_mask_generator is None: |
|
print("Loading SAM model...") |
|
sam_model = build_sam(checkpoint=sam_ckpt_path).to(device=DEVICE) |
|
sam_mask_generator = SamAutomaticMaskGenerator(sam_model) |
|
|
|
if rmbg_model is None: |
|
print("Loading BriaRMBG 2.0 model...") |
|
rmbg_model = AutoModelForImageSegmentation.from_pretrained('briaai/RMBG-2.0', trust_remote_code=True) |
|
rmbg_model.to(DEVICE) |
|
rmbg_model.eval() |
|
|
|
if part_synthesis_pipeline is None: |
|
print("Loading PartSynthesis model...") |
|
part_synthesis_pipeline = OmniPartImageTo3DPipeline.from_pretrained('omnipart/OmniPart') |
|
part_synthesis_pipeline.to(DEVICE) |
|
|
|
if bbox_gen_model is None: |
|
print("Loading BboxGen model...") |
|
bbox_gen_config = OmegaConf.load("configs/bbox_gen.yaml").model.args |
|
bbox_gen_config.partfield_encoder_path = partfield_ckpt_path |
|
bbox_gen_model = BboxGen(bbox_gen_config) |
|
bbox_gen_model.load_state_dict(torch.load(bbox_gen_ckpt_path), strict=False) |
|
bbox_gen_model.to(DEVICE) |
|
bbox_gen_model.eval().half() |
|
|
|
print("Models ready") |
|
|
|
|
|
@spaces.GPU |
|
def process_image(image_path, threshold, req: gr.Request): |
|
"""Process image and generate initial segmentation""" |
|
global size_th |
|
|
|
user_dir = os.path.join(TMP_ROOT, str(req.session_hash)) |
|
os.makedirs(user_dir, exist_ok=True) |
|
|
|
img_name = os.path.basename(image_path).split(".")[0] |
|
|
|
size_th = threshold |
|
|
|
img = Image.open(image_path).convert("RGB") |
|
processed_image = prepare_image(img, rmbg_net=rmbg_model.to(DEVICE)) |
|
|
|
processed_image = resize_and_pad_to_square(processed_image) |
|
white_bg = Image.new("RGBA", processed_image.size, (255, 255, 255, 255)) |
|
white_bg_img = Image.alpha_composite(white_bg, processed_image.convert("RGBA")) |
|
image = np.array(white_bg_img.convert('RGB')) |
|
|
|
rgba_path = os.path.join(user_dir, f"{img_name}_processed.png") |
|
processed_image.save(rgba_path) |
|
|
|
print("Generating raw SAM masks without post-processing...") |
|
raw_masks = sam_mask_generator.generate(image) |
|
|
|
raw_sam_vis = np.copy(image) |
|
raw_sam_vis = np.ones_like(image) * 255 |
|
|
|
sorted_masks = sorted(raw_masks, key=lambda x: x["area"], reverse=True) |
|
|
|
for i, mask_data in enumerate(sorted_masks): |
|
if mask_data["area"] < size_th: |
|
continue |
|
|
|
color_r = (i * 50 + 80) % 256 |
|
color_g = (i * 120 + 40) % 256 |
|
color_b = (i * 180 + 20) % 256 |
|
color = np.array([color_r, color_g, color_b]) |
|
|
|
mask = mask_data["segmentation"] |
|
raw_sam_vis[mask] = color |
|
|
|
visual = Visualizer(image) |
|
|
|
group_ids, pre_merge_im = get_sam_mask( |
|
image, |
|
sam_mask_generator, |
|
visual, |
|
merge_groups=None, |
|
rgba_image=processed_image, |
|
img_name=img_name, |
|
save_dir=user_dir, |
|
size_threshold=size_th |
|
) |
|
|
|
pre_merge_path = os.path.join(user_dir, f"{img_name}_mask_pre_merge.png") |
|
Image.fromarray(pre_merge_im).save(pre_merge_path) |
|
pre_split_vis = np.ones_like(image) * 255 |
|
|
|
unique_ids = np.unique(group_ids) |
|
unique_ids = unique_ids[unique_ids >= 0] |
|
|
|
for i, unique_id in enumerate(unique_ids): |
|
color_r = (i * 50 + 80) % 256 |
|
color_g = (i * 120 + 40) % 256 |
|
color_b = (i * 180 + 20) % 256 |
|
color = np.array([color_r, color_g, color_b]) |
|
|
|
mask = (group_ids == unique_id) |
|
pre_split_vis[mask] = color |
|
|
|
y_indices, x_indices = np.where(mask) |
|
if len(y_indices) > 0: |
|
center_y = int(np.mean(y_indices)) |
|
center_x = int(np.mean(x_indices)) |
|
cv2.putText(pre_split_vis, str(unique_id), |
|
(center_x, center_y), cv2.FONT_HERSHEY_SIMPLEX, |
|
0.5, (0, 0, 0), 1, cv2.LINE_AA) |
|
|
|
pre_split_path = os.path.join(user_dir, f"{img_name}_pre_split.png") |
|
Image.fromarray(pre_split_vis).save(pre_split_path) |
|
print(f"Pre-split segmentation (before disconnected parts handling) saved to {pre_split_path}") |
|
|
|
get_mask(group_ids, image, ids=2, img_name=img_name, save_dir=user_dir) |
|
|
|
init_seg_path = os.path.join(user_dir, f"{img_name}_mask_segments_2.png") |
|
|
|
seg_img = Image.open(init_seg_path) |
|
if seg_img.mode == 'RGBA': |
|
white_bg = Image.new('RGBA', seg_img.size, (255, 255, 255, 255)) |
|
seg_img = Image.alpha_composite(white_bg, seg_img) |
|
seg_img.save(init_seg_path) |
|
|
|
state = { |
|
"image": image.tolist(), |
|
"processed_image": rgba_path, |
|
"group_ids": group_ids.tolist() if isinstance(group_ids, np.ndarray) else group_ids, |
|
"original_group_ids": group_ids.tolist() if isinstance(group_ids, np.ndarray) else group_ids, |
|
"img_name": img_name, |
|
"pre_split_path": pre_split_path, |
|
} |
|
|
|
return init_seg_path, pre_merge_path, state |
|
|
|
|
|
def apply_merge(merge_input, state, req: gr.Request): |
|
"""Apply merge parameters and generate merged segmentation""" |
|
global sam_mask_generator |
|
|
|
if not state: |
|
return None, None, state |
|
|
|
user_dir = os.path.join(TMP_ROOT, str(req.session_hash)) |
|
|
|
|
|
image = np.array(state["image"]) |
|
|
|
group_ids = np.array(state["original_group_ids"]) |
|
img_name = state["img_name"] |
|
|
|
|
|
processed_image = Image.open(state["processed_image"]) |
|
|
|
|
|
unique_ids = np.unique(group_ids) |
|
unique_ids = unique_ids[unique_ids >= 0] |
|
print(f"Original segment IDs (used for merging): {sorted(unique_ids.tolist())}") |
|
|
|
|
|
merge_groups = None |
|
try: |
|
if merge_input: |
|
merge_groups = [] |
|
group_sets = merge_input.split(';') |
|
for group_set in group_sets: |
|
ids = [int(x) for x in group_set.split(',')] |
|
if ids: |
|
|
|
existing_ids = [id for id in ids if id in unique_ids] |
|
missing_ids = [id for id in ids if id not in unique_ids] |
|
|
|
if missing_ids: |
|
print(f"Warning: These IDs don't exist in the segmentation: {missing_ids}") |
|
|
|
|
|
if existing_ids: |
|
merge_groups.append(ids) |
|
print(f"Valid merge group: {ids} (missing: {missing_ids if missing_ids else 'none'})") |
|
else: |
|
print(f"Skipping merge group with no valid IDs: {ids}") |
|
|
|
print(f"Using merge groups: {merge_groups}") |
|
except Exception as e: |
|
print(f"Error parsing merge groups: {e}") |
|
return None, None, state |
|
|
|
|
|
visual = Visualizer(image) |
|
|
|
|
|
|
|
new_group_ids, merged_im = get_sam_mask( |
|
image, |
|
sam_mask_generator, |
|
visual, |
|
merge_groups=merge_groups, |
|
existing_group_ids=group_ids, |
|
rgba_image=processed_image, |
|
skip_split=True, |
|
img_name=img_name, |
|
save_dir=user_dir, |
|
size_threshold=size_th |
|
) |
|
|
|
|
|
new_unique_ids = np.unique(new_group_ids) |
|
new_unique_ids = new_unique_ids[new_unique_ids >= 0] |
|
print(f"New segment IDs (after merging): {new_unique_ids.tolist()}") |
|
|
|
|
|
new_group_ids = clean_segment_edges(new_group_ids) |
|
|
|
|
|
get_mask(new_group_ids, image, ids=3, img_name=img_name, save_dir=user_dir) |
|
|
|
|
|
merged_seg_path = os.path.join(user_dir, f"{img_name}_mask_segments_3.png") |
|
|
|
save_mask = new_group_ids + 1 |
|
save_mask = save_mask.reshape(518, 518, 1).repeat(3, axis=-1) |
|
cv2.imwrite(os.path.join(user_dir, f"{img_name}_mask.exr"), save_mask.astype(np.float32)) |
|
|
|
|
|
state["group_ids"] = new_group_ids.tolist() if isinstance(new_group_ids, np.ndarray) else new_group_ids |
|
state["save_mask_path"] = os.path.join(user_dir, f"{img_name}_mask.exr") |
|
|
|
return merged_seg_path, state |
|
|
|
|
|
def explode_mesh(mesh, explosion_scale=0.4): |
|
|
|
if isinstance(mesh, trimesh.Scene): |
|
scene = mesh |
|
elif isinstance(mesh, trimesh.Trimesh): |
|
print("Warning: Single mesh provided, can't create exploded view") |
|
scene = trimesh.Scene(mesh) |
|
return scene |
|
else: |
|
print(f"Warning: Unexpected mesh type: {type(mesh)}") |
|
scene = mesh |
|
|
|
if len(scene.geometry) <= 1: |
|
print("Only one geometry found - nothing to explode") |
|
return scene |
|
|
|
print(f"[EXPLODE_MESH] Starting mesh explosion with scale {explosion_scale}") |
|
print(f"[EXPLODE_MESH] Processing {len(scene.geometry)} parts") |
|
|
|
exploded_scene = trimesh.Scene() |
|
|
|
part_centers = [] |
|
geometry_names = [] |
|
|
|
for geometry_name, geometry in scene.geometry.items(): |
|
if hasattr(geometry, 'vertices'): |
|
transform = scene.graph[geometry_name][0] |
|
vertices_global = trimesh.transformations.transform_points( |
|
geometry.vertices, transform) |
|
center = np.mean(vertices_global, axis=0) |
|
part_centers.append(center) |
|
geometry_names.append(geometry_name) |
|
print(f"[EXPLODE_MESH] Part {geometry_name}: center = {center}") |
|
|
|
if not part_centers: |
|
print("No valid geometries with vertices found") |
|
return scene |
|
|
|
part_centers = np.array(part_centers) |
|
global_center = np.mean(part_centers, axis=0) |
|
|
|
print(f"[EXPLODE_MESH] Global center: {global_center}") |
|
|
|
for i, (geometry_name, geometry) in enumerate(scene.geometry.items()): |
|
if hasattr(geometry, 'vertices'): |
|
if i < len(part_centers): |
|
part_center = part_centers[i] |
|
direction = part_center - global_center |
|
|
|
direction_norm = np.linalg.norm(direction) |
|
if direction_norm > 1e-6: |
|
direction = direction / direction_norm |
|
else: |
|
direction = np.random.randn(3) |
|
direction = direction / np.linalg.norm(direction) |
|
|
|
offset = direction * explosion_scale |
|
else: |
|
offset = np.zeros(3) |
|
|
|
original_transform = scene.graph[geometry_name][0].copy() |
|
|
|
new_transform = original_transform.copy() |
|
new_transform[:3, 3] = new_transform[:3, 3] + offset |
|
|
|
exploded_scene.add_geometry( |
|
geometry, |
|
transform=new_transform, |
|
geom_name=geometry_name |
|
) |
|
|
|
print(f"[EXPLODE_MESH] Part {geometry_name}: moved by {np.linalg.norm(offset):.4f}") |
|
|
|
print("[EXPLODE_MESH] Mesh explosion complete") |
|
return exploded_scene |
|
|
|
@spaces.GPU(duration=90) |
|
def generate_parts(state, seed, cfg_strength, req: gr.Request): |
|
explode_factor=0.3 |
|
img_path = state["processed_image"] |
|
mask_path = state["save_mask_path"] |
|
user_dir = os.path.join(TMP_ROOT, str(req.session_hash)) |
|
img_white_bg, img_black_bg, ordered_mask_input, img_mask_vis = load_img_mask(img_path, mask_path) |
|
img_mask_vis.save(os.path.join(user_dir, "img_mask_vis.png")) |
|
|
|
voxel_coords = part_synthesis_pipeline.get_coords(img_black_bg, num_samples=1, seed=seed, sparse_structure_sampler_params={"steps": 25, "cfg_strength": 7.5}) |
|
voxel_coords = voxel_coords.cpu().numpy() |
|
np.save(os.path.join(user_dir, "voxel_coords.npy"), voxel_coords) |
|
voxel_coords_ply = vis_voxel_coords(voxel_coords) |
|
voxel_coords_ply.export(os.path.join(user_dir, "voxel_coords_vis.ply")) |
|
print("[INFO] Voxel coordinates saved") |
|
|
|
bbox_gen_input = prepare_bbox_gen_input(os.path.join(user_dir, "voxel_coords.npy"), img_white_bg, ordered_mask_input) |
|
bbox_gen_output = bbox_gen_model.generate(bbox_gen_input) |
|
np.save(os.path.join(user_dir, "bboxes.npy"), bbox_gen_output['bboxes'][0]) |
|
bboxes_vis = gen_mesh_from_bounds(bbox_gen_output['bboxes'][0]) |
|
bboxes_vis.export(os.path.join(user_dir, "bboxes_vis.glb")) |
|
print("[INFO] BboxGen output saved") |
|
|
|
|
|
part_synthesis_input = prepare_part_synthesis_input(os.path.join(user_dir, "voxel_coords.npy"), os.path.join(user_dir, "bboxes.npy"), ordered_mask_input) |
|
|
|
torch.cuda.empty_cache() |
|
|
|
part_synthesis_output = part_synthesis_pipeline.get_slat( |
|
img_black_bg, |
|
part_synthesis_input['coords'], |
|
[part_synthesis_input['part_layouts']], |
|
part_synthesis_input['masks'], |
|
seed=seed, |
|
slat_sampler_params={"steps": 25, "cfg_strength": cfg_strength}, |
|
formats=['mesh', 'gaussian'], |
|
preprocess_image=False, |
|
) |
|
save_parts_outputs( |
|
part_synthesis_output, |
|
output_dir=user_dir, |
|
simplify_ratio=0.0, |
|
save_video=False, |
|
save_glb=True, |
|
textured=False, |
|
) |
|
merge_parts(user_dir) |
|
print("[INFO] PartSynthesis output saved") |
|
|
|
bbox_mesh_path = os.path.join(user_dir, "bboxes_vis.glb") |
|
whole_mesh_path = os.path.join(user_dir, "mesh_segment.glb") |
|
|
|
combined_mesh = trimesh.load(whole_mesh_path) |
|
exploded_mesh_result = explode_mesh(combined_mesh, explosion_scale=explode_factor) |
|
exploded_mesh_result.export(os.path.join(user_dir, "exploded_parts.glb")) |
|
|
|
exploded_mesh_path = os.path.join(user_dir, "exploded_parts.glb") |
|
combined_gs_path = os.path.join(user_dir, "merged_gs.ply") |
|
exploded_gs_path = os.path.join(user_dir, "exploded_gs.ply") |
|
|
|
return bbox_mesh_path, whole_mesh_path, exploded_mesh_path, combined_gs_path, exploded_gs_path |
|
|