import gradio as gr import torch from transformers import AutoTokenizer, AutoModelForVision2Seq, AutoImageProcessor from PIL import Image import spaces import os from huggingface_hub import login login(os.environ["HF_KEY"]) device = "cuda" if torch.cuda.is_available() else "cpu" model = AutoModelForVision2Seq.from_pretrained("stabilityai/japanese-stable-vlm", trust_remote_code=True, device_map='auto') processor = AutoImageProcessor.from_pretrained("stabilityai/japanese-stable-vlm", device_map='auto') tokenizer = AutoTokenizer.from_pretrained("stabilityai/japanese-stable-vlm", device_map='auto') # Define the helper function to build prompts TASK2INSTRUCTION = { "caption": "画像を詳細に述べてください。", "tag": "与えられた単語を使って、画像を詳細に述べてください。", "vqa": "与えられた画像を下に、質問に答えてください。", } def build_prompt(task="caption", input=None, sep="\n\n### "): assert task in TASK2INSTRUCTION, f"Please choose from {list(TASK2INSTRUCTION.keys())}" if task in ["tag", "vqa"]: assert input is not None, "Please fill in `input`!" if task == "tag" and isinstance(input, list): input = "、".join(input) else: assert input is None, f"`{task}` mode doesn't support to input questions" sys_msg = "以下は、タスクを説明する指示と、文脈のある入力の組み合わせです。要求を適切に満たす応答を書きなさい。" p = sys_msg roles = ["指示", "応答"] instruction = TASK2INSTRUCTION[task] msgs = [": \n" + instruction, ": \n"] if input: roles.insert(1, "入力") msgs.insert(1, ": \n" + input) for role, msg in zip(roles, msgs): p += sep + role + msg return p # Define the function to generate text from the image and prompt @spaces.GPU(duration=120) def generate_text(image, task, input_text=None): prompt = build_prompt(task=task, input=input_text) inputs = processor(images=image, return_tensors="pt") text_encoding = tokenizer(prompt, add_special_tokens=False, return_tensors="pt") inputs.update(text_encoding) outputs = model.generate( **inputs.to(device=device, dtype=model.dtype), do_sample=False, num_beams=5, max_new_tokens=128, min_length=1, repetition_penalty=1.5, ) generated_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0].strip() return generated_text # Define the Gradio interface image_input = gr.Image(label="Upload an image") task_input = gr.Radio(choices=["caption", "tag", "vqa"], value="caption", label="Select a task") text_input = gr.Textbox(label="Enter text (for tag or vqa tasks)") output = gr.Textbox(label="Generated text") interface = gr.Interface( fn=generate_text, inputs=[image_input, task_input, text_input], outputs=output, examples=[ ["examples/example_1.jpeg", "caption", None], ["examples/example_2.jpg", "tag", "寿司、箸"], ["examples/example_3.jpg", "vqa", "この画像を説明する"], ], ) interface.launch()