Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import torch | |
import torch.nn as nn | |
import gradio as gr | |
import numpy as np | |
from PIL import Image | |
from omegaconf import OmegaConf | |
from pytorch_lightning import seed_everything | |
from huggingface_hub import hf_hub_download | |
""||||||||||||||||||||"from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler | |
from einops import rearrange | |
from shap_e.diffusion.sample import sample_latents | |
from shap_e.diffusion.gaussian_diffusion import diffusion_from_config | |
from shap_e.models.download import load_model, load_config | |
from shap_e.util.notebooks import create_pan_cameras, decode_latent_images, create_custom_cameras | |
from src.utils.train_util import instantiate_from_config | |
from src.utils.camera_util import ( | |
FOV_to_intrinsics, | |
get_zero123plus_input_cameras, | |
get_circular_camera_poses, | |
spherical_camera_pose | |
) | |
from src.utils.mesh_util import save_obj, save_glb | |
from src.utils.infer_util import remove_background, resize_foreground | |
def load_models(): | |
"""Initialize and load all required models""" | |
config = OmegaConf.load('configs/instant-nerf-large-best.yaml') | |
model_config = config.model_config | |
infer_config = config.infer_config | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
# Load diffusion pipeline | |
print('Loading diffusion pipeline...') | |
pipeline = DiffusionPipeline.from_pretrained( | |
"sudo-ai/zero123plus-v1.2", | |
custom_pipeline="zero123plus", | |
torch_dtype=torch.float16 | |
) | |
pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config( | |
pipeline.scheduler.config, timestep_spacing='trailing' | |
) | |
# Modify UNet to handle 8 input channels instead of 4 | |
in_channels = 8 | |
out_channels = pipeline.unet.conv_in.out_channels | |
pipeline.unet.register_to_config(in_channels=in_channels) | |
with torch.no_grad(): | |
new_conv_in = nn.Conv2d( | |
in_channels, out_channels, pipeline.unet.conv_in.kernel_size, | |
pipeline.unet.conv_in.stride, pipeline.unet.conv_in.padding | |
) | |
new_conv_in.weight.zero_() | |
new_conv_in.weight[:, :4, :, :].copy_(pipeline.unet.conv_in.weight) | |
pipeline.unet.conv_in = new_conv_in | |
# Load custom UNet | |
print('Loading custom UNet...') | |
unet_path = "best_21.ckpt" | |
state_dict = torch.load(unet_path, map_location='cpu') | |
# Process the state dict to match the model keys | |
if 'state_dict' in state_dict: | |
new_state_dict = {key.replace('unet.unet.', ''): value for key, value in state_dict['state_dict'].items()} | |
pipeline.unet.load_state_dict(new_state_dict, strict=False) | |
else: | |
pipeline.unet.load_state_dict(state_dict, strict=False) | |
pipeline = pipeline.to(device).to(torch_dtype=torch.float16) | |
# Load reconstruction model | |
print('Loading reconstruction model...') | |
model = instantiate_from_config(model_config) | |
model_path = hf_hub_download( | |
repo_id="TencentARC/InstantMesh", | |
filename="instant_nerf_large.ckpt", | |
repo_type="model" | |
) | |
state_dict = torch.load(model_path, map_location='cpu')['state_dict'] | |
state_dict = {k[14:]: v for k, v in state_dict.items() | |
if k.startswith('lrm_generator.') and 'source_camera' not in k} | |
model.load_state_dict(state_dict, strict=True) | |
model = model.to(device) | |
model.eval() | |
return pipeline, model, infer_config | |
def process_images(input_images, prompt, steps=75, guidance_scale=7.5, pipeline=None): | |
"""Process input images and run refinement""" | |
device = pipeline.device | |
if isinstance(input_images, list): | |
if len(input_images) == 1: | |
# Check if this is a pre-arranged layout | |
img = Image.open(input_images[0].name).convert('RGB') | |
if img.size == (640, 960): | |
# This is already a layout, use it directly | |
input_image = img | |
else: | |
# Single view - need 6 copies | |
img = img.resize((320, 320)) | |
img_array = np.array(img) / 255.0 | |
images = [img_array] * 6 | |
images = np.stack(images) | |
# Convert to tensor and create layout | |
images = torch.from_numpy(images).float() | |
images = images.permute(0, 3, 1, 2) | |
images = images.reshape(3, 2, 3, 320, 320) | |
images = images.permute(0, 2, 3, 1, 4) | |
images = images.reshape(3, 3, 320, 640) | |
images = images.reshape(1, 3, 960, 640) | |
# Convert back to PIL | |
images = images.permute(0, 2, 3, 1)[0] | |
images = (images.numpy() * 255).astype(np.uint8) | |
input_image = Image.fromarray(images) | |
else: | |
# Multiple individual views | |
images = [] | |
for img_file in input_images: | |
img = Image.open(img_file.name).convert('RGB') | |
img = img.resize((320, 320)) | |
img = np.array(img) / 255.0 | |
images.append(img) | |
# Pad to 6 images if needed | |
while len(images) < 6: | |
images.append(np.zeros_like(images[0])) | |
images = np.stack(images[:6]) | |
# Convert to tensor and create layout | |
images = torch.from_numpy(images).float() | |
images = images.permute(0, 3, 1, 2) | |
images = images.reshape(3, 2, 3, 320, 320) | |
images = images.permute(0, 2, 3, 1, 4) | |
images = images.reshape(3, 3, 320, 640) | |
images = images.reshape(1, 3, 960, 640) | |
# Convert back to PIL | |
images = images.permute(0, 2, 3, 1)[0] | |
images = (images.numpy() * 255).astype(np.uint8) | |
input_image = Image.fromarray(images) | |
else: | |
raise ValueError("Expected a list of images") | |
# Generate refined output | |
output = pipeline.refine( | |
input_image, | |
prompt=prompt, | |
num_inference_steps=int(steps), | |
guidance_scale=guidance_scale | |
).images[0] | |
return output, input_image | |
def create_mesh(refined_image, model, infer_config): | |
"""Generate mesh from refined image""" | |
# Convert PIL image to tensor | |
image = np.array(refined_image) / 255.0 | |
image = torch.from_numpy(image).float().permute(2, 0, 1) | |
# Reshape to 6 views | |
image = image.reshape(3, 960, 640) | |
image = image.reshape(3, 3, 320, 640) | |
image = image.permute(1, 0, 2, 3) | |
image = image.reshape(3, 3, 320, 2, 320) | |
image = image.permute(0, 3, 1, 2, 4) | |
image = image.reshape(6, 3, 320, 320) | |
# Add batch dimension | |
image = image.unsqueeze(0) | |
input_cameras = get_zero123plus_input_cameras(batch_size=1, radius=4.0).to("cuda") | |
image = image.to("cuda") | |
with torch.no_grad(): | |
planes = model.forward_planes(image, input_cameras) | |
mesh_out = model.extract_mesh(planes, **infer_config) | |
vertices, faces, vertex_colors = mesh_out | |
return vertices, faces, vertex_colors | |
class ShapERenderer: | |
def __init__(self, device): | |
print("Loading Shap-E models...") | |
self.device = device | |
self.xm = load_model('transmitter', device=device) | |
self.model = load_model('text300M', device=device) | |
self.diffusion = diffusion_from_config(load_config('diffusion')) | |
print("Shap-E models loaded!") | |
def generate_views(self, prompt, guidance_scale=15.0, num_steps=64): | |
# Generate latents using the text-to-3D model | |
batch_size = 1 | |
guidance_scale = float(guidance_scale) | |
latents = sample_latents( | |
batch_size=batch_size, | |
model=self.model, | |
diffusion=self.diffusion, | |
guidance_scale=guidance_scale, | |
model_kwargs=dict(texts=[prompt] * batch_size), | |
progress=True, | |
clip_denoised=True, | |
use_fp16=True, | |
use_karras=True, | |
karras_steps=num_steps, | |
sigma_min=1e-3, | |
sigma_max=160, | |
s_churn=0, | |
) | |
# Render the 6 views we need with specific viewing angles | |
size = 320 # Size of each rendered image | |
images = [] | |
# Define our 6 specific camera positions to match refine.py | |
azimuths = [30, 90, 150, 210, 270, 330] | |
elevations = [20, -10, 20, -10, 20, -10] | |
for i, (azimuth, elevation) in enumerate(zip(azimuths, elevations)): | |
cameras = create_custom_cameras(size, self.device, azimuths=[azimuth], elevations=[elevation], fov_degrees=30, distance=3.0) | |
rendered_image = decode_latent_images( | |
self.xm, | |
latents[0], | |
rendering_mode='stf', | |
cameras=cameras | |
) | |
images.append(rendered_image.detach().cpu().numpy()) | |
# Convert images to uint8 | |
images = [(image).astype(np.uint8) for image in images] | |
# Create 2x3 grid layout (640x960) instead of 3x2 (960x640) | |
layout = np.zeros((960, 640, 3), dtype=np.uint8) | |
for i, img in enumerate(images): | |
row = i // 2 # Now 3 images per row | |
col = i % 2 # Now 3 images per row | |
layout[row*320:(row+1)*320, col*320:(col+1)*320] = img | |
return Image.fromarray(layout), images | |
class RefinerInterface: | |
def __init__(self): | |
print("Initializing InstantMesh models...") | |
self.pipeline, self.model, self.infer_config = load_models() | |
print("InstantMesh models loaded!") | |
def refine_model(self, input_image, prompt, steps=75, guidance_scale=7.5): | |
"""Main refinement function""" | |
# Process image and get refined output | |
input_image = Image.fromarray(input_image) | |
# Rotate the layout if needed (if we're getting a 640x960 layout but pipeline expects 960x640) | |
if input_image.width == 960 and input_image.height == 640: | |
# Transpose the image to get 960x640 layout | |
input_array = np.array(input_image) | |
new_layout = np.zeros((960, 640, 3), dtype=np.uint8) | |
# Rearrange from 2x3 to 3x2 | |
for i in range(6): | |
src_row = i // 3 | |
src_col = i % 3 | |
dst_row = i // 2 | |
dst_col = i % 2 | |
new_layout[dst_row*320:(dst_row+1)*320, dst_col*320:(dst_col+1)*320] = \ | |
input_array[src_row*320:(src_row+1)*320, src_col*320:(src_col+1)*320] | |
input_image = Image.fromarray(new_layout) | |
# Process with the pipeline (expects 960x640) | |
refined_output_960x640 = self.pipeline.refine( | |
input_image, | |
prompt=prompt, | |
num_inference_steps=int(steps), | |
guidance_scale=guidance_scale | |
).images[0] | |
# Generate mesh using the 960x640 format | |
vertices, faces, vertex_colors = create_mesh( | |
refined_output_960x640, | |
self.model, | |
self.infer_config | |
) | |
# Save temporary mesh file | |
os.makedirs("temp", exist_ok=True) | |
temp_obj = os.path.join("temp", "refined_mesh.obj") | |
save_obj(vertices, faces, vertex_colors, temp_obj) | |
# Convert the output to 640x960 for display | |
refined_array = np.array(refined_output_960x640) | |
display_layout = np.zeros((960, 640, 3), dtype=np.uint8) | |
# Rearrange from 3x2 to 2x3 | |
for i in range(6): | |
src_row = i // 2 | |
src_col = i % 2 | |
dst_row = i // 2 | |
dst_col = i % 2 | |
display_layout[dst_row*320:(dst_row+1)*320, dst_col*320:(dst_col+1)*320] = \ | |
refined_array[src_row*320:(src_row+1)*320, src_col*320:(src_col+1)*320] | |
refined_output_640x960 = Image.fromarray(display_layout) | |
return refined_output_640x960, temp_obj | |
def create_demo(): | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
shap_e = ShapERenderer(device) | |
refiner = RefinerInterface() | |
with gr.Blocks() as demo: | |
gr.Markdown("# Shap-E to InstantMesh Pipeline") | |
# First row: Controls | |
with gr.Row(): | |
with gr.Column(): | |
# Shap-E inputs | |
shape_prompt = gr.Textbox( | |
label="Shap-E Prompt", | |
placeholder="Enter text to generate initial 3D model..." | |
) | |
shape_guidance = gr.Slider( | |
minimum=1, | |
maximum=30, | |
value=15.0, | |
label="Shap-E Guidance Scale" | |
) | |
shape_steps = gr.Slider( | |
minimum=16, | |
maximum=128, | |
value=64, | |
step=16, | |
label="Shap-E Steps" | |
) | |
generate_btn = gr.Button("Generate Views") | |
with gr.Column(): | |
# Refinement inputs | |
refine_prompt = gr.Textbox( | |
label="Refinement Prompt", | |
placeholder="Enter prompt to guide refinement..." | |
) | |
refine_steps = gr.Slider( | |
minimum=30, | |
maximum=100, | |
value=75, | |
step=1, | |
label="Refinement Steps" | |
) | |
refine_guidance = gr.Slider( | |
minimum=1, | |
maximum=20, | |
value=7.5, | |
label="Refinement Guidance Scale" | |
) | |
refine_btn = gr.Button("Refine") | |
# Second row: Image panels side by side | |
with gr.Row(): | |
# Outputs - Images side by side | |
shape_output = gr.Image( | |
label="Generated Views", | |
width=640, # Swapped dimensions | |
height=960 # Swapped dimensions | |
) | |
refined_output = gr.Image( | |
label="Refined Output", | |
width=640, # Swapped dimensions | |
height=960 # Swapped dimensions | |
) | |
# Third row: 3D mesh panel below | |
with gr.Row(): | |
# 3D mesh centered | |
mesh_output = gr.Model3D( | |
label="3D Mesh", | |
clear_color=[1.0, 1.0, 1.0, 1.0], | |
width=1280, # Full width | |
height=600 # Taller for better visualization | |
) | |
# Set up event handlers | |
def generate(prompt, guidance_scale, num_steps): | |
with torch.no_grad(): | |
layout, _ = shap_e.generate_views(prompt, guidance_scale, num_steps) | |
return layout | |
def refine(input_image, prompt, steps, guidance_scale): | |
refined_img, mesh_path = refiner.refine_model( | |
input_image, | |
prompt, | |
steps, | |
guidance_scale | |
) | |
return refined_img, mesh_path | |
generate_btn.click( | |
fn=generate, | |
inputs=[shape_prompt, shape_guidance, shape_steps], | |
outputs=[shape_output] | |
) | |
refine_btn.click( | |
fn=refine, | |
inputs=[shape_output, refine_prompt, refine_steps, refine_guidance], | |
outputs=[refined_output, mesh_output] | |
) | |
return demo | |
if __name__ == "__main__": | |
demo = create_demo() | |
demo.launch(share=True) |