Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import spaces | |
import os | |
import numpy as np | |
import trimesh | |
import time | |
import traceback | |
import torch | |
from PIL import Image | |
import cv2 | |
import shutil | |
from segment_anything import SamAutomaticMaskGenerator, build_sam | |
from omegaconf import OmegaConf | |
from modules.bbox_gen.models.autogressive_bbox_gen import BboxGen | |
from modules.part_synthesis.process_utils import save_parts_outputs | |
from modules.inference_utils import load_img_mask, prepare_bbox_gen_input, prepare_part_synthesis_input, gen_mesh_from_bounds, vis_voxel_coords, merge_parts | |
from modules.part_synthesis.pipelines import OmniPartImageTo3DPipeline | |
from modules.label_2d_mask.visualizer import Visualizer | |
from transformers import AutoModelForImageSegmentation | |
from modules.label_2d_mask.label_parts import ( | |
prepare_image, | |
get_sam_mask, | |
get_mask, | |
clean_segment_edges, | |
resize_and_pad_to_square, | |
size_th as DEFAULT_SIZE_TH | |
) | |
# Constants | |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
DTYPE = torch.float16 | |
MAX_SEED = np.iinfo(np.int32).max | |
TMP_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "tmp") | |
os.makedirs(TMP_ROOT, exist_ok=True) | |
sam_mask_generator = None | |
rmbg_model = None | |
bbox_gen_model = None | |
part_synthesis_pipeline = None | |
size_th = DEFAULT_SIZE_TH | |
def prepare_models(sam_ckpt_path, partfield_ckpt_path, bbox_gen_ckpt_path): | |
global sam_mask_generator, rmbg_model, bbox_gen_model, part_synthesis_pipeline | |
if sam_mask_generator is None: | |
print("Loading SAM model...") | |
sam_model = build_sam(checkpoint=sam_ckpt_path).to(device=DEVICE) | |
sam_mask_generator = SamAutomaticMaskGenerator(sam_model) | |
if rmbg_model is None: | |
print("Loading BriaRMBG 2.0 model...") | |
rmbg_model = AutoModelForImageSegmentation.from_pretrained('briaai/RMBG-2.0', trust_remote_code=True) | |
rmbg_model.to(DEVICE) | |
rmbg_model.eval() | |
if part_synthesis_pipeline is None: | |
print("Loading PartSynthesis model...") | |
part_synthesis_pipeline = OmniPartImageTo3DPipeline.from_pretrained('omnipart/OmniPart') | |
part_synthesis_pipeline.to(DEVICE) | |
if bbox_gen_model is None: | |
print("Loading BboxGen model...") | |
bbox_gen_config = OmegaConf.load("configs/bbox_gen.yaml").model.args | |
bbox_gen_config.partfield_encoder_path = partfield_ckpt_path | |
bbox_gen_model = BboxGen(bbox_gen_config) | |
bbox_gen_model.load_state_dict(torch.load(bbox_gen_ckpt_path), strict=False) | |
bbox_gen_model.to(DEVICE) | |
bbox_gen_model.eval().half() | |
print("Models ready") | |
def process_image(image_path, threshold, req: gr.Request): | |
"""Process image and generate initial segmentation""" | |
global size_th | |
user_dir = os.path.join(TMP_ROOT, str(req.session_hash)) | |
os.makedirs(user_dir, exist_ok=True) | |
img_name = os.path.basename(image_path).split(".")[0] | |
size_th = threshold | |
img = Image.open(image_path).convert("RGB") | |
processed_image = prepare_image(img, rmbg_net=rmbg_model.to(DEVICE)) | |
processed_image = resize_and_pad_to_square(processed_image) | |
white_bg = Image.new("RGBA", processed_image.size, (255, 255, 255, 255)) | |
white_bg_img = Image.alpha_composite(white_bg, processed_image.convert("RGBA")) | |
image = np.array(white_bg_img.convert('RGB')) | |
rgba_path = os.path.join(user_dir, f"{img_name}_processed.png") | |
processed_image.save(rgba_path) | |
print("Generating raw SAM masks without post-processing...") | |
raw_masks = sam_mask_generator.generate(image) | |
raw_sam_vis = np.copy(image) | |
raw_sam_vis = np.ones_like(image) * 255 | |
sorted_masks = sorted(raw_masks, key=lambda x: x["area"], reverse=True) | |
for i, mask_data in enumerate(sorted_masks): | |
if mask_data["area"] < size_th: | |
continue | |
color_r = (i * 50 + 80) % 256 | |
color_g = (i * 120 + 40) % 256 | |
color_b = (i * 180 + 20) % 256 | |
color = np.array([color_r, color_g, color_b]) | |
mask = mask_data["segmentation"] | |
raw_sam_vis[mask] = color | |
visual = Visualizer(image) | |
group_ids, pre_merge_im = get_sam_mask( | |
image, | |
sam_mask_generator, | |
visual, | |
merge_groups=None, | |
rgba_image=processed_image, | |
img_name=img_name, | |
save_dir=user_dir, | |
size_threshold=size_th | |
) | |
pre_merge_path = os.path.join(user_dir, f"{img_name}_mask_pre_merge.png") | |
Image.fromarray(pre_merge_im).save(pre_merge_path) | |
pre_split_vis = np.ones_like(image) * 255 | |
unique_ids = np.unique(group_ids) | |
unique_ids = unique_ids[unique_ids >= 0] | |
for i, unique_id in enumerate(unique_ids): | |
color_r = (i * 50 + 80) % 256 | |
color_g = (i * 120 + 40) % 256 | |
color_b = (i * 180 + 20) % 256 | |
color = np.array([color_r, color_g, color_b]) | |
mask = (group_ids == unique_id) | |
pre_split_vis[mask] = color | |
y_indices, x_indices = np.where(mask) | |
if len(y_indices) > 0: | |
center_y = int(np.mean(y_indices)) | |
center_x = int(np.mean(x_indices)) | |
cv2.putText(pre_split_vis, str(unique_id), | |
(center_x, center_y), cv2.FONT_HERSHEY_SIMPLEX, | |
0.5, (0, 0, 0), 1, cv2.LINE_AA) | |
pre_split_path = os.path.join(user_dir, f"{img_name}_pre_split.png") | |
Image.fromarray(pre_split_vis).save(pre_split_path) | |
print(f"Pre-split segmentation (before disconnected parts handling) saved to {pre_split_path}") | |
get_mask(group_ids, image, ids=2, img_name=img_name, save_dir=user_dir) | |
init_seg_path = os.path.join(user_dir, f"{img_name}_mask_segments_2.png") | |
seg_img = Image.open(init_seg_path) | |
if seg_img.mode == 'RGBA': | |
white_bg = Image.new('RGBA', seg_img.size, (255, 255, 255, 255)) | |
seg_img = Image.alpha_composite(white_bg, seg_img) | |
seg_img.save(init_seg_path) | |
state = { | |
"image": image.tolist(), | |
"processed_image": rgba_path, | |
"group_ids": group_ids.tolist() if isinstance(group_ids, np.ndarray) else group_ids, | |
"original_group_ids": group_ids.tolist() if isinstance(group_ids, np.ndarray) else group_ids, | |
"img_name": img_name, | |
"pre_split_path": pre_split_path, | |
} | |
return init_seg_path, pre_merge_path, state | |
def apply_merge(merge_input, state, req: gr.Request): | |
"""Apply merge parameters and generate merged segmentation""" | |
global sam_mask_generator | |
if not state: | |
return None, None, state | |
user_dir = os.path.join(TMP_ROOT, str(req.session_hash)) | |
# Convert back from list to numpy array | |
image = np.array(state["image"]) | |
# Use original group IDs instead of the most recent ones | |
group_ids = np.array(state["original_group_ids"]) | |
img_name = state["img_name"] | |
# Load processed image from path | |
processed_image = Image.open(state["processed_image"]) | |
# Display the original IDs before merging, SORTED for easier reading | |
unique_ids = np.unique(group_ids) | |
unique_ids = unique_ids[unique_ids >= 0] # Exclude background | |
print(f"Original segment IDs (used for merging): {sorted(unique_ids.tolist())}") | |
# Parse merge groups | |
merge_groups = None | |
try: | |
if merge_input: | |
merge_groups = [] | |
group_sets = merge_input.split(';') | |
for group_set in group_sets: | |
ids = [int(x) for x in group_set.split(',')] | |
if ids: | |
# Validate if these IDs exist in the segmentation | |
existing_ids = [id for id in ids if id in unique_ids] | |
missing_ids = [id for id in ids if id not in unique_ids] | |
if missing_ids: | |
print(f"Warning: These IDs don't exist in the segmentation: {missing_ids}") | |
# Only add group if it has valid IDs | |
if existing_ids: | |
merge_groups.append(ids) | |
print(f"Valid merge group: {ids} (missing: {missing_ids if missing_ids else 'none'})") | |
else: | |
print(f"Skipping merge group with no valid IDs: {ids}") | |
print(f"Using merge groups: {merge_groups}") | |
except Exception as e: | |
print(f"Error parsing merge groups: {e}") | |
return None, None, state | |
# Initialize visualizer | |
visual = Visualizer(image) | |
# Generate merged segmentation starting from original IDs | |
# Add skip_split=True to prevent splitting after merging | |
new_group_ids, merged_im = get_sam_mask( | |
image, | |
sam_mask_generator, | |
visual, | |
merge_groups=merge_groups, | |
existing_group_ids=group_ids, | |
rgba_image=processed_image, | |
skip_split=True, | |
img_name=img_name, | |
save_dir=user_dir, | |
size_threshold=size_th | |
) | |
# Display the new IDs after merging for future reference | |
new_unique_ids = np.unique(new_group_ids) | |
new_unique_ids = new_unique_ids[new_unique_ids >= 0] # Exclude background | |
print(f"New segment IDs (after merging): {new_unique_ids.tolist()}") | |
# Clean edges | |
new_group_ids = clean_segment_edges(new_group_ids) | |
# Save merged segmentation visualization | |
get_mask(new_group_ids, image, ids=3, img_name=img_name, save_dir=user_dir) | |
# Path to merged segmentation | |
merged_seg_path = os.path.join(user_dir, f"{img_name}_mask_segments_3.png") | |
save_mask = new_group_ids + 1 | |
save_mask = save_mask.reshape(518, 518, 1).repeat(3, axis=-1) | |
cv2.imwrite(os.path.join(user_dir, f"{img_name}_mask.exr"), save_mask.astype(np.float32)) | |
# Update state with the new group IDs but keep original IDs unchanged | |
state["group_ids"] = new_group_ids.tolist() if isinstance(new_group_ids, np.ndarray) else new_group_ids | |
state["save_mask_path"] = os.path.join(user_dir, f"{img_name}_mask.exr") | |
return merged_seg_path, state | |
def explode_mesh(mesh, explosion_scale=0.4): | |
if isinstance(mesh, trimesh.Scene): | |
scene = mesh | |
elif isinstance(mesh, trimesh.Trimesh): | |
print("Warning: Single mesh provided, can't create exploded view") | |
scene = trimesh.Scene(mesh) | |
return scene | |
else: | |
print(f"Warning: Unexpected mesh type: {type(mesh)}") | |
scene = mesh | |
if len(scene.geometry) <= 1: | |
print("Only one geometry found - nothing to explode") | |
return scene | |
print(f"[EXPLODE_MESH] Starting mesh explosion with scale {explosion_scale}") | |
print(f"[EXPLODE_MESH] Processing {len(scene.geometry)} parts") | |
exploded_scene = trimesh.Scene() | |
part_centers = [] | |
geometry_names = [] | |
for geometry_name, geometry in scene.geometry.items(): | |
if hasattr(geometry, 'vertices'): | |
transform = scene.graph[geometry_name][0] | |
vertices_global = trimesh.transformations.transform_points( | |
geometry.vertices, transform) | |
center = np.mean(vertices_global, axis=0) | |
part_centers.append(center) | |
geometry_names.append(geometry_name) | |
print(f"[EXPLODE_MESH] Part {geometry_name}: center = {center}") | |
if not part_centers: | |
print("No valid geometries with vertices found") | |
return scene | |
part_centers = np.array(part_centers) | |
global_center = np.mean(part_centers, axis=0) | |
print(f"[EXPLODE_MESH] Global center: {global_center}") | |
for i, (geometry_name, geometry) in enumerate(scene.geometry.items()): | |
if hasattr(geometry, 'vertices'): | |
if i < len(part_centers): | |
part_center = part_centers[i] | |
direction = part_center - global_center | |
direction_norm = np.linalg.norm(direction) | |
if direction_norm > 1e-6: | |
direction = direction / direction_norm | |
else: | |
direction = np.random.randn(3) | |
direction = direction / np.linalg.norm(direction) | |
offset = direction * explosion_scale | |
else: | |
offset = np.zeros(3) | |
original_transform = scene.graph[geometry_name][0].copy() | |
new_transform = original_transform.copy() | |
new_transform[:3, 3] = new_transform[:3, 3] + offset | |
exploded_scene.add_geometry( | |
geometry, | |
transform=new_transform, | |
geom_name=geometry_name | |
) | |
print(f"[EXPLODE_MESH] Part {geometry_name}: moved by {np.linalg.norm(offset):.4f}") | |
print("[EXPLODE_MESH] Mesh explosion complete") | |
return exploded_scene | |
def generate_parts(state, seed, cfg_strength, req: gr.Request): | |
explode_factor=0.3 | |
img_path = state["processed_image"] | |
mask_path = state["save_mask_path"] | |
user_dir = os.path.join(TMP_ROOT, str(req.session_hash)) | |
img_white_bg, img_black_bg, ordered_mask_input, img_mask_vis = load_img_mask(img_path, mask_path) | |
img_mask_vis.save(os.path.join(user_dir, "img_mask_vis.png")) | |
voxel_coords = part_synthesis_pipeline.get_coords(img_black_bg, num_samples=1, seed=seed, sparse_structure_sampler_params={"steps": 25, "cfg_strength": 7.5}) | |
voxel_coords = voxel_coords.cpu().numpy() | |
np.save(os.path.join(user_dir, "voxel_coords.npy"), voxel_coords) | |
voxel_coords_ply = vis_voxel_coords(voxel_coords) | |
voxel_coords_ply.export(os.path.join(user_dir, "voxel_coords_vis.ply")) | |
print("[INFO] Voxel coordinates saved") | |
bbox_gen_input = prepare_bbox_gen_input(os.path.join(user_dir, "voxel_coords.npy"), img_white_bg, ordered_mask_input) | |
bbox_gen_output = bbox_gen_model.generate(bbox_gen_input) | |
np.save(os.path.join(user_dir, "bboxes.npy"), bbox_gen_output['bboxes'][0]) | |
bboxes_vis = gen_mesh_from_bounds(bbox_gen_output['bboxes'][0]) | |
bboxes_vis.export(os.path.join(user_dir, "bboxes_vis.glb")) | |
print("[INFO] BboxGen output saved") | |
part_synthesis_input = prepare_part_synthesis_input(os.path.join(user_dir, "voxel_coords.npy"), os.path.join(user_dir, "bboxes.npy"), ordered_mask_input) | |
torch.cuda.empty_cache() | |
part_synthesis_output = part_synthesis_pipeline.get_slat( | |
img_black_bg, | |
part_synthesis_input['coords'], | |
[part_synthesis_input['part_layouts']], | |
part_synthesis_input['masks'], | |
seed=seed, | |
slat_sampler_params={"steps": 25, "cfg_strength": cfg_strength}, | |
formats=['mesh', 'gaussian'], | |
preprocess_image=False, | |
) | |
save_parts_outputs( | |
part_synthesis_output, | |
output_dir=user_dir, | |
simplify_ratio=0.0, | |
save_video=False, | |
save_glb=True, | |
textured=False, | |
) | |
merge_parts(user_dir) | |
print("[INFO] PartSynthesis output saved") | |
bbox_mesh_path = os.path.join(user_dir, "bboxes_vis.glb") | |
whole_mesh_path = os.path.join(user_dir, "mesh_segment.glb") | |
combined_mesh = trimesh.load(whole_mesh_path) | |
exploded_mesh_result = explode_mesh(combined_mesh, explosion_scale=explode_factor) | |
exploded_mesh_result.export(os.path.join(user_dir, "exploded_parts.glb")) | |
exploded_mesh_path = os.path.join(user_dir, "exploded_parts.glb") | |
combined_gs_path = os.path.join(user_dir, "merged_gs.ply") | |
exploded_gs_path = os.path.join(user_dir, "exploded_gs.ply") | |
return bbox_mesh_path, whole_mesh_path, exploded_mesh_path, combined_gs_path, exploded_gs_path | |