Asagi-14B / app.py
hayas's picture
Add files
b3d9f11
#!/usr/bin/env python
import gradio as gr
import PIL.Image
import spaces
import torch
from transformers import AutoModel, AutoProcessor, GenerationConfig
DESCRIPTION = "# MIL-UT/Asagi-14B"
model_id = "MIL-UT/Asagi-14B"
processor = AutoProcessor.from_pretrained(model_id)
model = AutoModel.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="auto", trust_remote_code=True)
TEMPLATE = (
"以下は、タスクを説明する指示です。要求を適切に満たす応答を書きなさい。\n\n"
"### 指示:\n<image>\n{prompt}\n\n### 応答:\n"
)
@spaces.GPU
def run(
image: PIL.Image.Image,
prompt: str,
max_new_tokens: int = 256,
temperature: float = 0.7,
) -> str:
prompt = TEMPLATE.format(prompt=prompt)
inputs = processor(text=prompt, images=image, return_tensors="pt")
inputs_text = processor.tokenizer(prompt, return_tensors="pt")
inputs["input_ids"] = inputs_text["input_ids"]
inputs["attention_mask"] = inputs_text["attention_mask"]
for k, v in inputs.items():
if v.dtype == torch.float32:
inputs[k] = v.to(model.dtype)
inputs = {k: inputs[k].to(model.device) for k in inputs if k != "token_type_ids"}
generation_config = GenerationConfig(
max_new_tokens=max_new_tokens,
temperature=temperature,
do_sample=temperature > 0,
num_beams=5,
)
output = model.generate(**inputs, generation_config=generation_config)
generated_text = processor.batch_decode(output, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
# do not print the prompt
if "<image>" in prompt:
prompt = prompt.replace("<image>", " ")
return generated_text.replace(prompt, "")
examples = [
[
"https://raw.githubusercontent.com/uehara-mech/uehara-mech.github.io/refs/heads/master/images/shibuya.jpg",
"この画像を見て、次の質問に詳細かつ具体的に答えてください。この写真はどこで撮影されたものか教えてください。また、画像の内容についても詳しく説明してください。",
],
[
"https://raw.githubusercontent.com/uehara-mech/uehara-mech.github.io/refs/heads/master/images/bridge.jpg",
"この画像を見て、次の指示に詳細かつ具体的に答えてください。この写真の内容について詳しく教えてください。",
],
[
"https://raw.githubusercontent.com/uehara-mech/uehara-mech.github.io/refs/heads/master/images/tower.jpg",
"この画像を見て、次の質問に詳細かつ具体的に答えてください。この写真について評価してください。",
],
[
"https://raw.githubusercontent.com/uehara-mech/uehara-mech.github.io/refs/heads/master/images/shrine.jpg",
"この画像を見て、次の質問に詳細かつ具体的に答えてください。この写真の神社について、細かいところまで詳しく説明してください。",
],
[
"https://raw.githubusercontent.com/uehara-mech/uehara-mech.github.io/refs/heads/master/images/garden.jpg",
"この画像を見て、次の指示に詳細かつ具体的に答えてください。これは日本庭園の中でも、どのような形式に分類される庭園ですか?また、その理由は何ですか?", # noqa: RUF001
],
[
"https://raw.githubusercontent.com/uehara-mech/uehara-mech.github.io/refs/heads/master/images/slope.jpg",
"この画像を見て、次の質問に詳細に答えてください。この画像の場所を舞台とした小説のあらすじを書いてください。",
],
]
with gr.Blocks(css_paths="style.css") as demo:
gr.Markdown(DESCRIPTION)
with gr.Row():
with gr.Column():
image = gr.Image(label="Input Image")
prompt = gr.Textbox(label="Prompt")
run_button = gr.Button()
with gr.Accordion("Advanced options", open=False):
max_new_tokens = gr.Slider(
label="Max new tokens",
minimum=1,
maximum=1024,
step=1,
value=256,
)
temperature = gr.Slider(
label="Temperature",
minimum=0.1,
maximum=2.0,
step=0.1,
value=0.7,
)
with gr.Column():
output = gr.Textbox(label="Output")
gr.Examples(examples=examples, inputs=[image, prompt])
run_button.click(
fn=run,
inputs=[image, prompt, max_new_tokens, temperature],
outputs=output,
)
if __name__ == "__main__":
demo.launch()