kevinconka commited on
Commit
955daea
·
1 Parent(s): c54f19a

save flagged data to HF dataset

Browse files
Files changed (3) hide show
  1. app.py +19 -19
  2. flagging.py +77 -0
  3. utils.py +36 -6
app.py CHANGED
@@ -1,13 +1,9 @@
1
  import gradio as gr
2
- from utils import load_model, load_image_from_url, inference
 
 
3
 
4
 
5
- BADGES = """
6
- <p align="right">
7
- <img alt="Static Badge" src="https://img.shields.io/badge/SEA.AI-beta-blue">
8
- </p>
9
- """
10
-
11
  TITLE = """
12
  <h1> RGB Detection Demo </h1>
13
  <p align="center">
@@ -19,7 +15,8 @@ Give it a try! Upload an image or enter a URL to an image and click
19
  NOTICE = """
20
  See something off? Your feedback makes a difference! Let us know by
21
  flagging any outcomes that don't seem right. Just click on `Flag`
22
- to submit the image for review.
 
23
  """
24
 
25
  css = """
@@ -36,11 +33,12 @@ model.max_det = 100
36
  model.agnostic = True # NMS class-agnostic
37
 
38
  # This callback will be used to flag images
39
- callback = gr.CSVLogger()
 
40
 
41
  with gr.Blocks(css=css) as demo:
42
- gr.Markdown(value=BADGES)
43
- gr.Markdown(value=TITLE)
44
 
45
  with gr.Row():
46
  with gr.Column():
@@ -80,6 +78,7 @@ with gr.Blocks(css=css) as demo:
80
  img_url.change(load_image_from_url, [img_url], img_input)
81
  submit.click(lambda image: inference(model, image), [img_input], img_output)
82
 
 
83
  @img_output.change(inputs=[img_output], outputs=[flag, notice])
84
  def show_hide(img_output):
85
  visible = img_output is not None
@@ -89,15 +88,16 @@ with gr.Blocks(css=css) as demo:
89
  }
90
 
91
  # This needs to be called prior to the first call to callback.flag()
92
- callback.setup([img_input, img_url, img_output], "flagged")
 
93
 
94
  # We can choose which components to flag (in this case, we'll flag all)
95
- flag.click(
96
- lambda *args: callback.flag(args),
97
- [img_input, img_url, img_output],
98
- None,
99
  preprocess=False,
100
- ).then(lambda: gr.Info("Thank you for contributing!"))
101
-
102
 
103
- demo.queue().launch()
 
 
1
  import gradio as gr
2
+ from huggingface_hub import get_token
3
+ from utils import load_model, load_image_from_url, inference, load_badges
4
+ from flagging import myHuggingFaceDatasetSaver
5
 
6
 
 
 
 
 
 
 
7
  TITLE = """
8
  <h1> RGB Detection Demo </h1>
9
  <p align="center">
 
15
  NOTICE = """
16
  See something off? Your feedback makes a difference! Let us know by
17
  flagging any outcomes that don't seem right. Just click on `Flag`
18
+ to submit the image for review. Note that by clicking `Flag`, you
19
+ agree to the use of your image for A.I. improvement purposes.
20
  """
21
 
22
  css = """
 
33
  model.agnostic = True # NMS class-agnostic
34
 
35
  # This callback will be used to flag images
36
+ dataset_name = "SEA-AI/crowdsourced-rgb-images"
37
+ hf_writer = myHuggingFaceDatasetSaver(get_token(), dataset_name)
38
 
39
  with gr.Blocks(css=css) as demo:
40
+ badges = gr.HTML(load_badges(dataset_name, trials=1))
41
+ title = gr.HTML(TITLE)
42
 
43
  with gr.Row():
44
  with gr.Column():
 
78
  img_url.change(load_image_from_url, [img_url], img_input)
79
  submit.click(lambda image: inference(model, image), [img_input], img_output)
80
 
81
+ # event listeners with decorators
82
  @img_output.change(inputs=[img_output], outputs=[flag, notice])
83
  def show_hide(img_output):
84
  visible = img_output is not None
 
88
  }
89
 
90
  # This needs to be called prior to the first call to callback.flag()
91
+ hf_writer.setup([img_input], "flagged")
92
+ img_input.flag
93
 
94
  # We can choose which components to flag (in this case, we'll flag all)
95
+ flag.click(lambda: gr.Info("Thank you for contributing!")).then(
96
+ lambda *args: hf_writer.flag(args),
97
+ [img_input, flag],
98
+ [],
99
  preprocess=False,
100
+ ).then(lambda: load_badges(dataset_name), [], badges)
 
101
 
102
+ if __name__ == "__main__":
103
+ demo.queue().launch()
flagging.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from collections import OrderedDict
3
+ from pathlib import Path
4
+ from typing import Any
5
+ import gradio as gr
6
+ from gradio.flagging import HuggingFaceDatasetSaver, client_utils
7
+ import huggingface_hub
8
+
9
+ class myHuggingFaceDatasetSaver(HuggingFaceDatasetSaver):
10
+ """
11
+ Custom HuggingFaceDatasetSaver to save images/audio to disk.
12
+ Gradio's implementation seems to have a bug.
13
+ """
14
+
15
+ def __init__(self, *args, **kwargs):
16
+ super().__init__(*args, **kwargs)
17
+
18
+ def _deserialize_components(
19
+ self,
20
+ data_dir: Path,
21
+ flag_data: list[Any],
22
+ flag_option: str = "",
23
+ username: str = "",
24
+ ) -> tuple[dict[Any, Any], list[Any]]:
25
+ """Deserialize components and return the corresponding row for the flagged sample.
26
+
27
+ Images/audio are saved to disk as individual files.
28
+ """
29
+ # Components that can have a preview on dataset repos
30
+ file_preview_types = {gr.Audio: "Audio", gr.Image: "Image"}
31
+
32
+ # Generate the row corresponding to the flagged sample
33
+ features = OrderedDict()
34
+ row = []
35
+ for component, sample in zip(self.components, flag_data):
36
+ # Get deserialized object (will save sample to disk if applicable -file, audio, image,...-)
37
+ label = component.label or ""
38
+ save_dir = data_dir / client_utils.strip_invalid_filename_characters(label)
39
+ save_dir.mkdir(exist_ok=True, parents=True)
40
+ deserialized = component.flag(sample, save_dir)
41
+ if isinstance(component, gr.Image) and isinstance(sample, dict):
42
+ deserialized = json.loads(deserialized)['path'] # dirty hack
43
+
44
+ # Add deserialized object to row
45
+ features[label] = {"dtype": "string", "_type": "Value"}
46
+ try:
47
+ assert Path(deserialized).exists()
48
+ row.append(str(Path(deserialized).relative_to(self.dataset_dir)))
49
+ except (AssertionError, TypeError, ValueError):
50
+ deserialized = "" if deserialized is None else str(deserialized)
51
+ row.append(deserialized)
52
+
53
+ # If component is eligible for a preview, add the URL of the file
54
+ # Be mindful that images and audio can be None
55
+ if isinstance(component, tuple(file_preview_types)): # type: ignore
56
+ for _component, _type in file_preview_types.items():
57
+ if isinstance(component, _component):
58
+ features[label + " file"] = {"_type": _type}
59
+ break
60
+ if deserialized:
61
+ path_in_repo = str( # returned filepath is absolute, we want it relative to compute URL
62
+ Path(deserialized).relative_to(self.dataset_dir)
63
+ ).replace("\\", "/")
64
+ row.append(
65
+ huggingface_hub.hf_hub_url(
66
+ repo_id=self.dataset_id,
67
+ filename=path_in_repo,
68
+ repo_type="dataset",
69
+ )
70
+ )
71
+ else:
72
+ row.append("")
73
+ features["flag"] = {"dtype": "string", "_type": "Value"}
74
+ features["username"] = {"dtype": "string", "_type": "Value"}
75
+ row.append(flag_option)
76
+ row.append(username)
77
+ return features, row
utils.py CHANGED
@@ -1,4 +1,3 @@
1
- import os
2
  import requests
3
  from io import BytesIO
4
  import numpy as np
@@ -6,14 +5,12 @@ from PIL import Image
6
  import yolov5
7
  from yolov5.utils.plots import Annotator, colors
8
  import gradio as gr
 
 
9
 
10
 
11
  def load_model(model_path, img_size=640):
12
- HF_TOKEN = os.getenv("HF_TOKEN")
13
- if HF_TOKEN is not None: # assume SECRET variable is set
14
- model = yolov5.load(model_path, hf_token=HF_TOKEN)
15
- else:
16
- model = yolov5.load(model_path)
17
  model.img_size = img_size # add img_size attribute
18
  return model
19
 
@@ -37,3 +34,36 @@ def inference(model, image):
37
  # print(f'{cls} {conf:.2f} {box}')
38
  annotator.box_label(box, "", color=colors(cls, True))
39
  return annotator.im
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import requests
2
  from io import BytesIO
3
  import numpy as np
 
5
  import yolov5
6
  from yolov5.utils.plots import Annotator, colors
7
  import gradio as gr
8
+ from huggingface_hub import get_token
9
+ import time
10
 
11
 
12
  def load_model(model_path, img_size=640):
13
+ model = yolov5.load(model_path, hf_token=get_token())
 
 
 
 
14
  model.img_size = img_size # add img_size attribute
15
  return model
16
 
 
34
  # print(f'{cls} {conf:.2f} {box}')
35
  annotator.box_label(box, "", color=colors(cls, True))
36
  return annotator.im
37
+
38
+
39
+ def count_flagged_images(dataset_name, trials=10):
40
+ headers = {"Authorization": f"Bearer {get_token()}"}
41
+ API_URL = f"https://datasets-server.huggingface.co/size?dataset={dataset_name}"
42
+
43
+ def query():
44
+ response = requests.get(API_URL, headers=headers, timeout=5)
45
+ return response.json()
46
+
47
+ for i in range(trials):
48
+ try:
49
+ data = query()
50
+ if "error" not in data and data["size"]["dataset"]["num_rows"] > 0:
51
+ print(f"[{i+1}/{trials}] {data}")
52
+ return data["size"]["dataset"]["num_rows"]
53
+ except Exception:
54
+ pass
55
+ print(f"[{i+1}/{trials}] {data}")
56
+ time.sleep(5)
57
+
58
+ return 0
59
+
60
+
61
+ def load_badges(dataset_name, trials=10):
62
+ n = count_flagged_images(dataset_name, trials)
63
+ return f"""
64
+ <p style="display: flex">
65
+ <img alt="" src="https://img.shields.io/badge/SEA.AI-beta-blue">
66
+ &nbsp;
67
+ <img alt="" src="https://img.shields.io/badge/%F0%9F%96%BC%EF%B8%8F-{n}-green">
68
+ </p>
69
+ """