|
|
|
|
|
import gradio as gr |
|
import PIL.Image |
|
import spaces |
|
import torch |
|
from transformers import AutoModel, AutoProcessor, GenerationConfig |
|
|
|
DESCRIPTION = "# MIL-UT/Asagi-14B" |
|
|
|
model_id = "MIL-UT/Asagi-14B" |
|
processor = AutoProcessor.from_pretrained(model_id) |
|
model = AutoModel.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="auto", trust_remote_code=True) |
|
|
|
|
|
TEMPLATE = ( |
|
"以下は、タスクを説明する指示です。要求を適切に満たす応答を書きなさい。\n\n" |
|
"### 指示:\n<image>\n{prompt}\n\n### 応答:\n" |
|
) |
|
|
|
|
|
@spaces.GPU |
|
def run( |
|
image: PIL.Image.Image, |
|
prompt: str, |
|
max_new_tokens: int = 256, |
|
temperature: float = 0.7, |
|
) -> str: |
|
prompt = TEMPLATE.format(prompt=prompt) |
|
|
|
inputs = processor(text=prompt, images=image, return_tensors="pt") |
|
inputs_text = processor.tokenizer(prompt, return_tensors="pt") |
|
inputs["input_ids"] = inputs_text["input_ids"] |
|
inputs["attention_mask"] = inputs_text["attention_mask"] |
|
for k, v in inputs.items(): |
|
if v.dtype == torch.float32: |
|
inputs[k] = v.to(model.dtype) |
|
inputs = {k: inputs[k].to(model.device) for k in inputs if k != "token_type_ids"} |
|
|
|
generation_config = GenerationConfig( |
|
max_new_tokens=max_new_tokens, |
|
temperature=temperature, |
|
do_sample=temperature > 0, |
|
num_beams=5, |
|
) |
|
|
|
output = model.generate(**inputs, generation_config=generation_config) |
|
generated_text = processor.batch_decode(output, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] |
|
|
|
|
|
if "<image>" in prompt: |
|
prompt = prompt.replace("<image>", " ") |
|
return generated_text.replace(prompt, "") |
|
|
|
|
|
examples = [ |
|
[ |
|
"https://raw.githubusercontent.com/uehara-mech/uehara-mech.github.io/refs/heads/master/images/shibuya.jpg", |
|
"この画像を見て、次の質問に詳細かつ具体的に答えてください。この写真はどこで撮影されたものか教えてください。また、画像の内容についても詳しく説明してください。", |
|
], |
|
[ |
|
"https://raw.githubusercontent.com/uehara-mech/uehara-mech.github.io/refs/heads/master/images/bridge.jpg", |
|
"この画像を見て、次の指示に詳細かつ具体的に答えてください。この写真の内容について詳しく教えてください。", |
|
], |
|
[ |
|
"https://raw.githubusercontent.com/uehara-mech/uehara-mech.github.io/refs/heads/master/images/tower.jpg", |
|
"この画像を見て、次の質問に詳細かつ具体的に答えてください。この写真について評価してください。", |
|
], |
|
[ |
|
"https://raw.githubusercontent.com/uehara-mech/uehara-mech.github.io/refs/heads/master/images/shrine.jpg", |
|
"この画像を見て、次の質問に詳細かつ具体的に答えてください。この写真の神社について、細かいところまで詳しく説明してください。", |
|
], |
|
[ |
|
"https://raw.githubusercontent.com/uehara-mech/uehara-mech.github.io/refs/heads/master/images/garden.jpg", |
|
"この画像を見て、次の指示に詳細かつ具体的に答えてください。これは日本庭園の中でも、どのような形式に分類される庭園ですか?また、その理由は何ですか?", |
|
], |
|
[ |
|
"https://raw.githubusercontent.com/uehara-mech/uehara-mech.github.io/refs/heads/master/images/slope.jpg", |
|
"この画像を見て、次の質問に詳細に答えてください。この画像の場所を舞台とした小説のあらすじを書いてください。", |
|
], |
|
] |
|
|
|
with gr.Blocks(css_paths="style.css") as demo: |
|
gr.Markdown(DESCRIPTION) |
|
with gr.Row(): |
|
with gr.Column(): |
|
image = gr.Image(label="Input Image") |
|
prompt = gr.Textbox(label="Prompt") |
|
run_button = gr.Button() |
|
with gr.Accordion("Advanced options", open=False): |
|
max_new_tokens = gr.Slider( |
|
label="Max new tokens", |
|
minimum=1, |
|
maximum=1024, |
|
step=1, |
|
value=256, |
|
) |
|
temperature = gr.Slider( |
|
label="Temperature", |
|
minimum=0.1, |
|
maximum=2.0, |
|
step=0.1, |
|
value=0.7, |
|
) |
|
with gr.Column(): |
|
output = gr.Textbox(label="Output") |
|
|
|
gr.Examples(examples=examples, inputs=[image, prompt]) |
|
|
|
run_button.click( |
|
fn=run, |
|
inputs=[image, prompt, max_new_tokens, temperature], |
|
outputs=output, |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|