|
from transformers import BlipForConditionalGeneration, BlipProcessor |
|
import time |
|
import gradio as gr |
|
|
|
|
|
def get_image_captioning_tab(): |
|
salesforce_model_name = "Salesforce/blip-image-captioning-base" |
|
salesforce_model = BlipForConditionalGeneration.from_pretrained(salesforce_model_name) |
|
salesforce_processor = BlipProcessor.from_pretrained(salesforce_model_name) |
|
|
|
noamrot_model_name = "noamrot/FuseCap_Image_Captioning" |
|
noamrot_model = BlipForConditionalGeneration.from_pretrained(noamrot_model_name) |
|
noamrot_processor = BlipProcessor.from_pretrained(noamrot_model_name) |
|
|
|
model_map = { |
|
salesforce_model_name: (salesforce_model, salesforce_processor), |
|
noamrot_model_name: (noamrot_model, noamrot_processor) |
|
} |
|
|
|
def gradio_process(model_name, image, text): |
|
(model, processor) = model_map[model_name] |
|
start = time.time() |
|
inputs = processor(image, text, return_tensors="pt") |
|
out = model.generate(**inputs) |
|
result = processor.decode(out[0], skip_special_tokens=True) |
|
end = time.time() |
|
time_spent = end - start |
|
|
|
return [result, time_spent] |
|
|
|
with gr.TabItem("Image Captioning") as image_captioning_tab: |
|
gr.Markdown("# Image Captioning") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
|
|
input_image = gr.Image(label="Upload Image", type="pil") |
|
input_text = gr.Textbox(label="Caption") |
|
model_selector = gr.Dropdown([salesforce_model_name, noamrot_model_name], |
|
label = "Select Model") |
|
|
|
|
|
process_btn = gr.Button("Generate caption") |
|
|
|
with gr.Column(): |
|
|
|
elapsed_result = gr.Textbox(label="Seconds elapsed", lines=1) |
|
output_text = gr.Textbox(label="Generated caption") |
|
|
|
|
|
process_btn.click( |
|
fn=gradio_process, |
|
inputs=[ |
|
model_selector, |
|
input_image, |
|
input_text |
|
], |
|
outputs=[output_text, elapsed_result] |
|
) |
|
|
|
return image_captioning_tab |
|
|