Spaces:
Runtime error
Runtime error
store examples in named directories
Browse filesclean up flagging sharing functionality
- app.py +8 -4
- gradio_patches/examples.py +13 -0
- flagging.py → gradio_patches/flagging.py +6 -228
app.py
CHANGED
|
@@ -36,7 +36,8 @@ from huggingface_hub import login
|
|
| 36 |
from tqdm import tqdm
|
| 37 |
|
| 38 |
from extrude import extrude_depth_3d
|
| 39 |
-
from
|
|
|
|
| 40 |
from marigold_depth_estimation_lcm import MarigoldDepthConsistencyPipeline
|
| 41 |
|
| 42 |
warnings.filterwarnings(
|
|
@@ -533,7 +534,7 @@ def run_demo_server(pipe, hf_writer=None):
|
|
| 533 |
"Share", variant="stop", scale=1
|
| 534 |
)
|
| 535 |
|
| 536 |
-
|
| 537 |
fn=process_pipe_image,
|
| 538 |
examples=[
|
| 539 |
os.path.join("files", "image", name)
|
|
@@ -568,6 +569,7 @@ def run_demo_server(pipe, hf_writer=None):
|
|
| 568 |
inputs=[image_input],
|
| 569 |
outputs=[image_output_slider, image_output_files],
|
| 570 |
cache_examples=True,
|
|
|
|
| 571 |
)
|
| 572 |
|
| 573 |
with gr.Tab("Video"):
|
|
@@ -592,7 +594,7 @@ def run_demo_server(pipe, hf_writer=None):
|
|
| 592 |
elem_id="download",
|
| 593 |
interactive=False,
|
| 594 |
)
|
| 595 |
-
|
| 596 |
fn=process_pipe_video,
|
| 597 |
examples=[
|
| 598 |
os.path.join("files", "video", name)
|
|
@@ -605,6 +607,7 @@ def run_demo_server(pipe, hf_writer=None):
|
|
| 605 |
inputs=[video_input],
|
| 606 |
outputs=[video_output_video, video_output_files],
|
| 607 |
cache_examples=True,
|
|
|
|
| 608 |
)
|
| 609 |
|
| 610 |
with gr.Tab("Bas-relief (3D)"):
|
|
@@ -729,7 +732,7 @@ def run_demo_server(pipe, hf_writer=None):
|
|
| 729 |
elem_id="download",
|
| 730 |
interactive=False,
|
| 731 |
)
|
| 732 |
-
|
| 733 |
fn=process_pipe_bas,
|
| 734 |
examples=[
|
| 735 |
[
|
|
@@ -795,6 +798,7 @@ def run_demo_server(pipe, hf_writer=None):
|
|
| 795 |
],
|
| 796 |
outputs=[bas_output_viewer, bas_output_files],
|
| 797 |
cache_examples=True,
|
|
|
|
| 798 |
)
|
| 799 |
|
| 800 |
### Image tab
|
|
|
|
| 36 |
from tqdm import tqdm
|
| 37 |
|
| 38 |
from extrude import extrude_depth_3d
|
| 39 |
+
from gradio_patches.examples import Examples
|
| 40 |
+
from gradio_patches.flagging import FlagMethod, HuggingFaceDatasetSaver
|
| 41 |
from marigold_depth_estimation_lcm import MarigoldDepthConsistencyPipeline
|
| 42 |
|
| 43 |
warnings.filterwarnings(
|
|
|
|
| 534 |
"Share", variant="stop", scale=1
|
| 535 |
)
|
| 536 |
|
| 537 |
+
Examples(
|
| 538 |
fn=process_pipe_image,
|
| 539 |
examples=[
|
| 540 |
os.path.join("files", "image", name)
|
|
|
|
| 569 |
inputs=[image_input],
|
| 570 |
outputs=[image_output_slider, image_output_files],
|
| 571 |
cache_examples=True,
|
| 572 |
+
directory_name="examples_image",
|
| 573 |
)
|
| 574 |
|
| 575 |
with gr.Tab("Video"):
|
|
|
|
| 594 |
elem_id="download",
|
| 595 |
interactive=False,
|
| 596 |
)
|
| 597 |
+
Examples(
|
| 598 |
fn=process_pipe_video,
|
| 599 |
examples=[
|
| 600 |
os.path.join("files", "video", name)
|
|
|
|
| 607 |
inputs=[video_input],
|
| 608 |
outputs=[video_output_video, video_output_files],
|
| 609 |
cache_examples=True,
|
| 610 |
+
directory_name="examples_video",
|
| 611 |
)
|
| 612 |
|
| 613 |
with gr.Tab("Bas-relief (3D)"):
|
|
|
|
| 732 |
elem_id="download",
|
| 733 |
interactive=False,
|
| 734 |
)
|
| 735 |
+
Examples(
|
| 736 |
fn=process_pipe_bas,
|
| 737 |
examples=[
|
| 738 |
[
|
|
|
|
| 798 |
],
|
| 799 |
outputs=[bas_output_viewer, bas_output_files],
|
| 800 |
cache_examples=True,
|
| 801 |
+
directory_name="examples_bas",
|
| 802 |
)
|
| 803 |
|
| 804 |
### Image tab
|
gradio_patches/examples.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
|
| 3 |
+
import gradio
|
| 4 |
+
from gradio.utils import get_cache_folder
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class Examples(gradio.helpers.Examples):
|
| 8 |
+
def __init__(self, *args, directory_name=None, **kwargs):
|
| 9 |
+
super().__init__(*args, **kwargs, _initiated_directly=False)
|
| 10 |
+
if directory_name is not None:
|
| 11 |
+
self.cached_folder = get_cache_folder() / directory_name
|
| 12 |
+
self.cached_file = Path(self.cached_folder) / "log.csv"
|
| 13 |
+
self.create()
|
flagging.py → gradio_patches/flagging.py
RENAMED
|
@@ -1,157 +1,22 @@
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
| 3 |
-
import
|
| 4 |
import json
|
| 5 |
import time
|
| 6 |
import uuid
|
| 7 |
-
from abc import ABC, abstractmethod
|
| 8 |
from collections import OrderedDict
|
| 9 |
from datetime import datetime, timezone
|
| 10 |
from pathlib import Path
|
| 11 |
-
from typing import
|
| 12 |
|
| 13 |
-
import
|
|
|
|
| 14 |
import huggingface_hub
|
|
|
|
| 15 |
from gradio_client import utils as client_utils
|
| 16 |
-
from gradio_client.documentation import document
|
| 17 |
-
|
| 18 |
-
import gradio as gr
|
| 19 |
-
from gradio import utils
|
| 20 |
-
|
| 21 |
-
if TYPE_CHECKING:
|
| 22 |
-
from gradio.components import Component
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
class FlaggingCallback(ABC):
|
| 26 |
-
"""
|
| 27 |
-
An abstract class for defining the methods that any FlaggingCallback should have.
|
| 28 |
-
"""
|
| 29 |
-
|
| 30 |
-
@abstractmethod
|
| 31 |
-
def setup(self, components: list[Component], flagging_dir: str):
|
| 32 |
-
"""
|
| 33 |
-
This method should be overridden and ensure that everything is set up correctly for flag().
|
| 34 |
-
This method gets called once at the beginning of the Interface.launch() method.
|
| 35 |
-
Parameters:
|
| 36 |
-
components: Set of components that will provide flagged data.
|
| 37 |
-
flagging_dir: A string, typically containing the path to the directory where the flagging file should be stored (provided as an argument to Interface.__init__()).
|
| 38 |
-
"""
|
| 39 |
-
pass
|
| 40 |
-
|
| 41 |
-
@abstractmethod
|
| 42 |
-
def flag(
|
| 43 |
-
self,
|
| 44 |
-
flag_data: list[Any],
|
| 45 |
-
flag_option: str = "",
|
| 46 |
-
username: str | None = None,
|
| 47 |
-
) -> int:
|
| 48 |
-
"""
|
| 49 |
-
This method should be overridden by the FlaggingCallback subclass and may contain optional additional arguments.
|
| 50 |
-
This gets called every time the <flag> button is pressed.
|
| 51 |
-
Parameters:
|
| 52 |
-
interface: The Interface object that is being used to launch the flagging interface.
|
| 53 |
-
flag_data: The data to be flagged.
|
| 54 |
-
flag_option (optional): In the case that flagging_options are provided, the flag option that is being used.
|
| 55 |
-
username (optional): The username of the user that is flagging the data, if logged in.
|
| 56 |
-
Returns:
|
| 57 |
-
(int) The total number of samples that have been flagged.
|
| 58 |
-
"""
|
| 59 |
-
pass
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
@document()
|
| 63 |
-
class HuggingFaceDatasetSaver(FlaggingCallback):
|
| 64 |
-
"""
|
| 65 |
-
A callback that saves each flagged sample (both the input and output data) to a HuggingFace dataset.
|
| 66 |
-
|
| 67 |
-
Example:
|
| 68 |
-
import gradio as gr
|
| 69 |
-
hf_writer = gr.HuggingFaceDatasetSaver(HF_API_TOKEN, "image-classification-mistakes")
|
| 70 |
-
def image_classifier(inp):
|
| 71 |
-
return {'cat': 0.3, 'dog': 0.7}
|
| 72 |
-
demo = gr.Interface(fn=image_classifier, inputs="image", outputs="label",
|
| 73 |
-
allow_flagging="manual", flagging_callback=hf_writer)
|
| 74 |
-
Guides: using-flagging
|
| 75 |
-
"""
|
| 76 |
-
|
| 77 |
-
def __init__(
|
| 78 |
-
self,
|
| 79 |
-
hf_token: str,
|
| 80 |
-
dataset_name: str,
|
| 81 |
-
private: bool = False,
|
| 82 |
-
info_filename: str = "dataset_info.json",
|
| 83 |
-
separate_dirs: bool = False,
|
| 84 |
-
):
|
| 85 |
-
"""
|
| 86 |
-
Parameters:
|
| 87 |
-
hf_token: The HuggingFace token to use to create (and write the flagged sample to) the HuggingFace dataset (defaults to the registered one).
|
| 88 |
-
dataset_name: The repo_id of the dataset to save the data to, e.g. "image-classifier-1" or "username/image-classifier-1".
|
| 89 |
-
private: Whether the dataset should be private (defaults to False).
|
| 90 |
-
info_filename: The name of the file to save the dataset info (defaults to "dataset_infos.json").
|
| 91 |
-
separate_dirs: If True, each flagged item will be saved in a separate directory. This makes the flagging more robust to concurrent editing, but may be less convenient to use.
|
| 92 |
-
"""
|
| 93 |
-
self.hf_token = hf_token
|
| 94 |
-
self.dataset_id = dataset_name # TODO: rename parameter (but ensure backward compatibility somehow)
|
| 95 |
-
self.dataset_private = private
|
| 96 |
-
self.info_filename = info_filename
|
| 97 |
-
self.separate_dirs = separate_dirs
|
| 98 |
-
|
| 99 |
-
def setup(self, components: list[Component], flagging_dir: str):
|
| 100 |
-
"""
|
| 101 |
-
Params:
|
| 102 |
-
flagging_dir (str): local directory where the dataset is cloned,
|
| 103 |
-
updated, and pushed from.
|
| 104 |
-
"""
|
| 105 |
-
# Setup dataset on the Hub
|
| 106 |
-
self.dataset_id = huggingface_hub.create_repo(
|
| 107 |
-
repo_id=self.dataset_id,
|
| 108 |
-
token=self.hf_token,
|
| 109 |
-
private=self.dataset_private,
|
| 110 |
-
repo_type="dataset",
|
| 111 |
-
exist_ok=True,
|
| 112 |
-
).repo_id
|
| 113 |
-
path_glob = "**/*.jsonl" if self.separate_dirs else "data.csv"
|
| 114 |
-
huggingface_hub.metadata_update(
|
| 115 |
-
repo_id=self.dataset_id,
|
| 116 |
-
repo_type="dataset",
|
| 117 |
-
metadata={
|
| 118 |
-
"configs": [
|
| 119 |
-
{
|
| 120 |
-
"config_name": "default",
|
| 121 |
-
"data_files": [{"split": "train", "path": path_glob}],
|
| 122 |
-
}
|
| 123 |
-
]
|
| 124 |
-
},
|
| 125 |
-
overwrite=True,
|
| 126 |
-
token=self.hf_token,
|
| 127 |
-
)
|
| 128 |
-
|
| 129 |
-
# Setup flagging dir
|
| 130 |
-
self.components = components
|
| 131 |
-
self.dataset_dir = (
|
| 132 |
-
Path(flagging_dir).absolute() / self.dataset_id.split("/")[-1]
|
| 133 |
-
)
|
| 134 |
-
self.dataset_dir.mkdir(parents=True, exist_ok=True)
|
| 135 |
-
self.infos_file = self.dataset_dir / self.info_filename
|
| 136 |
|
| 137 |
-
# Download remote files to local
|
| 138 |
-
remote_files = [self.info_filename]
|
| 139 |
-
if not self.separate_dirs:
|
| 140 |
-
# No separate dirs => means all data is in the same CSV file => download it to get its current content
|
| 141 |
-
remote_files.append("data.csv")
|
| 142 |
-
|
| 143 |
-
for filename in remote_files:
|
| 144 |
-
try:
|
| 145 |
-
huggingface_hub.hf_hub_download(
|
| 146 |
-
repo_id=self.dataset_id,
|
| 147 |
-
repo_type="dataset",
|
| 148 |
-
filename=filename,
|
| 149 |
-
local_dir=self.dataset_dir,
|
| 150 |
-
token=self.hf_token,
|
| 151 |
-
)
|
| 152 |
-
except huggingface_hub.utils.EntryNotFoundError:
|
| 153 |
-
pass
|
| 154 |
|
|
|
|
| 155 |
def flag(
|
| 156 |
self,
|
| 157 |
flag_data: list[Any],
|
|
@@ -188,93 +53,6 @@ class HuggingFaceDatasetSaver(FlaggingCallback):
|
|
| 188 |
username=username or "",
|
| 189 |
)
|
| 190 |
|
| 191 |
-
def _flag_in_dir(
|
| 192 |
-
self,
|
| 193 |
-
data_file: Path,
|
| 194 |
-
components_dir: Path,
|
| 195 |
-
path_in_repo: str | None,
|
| 196 |
-
flag_data: list[Any],
|
| 197 |
-
flag_option: str = "",
|
| 198 |
-
username: str = "",
|
| 199 |
-
) -> int:
|
| 200 |
-
# Deserialize components (write images/audio to files)
|
| 201 |
-
features, row = self._deserialize_components(
|
| 202 |
-
components_dir, flag_data, flag_option, username
|
| 203 |
-
)
|
| 204 |
-
|
| 205 |
-
# Write generic info to dataset_infos.json + upload
|
| 206 |
-
with filelock.FileLock(str(self.infos_file) + ".lock"):
|
| 207 |
-
if not self.infos_file.exists():
|
| 208 |
-
self.infos_file.write_text(
|
| 209 |
-
json.dumps({"flagged": {"features": features}})
|
| 210 |
-
)
|
| 211 |
-
|
| 212 |
-
huggingface_hub.upload_file(
|
| 213 |
-
repo_id=self.dataset_id,
|
| 214 |
-
repo_type="dataset",
|
| 215 |
-
token=self.hf_token,
|
| 216 |
-
path_in_repo=self.infos_file.name,
|
| 217 |
-
path_or_fileobj=self.infos_file,
|
| 218 |
-
)
|
| 219 |
-
|
| 220 |
-
headers = list(features.keys())
|
| 221 |
-
|
| 222 |
-
if not self.separate_dirs:
|
| 223 |
-
with filelock.FileLock(components_dir / ".lock"):
|
| 224 |
-
sample_nb = self._save_as_csv(data_file, headers=headers, row=row)
|
| 225 |
-
sample_name = str(sample_nb)
|
| 226 |
-
huggingface_hub.upload_folder(
|
| 227 |
-
repo_id=self.dataset_id,
|
| 228 |
-
repo_type="dataset",
|
| 229 |
-
commit_message=f"Flagged sample #{sample_name}",
|
| 230 |
-
path_in_repo=path_in_repo,
|
| 231 |
-
ignore_patterns="*.lock",
|
| 232 |
-
folder_path=components_dir,
|
| 233 |
-
token=self.hf_token,
|
| 234 |
-
)
|
| 235 |
-
else:
|
| 236 |
-
sample_name = self._save_as_jsonl(data_file, headers=headers, row=row)
|
| 237 |
-
sample_nb = len(
|
| 238 |
-
[path for path in self.dataset_dir.iterdir() if path.is_dir()]
|
| 239 |
-
)
|
| 240 |
-
huggingface_hub.upload_folder(
|
| 241 |
-
repo_id=self.dataset_id,
|
| 242 |
-
repo_type="dataset",
|
| 243 |
-
commit_message=f"Flagged sample #{sample_name}",
|
| 244 |
-
path_in_repo=path_in_repo,
|
| 245 |
-
ignore_patterns="*.lock",
|
| 246 |
-
folder_path=components_dir,
|
| 247 |
-
token=self.hf_token,
|
| 248 |
-
)
|
| 249 |
-
|
| 250 |
-
return sample_nb
|
| 251 |
-
|
| 252 |
-
@staticmethod
|
| 253 |
-
def _save_as_csv(data_file: Path, headers: list[str], row: list[Any]) -> int:
|
| 254 |
-
"""Save data as CSV and return the sample name (row number)."""
|
| 255 |
-
is_new = not data_file.exists()
|
| 256 |
-
|
| 257 |
-
with data_file.open("a", newline="", encoding="utf-8") as csvfile:
|
| 258 |
-
writer = csv.writer(csvfile)
|
| 259 |
-
|
| 260 |
-
# Write CSV headers if new file
|
| 261 |
-
if is_new:
|
| 262 |
-
writer.writerow(utils.sanitize_list_for_csv(headers))
|
| 263 |
-
|
| 264 |
-
# Write CSV row for flagged sample
|
| 265 |
-
writer.writerow(utils.sanitize_list_for_csv(row))
|
| 266 |
-
|
| 267 |
-
with data_file.open(encoding="utf-8") as csvfile:
|
| 268 |
-
return sum(1 for _ in csv.reader(csvfile)) - 1
|
| 269 |
-
|
| 270 |
-
@staticmethod
|
| 271 |
-
def _save_as_jsonl(data_file: Path, headers: list[str], row: list[Any]) -> str:
|
| 272 |
-
"""Save data as JSONL and return the sample name (uuid)."""
|
| 273 |
-
Path.mkdir(data_file.parent, parents=True, exist_ok=True)
|
| 274 |
-
with open(data_file, "w") as f:
|
| 275 |
-
json.dump(dict(zip(headers, row)), f)
|
| 276 |
-
return data_file.parent.name
|
| 277 |
-
|
| 278 |
def _deserialize_components(
|
| 279 |
self,
|
| 280 |
data_dir: Path,
|
|
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
| 3 |
+
import datetime
|
| 4 |
import json
|
| 5 |
import time
|
| 6 |
import uuid
|
|
|
|
| 7 |
from collections import OrderedDict
|
| 8 |
from datetime import datetime, timezone
|
| 9 |
from pathlib import Path
|
| 10 |
+
from typing import Any
|
| 11 |
|
| 12 |
+
import gradio
|
| 13 |
+
import gradio as gr
|
| 14 |
import huggingface_hub
|
| 15 |
+
from gradio import FlaggingCallback
|
| 16 |
from gradio_client import utils as client_utils
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
|
| 19 |
+
class HuggingFaceDatasetSaver(gradio.HuggingFaceDatasetSaver):
|
| 20 |
def flag(
|
| 21 |
self,
|
| 22 |
flag_data: list[Any],
|
|
|
|
| 53 |
username=username or "",
|
| 54 |
)
|
| 55 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
def _deserialize_components(
|
| 57 |
self,
|
| 58 |
data_dir: Path,
|