Slep's picture
Clean and add model link
7fe54e9
raw
history blame contribute delete
No virus
3.05 kB
import pandas as pd
import torch
import faiss
import gradio as gr
import base64
from PIL import Image
from io import BytesIO
from src.model import ConditionalViT, B16_Params, categories
from src.transform import valid_tf
from src.process_images import process_img, make_img_html
from src.examples import ExamplesHandler
from src.js_loader import JavaScriptLoader
# Load Model
m = ConditionalViT(**B16_Params, n_categories=len(categories))
m.load_state_dict(torch.load("./artifacts/cat_condvit_b16.pth", map_location="cpu"))
m.eval()
# Load data
index = faiss.read_index("./artifacts/gallery_index.faiss")
gal_imgs = pd.read_parquet("./artifacts/gallery_imgs.parquet")
tfs = valid_tf((224, 224))
K = 5
examples = [
["examples/3.jpg", "Outwear"],
["examples/3.jpg", "Lower Body"],
["examples/3.jpg", "Feet"],
["examples/757.jpg", "Bags"],
["examples/757.jpg", "Upper Body"],
["examples/769.jpg", "Upper Body"],
["examples/1811.jpg", "Lower Body"],
["examples/1811.jpg", "Bags"],
]
@torch.inference_mode()
def retrieval(image, category):
if image is None or category is None:
return
q_emb = m(tfs(image).unsqueeze(0), torch.tensor([category]))
r = index.search(q_emb, K)
imgs = [process_img(idx, gal_imgs) for idx in r[1][0]]
html = [make_img_html(i) for i in imgs]
html += ["<p></p>"] # Avoid Gradio's last-child{margin-bottom:0!important;}
return "\n".join(html)
JavaScriptLoader("src/custom_functions.js")
with gr.Blocks(css="src/style.css") as demo:
with gr.Column():
gr.Markdown(
"""
# Conditional ViT Demo
[[`Paper`](https://arxiv.org/abs/2306.02928)]
[[`Code`](https://github.com/Simon-Lepage/CondViT-LRVSF)]
[[`Dataset`](https://huggingface.co/datasets/Slep/LAION-RVS-Fashion)]
[[`Model`](https://huggingface.co/Slep/CondViT-B16-cat)]
*Running on 2 vCPU, 16Go RAM.*
- **Model :** Categorical CondViT-B/16
- **Gallery :** 93K images.
"""
)
# Input section
with gr.Row():
img = gr.Image(label="Query Image", type="pil", elem_id="query_img")
with gr.Column():
cat = gr.Dropdown(
choices=categories,
label="Category",
value="Upper Body",
type="index",
elem_id="dropdown",
)
submit = gr.Button("Submit")
# Examples
gr.Examples(
examples,
inputs=[img, cat],
fn=retrieval,
elem_id="preset_examples",
examples_per_page=100,
)
gr.HTML(
value=ExamplesHandler(examples).to_html(),
label="examples",
elem_id="html_examples",
)
# Outputs
gr.Markdown("# Retrieved Items")
out = gr.HTML(label="Results", elem_id="html_output")
submit.click(fn=retrieval, inputs=[img, cat], outputs=out)
demo.launch()