import torch import numpy as np import pandas as pd import gradio as gr from io import BytesIO from PIL import Image as PILIMAGE from IPython.display import Image from IPython.core.display import HTML from transformers import CLIPProcessor, CLIPModel, CLIPTokenizer import os device = "cuda" if torch.cuda.is_available() else "cpu" model = CLIPModel.from_pretrained("vesteinn/clip-nabirds").to(device) processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32") def load_class_names(dataset_path=''): names = {} with open(os.path.join(dataset_path, 'classes.txt')) as f: for line in f: pieces = line.strip().split() class_id = pieces[0] names[class_id] = ' '.join(pieces[1:]) return names def get_labels(): labels = [] class_names = load_class_names(".") for _, name in class_names.items(): labels.append(f"This is a photo of {name}.") return labels def encode_text(text): with torch.no_grad(): inputs = tokenizer([text], padding=True, return_tensors="pt") text_encoded = model.get_text_features(**inputs).detach().numpy() return text_encoded ALL_LABELS = get_labels() try: LABEL_FEATURES = np.load("label_features.np") except: LABEL_FEATURES = [] for label in ALL_LABELS: LABEL_FEATURES.append(encode_text(label)) LABEL_FEATURES = np.vstack(LABEL_FEATURES) np.save(open("label_features.np", "wb"), LABEL_FEATURES) def encode_image(image): image = PILIMAGE.fromarray(image.astype('uint8'), 'RGB') with torch.no_grad(): photo_preprocessed = processor(text=None, images=image, return_tensors="pt", padding=True)["pixel_values"] search_photo_feature = model.get_image_features(photo_preprocessed.to(device)) search_photo_feature /= search_photo_feature.norm(dim=-1, keepdim=True) image_encoded = search_photo_feature.cpu().numpy() return image_encoded def similarity(feature, label_features): similarities = list((feature @ label_features.T).squeeze(0)) return similarities def find_best_matches(image): image_features = encode_image(image) similarities = similarity(image_features, LABEL_FEATURES) best_spec = sorted(zip(similarities, range(LABEL_FEATURES.shape[0])), key=lambda x: x[0], reverse=True) idx = best_spec[0][1] label = ALL_LABELS[idx] return label examples=[['bj.jpg'],['duckly.jpg'],['some.jpg'],['turdus.jpg'],['seag.jpg'],['thursh.jpg'], ['woodcock.jpeg'],['dipper.jpeg']] gr.Interface(fn=find_best_matches, inputs=[ gr.inputs.Image(label="Image to classify", optional=False), ], examples=examples, theme="grass", outputs=gr.outputs.Label(), enable_queue=True, title="North American Bird Classifier", description="This application can classify North American Birds.").launch()