File size: 1,037 Bytes
998ea00
81898e8
744fb8a
 
 
565dc8e
81898e8
565dc8e
 
744fb8a
33b2c35
744fb8a
3432a10
744fb8a
3432a10
 
 
 
744fb8a
3432a10
744fb8a
3432a10
744fb8a
 
 
3432a10
 
565dc8e
 
 
3432a10
 
 
 
 
81898e8
 
 
9251afa
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import clip
import gradio as gr
import os
import torch
from torchvision.datasets import CIFAR100
from transformers import CLIPProcessor, CLIPModel

model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

cifar100 = CIFAR100(root=os.path.expanduser("~/.cache"), download=True, train=False)

text_inputs = []

for c in cifar100.classes:
    classes = "a photo of a " + c
    print(classes)
    text_inputs.append(classes)

print(text_inputs)

test = ["a photo of a dog", "a photo of a cat"]


def send_inputs(img):
    inputs = processor(text=test, images=img, return_tensors="pt", padding=True)

    outputs = model(**inputs)
    logits_per_image = outputs.logits_per_image
    probs = logits_per_image.softmax(dim=1)

    result = probs.argmax(dim=1)
    index = result.item()
    print(test[index])
    return test[index]


if __name__ == "__main__":
    gr.Interface(title="CAT OR DOG ???", fn=send_inputs, inputs=["image"], outputs="text").launch()