ipkol / app.py
fantaxy's picture
Update app.py
35e3bb1 verified
raw
history blame
6.71 kB
import gradio as gr
import numpy as np
import random
import torch
from PIL import Image
import os
from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor, pipeline
from kolors.pipelines.pipeline_stable_diffusion_xl_chatglm_256_ipadapter import StableDiffusionXLPipeline
from kolors.models.modeling_chatglm import ChatGLMModel
from kolors.models.tokenization_chatglm import ChatGLMTokenizer
from kolors.models.unet_2d_condition import UNet2DConditionModel
from diffusers import AutoencoderKL, EulerDiscreteScheduler
from huggingface_hub import snapshot_download
import spaces
device = "cuda"
root_dir = os.getcwd()
ckpt_dir = f'{root_dir}/weights/Kolors'
snapshot_download(repo_id="Kwai-Kolors/Kolors", local_dir=ckpt_dir)
snapshot_download(repo_id="Kwai-Kolors/Kolors-IP-Adapter-Plus", local_dir=f"{root_dir}/weights/Kolors-IP-Adapter-Plus")
# Load models
text_encoder = ChatGLMModel.from_pretrained(f'{ckpt_dir}/text_encoder', torch_dtype=torch.float16).half().to(device)
tokenizer = ChatGLMTokenizer.from_pretrained(f'{ckpt_dir}/text_encoder')
vae = AutoencoderKL.from_pretrained(f"{ckpt_dir}/vae", revision=None).half().to(device)
scheduler = EulerDiscreteScheduler.from_pretrained(f"{ckpt_dir}/scheduler")
unet = UNet2DConditionModel.from_pretrained(f"{ckpt_dir}/unet", revision=None).half().to(device)
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
f'{root_dir}/weights/Kolors-IP-Adapter-Plus/image_encoder',
ignore_mismatched_sizes=True
).to(dtype=torch.float16, device=device)
ip_img_size = 336
clip_image_processor = CLIPImageProcessor(size=ip_img_size, crop_size=ip_img_size)
pipe = StableDiffusionXLPipeline(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
image_encoder=image_encoder,
feature_extractor=clip_image_processor,
force_zeros_for_empty_prompt=False
).to(device)
if hasattr(pipe.unet, 'encoder_hid_proj'):
pipe.unet.text_encoder_hid_proj = pipe.unet.encoder_hid_proj
pipe.load_ip_adapter(f'{root_dir}/weights/Kolors-IP-Adapter-Plus', subfolder="", weight_name=["ip_adapter_plus_general.bin"])
translator = pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en")
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1024
@spaces.GPU
def infer(prompt, ip_adapter_image, ip_adapter_scale=0.5, negative_prompt="", seed=100, randomize_seed=False, width=1024, height=1024, guidance_scale=5.0, num_inference_steps=50, progress=gr.Progress(track_tqdm=True)):
if randomize_seed:
seed = random.randint(0, MAX_SEED)
# Translate prompt if it's in Korean
translated_prompt = translator(prompt, src_lang="ko", tgt_lang="en")[0]['translation_text']
generator = torch.Generator(device="cuda").manual_seed(seed)
pipe.to("cuda")
image_encoder.to("cuda")
pipe.image_encoder = image_encoder
pipe.set_ip_adapter_scale([ip_adapter_scale])
image = pipe(
prompt=translated_prompt,
ip_adapter_image=[ip_adapter_image],
negative_prompt=negative_prompt,
height=height,
width=width,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
num_images_per_prompt=1,
generator=generator,
).images[0]
return image, seed
examples = [
["๊ฐ•์•„์ง€", "minta.jpeg", 0.4],
["ํ™˜ํ•˜๊ฒŒ ์›ƒ์–ด๋ผ", "trump.png", 0.5],
["์˜ฌ๋นผ๋ฏธ", "forest.png", 0.5],
["", "meow.jpeg", 1.0],
]
css="""
#col-container {
margin: 0 auto;
max-width: 720px;
}
#result img{
object-position: top;
}
#result .image-container{
height: 100%
}
"""
with gr.Blocks(css=css) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown(f"""
# Kolors IP-Adapter - ์ด๋ฏธ์ง€ ์ฐธ์กฐ ๋ฐ ๋ณ€ํ˜•
""")
with gr.Row():
prompt = gr.Text(
label="ํ”„๋กฌํ”„ํŠธ",
show_label=False,
max_lines=1,
placeholder="ํ”„๋กฌํ”„ํŠธ๋ฅผ ์ž…๋ ฅํ•˜์„ธ์š”",
container=False,
)
run_button = gr.Button("์‹คํ–‰", scale=0)
with gr.Row():
with gr.Column():
ip_adapter_image = gr.Image(label="IP-์–ด๋Œ‘ํ„ฐ ์ด๋ฏธ์ง€", type="pil")
ip_adapter_scale = gr.Slider(
label="์ด๋ฏธ์ง€ ์˜ํ–ฅ ์ฒ™๋„",
info="๋ณ€ํ˜•์„ ์ƒ์„ฑํ•˜๋ ค๋ฉด 1์„ ์‚ฌ์šฉํ•˜์„ธ์š”",
minimum=0.0,
maximum=1.0,
step=0.05,
value=0.5,
)
result = gr.Image(label="๊ฒฐ๊ณผ", elem_id="result")
with gr.Accordion("๊ณ ๊ธ‰ ์„ค์ •", open=False):
negative_prompt = gr.Text(
label="๋ถ€์ •์  ํ”„๋กฌํ”„ํŠธ",
max_lines=1,
placeholder="๋ถ€์ •์  ํ”„๋กฌํ”„ํŠธ๋ฅผ ์ž…๋ ฅํ•˜์„ธ์š”",
)
seed = gr.Slider(
label="์‹œ๋“œ",
minimum=0,
maximum=MAX_SEED,
step=1,
value=0,
)
randomize_seed = gr.Checkbox(label="์‹œ๋“œ ๋ฌด์ž‘์œ„ํ™”", value=True)
with gr.Row():
width = gr.Slider(
label="๋„ˆ๋น„",
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=32,
value=1024,
)
height = gr.Slider(
label="๋†’์ด",
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=32,
value=1024,
)
with gr.Row():
guidance_scale = gr.Slider(
label="๊ฐ€์ด๋˜์Šค ์ฒ™๋„",
minimum=0.0,
maximum=10.0,
step=0.1,
value=5.0,
)
num_inference_steps = gr.Slider(
label="์ถ”๋ก  ๋‹จ๊ณ„ ์ˆ˜",
minimum=1,
maximum=100,
step=1,
value=50,
)
gr.Examples(
examples=examples,
fn=infer,
inputs=[prompt, ip_adapter_image, ip_adapter_scale],
outputs=[result, seed],
cache_examples="lazy"
)
gr.on(
triggers=[run_button.click, prompt.submit],
fn=infer,
inputs=[prompt, ip_adapter_image, ip_adapter_scale, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps],
outputs=[result, seed]
)
# Launch the app
demo.launch(share=True)