import torch import numpy as np import pandas as pd import gradio as gr from io import BytesIO from PIL import Image as PILIMAGE #from IPython.display import Image #from IPython.core.display import HTML from transformers import CLIPProcessor, CLIPModel, CLIPTokenizer import os device = "cuda" if torch.cuda.is_available() else "cpu" model = CLIPModel.from_pretrained("vesteinn/clip-nabirds").to(device) processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32") def load_class_names(dataset_path=''): names = {} with open(os.path.join(dataset_path, 'classes.txt')) as f: for line in f: pieces = line.strip().split() class_id = pieces[0] names[class_id] = ' '.join(pieces[1:]) return names def get_labels(): labels = [] class_names = load_class_names(".") for _, name in class_names.items(): labels.append(f"This is a photo of {name}.") return labels def encode_text(text): with torch.no_grad(): inputs = tokenizer([text], padding=True, return_tensors="pt") text_encoded = model.get_text_features(**inputs).detach().numpy() return text_encoded ALL_LABELS = get_labels() try: LABEL_FEATURES = np.load("label_features.np") except: LABEL_FEATURES = [] for label in ALL_LABELS: LABEL_FEATURES.append(encode_text(label)) LABEL_FEATURES = np.vstack(LABEL_FEATURES) np.save(open("label_features.np", "wb"), LABEL_FEATURES) def encode_image(image): image = PILIMAGE.fromarray(image.astype('uint8'), 'RGB') with torch.no_grad(): photo_preprocessed = processor(text=None, images=image, return_tensors="pt", padding=True)["pixel_values"] search_photo_feature = model.get_image_features(photo_preprocessed.to(device)) search_photo_feature /= search_photo_feature.norm(dim=-1, keepdim=True) image_encoded = search_photo_feature.cpu().numpy() return image_encoded def similarity(feature, label_features): similarities = list((feature @ label_features.T).squeeze(0)) return similarities def find_best_matches(image): image_features = encode_image(image) similarities = similarity(image_features, LABEL_FEATURES) best_spec = sorted(zip(similarities, range(LABEL_FEATURES.shape[0])), key=lambda x: x[0], reverse=True) idx = best_spec[0][1] label = ALL_LABELS[idx] return label examples=[['bj.jpg'],['duckly.jpg'],['some.jpg'],['turdus.jpg'],['seag.jpg'],['thursh.jpg'], ['woodcock.jpeg'],['dipper.jpeg']] gr.Interface(fn=find_best_matches, inputs=[ gr.inputs.Image(label="Image to classify", optional=False), ], examples=examples, theme="grass", outputs=gr.outputs.Label(), enable_queue=True, title="North American Bird Classifier", description="This application can classify North American Birds.").launch()