import os from pathlib import Path import pandas as pd, numpy as np from transformers import CLIPProcessor, CLIPTextModel, CLIPModel import torch from torch import nn import gradio as gr import requests from PIL import Image, ImageFile ImageFile.LOAD_TRUNCATED_IMAGES = True LABELS = Path('class_names.txt').read_text().splitlines() class_model = nn.Sequential( nn.Conv2d(1, 32, 3, padding='same'), nn.ReLU(), nn.MaxPool2d(2), nn.Conv2d(32, 64, 3, padding='same'), nn.ReLU(), nn.MaxPool2d(2), nn.Conv2d(64, 128, 3, padding='same'), nn.ReLU(), nn.MaxPool2d(2), nn.Flatten(), nn.Linear(1152, 256), nn.ReLU(), nn.Linear(256, len(LABELS)), ) state_dict = torch.load('pytorch_model.bin', map_location='cpu') class_model.load_state_dict(state_dict, strict=False) class_model.eval() model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") df = pd.read_csv('clip.csv') embeddings_npy = np.load('clip.npy') embeddings = np.divide(embeddings_npy, np.sqrt(np.sum(embeddings_npy**2, axis=1, keepdims=True))) def compute_text_embeddings(list_of_strings): inputs = processor(text=list_of_strings, return_tensors="pt", padding=True) return model.get_text_features(**inputs) def compute_image_embeddings(list_of_images): inputs = processor(images=list_of_images, return_tensors="pt", padding=True) return model.get_image_features(**inputs) def load_image(image, same_height=False): # im = Image.open(path) im = Image.fromarray(np.uint8(image)) if im.mode != 'RGB': im = im.convert('RGB') if same_height: ratio = 224/im.size[1] return im.resize((int(im.size[0]*ratio), int(im.size[1]*ratio))) else: ratio = 224/min(im.size) return im.resize((int(im.size[0]*ratio), int(im.size[1]*ratio))) def download_img(identifier, url): local_path = f"{identifier}.jpg" if not os.path.isfile(local_path): img_data = requests.get(url).content with open(local_path, 'wb') as handler: handler.write(img_data) return local_path def predict(image=None, text=None, sketch=None): if image is not None: input_embeddings = compute_image_embeddings([load_image(image)]).detach().numpy() topk = {"local": 100} else: if text: query = text topk = {text: 100} else: x = torch.tensor(sketch, dtype=torch.float32).unsqueeze(0).unsqueeze(0) / 255. with torch.no_grad(): out = class_model(x) probabilities = torch.nn.functional.softmax(out[0], dim=0) values, indices = torch.topk(probabilities, 5) query = LABELS[indices[0]] topk = {LABELS[i]: v.item() for i, v in zip(indices, values)} input_embeddings = compute_text_embeddings([query]).detach().numpy() n_results = 3 results = np.argsort((embeddings @ input_embeddings.T)[:, 0])[-1:-n_results - 1:-1] outputs = [download_img(df.iloc[i]['id'], df.iloc[i]['thumbnail']) for i in results] outputs.insert(0, topk) print(outputs) return outputs def predict_text(text): return predict(None, text, None) title = "Type to search in the Nasjonalbiblioteket" description = "Find images in the Nasjonalbiblioteket image collections based on what you type" interface = gr.Interface( fn=predict_text, inputs=["text"], outputs=[gr.outputs.Label(num_top_classes=3), gr.outputs.Image(type="filepath"), gr.outputs.Image(type="filepath"), gr.outputs.Image(type="filepath")], title=title, description=description, #live=True, examples=[ ["kids playing in the snow"], ["walking in the dark"], ["woman sitting on a chair while drinking a beer"], ["nice view out the window on a train"], ], ) interface.launch(debug=True)