open-source-models-hg / image_captioning.py
kirill
Fixed tabs
01ab6e1
from transformers import BlipForConditionalGeneration, BlipProcessor
import time
import gradio as gr
def get_image_captioning_tab():
salesforce_model_name = "Salesforce/blip-image-captioning-base"
salesforce_model = BlipForConditionalGeneration.from_pretrained(salesforce_model_name)
salesforce_processor = BlipProcessor.from_pretrained(salesforce_model_name)
noamrot_model_name = "noamrot/FuseCap_Image_Captioning"
noamrot_model = BlipForConditionalGeneration.from_pretrained(noamrot_model_name)
noamrot_processor = BlipProcessor.from_pretrained(noamrot_model_name)
model_map = {
salesforce_model_name: (salesforce_model, salesforce_processor),
noamrot_model_name: (noamrot_model, noamrot_processor)
}
def gradio_process(model_name, image, text):
(model, processor) = model_map[model_name]
start = time.time()
inputs = processor(image, text, return_tensors="pt")
out = model.generate(**inputs)
result = processor.decode(out[0], skip_special_tokens=True)
end = time.time()
time_spent = end - start
return [result, time_spent]
with gr.TabItem("Image Captioning") as image_captioning_tab:
gr.Markdown("# Image Captioning")
with gr.Row():
with gr.Column():
# Input components
input_image = gr.Image(label="Upload Image", type="pil")
input_text = gr.Textbox(label="Caption")
model_selector = gr.Dropdown([salesforce_model_name, noamrot_model_name],
label = "Select Model")
# Process button
process_btn = gr.Button("Generate caption")
with gr.Column():
# Output components
elapsed_result = gr.Textbox(label="Seconds elapsed", lines=1)
output_text = gr.Textbox(label="Generated caption")
# Connect the input components to the processing function
process_btn.click(
fn=gradio_process,
inputs=[
model_selector,
input_image,
input_text
],
outputs=[output_text, elapsed_result]
)
return image_captioning_tab