vesteinn's picture
Update app.py
b422fae
raw history blame
No virus
2.95 kB
import torch
import numpy as np
import pandas as pd
import gradio as gr
from io import BytesIO
from PIL import Image as PILIMAGE
#from IPython.display import Image
#from IPython.core.display import HTML
from transformers import CLIPProcessor, CLIPModel, CLIPTokenizer
import os
device = "cuda" if torch.cuda.is_available() else "cpu"
model = CLIPModel.from_pretrained("vesteinn/clip-nabirds").to(device)
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
def load_class_names(dataset_path=''):
names = {}
with open(os.path.join(dataset_path, 'classes.txt')) as f:
for line in f:
pieces = line.strip().split()
class_id = pieces[0]
names[class_id] = ' '.join(pieces[1:])
return names
def get_labels():
labels = []
class_names = load_class_names(".")
for _, name in class_names.items():
labels.append(f"This is a photo of {name}.")
return labels
def encode_text(text):
with torch.no_grad():
inputs = tokenizer([text], padding=True, return_tensors="pt")
text_encoded = model.get_text_features(**inputs).detach().numpy()
return text_encoded
ALL_LABELS = get_labels()
try:
LABEL_FEATURES = np.load("label_features.np")
except:
LABEL_FEATURES = []
for label in ALL_LABELS:
LABEL_FEATURES.append(encode_text(label))
LABEL_FEATURES = np.vstack(LABEL_FEATURES)
np.save(open("label_features.np", "wb"), LABEL_FEATURES)
def encode_image(image):
image = PILIMAGE.fromarray(image.astype('uint8'), 'RGB')
with torch.no_grad():
photo_preprocessed = processor(text=None, images=image, return_tensors="pt", padding=True)["pixel_values"]
search_photo_feature = model.get_image_features(photo_preprocessed.to(device))
search_photo_feature /= search_photo_feature.norm(dim=-1, keepdim=True)
image_encoded = search_photo_feature.cpu().numpy()
return image_encoded
def similarity(feature, label_features):
similarities = list((feature @ label_features.T).squeeze(0))
return similarities
def find_best_matches(image):
image_features = encode_image(image)
similarities = similarity(image_features, LABEL_FEATURES)
best_spec = sorted(zip(similarities, range(LABEL_FEATURES.shape[0])), key=lambda x: x[0], reverse=True)
idx = best_spec[0][1]
label = ALL_LABELS[idx]
return label
examples=[['bj.jpg'],['duckly.jpg'],['some.jpg'],['turdus.jpg'],['seag.jpg'],['thursh.jpg'], ['woodcock.jpeg'],['dipper.jpeg']]
gr.Interface(fn=find_best_matches,
inputs=[
gr.inputs.Image(label="Image to classify", optional=False),
],
examples=examples,
theme="grass",
outputs=gr.outputs.Label(), enable_queue=True, title="North American Bird Classifier",
description="This application can classify North American Birds.").launch()