import os from pathlib import Path import pandas as pd, numpy as np from transformers import CLIPProcessor, CLIPTextModel, CLIPModel import torch from torch import nn import gradio as gr import requests LABELS = Path('class_names.txt').read_text().splitlines() class_model = nn.Sequential( nn.Conv2d(1, 32, 3, padding='same'), nn.ReLU(), nn.MaxPool2d(2), nn.Conv2d(32, 64, 3, padding='same'), nn.ReLU(), nn.MaxPool2d(2), nn.Conv2d(64, 128, 3, padding='same'), nn.ReLU(), nn.MaxPool2d(2), nn.Flatten(), nn.Linear(1152, 256), nn.ReLU(), nn.Linear(256, len(LABELS)), ) state_dict = torch.load('pytorch_model.bin', map_location='cpu') class_model.load_state_dict(state_dict, strict=False) class_model.eval() model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") df = pd.read_csv('data2.csv') embeddings_npy = np.load('embeddings.npy') embeddings = np.divide(embeddings_npy, np.sqrt(np.sum(embeddings_npy**2, axis=1, keepdims=True))) def compute_text_embeddings(list_of_strings): inputs = processor(text=list_of_strings, return_tensors="pt", padding=True) return model.get_text_features(**inputs) def download_img(path): img_data = requests.get(path).content local_path = path.split("/")[-1] with open(local_path, 'wb') as handler: handler.write(img_data) return local_path def predict(im): x = torch.tensor(im, dtype=torch.float32).unsqueeze(0).unsqueeze(0) / 255. with torch.no_grad(): out = class_model(x) probabilities = torch.nn.functional.softmax(out[0], dim=0) values, indices = torch.topk(probabilities, 5) query = LABELS[indices[0]] n_results=3 text_embeddings = compute_text_embeddings([query]).detach().numpy() results = np.argsort((embeddings@text_embeddings.T)[:, 0])[-1:-n_results-1:-1] outputs = [download_img(df.iloc[i]['path']) for i in results] outputs.insert(0, {LABELS[i]: v.item() for i, v in zip(indices, values)}) print(outputs) return outputs title = "Draw to Search" description = "Using the power of CLIP and a simple small CNN, find images from movies based on what you draw!" iface = gr.Interface( fn=predict, inputs='sketchpad', outputs=[gr.outputs.Label(num_top_classes=3), gr.outputs.Image(type="file"), gr.outputs.Image(type="file"), gr.outputs.Image(type="file")], title=title, description=description, live=True ) iface.launch(debug=True)