merve's picture
merve HF staff
Update app.py
fcc4e47
raw history blame
No virus
2.13 kB
import torch
import gradio as gr
from PIL import Image
from transformers import AutoProcessor, SiglipModel
import faiss
import numpy as np
from huggingface_hub import hf_hub_download
from datasets import load_dataset
import pandas as pd
import requests
from io import BytesIO
# download model and dataset
hf_hub_download("merve/siglip-faiss-wikiart", "siglip_10k.index", local_dir="./")
hf_hub_download("merve/siglip-faiss-wikiart", "wikiart_10k.csv", local_dir="./")
# read index, dataset and load siglip model and processor
index = faiss.read_index("./siglip_10k.index")
df = pd.read_csv("./wikiart_10k.csv")
device = torch.device('cuda' if torch.cuda.is_available() else "cpu")
processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")
model = SiglipModel.from_pretrained("google/siglip-base-patch16-224").to(device)
def read_image_from_url(url):
response = requests.get(url)
img = Image.open(BytesIO(response.content)).convert("RGB")
return img
def extract_features_siglip(image):
with torch.no_grad():
inputs = processor(images=image, return_tensors="pt").to(device)
image_features = model.get_image_features(**inputs)
return image_features
def infer(input_image):
input_features = extract_features_siglip(input_image["composite"].convert("RGB"))
input_features = input_features.detach().cpu().numpy()
input_features = np.float32(input_features)
faiss.normalize_L2(input_features)
distances, indices = index.search(input_features, 3)
gallery_output = []
for i,v in enumerate(indices[0]):
sim = -distances[0][i]
image_url = df.iloc[v]["Link"]
img_retrieved = read_image_from_url(image_url)
gallery_output.append(img_retrieved)
return gallery_output
description="This is an application where you can draw an image and find the closest artwork among 10k art from wikiart dataset. This is built on 🤗 transformers integration of SIGLIP model by Google, and FAISS for indexing."
sketchpad = gr.ImageEditor(type="pil")
gr.Interface(infer, sketchpad, "gallery", description=description, title="Draw to Search Art 🖼️").launch()