#!/usr/bin/env python from __future__ import annotations import os import string import gradio as gr import PIL.Image import torch from transformers import InstructBlipProcessor, InstructBlipForConditionalGeneration DESCRIPTION = "# [BLIP-2 test](https://github.com/salesforce/LAVIS/tree/main/projects/blip2)" if not torch.cuda.is_available(): DESCRIPTION += "\n

Running on CPU.

" device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") MODEL_ID = "Salesforce/instructblip-flan-t5-xl" processor = InstructBlipProcessor.from_pretrained(MODEL_ID) model = InstructBlipForConditionalGeneration.from_pretrained(MODEL_ID).to(device) def answer_ad_listing_question( image: PIL.Image.Image, title: str, ) -> str: # The prompt template with the provided title prompt = f"""Given an ad listing with the title '{title}' and image, answer the following questions without any explanation or extra text: Identify the species mentioned in the text, including specific names, e.g., 'Nile crocodile' instead of just 'crocodile'. Select the product type from the following options: Animal fibers, Animal parts (bone or bone-like), Animal parts (fleshy), Coral product, Egg, Extract, Food, Ivory products, Live, Medicine, Nests, Organs and tissues, Powder, Scales or spines, Shells, Skin or leather products, Taxidermy, Insects. The response should be in the format: "Product Type: [type] Species: [species]" """ inputs = processor(images=image, text=prompt, return_tensors="pt").to(device, torch.float16) generated_ids = model.generate(**inputs) result = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip() return result def postprocess_output(output: str) -> str: # if output and output[-1] not in string.punctuation: # output += "." return output with gr.Blocks() as demo: gr.Markdown(DESCRIPTION) with gr.Group(): # Image and ad title input image = gr.Image(type="pil") ad_title = gr.Textbox(label="Advertisement Title", placeholder="Enter the title here", lines=1) # Output section answer_output = gr.Textbox(label="Analysis", show_label=True, placeholder="Response.") # Submit and clear buttons with gr.Row(): submit_button = gr.Button("Analyze Listing", variant="primary") clear_button = gr.Button("Clear") # Logic to handle clicking on "Analyze Ad Listing" submit_button.click( fn=answer_ad_listing_question, inputs=[image, ad_title], # Only the image and ad title are inputs outputs=answer_output, ) # Logic to handle clearing the inputs and outputs clear_button.click( fn=lambda: ("", "", ""), # Clear all the fields inputs=None, outputs=[image, ad_title, answer_output], queue=False, ) if __name__ == "__main__": demo.queue(max_size=10).launch()