Spaces:
Sleeping
Sleeping
import argparse | |
import numpy as np | |
import torch | |
from diffusers import AutoencoderKL, DDPMScheduler, LCMScheduler, UNet2DConditionModel | |
from PIL import Image | |
from torchvision import transforms | |
from tqdm import tqdm | |
from transformers import AutoModelForImageSegmentation | |
import logging | |
# Configure logging | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s') | |
from mvadapter.pipelines.pipeline_mvadapter_i2mv_sdxl import MVAdapterI2MVSDXLPipeline | |
from mvadapter.schedulers.scheduling_shift_snr import ShiftSNRScheduler | |
from mvadapter.utils import ( | |
get_orthogonal_camera, | |
get_plucker_embeds_from_cameras_ortho, | |
make_image_grid, | |
) | |
def prepare_pipeline( | |
base_model, | |
vae_model, | |
unet_model, | |
lora_model, | |
adapter_path, | |
scheduler, | |
num_views, | |
device, | |
dtype, | |
): | |
# Load vae and unet if provided | |
pipe_kwargs = {} | |
if vae_model is not None: | |
pipe_kwargs["vae"] = AutoencoderKL.from_pretrained(vae_model) | |
if unet_model is not None: | |
pipe_kwargs["unet"] = UNet2DConditionModel.from_pretrained(unet_model) | |
# Prepare pipeline | |
pipe: MVAdapterI2MVSDXLPipeline | |
pipe = MVAdapterI2MVSDXLPipeline.from_pretrained(base_model, **pipe_kwargs) | |
# Load scheduler if provided | |
scheduler_class = None | |
if scheduler == "ddpm": | |
scheduler_class = DDPMScheduler | |
elif scheduler == "lcm": | |
scheduler_class = LCMScheduler | |
pipe.scheduler = ShiftSNRScheduler.from_scheduler( | |
pipe.scheduler, | |
shift_mode="interpolated", | |
shift_scale=8.0, | |
scheduler_class=scheduler_class, | |
) | |
pipe.init_custom_adapter(num_views=num_views) | |
pipe.load_custom_adapter( | |
adapter_path, weight_name="mvadapter_i2mv_sdxl.safetensors" | |
) | |
pipe.to(device=device, dtype=dtype) | |
pipe.cond_encoder.to(device=device, dtype=dtype) | |
# load lora if provided | |
if lora_model is not None: | |
model_, name_ = lora_model.rsplit("/", 1) | |
pipe.load_lora_weights(model_, weight_name=name_) | |
# vae slicing for lower memory usage | |
pipe.enable_vae_slicing() | |
return pipe | |
def remove_bg(image: Image.Image, net, transform, device, mask: Image.Image = None): | |
""" | |
Applies a pre-existing mask to an image to make the background transparent. | |
Args: | |
image (PIL.Image.Image): The input image. | |
net: Pre-trained neural network (not used but kept for compatibility). | |
transform: Image transformation object (not used but kept for compatibility). | |
device: Device used for inference (not used but kept for compatibility). | |
mask (PIL.Image.Image, optional): The mask to use. Should be the same size | |
as the input image, with values between 0 and 255 (or 0-1). | |
If None, will return image with no changes. | |
Returns: | |
PIL.Image.Image: The modified image with transparent background. | |
""" | |
if mask is None: | |
return image | |
image_size = image.size | |
if mask.size != image_size: | |
mask = mask.resize(image_size) # Resizing the mask if it is not the same size as image | |
image.putalpha(mask) | |
return image | |
# def remove_bg(image, net, transform, device): | |
# image_size = image.size | |
# input_images = transform(image).unsqueeze(0).to(device) | |
# with torch.no_grad(): | |
# preds = net(input_images)[0].sigmoid().cpu() | |
# #preds = net(input_images)[-1] if isinstance(net(input_images), list) else net(input_images) | |
# pred = preds[0].squeeze() | |
# pred_pil = transforms.ToPILImage()(pred) | |
# mask = pred_pil.resize(image_size) | |
# image.putalpha(mask) | |
# return image | |
# def remove_bg(image: Image.Image, net, transform, device): | |
# """ | |
# Applies a pre-existing mask to an image to make the background transparent. | |
# Args: | |
# image (PIL.Image.Image): The input image. | |
# net: Pre-trained neural network (not used but kept for compatibility). | |
# transform: Image transformation object (not used but kept for compatibility). | |
# device: Device used for inference (not used but kept for compatibility). | |
# Returns: | |
# PIL.Image.Image: The modified image with transparent background. | |
# """ | |
# image_size = image.size | |
# input_images = transform(image).unsqueeze(0).to(device) | |
# with torch.no_grad(): | |
# preds = net(input_images)[-1].sigmoid().cpu() | |
# pred = preds[0].squeeze() | |
# pred_pil = transforms.ToPILImage()(pred) | |
# # Resize the mask to match the original image size | |
# mask = pred_pil.resize(image_size, Image.LANCZOS) | |
# # Create a new image with the same size and mode as the original | |
# output_image = Image.new("RGBA", image_size) | |
# # Apply the mask to the original image | |
# image.putalpha(mask) | |
# # Composite the original image with the mask | |
# output_image.paste(image, (0, 0), image) | |
# return output_image | |
def remove_bg(image: Image.Image, net, transform, device, mask: np.ndarray = None): | |
""" | |
Applies a pre-existing mask to an image to make the background transparent. | |
Args: | |
image (PIL.Image.Image): The input image. | |
net: Pre-trained neural network (not used but kept for compatibility). | |
transform: Image transformation object (not used but kept for compatibility). | |
device: Device used for inference (not used but kept for compatibility). | |
mask (np.ndarray, optional): The mask to use. Should be the same size | |
as the input image, with values between 0 and 255. | |
If None, will return image with no changes. | |
Returns: | |
PIL.Image.Image: The modified image with transparent background. | |
""" | |
if mask is None: | |
return image | |
# Ensure the mask is in the correct format | |
if mask.ndim == 2: # If mask is 2D (H, W) | |
mask = mask.astype(np.uint8) # Ensure mask is uint8 | |
mask = np.expand_dims(mask, axis=-1) # Add channel dimension | |
# Convert the mask to PIL Image | |
mask_pil = Image.fromarray(mask.squeeze(2) * 255) # Convert to binary mask | |
# Resize the mask to match the original image size | |
mask_pil = mask_pil.resize(image.size, Image.LANCZOS) | |
# Create a new image with the same size and mode as the original | |
output_image = Image.new("RGBA", image.size) | |
# Apply the mask to the original image | |
image.putalpha(mask_pil) | |
# Composite the original image with the mask | |
output_image.paste(image, (0, 0), image) | |
return output_image | |
# def preprocess_image(image: Image.Image, height, width): | |
# alpha = image[..., 3] > 0 | |
# # alpha = image | |
# #if image.mode in ("RGBA", "LA"): | |
# # image = np.array(image) | |
# # alpha = image[..., 3] # Extract the alpha channel | |
# #elif image.mode in ("RGB"): | |
# # image = np.array(image) | |
# # Create default alpha for non-alpha images | |
# # alpha = np.ones(image[..., 0].shape, dtype=np.uint8) * 255 # Create | |
# H, W = alpha.shape | |
# # get the bounding box of alpha | |
# y, x = np.where(alpha) | |
# y0, y1 = max(y.min() - 1, 0), min(y.max() + 1, H) | |
# x0, x1 = max(x.min() - 1, 0), min(x.max() + 1, W) | |
# image_center = image[y0:y1, x0:x1] | |
# # resize the longer side to H * 0.9 | |
# H, W, _ = image_center.shape | |
# if H > W: | |
# W = int(W * (height * 0.9) / H) | |
# H = int(height * 0.9) | |
# else: | |
# H = int(H * (width * 0.9) / W) | |
# W = int(width * 0.9) | |
# image_center = np.array(Image.fromarray(image_center).resize((W, H))) | |
# # pad to H, W | |
# start_h = (height - H) // 2 | |
# start_w = (width - W) // 2 | |
# image = np.zeros((height, width, 4), dtype=np.uint8) | |
# image[start_h : start_h + H, start_w : start_w + W] = image_center | |
# image = image.astype(np.float32) / 255.0 | |
# image = image[:, :, :3] * image[:, :, 3:4] + (1 - image[:, :, 3:4]) * 0.5 | |
# image = (image * 255).clip(0, 255).astype(np.uint8) | |
# image = Image.fromarray(image) | |
# return image | |
def preprocess_image(image: Image.Image, height, width): | |
# Convert image to numpy array | |
image_np = np.array(image) | |
# Extract the alpha channel if present | |
if image_np.shape[-1] == 4: | |
alpha = image_np[..., 3] > 0 # Create a binary mask from the alpha channel | |
else: | |
alpha = np.ones(image_np[..., 0].shape, dtype=bool) # Default to all true for RGB images | |
H, W = alpha.shape | |
# Get the bounding box of the alpha | |
y, x = np.where(alpha) | |
y0, y1 = max(y.min() - 1, 0), min(y.max() + 1, H) | |
x0, x1 = max(x.min() - 1, 0), min(x.max() + 1, W) | |
image_center = image_np[y0:y1, x0:x1] | |
# Resize the longer side to H * 0.9 | |
H, W, _ = image_center.shape | |
if H > W: | |
W = int(W * (height * 0.9) / H) | |
H = int(height * 0.9) | |
else: | |
H = int(H * (width * 0.9) / W) | |
W = int(width * 0.9) | |
image_center = np.array(Image.fromarray(image_center).resize((W, H))) | |
# Pad to H, W | |
start_h = (height - H) // 2 | |
start_w = (width - W) // 2 | |
padded_image = np.zeros((height, width, 4), dtype=np.uint8) | |
padded_image[start_h:start_h + H, start_w:start_w + W] = image_center | |
# Convert back to PIL Image | |
return Image.fromarray(padded_image) | |
def run_pipeline( | |
pipe, | |
num_views, | |
text, | |
image, | |
height, | |
width, | |
num_inference_steps, | |
guidance_scale, | |
seed, | |
remove_bg_fn=None, | |
reference_conditioning_scale=1.0, | |
negative_prompt="watermark, ugly, deformed, noisy, blurry, low contrast", | |
lora_scale=1.0, | |
device="cuda", | |
): | |
# Prepare cameras | |
cameras = get_orthogonal_camera( | |
elevation_deg=[0, 0, 0, 0, 0, 0], | |
distance=[1.8] * num_views, | |
left=-0.55, | |
right=0.55, | |
bottom=-0.55, | |
top=0.55, | |
azimuth_deg=[x - 90 for x in [0, 45, 90, 180, 270, 315]], | |
device=device, | |
) | |
plucker_embeds = get_plucker_embeds_from_cameras_ortho( | |
cameras.c2w, [1.1] * num_views, width | |
) | |
control_images = ((plucker_embeds + 1.0) / 2.0).clamp(0, 1) | |
# Prepare image | |
# reference_image = Image.open(image) if isinstance(image, str) else image | |
# if remove_bg_fn is not None: | |
# reference_image = remove_bg_fn(reference_image) | |
# reference_image = preprocess_image(reference_image, height, width) | |
# elif reference_image.mode == "RGBA": | |
# reference_image = preprocess_image(reference_image, height, width) | |
reference_image = Image.open(image) if isinstance(image, str) else image | |
logging.info(f"Initial reference_image mode: {reference_image.mode}") | |
if remove_bg_fn is not None: | |
logging.info("Using remove_bg_fn") | |
reference_image = remove_bg_fn(reference_image) | |
reference_image = preprocess_image(reference_image, height, width) | |
elif reference_image.mode == "RGBA": | |
logging.info("Image is RGBA, preprocessing directly") | |
reference_image = preprocess_image(reference_image, height, width) | |
logging.info(f"Final reference_image mode: {reference_image.mode}") | |
pipe_kwargs = {} | |
if seed != -1 and isinstance(seed, int): | |
pipe_kwargs["generator"] = torch.Generator(device=device).manual_seed(seed) | |
images = pipe( | |
text, | |
height=height, | |
width=width, | |
num_inference_steps=num_inference_steps, | |
guidance_scale=guidance_scale, | |
num_images_per_prompt=num_views, | |
control_image=control_images, | |
control_conditioning_scale=1.0, | |
reference_image=reference_image, | |
reference_conditioning_scale=reference_conditioning_scale, | |
negative_prompt=negative_prompt, | |
cross_attention_kwargs={"scale": lora_scale}, | |
**pipe_kwargs, | |
).images | |
return images, reference_image | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
# Models | |
parser.add_argument( | |
"--base_model", type=str, default="stabilityai/stable-diffusion-xl-base-1.0" | |
) | |
parser.add_argument( | |
"--vae_model", type=str, default="madebyollin/sdxl-vae-fp16-fix" | |
) | |
parser.add_argument("--unet_model", type=str, default=None) | |
parser.add_argument("--scheduler", type=str, default=None) | |
parser.add_argument("--lora_model", type=str, default=None) | |
parser.add_argument("--adapter_path", type=str, default="huanngzh/mv-adapter") | |
parser.add_argument("--num_views", type=int, default=6) | |
# Device | |
parser.add_argument("--device", type=str, default="cuda") | |
# Inference | |
parser.add_argument("--image", type=str, required=True) | |
parser.add_argument("--text", type=str, default="high quality") | |
parser.add_argument("--num_inference_steps", type=int, default=50) | |
parser.add_argument("--guidance_scale", type=float, default=3.0) | |
parser.add_argument("--seed", type=int, default=-1) | |
parser.add_argument("--lora_scale", type=float, default=1.0) | |
parser.add_argument("--reference_conditioning_scale", type=float, default=1.0) | |
parser.add_argument( | |
"--negative_prompt", | |
type=str, | |
default="watermark, ugly, deformed, noisy, blurry, low contrast", | |
) | |
parser.add_argument("--output", type=str, default="output.png") | |
# Extra | |
parser.add_argument("--remove_bg", action="store_true", help="Remove background") | |
args = parser.parse_args() | |
pipe = prepare_pipeline( | |
base_model=args.base_model, | |
vae_model=args.vae_model, | |
unet_model=args.unet_model, | |
lora_model=args.lora_model, | |
adapter_path=args.adapter_path, | |
scheduler=args.scheduler, | |
num_views=args.num_views, | |
device=args.device, | |
dtype=torch.float16, | |
) | |
if args.remove_bg: | |
birefnet = AutoModelForImageSegmentation.from_pretrained( | |
"ZhengPeng7/BiRefNet", trust_remote_code=True | |
) | |
birefnet.to(args.device) | |
transform_image = transforms.Compose( | |
[ | |
transforms.Resize((1024, 1024)), | |
transforms.ToTensor(), | |
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), | |
] | |
) | |
remove_bg_fn = lambda x: remove_bg(x, birefnet, transform_image, args.device) | |
else: | |
remove_bg_fn = None | |
images, reference_image = run_pipeline( | |
pipe, | |
num_views=args.num_views, | |
text=args.text, | |
image=args.image, | |
height=768, | |
width=768, | |
num_inference_steps=args.num_inference_steps, | |
guidance_scale=args.guidance_scale, | |
seed=args.seed, | |
lora_scale=args.lora_scale, | |
reference_conditioning_scale=args.reference_conditioning_scale, | |
negative_prompt=args.negative_prompt, | |
device=args.device, | |
remove_bg_fn=remove_bg_fn, | |
) | |
make_image_grid(images, rows=1).save(args.output) | |
reference_image.save(args.output.rsplit(".", 1)[0] + "_reference.png") |