LayBraid
update app
33b2c35
raw history blame
No virus
1.01 kB
import clip
import gradio as gr
import os
import torch
from torchvision.datasets import CIFAR100
from transformers import CLIPProcessor, CLIPModel
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
cifar100 = CIFAR100(root=os.path.expanduser("~/.cache"), download=True, train=False)
text_inputs = []
for c in cifar100.classes:
classes = "a photo of a " + c
print(classes)
text_inputs.append(classes)
print(text_inputs)
test = ["a photo of a dog", "a photo of a cat"]
def send_inputs(img):
inputs = processor(text=test, images=img, return_tensors="pt", padding=True)
outputs = model(**inputs)
logits_per_image = outputs.logits_per_image
probs = logits_per_image.softmax(dim=1)
result = probs.argmax(dim=1)
index = result.item()
print(test[index])
return test[index]
if __name__ == "__main__":
gr.Interface(fn=send_inputs, inputs=["image"], outputs="text").launch()