Julien Simon
Cleanup
6753426
raw
history blame
1.55 kB
import gradio as gr
from PIL import Image
from transformers import BridgeTowerForImageAndTextRetrieval, BridgeTowerProcessor
model_id = "BridgeTower/bridgetower-large-itm-mlm-gaudi"
processor = BridgeTowerProcessor.from_pretrained(model_id)
model = BridgeTowerForImageAndTextRetrieval.from_pretrained(model_id)
# Process an image
def process(image, texts):
scores = {}
texts = texts.split(",")
for text in texts:
encoding = processor(image, text, return_tensors="pt")
outputs = model(**encoding)
scores[text] = "{:.2f}".format(outputs.logits[0, 1].item())
# sort scores in descending order
scores = dict(sorted(scores.items(), key=lambda item: item[1], reverse=True))
return scores
# Inputs
image = gr.Image(label="Image")
texts = gr.Text(label="List of comma-separated texts")
# Output
scores = gr.JSON(label="Scores")
description = "This Space lets you score a list of texts on an image.\
This can be used to find the most relevant text for an image, or for semantic search on images."
iface = gr.Interface(
theme="huggingface",
description=description,
fn=process,
inputs=[image, texts],
outputs=scores,
examples=[
[
"example1.jpg",
"a metal band on stage, a chamber orchestra on stage, a giant rubber duck, a machine learning meetup",
],
[
"example2.jpg",
"medieval art, religious art, a group of angels, a movie poster",
],
],
allow_flagging="never",
)
iface.launch()