File size: 2,951 Bytes
7cfb866
 
 
 
 
 
 
 
 
2ad4686
7cfb866
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3f0ac21
7cfb866
 
 
b4abc27
7cfb866
b4abc27
7cfb866
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import torch
import numpy as np
import pandas as pd
import gradio as gr
from io import BytesIO
from PIL import Image as PILIMAGE
from IPython.display import Image
from IPython.core.display import HTML
from transformers import CLIPProcessor, CLIPModel, CLIPTokenizer
import os


device = "cuda" if torch.cuda.is_available() else "cpu"
model = CLIPModel.from_pretrained("vesteinn/clip-nabirds").to(device)
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")


def load_class_names(dataset_path=''):
  names = {}
  with open(os.path.join(dataset_path, 'classes.txt')) as f:
    for line in f:
      pieces = line.strip().split()
      class_id = pieces[0]
      names[class_id] = ' '.join(pieces[1:])
  
  return names


def get_labels():
    labels = []
    class_names = load_class_names(".")
    for _, name in class_names.items():
        labels.append(f"This is a photo of {name}.")
    return labels


def encode_text(text):
  with torch.no_grad():
      inputs = tokenizer([text],  padding=True, return_tensors="pt")
  text_encoded =  model.get_text_features(**inputs).detach().numpy()
  return text_encoded


ALL_LABELS = get_labels()
try:
    LABEL_FEATURES = np.load("label_features.np")
except:
    LABEL_FEATURES = []
    for label in ALL_LABELS:
        LABEL_FEATURES.append(encode_text(label))
    LABEL_FEATURES = np.vstack(LABEL_FEATURES)
    np.save(open("label_features.np", "wb"), LABEL_FEATURES)


def encode_image(image):
  image = PILIMAGE.fromarray(image.astype('uint8'), 'RGB')
  with torch.no_grad():
        photo_preprocessed = processor(text=None, images=image, return_tensors="pt", padding=True)["pixel_values"]
        search_photo_feature = model.get_image_features(photo_preprocessed.to(device))
        search_photo_feature /= search_photo_feature.norm(dim=-1, keepdim=True)
  image_encoded = search_photo_feature.cpu().numpy()
  return image_encoded


def similarity(feature, label_features):
    similarities = list((feature @ label_features.T).squeeze(0))
    return similarities


def find_best_matches(image):
    image_features = encode_image(image)
    similarities = similarity(image_features, LABEL_FEATURES)
    best_spec = sorted(zip(similarities, range(LABEL_FEATURES.shape[0])), key=lambda x: x[0], reverse=True)
    idx = best_spec[0][1]
    label = ALL_LABELS[idx]
    return label

examples=[['bj.jpg'],['duckly.jpg'],['some.jpg'],['turdus.jpg'],['seag.jpg'],['thursh.jpg'], ['woodcock.jpeg'],['dipper.jpeg']]

gr.Interface(fn=find_best_matches,
                     inputs=[
            gr.inputs.Image(label="Image to classify", optional=False),
            ],
            examples=examples,
            theme="grass",
            outputs=gr.outputs.Label(), enable_queue=True, title="North American Bird Classifier",
        description="This application can classify North American Birds.").launch()