File size: 2,439 Bytes
4ebb492
 
 
 
 
 
 
 
e0de7f3
b53eff8
 
35191e2
b53eff8
4ebb492
e0de7f3
ccb786c
 
4ebb492
e0de7f3
74ba09f
 
4ebb492
fcc4e47
 
4ebb492
d2b0ea9
 
 
 
35191e2
d86dfd1
4ebb492
 
 
 
 
 
def1b42
4ebb492
b53eff8
4ebb492
 
 
e0de7f3
4ebb492
 
 
e0de7f3
 
 
4ebb492
 
e0de7f3
f487b28
b53eff8
 
 
4ebb492
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import torch
import gradio as gr
from PIL import Image
from transformers import AutoProcessor, SiglipModel
import faiss
import numpy as np
from huggingface_hub import hf_hub_download
from datasets import load_dataset
import pandas as pd
import requests
from io import BytesIO
import spaces


# download model and dataset
hf_hub_download("merve/siglip-faiss-wikiart", "siglip_10k_latest.index", local_dir="./")
hf_hub_download("merve/siglip-faiss-wikiart", "wikiart_10k_latest.csv", local_dir="./")

# read index, dataset and load siglip model and processor
index = faiss.read_index("./siglip_10k_latest.index")
df = pd.read_csv("./wikiart_10k_latest.csv")
device = torch.device('cuda' if torch.cuda.is_available() else "cpu")
processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")
model = SiglipModel.from_pretrained("google/siglip-base-patch16-224").to(device)

def read_image_from_url(url):
    response = requests.get(url)
    img = Image.open(BytesIO(response.content)).convert("RGB")
    return img

#@spaces.GPU
def extract_features_siglip(image):
    with torch.no_grad():
        inputs = processor(images=image, return_tensors="pt").to(device)
        image_features = model.get_image_features(**inputs)
    return image_features

@spaces.GPU
def infer(input_image):
  input_features = extract_features_siglip(input_image["composite"].convert("RGB"))
  input_features = input_features.detach().cpu().numpy()
  input_features = np.float32(input_features)
  faiss.normalize_L2(input_features)
  distances, indices = index.search(input_features, 3)
  gallery_output = []
  for i,v in enumerate(indices[0]):
    sim = -distances[0][i]
    image_url = df.iloc[v]["Link"]
    img_retrieved = read_image_from_url(image_url)
    gallery_output.append(img_retrieved)
  return gallery_output


description="This is an application where you can draw or upload an image and find the closest artwork among 10k art from wikiart dataset. This is built on 🤗 transformers integration of [SigLIP](https://github.com/merveenoyan/siglip?tab=readme-ov-file#siglip-projects-) model by Google, and FAISS for indexing. In this [link](https://github.com/merveenoyan/siglip?tab=readme-ov-file#siglip-projects-) you can also find the notebook to index the dataset using SigLIP."
sketchpad = gr.ImageEditor(type="pil")

gr.Interface(infer, sketchpad, "gallery", description=description, title="Draw to Search Art 🖼️").launch()