toshas commited on
Commit
1e6b2f7
1 Parent(s): 4467dbe

store examples in named directories

Browse files

clean up flagging sharing functionality

app.py CHANGED
@@ -36,7 +36,8 @@ from huggingface_hub import login
36
  from tqdm import tqdm
37
 
38
  from extrude import extrude_depth_3d
39
- from flagging import FlagMethod, HuggingFaceDatasetSaver
 
40
  from marigold_depth_estimation_lcm import MarigoldDepthConsistencyPipeline
41
 
42
  warnings.filterwarnings(
@@ -533,7 +534,7 @@ def run_demo_server(pipe, hf_writer=None):
533
  "Share", variant="stop", scale=1
534
  )
535
 
536
- gr.Examples(
537
  fn=process_pipe_image,
538
  examples=[
539
  os.path.join("files", "image", name)
@@ -568,6 +569,7 @@ def run_demo_server(pipe, hf_writer=None):
568
  inputs=[image_input],
569
  outputs=[image_output_slider, image_output_files],
570
  cache_examples=True,
 
571
  )
572
 
573
  with gr.Tab("Video"):
@@ -592,7 +594,7 @@ def run_demo_server(pipe, hf_writer=None):
592
  elem_id="download",
593
  interactive=False,
594
  )
595
- gr.Examples(
596
  fn=process_pipe_video,
597
  examples=[
598
  os.path.join("files", "video", name)
@@ -605,6 +607,7 @@ def run_demo_server(pipe, hf_writer=None):
605
  inputs=[video_input],
606
  outputs=[video_output_video, video_output_files],
607
  cache_examples=True,
 
608
  )
609
 
610
  with gr.Tab("Bas-relief (3D)"):
@@ -729,7 +732,7 @@ def run_demo_server(pipe, hf_writer=None):
729
  elem_id="download",
730
  interactive=False,
731
  )
732
- gr.Examples(
733
  fn=process_pipe_bas,
734
  examples=[
735
  [
@@ -795,6 +798,7 @@ def run_demo_server(pipe, hf_writer=None):
795
  ],
796
  outputs=[bas_output_viewer, bas_output_files],
797
  cache_examples=True,
 
798
  )
799
 
800
  ### Image tab
 
36
  from tqdm import tqdm
37
 
38
  from extrude import extrude_depth_3d
39
+ from gradio_patches.examples import Examples
40
+ from gradio_patches.flagging import FlagMethod, HuggingFaceDatasetSaver
41
  from marigold_depth_estimation_lcm import MarigoldDepthConsistencyPipeline
42
 
43
  warnings.filterwarnings(
 
534
  "Share", variant="stop", scale=1
535
  )
536
 
537
+ Examples(
538
  fn=process_pipe_image,
539
  examples=[
540
  os.path.join("files", "image", name)
 
569
  inputs=[image_input],
570
  outputs=[image_output_slider, image_output_files],
571
  cache_examples=True,
572
+ directory_name="examples_image",
573
  )
574
 
575
  with gr.Tab("Video"):
 
594
  elem_id="download",
595
  interactive=False,
596
  )
597
+ Examples(
598
  fn=process_pipe_video,
599
  examples=[
600
  os.path.join("files", "video", name)
 
607
  inputs=[video_input],
608
  outputs=[video_output_video, video_output_files],
609
  cache_examples=True,
610
+ directory_name="examples_video",
611
  )
612
 
613
  with gr.Tab("Bas-relief (3D)"):
 
732
  elem_id="download",
733
  interactive=False,
734
  )
735
+ Examples(
736
  fn=process_pipe_bas,
737
  examples=[
738
  [
 
798
  ],
799
  outputs=[bas_output_viewer, bas_output_files],
800
  cache_examples=True,
801
+ directory_name="examples_bas",
802
  )
803
 
804
  ### Image tab
gradio_patches/examples.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ import gradio
4
+ from gradio.utils import get_cache_folder
5
+
6
+
7
+ class Examples(gradio.helpers.Examples):
8
+ def __init__(self, *args, directory_name=None, **kwargs):
9
+ super().__init__(*args, **kwargs, _initiated_directly=False)
10
+ if directory_name is not None:
11
+ self.cached_folder = get_cache_folder() / directory_name
12
+ self.cached_file = Path(self.cached_folder) / "log.csv"
13
+ self.create()
flagging.py → gradio_patches/flagging.py RENAMED
@@ -1,157 +1,22 @@
1
  from __future__ import annotations
2
 
3
- import csv
4
  import json
5
  import time
6
  import uuid
7
- from abc import ABC, abstractmethod
8
  from collections import OrderedDict
9
  from datetime import datetime, timezone
10
  from pathlib import Path
11
- from typing import TYPE_CHECKING, Any
12
 
13
- import filelock
 
14
  import huggingface_hub
 
15
  from gradio_client import utils as client_utils
16
- from gradio_client.documentation import document
17
-
18
- import gradio as gr
19
- from gradio import utils
20
-
21
- if TYPE_CHECKING:
22
- from gradio.components import Component
23
-
24
-
25
- class FlaggingCallback(ABC):
26
- """
27
- An abstract class for defining the methods that any FlaggingCallback should have.
28
- """
29
-
30
- @abstractmethod
31
- def setup(self, components: list[Component], flagging_dir: str):
32
- """
33
- This method should be overridden and ensure that everything is set up correctly for flag().
34
- This method gets called once at the beginning of the Interface.launch() method.
35
- Parameters:
36
- components: Set of components that will provide flagged data.
37
- flagging_dir: A string, typically containing the path to the directory where the flagging file should be stored (provided as an argument to Interface.__init__()).
38
- """
39
- pass
40
-
41
- @abstractmethod
42
- def flag(
43
- self,
44
- flag_data: list[Any],
45
- flag_option: str = "",
46
- username: str | None = None,
47
- ) -> int:
48
- """
49
- This method should be overridden by the FlaggingCallback subclass and may contain optional additional arguments.
50
- This gets called every time the <flag> button is pressed.
51
- Parameters:
52
- interface: The Interface object that is being used to launch the flagging interface.
53
- flag_data: The data to be flagged.
54
- flag_option (optional): In the case that flagging_options are provided, the flag option that is being used.
55
- username (optional): The username of the user that is flagging the data, if logged in.
56
- Returns:
57
- (int) The total number of samples that have been flagged.
58
- """
59
- pass
60
-
61
-
62
- @document()
63
- class HuggingFaceDatasetSaver(FlaggingCallback):
64
- """
65
- A callback that saves each flagged sample (both the input and output data) to a HuggingFace dataset.
66
-
67
- Example:
68
- import gradio as gr
69
- hf_writer = gr.HuggingFaceDatasetSaver(HF_API_TOKEN, "image-classification-mistakes")
70
- def image_classifier(inp):
71
- return {'cat': 0.3, 'dog': 0.7}
72
- demo = gr.Interface(fn=image_classifier, inputs="image", outputs="label",
73
- allow_flagging="manual", flagging_callback=hf_writer)
74
- Guides: using-flagging
75
- """
76
-
77
- def __init__(
78
- self,
79
- hf_token: str,
80
- dataset_name: str,
81
- private: bool = False,
82
- info_filename: str = "dataset_info.json",
83
- separate_dirs: bool = False,
84
- ):
85
- """
86
- Parameters:
87
- hf_token: The HuggingFace token to use to create (and write the flagged sample to) the HuggingFace dataset (defaults to the registered one).
88
- dataset_name: The repo_id of the dataset to save the data to, e.g. "image-classifier-1" or "username/image-classifier-1".
89
- private: Whether the dataset should be private (defaults to False).
90
- info_filename: The name of the file to save the dataset info (defaults to "dataset_infos.json").
91
- separate_dirs: If True, each flagged item will be saved in a separate directory. This makes the flagging more robust to concurrent editing, but may be less convenient to use.
92
- """
93
- self.hf_token = hf_token
94
- self.dataset_id = dataset_name # TODO: rename parameter (but ensure backward compatibility somehow)
95
- self.dataset_private = private
96
- self.info_filename = info_filename
97
- self.separate_dirs = separate_dirs
98
-
99
- def setup(self, components: list[Component], flagging_dir: str):
100
- """
101
- Params:
102
- flagging_dir (str): local directory where the dataset is cloned,
103
- updated, and pushed from.
104
- """
105
- # Setup dataset on the Hub
106
- self.dataset_id = huggingface_hub.create_repo(
107
- repo_id=self.dataset_id,
108
- token=self.hf_token,
109
- private=self.dataset_private,
110
- repo_type="dataset",
111
- exist_ok=True,
112
- ).repo_id
113
- path_glob = "**/*.jsonl" if self.separate_dirs else "data.csv"
114
- huggingface_hub.metadata_update(
115
- repo_id=self.dataset_id,
116
- repo_type="dataset",
117
- metadata={
118
- "configs": [
119
- {
120
- "config_name": "default",
121
- "data_files": [{"split": "train", "path": path_glob}],
122
- }
123
- ]
124
- },
125
- overwrite=True,
126
- token=self.hf_token,
127
- )
128
-
129
- # Setup flagging dir
130
- self.components = components
131
- self.dataset_dir = (
132
- Path(flagging_dir).absolute() / self.dataset_id.split("/")[-1]
133
- )
134
- self.dataset_dir.mkdir(parents=True, exist_ok=True)
135
- self.infos_file = self.dataset_dir / self.info_filename
136
 
137
- # Download remote files to local
138
- remote_files = [self.info_filename]
139
- if not self.separate_dirs:
140
- # No separate dirs => means all data is in the same CSV file => download it to get its current content
141
- remote_files.append("data.csv")
142
-
143
- for filename in remote_files:
144
- try:
145
- huggingface_hub.hf_hub_download(
146
- repo_id=self.dataset_id,
147
- repo_type="dataset",
148
- filename=filename,
149
- local_dir=self.dataset_dir,
150
- token=self.hf_token,
151
- )
152
- except huggingface_hub.utils.EntryNotFoundError:
153
- pass
154
 
 
155
  def flag(
156
  self,
157
  flag_data: list[Any],
@@ -188,93 +53,6 @@ class HuggingFaceDatasetSaver(FlaggingCallback):
188
  username=username or "",
189
  )
190
 
191
- def _flag_in_dir(
192
- self,
193
- data_file: Path,
194
- components_dir: Path,
195
- path_in_repo: str | None,
196
- flag_data: list[Any],
197
- flag_option: str = "",
198
- username: str = "",
199
- ) -> int:
200
- # Deserialize components (write images/audio to files)
201
- features, row = self._deserialize_components(
202
- components_dir, flag_data, flag_option, username
203
- )
204
-
205
- # Write generic info to dataset_infos.json + upload
206
- with filelock.FileLock(str(self.infos_file) + ".lock"):
207
- if not self.infos_file.exists():
208
- self.infos_file.write_text(
209
- json.dumps({"flagged": {"features": features}})
210
- )
211
-
212
- huggingface_hub.upload_file(
213
- repo_id=self.dataset_id,
214
- repo_type="dataset",
215
- token=self.hf_token,
216
- path_in_repo=self.infos_file.name,
217
- path_or_fileobj=self.infos_file,
218
- )
219
-
220
- headers = list(features.keys())
221
-
222
- if not self.separate_dirs:
223
- with filelock.FileLock(components_dir / ".lock"):
224
- sample_nb = self._save_as_csv(data_file, headers=headers, row=row)
225
- sample_name = str(sample_nb)
226
- huggingface_hub.upload_folder(
227
- repo_id=self.dataset_id,
228
- repo_type="dataset",
229
- commit_message=f"Flagged sample #{sample_name}",
230
- path_in_repo=path_in_repo,
231
- ignore_patterns="*.lock",
232
- folder_path=components_dir,
233
- token=self.hf_token,
234
- )
235
- else:
236
- sample_name = self._save_as_jsonl(data_file, headers=headers, row=row)
237
- sample_nb = len(
238
- [path for path in self.dataset_dir.iterdir() if path.is_dir()]
239
- )
240
- huggingface_hub.upload_folder(
241
- repo_id=self.dataset_id,
242
- repo_type="dataset",
243
- commit_message=f"Flagged sample #{sample_name}",
244
- path_in_repo=path_in_repo,
245
- ignore_patterns="*.lock",
246
- folder_path=components_dir,
247
- token=self.hf_token,
248
- )
249
-
250
- return sample_nb
251
-
252
- @staticmethod
253
- def _save_as_csv(data_file: Path, headers: list[str], row: list[Any]) -> int:
254
- """Save data as CSV and return the sample name (row number)."""
255
- is_new = not data_file.exists()
256
-
257
- with data_file.open("a", newline="", encoding="utf-8") as csvfile:
258
- writer = csv.writer(csvfile)
259
-
260
- # Write CSV headers if new file
261
- if is_new:
262
- writer.writerow(utils.sanitize_list_for_csv(headers))
263
-
264
- # Write CSV row for flagged sample
265
- writer.writerow(utils.sanitize_list_for_csv(row))
266
-
267
- with data_file.open(encoding="utf-8") as csvfile:
268
- return sum(1 for _ in csv.reader(csvfile)) - 1
269
-
270
- @staticmethod
271
- def _save_as_jsonl(data_file: Path, headers: list[str], row: list[Any]) -> str:
272
- """Save data as JSONL and return the sample name (uuid)."""
273
- Path.mkdir(data_file.parent, parents=True, exist_ok=True)
274
- with open(data_file, "w") as f:
275
- json.dump(dict(zip(headers, row)), f)
276
- return data_file.parent.name
277
-
278
  def _deserialize_components(
279
  self,
280
  data_dir: Path,
 
1
  from __future__ import annotations
2
 
3
+ import datetime
4
  import json
5
  import time
6
  import uuid
 
7
  from collections import OrderedDict
8
  from datetime import datetime, timezone
9
  from pathlib import Path
10
+ from typing import Any
11
 
12
+ import gradio
13
+ import gradio as gr
14
  import huggingface_hub
15
+ from gradio import FlaggingCallback
16
  from gradio_client import utils as client_utils
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
+ class HuggingFaceDatasetSaver(gradio.HuggingFaceDatasetSaver):
20
  def flag(
21
  self,
22
  flag_data: list[Any],
 
53
  username=username or "",
54
  )
55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  def _deserialize_components(
57
  self,
58
  data_dir: Path,