quantumiracle-git commited on
Commit
973ca84
·
1 Parent(s): 348b569

Upload hfserver.py

Browse files
Files changed (1) hide show
  1. hfserver.py +342 -0
hfserver.py ADDED
@@ -0,0 +1,342 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import csv
4
+ import datetime
5
+ import io
6
+ import json
7
+ import os
8
+ from abc import ABC, abstractmethod
9
+ from typing import TYPE_CHECKING, Any, List, Optional
10
+
11
+ import gradio as gr
12
+ from gradio import encryptor, utils
13
+
14
+ if TYPE_CHECKING:
15
+ from gradio.components import Component
16
+
17
+
18
+ class FlaggingCallback(ABC):
19
+ """
20
+ An abstract class for defining the methods that any FlaggingCallback should have.
21
+ """
22
+
23
+ @abstractmethod
24
+ def setup(self, components: List[Component], flagging_dir: str):
25
+ """
26
+ This method should be overridden and ensure that everything is set up correctly for flag().
27
+ This method gets called once at the beginning of the Interface.launch() method.
28
+ Parameters:
29
+ components: Set of components that will provide flagged data.
30
+ flagging_dir: A string, typically containing the path to the directory where the flagging file should be storied (provided as an argument to Interface.__init__()).
31
+ """
32
+ pass
33
+
34
+ @abstractmethod
35
+ def flag(
36
+ self,
37
+ flag_data: List[Any],
38
+ flag_option: Optional[str] = None,
39
+ flag_index: Optional[int] = None,
40
+ username: Optional[str] = None,
41
+ ) -> int:
42
+ """
43
+ This method should be overridden by the FlaggingCallback subclass and may contain optional additional arguments.
44
+ This gets called every time the <flag> button is pressed.
45
+ Parameters:
46
+ interface: The Interface object that is being used to launch the flagging interface.
47
+ flag_data: The data to be flagged.
48
+ flag_option (optional): In the case that flagging_options are provided, the flag option that is being used.
49
+ flag_index (optional): The index of the sample that is being flagged.
50
+ username (optional): The username of the user that is flagging the data, if logged in.
51
+ Returns:
52
+ (int) The total number of samples that have been flagged.
53
+ """
54
+ pass
55
+
56
+
57
+ class SimpleCSVLogger(FlaggingCallback):
58
+ """
59
+ A simple example implementation of the FlaggingCallback abstract class
60
+ provided for illustrative purposes.
61
+ """
62
+
63
+ def setup(self, components: List[Component], flagging_dir: str):
64
+ self.components = components
65
+ self.flagging_dir = flagging_dir
66
+ os.makedirs(flagging_dir, exist_ok=True)
67
+
68
+ def flag(
69
+ self,
70
+ flag_data: List[Any],
71
+ flag_option: Optional[str] = None,
72
+ flag_index: Optional[int] = None,
73
+ username: Optional[str] = None,
74
+ ) -> int:
75
+ flagging_dir = self.flagging_dir
76
+ log_filepath = os.path.join(flagging_dir, "log.csv")
77
+
78
+ csv_data = []
79
+ for component, sample in zip(self.components, flag_data):
80
+ csv_data.append(
81
+ component.save_flagged(
82
+ flagging_dir,
83
+ component.label,
84
+ sample,
85
+ None,
86
+ )
87
+ )
88
+
89
+ with open(log_filepath, "a", newline="") as csvfile:
90
+ writer = csv.writer(csvfile, quoting=csv.QUOTE_NONNUMERIC, quotechar="'")
91
+ writer.writerow(csv_data)
92
+
93
+ with open(log_filepath, "r") as csvfile:
94
+ line_count = len([None for row in csv.reader(csvfile)]) - 1
95
+ return line_count
96
+
97
+
98
+ class CSVLogger(FlaggingCallback):
99
+ """
100
+ The default implementation of the FlaggingCallback abstract class.
101
+ Logs the input and output data to a CSV file. Supports encryption.
102
+ """
103
+
104
+ def setup(
105
+ self,
106
+ components: List[Component],
107
+ flagging_dir: str,
108
+ encryption_key: Optional[str] = None,
109
+ ):
110
+ self.components = components
111
+ self.flagging_dir = flagging_dir
112
+ self.encryption_key = encryption_key
113
+ os.makedirs(flagging_dir, exist_ok=True)
114
+
115
+ def flag(
116
+ self,
117
+ flag_data: List[Any],
118
+ flag_option: Optional[str] = None,
119
+ flag_index: Optional[int] = None,
120
+ username: Optional[str] = None,
121
+ ) -> int:
122
+ flagging_dir = self.flagging_dir
123
+ log_filepath = os.path.join(flagging_dir, "log.csv")
124
+ is_new = not os.path.exists(log_filepath)
125
+
126
+ if flag_index is None:
127
+ csv_data = []
128
+ for component, sample in zip(self.components, flag_data):
129
+ csv_data.append(
130
+ component.save_flagged(
131
+ flagging_dir,
132
+ component.label,
133
+ sample,
134
+ self.encryption_key,
135
+ )
136
+ if sample is not None
137
+ else ""
138
+ )
139
+ csv_data.append(flag_option if flag_option is not None else "")
140
+ csv_data.append(username if username is not None else "")
141
+ csv_data.append(str(datetime.datetime.now()))
142
+ if is_new:
143
+ headers = [component.label for component in self.components] + [
144
+ "flag",
145
+ "username",
146
+ "timestamp",
147
+ ]
148
+
149
+ def replace_flag_at_index(file_content):
150
+ file_content = io.StringIO(file_content)
151
+ content = list(csv.reader(file_content))
152
+ header = content[0]
153
+ flag_col_index = header.index("flag")
154
+ content[flag_index][flag_col_index] = flag_option
155
+ output = io.StringIO()
156
+ writer = csv.writer(output, quoting=csv.QUOTE_NONNUMERIC, quotechar="'")
157
+ writer.writerows(content)
158
+ return output.getvalue()
159
+
160
+ if self.encryption_key:
161
+ output = io.StringIO()
162
+ if not is_new:
163
+ with open(log_filepath, "rb") as csvfile:
164
+ encrypted_csv = csvfile.read()
165
+ decrypted_csv = encryptor.decrypt(
166
+ self.encryption_key, encrypted_csv
167
+ )
168
+ file_content = decrypted_csv.decode()
169
+ if flag_index is not None:
170
+ file_content = replace_flag_at_index(file_content)
171
+ output.write(file_content)
172
+ writer = csv.writer(output, quoting=csv.QUOTE_NONNUMERIC, quotechar="'")
173
+ if flag_index is None:
174
+ if is_new:
175
+ writer.writerow(headers)
176
+ writer.writerow(csv_data)
177
+ with open(log_filepath, "wb") as csvfile:
178
+ csvfile.write(
179
+ encryptor.encrypt(self.encryption_key, output.getvalue().encode())
180
+ )
181
+ else:
182
+ if flag_index is None:
183
+ with open(log_filepath, "a", newline="") as csvfile:
184
+ writer = csv.writer(
185
+ csvfile, quoting=csv.QUOTE_NONNUMERIC, quotechar="'"
186
+ )
187
+ if is_new:
188
+ writer.writerow(headers)
189
+ writer.writerow(csv_data)
190
+ else:
191
+ with open(log_filepath) as csvfile:
192
+ file_content = csvfile.read()
193
+ file_content = replace_flag_at_index(file_content)
194
+ with open(
195
+ log_filepath, "w", newline=""
196
+ ) as csvfile: # newline parameter needed for Windows
197
+ csvfile.write(file_content)
198
+ with open(log_filepath, "r") as csvfile:
199
+ line_count = len([None for row in csv.reader(csvfile)]) - 1
200
+ return line_count
201
+
202
+
203
+ class HuggingFaceDatasetSaver(FlaggingCallback):
204
+ """
205
+ A FlaggingCallback that saves flagged data to a HuggingFace dataset.
206
+ """
207
+
208
+ def __init__(
209
+ self,
210
+ hf_foken: str,
211
+ dataset_name: str,
212
+ organization: Optional[str] = None,
213
+ private: bool = False,
214
+ verbose: bool = True,
215
+ ):
216
+ """
217
+ Params:
218
+ hf_token (str): The token to use to access the huggingface API.
219
+ dataset_name (str): The name of the dataset to save the data to, e.g.
220
+ "image-classifier-1"
221
+ organization (str): The name of the organization to which to attach
222
+ the datasets. If None, the dataset attaches to the user only.
223
+ private (bool): If the dataset does not already exist, whether it
224
+ should be created as a private dataset or public. Private datasets
225
+ may require paid huggingface.co accounts
226
+ verbose (bool): Whether to print out the status of the dataset
227
+ creation.
228
+ """
229
+ self.hf_foken = hf_foken
230
+ self.dataset_name = dataset_name
231
+ self.organization_name = organization
232
+ self.dataset_private = private
233
+ self.verbose = verbose
234
+
235
+ def setup(self, components: List[Component], flagging_dir: str):
236
+ """
237
+ Params:
238
+ flagging_dir (str): local directory where the dataset is cloned,
239
+ updated, and pushed from.
240
+ """
241
+ try:
242
+ import huggingface_hub
243
+ except (ImportError, ModuleNotFoundError):
244
+ raise ImportError(
245
+ "Package `huggingface_hub` not found is needed "
246
+ "for HuggingFaceDatasetSaver. Try 'pip install huggingface_hub'."
247
+ )
248
+ path_to_dataset_repo = huggingface_hub.create_repo(
249
+ # name=self.dataset_name, https://github.com/huggingface/huggingface_hub/blob/main/src/huggingface_hub/hf_api.py
250
+ repo_id=self.dataset_name,
251
+ token=self.hf_foken,
252
+ private=self.dataset_private,
253
+ repo_type="dataset",
254
+ exist_ok=True,
255
+ )
256
+ self.path_to_dataset_repo = path_to_dataset_repo # e.g. "https://huggingface.co/datasets/abidlabs/test-audio-10"
257
+ self.components = components
258
+ self.flagging_dir = flagging_dir
259
+ self.dataset_dir = os.path.join(flagging_dir, self.dataset_name)
260
+ self.repo = huggingface_hub.Repository(
261
+ local_dir=self.dataset_dir,
262
+ clone_from=path_to_dataset_repo,
263
+ use_auth_token=self.hf_foken,
264
+ )
265
+ self.repo.git_pull()
266
+
267
+ # Should filename be user-specified?
268
+ self.log_file = os.path.join(self.dataset_dir, "data.csv")
269
+ self.infos_file = os.path.join(self.dataset_dir, "dataset_infos.json")
270
+
271
+ def flag(
272
+ self,
273
+ flag_data: List[Any],
274
+ flag_option: Optional[str] = None,
275
+ flag_index: Optional[int] = None,
276
+ username: Optional[str] = None,
277
+ ) -> int:
278
+ is_new = not os.path.exists(self.log_file)
279
+ infos = {"flagged": {"features": {}}}
280
+
281
+ with open(self.log_file, "a", newline="") as csvfile:
282
+ writer = csv.writer(csvfile)
283
+
284
+ # File previews for certain input and output types
285
+ file_preview_types = {
286
+ gr.inputs.Audio: "Audio",
287
+ gr.outputs.Audio: "Audio",
288
+ gr.inputs.Image: "Image",
289
+ gr.outputs.Image: "Image",
290
+ }
291
+
292
+ # Generate the headers and dataset_infos
293
+ if is_new:
294
+ headers = []
295
+
296
+ for component, sample in zip(self.components, flag_data):
297
+ headers.append(component.label)
298
+ headers.append(component.label)
299
+ infos["flagged"]["features"][component.label] = {
300
+ "dtype": "string",
301
+ "_type": "Value",
302
+ }
303
+ if isinstance(component, tuple(file_preview_types)):
304
+ headers.append(component.label + " file")
305
+ for _component, _type in file_preview_types.items():
306
+ if isinstance(component, _component):
307
+ infos["flagged"]["features"][
308
+ component.label + " file"
309
+ ] = {"_type": _type}
310
+ break
311
+
312
+ headers.append("flag")
313
+ infos["flagged"]["features"]["flag"] = {
314
+ "dtype": "string",
315
+ "_type": "Value",
316
+ }
317
+
318
+ writer.writerow(headers)
319
+
320
+ # Generate the row corresponding to the flagged sample
321
+ csv_data = []
322
+ for component, sample in zip(self.components, flag_data):
323
+ filepath = component.save_flagged(
324
+ self.dataset_dir, component.label, sample, None
325
+ )
326
+ csv_data.append(filepath)
327
+ if isinstance(component, tuple(file_preview_types)):
328
+ csv_data.append(
329
+ "{}/resolve/main/{}".format(self.path_to_dataset_repo, filepath)
330
+ )
331
+ csv_data.append(flag_option if flag_option is not None else "")
332
+ writer.writerow(csv_data)
333
+
334
+ if is_new:
335
+ json.dump(infos, open(self.infos_file, "w"))
336
+
337
+ with open(self.log_file, "r") as csvfile:
338
+ line_count = len([None for row in csv.reader(csvfile)]) - 1
339
+
340
+ self.repo.push_to_hub(commit_message="Flagged sample #{}".format(line_count))
341
+
342
+ return line_count