LayBraid
add requirements.txt
998ea00
raw history blame
No virus
940 Bytes
import clip
import gradio as gr
import os
import torch
from torchvision.datasets import CIFAR100
from transformers import CLIPProcessor, CLIPModel
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
cifar100 = CIFAR100(root=os.path.expanduser("~/.cache"), download=True, train=False)
text_inputs = torch.cat([clip.tokenize(f"a photo of a {c}") for c in cifar100.classes])
# TODO debug cette ligne pour avoir un affichage correct
# TODO Finir l'affichage du résultat
def send_inputs(img):
inputs = processor(text=cifar100.classes, images=img, return_tensors="pt", padding=True)
outputs = model(**inputs)
logits_per_image = outputs.logits_per_image
probs = logits_per_image.softmax(dim=1)
print(probs)
return probs
if __name__ == "__main__":
gr.Interface(fn=send_inputs, inputs=["image"], outputs="text").launch()