Spaces:
Running
Running
kevinconka
commited on
Commit
·
955daea
1
Parent(s):
c54f19a
save flagged data to HF dataset
Browse files- app.py +19 -19
- flagging.py +77 -0
- utils.py +36 -6
app.py
CHANGED
@@ -1,13 +1,9 @@
|
|
1 |
import gradio as gr
|
2 |
-
from
|
|
|
|
|
3 |
|
4 |
|
5 |
-
BADGES = """
|
6 |
-
<p align="right">
|
7 |
-
<img alt="Static Badge" src="https://img.shields.io/badge/SEA.AI-beta-blue">
|
8 |
-
</p>
|
9 |
-
"""
|
10 |
-
|
11 |
TITLE = """
|
12 |
<h1> RGB Detection Demo </h1>
|
13 |
<p align="center">
|
@@ -19,7 +15,8 @@ Give it a try! Upload an image or enter a URL to an image and click
|
|
19 |
NOTICE = """
|
20 |
See something off? Your feedback makes a difference! Let us know by
|
21 |
flagging any outcomes that don't seem right. Just click on `Flag`
|
22 |
-
to submit the image for review.
|
|
|
23 |
"""
|
24 |
|
25 |
css = """
|
@@ -36,11 +33,12 @@ model.max_det = 100
|
|
36 |
model.agnostic = True # NMS class-agnostic
|
37 |
|
38 |
# This callback will be used to flag images
|
39 |
-
|
|
|
40 |
|
41 |
with gr.Blocks(css=css) as demo:
|
42 |
-
gr.
|
43 |
-
gr.
|
44 |
|
45 |
with gr.Row():
|
46 |
with gr.Column():
|
@@ -80,6 +78,7 @@ with gr.Blocks(css=css) as demo:
|
|
80 |
img_url.change(load_image_from_url, [img_url], img_input)
|
81 |
submit.click(lambda image: inference(model, image), [img_input], img_output)
|
82 |
|
|
|
83 |
@img_output.change(inputs=[img_output], outputs=[flag, notice])
|
84 |
def show_hide(img_output):
|
85 |
visible = img_output is not None
|
@@ -89,15 +88,16 @@ with gr.Blocks(css=css) as demo:
|
|
89 |
}
|
90 |
|
91 |
# This needs to be called prior to the first call to callback.flag()
|
92 |
-
|
|
|
93 |
|
94 |
# We can choose which components to flag (in this case, we'll flag all)
|
95 |
-
flag.click(
|
96 |
-
lambda *args:
|
97 |
-
[img_input,
|
98 |
-
|
99 |
preprocess=False,
|
100 |
-
).then(lambda:
|
101 |
-
|
102 |
|
103 |
-
|
|
|
|
1 |
import gradio as gr
|
2 |
+
from huggingface_hub import get_token
|
3 |
+
from utils import load_model, load_image_from_url, inference, load_badges
|
4 |
+
from flagging import myHuggingFaceDatasetSaver
|
5 |
|
6 |
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
TITLE = """
|
8 |
<h1> RGB Detection Demo </h1>
|
9 |
<p align="center">
|
|
|
15 |
NOTICE = """
|
16 |
See something off? Your feedback makes a difference! Let us know by
|
17 |
flagging any outcomes that don't seem right. Just click on `Flag`
|
18 |
+
to submit the image for review. Note that by clicking `Flag`, you
|
19 |
+
agree to the use of your image for A.I. improvement purposes.
|
20 |
"""
|
21 |
|
22 |
css = """
|
|
|
33 |
model.agnostic = True # NMS class-agnostic
|
34 |
|
35 |
# This callback will be used to flag images
|
36 |
+
dataset_name = "SEA-AI/crowdsourced-rgb-images"
|
37 |
+
hf_writer = myHuggingFaceDatasetSaver(get_token(), dataset_name)
|
38 |
|
39 |
with gr.Blocks(css=css) as demo:
|
40 |
+
badges = gr.HTML(load_badges(dataset_name, trials=1))
|
41 |
+
title = gr.HTML(TITLE)
|
42 |
|
43 |
with gr.Row():
|
44 |
with gr.Column():
|
|
|
78 |
img_url.change(load_image_from_url, [img_url], img_input)
|
79 |
submit.click(lambda image: inference(model, image), [img_input], img_output)
|
80 |
|
81 |
+
# event listeners with decorators
|
82 |
@img_output.change(inputs=[img_output], outputs=[flag, notice])
|
83 |
def show_hide(img_output):
|
84 |
visible = img_output is not None
|
|
|
88 |
}
|
89 |
|
90 |
# This needs to be called prior to the first call to callback.flag()
|
91 |
+
hf_writer.setup([img_input], "flagged")
|
92 |
+
img_input.flag
|
93 |
|
94 |
# We can choose which components to flag (in this case, we'll flag all)
|
95 |
+
flag.click(lambda: gr.Info("Thank you for contributing!")).then(
|
96 |
+
lambda *args: hf_writer.flag(args),
|
97 |
+
[img_input, flag],
|
98 |
+
[],
|
99 |
preprocess=False,
|
100 |
+
).then(lambda: load_badges(dataset_name), [], badges)
|
|
|
101 |
|
102 |
+
if __name__ == "__main__":
|
103 |
+
demo.queue().launch()
|
flagging.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
from collections import OrderedDict
|
3 |
+
from pathlib import Path
|
4 |
+
from typing import Any
|
5 |
+
import gradio as gr
|
6 |
+
from gradio.flagging import HuggingFaceDatasetSaver, client_utils
|
7 |
+
import huggingface_hub
|
8 |
+
|
9 |
+
class myHuggingFaceDatasetSaver(HuggingFaceDatasetSaver):
|
10 |
+
"""
|
11 |
+
Custom HuggingFaceDatasetSaver to save images/audio to disk.
|
12 |
+
Gradio's implementation seems to have a bug.
|
13 |
+
"""
|
14 |
+
|
15 |
+
def __init__(self, *args, **kwargs):
|
16 |
+
super().__init__(*args, **kwargs)
|
17 |
+
|
18 |
+
def _deserialize_components(
|
19 |
+
self,
|
20 |
+
data_dir: Path,
|
21 |
+
flag_data: list[Any],
|
22 |
+
flag_option: str = "",
|
23 |
+
username: str = "",
|
24 |
+
) -> tuple[dict[Any, Any], list[Any]]:
|
25 |
+
"""Deserialize components and return the corresponding row for the flagged sample.
|
26 |
+
|
27 |
+
Images/audio are saved to disk as individual files.
|
28 |
+
"""
|
29 |
+
# Components that can have a preview on dataset repos
|
30 |
+
file_preview_types = {gr.Audio: "Audio", gr.Image: "Image"}
|
31 |
+
|
32 |
+
# Generate the row corresponding to the flagged sample
|
33 |
+
features = OrderedDict()
|
34 |
+
row = []
|
35 |
+
for component, sample in zip(self.components, flag_data):
|
36 |
+
# Get deserialized object (will save sample to disk if applicable -file, audio, image,...-)
|
37 |
+
label = component.label or ""
|
38 |
+
save_dir = data_dir / client_utils.strip_invalid_filename_characters(label)
|
39 |
+
save_dir.mkdir(exist_ok=True, parents=True)
|
40 |
+
deserialized = component.flag(sample, save_dir)
|
41 |
+
if isinstance(component, gr.Image) and isinstance(sample, dict):
|
42 |
+
deserialized = json.loads(deserialized)['path'] # dirty hack
|
43 |
+
|
44 |
+
# Add deserialized object to row
|
45 |
+
features[label] = {"dtype": "string", "_type": "Value"}
|
46 |
+
try:
|
47 |
+
assert Path(deserialized).exists()
|
48 |
+
row.append(str(Path(deserialized).relative_to(self.dataset_dir)))
|
49 |
+
except (AssertionError, TypeError, ValueError):
|
50 |
+
deserialized = "" if deserialized is None else str(deserialized)
|
51 |
+
row.append(deserialized)
|
52 |
+
|
53 |
+
# If component is eligible for a preview, add the URL of the file
|
54 |
+
# Be mindful that images and audio can be None
|
55 |
+
if isinstance(component, tuple(file_preview_types)): # type: ignore
|
56 |
+
for _component, _type in file_preview_types.items():
|
57 |
+
if isinstance(component, _component):
|
58 |
+
features[label + " file"] = {"_type": _type}
|
59 |
+
break
|
60 |
+
if deserialized:
|
61 |
+
path_in_repo = str( # returned filepath is absolute, we want it relative to compute URL
|
62 |
+
Path(deserialized).relative_to(self.dataset_dir)
|
63 |
+
).replace("\\", "/")
|
64 |
+
row.append(
|
65 |
+
huggingface_hub.hf_hub_url(
|
66 |
+
repo_id=self.dataset_id,
|
67 |
+
filename=path_in_repo,
|
68 |
+
repo_type="dataset",
|
69 |
+
)
|
70 |
+
)
|
71 |
+
else:
|
72 |
+
row.append("")
|
73 |
+
features["flag"] = {"dtype": "string", "_type": "Value"}
|
74 |
+
features["username"] = {"dtype": "string", "_type": "Value"}
|
75 |
+
row.append(flag_option)
|
76 |
+
row.append(username)
|
77 |
+
return features, row
|
utils.py
CHANGED
@@ -1,4 +1,3 @@
|
|
1 |
-
import os
|
2 |
import requests
|
3 |
from io import BytesIO
|
4 |
import numpy as np
|
@@ -6,14 +5,12 @@ from PIL import Image
|
|
6 |
import yolov5
|
7 |
from yolov5.utils.plots import Annotator, colors
|
8 |
import gradio as gr
|
|
|
|
|
9 |
|
10 |
|
11 |
def load_model(model_path, img_size=640):
|
12 |
-
|
13 |
-
if HF_TOKEN is not None: # assume SECRET variable is set
|
14 |
-
model = yolov5.load(model_path, hf_token=HF_TOKEN)
|
15 |
-
else:
|
16 |
-
model = yolov5.load(model_path)
|
17 |
model.img_size = img_size # add img_size attribute
|
18 |
return model
|
19 |
|
@@ -37,3 +34,36 @@ def inference(model, image):
|
|
37 |
# print(f'{cls} {conf:.2f} {box}')
|
38 |
annotator.box_label(box, "", color=colors(cls, True))
|
39 |
return annotator.im
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import requests
|
2 |
from io import BytesIO
|
3 |
import numpy as np
|
|
|
5 |
import yolov5
|
6 |
from yolov5.utils.plots import Annotator, colors
|
7 |
import gradio as gr
|
8 |
+
from huggingface_hub import get_token
|
9 |
+
import time
|
10 |
|
11 |
|
12 |
def load_model(model_path, img_size=640):
|
13 |
+
model = yolov5.load(model_path, hf_token=get_token())
|
|
|
|
|
|
|
|
|
14 |
model.img_size = img_size # add img_size attribute
|
15 |
return model
|
16 |
|
|
|
34 |
# print(f'{cls} {conf:.2f} {box}')
|
35 |
annotator.box_label(box, "", color=colors(cls, True))
|
36 |
return annotator.im
|
37 |
+
|
38 |
+
|
39 |
+
def count_flagged_images(dataset_name, trials=10):
|
40 |
+
headers = {"Authorization": f"Bearer {get_token()}"}
|
41 |
+
API_URL = f"https://datasets-server.huggingface.co/size?dataset={dataset_name}"
|
42 |
+
|
43 |
+
def query():
|
44 |
+
response = requests.get(API_URL, headers=headers, timeout=5)
|
45 |
+
return response.json()
|
46 |
+
|
47 |
+
for i in range(trials):
|
48 |
+
try:
|
49 |
+
data = query()
|
50 |
+
if "error" not in data and data["size"]["dataset"]["num_rows"] > 0:
|
51 |
+
print(f"[{i+1}/{trials}] {data}")
|
52 |
+
return data["size"]["dataset"]["num_rows"]
|
53 |
+
except Exception:
|
54 |
+
pass
|
55 |
+
print(f"[{i+1}/{trials}] {data}")
|
56 |
+
time.sleep(5)
|
57 |
+
|
58 |
+
return 0
|
59 |
+
|
60 |
+
|
61 |
+
def load_badges(dataset_name, trials=10):
|
62 |
+
n = count_flagged_images(dataset_name, trials)
|
63 |
+
return f"""
|
64 |
+
<p style="display: flex">
|
65 |
+
<img alt="" src="https://img.shields.io/badge/SEA.AI-beta-blue">
|
66 |
+
|
67 |
+
<img alt="" src="https://img.shields.io/badge/%F0%9F%96%BC%EF%B8%8F-{n}-green">
|
68 |
+
</p>
|
69 |
+
"""
|