merve HF staff commited on
Commit
b53eff8
1 Parent(s): f9ba4ce

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -2
app.py CHANGED
@@ -7,6 +7,9 @@ import numpy as np
7
  from huggingface_hub import hf_hub_download
8
  from datasets import load_dataset
9
  import pandas as pd
 
 
 
10
 
11
  # download model and dataset
12
  hf_hub_download("merve/siglip-faiss-wikiart", "siglip_10k.index", local_dir="./")
@@ -31,7 +34,7 @@ def extract_features_siglip(image):
31
  return image_features
32
 
33
  def infer(input_image):
34
- input_features = extract_features_siglip(input_image)
35
  input_features = input_features.detach().cpu().numpy()
36
  input_features = np.float32(input_features)
37
  faiss.normalize_L2(input_features)
@@ -46,6 +49,8 @@ def infer(input_image):
46
 
47
 
48
  description="This is an application where you can draw an image and find the closest artwork among 10k art from wikiart dataset. This is built on 🤗 transformers integration of SIGLIP model by Google, and FAISS for indexing."
49
- gr.Interface(infer, "sketchpad", "gallery", description=description, title="Draw to Search Art 🖼️").launch()
 
 
50
 
51
 
 
7
  from huggingface_hub import hf_hub_download
8
  from datasets import load_dataset
9
  import pandas as pd
10
+ import requests
11
+ from io import BytesIO
12
+
13
 
14
  # download model and dataset
15
  hf_hub_download("merve/siglip-faiss-wikiart", "siglip_10k.index", local_dir="./")
 
34
  return image_features
35
 
36
  def infer(input_image):
37
+ input_features = extract_features_siglip(input_image["composite"].convert("RGB"))
38
  input_features = input_features.detach().cpu().numpy()
39
  input_features = np.float32(input_features)
40
  faiss.normalize_L2(input_features)
 
49
 
50
 
51
  description="This is an application where you can draw an image and find the closest artwork among 10k art from wikiart dataset. This is built on 🤗 transformers integration of SIGLIP model by Google, and FAISS for indexing."
52
+ sketchpad = gr.ImageEditor(type="pil")
53
+
54
+ gr.Interface(infer, sketchpad, "gallery", description=description, title="Draw to Search Art 🖼️").launch()
55
 
56