import gradio as gr from PIL import Image from transformers import BridgeTowerForImageAndTextRetrieval, BridgeTowerProcessor model_id = "BridgeTower/bridgetower-large-itm-mlm-gaudi" processor = BridgeTowerProcessor.from_pretrained(model_id) model = BridgeTowerForImageAndTextRetrieval.from_pretrained(model_id) # Process an image def process(image, texts): scores = {} texts = texts.split(",") for text in texts: encoding = processor(image, text, return_tensors="pt") outputs = model(**encoding) scores[text] = "{:.2f}".format(outputs.logits[0, 1].item()) # sort scores in descending order scores = dict(sorted(scores.items(), key=lambda item: item[1], reverse=True)) return scores # Inputs image = gr.Image(label="Image") texts = gr.Text(label="List of comma-separated texts") # Output scores = gr.JSON(label="Scores") description = "This Space lets you score a list of texts on an image.\ This can be used to find the most relevant text for an image, or for semantic search on images." iface = gr.Interface( theme="huggingface", description=description, fn=process, inputs=[image, texts], outputs=scores, examples=[ [ "example1.jpg", "a metal band on stage, a chamber orchestra on stage, a giant rubber duck, a machine learning meetup", ], [ "example2.jpg", "medieval art, religious art, a group of angels, a movie poster", ], ], allow_flagging="never", ) iface.launch()