from transformers import BlipForConditionalGeneration, BlipProcessor import time import gradio as gr def get_image_captioning_tab(): salesforce_model_name = "Salesforce/blip-image-captioning-base" salesforce_model = BlipForConditionalGeneration.from_pretrained(salesforce_model_name) salesforce_processor = BlipProcessor.from_pretrained(salesforce_model_name) noamrot_model_name = "noamrot/FuseCap_Image_Captioning" noamrot_model = BlipForConditionalGeneration.from_pretrained(noamrot_model_name) noamrot_processor = BlipProcessor.from_pretrained(noamrot_model_name) model_map = { salesforce_model_name: (salesforce_model, salesforce_processor), noamrot_model_name: (noamrot_model, noamrot_processor) } def gradio_process(model_name, image, text): (model, processor) = model_map[model_name] start = time.time() inputs = processor(image, text, return_tensors="pt") out = model.generate(**inputs) result = processor.decode(out[0], skip_special_tokens=True) end = time.time() time_spent = end - start return [result, time_spent] with gr.TabItem("Image Captioning") as image_captioning_tab: gr.Markdown("# Image Captioning") with gr.Row(): with gr.Column(): # Input components input_image = gr.Image(label="Upload Image", type="pil") input_text = gr.Textbox(label="Caption") model_selector = gr.Dropdown([salesforce_model_name, noamrot_model_name], label = "Select Model") # Process button process_btn = gr.Button("Generate caption") with gr.Column(): # Output components elapsed_result = gr.Textbox(label="Seconds elapsed", lines=1) output_text = gr.Textbox(label="Generated caption") # Connect the input components to the processing function process_btn.click( fn=gradio_process, inputs=[ model_selector, input_image, input_text ], outputs=[output_text, elapsed_result] ) return image_captioning_tab