lotrlol's picture
Duplicate from flava/zero-shot-image-classification
a3c131d
raw
history blame contribute delete
No virus
1.57 kB
import gradio as gr
from transformers import FlavaModel, BertTokenizer, FlavaFeatureExtractor
import numpy as np
from PIL import Image
import torch
images="dog.jpg"
model = FlavaModel.from_pretrained("facebook/flava-full")
model.eval()
fe = FlavaFeatureExtractor.from_pretrained("facebook/flava-full")
tokenizer = BertTokenizer.from_pretrained("facebook/flava-full")
def shot(image, labels_text):
PIL_image = Image.fromarray(np.uint8(image)).convert('RGB')
labels = labels_text.split(",")
label_with_template = [f"This is a photo of a {label}" for label in labels]
image_input = fe([PIL_image], return_tensors="pt")
text_inputs = tokenizer(label_with_template, padding="max_length", return_tensors="pt")
image_embeddings = model.get_image_features(**image_input)[:, 0, :]
text_embeddings = model.get_text_features(**text_inputs)[:, 0, :]
similarities = list(torch.nn.functional.softmax((text_embeddings @ image_embeddings.T).squeeze(0), dim=0))
return {label: similarities[idx].item() for idx, label in enumerate(labels)}
iface = gr.Interface(shot,
["image", "text"],
"label",
examples=[["dog.jpg", "dog,cat,bird"],
["germany.jpg", "germany,belgium,colombia"],
["rocket.jpg", "car,rocket,train"]
],
description="Add a picture and a list of labels separated by commas",
title="FLAVA Zero-shot Image Classification")
iface.launch()