import os import torch import torch.nn as nn import gradio as gr import numpy as np from PIL import Image from omegaconf import OmegaConf from pytorch_lightning import seed_everything from huggingface_hub import hf_hub_download ""||||||||||||||||||||"from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler from einops import rearrange from shap_e.diffusion.sample import sample_latents from shap_e.diffusion.gaussian_diffusion import diffusion_from_config from shap_e.models.download import load_model, load_config from shap_e.util.notebooks import create_pan_cameras, decode_latent_images, create_custom_cameras from src.utils.train_util import instantiate_from_config from src.utils.camera_util import ( FOV_to_intrinsics, get_zero123plus_input_cameras, get_circular_camera_poses, spherical_camera_pose ) from src.utils.mesh_util import save_obj, save_glb from src.utils.infer_util import remove_background, resize_foreground def load_models(): """Initialize and load all required models""" config = OmegaConf.load('configs/instant-nerf-large-best.yaml') model_config = config.model_config infer_config = config.infer_config device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Load diffusion pipeline print('Loading diffusion pipeline...') pipeline = DiffusionPipeline.from_pretrained( "sudo-ai/zero123plus-v1.2", custom_pipeline="zero123plus", torch_dtype=torch.float16 ) pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config( pipeline.scheduler.config, timestep_spacing='trailing' ) # Modify UNet to handle 8 input channels instead of 4 in_channels = 8 out_channels = pipeline.unet.conv_in.out_channels pipeline.unet.register_to_config(in_channels=in_channels) with torch.no_grad(): new_conv_in = nn.Conv2d( in_channels, out_channels, pipeline.unet.conv_in.kernel_size, pipeline.unet.conv_in.stride, pipeline.unet.conv_in.padding ) new_conv_in.weight.zero_() new_conv_in.weight[:, :4, :, :].copy_(pipeline.unet.conv_in.weight) pipeline.unet.conv_in = new_conv_in # Load custom UNet print('Loading custom UNet...') unet_path = "best_21.ckpt" state_dict = torch.load(unet_path, map_location='cpu') # Process the state dict to match the model keys if 'state_dict' in state_dict: new_state_dict = {key.replace('unet.unet.', ''): value for key, value in state_dict['state_dict'].items()} pipeline.unet.load_state_dict(new_state_dict, strict=False) else: pipeline.unet.load_state_dict(state_dict, strict=False) pipeline = pipeline.to(device).to(torch_dtype=torch.float16) # Load reconstruction model print('Loading reconstruction model...') model = instantiate_from_config(model_config) model_path = hf_hub_download( repo_id="TencentARC/InstantMesh", filename="instant_nerf_large.ckpt", repo_type="model" ) state_dict = torch.load(model_path, map_location='cpu')['state_dict'] state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith('lrm_generator.') and 'source_camera' not in k} model.load_state_dict(state_dict, strict=True) model = model.to(device) model.eval() return pipeline, model, infer_config def process_images(input_images, prompt, steps=75, guidance_scale=7.5, pipeline=None): """Process input images and run refinement""" device = pipeline.device if isinstance(input_images, list): if len(input_images) == 1: # Check if this is a pre-arranged layout img = Image.open(input_images[0].name).convert('RGB') if img.size == (640, 960): # This is already a layout, use it directly input_image = img else: # Single view - need 6 copies img = img.resize((320, 320)) img_array = np.array(img) / 255.0 images = [img_array] * 6 images = np.stack(images) # Convert to tensor and create layout images = torch.from_numpy(images).float() images = images.permute(0, 3, 1, 2) images = images.reshape(3, 2, 3, 320, 320) images = images.permute(0, 2, 3, 1, 4) images = images.reshape(3, 3, 320, 640) images = images.reshape(1, 3, 960, 640) # Convert back to PIL images = images.permute(0, 2, 3, 1)[0] images = (images.numpy() * 255).astype(np.uint8) input_image = Image.fromarray(images) else: # Multiple individual views images = [] for img_file in input_images: img = Image.open(img_file.name).convert('RGB') img = img.resize((320, 320)) img = np.array(img) / 255.0 images.append(img) # Pad to 6 images if needed while len(images) < 6: images.append(np.zeros_like(images[0])) images = np.stack(images[:6]) # Convert to tensor and create layout images = torch.from_numpy(images).float() images = images.permute(0, 3, 1, 2) images = images.reshape(3, 2, 3, 320, 320) images = images.permute(0, 2, 3, 1, 4) images = images.reshape(3, 3, 320, 640) images = images.reshape(1, 3, 960, 640) # Convert back to PIL images = images.permute(0, 2, 3, 1)[0] images = (images.numpy() * 255).astype(np.uint8) input_image = Image.fromarray(images) else: raise ValueError("Expected a list of images") # Generate refined output output = pipeline.refine( input_image, prompt=prompt, num_inference_steps=int(steps), guidance_scale=guidance_scale ).images[0] return output, input_image def create_mesh(refined_image, model, infer_config): """Generate mesh from refined image""" # Convert PIL image to tensor image = np.array(refined_image) / 255.0 image = torch.from_numpy(image).float().permute(2, 0, 1) # Reshape to 6 views image = image.reshape(3, 960, 640) image = image.reshape(3, 3, 320, 640) image = image.permute(1, 0, 2, 3) image = image.reshape(3, 3, 320, 2, 320) image = image.permute(0, 3, 1, 2, 4) image = image.reshape(6, 3, 320, 320) # Add batch dimension image = image.unsqueeze(0) input_cameras = get_zero123plus_input_cameras(batch_size=1, radius=4.0).to("cuda") image = image.to("cuda") with torch.no_grad(): planes = model.forward_planes(image, input_cameras) mesh_out = model.extract_mesh(planes, **infer_config) vertices, faces, vertex_colors = mesh_out return vertices, faces, vertex_colors class ShapERenderer: def __init__(self, device): print("Loading Shap-E models...") self.device = device self.xm = load_model('transmitter', device=device) self.model = load_model('text300M', device=device) self.diffusion = diffusion_from_config(load_config('diffusion')) print("Shap-E models loaded!") def generate_views(self, prompt, guidance_scale=15.0, num_steps=64): # Generate latents using the text-to-3D model batch_size = 1 guidance_scale = float(guidance_scale) latents = sample_latents( batch_size=batch_size, model=self.model, diffusion=self.diffusion, guidance_scale=guidance_scale, model_kwargs=dict(texts=[prompt] * batch_size), progress=True, clip_denoised=True, use_fp16=True, use_karras=True, karras_steps=num_steps, sigma_min=1e-3, sigma_max=160, s_churn=0, ) # Render the 6 views we need with specific viewing angles size = 320 # Size of each rendered image images = [] # Define our 6 specific camera positions to match refine.py azimuths = [30, 90, 150, 210, 270, 330] elevations = [20, -10, 20, -10, 20, -10] for i, (azimuth, elevation) in enumerate(zip(azimuths, elevations)): cameras = create_custom_cameras(size, self.device, azimuths=[azimuth], elevations=[elevation], fov_degrees=30, distance=3.0) rendered_image = decode_latent_images( self.xm, latents[0], rendering_mode='stf', cameras=cameras ) images.append(rendered_image.detach().cpu().numpy()) # Convert images to uint8 images = [(image).astype(np.uint8) for image in images] # Create 2x3 grid layout (640x960) instead of 3x2 (960x640) layout = np.zeros((960, 640, 3), dtype=np.uint8) for i, img in enumerate(images): row = i // 2 # Now 3 images per row col = i % 2 # Now 3 images per row layout[row*320:(row+1)*320, col*320:(col+1)*320] = img return Image.fromarray(layout), images class RefinerInterface: def __init__(self): print("Initializing InstantMesh models...") self.pipeline, self.model, self.infer_config = load_models() print("InstantMesh models loaded!") def refine_model(self, input_image, prompt, steps=75, guidance_scale=7.5): """Main refinement function""" # Process image and get refined output input_image = Image.fromarray(input_image) # Rotate the layout if needed (if we're getting a 640x960 layout but pipeline expects 960x640) if input_image.width == 960 and input_image.height == 640: # Transpose the image to get 960x640 layout input_array = np.array(input_image) new_layout = np.zeros((960, 640, 3), dtype=np.uint8) # Rearrange from 2x3 to 3x2 for i in range(6): src_row = i // 3 src_col = i % 3 dst_row = i // 2 dst_col = i % 2 new_layout[dst_row*320:(dst_row+1)*320, dst_col*320:(dst_col+1)*320] = \ input_array[src_row*320:(src_row+1)*320, src_col*320:(src_col+1)*320] input_image = Image.fromarray(new_layout) # Process with the pipeline (expects 960x640) refined_output_960x640 = self.pipeline.refine( input_image, prompt=prompt, num_inference_steps=int(steps), guidance_scale=guidance_scale ).images[0] # Generate mesh using the 960x640 format vertices, faces, vertex_colors = create_mesh( refined_output_960x640, self.model, self.infer_config ) # Save temporary mesh file os.makedirs("temp", exist_ok=True) temp_obj = os.path.join("temp", "refined_mesh.obj") save_obj(vertices, faces, vertex_colors, temp_obj) # Convert the output to 640x960 for display refined_array = np.array(refined_output_960x640) display_layout = np.zeros((960, 640, 3), dtype=np.uint8) # Rearrange from 3x2 to 2x3 for i in range(6): src_row = i // 2 src_col = i % 2 dst_row = i // 2 dst_col = i % 2 display_layout[dst_row*320:(dst_row+1)*320, dst_col*320:(dst_col+1)*320] = \ refined_array[src_row*320:(src_row+1)*320, src_col*320:(src_col+1)*320] refined_output_640x960 = Image.fromarray(display_layout) return refined_output_640x960, temp_obj def create_demo(): device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') shap_e = ShapERenderer(device) refiner = RefinerInterface() with gr.Blocks() as demo: gr.Markdown("# Shap-E to InstantMesh Pipeline") # First row: Controls with gr.Row(): with gr.Column(): # Shap-E inputs shape_prompt = gr.Textbox( label="Shap-E Prompt", placeholder="Enter text to generate initial 3D model..." ) shape_guidance = gr.Slider( minimum=1, maximum=30, value=15.0, label="Shap-E Guidance Scale" ) shape_steps = gr.Slider( minimum=16, maximum=128, value=64, step=16, label="Shap-E Steps" ) generate_btn = gr.Button("Generate Views") with gr.Column(): # Refinement inputs refine_prompt = gr.Textbox( label="Refinement Prompt", placeholder="Enter prompt to guide refinement..." ) refine_steps = gr.Slider( minimum=30, maximum=100, value=75, step=1, label="Refinement Steps" ) refine_guidance = gr.Slider( minimum=1, maximum=20, value=7.5, label="Refinement Guidance Scale" ) refine_btn = gr.Button("Refine") # Second row: Image panels side by side with gr.Row(): # Outputs - Images side by side shape_output = gr.Image( label="Generated Views", width=640, # Swapped dimensions height=960 # Swapped dimensions ) refined_output = gr.Image( label="Refined Output", width=640, # Swapped dimensions height=960 # Swapped dimensions ) # Third row: 3D mesh panel below with gr.Row(): # 3D mesh centered mesh_output = gr.Model3D( label="3D Mesh", clear_color=[1.0, 1.0, 1.0, 1.0], width=1280, # Full width height=600 # Taller for better visualization ) # Set up event handlers def generate(prompt, guidance_scale, num_steps): with torch.no_grad(): layout, _ = shap_e.generate_views(prompt, guidance_scale, num_steps) return layout def refine(input_image, prompt, steps, guidance_scale): refined_img, mesh_path = refiner.refine_model( input_image, prompt, steps, guidance_scale ) return refined_img, mesh_path generate_btn.click( fn=generate, inputs=[shape_prompt, shape_guidance, shape_steps], outputs=[shape_output] ) refine_btn.click( fn=refine, inputs=[shape_output, refine_prompt, refine_steps, refine_guidance], outputs=[refined_output, mesh_output] ) return demo if __name__ == "__main__": demo = create_demo() demo.launch(share=True)