import data import torch import gradio as gr from models import imagebind_model from models.imagebind_model import ModalityType device = "cuda:0" if torch.cuda.is_available() else "cpu" model = imagebind_model.imagebind_huge(pretrained=True) model.eval() model.to(device) def image_text_zeroshot(image, text_list): image_paths = [image] labels = [label.strip(" ") for label in text_list.strip(" ").split("|")] inputs = { ModalityType.TEXT: data.load_and_transform_text(labels, device), ModalityType.VISION: data.load_and_transform_vision_data(image_paths, device), } with torch.no_grad(): embeddings = model(inputs) scores = torch.softmax( embeddings[ModalityType.VISION] @ embeddings[ModalityType.TEXT].T, dim=-1 ).squeeze(0).tolist() score_dict = {label:score for label, score in zip(labels, scores)} return score_dict inputs = [ gr.inputs.Image(type='file', label="Input image"), gr.inputs.Textbox(lines=1, label="Candidate texts"), ] iface = gr.Interface(image_text_zeroshot, inputs, "label", examples=[["assets/dog_image.jpg", "A dog|A car|A bird"], ["assets/car_image.jpg", "A dog|A car|A bird"], ["assets/bird_image.jpg", "A dog|A car|A bird"]], description="""Zeroshot test""", title="Zero-shot Classification") iface.launch()