rxavier's picture
Update app.py
31e18ef
raw
history blame
No virus
4.34 kB
from typing import Optional
import gradio as gr
from PIL import Image
from off_topic import OffTopicDetector, Translator
translator = Translator("Helsinki-NLP/opus-mt-roa-en")
detector = OffTopicDetector("openai/clip-vit-base-patch32", image_size="V", translator=translator)
def validate_item(item_id: str, use_title: bool, threshold: float):
images, domain, probas, valid_probas, invalid_probas = detector.predict_probas_item(item_id, use_title=use_title)
valid_images = [x for i, x in enumerate(images) if valid_probas[i].squeeze() >= threshold]
invalid_images = [x for i, x in enumerate(images) if valid_probas[i].squeeze() < threshold]
return f"## Domain: {domain}", valid_images, invalid_images
def validate_images(img_url_1, img_url_2, img_url_3, domain: str, title: str, threshold: float):
img_urls = [url for url in [img_url_1, img_url_2, img_url_3] if url != ""]
site, domain = domain.split("-")
domain_text = domain.replace("_", " ").lower()
if title == "":
title = None
probas, valid_probas, invalid_probas = detector.predict_probas_url(img_urls, domain_text, site, title)
valid_images = [x for i, x in enumerate(images) if valid_probas[i].squeeze() >= threshold]
invalid_images = [x for i, x in enumerate(images) if valid_probas[i].squeeze() < threshold]
return f"## Domain: {domain}", valid_images, invalid_images
with gr.Blocks() as demo:
gr.Markdown("""
# Off topic image detector
### This app takes an item ID and classifies its pictures as valid/invalid depending on whether they relate to the domain in which it's been listed.
Input an item ID or select one of the preloaded examples below.""")
with gr.Tab("From item_id"):
with gr.Row():
item_id = gr.Textbox(label="Item ID")
with gr.Column():
use_title = gr.Checkbox(label="Use translated item title", value=True)
threshold = gr.Number(label="Threshold", value=0.25, precision=2)
submit = gr.Button("Submit")
gr.HTML("<hr>")
domain = gr.Markdown()
valid = gr.Gallery(label="Valid images").style(grid=[1, 2, 3], height="auto")
gr.HTML("<hr>")
invalid = gr.Gallery(label="Invalid images").style(grid=[1, 2, 3], height="auto")
submit.click(inputs=[item_id, use_title, threshold], outputs=[domain, valid, invalid], fn=validate_item)
gr.HTML("<hr>")
gr.Examples(
examples=[["MLC572974424", True, 0.25], ["MLU449951849", True, 0.25], ["MLA1293465558", True, 0.25],
["MLB3184663685", True, 0.25], ["MLC1392230619", True, 0.25], ["MCO546152796", True, 0.25]],
inputs=[item_id, use_title, threshold],
outputs=[domain, valid, invalid],
fn=validate_item,
cache_examples=True,
)
with gr.Tab("From image urls"):
with gr.Row():
with gr.Column():
img_url_1 = gr.Textbox(label="Picture URL")
img_url_2 = gr.Textbox(label="Picture URL")
img_url_3 = gr.Textbox(label="Picture URL")
with gr.Column():
domain = gr.Textbox(label="Domain ID", placeholder="Required")
title = gr.Textbox(label="Item title", placeholder="Optional")
threshold = gr.Number(label="Threshold", value=0.25, precision=2)
submit = gr.Button("Submit")
gr.HTML("<hr>")
domain_output = gr.Markdown()
valid = gr.Gallery(label="Valid images").style(grid=[1, 2, 3], height="auto")
gr.HTML("<hr>")
invalid = gr.Gallery(label="Invalid images").style(grid=[1, 2, 3], height="auto")
submit.click(inputs=[img_url_1, img_url_2, img_url_3, domain, title, threshold], outputs=[domain_output, valid, invalid], fn=validate_images)
gr.HTML("<hr>")
#gr.Examples(
# examples=[["MLC572974424", True, 0.25], ["MLU449951849", True, 0.25], ["MLA1293465558", True, 0.25],
# ["MLB3184663685", True, 0.25], ["MLC1392230619", True, 0.25], ["MCO546152796", True, 0.25]],
# inputs=[item_id, use_title, threshold],
# outputs=[domain, valid, invalid],
# fn=validate,
# cache_examples=True,
#)
demo.launch()