File size: 2,133 Bytes
4ebb492
 
 
 
 
 
 
 
e0de7f3
b53eff8
 
 
4ebb492
e0de7f3
 
 
4ebb492
e0de7f3
 
 
4ebb492
fcc4e47
 
4ebb492
d2b0ea9
 
 
 
 
4ebb492
 
 
 
 
 
 
b53eff8
4ebb492
 
 
e0de7f3
4ebb492
 
 
e0de7f3
 
 
4ebb492
 
e0de7f3
f9ba4ce
b53eff8
 
 
4ebb492
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import torch
import gradio as gr
from PIL import Image
from transformers import AutoProcessor, SiglipModel
import faiss
import numpy as np
from huggingface_hub import hf_hub_download
from datasets import load_dataset
import pandas as pd
import requests
from io import BytesIO


# download model and dataset
hf_hub_download("merve/siglip-faiss-wikiart", "siglip_10k.index", local_dir="./")
hf_hub_download("merve/siglip-faiss-wikiart", "wikiart_10k.csv", local_dir="./")

# read index, dataset and load siglip model and processor
index = faiss.read_index("./siglip_10k.index")
df = pd.read_csv("./wikiart_10k.csv")
device = torch.device('cuda' if torch.cuda.is_available() else "cpu")
processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")
model = SiglipModel.from_pretrained("google/siglip-base-patch16-224").to(device)

def read_image_from_url(url):
    response = requests.get(url)
    img = Image.open(BytesIO(response.content)).convert("RGB")
    return img
    
def extract_features_siglip(image):
    with torch.no_grad():
        inputs = processor(images=image, return_tensors="pt").to(device)
        image_features = model.get_image_features(**inputs)
    return image_features

def infer(input_image):
  input_features = extract_features_siglip(input_image["composite"].convert("RGB"))
  input_features = input_features.detach().cpu().numpy()
  input_features = np.float32(input_features)
  faiss.normalize_L2(input_features)
  distances, indices = index.search(input_features, 3)
  gallery_output = []
  for i,v in enumerate(indices[0]):
    sim = -distances[0][i]
    image_url = df.iloc[v]["Link"]
    img_retrieved = read_image_from_url(image_url)
    gallery_output.append(img_retrieved)
  return gallery_output


description="This is an application where you can draw an image and find the closest artwork among 10k art from wikiart dataset. This is built on 🤗 transformers integration of SIGLIP model by Google, and FAISS for indexing."
sketchpad = gr.ImageEditor(type="pil")

gr.Interface(infer, sketchpad, "gallery", description=description, title="Draw to Search Art 🖼️").launch()