import os from pathlib import Path import pandas as pd, numpy as np from transformers import CLIPProcessor, CLIPTextModel, CLIPModel import torch from torch import nn import gradio as gr import requests from PIL import Image, ImageFile from urllib.request import urlretrieve ImageFile.LOAD_TRUNCATED_IMAGES = True # Download sample images urlretrieve("https://huggingface.co/spaces/NbAiLab/maken-clip-image/resolve/main/Gibraltar_Barbary_Macaque.jpg","monkey.jpg") urlretrieve("https://huggingface.co/spaces/NbAiLab/maken-clip-image/resolve/main/buying-a-sailboat-checklist.jpg","sailboat.jpg") urlretrieve("https://huggingface.co/spaces/NbAiLab/maken-clip-image/resolve/main/lG5mI_9Co1obw2TiY0e-oChlXfEQY3tsRaIjpYjERqs.jpg","bicycle.jpg") sample_images = [ ["monkey.jpg"], ["sailboat.jpg"], ["bicycle.jpg"], ] LABELS = Path('class_names.txt').read_text().splitlines() class_model = nn.Sequential( nn.Conv2d(1, 32, 3, padding='same'), nn.ReLU(), nn.MaxPool2d(2), nn.Conv2d(32, 64, 3, padding='same'), nn.ReLU(), nn.MaxPool2d(2), nn.Conv2d(64, 128, 3, padding='same'), nn.ReLU(), nn.MaxPool2d(2), nn.Flatten(), nn.Linear(1152, 256), nn.ReLU(), nn.Linear(256, len(LABELS)), ) state_dict = torch.load('pytorch_model.bin', map_location='cpu') class_model.load_state_dict(state_dict, strict=False) class_model.eval() model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") df = pd.read_csv('clip.csv') embeddings_npy = np.load('clip.npy') embeddings = np.divide(embeddings_npy, np.sqrt(np.sum(embeddings_npy**2, axis=1, keepdims=True))) def compute_text_embeddings(list_of_strings): inputs = processor(text=list_of_strings, return_tensors="pt", padding=True) return model.get_text_features(**inputs) def compute_image_embeddings(list_of_images): inputs = processor(images=list_of_images, return_tensors="pt", padding=True) return model.get_image_features(**inputs) def load_image(image, same_height=False): # im = Image.open(path) im = Image.fromarray(np.uint8(image)) if im.mode != 'RGB': im = im.convert('RGB') if same_height: ratio = 224/im.size[1] return im.resize((int(im.size[0]*ratio), int(im.size[1]*ratio))) else: ratio = 224/min(im.size) return im.resize((int(im.size[0]*ratio), int(im.size[1]*ratio))) def download_img(identifier, url): local_path = f"{identifier}.jpg" if not os.path.isfile(local_path): img_data = requests.get(url).content with open(local_path, 'wb') as handler: handler.write(img_data) return local_path def predict(image=None, text=None, sketch=None): if image is not None: input_embeddings = compute_image_embeddings([load_image(image)]).detach().numpy() topk = {"local": 100} else: if text: query = text topk = {text: 100} else: x = torch.tensor(sketch, dtype=torch.float32).unsqueeze(0).unsqueeze(0) / 255. with torch.no_grad(): out = class_model(x) probabilities = torch.nn.functional.softmax(out[0], dim=0) values, indices = torch.topk(probabilities, 5) query = LABELS[indices[0]] topk = {LABELS[i]: v.item() for i, v in zip(indices, values)} input_embeddings = compute_text_embeddings([query]).detach().numpy() n_results = 3 results = np.argsort((embeddings @ input_embeddings.T)[:, 0])[-1:-n_results - 1:-1] outputs = [download_img(df.iloc[i]['id'], df.iloc[i]['thumbnail']) for i in results] outputs.insert(0, topk) print(outputs) return outputs def predict_image(image): return predict(image, None, None) def predict_text(image=None, text=None, sketch=None): return predict(None, text, None) def predict_sketch(image=None, text=None, sketch=None): return predict(None, None, image) title = "Upload an image to search in the Nasjonalbiblioteket" description = "Find images in the Nasjonalbiblioteket image collections based on images you upload" interface = gr.Interface( fn=predict_image, inputs=["image"], outputs=[gr.outputs.Label(num_top_classes=3), gr.outputs.Image(type="file"), gr.outputs.Image(type="file"), gr.outputs.Image(type="file")], title=title, description=description, examples=sample_images, #live=True ) interface.launch(debug=True)