StyleShot / app.py
nowsyn's picture
update app.py
f95120a
raw
history blame contribute delete
No virus
9.92 kB
import spaces
from types import MethodType
import os
import gradio as gr
import torch
import cv2
from annotator.util import resize_image
from annotator.hed import SOFT_HEDdetector
from annotator.lineart import LineartDetector
from diffusers import UNet2DConditionModel, ControlNetModel
from transformers import CLIPVisionModelWithProjection
from huggingface_hub import snapshot_download
from PIL import Image
from ip_adapter import StyleShot, StyleContentStableDiffusionControlNetPipeline
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
contour_detector = SOFT_HEDdetector()
lineart_detector = LineartDetector()
base_model_path = "runwayml/stable-diffusion-v1-5"
transformer_block_path = "laion/CLIP-ViT-H-14-laion2B-s32B-b79K"
styleshot_model_path = "Gaojunyao/StyleShot"
styleshot_lineart_model_path = "Gaojunyao/StyleShot_lineart"
if not os.path.isdir(base_model_path):
base_model_path = snapshot_download(base_model_path, local_dir=base_model_path)
print(f"Downloaded model to {base_model_path}")
if not os.path.isdir(transformer_block_path):
transformer_block_path = snapshot_download(transformer_block_path, local_dir=transformer_block_path)
print(f"Downloaded model to {transformer_block_path}")
if not os.path.isdir(styleshot_model_path):
styleshot_model_path = snapshot_download(styleshot_model_path, local_dir=styleshot_model_path)
print(f"Downloaded model to {styleshot_model_path}")
if not os.path.isdir(styleshot_lineart_model_path):
styleshot_lineart_model_path = snapshot_download(styleshot_lineart_model_path, local_dir=styleshot_lineart_model_path)
print(f"Downloaded model to {styleshot_lineart_model_path}")
# weights for ip-adapter and our content-fusion encoder
contour_ip_ckpt = os.path.join(styleshot_model_path, "pretrained_weight/ip.bin")
contour_style_aware_encoder_path = os.path.join(styleshot_model_path, "pretrained_weight/style_aware_encoder.bin")
contour_transformer_block_path = transformer_block_path
contour_unet = UNet2DConditionModel.from_pretrained(base_model_path, subfolder="unet")
contour_content_fusion_encoder = ControlNetModel.from_unet(contour_unet)
contour_pipe = StyleContentStableDiffusionControlNetPipeline.from_pretrained(base_model_path, controlnet=contour_content_fusion_encoder)
contour_styleshot = StyleShot(device, contour_pipe, contour_ip_ckpt, contour_style_aware_encoder_path, contour_transformer_block_path)
lineart_ip_ckpt = os.path.join(styleshot_lineart_model_path, "pretrained_weight/ip.bin")
lineart_style_aware_encoder_path = os.path.join(styleshot_lineart_model_path, "pretrained_weight/style_aware_encoder.bin")
lineart_transformer_block_path = transformer_block_path
lineart_unet = UNet2DConditionModel.from_pretrained(base_model_path, subfolder="unet")
lineart_content_fusion_encoder = ControlNetModel.from_unet(lineart_unet)
lineart_pipe = StyleContentStableDiffusionControlNetPipeline.from_pretrained(base_model_path, controlnet=lineart_content_fusion_encoder)
lineart_styleshot = StyleShot(device, lineart_pipe, lineart_ip_ckpt, lineart_style_aware_encoder_path, lineart_transformer_block_path)
@spaces.GPU(duration=150)
def process(style_image, content_image, prompt, num_samples, image_resolution, condition_scale, style_scale,ddim_steps, guidance_scale, seed, a_prompt, n_prompt, btn1, Contour_Threshold=200):
weight_dtype = torch.float32
style_shots = []
btns = []
contour_content_images = []
contour_results = []
lineart_content_images = []
lineart_results = []
type1 = 'Contour'
type2 = 'Lineart'
if btn1 == type1 or content_image is None:
style_shots = [contour_styleshot]
btns = [type1]
elif btn1 == type2:
style_shots = [lineart_styleshot]
btns = [type2]
elif btn1 == "Both":
style_shots = [contour_styleshot, lineart_styleshot]
btns = [type1, type2]
ori_style_image = style_image.copy()
if content_image is not None:
ori_content_image = content_image.copy()
else:
ori_content_image = None
for styleshot, btn in zip(style_shots, btns):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
prompts = [prompt+" "+a_prompt]
style_image = Image.fromarray(ori_style_image)
if ori_content_image is not None:
if btn == type1:
content_image = resize_image(ori_content_image, image_resolution)
content_image = contour_detector(content_image, threshold=Contour_Threshold)
elif btn == type2:
content_image = resize_image(ori_content_image, image_resolution)
content_image = lineart_detector(content_image, coarse=False)
content_image = Image.fromarray(content_image)
else:
content_image = cv2.resize(ori_style_image, (image_resolution, image_resolution))
content_image = Image.fromarray(content_image)
condition_scale = 0.0
g_images = styleshot.generate(style_image=style_image,
prompt=[[prompt]],
negative_prompt=n_prompt,
scale=style_scale,
num_samples = num_samples,
seed = seed,
num_inference_steps=ddim_steps,
guidance_scale=guidance_scale,
content_image=content_image,
controlnet_conditioning_scale= float(condition_scale))
if btn == type1:
contour_content_images = [content_image]
contour_results = g_images[0]
elif btn == type2:
lineart_content_images = [content_image]
lineart_results = g_images[0]
if ori_content_image is None:
contour_content_images = []
lineart_results = []
lineart_content_images = []
return [contour_results, contour_content_images, lineart_results, lineart_content_images]
block = gr.Blocks().queue()
with block:
with gr.Row():
gr.Markdown("## Styleshot Demo")
with gr.Row():
with gr.Column():
style_image = gr.Image(sources=['upload'], type="numpy", label='Style Image')
with gr.Column():
with gr.Blocks():
with gr.Column():
content_image = gr.Image(sources=['upload'], type="numpy", label='Content Image (optional)')
btn1 = gr.Radio(
choices=["Contour", "Lineart", "Both"],
interactive=True,
label="Preprocessor",
value="Both",
)
gr.Markdown("We recommend using 'Contour' for sparse control and 'Lineart' for detailed control. If you choose 'Both', we will provide results for two types of control. If you choose 'Contour', you can adjust the 'Contour Threshold' under the 'Advanced options' for the level of detail in control. ")
with gr.Row():
prompt = gr.Textbox(label="Prompt")
with gr.Row():
run_button = gr.Button(value="Run")
with gr.Row():
with gr.Column():
with gr.Accordion("Advanced options", open=False):
num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=4, step=1)
image_resolution = gr.Slider(label="Image Resolution", minimum=256, maximum=768, value=512, step=64)
condition_scale = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
Contour_Threshold = gr.Slider(label="Contour Threshold", minimum=0, maximum=255, value=200, step=1)
style_scale = gr.Slider(label="Style Strength", minimum=0, maximum=2, value=1.0, step=0.01)
ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=50, step=1)
guidance_scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=7.5, step=0.1)
seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, value=42, step=1)
a_prompt = gr.Textbox(label="Added Prompt", value='best quality, extremely detailed')
n_prompt = gr.Textbox(label="Negative Prompt",
value='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality')
with gr.Row():
gr.Markdown("### Results for Contour")
with gr.Row():
with gr.Blocks():
with gr.Row():
with gr.Column(scale = 1):
contour_gallery = gr.Gallery(label='Contour Output', show_label=True, elem_id="gallery", columns=[1], rows=[1], height='auto')
with gr.Column(scale = 4):
image_gallery = gr.Gallery(label='Result for Contour', show_label=True, elem_id="gallery", columns=[4], rows=[1], height='auto')
with gr.Row():
gr.Markdown("### Results for Lineart")
with gr.Row():
with gr.Blocks():
with gr.Row():
with gr.Column(scale = 1):
line_gallery = gr.Gallery(label='Lineart Output', show_label=True, elem_id="gallery", columns=[1], rows=[1], height='auto')
with gr.Column(scale = 4):
line_image_gallery = gr.Gallery(label='Result for Lineart', show_label=True, elem_id="gallery", columns=[4], rows=[1], height='auto')
ips = [style_image, content_image, prompt, num_samples, image_resolution, condition_scale, style_scale, ddim_steps, guidance_scale, seed, a_prompt, n_prompt, btn1, Contour_Threshold]
run_button.click(fn=process, inputs=ips, outputs=[image_gallery, contour_gallery, line_image_gallery, line_gallery])
block.launch(server_name='0.0.0.0')