joy-caption-ko / app.py
kijeoung's picture
Update app.py
bb5b825 verified
import spaces
import gradio as gr
from huggingface_hub import InferenceClient
from torch import nn
from transformers import AutoModel, AutoProcessor, AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast, AutoModelForCausalLM
from pathlib import Path
import torch
from PIL import Image
import os
import re
# ๊ฒฝ๋กœ ๋ฐ ์„ค์ •
CLIP_PATH = "google/siglip-so400m-patch14-384"
VLM_PROMPT = "A descriptive caption for this image:\n"
MODEL_PATH = "meta-llama/Meta-Llama-3.1-8B"
CHECKPOINT_PATH = Path("wpkklhc6")
TITLE = "<h1><center>JoyCaption Pre-Alpha (2024-07-30a)</center></h1>"
HF_TOKEN = os.environ.get("HF_TOKEN", None)
# ์ด๋ฏธ์ง€ ์–ด๋Œ‘ํ„ฐ ์ •์˜
class ImageAdapter(nn.Module):
def __init__(self, input_features: int, output_features: int):
super().__init__()
self.linear1 = nn.Linear(input_features, output_features)
self.activation = nn.GELU()
self.linear2 = nn.Linear(output_features, output_features)
def forward(self, vision_outputs: torch.Tensor):
x = self.linear1(vision_outputs)
x = self.activation(x)
x = self.linear2(x)
return x
# CLIP ๋ชจ๋ธ ๋กœ๋“œ
print("Loading CLIP")
clip_processor = AutoProcessor.from_pretrained(CLIP_PATH)
clip_model = AutoModel.from_pretrained(CLIP_PATH)
clip_model = clip_model.vision_model
clip_model.eval()
clip_model.requires_grad_(False)
clip_model.to("cuda")
# ํ† ํฌ๋‚˜์ด์ € ๋กœ๋“œ
print("Loading tokenizer")
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, use_fast=False)
assert isinstance(tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)), f"Tokenizer is of type {type(tokenizer)}"
# ์–ธ์–ด ๋ชจ๋ธ(LLM) ๋กœ๋“œ
print("Loading LLM")
text_model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, device_map="auto", torch_dtype=torch.bfloat16)
text_model.eval()
# ์ด๋ฏธ์ง€ ์–ด๋Œ‘ํ„ฐ ๋กœ๋“œ
print("Loading image adapter")
image_adapter = ImageAdapter(clip_model.config.hidden_size, text_model.config.hidden_size)
image_adapter.load_state_dict(torch.load(CHECKPOINT_PATH / "image_adapter.pt", map_location="cpu"))
image_adapter.eval()
image_adapter.to("cuda")
# ์ด๋ฏธ์ง€ ์„ค๋ช… ์ƒ์„ฑ ํ•จ์ˆ˜
@spaces.GPU()
@torch.no_grad()
def stream_chat(input_image: Image.Image):
torch.cuda.empty_cache()
# ์ด๋ฏธ์ง€ ์ „์ฒ˜๋ฆฌ
image = clip_processor(images=input_image, return_tensors='pt').pixel_values.to('cuda')
# ํ”„๋กฌํ”„ํŠธ ํ† ํฌ๋‚˜์ด์ฆˆ
prompt = tokenizer.encode(VLM_PROMPT, return_tensors='pt', add_special_tokens=False).to('cuda')
# ์ด๋ฏธ์ง€ ์ž„๋ฒ ๋”ฉ
with torch.amp.autocast_mode.autocast('cuda', enabled=True):
vision_outputs = clip_model(pixel_values=image, output_hidden_states=True)
image_features = vision_outputs.hidden_states[-2]
embedded_images = image_adapter(image_features).to('cuda')
# ํ”„๋กฌํ”„ํŠธ ์ž„๋ฒ ๋”ฉ
prompt_embeds = text_model.model.embed_tokens(prompt)
embedded_bos = text_model.model.embed_tokens(torch.tensor([[tokenizer.bos_token_id]], device='cuda', dtype=torch.int64))
# ํ”„๋กฌํ”„ํŠธ ๊ตฌ์„ฑ
inputs_embeds = torch.cat([
embedded_bos.expand(embedded_images.shape[0], -1, -1),
embedded_images,
prompt_embeds
], dim=1)
# CPU์— ์žˆ๋Š” ํ…์„œ๋ฅผ GPU๋กœ ์ด๋™
input_ids = torch.cat([
torch.tensor([[tokenizer.bos_token_id]], dtype=torch.long).to('cuda'),
torch.zeros((1, embedded_images.shape[1]), dtype=torch.long).to('cuda'),
prompt.to('cuda')
], dim=1)
attention_mask = torch.ones_like(input_ids)
# ํ…์ŠคํŠธ ์ƒ์„ฑ
generate_ids = text_model.generate(input_ids, inputs_embeds=inputs_embeds, attention_mask=attention_mask, max_new_tokens=300, do_sample=True, top_k=10, temperature=0.5)
generate_ids = generate_ids[:, input_ids.shape[1]:]
if generate_ids[0][-1] == tokenizer.eos_token_id:
generate_ids = generate_ids[:, :-1]
caption = tokenizer.batch_decode(generate_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False)[0]
return caption.strip()
# ํ•œ๊ธ€ ์˜ต์…˜์„ ์˜์–ด๋กœ ๋ณ€ํ™˜ํ•˜๋Š” ์‚ฌ์ „
translation_dict = {
"ํ•œ๊ตญ ๋‚จ์ž": "Korean man",
"ํ•œ๊ตญ ์—ฌ์ž": "Korean woman",
"์‚ญ๋ฐœ": "shaved",
"์ˆ์ปท": "short cut",
"๋ฏธ๋””์—„": "medium",
"๋กฑํ—ค์–ด": "long hair",
"๋ ˆ์ด์–ด๋“œ": "layered",
"๋ฐฅ": "bob",
"ํŽŒ": "perm",
"์—…์Šคํƒ€์ผ": "upstyle",
"ํฌ๋‹ˆํ…Œ์ผ": "ponytail",
"๋ธŒ๋ ˆ์ด๋“œ": "braid",
"์ปฌ": "curl",
"์›จ์ด๋ธŒ": "wave",
"๋ธ”๋ž™": "black",
"๋ธŒ๋ผ์šด": "brown",
"๋ธ”๋ก ๋“œ": "blonde",
"๋ ˆ๋“œ": "red",
"์• ์‰ฌ": "ash",
"ํผํ”Œ": "purple",
"ํ•‘ํฌ": "pink",
"๋ธ”๋ฃจ": "blue",
"๊ทธ๋ฆฐ": "green",
"์˜ค๋ Œ์ง€": "orange",
"ํ™”์ดํŠธ": "white",
"ํ—ค์–ด๋ฐด๋“œ": "headband",
"๋จธ๋ฆฌํ•€": "hairpin",
"๋ฆฌ๋ณธ": "ribbon",
"์Šคํฌ๋Ÿฐ์น˜": "scrunchie",
"ํ—ค์–ดํด๋ฆฝ": "hairclip",
"ํ‹ฐ์•„๋ผ": "tiara",
"๊ฝƒ์žฅ์‹": "flower decoration",
}
# ์„ฑ๋ณ„์— ๋”ฐ๋ฅธ ๋‹จ์–ด ์น˜ํ™˜
def translate(option):
return translation_dict.get(option, option)
def replace_gender_specific_words(caption, gender_prefix):
if gender_prefix == "Korean man":
caption = re.sub(r'\bwoman\b', "man", caption, flags=re.IGNORECASE)
caption = re.sub(r'\bgirl\b', "boy", caption, flags=re.IGNORECASE)
caption = re.sub(r'\blady\b', "gentleman", caption, flags=re.IGNORECASE)
caption = re.sub(r'\bshe\b', "he", caption, flags=re.IGNORECASE)
caption = re.sub(r'\bher\b', "his", caption, flags=re.IGNORECASE)
caption = re.sub(r'\bherself\b', "himself", caption, flags=re.IGNORECASE)
elif gender_prefix == "Korean woman":
caption = re.sub(r'\bman\b', "woman", caption, flags=re.IGNORECASE)
caption = re.sub(r'\bboy\b', "girl", caption, flags=re.IGNORECASE)
caption = re.sub(r'\bgentleman\b', "lady", caption, flags=re.IGNORECASE)
caption = re.sub(r'\bhe\b', "she", caption, flags=re.IGNORECASE)
caption = re.sub(r'\bhis\b', "her", caption, flags=re.IGNORECASE)
caption = re.sub(r'\bhimself\b', "herself", caption, flags=re.IGNORECASE)
return caption
def replace_gender_words(caption, gender, age, hair_length, hair_style, hair_color, hair_accessory):
gender_prefix = translate(gender)
hair_length_en = translate(hair_length)
hair_style_en = translate(hair_style)
hair_color_en = translate(hair_color)
hair_accessory_en = translate(hair_accessory)
hair_description = f"{hair_length_en} hair with {hair_style_en}, {hair_color_en} color"
if hair_accessory_en:
hair_description += f", wearing a {hair_accessory_en}"
caption = replace_gender_specific_words(caption, gender_prefix)
return f"{gender_prefix}, age {age}, {hair_description}: {caption}"
# Recaption ํ•จ์ˆ˜
def recaption(input_image: Image.Image, prefix: str, age: int, hair_length: str, hair_style: str, hair_color: str, hair_accessory: str):
original_caption = stream_chat(input_image)
updated_caption = replace_gender_words(original_caption, prefix, age, hair_length, hair_style, hair_color, hair_accessory)
return updated_caption
# Gradio ์ธํ„ฐํŽ˜์ด์Šค
with gr.Blocks() as demo:
gr.HTML(TITLE)
with gr.Row():
with gr.Column():
input_image = gr.Image(type="pil", label="Input Image")
run_button = gr.Button("Caption")
recaption_button = gr.Button("Recaption")
gender_selection = gr.Radio(choices=["ํ•œ๊ตญ ๋‚จ์ž", "ํ•œ๊ตญ ์—ฌ์ž"], label="์„ฑ๋ณ„ ์„ ํƒ")
age_slider = gr.Slider(minimum=1, maximum=100, step=1, label="Age", value=25)
hair_length = gr.Radio(choices=["์‚ญ๋ฐœ", "์ˆ์ปท", "๋ฏธ๋””์—„", "๋กฑํ—ค์–ด", "๋ ˆ์ด์–ด๋“œ", "๋ฐฅ"], label="ํ—ค์–ด ๊ธธ์ด")
hair_style = gr.Radio(choices=["ํŽŒ", "์—…์Šคํƒ€์ผ", "ํฌ๋‹ˆํ…Œ์ผ", "๋ธŒ๋ ˆ์ด๋“œ", "์ปฌ", "์›จ์ด๋ธŒ"], label="ํ—ค์–ด ์Šคํƒ€์ผ")
hair_color = gr.Radio(choices=["๋ธ”๋ž™", "๋ธŒ๋ผ์šด", "๋ธ”๋ก ๋“œ", "๋ ˆ๋“œ", "์• ์‰ฌ", "ํผํ”Œ", "ํ•‘ํฌ", "๋ธ”๋ฃจ", "๊ทธ๋ฆฐ", "์˜ค๋ Œ์ง€", "ํ™”์ดํŠธ"], label="ํ—ค์–ด ์ƒ‰์ƒ")
hair_accessory = gr.Radio(choices=["ํ—ค์–ด๋ฐด๋“œ", "๋จธ๋ฆฌํ•€", "๋ฆฌ๋ณธ", "์Šคํฌ๋Ÿฐ์น˜", "ํ—ค์–ดํด๋ฆฝ", "ํ‹ฐ์•„๋ผ", "๊ฝƒ์žฅ์‹"], label="ํ—ค์–ด ์•ก์„ธ์„œ๋ฆฌ")
with gr.Column():
output_caption = gr.Textbox(label="Caption")
new_caption_output = gr.Textbox(label="Recaptioned Caption", placeholder="New caption will appear here")
run_button.click(fn=stream_chat, inputs=[input_image], outputs=[output_caption])
recaption_button.click(fn=recaption, inputs=[input_image, gender_selection, age_slider, hair_length, hair_style, hair_color, hair_accessory], outputs=[new_caption_output])
if __name__ == "__main__":
demo.launch()