Spaces:
Running
on
Zero
Running
on
Zero
import argparse | |
import numpy as np | |
import torch | |
from diffusers import AutoencoderKL, DDPMScheduler, LCMScheduler, UNet2DConditionModel | |
from PIL import Image | |
from torchvision import transforms | |
from tqdm import tqdm | |
from transformers import AutoModelForImageSegmentation | |
from step1x3d_texture.models.attention_processor import ( | |
DecoupledMVRowColSelfAttnProcessor2_0, | |
) | |
from step1x3d_texture.pipelines.ig2mv_sdxl_pipeline import IG2MVSDXLPipeline | |
from step1x3d_texture.schedulers.scheduling_shift_snr import ShiftSNRScheduler | |
from step1x3d_texture.utils import ( | |
get_orthogonal_camera, | |
make_image_grid, | |
tensor_to_image, | |
) | |
from step1x3d_texture.utils.render import NVDiffRastContextWrapper, load_mesh, render | |
from step1x3d_texture.differentiable_renderer.mesh_render import MeshRender | |
import trimesh | |
import xatlas | |
import scipy.sparse | |
from scipy.sparse.linalg import spsolve | |
from step1x3d_geometry.models.pipelines.pipeline_utils import smart_load_model | |
class Step1X3DTextureConfig: | |
def __init__(self): | |
# prepare pipeline params | |
self.base_model = "stabilityai/stable-diffusion-xl-base-1.0" | |
self.vae_model = "madebyollin/sdxl-vae-fp16-fix" | |
self.unet_model = None | |
self.lora_model = None | |
self.adapter_path = "stepfun-ai/Step1X-3D" | |
self.scheduler = None | |
self.num_views = 6 | |
self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
self.dtype = torch.float16 | |
self.lora_scale = None | |
# run pipeline params | |
self.text = "high quality" | |
self.num_inference_steps = 50 | |
self.guidance_scale = 3.0 | |
self.seed = -1 | |
self.reference_conditioning_scale = 1.0 | |
self.negative_prompt = "watermark, ugly, deformed, noisy, blurry, low contrast" | |
self.azimuth_deg = [0, 45, 90, 180, 270, 315] | |
# texture baker params | |
self.selected_camera_azims = [0, 90, 180, 270, 180, 180] | |
self.selected_camera_elevs = [0, 0, 0, 0, 90, -90] | |
self.selected_view_weights = [1, 0.1, 0.5, 0.1, 0.05, 0.05] | |
self.camera_distance = 1.8 | |
self.render_size = 2048 | |
self.texture_size = 2048 | |
self.bake_exp = 4 | |
self.merge_method = "fast" | |
class Step1X3DTexturePipeline: | |
def __init__(self, config): | |
self.config = config | |
self.mesh_render = MeshRender( | |
default_resolution=self.config.render_size, | |
texture_size=self.config.texture_size, | |
camera_distance=self.config.camera_distance, | |
) | |
self.ig2mv_pipe = self.prepare_ig2mv_pipeline( | |
base_model=self.config.base_model, | |
vae_model=self.config.vae_model, | |
unet_model=self.config.unet_model, | |
lora_model=self.config.lora_model, | |
adapter_path=self.config.adapter_path, | |
scheduler=self.config.scheduler, | |
num_views=self.config.num_views, | |
device=self.config.device, | |
dtype=self.config.dtype, | |
) | |
def from_pretrained(cls, model_path, subfolder): | |
config = Step1X3DTextureConfig() | |
local_model_path = smart_load_model(model_path, subfolder=subfolder) | |
print(f'Local model path: {local_model_path}') | |
config.adapter_path = local_model_path | |
return cls(config) | |
def mesh_uv_wrap(self, mesh): | |
if isinstance(mesh, trimesh.Scene): | |
mesh = mesh.to_geometry() | |
vmapping, indices, uvs = xatlas.parametrize(mesh.vertices, mesh.faces) | |
mesh.vertices = mesh.vertices[vmapping] | |
mesh.faces = indices | |
mesh.visual.uv = uvs | |
return mesh | |
def prepare_ig2mv_pipeline( | |
self, | |
base_model, | |
vae_model, | |
unet_model, | |
lora_model, | |
adapter_path, | |
scheduler, | |
num_views, | |
device, | |
dtype, | |
): | |
# Load vae and unet if provided | |
pipe_kwargs = {} | |
if vae_model is not None: | |
pipe_kwargs["vae"] = AutoencoderKL.from_pretrained(vae_model) | |
if unet_model is not None: | |
pipe_kwargs["unet"] = UNet2DConditionModel.from_pretrained(unet_model) | |
print('VAE Loaded!') | |
# Prepare pipeline | |
pipe = IG2MVSDXLPipeline.from_pretrained(base_model, **pipe_kwargs) | |
print('Base model Loaded!') | |
# Load scheduler if provided | |
scheduler_class = None | |
if scheduler == "ddpm": | |
scheduler_class = DDPMScheduler | |
elif scheduler == "lcm": | |
scheduler_class = LCMScheduler | |
pipe.scheduler = ShiftSNRScheduler.from_scheduler( | |
pipe.scheduler, | |
shift_mode="interpolated", | |
shift_scale=8.0, | |
scheduler_class=scheduler_class, | |
) | |
print('Scheduler Loaded!') | |
pipe.init_custom_adapter( | |
num_views=num_views, | |
self_attn_processor=DecoupledMVRowColSelfAttnProcessor2_0, | |
) | |
print(f'Load adapter from {adapter_path}/step1x-3d-ig2v.safetensors') | |
pipe.load_custom_adapter(adapter_path, "step1x-3d-ig2v.safetensors") | |
print(f'Load adapter successed!') | |
pipe.to(device=device, dtype=dtype) | |
pipe.cond_encoder.to(device=device, dtype=dtype) | |
# load lora if provided | |
if lora_model is not None: | |
model_, name_ = lora_model.rsplit("/", 1) | |
pipe.load_lora_weights(model_, weight_name=name_) | |
return pipe | |
def remove_bg(self, image, net, transform, device): | |
image_size = image.size | |
input_images = transform(image).unsqueeze(0).to(device) | |
with torch.no_grad(): | |
preds = net(input_images)[-1].sigmoid().cpu() | |
pred = preds[0].squeeze() | |
pred_pil = transforms.ToPILImage()(pred) | |
mask = pred_pil.resize(image_size) | |
image.putalpha(mask) | |
return image | |
def preprocess_image(self, image, height, width): | |
image = np.array(image) | |
alpha = image[..., 3] > 0 | |
H, W = alpha.shape | |
# get the bounding box of alpha | |
y, x = np.where(alpha) | |
y0, y1 = max(y.min() - 1, 0), min(y.max() + 1, H) | |
x0, x1 = max(x.min() - 1, 0), min(x.max() + 1, W) | |
image_center = image[y0:y1, x0:x1] | |
# resize the longer side to H * 0.9 | |
H, W, _ = image_center.shape | |
if H > W: | |
W = int(W * (height * 0.9) / H) | |
H = int(height * 0.9) | |
else: | |
H = int(H * (width * 0.9) / W) | |
W = int(width * 0.9) | |
image_center = np.array(Image.fromarray(image_center).resize((W, H))) | |
# pad to H, W | |
start_h = (height - H) // 2 | |
start_w = (width - W) // 2 | |
image = np.zeros((height, width, 4), dtype=np.uint8) | |
image[start_h : start_h + H, start_w : start_w + W] = image_center | |
image = image.astype(np.float32) / 255.0 | |
image = image[:, :, :3] * image[:, :, 3:4] + (1 - image[:, :, 3:4]) * 0.5 | |
image = (image * 255).clip(0, 255).astype(np.uint8) | |
image = Image.fromarray(image) | |
return image | |
def run_ig2mv_pipeline( | |
self, | |
pipe, | |
mesh, | |
num_views, | |
text, | |
image, | |
height, | |
width, | |
num_inference_steps, | |
guidance_scale, | |
seed, | |
remove_bg_fn=None, | |
reference_conditioning_scale=1.0, | |
negative_prompt="watermark, ugly, deformed, noisy, blurry, low contrast", | |
lora_scale=1.0, | |
device="cuda", | |
): | |
# Prepare cameras | |
cameras = get_orthogonal_camera( | |
elevation_deg=[0, 0, 0, 0, 89.99, -89.99], | |
distance=[1.8] * num_views, | |
left=-0.55, | |
right=0.55, | |
bottom=-0.55, | |
top=0.55, | |
azimuth_deg=[x - 90 for x in [0, 90, 180, 270, 180, 180]], | |
device=device, | |
) | |
ctx = NVDiffRastContextWrapper(device=device, context_type="cuda") | |
mesh, mesh_bp = load_mesh(mesh, rescale=True, device=device) | |
render_out = render( | |
ctx, | |
mesh, | |
cameras, | |
height=height, | |
width=width, | |
render_attr=False, | |
normal_background=0.0, | |
) | |
pos_images = tensor_to_image((render_out.pos + 0.5).clamp(0, 1), batched=True) | |
normal_images = tensor_to_image( | |
(render_out.normal / 2 + 0.5).clamp(0, 1), batched=True | |
) | |
control_images = ( | |
torch.cat( | |
[ | |
(render_out.pos + 0.5).clamp(0, 1), | |
(render_out.normal / 2 + 0.5).clamp(0, 1), | |
], | |
dim=-1, | |
) | |
.permute(0, 3, 1, 2) | |
.to(device) | |
) | |
# Prepare image | |
reference_image = Image.open(image) if isinstance(image, str) else image | |
if len(reference_image.split()) == 1: | |
reference_image = reference_image.convert("RGBA") | |
if remove_bg_fn is not None and reference_image.mode == "RGB": | |
reference_image = remove_bg_fn(reference_image) | |
reference_image = self.preprocess_image(reference_image, height, width) | |
elif reference_image.mode == "RGBA": | |
reference_image = self.preprocess_image(reference_image, height, width) | |
pipe_kwargs = {} | |
if seed != -1 and isinstance(seed, int): | |
pipe_kwargs["generator"] = torch.Generator(device=device).manual_seed(seed) | |
images = pipe( | |
text, | |
height=height, | |
width=width, | |
num_inference_steps=num_inference_steps, | |
guidance_scale=guidance_scale, | |
num_images_per_prompt=num_views, | |
control_image=control_images, | |
control_conditioning_scale=1.0, | |
reference_image=reference_image, | |
reference_conditioning_scale=reference_conditioning_scale, | |
negative_prompt=negative_prompt, | |
cross_attention_kwargs={"scale": lora_scale}, | |
mesh=mesh_bp, | |
**pipe_kwargs, | |
).images | |
return images, pos_images, normal_images, reference_image, mesh, mesh_bp | |
def bake_from_multiview( | |
self, | |
render, | |
views, | |
camera_elevs, | |
camera_azims, | |
view_weights, | |
method="graphcut", | |
bake_exp=4, | |
): | |
project_textures, project_weighted_cos_maps = [], [] | |
project_boundary_maps = [] | |
for view, camera_elev, camera_azim, weight in zip( | |
views, camera_elevs, camera_azims, view_weights | |
): | |
project_texture, project_cos_map, project_boundary_map = ( | |
render.back_project(view, camera_elev, camera_azim) | |
) | |
project_cos_map = weight * (project_cos_map**bake_exp) | |
project_textures.append(project_texture) | |
project_weighted_cos_maps.append(project_cos_map) | |
project_boundary_maps.append(project_boundary_map) | |
if method == "fast": | |
texture, ori_trust_map = render.fast_bake_texture( | |
project_textures, project_weighted_cos_maps | |
) | |
else: | |
raise f"no method {method}" | |
return texture, ori_trust_map > 1e-8 | |
def texture_inpaint(self, render, texture, mask): | |
texture_np = render.uv_inpaint(texture, mask) | |
texture = torch.tensor(texture_np / 255).float().to(texture.device) | |
return texture | |
def __call__(self, image, mesh, remove_bg=True, seed=2025): | |
if remove_bg: | |
birefnet = AutoModelForImageSegmentation.from_pretrained( | |
"ZhengPeng7/BiRefNet", trust_remote_code=True | |
) | |
birefnet.to(self.config.device) | |
transform_image = transforms.Compose( | |
[ | |
transforms.Resize((1024, 1024)), | |
transforms.ToTensor(), | |
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), | |
] | |
) | |
remove_bg_fn = lambda x: self.remove_bg( | |
x, birefnet, transform_image, self.config.device | |
) | |
else: | |
remove_bg_fn = None | |
if isinstance(mesh, trimesh.Scene): | |
mesh = mesh.to_geometry() | |
# multi-view generation pipeline | |
images, pos_images, normal_images, reference_image, textured_mesh, mesh_bp = ( | |
self.run_ig2mv_pipeline( | |
self.ig2mv_pipe, | |
mesh=mesh, | |
num_views=self.config.num_views, | |
text=self.config.text, | |
image=image, | |
height=768, | |
width=768, | |
num_inference_steps=self.config.num_inference_steps, | |
guidance_scale=self.config.guidance_scale, | |
seed=seed if seed is not None else self.config.seed, | |
lora_scale=self.config.lora_scale, | |
reference_conditioning_scale=self.config.reference_conditioning_scale, | |
negative_prompt=self.config.negative_prompt, | |
device=self.config.device, | |
remove_bg_fn=remove_bg_fn, | |
) | |
) | |
for i in range(len(images)): | |
images[i] = images[i].resize( | |
(self.config.render_size, self.config.render_size), | |
Image.Resampling.LANCZOS, | |
) | |
mesh = self.mesh_uv_wrap(mesh_bp) | |
self.mesh_render.load_mesh(mesh, auto_center=False, scale_factor=1.0) | |
# texture baker | |
texture, mask = self.bake_from_multiview( | |
self.mesh_render, | |
images, | |
self.config.selected_camera_elevs, | |
self.config.selected_camera_azims, | |
self.config.selected_view_weights, | |
method="fast", | |
) | |
mask_np = (mask.squeeze(-1).cpu().numpy() * 255).astype(np.uint8) | |
# texture inpaint | |
texture = self.texture_inpaint(self.mesh_render, texture, mask_np) | |
self.mesh_render.set_texture(texture) | |
textured_mesh = self.mesh_render.save_mesh() | |
return textured_mesh | |