merve HF staff commited on
Commit
e0de7f3
1 Parent(s): 4ebb492

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -9
app.py CHANGED
@@ -6,18 +6,19 @@ import faiss
6
  import numpy as np
7
  from huggingface_hub import hf_hub_download
8
  from datasets import load_dataset
 
9
 
10
- hf_hub_download("merve/siglip-faiss-wikiart", "siglip_new.index", local_dir="./")
11
- index = faiss.read_index("./siglip_new.index")
 
12
 
13
- dataset = load_dataset("huggan/wikiart")
 
 
14
  device = torch.device('cuda' if torch.cuda.is_available() else "cpu")
15
- dataset = dataset.with_format("torch", device=device)
16
-
17
  processor = AutoProcessor.from_pretrained("nielsr/siglip-base-patch16-224")
18
  model = SiglipModel.from_pretrained("nielsr/siglip-base-patch16-224").to(device)
19
 
20
-
21
  def extract_features_siglip(image):
22
  with torch.no_grad():
23
  inputs = processor(images=image, return_tensors="pt").to(device)
@@ -29,14 +30,17 @@ def infer(input_image):
29
  input_features = input_features.detach().cpu().numpy()
30
  input_features = np.float32(input_features)
31
  faiss.normalize_L2(input_features)
32
- distances, indices = index2.search(input_features, 9)
33
  gallery_output = []
34
  for i,v in enumerate(indices[0]):
35
  sim = -distances[0][i]
36
- img_resized = dataset["train"][int(v)]['image']
37
- gallery_output.append(img_resized)
 
38
  return gallery_output
39
 
 
 
40
  gr.Interface(infer, "sketchpad", "gallery", title="Draw to Search Art 🖼️").launch()
41
 
42
 
 
6
  import numpy as np
7
  from huggingface_hub import hf_hub_download
8
  from datasets import load_dataset
9
+ import pandas as pd
10
 
11
+ # download model and dataset
12
+ hf_hub_download("merve/siglip-faiss-wikiart", "siglip_10k.index", local_dir="./")
13
+ hf_hub_download("merve/siglip-faiss-wikiart", "wikiart_10k.csv", local_dir="./")
14
 
15
+ # read index, dataset and load siglip model and processor
16
+ index = faiss.read_index("./siglip_10k.index")
17
+ df = pd.read_csv("./wikiart_10k.csv")
18
  device = torch.device('cuda' if torch.cuda.is_available() else "cpu")
 
 
19
  processor = AutoProcessor.from_pretrained("nielsr/siglip-base-patch16-224")
20
  model = SiglipModel.from_pretrained("nielsr/siglip-base-patch16-224").to(device)
21
 
 
22
  def extract_features_siglip(image):
23
  with torch.no_grad():
24
  inputs = processor(images=image, return_tensors="pt").to(device)
 
30
  input_features = input_features.detach().cpu().numpy()
31
  input_features = np.float32(input_features)
32
  faiss.normalize_L2(input_features)
33
+ distances, indices = index.search(input_features, 3)
34
  gallery_output = []
35
  for i,v in enumerate(indices[0]):
36
  sim = -distances[0][i]
37
+ image_url = df.iloc[v]["Link"]
38
+ img_retrieved = read_image_from_url(image_url)
39
+ gallery_output.append(img_retrieved)
40
  return gallery_output
41
 
42
+
43
+ description="This is an application where you can draw an image and find the closest artwork among 10k art from wikiart dataset."
44
  gr.Interface(infer, "sketchpad", "gallery", title="Draw to Search Art 🖼️").launch()
45
 
46