evelyncsb's picture
Update app.py
7f0b913
raw
history blame
No virus
1.49 kB
import gradio as gr
import os
import skimage
import IPython.display
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
from collections import OrderedDict
import torch
from imagebind import data
from imagebind.models import imagebind_model
from imagebind.models.imagebind_model import ModalityType
import torch.nn as nn
device = "cuda:0" if torch.cuda.is_available() else "cpu"
model = imagebind_model.imagebind_huge(pretrained=True)
model.eval()
model.to(device)
def image_text_zeroshot(image, text_list):
image_paths = [image]
labels = [label.strip(" ") for label in text_list.strip(" ").split("|")]
inputs = {
ModalityType.TEXT: data.load_and_transform_text(labels, device),
ModalityType.VISION: data.load_and_transform_vision_data(image_paths, device),
}
with torch.no_grad():
embeddings = model(inputs)
scores = (
torch.softmax(
embeddings[ModalityType.VISION] @ embeddings[ModalityType.TEXT].T, dim=-1
)
.squeeze(0)
.tolist()
)
score_dict = {label: score for label, score in zip(labels, scores)}
return score_dict
def main():
inputs = [
gr.inputs.Textbox(lines=1, label="texts"),
gr.inputs.Image(type="filepath", label="Input image")
]
iface = gr.Interface(
image_text_zeroshot(image, text_list),
inputs,
"label",
description="""...""",
title="ImageBind",
)
iface.launch()