import spaces import huggingface_hub huggingface_hub.snapshot_download( repo_id='h94/IP-Adapter', allow_patterns=[ 'models/**', 'sdxl_models/**', ], local_dir='./', local_dir_use_symlinks=False, ) import gradio as gr from diffusers import StableDiffusionXLControlNetInpaintPipeline, ControlNetModel from rembg import remove from PIL import Image import torch from ip_adapter import IPAdapterXL from ip_adapter.utils import register_cross_attention_hook, get_net_attn_map, attnmaps2images from PIL import Image, ImageChops, ImageEnhance import numpy as np import os import glob import torch import cv2 import argparse from diffusers.models.attention_processor import AttnProcessor2_0 import DPT.util.io from torchvision.transforms import Compose from DPT.dpt.models import DPTDepthModel from DPT.dpt.midas_net import MidasNet_large from DPT.dpt.transforms import Resize, NormalizeImage, PrepareForNet """ Get ZeST Ready """ base_model_path = "Lykon/dreamshaper-xl-lightning" image_encoder_path = "models/image_encoder" ip_ckpt = "sdxl_models/ip-adapter_sdxl_vit-h.bin" controlnet_path = "diffusers/controlnet-depth-sdxl-1.0" device = "cuda" torch.cuda.empty_cache() # load SDXL pipeline controlnet = ControlNetModel.from_pretrained(controlnet_path, variant="fp16", use_safetensors=True, torch_dtype=torch.float16).to(device) pipe = StableDiffusionXLControlNetInpaintPipeline.from_pretrained( base_model_path, controlnet=controlnet, use_safetensors=True, torch_dtype=torch.float16, add_watermarker=False, ).to(device) pipe.unet = register_cross_attention_hook(pipe.unet) pipe.unet.set_attn_processor(AttnProcessor2_0()) ip_model = IPAdapterXL(pipe, image_encoder_path, ip_ckpt, device) """ Get Depth Model Ready """ model_path = "DPT/weights/dpt_hybrid-midas-501f0c75.pt" net_w = net_h = 384 model = DPTDepthModel( path=model_path, backbone="vitb_rn50_384", non_negative=True, enable_attention_hooks=False, ) normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) transform = Compose( [ Resize( net_w, net_h, resize_target=None, keep_aspect_ratio=True, ensure_multiple_of=32, resize_method="minimal", image_interpolation_method=cv2.INTER_CUBIC, ), normalization, PrepareForNet(), ] ) model.eval() @spaces.GPU() def greet(input_image, material_exemplar): """ Compute depth map from input_image """ img = np.array(input_image) img_input = transform({"image": img})["image"] # compute with torch.no_grad(): sample = torch.from_numpy(img_input).unsqueeze(0) # if optimize == True and device == torch.device("cuda"): # sample = sample.to(memory_format=torch.channels_last) # sample = sample.half() prediction = model.forward(sample) prediction = ( torch.nn.functional.interpolate( prediction.unsqueeze(1), size=img.shape[:2], mode="bicubic", align_corners=False, ) .squeeze() .cpu() .numpy() ) depth_min = prediction.min() depth_max = prediction.max() bits = 2 max_val = (2 ** (8 * bits)) - 1 if depth_max - depth_min > np.finfo("float").eps: out = max_val * (prediction - depth_min) / (depth_max - depth_min) else: out = np.zeros(prediction.shape, dtype=depth.dtype) out = (out / 256).astype('uint8') depth_map = Image.fromarray(out).resize((1024, 1024)) """ Process foreground decolored image """ rm_bg = remove(input_image) target_mask = rm_bg.convert("RGB").point(lambda x: 0 if x < 1 else 255).convert('L').convert('RGB') mask_target_img = ImageChops.lighter(input_image, target_mask) invert_target_mask = ImageChops.invert(target_mask) gray_target_image = input_image.convert('L').convert('RGB') gray_target_image = ImageEnhance.Brightness(gray_target_image) factor = 1.0 # Try adjusting this to get the desired brightness gray_target_image = gray_target_image.enhance(factor) grayscale_img = ImageChops.darker(gray_target_image, target_mask) img_black_mask = ImageChops.darker(input_image, invert_target_mask) grayscale_init_img = ImageChops.lighter(img_black_mask, grayscale_img) init_img = grayscale_init_img """ Process material exemplar and resize all images """ ip_image = material_exemplar.resize((1024, 1024)) init_img = init_img.resize((1024,1024)) mask = target_mask.resize((1024, 1024)) num_samples = 1 images = ip_model.generate(guidance_scale=2, pil_image=ip_image, image=init_img, control_image=depth_map, mask_image=mask, controlnet_conditioning_scale=0.9, num_samples=num_samples, num_inference_steps=4, seed=42) return images[0] css = """ #col-container{ margin: 0 auto; max-width: 960px; } """ with gr.Blocks(css=css) as demo: with gr.Column(elem_id="col-container"): gr.Markdown(""" # ZeST: Zero-Shot Material Transfer from a Single Image

Upload two images -- input image and material exemplar. (both 1024*1024 for better results)
ZeST extracts the material from the exemplar and cast it onto the input image following the original lighting cues.

""") with gr.Row(): with gr.Column(): with gr.Row(): input_image = gr.Image(type="pil", label="input image") input_image2 = gr.Image(type="pil", label = "material examplar") submit_btn = gr.Button("Submit") gr.Examples( examples = [["demo_assets/input_imgs/pumpkin.png", "demo_assets/material_exemplars/cup_glaze.png"]], inputs = [input_image, input_image2] ) with gr.Column(): output_image = gr.Image(label="transfer result") submit_btn.click(fn=greet, inputs=[input_image, input_image2], outputs=[output_image]) demo.queue().launch()