#!/usr/bin/env python from __future__ import annotations import os import gradio as gr import PIL.Image import spaces import torch from transformers import InstructBlipForConditionalGeneration, InstructBlipProcessor DESCRIPTION = "# InstructBLIP" MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "1024")) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model_id = "Salesforce/instructblip-vicuna-7b" processor = InstructBlipProcessor.from_pretrained(model_id) model = InstructBlipForConditionalGeneration.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto") @spaces.GPU def run( image: PIL.Image.Image, prompt: str, text_decoding_method: str = "Nucleus sampling", num_beams: int = 5, max_length: int = 256, min_length: int = 1, top_p: float = 0.9, repetition_penalty: float = 1.5, length_penalty: float = 1.0, temperature: float = 1.0, ) -> str: h, w = image.size scale = MAX_IMAGE_SIZE / max(h, w) if scale < 1: new_w = int(w * scale) new_h = int(h * scale) image = image.resize((new_w, new_h), resample=PIL.Image.Resampling.LANCZOS) inputs = processor(images=image, text=prompt, return_tensors="pt").to(device, torch.float16) generated_ids = model.generate( **inputs, do_sample=text_decoding_method == "Nucleus sampling", num_beams=num_beams, max_length=max_length, min_length=min_length, top_p=top_p, repetition_penalty=repetition_penalty, length_penalty=length_penalty, temperature=temperature, ) generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip() return generated_caption with gr.Blocks(css="style.css") as demo: gr.Markdown(DESCRIPTION) with gr.Row(): with gr.Column(): input_image = gr.Image(type="pil") prompt = gr.Textbox(label="Prompt") run_button = gr.Button() with gr.Accordion(label="Advanced options", open=False): text_decoding_method = gr.Radio( label="Text Decoding Method", choices=["Beam search", "Nucleus sampling"], value="Nucleus sampling", ) num_beams = gr.Slider( label="Number of Beams", minimum=1, maximum=10, step=1, value=5, ) max_length = gr.Slider( label="Max Length", minimum=1, maximum=512, step=1, value=256, ) min_length = gr.Slider( label="Minimum Length", minimum=1, maximum=64, step=1, value=1, ) top_p = gr.Slider( label="Top P", minimum=0.1, maximum=1.0, step=0.1, value=0.9, ) repetition_penalty = gr.Slider( label="Repetition Penalty", info="Larger value prevents repetition.", minimum=1.0, maximum=5.0, step=0.5, value=1.5, ) length_penalty = gr.Slider( label="Length Penalty", info="Set to larger for longer sequence, used with beam search.", minimum=-1.0, maximum=2.0, step=0.2, value=1.0, ) temperature = gr.Slider( label="Temperature", info="Used with nucleus sampling.", minimum=0.5, maximum=1.0, step=0.1, value=1.0, ) with gr.Column(): output = gr.Textbox(label="Result") gr.on( triggers=[prompt.submit, run_button.click], fn=run, inputs=[ input_image, prompt, text_decoding_method, num_beams, max_length, min_length, top_p, repetition_penalty, length_penalty, temperature, ], outputs=output, api_name="run", ) if __name__ == "__main__": demo.queue(max_size=20).launch()