Yurii Paniv commited on
Commit
a1d768d
1 Parent(s): 312c1fb

Fix data logger

Browse files
Files changed (2) hide show
  1. app.py +20 -12
  2. data_logger.py +41 -0
app.py CHANGED
@@ -4,7 +4,23 @@ from datetime import datetime
4
  from enum import Enum
5
  from ukrainian_tts.tts import TTS, Stress, Voices
6
  from torch.cuda import is_available
7
- from os import environ
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  class StressOption(Enum):
10
  AutomaticStress = "Автоматичні наголоси (за словником) 📖"
@@ -53,18 +69,13 @@ def tts(text: str, voice: str, stress: str):
53
  text if len(text) < text_limit else text[0:text_limit]
54
  ) # mitigate crashes on hf space
55
 
 
 
56
 
57
  with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as fp:
58
  _, text = ukr_tts.tts(text, speaker_name, stress_selected, fp)
59
  return fp.name, text
60
 
61
- if environ["HF_API_TOKEN"] is None:
62
- print("Using default flagging.")
63
- flagging_callback = gr.CSVLogger()
64
- else:
65
- print("Using HuggingFace dataset saver.")
66
- flagging_callback = gr.HuggingFaceDatasetSaver(hf_token=environ["HF_API_TOKEN"], dataset_name="uk-tts-output", private=True)
67
-
68
 
69
  with open("README.md") as file:
70
  article = file.read()
@@ -122,9 +133,6 @@ iface = gr.Interface(
122
  VoiceOption.Lada.value,
123
  StressOption.AutomaticStress.value,
124
  ],
125
- ],
126
- allow_flagging="auto",
127
- flagging_callback=flagging_callback,
128
- flagging_options=None
129
  )
130
  iface.launch(enable_queue=True)
4
  from enum import Enum
5
  from ukrainian_tts.tts import TTS, Stress, Voices
6
  from torch.cuda import is_available
7
+ from os import getenv
8
+ from data_logger import log_data
9
+ from threading import Thread
10
+ from queue import Queue
11
+
12
+
13
+ def check_thread(logging_queue: Queue):
14
+ logging_callback = log_data(hf_token=getenv("HF_API_TOKEN"), dataset_name="uk-tts-output", private=True)
15
+ while True:
16
+ item = logging_queue.get()
17
+ logging_callback(item)
18
+
19
+ if getenv("HF_API_TOKEN") is not None:
20
+ log_queue = Queue()
21
+ t = Thread(target=check_thread, args=(log_queue,))
22
+ t.start()
23
+
24
 
25
  class StressOption(Enum):
26
  AutomaticStress = "Автоматичні наголоси (за словником) 📖"
69
  text if len(text) < text_limit else text[0:text_limit]
70
  ) # mitigate crashes on hf space
71
 
72
+ if getenv("HF_API_TOKEN") is not None:
73
+ log_queue.put([text, speaker_name, stress_selected, str(datetime.utcnow())])
74
 
75
  with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as fp:
76
  _, text = ukr_tts.tts(text, speaker_name, stress_selected, fp)
77
  return fp.name, text
78
 
 
 
 
 
 
 
 
79
 
80
  with open("README.md") as file:
81
  article = file.read()
133
  VoiceOption.Lada.value,
134
  StressOption.AutomaticStress.value,
135
  ],
136
+ ]
 
 
 
137
  )
138
  iface.launch(enable_queue=True)
data_logger.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from gradio import utils
2
+ import os
3
+ import csv
4
+ import huggingface_hub
5
+
6
+ def log_data(hf_token: str, dataset_name: str, private=True):
7
+ path_to_dataset_repo = huggingface_hub.create_repo(
8
+ name=dataset_name,
9
+ token=hf_token,
10
+ private=private,
11
+ repo_type="dataset",
12
+ exist_ok=True,
13
+ )
14
+ flagging_dir = "flagged"
15
+ dataset_dir = os.path.join(flagging_dir, dataset_name)
16
+ repo = huggingface_hub.Repository(
17
+ local_dir=dataset_dir,
18
+ clone_from=path_to_dataset_repo,
19
+ use_auth_token=hf_token,
20
+ )
21
+ repo.git_pull(lfs=True)
22
+ log_file = os.path.join(dataset_dir, "data.csv")
23
+
24
+ def log_function(data):
25
+ repo.git_pull(lfs=True)
26
+
27
+ with open(log_file, "a", newline="", encoding="utf-8") as csvfile:
28
+ writer = csv.writer(csvfile)
29
+
30
+ csv_data = data
31
+ writer.writerow(utils.sanitize_list_for_csv(csv_data))
32
+
33
+ with open(log_file, "r", encoding="utf-8") as csvfile:
34
+ line_count = len([None for row in csv.reader(csvfile)]) - 1
35
+
36
+ repo.push_to_hub(commit_message="Flagged sample #{}".format(line_count))
37
+
38
+ return line_count
39
+
40
+ return log_function
41
+