File size: 2,953 Bytes
7cfb866
 
 
 
 
 
b422fae
 
7cfb866
2ad4686
7cfb866
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3f0ac21
7cfb866
 
 
b4abc27
7cfb866
b4abc27
7cfb866
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import torch
import numpy as np
import pandas as pd
import gradio as gr
from io import BytesIO
from PIL import Image as PILIMAGE
#from IPython.display import Image
#from IPython.core.display import HTML
from transformers import CLIPProcessor, CLIPModel, CLIPTokenizer
import os


device = "cuda" if torch.cuda.is_available() else "cpu"
model = CLIPModel.from_pretrained("vesteinn/clip-nabirds").to(device)
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")


def load_class_names(dataset_path=''):
  names = {}
  with open(os.path.join(dataset_path, 'classes.txt')) as f:
    for line in f:
      pieces = line.strip().split()
      class_id = pieces[0]
      names[class_id] = ' '.join(pieces[1:])
  
  return names


def get_labels():
    labels = []
    class_names = load_class_names(".")
    for _, name in class_names.items():
        labels.append(f"This is a photo of {name}.")
    return labels


def encode_text(text):
  with torch.no_grad():
      inputs = tokenizer([text],  padding=True, return_tensors="pt")
  text_encoded =  model.get_text_features(**inputs).detach().numpy()
  return text_encoded


ALL_LABELS = get_labels()
try:
    LABEL_FEATURES = np.load("label_features.np")
except:
    LABEL_FEATURES = []
    for label in ALL_LABELS:
        LABEL_FEATURES.append(encode_text(label))
    LABEL_FEATURES = np.vstack(LABEL_FEATURES)
    np.save(open("label_features.np", "wb"), LABEL_FEATURES)


def encode_image(image):
  image = PILIMAGE.fromarray(image.astype('uint8'), 'RGB')
  with torch.no_grad():
        photo_preprocessed = processor(text=None, images=image, return_tensors="pt", padding=True)["pixel_values"]
        search_photo_feature = model.get_image_features(photo_preprocessed.to(device))
        search_photo_feature /= search_photo_feature.norm(dim=-1, keepdim=True)
  image_encoded = search_photo_feature.cpu().numpy()
  return image_encoded


def similarity(feature, label_features):
    similarities = list((feature @ label_features.T).squeeze(0))
    return similarities


def find_best_matches(image):
    image_features = encode_image(image)
    similarities = similarity(image_features, LABEL_FEATURES)
    best_spec = sorted(zip(similarities, range(LABEL_FEATURES.shape[0])), key=lambda x: x[0], reverse=True)
    idx = best_spec[0][1]
    label = ALL_LABELS[idx]
    return label

examples=[['bj.jpg'],['duckly.jpg'],['some.jpg'],['turdus.jpg'],['seag.jpg'],['thursh.jpg'], ['woodcock.jpeg'],['dipper.jpeg']]

gr.Interface(fn=find_best_matches,
                     inputs=[
            gr.inputs.Image(label="Image to classify", optional=False),
            ],
            examples=examples,
            theme="grass",
            outputs=gr.outputs.Label(), enable_queue=True, title="North American Bird Classifier",
        description="This application can classify North American Birds.").launch()