vesteinn's picture
Update app.py
b422fae
import torch
import numpy as np
import pandas as pd
import gradio as gr
from io import BytesIO
from PIL import Image as PILIMAGE
#from IPython.display import Image
#from IPython.core.display import HTML
from transformers import CLIPProcessor, CLIPModel, CLIPTokenizer
import os
device = "cuda" if torch.cuda.is_available() else "cpu"
model = CLIPModel.from_pretrained("vesteinn/clip-nabirds").to(device)
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
def load_class_names(dataset_path=''):
names = {}
with open(os.path.join(dataset_path, 'classes.txt')) as f:
for line in f:
pieces = line.strip().split()
class_id = pieces[0]
names[class_id] = ' '.join(pieces[1:])
return names
def get_labels():
labels = []
class_names = load_class_names(".")
for _, name in class_names.items():
labels.append(f"This is a photo of {name}.")
return labels
def encode_text(text):
with torch.no_grad():
inputs = tokenizer([text], padding=True, return_tensors="pt")
text_encoded = model.get_text_features(**inputs).detach().numpy()
return text_encoded
ALL_LABELS = get_labels()
try:
LABEL_FEATURES = np.load("label_features.np")
except:
LABEL_FEATURES = []
for label in ALL_LABELS:
LABEL_FEATURES.append(encode_text(label))
LABEL_FEATURES = np.vstack(LABEL_FEATURES)
np.save(open("label_features.np", "wb"), LABEL_FEATURES)
def encode_image(image):
image = PILIMAGE.fromarray(image.astype('uint8'), 'RGB')
with torch.no_grad():
photo_preprocessed = processor(text=None, images=image, return_tensors="pt", padding=True)["pixel_values"]
search_photo_feature = model.get_image_features(photo_preprocessed.to(device))
search_photo_feature /= search_photo_feature.norm(dim=-1, keepdim=True)
image_encoded = search_photo_feature.cpu().numpy()
return image_encoded
def similarity(feature, label_features):
similarities = list((feature @ label_features.T).squeeze(0))
return similarities
def find_best_matches(image):
image_features = encode_image(image)
similarities = similarity(image_features, LABEL_FEATURES)
best_spec = sorted(zip(similarities, range(LABEL_FEATURES.shape[0])), key=lambda x: x[0], reverse=True)
idx = best_spec[0][1]
label = ALL_LABELS[idx]
return label
examples=[['bj.jpg'],['duckly.jpg'],['some.jpg'],['turdus.jpg'],['seag.jpg'],['thursh.jpg'], ['woodcock.jpeg'],['dipper.jpeg']]
gr.Interface(fn=find_best_matches,
inputs=[
gr.inputs.Image(label="Image to classify", optional=False),
],
examples=examples,
theme="grass",
outputs=gr.outputs.Label(), enable_queue=True, title="North American Bird Classifier",
description="This application can classify North American Birds.").launch()