rxavier commited on
Commit
297a61e
1 Parent(s): 7fdac21

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -7
app.py CHANGED
@@ -4,8 +4,8 @@ from off_topic import OffTopicDetector
4
 
5
  detector = OffTopicDetector("openai/clip-vit-base-patch32")
6
 
7
- def validate(item_id: str, threshold: float):
8
- images, domain, probas, valid_probas, invalid_probas = detector.predict_probas_item(item_id)
9
  valid_images = [x for i, x in enumerate(images) if valid_probas[i].squeeze() >= threshold]
10
  invalid_images = [x for i, x in enumerate(images) if valid_probas[i].squeeze() < threshold]
11
  return f"## Domain: {domain}", valid_images, invalid_images
@@ -15,19 +15,21 @@ with gr.Blocks() as demo:
15
  # Off topic image detector
16
  ### This app takes an item ID and classifies its pictures as valid/invalid depending on whether they relate to the domain in which it's been listed.
17
  Input an item ID or select one of the preloaded examples below.""")
18
- item_id = gr.Textbox(label="Item ID")
19
- threshold = gr.Number(label="Threshold", value=0.25, precision=2)
 
 
20
  submit = gr.Button("Submit")
21
  gr.HTML("<hr>")
22
  domain = gr.Markdown()
23
  valid = gr.Gallery(label="Valid images").style(grid=[1, 2, 3], height="auto")
24
  gr.HTML("<hr>")
25
  invalid = gr.Gallery(label="Invalid images").style(grid=[1, 2, 3], height="auto")
26
- submit.click(inputs=[item_id, threshold], outputs=[domain, valid, invalid], fn=validate)
27
  gr.HTML("<hr>")
28
  gr.Examples(
29
- examples=[["MLC572974424", 0.25], ["MLU449951849", 0.25], ["MLA1293465558", 0.25],
30
- ["MLB3184663685", 0.25], ["MLC1392230619", 0.25], ["MCO546152796", 0.25]],
31
  inputs=[item_id, threshold],
32
  outputs=[domain, valid, invalid],
33
  fn=validate,
 
4
 
5
  detector = OffTopicDetector("openai/clip-vit-base-patch32")
6
 
7
+ def validate(item_id: str, use_title: bool, threshold: float):
8
+ images, domain, probas, valid_probas, invalid_probas = detector.predict_probas_item(item_id, use_title=use_title)
9
  valid_images = [x for i, x in enumerate(images) if valid_probas[i].squeeze() >= threshold]
10
  invalid_images = [x for i, x in enumerate(images) if valid_probas[i].squeeze() < threshold]
11
  return f"## Domain: {domain}", valid_images, invalid_images
 
15
  # Off topic image detector
16
  ### This app takes an item ID and classifies its pictures as valid/invalid depending on whether they relate to the domain in which it's been listed.
17
  Input an item ID or select one of the preloaded examples below.""")
18
+ with gr.Row():
19
+ item_id = gr.Textbox(label="Item ID")
20
+ use_title = gr.Checkbox(label="Use item title", value=True)
21
+ threshold = gr.Number(label="Threshold", value=0.25, precision=2)
22
  submit = gr.Button("Submit")
23
  gr.HTML("<hr>")
24
  domain = gr.Markdown()
25
  valid = gr.Gallery(label="Valid images").style(grid=[1, 2, 3], height="auto")
26
  gr.HTML("<hr>")
27
  invalid = gr.Gallery(label="Invalid images").style(grid=[1, 2, 3], height="auto")
28
+ submit.click(inputs=[item_id, use_title, threshold], outputs=[domain, valid, invalid], fn=validate)
29
  gr.HTML("<hr>")
30
  gr.Examples(
31
+ examples=[["MLC572974424", True, 0.25], ["MLU449951849", True, 0.25], ["MLA1293465558", True, 0.25],
32
+ ["MLB3184663685", True, 0.25], ["MLC1392230619", True, 0.25], ["MCO546152796", True, 0.25]],
33
  inputs=[item_id, threshold],
34
  outputs=[domain, valid, invalid],
35
  fn=validate,