LayBraid
update app
9251afa
raw history blame
No virus
1.04 kB
import clip
import gradio as gr
import os
import torch
from torchvision.datasets import CIFAR100
from transformers import CLIPProcessor, CLIPModel
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
cifar100 = CIFAR100(root=os.path.expanduser("~/.cache"), download=True, train=False)
text_inputs = []
for c in cifar100.classes:
classes = "a photo of a " + c
print(classes)
text_inputs.append(classes)
print(text_inputs)
test = ["a photo of a dog", "a photo of a cat"]
def send_inputs(img):
inputs = processor(text=test, images=img, return_tensors="pt", padding=True)
outputs = model(**inputs)
logits_per_image = outputs.logits_per_image
probs = logits_per_image.softmax(dim=1)
result = probs.argmax(dim=1)
index = result.item()
print(test[index])
return test[index]
if __name__ == "__main__":
gr.Interface(title="CAT OR DOG ???", fn=send_inputs, inputs=["image"], outputs="text").launch()