multimodalart's picture
Update app.py
4bacbb7 verified
raw
history blame
6.8 kB
import gradio as gr
import numpy as np
import random
import torch
from PIL import Image
import os
from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor
from kolors.pipelines.pipeline_stable_diffusion_xl_chatglm_256_ipadapter import StableDiffusionXLPipeline
from kolors.models.modeling_chatglm import ChatGLMModel
from kolors.models.tokenization_chatglm import ChatGLMTokenizer
from kolors.models.unet_2d_condition import UNet2DConditionModel
from diffusers import AutoencoderKL, EulerDiscreteScheduler
from huggingface_hub import snapshot_download
import spaces
device = "cuda"
root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
ckpt_dir = f'{root_dir}/weights/Kolors'
snapshot_download(repo_id="Kwai-Kolors/Kolors", local_dir=ckpt_dir)
snapshot_download(repo_id="Kwai-Kolors/Kolors-IP-Adapter-Plus", local_dir=f"{root_dir}/weights/Kolors-IP-Adapter-Plus")
# Load models
text_encoder = ChatGLMModel.from_pretrained(f'{ckpt_dir}/text_encoder', torch_dtype=torch.float16).half().to(device)
tokenizer = ChatGLMTokenizer.from_pretrained(f'{ckpt_dir}/text_encoder')
vae = AutoencoderKL.from_pretrained(f"{ckpt_dir}/vae", revision=None).half().to(device)
scheduler = EulerDiscreteScheduler.from_pretrained(f"{ckpt_dir}/scheduler")
unet = UNet2DConditionModel.from_pretrained(f"{ckpt_dir}/unet", revision=None).half().to(device)
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
f'{root_dir}/weights/Kolors-IP-Adapter-Plus/image_encoder',
ignore_mismatched_sizes=True
).to(dtype=torch.float16, device=device)
ip_img_size = 336
clip_image_processor = CLIPImageProcessor(size=ip_img_size, crop_size=ip_img_size)
pipe = StableDiffusionXLPipeline(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
image_encoder=image_encoder,
feature_extractor=clip_image_processor,
force_zeros_for_empty_prompt=False
).to(device)
#pipe = pipe.to(device)
#pipe.enable_model_cpu_offload()
if hasattr(pipe.unet, 'encoder_hid_proj'):
pipe.unet.text_encoder_hid_proj = pipe.unet.encoder_hid_proj
pipe.load_ip_adapter(f'{root_dir}/weights/Kolors-IP-Adapter-Plus', subfolder="", weight_name=["ip_adapter_plus_general.bin"])
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1024
@spaces.GPU
def infer(prompt, ip_adapter_image, negative_prompt="", seed=42, randomize_seed=False, width=1024, height=1024, guidance_scale=5.0, num_inference_steps=50, ip_adapter_scale=0.5, progress=gr.Progress(track_tqdm=True)):
if randomize_seed:
seed = random.randint(0, MAX_SEED)
generator = torch.Generator(device="cuda").manual_seed(seed)
pipe = pipe.to("cuda")
image_encoder = image_encoder.to("cuda")
pipe.image_encoder = image_encoder
pipe.set_ip_adapter_scale([ip_adapter_scale])
image = pipe(
prompt=prompt,
ip_adapter_image=[ip_adapter_image],
negative_prompt=negative_prompt,
height=height,
width=width,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
num_images_per_prompt=1,
generator=generator,
).images[0]
return image, seed
examples = [
["A photo of a lavender cat", "https://upload.wikimedia.org/wikipedia/commons/thumb/4/4d/Cat_November_2010-1a.jpg/640px-Cat_November_2010-1a.jpg"],
["Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", "https://upload.wikimedia.org/wikipedia/commons/thumb/b/b5/Astronaut_EVA.jpg/640px-Astronaut_EVA.jpg"],
["An astronaut riding a green horse", "https://upload.wikimedia.org/wikipedia/commons/thumb/f/f7/Haflinger_in-motion.jpg/640px-Haflinger_in-motion.jpg"],
["A delicious ceviche cheesecake slice", "https://upload.wikimedia.org/wikipedia/commons/thumb/9/9c/Ceviche_mixto.jpg/640px-Ceviche_mixto.jpg"],
]
css="""
#col-container {
margin: 0 auto;
max-width: 720px;
}
"""
with gr.Blocks(css=css) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown(f"""
# Kolors Demo
Demo of the Kolors model with IP-Adapter integration
""")
with gr.Row():
prompt = gr.Text(
label="Prompt",
show_label=False,
max_lines=1,
placeholder="Enter your prompt",
container=False,
)
run_button = gr.Button("Run", scale=0)
with gr.Row():
ip_adapter_image = gr.Image(label="IP-Adapter Image", type="pil")
result = gr.Image(label="Result", show_label=False)
with gr.Accordion("Advanced Settings", open=False):
negative_prompt = gr.Text(
label="Negative prompt",
max_lines=1,
placeholder="Enter a negative prompt",
)
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=0,
)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
with gr.Row():
width = gr.Slider(
label="Width",
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=32,
value=1024,
)
height = gr.Slider(
label="Height",
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=32,
value=1024,
)
with gr.Row():
guidance_scale = gr.Slider(
label="Guidance scale",
minimum=0.0,
maximum=10.0,
step=0.1,
value=5.0,
)
num_inference_steps = gr.Slider(
label="Number of inference steps",
minimum=1,
maximum=100,
step=1,
value=50,
)
ip_adapter_scale = gr.Slider(
label="IP-Adapter Scale",
minimum=0.0,
maximum=1.0,
step=0.01,
value=0.5,
)
gr.Examples(
examples=examples,
fn=infer,
inputs=[prompt, ip_adapter_image],
outputs=[result, seed],
cache_examples="lazy"
)
gr.on(
triggers=[run_button.click, prompt.submit],
fn=infer,
inputs=[prompt, ip_adapter_image, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps, ip_adapter_scale],
outputs=[result, seed]
)
demo.queue().launch()