merve HF staff commited on
Commit
4ebb492
1 Parent(s): 6f1542f

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -0
app.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gradio as gr
3
+ from PIL import Image
4
+ from transformers import AutoProcessor, SiglipModel
5
+ import faiss
6
+ import numpy as np
7
+ from huggingface_hub import hf_hub_download
8
+ from datasets import load_dataset
9
+
10
+ hf_hub_download("merve/siglip-faiss-wikiart", "siglip_new.index", local_dir="./")
11
+ index = faiss.read_index("./siglip_new.index")
12
+
13
+ dataset = load_dataset("huggan/wikiart")
14
+ device = torch.device('cuda' if torch.cuda.is_available() else "cpu")
15
+ dataset = dataset.with_format("torch", device=device)
16
+
17
+ processor = AutoProcessor.from_pretrained("nielsr/siglip-base-patch16-224")
18
+ model = SiglipModel.from_pretrained("nielsr/siglip-base-patch16-224").to(device)
19
+
20
+
21
+ def extract_features_siglip(image):
22
+ with torch.no_grad():
23
+ inputs = processor(images=image, return_tensors="pt").to(device)
24
+ image_features = model.get_image_features(**inputs)
25
+ return image_features
26
+
27
+ def infer(input_image):
28
+ input_features = extract_features_siglip(input_image)
29
+ input_features = input_features.detach().cpu().numpy()
30
+ input_features = np.float32(input_features)
31
+ faiss.normalize_L2(input_features)
32
+ distances, indices = index2.search(input_features, 9)
33
+ gallery_output = []
34
+ for i,v in enumerate(indices[0]):
35
+ sim = -distances[0][i]
36
+ img_resized = dataset["train"][int(v)]['image']
37
+ gallery_output.append(img_resized)
38
+ return gallery_output
39
+
40
+ gr.Interface(infer, "sketchpad", "gallery", title="Draw to Search Art 🖼️").launch()
41
+
42
+