ms180 commited on
Commit
dbdb417
1 Parent(s): 1efbfe3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +268 -268
app.py CHANGED
@@ -1,268 +1,268 @@
1
- import glob
2
- import os
3
- import shutil
4
- import sys
5
- import re
6
- import tempfile
7
- import zipfile
8
- from pathlib import Path
9
-
10
- import gradio as gr
11
-
12
- from finetune import finetune_model, baseline_model
13
-
14
- from language import languages
15
- from task import tasks
16
- import matplotlib.pyplot as plt
17
-
18
-
19
- os.environ['TEMP_DIR'] = tempfile.mkdtemp()
20
-
21
- def load_markdown():
22
- with open("intro.md", "r") as f:
23
- return f.read()
24
-
25
-
26
- def read_logs():
27
- try:
28
- with open(f"output.log", "r") as f:
29
- return f.read()
30
- except:
31
- return None
32
-
33
-
34
- def plot_loss_acc(temp_dir, log_every):
35
- sys.stdout.flush()
36
- lines = []
37
- with open("output.log", "r") as f:
38
- for line in f.readlines():
39
- if re.match(r"^\[\d+\] - loss: \d+\.\d+ - acc: \d+\.\d+$", line):
40
- lines.append(line)
41
-
42
- losses = []
43
- acces = []
44
- if len(lines) == 0:
45
- return None, None
46
-
47
- for line in lines:
48
- _, loss, acc = line.split(" - ")
49
- losses.append(float(loss.split(":")[1].strip()))
50
- acces.append(float(acc.split(":")[1].strip()))
51
-
52
- x = [i * log_every for i in range(1, len(losses) + 1)]
53
-
54
- plt.plot(x, losses, label="loss")
55
- plt.xlim(log_every // 2, x[-1] + log_every // 2)
56
- plt.savefig(f"{temp_dir}/loss.png")
57
- plt.clf()
58
- plt.plot(x, acces, label="acc")
59
- plt.xlim(log_every // 2, x[-1] + log_every // 2)
60
- plt.savefig(f"{temp_dir}/acc.png")
61
- plt.clf()
62
- return f"{temp_dir}/acc.png", f"{temp_dir}/loss.png"
63
-
64
-
65
- def upload_file(fileobj, temp_dir):
66
- """
67
- Upload a file and check the uploaded zip file.
68
- """
69
- # First check if a file is a zip file.
70
- if not zipfile.is_zipfile(fileobj.name):
71
- raise gr.Error("Please upload a zip file.")
72
-
73
- # Then unzip file
74
- shutil.unpack_archive(fileobj.name, temp_dir)
75
-
76
- # check zip file
77
- if not os.path.exists(os.path.join(temp_dir, "text")):
78
- raise gr.Error("Please upload a valid zip file.")
79
-
80
- if not os.path.exists(os.path.join(temp_dir, "text_ctc")):
81
- raise gr.Error("Please upload a valid zip file.")
82
-
83
- if not os.path.exists(os.path.join(temp_dir, "audio")):
84
- raise gr.Error("Please upload a valid zip file.")
85
-
86
- # check if all texts and audio matches
87
- audio_ids = []
88
- with open(os.path.join(temp_dir, "text"), "r") as f:
89
- for line in f.readlines():
90
- audio_ids.append(line.split(maxsplit=1)[0])
91
-
92
- with open(os.path.join(temp_dir, "text_ctc"), "r") as f:
93
- ctc_audio_ids = []
94
- for line in f.readlines():
95
- ctc_audio_ids.append(line.split(maxsplit=1)[0])
96
-
97
- if len(audio_ids) != len(ctc_audio_ids):
98
- raise gr.Error(
99
- f"Length of `text` ({len(audio_ids)}) and `text_ctc` ({len(ctc_audio_ids)}) is different."
100
- )
101
-
102
- if set(audio_ids) != set(ctc_audio_ids):
103
- raise gr.Error(f"`text` and `text_ctc` have different audio ids.")
104
-
105
- for audio_id in glob.glob(os.path.join(temp_dir, "audio", "*")):
106
- if not Path(audio_id).stem in audio_ids:
107
- raise gr.Error(f"Audio id {audio_id} is not in `text` or `text_ctc`.")
108
-
109
- gr.Info("Successfully uploaded and validated zip file.")
110
-
111
- return [fileobj]
112
-
113
-
114
- with gr.Blocks(title="OWSM-finetune") as demo:
115
- tempdir_path = gr.State(os.environ['TEMP_DIR'])
116
- gr.Markdown(
117
- """# OWSM finetune demo!
118
-
119
- Finetune `owsm_v3.1_ebf_base` with your own dataset!
120
- Due to resource limitation, you can only train 50 epochs on maximum.
121
-
122
- ## Upload dataset and define settings
123
- """
124
- )
125
-
126
- # main contents
127
- with gr.Row():
128
- with gr.Column():
129
- file_output = gr.File()
130
- upload_button = gr.UploadButton("Click to Upload a File", file_count="single")
131
- upload_button.upload(
132
- upload_file, [upload_button, tempdir_path], [file_output]
133
- )
134
-
135
- with gr.Column():
136
- lang = gr.Dropdown(
137
- languages["espnet/owsm_v3.1_ebf_base"],
138
- label="Language",
139
- info="Choose language!",
140
- value="jpn",
141
- interactive=True,
142
- )
143
- task = gr.Dropdown(
144
- tasks["espnet/owsm_v3.1_ebf_base"],
145
- label="Task",
146
- info="Choose task!",
147
- value="asr",
148
- interactive=True,
149
- )
150
-
151
- gr.Markdown("## Set training settings")
152
-
153
- with gr.Row():
154
- with gr.Column():
155
- log_every = gr.Number(value=10, label="log_every", interactive=True)
156
- max_epoch = gr.Slider(1, 10, step=1, label="max_epoch", interactive=True)
157
- scheduler = gr.Dropdown(
158
- ["warmuplr"], label="warmup", value="warmuplr", interactive=True
159
- )
160
- warmup_steps = gr.Number(
161
- value=100, label="warmup_steps", interactive=True
162
- )
163
-
164
- with gr.Column():
165
- optimizer = gr.Dropdown(
166
- ["adam", "adamw", "sgd", "adadelta", "adagrad", "adamax", "asgd", "rmsprop"],
167
- label="optimizer",
168
- value="adam",
169
- interactive=True
170
- )
171
- learning_rate = gr.Number(
172
- value=1e-4, label="learning_rate", interactive=True
173
- )
174
- weight_decay = gr.Number(
175
- value=0.000001, label="weight_decay", interactive=True
176
- )
177
-
178
- gr.Markdown("## Logs and plots")
179
-
180
- with gr.Row():
181
- with gr.Column():
182
- log_output = gr.Textbox(
183
- show_label=False,
184
- interactive=False,
185
- max_lines=23,
186
- lines=23,
187
- )
188
- demo.load(read_logs, None, log_output, every=2)
189
-
190
- with gr.Column():
191
- log_acc = gr.Image(label="Accuracy", show_label=True, interactive=False)
192
- log_loss = gr.Image(label="Loss", show_label=True, interactive=False)
193
- demo.load(plot_loss_acc, [tempdir_path, log_every], [log_acc, log_loss], every=10)
194
-
195
- with gr.Row():
196
- with gr.Column():
197
- ref_text = gr.Textbox(
198
- label="Reference text",
199
- show_label=True,
200
- interactive=False,
201
- max_lines=10,
202
- lines=10,
203
- )
204
- with gr.Column():
205
- base_text = gr.Textbox(
206
- label="Baseline text",
207
- show_label=True,
208
- interactive=False,
209
- max_lines=10,
210
- lines=10,
211
- )
212
-
213
- with gr.Row():
214
- with gr.Column():
215
- hyp_text = gr.Textbox(
216
- label="Hypothesis text",
217
- show_label=True,
218
- interactive=False,
219
- max_lines=10,
220
- lines=10,
221
- )
222
- with gr.Column():
223
- trained_model = gr.File(
224
- label="Trained model",
225
- interactive=False,
226
- )
227
-
228
- with gr.Row():
229
- with gr.Column():
230
- baseline_btn = gr.Button("Run Baseline", variant="secondary")
231
- baseline_btn.click(
232
- baseline_model,
233
- [
234
- lang,
235
- task,
236
- tempdir_path,
237
- ],
238
- [ref_text, base_text]
239
- )
240
- with gr.Column():
241
- finetune_btn = gr.Button("Finetune Model", variant="primary")
242
- finetune_btn.click(
243
- finetune_model,
244
- [
245
- lang,
246
- task,
247
- tempdir_path,
248
- log_every,
249
- max_epoch,
250
- scheduler,
251
- warmup_steps,
252
- optimizer,
253
- learning_rate,
254
- weight_decay,
255
- ],
256
- [trained_model, hyp_text]
257
- )
258
-
259
- gr.Markdown(load_markdown())
260
-
261
- if __name__ == "__main__":
262
- try:
263
- demo.queue().launch()
264
- except:
265
- print("Unexpected error:", sys.exc_info()[0])
266
- raise
267
- finally:
268
- shutil.rmtree(os.environ['TEMP_DIR'])
 
1
+ import glob
2
+ import os
3
+ import shutil
4
+ import sys
5
+ import re
6
+ import tempfile
7
+ import zipfile
8
+ from pathlib import Path
9
+
10
+ import gradio as gr
11
+
12
+ from finetune import finetune_model, baseline_model
13
+
14
+ from language import languages
15
+ from task import tasks
16
+ import matplotlib.pyplot as plt
17
+
18
+
19
+ os.environ['TEMP_DIR'] = tempfile.mkdtemp()
20
+
21
+ def load_markdown():
22
+ with open("intro.md", "r") as f:
23
+ return f.read()
24
+
25
+
26
+ def read_logs():
27
+ try:
28
+ with open(f"output.log", "r") as f:
29
+ return f.read()
30
+ except:
31
+ return None
32
+
33
+
34
+ def plot_loss_acc(temp_dir, log_every):
35
+ sys.stdout.flush()
36
+ lines = []
37
+ with open("output.log", "r") as f:
38
+ for line in f.readlines():
39
+ if re.match(r"^\[\d+\] - loss: \d+\.\d+ - acc: \d+\.\d+$", line):
40
+ lines.append(line)
41
+
42
+ losses = []
43
+ acces = []
44
+ if len(lines) == 0:
45
+ return None, None
46
+
47
+ for line in lines:
48
+ _, loss, acc = line.split(" - ")
49
+ losses.append(float(loss.split(":")[1].strip()))
50
+ acces.append(float(acc.split(":")[1].strip()))
51
+
52
+ x = [i * log_every for i in range(1, len(losses) + 1)]
53
+
54
+ plt.plot(x, losses, label="loss")
55
+ plt.xlim(log_every // 2, x[-1] + log_every // 2)
56
+ plt.savefig(f"{temp_dir}/loss.png")
57
+ plt.clf()
58
+ plt.plot(x, acces, label="acc")
59
+ plt.xlim(log_every // 2, x[-1] + log_every // 2)
60
+ plt.savefig(f"{temp_dir}/acc.png")
61
+ plt.clf()
62
+ return f"{temp_dir}/acc.png", f"{temp_dir}/loss.png"
63
+
64
+
65
+ def upload_file(fileobj, temp_dir):
66
+ """
67
+ Upload a file and check the uploaded zip file.
68
+ """
69
+ # First check if a file is a zip file.
70
+ if not zipfile.is_zipfile(fileobj.name):
71
+ raise gr.Error("Please upload a zip file.")
72
+
73
+ # Then unzip file
74
+ shutil.unpack_archive(fileobj.name, temp_dir)
75
+
76
+ # check zip file
77
+ if not os.path.exists(os.path.join(temp_dir, "text")):
78
+ raise gr.Error("Please upload a valid zip file.")
79
+
80
+ if not os.path.exists(os.path.join(temp_dir, "text_ctc")):
81
+ raise gr.Error("Please upload a valid zip file.")
82
+
83
+ if not os.path.exists(os.path.join(temp_dir, "audio")):
84
+ raise gr.Error("Please upload a valid zip file.")
85
+
86
+ # check if all texts and audio matches
87
+ audio_ids = []
88
+ with open(os.path.join(temp_dir, "text"), "r") as f:
89
+ for line in f.readlines():
90
+ audio_ids.append(line.split(maxsplit=1)[0])
91
+
92
+ with open(os.path.join(temp_dir, "text_ctc"), "r") as f:
93
+ ctc_audio_ids = []
94
+ for line in f.readlines():
95
+ ctc_audio_ids.append(line.split(maxsplit=1)[0])
96
+
97
+ if len(audio_ids) != len(ctc_audio_ids):
98
+ raise gr.Error(
99
+ f"Length of `text` ({len(audio_ids)}) and `text_ctc` ({len(ctc_audio_ids)}) is different."
100
+ )
101
+
102
+ if set(audio_ids) != set(ctc_audio_ids):
103
+ raise gr.Error(f"`text` and `text_ctc` have different audio ids.")
104
+
105
+ for audio_id in glob.glob(os.path.join(temp_dir, "audio", "*")):
106
+ if not Path(audio_id).stem in audio_ids:
107
+ raise gr.Error(f"Audio id {audio_id} is not in `text` or `text_ctc`.")
108
+
109
+ gr.Info("Successfully uploaded and validated zip file.")
110
+
111
+ return [fileobj]
112
+
113
+
114
+ with gr.Blocks(title="OWSM-finetune") as demo:
115
+ tempdir_path = gr.State(os.environ['TEMP_DIR'])
116
+ gr.Markdown(
117
+ """# OWSM finetune demo!
118
+
119
+ Finetune `owsm_v3.1_ebf_base` with your own dataset!
120
+ Due to resource limitation, you can only train 10 epochs on maximum.
121
+
122
+ ## Upload dataset and define settings
123
+ """
124
+ )
125
+
126
+ # main contents
127
+ with gr.Row():
128
+ with gr.Column():
129
+ file_output = gr.File()
130
+ upload_button = gr.UploadButton("Click to Upload a File", file_count="single")
131
+ upload_button.upload(
132
+ upload_file, [upload_button, tempdir_path], [file_output]
133
+ )
134
+
135
+ with gr.Column():
136
+ lang = gr.Dropdown(
137
+ languages["espnet/owsm_v3.1_ebf_base"],
138
+ label="Language",
139
+ info="Choose language!",
140
+ value="jpn",
141
+ interactive=True,
142
+ )
143
+ task = gr.Dropdown(
144
+ tasks["espnet/owsm_v3.1_ebf_base"],
145
+ label="Task",
146
+ info="Choose task!",
147
+ value="asr",
148
+ interactive=True,
149
+ )
150
+
151
+ gr.Markdown("## Set training settings")
152
+
153
+ with gr.Row():
154
+ with gr.Column():
155
+ log_every = gr.Number(value=10, label="log_every", interactive=True)
156
+ max_epoch = gr.Slider(1, 10, step=1, label="max_epoch", interactive=True)
157
+ scheduler = gr.Dropdown(
158
+ ["warmuplr"], label="warmup", value="warmuplr", interactive=True
159
+ )
160
+ warmup_steps = gr.Number(
161
+ value=100, label="warmup_steps", interactive=True
162
+ )
163
+
164
+ with gr.Column():
165
+ optimizer = gr.Dropdown(
166
+ ["adam", "adamw", "sgd", "adadelta", "adagrad", "adamax", "asgd", "rmsprop"],
167
+ label="optimizer",
168
+ value="adam",
169
+ interactive=True
170
+ )
171
+ learning_rate = gr.Number(
172
+ value=1e-4, label="learning_rate", interactive=True
173
+ )
174
+ weight_decay = gr.Number(
175
+ value=0.000001, label="weight_decay", interactive=True
176
+ )
177
+
178
+ gr.Markdown("## Logs and plots")
179
+
180
+ with gr.Row():
181
+ with gr.Column():
182
+ log_output = gr.Textbox(
183
+ show_label=False,
184
+ interactive=False,
185
+ max_lines=23,
186
+ lines=23,
187
+ )
188
+ demo.load(read_logs, None, log_output, every=2)
189
+
190
+ with gr.Column():
191
+ log_acc = gr.Image(label="Accuracy", show_label=True, interactive=False)
192
+ log_loss = gr.Image(label="Loss", show_label=True, interactive=False)
193
+ demo.load(plot_loss_acc, [tempdir_path, log_every], [log_acc, log_loss], every=10)
194
+
195
+ with gr.Row():
196
+ with gr.Column():
197
+ ref_text = gr.Textbox(
198
+ label="Reference text",
199
+ show_label=True,
200
+ interactive=False,
201
+ max_lines=10,
202
+ lines=10,
203
+ )
204
+ with gr.Column():
205
+ base_text = gr.Textbox(
206
+ label="Baseline text",
207
+ show_label=True,
208
+ interactive=False,
209
+ max_lines=10,
210
+ lines=10,
211
+ )
212
+
213
+ with gr.Row():
214
+ with gr.Column():
215
+ hyp_text = gr.Textbox(
216
+ label="Hypothesis text",
217
+ show_label=True,
218
+ interactive=False,
219
+ max_lines=10,
220
+ lines=10,
221
+ )
222
+ with gr.Column():
223
+ trained_model = gr.File(
224
+ label="Trained model",
225
+ interactive=False,
226
+ )
227
+
228
+ with gr.Row():
229
+ with gr.Column():
230
+ baseline_btn = gr.Button("Run Baseline", variant="secondary")
231
+ baseline_btn.click(
232
+ baseline_model,
233
+ [
234
+ lang,
235
+ task,
236
+ tempdir_path,
237
+ ],
238
+ [ref_text, base_text]
239
+ )
240
+ with gr.Column():
241
+ finetune_btn = gr.Button("Finetune Model", variant="primary")
242
+ finetune_btn.click(
243
+ finetune_model,
244
+ [
245
+ lang,
246
+ task,
247
+ tempdir_path,
248
+ log_every,
249
+ max_epoch,
250
+ scheduler,
251
+ warmup_steps,
252
+ optimizer,
253
+ learning_rate,
254
+ weight_decay,
255
+ ],
256
+ [trained_model, hyp_text]
257
+ )
258
+
259
+ gr.Markdown(load_markdown())
260
+
261
+ if __name__ == "__main__":
262
+ try:
263
+ demo.queue().launch()
264
+ except:
265
+ print("Unexpected error:", sys.exc_info()[0])
266
+ raise
267
+ finally:
268
+ shutil.rmtree(os.environ['TEMP_DIR'])