Taejin commited on
Commit
bf24ae8
1 Parent(s): 30dd9ff

Adding files

Browse files

Signed-off-by: Taejin Park <tango4j@gmail.com>

__pycache__/app.cpython-310.pyc ADDED
Binary file (7.58 kB). View file
 
__pycache__/app.cpython-39.pyc ADDED
Binary file (7.96 kB). View file
 
__pycache__/app_new.cpython-310.pyc ADDED
Binary file (6.49 kB). View file
 
__pycache__/app_new.cpython-39.pyc ADDED
Binary file (7.45 kB). View file
 
__pycache__/content.cpython-310.pyc ADDED
Binary file (5.47 kB). View file
 
__pycache__/scorer.cpython-310.pyc ADDED
Binary file (1.94 kB). View file
 
app.py CHANGED
@@ -27,6 +27,8 @@ api = HfApi()
27
 
28
  YEAR_VERSION = "2024"
29
 
 
 
30
  def read_json_file(filepath):
31
  with open(filepath) as infile:
32
  data_dict = json.load(infile)
@@ -38,17 +40,17 @@ def save_json_file(filepath, data_dict):
38
 
39
  os.makedirs("scored", exist_ok=True)
40
 
41
- test_data_files = {"test": "contextual_test.csv"}
42
- test_dataset = load_dataset(TEST_DATASET, data_files=test_data_files , token=TOKEN, download_mode="force_redownload", ignore_verifications=True)
43
 
44
- val_data_files = {"val": "contextual_val.csv"}
45
- val_dataset = load_dataset(VAL_DATASET, data_files=val_data_files , token=TOKEN, download_mode="force_redownload", ignore_verifications=True)
46
 
47
- results_data_files = {"test": "contextual_test_results.csv", "val": "contextual_val_results.csv"}
48
- results = load_dataset(RESULTS_DATASET, data_files=results_data_files, token=TOKEN, download_mode="force_redownload", ignore_verifications=True)
49
 
50
- contacts_data_files = {"contacts": "contacts.csv"}
51
- contact_infos = load_dataset(CONTACT_DATASET, data_files=contacts_data_files, token=TOKEN, download_mode="force_redownload", ignore_verifications=True)
52
 
53
  def get_dataframe_from_results(results, split):
54
  df = results[split].to_pandas()
@@ -56,13 +58,13 @@ def get_dataframe_from_results(results, split):
56
  df = df.sort_values(by=["All"], ascending=False)
57
  return df
58
 
59
- test_dataset_dataframe = test_dataset["test"].to_pandas()
60
- val_dataset_dataframe = val_dataset["val"].to_pandas()
61
 
62
- contacts_dataframe = contact_infos["contacts"].to_pandas()
63
 
64
- val_results_dataframe = get_dataframe_from_results(results=results, split="val")
65
- test_results_dataframe = get_dataframe_from_results(results=results, split="test")
66
 
67
  def restart_space():
68
  api.restart_space(repo_id=LEADERBOARD_PATH, token=TOKEN)
 
27
 
28
  YEAR_VERSION = "2024"
29
 
30
+ results = {"dev": {"cpWER": 0, "W
31
+
32
  def read_json_file(filepath):
33
  with open(filepath) as infile:
34
  data_dict = json.load(infile)
 
40
 
41
  os.makedirs("scored", exist_ok=True)
42
 
43
+ # test_data_files = {"test": "contextual_test.csv"}
44
+ # test_dataset = load_dataset(TEST_DATASET, data_files=test_data_files , token=TOKEN, download_mode="force_redownload", ignore_verifications=True)
45
 
46
+ # val_data_files = {"val": "contextual_val.csv"}
47
+ # val_dataset = load_dataset(VAL_DATASET, data_files=val_data_files , token=TOKEN, download_mode="force_redownload", ignore_verifications=True)
48
 
49
+ # results_data_files = {"test": "contextual_test_results.csv", "val": "contextual_val_results.csv"}
50
+ # results = load_dataset(RESULTS_DATASET, data_files=results_data_files, token=TOKEN, download_mode="force_redownload", ignore_verifications=True)
51
 
52
+ # contacts_data_files = {"contacts": "contacts.csv"}
53
+ # contact_infos = load_dataset(CONTACT_DATASET, data_files=contacts_data_files, token=TOKEN, download_mode="force_redownload", ignore_verifications=True)
54
 
55
  def get_dataframe_from_results(results, split):
56
  df = results[split].to_pandas()
 
58
  df = df.sort_values(by=["All"], ascending=False)
59
  return df
60
 
61
+ # test_dataset_dataframe = test_dataset["test"].to_pandas()
62
+ # val_dataset_dataframe = val_dataset["val"].to_pandas()
63
 
64
+ # contacts_dataframe = contact_infos["contacts"].to_pandas()
65
 
66
+ # val_results_dataframe = get_dataframe_from_results(results=results, split="val")
67
+ # test_results_dataframe = get_dataframe_from_results(results=results, split="test")
68
 
69
  def restart_space():
70
  api.restart_space(repo_id=LEADERBOARD_PATH, token=TOKEN)
app_new.py ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import csv
4
+ import datetime
5
+ from email.utils import parseaddr
6
+
7
+ import gradio as gr
8
+ import pandas as pd
9
+ import numpy as np
10
+
11
+ from datasets import load_dataset
12
+ from apscheduler.schedulers.background import BackgroundScheduler
13
+ from huggingface_hub import HfApi
14
+
15
+ from scorer import instruction_scorer
16
+ from content import format_error, format_warning, format_log, TITLE, INTRODUCTION_TEXT, SUBMISSION_TEXT, CITATION_BUTTON_LABEL, CITATION_BUTTON_TEXT, model_hyperlink
17
+
18
+ TOKEN = os.environ.get("TOKEN", None)
19
+ # OWNER="ucla-contextual"
20
+ OWNER="Taejin"
21
+ # TEST_DATASET = f"{OWNER}/contextual_test"
22
+ # VAL_DATASET = f"{OWNER}/contextual_val"
23
+ # SUBMISSION_DATASET = f"{OWNER}/submissions_internal"
24
+ # CONTACT_DATASET = f"{OWNER}/contact_info"
25
+ # RESULTS_DATASET = f"{OWNER}/results"
26
+ # LEADERBOARD_PATH = f"{OWNER}/leaderboard"
27
+
28
+ RESULTS_DATASET = f"{OWNER}/spk_tag_results"
29
+ LEADERBOARD_PATH = f"{OWNER}/leaderboard"
30
+ SUBMISSION_DATASET = f"{OWNER}/submission_leaderboard"
31
+ api = HfApi()
32
+
33
+ YEAR_VERSION = "2024"
34
+
35
+ def read_json_file(filepath):
36
+ with open(filepath) as infile:
37
+ data_dict = json.load(infile)
38
+ return data_dict
39
+
40
+ def save_json_file(filepath, data_dict):
41
+ with open(filepath, "w") as outfile:
42
+ json.dump(data_dict, outfile)
43
+
44
+ os.makedirs("scored", exist_ok=True)
45
+
46
+ # test_data_files = {"test": "contextual_test.csv"}
47
+ # test_dataset = load_dataset(TEST_DATASET, data_files=test_data_files , token=TOKEN, download_mode="force_redownload", ignore_verifications=True)
48
+
49
+ # val_data_files = {"val": "contextual_val.csv"}
50
+ # val_dataset = load_dataset(VAL_DATASET, data_files=val_data_files , token=TOKEN, download_mode="force_redownload", ignore_verifications=True)
51
+
52
+ # results_data_files = {"test": "contextual_test_results.csv", "val": "contextual_val_results.csv"}
53
+ # results = load_dataset(RESULTS_DATASET, data_files=results_data_files, token=TOKEN, download_mode="force_redownload", ignore_verifications=True)
54
+
55
+ # contacts_data_files = {"contacts": "contacts.csv"}
56
+ # contact_infos = load_dataset(CONTACT_DATASET, data_files=contacts_data_files, token=TOKEN, download_mode="force_redownload", ignore_verifications=True)
57
+
58
+ # BASE_PATH="entry_data"
59
+
60
+
61
+
62
+ # results_data_files = {"dev": f"{BASE_PATH}/dev_set_data.csv", "val": "contextual_val_results.csv"}
63
+ results_data_files = {"dev": "dev_set_data.csv"}
64
+ results = load_dataset(RESULTS_DATASET, data_files=results_data_files, token=TOKEN, download_mode="force_redownload", ignore_verifications=True)
65
+
66
+ # contacts_data_files = {"contacts": "contacts.csv"}
67
+ # contact_infos = load_dataset(CONTACT_DATASET, data_files=contacts_data_files, token=TOKEN, download_mode="force_redownload", ignore_verifications=True)
68
+
69
+ def get_dataframe_from_results(results, split):
70
+ df = results[split].to_pandas()
71
+ # df.drop(columns=['URL'], inplace=True)
72
+ df = df.sort_values(by=["cpWER"], ascending=False)
73
+ return df
74
+
75
+
76
+
77
+ # test_dataset_dataframe = test_dataset["test"].to_pandas()
78
+ # val_dataset_dataframe = val_dataset["val"].to_pandas()
79
+
80
+ # contacts_dataframe = contact_infos["contacts"].to_pandas()
81
+
82
+ # val_results_dataframe = get_dataframe_from_results(results=results, split="val")
83
+ # test_results_dataframe = get_dataframe_from_results(results=results, split="test")
84
+
85
+ def restart_space():
86
+ api.restart_space(repo_id=LEADERBOARD_PATH, token=TOKEN)
87
+
88
+ # TYPES = ["markdown", "markdown", "markdown", "number", "number", "number","number", "number", "number", "number", "number", "number"]
89
+ TYPES = ["markdown", "markdown", "markdown", "markdown", "number", "number"]
90
+
91
+ # file_path = "dev_set_data.csv"
92
+ # dev_dataframe= pd.read_csv(file_path)
93
+ dev_dataset_dataframe= get_dataframe_from_results(results=results, split="dev")
94
+
95
+ def add_new_eval(
96
+ system_name: str,
97
+ method: str,
98
+ path_to_file: str,
99
+ organisation: str,
100
+ mail: str,
101
+ ):
102
+ print("printing all inputs:", system_name, method, path_to_file, organisation, mail)
103
+
104
+ if len(system_name)==0:
105
+ print("system_name none")
106
+ raise gr.Error("Please provide a system_name name. Field empty!")
107
+
108
+ if len(method)==0:
109
+ print("method none")
110
+ raise gr.Error("Please provide a method. Field empty!")
111
+
112
+ if len(organisation)==0:
113
+ print("org none")
114
+ raise gr.Error("Please provide organisation information. Field empty!")
115
+
116
+ # Very basic email parsing
117
+ _, parsed_mail = parseaddr(mail)
118
+ if not "@" in parsed_mail:
119
+ print("email here")
120
+ raise gr.Error("Please provide a valid email address.")
121
+
122
+
123
+ # Check if the combination system_name/org already exists and prints a warning message if yes
124
+ # if system_name.lower() in set([m.lower() for m in results["dev"]["System_name"]]) and organisation.lower() in set([o.lower() for o in results["dev"]["Organisation"]]):
125
+ # print("system_name org combo here")
126
+ # raise gr.Error("This system_name has been already submitted.")
127
+
128
+ if path_to_file is None:
129
+ print("file missing here")
130
+ raise gr.Error("Please attach a file.")
131
+
132
+ tmp_file_output = read_json_file(path_to_file.name)
133
+
134
+ if len(tmp_file_output.keys())!=1:
135
+ print("file format wrong here")
136
+ raise gr.Error("Submission file format incorrect. Please refer to the format description!")
137
+
138
+ tmp_output_key = list(tmp_file_output.keys())[0]
139
+ if len(tmp_file_output[tmp_output_key].keys())!=100:
140
+ print("file not 100 here")
141
+ raise gr.Error("File must contain exactly 100 predictions.")
142
+
143
+ # Save submitted file
144
+ time_atm = datetime.datetime.today()
145
+ api.upload_file(
146
+ repo_id=SUBMISSION_DATASET,
147
+ path_or_fileobj=path_to_file.name,
148
+ path_in_repo=f"{organisation}/{system_name}/{YEAR_VERSION}_raw_{time_atm}.json",
149
+ repo_type="dataset",
150
+ token=TOKEN
151
+ )
152
+
153
+ # Compute score
154
+ file_path = path_to_file.name
155
+ # scores = instruction_scorer(val_dataset_dataframe, file_path , system_name)
156
+ ref_file_path="seglst_files/err_dev.ref.seglst.json"
157
+ scores = instruction_scorer(file_path_input= path_to_file.name, ref_file_path=ref_file_path, system_name=system_name)
158
+
159
+ path_or_fileobj=f"scored/{organisation}_{system_name}.json"
160
+ save_json_file(path_or_fileobj, scores)
161
+
162
+ # Save scored file
163
+ api.upload_file(
164
+ repo_id=SUBMISSION_DATASET,
165
+ path_or_fileobj=path_or_fileobj,
166
+ path_in_repo=f"{organisation}/{system_name}/{YEAR_VERSION}_scored_{time_atm}.json",
167
+ repo_type="dataset",
168
+ token=TOKEN
169
+ )
170
+
171
+ # Actual submission
172
+ eval_entry = {
173
+ "System_name": system_name,
174
+ "Method":method,
175
+ "Organisation": organisation,
176
+ "cpWER":scores["cpWER"],
177
+ "WER":scores["WER"],
178
+ }
179
+
180
+
181
+ dev_set_data_csv = "dev_set_data.csv"
182
+
183
+ val_results_dataframe = get_dataframe_from_results(results=results, split="val")
184
+ val_results_dataframe = pd.concat([val_results_dataframe, pd.DataFrame([eval_entry])], ignore_index=True)
185
+ val_results_dataframe.to_csv(dev_set_data_csv, index=False)
186
+
187
+ api.upload_file(
188
+ repo_id=RESULTS_DATASET,
189
+ path_or_fileobj=dev_set_data_csv,
190
+ path_in_repo=dev_set_data_csv,
191
+ repo_type="dataset",
192
+ token=TOKEN
193
+ )
194
+
195
+ # contact_info = {
196
+ # "System_name": system_name,
197
+ # "Organisation": organisation,
198
+ # "Mail": mail,
199
+ # }
200
+
201
+ # contacts_dataframe = contact_infos["contacts"].to_pandas()
202
+ # contacts_dataframe = pd.concat([contacts_dataframe, pd.DataFrame([contact_info])], ignore_index=True)
203
+ # contacts_dataframe.to_csv('contacts.csv', index=False)
204
+
205
+ # api.upload_file(
206
+ # repo_id=CONTACT_DATASET,
207
+ # path_or_fileobj="contacts.csv",
208
+ # path_in_repo=f"contacts.csv",
209
+ # repo_type="dataset",
210
+ # token=TOKEN
211
+ # )
212
+
213
+ return format_log(f"System_name {system_name} submitted by {organisation} successfully! \nPlease refresh the val leaderboard, and wait a bit to see the score displayed")
214
+
215
+
216
+ # def refresh():
217
+ # results_data_files = {"test": "contextual_test_results.csv", "val": "contextual_val_results.csv"}
218
+ # results = load_dataset(RESULTS_DATASET, data_files=
219
+ # results_data_files, token=TOKEN, download_mode="force_redownload", ignore_verifications=True)
220
+ # val_results_dataframe = get_dataframe_from_results(results=results, split="val")
221
+ # test_results_dataframe = get_dataframe_from_results(results=results, split="test")
222
+ # return val_results_dataframe, test_results_dataframe
223
+
224
+ def refresh():
225
+ results_data_files = {"dev": "dev_set_data.csv"}
226
+ results = load_dataset(RESULTS_DATASET, data_files=
227
+ results_data_files, token=TOKEN, download_mode="force_redownload", ignore_verifications=True)
228
+ dev_results_dataframe = get_dataframe_from_results(results=results, split="dev")
229
+ # test_results_dataframe = get_dataframe_from_results(results=results, split="test")
230
+ return dev_results_dataframe
231
+
232
+ def upload_file(files):
233
+ file_paths = [file.name for file in files]
234
+ return file_paths
235
+
236
+
237
+
238
+
239
+ demo = gr.Blocks()
240
+ with demo:
241
+ gr.HTML(TITLE)
242
+ gr.Markdown(INTRODUCTION_TEXT, elem_classes="markdown-text")
243
+
244
+ with gr.Row():
245
+ with gr.Accordion("🧐 Introduction", open=False):
246
+ gr.Markdown(INTRODUCTION_TEXT, elem_classes="markdown-text")
247
+
248
+ with gr.Row():
249
+ with gr.Accordion("🎯 Submission Guidelines", open=False):
250
+ gr.Markdown(SUBMISSION_TEXT, elem_classes="markdown-text")
251
+
252
+ with gr.Row():
253
+ with gr.Accordion("📙 Citation", open=False):
254
+ citation_button = gr.TextArea(
255
+ value=CITATION_BUTTON_TEXT,
256
+ label=CITATION_BUTTON_LABEL,
257
+ elem_id="citation-button",
258
+ )
259
+ with gr.Tab("Results: Dev"):
260
+ leaderboard_table_dev = gr.components.Dataframe(
261
+ value=dev_dataset_dataframe, datatype=TYPES, interactive=False,
262
+ column_widths=["20%"]
263
+ )
264
+
265
+ refresh_button = gr.Button("Refresh")
266
+ refresh_button.click(
267
+ refresh,
268
+ inputs=[],
269
+ outputs=[
270
+ leaderboard_table_dev,
271
+ ],
272
+ )
273
+ with gr.Accordion("Submit a new system_name for evaluation"):
274
+ with gr.Row():
275
+ with gr.Column():
276
+ system_name_textbox = gr.Textbox(label="System name", type='text')
277
+ method_textbox = gr.Textbox(label="Method (LLM with prompt, beam-search, etc)", type='text')
278
+ with gr.Column():
279
+ organisation = gr.Textbox(label="Organisation or Team Name", type='text')
280
+ mail = gr.Textbox(label="Contact email (will be stored privately, & used if there is an issue with your submission)", type='email')
281
+ file_output = gr.File()
282
+
283
+
284
+ submit_button = gr.Button("Submit Eval")
285
+ submission_result = gr.Markdown()
286
+ submit_button.click(
287
+ add_new_eval,
288
+ [
289
+ system_name_textbox,
290
+ method_textbox,
291
+ file_output,
292
+ organisation,
293
+ mail
294
+ ],
295
+ submission_result,
296
+ )
297
+
298
+ scheduler = BackgroundScheduler()
299
+ scheduler.add_job(restart_space, "interval", seconds=3600)
300
+ scheduler.start()
301
+ demo.launch(debug=True)
app_old.py ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import csv
4
+ import datetime
5
+ from email.utils import parseaddr
6
+
7
+ import gradio as gr
8
+ import pandas as pd
9
+ import numpy as np
10
+
11
+ from datasets import load_dataset
12
+ from apscheduler.schedulers.background import BackgroundScheduler
13
+ from huggingface_hub import HfApi
14
+
15
+ from scorer import instruction_scorer
16
+ from content import format_error, format_warning, format_log, TITLE, INTRODUCTION_TEXT, SUBMISSION_TEXT, CITATION_BUTTON_LABEL, CITATION_BUTTON_TEXT, model_hyperlink
17
+
18
+ TOKEN = os.environ.get("TOKEN", None)
19
+ OWNER="ucla-contextual"
20
+ TEST_DATASET = f"{OWNER}/contextual_test"
21
+ VAL_DATASET = f"{OWNER}/contextual_val"
22
+ SUBMISSION_DATASET = f"{OWNER}/submissions_internal"
23
+ CONTACT_DATASET = f"{OWNER}/contact_info"
24
+ RESULTS_DATASET = f"{OWNER}/results"
25
+ LEADERBOARD_PATH = f"{OWNER}/leaderboard"
26
+ api = HfApi()
27
+
28
+ YEAR_VERSION = "2024"
29
+
30
+ def read_json_file(filepath):
31
+ with open(filepath) as infile:
32
+ data_dict = json.load(infile)
33
+ return data_dict
34
+
35
+ def save_json_file(filepath, data_dict):
36
+ with open(filepath, "w") as outfile:
37
+ json.dump(data_dict, outfile)
38
+
39
+ os.makedirs("scored", exist_ok=True)
40
+
41
+ # test_data_files = {"test": "contextual_test.csv"}
42
+ # test_dataset = load_dataset(TEST_DATASET, data_files=test_data_files , token=TOKEN, download_mode="force_redownload", ignore_verifications=True)
43
+
44
+ # val_data_files = {"val": "contextual_val.csv"}
45
+ # val_dataset = load_dataset(VAL_DATASET, data_files=val_data_files , token=TOKEN, download_mode="force_redownload", ignore_verifications=True)
46
+
47
+ # results_data_files = {"test": "contextual_test_results.csv", "val": "contextual_val_results.csv"}
48
+ # results = load_dataset(RESULTS_DATASET, data_files=results_data_files, token=TOKEN, download_mode="force_redownload", ignore_verifications=True)
49
+
50
+ # contacts_data_files = {"contacts": "contacts.csv"}
51
+ # contact_infos = load_dataset(CONTACT_DATASET, data_files=contacts_data_files, token=TOKEN, download_mode="force_redownload", ignore_verifications=True)
52
+
53
+ def get_dataframe_from_results(results, split):
54
+ df = results[split].to_pandas()
55
+ df.drop(columns=['URL'], inplace=True)
56
+ df = df.sort_values(by=["All"], ascending=False)
57
+ return df
58
+
59
+ # test_dataset_dataframe = test_dataset["test"].to_pandas()
60
+ # val_dataset_dataframe = val_dataset["val"].to_pandas()
61
+
62
+ # contacts_dataframe = contact_infos["contacts"].to_pandas()
63
+
64
+ # val_results_dataframe = get_dataframe_from_results(results=results, split="val")
65
+ # test_results_dataframe = get_dataframe_from_results(results=results, split="test")
66
+
67
+ def restart_space():
68
+ api.restart_space(repo_id=LEADERBOARD_PATH, token=TOKEN)
69
+
70
+ TYPES = ["markdown", "markdown", "markdown", "number", "number", "number","number", "number", "number", "number", "number", "number"]
71
+
72
+ def add_new_eval(
73
+ model: str,
74
+ method: str,
75
+ url: str,
76
+ path_to_file: str,
77
+ organisation: str,
78
+ mail: str,
79
+ ):
80
+ print("printing all inputs:", model, method, url, path_to_file, organisation, mail)
81
+
82
+ if len(model)==0:
83
+ print("model none")
84
+ raise gr.Error("Please provide a model name. Field empty!")
85
+
86
+ if len(method)==0:
87
+ print("method none")
88
+ raise gr.Error("Please provide a method. Field empty!")
89
+
90
+ if len(organisation)==0:
91
+ print("org none")
92
+ raise gr.Error("Please provide organisation information. Field empty!")
93
+
94
+ # Very basic email parsing
95
+ _, parsed_mail = parseaddr(mail)
96
+ if not "@" in parsed_mail:
97
+ print("email here")
98
+ raise gr.Error("Please provide a valid email address.")
99
+
100
+
101
+ # Check if the combination model/org already exists and prints a warning message if yes
102
+ if model.lower() in set([m.lower() for m in results["val"]["Model"]]) and organisation.lower() in set([o.lower() for o in results["val"]["Organisation"]]):
103
+ print("model org combo here")
104
+ raise gr.Error("This model has been already submitted.")
105
+
106
+ if path_to_file is None:
107
+ print("file missing here")
108
+ raise gr.Error("Please attach a file.")
109
+
110
+ tmp_file_output = read_json_file(path_to_file.name)
111
+
112
+ if len(tmp_file_output.keys())!=1:
113
+ print("file format wrong here")
114
+ raise gr.Error("Submission file format incorrect. Please refer to the format description!")
115
+
116
+ tmp_output_key = list(tmp_file_output.keys())[0]
117
+ if len(tmp_file_output[tmp_output_key].keys())!=100:
118
+ print("file not 100 here")
119
+ raise gr.Error("File must contain exactly 100 predictions.")
120
+
121
+ # Save submitted file
122
+ time_atm = datetime.datetime.today()
123
+ api.upload_file(
124
+ repo_id=SUBMISSION_DATASET,
125
+ path_or_fileobj=path_to_file.name,
126
+ path_in_repo=f"{organisation}/{model}/{YEAR_VERSION}_raw_{time_atm}.json",
127
+ repo_type="dataset",
128
+ token=TOKEN
129
+ )
130
+
131
+ # Compute score
132
+ file_path = path_to_file.name
133
+ scores = instruction_scorer(val_dataset_dataframe, file_path , model)
134
+
135
+ path_or_fileobj=f"scored/{organisation}_{model}.json"
136
+ save_json_file(path_or_fileobj, scores)
137
+
138
+ # Save scored file
139
+ api.upload_file(
140
+ repo_id=SUBMISSION_DATASET,
141
+ path_or_fileobj=path_or_fileobj,
142
+ path_in_repo=f"{organisation}/{model}/{YEAR_VERSION}_scored_{time_atm}.json",
143
+ repo_type="dataset",
144
+ token=TOKEN
145
+ )
146
+
147
+ # Actual submission
148
+ eval_entry = {
149
+ "Model": model,
150
+ "Method":method,
151
+ "Organisation": organisation,
152
+ "URL": url,
153
+ "All":scores["average"],
154
+ "Time":scores["time"],
155
+ "Shopping":scores["shopping"],
156
+ "Navigation":scores["navigation-transportation"],
157
+ "Abstract":scores["abstract"],
158
+ "Application Usage":scores["app"],
159
+ "Web Usage":scores["web"],
160
+ "Infographic":scores["infographics"],
161
+ "Miscellaneous Natural Scenes": scores["misc"]
162
+ }
163
+
164
+ val_results_dataframe = get_dataframe_from_results(results=results, split="val")
165
+ val_results_dataframe = pd.concat([val_results_dataframe, pd.DataFrame([eval_entry])], ignore_index=True)
166
+ val_results_dataframe.to_csv('contextual_val_results.csv', index=False)
167
+
168
+ api.upload_file(
169
+ repo_id=RESULTS_DATASET,
170
+ path_or_fileobj="contextual_val_results.csv",
171
+ path_in_repo=f"contextual_val_results.csv",
172
+ repo_type="dataset",
173
+ token=TOKEN
174
+ )
175
+
176
+ contact_info = {
177
+ "Model": model,
178
+ "URL": url,
179
+ "Organisation": organisation,
180
+ "Mail": mail,
181
+ }
182
+
183
+ contacts_dataframe = contact_infos["contacts"].to_pandas()
184
+ contacts_dataframe = pd.concat([contacts_dataframe, pd.DataFrame([contact_info])], ignore_index=True)
185
+ contacts_dataframe.to_csv('contacts.csv', index=False)
186
+
187
+ api.upload_file(
188
+ repo_id=CONTACT_DATASET,
189
+ path_or_fileobj="contacts.csv",
190
+ path_in_repo=f"contacts.csv",
191
+ repo_type="dataset",
192
+ token=TOKEN
193
+ )
194
+
195
+ return format_log(f"Model {model} submitted by {organisation} successfully! \nPlease refresh the val leaderboard, and wait a bit to see the score displayed")
196
+
197
+
198
+ def refresh():
199
+ results_data_files = {"test": "contextual_test_results.csv", "val": "contextual_val_results.csv"}
200
+ results = load_dataset(RESULTS_DATASET, data_files=
201
+ results_data_files, token=TOKEN, download_mode="force_redownload", ignore_verifications=True)
202
+ val_results_dataframe = get_dataframe_from_results(results=results, split="val")
203
+ test_results_dataframe = get_dataframe_from_results(results=results, split="test")
204
+ return val_results_dataframe, test_results_dataframe
205
+
206
+ def upload_file(files):
207
+ file_paths = [file.name for file in files]
208
+ return file_paths
209
+
210
+
211
+ demo = gr.Blocks()
212
+ with demo:
213
+ gr.HTML(TITLE)
214
+ # gr.Markdown(INTRODUCTION_TEXT, elem_classes="markdown-text")
215
+
216
+ with gr.Row():
217
+ with gr.Accordion("🧐 Introduction", open=False):
218
+ gr.Markdown(INTRODUCTION_TEXT, elem_classes="markdown-text")
219
+
220
+ with gr.Row():
221
+ with gr.Accordion("🎯 Submission Guidelines", open=False):
222
+ gr.Markdown(SUBMISSION_TEXT, elem_classes="markdown-text")
223
+
224
+ with gr.Row():
225
+ with gr.Accordion("📙 Citation", open=False):
226
+ citation_button = gr.TextArea(
227
+ value=CITATION_BUTTON_TEXT,
228
+ label=CITATION_BUTTON_LABEL,
229
+ elem_id="citation-button",
230
+ )
231
+ with gr.Tab("Results: Test"):
232
+ leaderboard_table_test = gr.components.Dataframe(
233
+ value=test_results_dataframe, datatype=TYPES, interactive=False,
234
+ column_widths=["20%"]
235
+ )
236
+ with gr.Tab("Results: Val"):
237
+ leaderboard_table_val = gr.components.Dataframe(
238
+ value=val_results_dataframe, datatype=TYPES, interactive=False,
239
+ column_widths=["20%"]
240
+ )
241
+
242
+ refresh_button = gr.Button("Refresh")
243
+ refresh_button.click(
244
+ refresh,
245
+ inputs=[],
246
+ outputs=[
247
+ leaderboard_table_val,
248
+ leaderboard_table_test,
249
+ ],
250
+ )
251
+ with gr.Accordion("Submit a new model for evaluation"):
252
+ with gr.Row():
253
+ with gr.Column():
254
+ model_name_textbox = gr.Textbox(label="Model name", type='text')
255
+ method_textbox = gr.Textbox(label="Method (LMM or Aug LLM or any other)", type='text')
256
+ url_textbox = gr.Textbox(label="URL to model information", type='text')
257
+ with gr.Column():
258
+ organisation = gr.Textbox(label="Organisation", type='text')
259
+ mail = gr.Textbox(label="Contact email (will be stored privately, & used if there is an issue with your submission)", type='email')
260
+ file_output = gr.File()
261
+
262
+
263
+ submit_button = gr.Button("Submit Eval")
264
+ submission_result = gr.Markdown()
265
+ submit_button.click(
266
+ add_new_eval,
267
+ [
268
+ model_name_textbox,
269
+ method_textbox,
270
+ url_textbox,
271
+ file_output,
272
+ organisation,
273
+ mail
274
+ ],
275
+ submission_result,
276
+ )
277
+
278
+ scheduler = BackgroundScheduler()
279
+ scheduler.add_job(restart_space, "interval", seconds=3600)
280
+ scheduler.start()
281
+ demo.launch(debug=True)
beam_search_utils.py ADDED
@@ -0,0 +1,339 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tqdm import tqdm
2
+ from typing import Dict, List
3
+ from pydiardecode import build_diardecoder
4
+ import numpy as np
5
+ import copy
6
+ import os
7
+ import json
8
+ import concurrent.futures
9
+ import kenlm
10
+
11
+ __INFO_TAG__ = "[BeamSearchUtil INFO]"
12
+
13
+ class SpeakerTaggingBeamSearchDecoder:
14
+ def __init__(self, loaded_kenlm_model: kenlm, cfg: dict):
15
+ self.realigning_lm_params = cfg
16
+ self.realigning_lm = self._load_realigning_LM(loaded_kenlm_model=loaded_kenlm_model)
17
+ self._SPLITSYM = "@"
18
+
19
+ def _load_realigning_LM(self, loaded_kenlm_model: kenlm):
20
+ """
21
+ Load ARPA language model for realigning speaker labels for words.
22
+ """
23
+ diar_decoder = None
24
+ return diar_decoder
25
+
26
+ def realign_words_with_lm(self, word_dict_seq_list: List[Dict[str, float]], speaker_count: int = None, port_num=None) -> List[Dict[str, float]]:
27
+ if speaker_count is None:
28
+ spk_list = []
29
+ for k, line_dict in enumerate(word_dict_seq_list):
30
+ _, spk_label = line_dict['word'], line_dict['speaker']
31
+ spk_list.append(spk_label)
32
+ else:
33
+ spk_list = [ f"speaker_{k}" for k in range(speaker_count)]
34
+
35
+ realigned_list = self.realigning_lm.decode_beams(beam_width=self.realigning_lm_params['beam_width'],
36
+ speaker_list=sorted(list(set(spk_list))),
37
+ word_dict_seq_list=word_dict_seq_list,
38
+ port_num=port_num)
39
+ return realigned_list
40
+
41
+ def beam_search_diarization(
42
+ self,
43
+ trans_info_dict: Dict[str, Dict[str, list]],
44
+ port_num: List[int] = None,
45
+ ) -> Dict[str, Dict[str, float]]:
46
+ """
47
+ Match the diarization result with the ASR output.
48
+ The words and the timestamps for the corresponding words are matched in a for loop.
49
+
50
+ Args:
51
+
52
+ Returns:
53
+ trans_info_dict (dict):
54
+ Dictionary containing word timestamps, speaker labels and words from all sessions.
55
+ Each session is indexed by a unique ID.
56
+ """
57
+ for uniq_id, session_dict in tqdm(trans_info_dict.items(), total=len(trans_info_dict), disable=True):
58
+ # print(f"{__INFO_TAG__} Processing session {uniq_id}")
59
+ word_dict_seq_list = session_dict['words']
60
+ output_beams = self.realign_words_with_lm(word_dict_seq_list=word_dict_seq_list, speaker_count=session_dict['speaker_count'], port_num=port_num)
61
+ word_dict_seq_list = output_beams[0][2]
62
+ trans_info_dict[uniq_id]['words'] = word_dict_seq_list
63
+ return trans_info_dict
64
+
65
+ def merge_div_inputs(self, div_trans_info_dict, org_trans_info_dict, win_len=250, word_window=16, limit_max_spks=8):
66
+ """
67
+ Merge the outputs of parallel processing.
68
+ """
69
+ uniq_id_list = list(org_trans_info_dict.keys())
70
+ sub_div_dict = {}
71
+ for seq_id in div_trans_info_dict.keys():
72
+ div_info = seq_id.split(self._SPLITSYM)
73
+ uniq_id, sub_idx, total_count = div_info[0], int(div_info[1]), int(div_info[2])
74
+ if uniq_id not in sub_div_dict:
75
+ sub_div_dict[uniq_id] = [None] * total_count
76
+ sub_div_dict[uniq_id][sub_idx] = div_trans_info_dict[seq_id]['words']
77
+
78
+ processed_trans_info_dict = {}
79
+ for uniq_id in uniq_id_list:
80
+ processed_trans_info_dict[uniq_id] = {'words': []}
81
+
82
+ if uniq_id in sub_div_dict:
83
+ for k, div_words in enumerate(sub_div_dict[uniq_id]):
84
+ if k == 0:
85
+ div_words = div_words[:win_len]
86
+ else:
87
+ div_words = div_words[word_window:]
88
+ processed_trans_info_dict[uniq_id]['words'].extend(div_words)
89
+
90
+ org_trans_info_dict[uniq_id]['words'] = processed_trans_info_dict[uniq_id]['words']
91
+ else:
92
+ processed_trans_info_dict[uniq_id]['words'] = org_trans_info_dict[uniq_id]['words']
93
+ return processed_trans_info_dict
94
+ # return org_trans_info_dict
95
+
96
+ def divide_chunks(self, trans_info_dict, win_len, word_window, limit_max_spks, port):
97
+ """
98
+ Divide word sequence into chunks of length `win_len` for parallel processing.
99
+
100
+ Args:
101
+ trans_info_dict (_type_): _description_
102
+ diar_logits (_type_): _description_
103
+ win_len (int, optional): _description_. Defaults to 250.
104
+ """
105
+ if len(port) > 1:
106
+ num_workers = len(port)
107
+ else:
108
+ num_workers = 25
109
+ div_trans_info_dict = {}
110
+ for uniq_id in trans_info_dict.keys():
111
+
112
+ uniq_trans = trans_info_dict[uniq_id]
113
+ if 'status' in uniq_trans:
114
+ del uniq_trans['status']
115
+ if 'transcription' in uniq_trans:
116
+ del uniq_trans['transcription']
117
+ if 'sentences' in uniq_trans:
118
+ del uniq_trans['sentences']
119
+ word_seq = uniq_trans['words']
120
+ num_spks = len(set([x['speaker'] for x in word_seq]))
121
+ if num_spks > limit_max_spks:
122
+ continue
123
+
124
+ div_word_seq = []
125
+ if win_len is None:
126
+ win_len = int(np.ceil(len(word_seq)/num_workers))
127
+ n_chunks = int(np.ceil(len(word_seq)/win_len))
128
+
129
+ for k in range(n_chunks):
130
+ div_word_seq.append(word_seq[max(k*win_len - word_window, 0):(k+1)*win_len])
131
+
132
+ total_count = len(div_word_seq)
133
+ for k, w_seq in enumerate(div_word_seq):
134
+ seq_id = uniq_id + f"{self._SPLITSYM}{k}{self._SPLITSYM}{total_count}"
135
+ div_trans_info_dict[seq_id] = dict(uniq_trans)
136
+ div_trans_info_dict[seq_id]['words'] = w_seq
137
+ return div_trans_info_dict
138
+
139
+ def run_mp_beam_search_decoding(
140
+ speaker_beam_search_decoder,
141
+ loaded_kenlm_model,
142
+ div_trans_info_dict,
143
+ org_trans_info_dict,
144
+ div_mp,
145
+ win_len,
146
+ word_window,
147
+ limit_max_spks,
148
+ port=None,
149
+ use_ngram=False
150
+ ):
151
+ if len(port) > 1:
152
+ port = [int(p) for p in port]
153
+ if use_ngram:
154
+ port = [None]
155
+ num_workers = 24
156
+ else:
157
+ num_workers = len(port)
158
+ uniq_id_list = sorted(list(div_trans_info_dict.keys() ))
159
+ tp = concurrent.futures.ProcessPoolExecutor(max_workers=num_workers)
160
+ futures = []
161
+
162
+ count = 0
163
+ print(f"{__INFO_TAG__} Number of unique chunks to process: {len(uniq_id_list)}")
164
+ for uniq_id in uniq_id_list:
165
+ print(f"{__INFO_TAG__} Running beam search decoding for {uniq_id}...")
166
+ if port is not None:
167
+ port_num = port[count % len(port)]
168
+ else:
169
+ port_num = None
170
+ count += 1
171
+ uniq_trans_info_dict = {uniq_id: div_trans_info_dict[uniq_id]}
172
+ futures.append(tp.submit(speaker_beam_search_decoder.beam_search_diarization, uniq_trans_info_dict, port_num=port_num))
173
+
174
+ pbar = tqdm(total=len(uniq_id_list), desc="Running beam search decoding", unit="files")
175
+ count = 0
176
+ output_trans_info_dict = {}
177
+ for done_future in concurrent.futures.as_completed(futures):
178
+ count += 1
179
+ pbar.update()
180
+ output_trans_info_dict.update(done_future.result())
181
+ pbar.close()
182
+ tp.shutdown()
183
+ if div_mp:
184
+ output_trans_info_dict = speaker_beam_search_decoder.merge_div_inputs(div_trans_info_dict=output_trans_info_dict,
185
+ org_trans_info_dict=org_trans_info_dict,
186
+ win_len=win_len,
187
+ word_window=word_window,
188
+ limit_max_spks=limit_max_spks)
189
+ return output_trans_info_dict
190
+
191
+ def count_num_of_spks(json_trans_list):
192
+ spk_set = set()
193
+ for sentence_dict in json_trans_list:
194
+ spk_set.add(sentence_dict['speaker'])
195
+ speaker_map = { spk_str: idx for idx, spk_str in enumerate(spk_set)}
196
+ return speaker_map
197
+
198
+ def add_placeholder_speaker_softmax(json_trans_list, peak_prob=0.94 ,max_spks=4):
199
+ nemo_json_dict = {}
200
+ word_dict_seq_list = []
201
+ if peak_prob > 1 or peak_prob < 0:
202
+ raise ValueError(f"peak_prob must be between 0 and 1 but got {peak_prob}")
203
+ speaker_map = count_num_of_spks(json_trans_list)
204
+ base_array = np.ones(max_spks) * (1 - peak_prob)/(max_spks-1)
205
+ stt_sec, end_sec = None, None
206
+ for sentence_dict in json_trans_list:
207
+ word_list = sentence_dict['words'].split()
208
+ speaker = sentence_dict['speaker']
209
+ for word in word_list:
210
+ speaker_softmax = copy.deepcopy(base_array)
211
+ speaker_softmax[speaker_map[speaker]] = peak_prob
212
+ word_dict_seq_list.append({'word': word,
213
+ 'start_time': stt_sec,
214
+ 'end_time': end_sec,
215
+ 'speaker': speaker_map[speaker],
216
+ 'speaker_softmax': speaker_softmax}
217
+ )
218
+ nemo_json_dict.update({'words': word_dict_seq_list,
219
+ 'status': "success",
220
+ 'sentences': json_trans_list,
221
+ 'speaker_count': len(speaker_map),
222
+ 'transcription': None}
223
+ )
224
+ return nemo_json_dict
225
+
226
+ def convert_nemo_json_to_seglst(trans_info_dict):
227
+ seglst_seq_list = []
228
+ seg_lst_dict, spk_wise_trans_sessions = {}, {}
229
+ for uniq_id in trans_info_dict.keys():
230
+ spk_wise_trans_sessions[uniq_id] = {}
231
+ seglst_seq_list = []
232
+ word_seq_list = trans_info_dict[uniq_id]['words']
233
+ prev_speaker, sentence = None, ''
234
+ for widx, word_dict in enumerate(word_seq_list):
235
+ curr_speaker = word_dict['speaker']
236
+
237
+ # For making speaker wise transcriptions
238
+ word = word_dict['word']
239
+ if curr_speaker not in spk_wise_trans_sessions[uniq_id]:
240
+ spk_wise_trans_sessions[uniq_id][curr_speaker] = word
241
+ elif curr_speaker in spk_wise_trans_sessions[uniq_id]:
242
+ spk_wise_trans_sessions[uniq_id][curr_speaker] = f"{spk_wise_trans_sessions[uniq_id][curr_speaker]} {word_dict['word']}"
243
+
244
+ # For making segment wise transcriptions
245
+ if curr_speaker!= prev_speaker and prev_speaker is not None:
246
+ seglst_seq_list.append({'session_id': uniq_id,
247
+ 'words': sentence.strip(),
248
+ 'start_time': 0.0,
249
+ 'end_time': 0.0,
250
+ 'speaker': prev_speaker,
251
+ })
252
+ sentence = word_dict['word']
253
+ else:
254
+ sentence = f"{sentence} {word_dict['word']}"
255
+ prev_speaker = curr_speaker
256
+
257
+ # For the last word:
258
+ # (1) If there is no speaker change, add the existing sentence and exit the loop
259
+ # (2) If there is a speaker change, add the last word and exit the loop
260
+ if widx == len(word_seq_list) - 1:
261
+ seglst_seq_list.append({'session_id': uniq_id,
262
+ 'words': sentence.strip(),
263
+ 'start_time': 0.0,
264
+ 'end_time': 0.0,
265
+ 'speaker': curr_speaker,
266
+ })
267
+ seg_lst_dict[uniq_id] = seglst_seq_list
268
+ return seg_lst_dict
269
+
270
+ def load_input_jsons(input_error_src_list_path, ext_str=".seglst.json", peak_prob=0.94, max_spks=4):
271
+ trans_info_dict = {}
272
+ json_filepath_list = open(input_error_src_list_path).readlines()
273
+ for json_path in json_filepath_list:
274
+ json_path = json_path.strip()
275
+ uniq_id = os.path.split(json_path)[-1].split(ext_str)[0]
276
+ if os.path.exists(json_path):
277
+ with open(json_path, "r") as file:
278
+ json_trans = json.load(file)
279
+ else:
280
+ raise FileNotFoundError(f"{json_path} does not exist. Aborting.")
281
+ nemo_json_dict = add_placeholder_speaker_softmax(json_trans, peak_prob=peak_prob, max_spks=max_spks)
282
+ trans_info_dict[uniq_id] = nemo_json_dict
283
+ return trans_info_dict
284
+
285
+ def load_reference_jsons(reference_seglst_list_path, ext_str=".seglst.json"):
286
+ reference_info_dict = {}
287
+ json_filepath_list = open(reference_seglst_list_path).readlines()
288
+ for json_path in json_filepath_list:
289
+ json_path = json_path.strip()
290
+ uniq_id = os.path.split(json_path)[-1].split(ext_str)[0]
291
+ if os.path.exists(json_path):
292
+ with open(json_path, "r") as file:
293
+ json_trans = json.load(file)
294
+ else:
295
+ raise FileNotFoundError(f"{json_path} does not exist. Aborting.")
296
+ json_trans_uniq_id = []
297
+ for sentence_dict in json_trans:
298
+ sentence_dict['session_id'] = uniq_id
299
+ json_trans_uniq_id.append(sentence_dict)
300
+ reference_info_dict[uniq_id] = json_trans_uniq_id
301
+ return reference_info_dict
302
+
303
+ def write_seglst_jsons(
304
+ seg_lst_sessions_dict: dict,
305
+ input_error_src_list_path: str,
306
+ diar_out_path: str,
307
+ ext_str: str,
308
+ write_individual_seglst_jsons=True
309
+ ):
310
+ """
311
+ Writes the segment list (seglst) JSON files to the output directory.
312
+
313
+ Parameters:
314
+ seg_lst_sessions_dict (dict): A dictionary containing session IDs as keys and their corresponding segment lists as values.
315
+ input_error_src_list_path (str): The path to the input error source list file.
316
+ diar_out_path (str): The path to the output directory where the seglst JSON files will be written.
317
+ type_string (str): A string representing the type of the seglst JSON files (e.g., 'hyp' for hypothesis or 'ef' for reference).
318
+ write_individual_seglst_jsons (bool, optional): A flag indicating whether to write individual seglst JSON files for each session. Defaults to True.
319
+
320
+ Returns:
321
+ None
322
+ """
323
+ total_infer_list = []
324
+ total_output_filename = os.path.split(input_error_src_list_path)[-1].replace(".list", "")
325
+ for session_id, seg_lst_list in seg_lst_sessions_dict.items():
326
+ total_infer_list.extend(seg_lst_list)
327
+ if write_individual_seglst_jsons:
328
+ print(f"{__INFO_TAG__} Writing {diar_out_path}/{session_id}.seglst.json")
329
+ with open(f'{diar_out_path}/{session_id}.seglst.json', 'w') as file:
330
+ json.dump(seg_lst_list, file, indent=4) # indent=4 for pretty printing
331
+
332
+ print(f"{__INFO_TAG__} Writing {diar_out_path}/{session_id}.seglst.json")
333
+ total_output_filename = total_output_filename.replace("src", ext_str).replace("ref", ext_str)
334
+ write_fn = f"{diar_out_path}/{total_output_filename}.seglst.json"
335
+ if os.path.exists(write_fn):
336
+ print(f"{__INFO_TAG__} {write_fn} already exists. Deleting it.")
337
+ os.remove(write_fn)
338
+ with open(write_fn, 'w') as file:
339
+ json.dump(total_infer_list, file, indent=4) # indent=4 for pretty printing
entry_data/dev_set_data.csv ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ system_name,method,organisation,mail,cpWER,WER
2
+ baseline_system,beam_search_ngram,SLT_Task2,tango4j@gmail.com,0.24536675570166427,0.21231591
entry_data/dev_set_data_1.csv ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ baseline_system,0.24536675570166427,0.21231591
2
+ baseline_system_2,0.01234,0.1234
hyper_optim.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import optuna
2
+ import os
3
+ import tempfile
4
+ import time
5
+ import json
6
+ import subprocess
7
+ import logging
8
+ from beam_search_utils import (
9
+ write_seglst_jsons,
10
+ run_mp_beam_search_decoding,
11
+ convert_nemo_json_to_seglst,
12
+ SpeakerTaggingBeamSearchDecoder,
13
+ )
14
+
15
+ from speaker_tagging_cpwer_jsons import process_session_data
16
+
17
+ def evaluate(cfg, temp_out_dir, asrdiar_file_name, source_info_dict, hypothesis_sessions_dict, reference_info_dict):
18
+ write_seglst_jsons(hypothesis_sessions_dict, input_error_src_list_path=cfg.input_error_src_list_path, diar_out_path=temp_out_dir, ext_str='hyp')
19
+ write_seglst_jsons(reference_info_dict, input_error_src_list_path=cfg.groundtruth_ref_list_path, diar_out_path=temp_out_dir, ext_str='ref')
20
+ write_seglst_jsons(source_info_dict, input_error_src_list_path=cfg.groundtruth_ref_list_path, diar_out_path=temp_out_dir, ext_str='src')
21
+
22
+ # Construct the file paths
23
+ # src_seglst_json = os.path.join(temp_out_dir, f"{asrdiar_file_name}.src.seglst.json")
24
+ hyp_seglst_json = os.path.join(temp_out_dir, f"{asrdiar_file_name}.hyp.seglst.json")
25
+ ref_seglst_json = os.path.join(temp_out_dir, f"{asrdiar_file_name}.ref.seglst.json")
26
+
27
+ # Construct the output JSON file path
28
+ output_cpwer_hyp_json_file = os.path.join(temp_out_dir, f"{asrdiar_file_name}.hyp.seglst_cpwer.json")
29
+ # output_cpwer_src_json_file = os.path.join(temp_out_dir, f"{asrdiar_file_name}.src.seglst_cpwer.json")
30
+
31
+ # Run meeteval-wer command
32
+ cmd_hyp = [
33
+ "meeteval-wer",
34
+ "cpwer",
35
+ "-h", hyp_seglst_json,
36
+ "-r", ref_seglst_json
37
+ ]
38
+ subprocess.run(cmd_hyp)
39
+
40
+ # Read the JSON file and print the cpWER
41
+ try:
42
+ with open(output_cpwer_hyp_json_file, "r") as file:
43
+ data_h = json.load(file)
44
+ print("Hypothesis cpWER:", data_h["error_rate"])
45
+ cpwer = data_h["error_rate"]
46
+ logging.info(f"-> HYPOTHESIS cpWER={cpwer:.4f}")
47
+ except FileNotFoundError:
48
+ raise FileNotFoundError(f"Output JSON: {output_cpwer_hyp_json_file}\nfile not found.")
49
+
50
+ return cpwer
51
+
52
+ def evaluate_diff(cfg, temp_out_dir, asrdiar_file_name, source_info_dict, hypothesis_sessions_dict, reference_info_dict):
53
+ write_seglst_jsons(hypothesis_sessions_dict, input_error_src_list_path=cfg.input_error_src_list_path, diar_out_path=temp_out_dir, ext_str='hyp')
54
+ write_seglst_jsons(reference_info_dict, input_error_src_list_path=cfg.groundtruth_ref_list_path, diar_out_path=temp_out_dir, ext_str='ref')
55
+ write_seglst_jsons(source_info_dict, input_error_src_list_path=cfg.groundtruth_ref_list_path, diar_out_path=temp_out_dir, ext_str='src')
56
+
57
+ # Construct the file paths
58
+ src_seglst_json = os.path.join(temp_out_dir, f"{asrdiar_file_name}.src.seglst.json")
59
+ hyp_seglst_json = os.path.join(temp_out_dir, f"{asrdiar_file_name}.hyp.seglst.json")
60
+ ref_seglst_json = os.path.join(temp_out_dir, f"{asrdiar_file_name}.ref.seglst.json")
61
+
62
+ # Run meeteval-wer command
63
+ cmd_hyp = [
64
+ "meeteval-wer",
65
+ "cpwer",
66
+ "-h", hyp_seglst_json,
67
+ "-r", ref_seglst_json
68
+ ]
69
+ subprocess.run(cmd_hyp)
70
+
71
+ cmd_src = [
72
+ "meeteval-wer",
73
+ "cpwer",
74
+ "-h", src_seglst_json,
75
+ "-r", ref_seglst_json
76
+ ]
77
+ subprocess.run(cmd_src)
78
+ # Construct the output JSON file path
79
+ output_cpwer_hyp_json_file = os.path.join(temp_out_dir, f"{asrdiar_file_name}.hyp.seglst_cpwer.json")
80
+ output_cpwer_src_json_file = os.path.join(temp_out_dir, f"{asrdiar_file_name}.src.seglst_cpwer.json")
81
+ output_cpwer_hyp_json_file_per_reco = os.path.join(temp_out_dir, f"{asrdiar_file_name}.hyp.seglst_cpwer_per_reco.json")
82
+ output_cpwer_src_json_file_per_reco = os.path.join(temp_out_dir, f"{asrdiar_file_name}.src.seglst_cpwer_per_reco.json")
83
+
84
+ avg_cpwer_diff = process_session_data(output_cpwer_hyp_json_file_per_reco, output_cpwer_src_json_file_per_reco)
85
+
86
+ try:
87
+ with open(output_cpwer_hyp_json_file, "r") as file:
88
+ data_h = json.load(file)
89
+ hyp_cpwer = data_h["error_rate"]
90
+ logging.info(f"-> HYPOTHESIS cpWER={hyp_cpwer:.4f}")
91
+ except FileNotFoundError:
92
+ raise FileNotFoundError(f"Output JSON: {output_cpwer_hyp_json_file}\nfile not found.")
93
+
94
+ try:
95
+ with open(output_cpwer_src_json_file, "r") as file:
96
+ data_h = json.load(file)
97
+ src_cpwer = data_h["error_rate"]
98
+ logging.info(f"-> SOURCE cpWER={src_cpwer:.4f}")
99
+ except FileNotFoundError:
100
+ raise FileNotFoundError(f"Output JSON: {output_cpwer_src_json_file}\nfile not found.")
101
+ diff_cpwer = (hyp_cpwer - src_cpwer)
102
+ logging.info(f"-> Average cpWER DIFF={avg_cpwer_diff:.4f}")
103
+ logging.info(f"-> HYPOTHESIS Improved cpWER={diff_cpwer:.4f}")
104
+ return diff_cpwer
105
+
106
+
107
+ def optuna_suggest_params(cfg, trial):
108
+ cfg.alpha = trial.suggest_float("alpha", 0.5, 1.5)
109
+ cfg.beta = trial.suggest_float("beta", 0.02, 0.4)
110
+ cfg.beam_width = trial.suggest_int("beam_width", 2, 12)
111
+ cfg.word_window = trial.suggest_int("word_window", 10, 50, step=10)
112
+ cfg.use_ngram = True
113
+ cfg.parallel_chunk_word_len = trial.suggest_int("parallel_chunk_word_len", 50, 250, step=25)
114
+ cfg.peak_prob = trial.suggest_float("peak_prob", 0.96, 0.96)
115
+ return cfg
116
+
117
+ def beamsearch_objective(
118
+ trial,
119
+ cfg,
120
+ speaker_beam_search_decoder,
121
+ loaded_kenlm_model,
122
+ org_trans_info_dict,
123
+ source_info_dict,
124
+ reference_info_dict,
125
+ ):
126
+ with tempfile.TemporaryDirectory(dir=cfg.temp_out_dir, prefix="GenSEC_") as local_temp_out_dir:
127
+ start_time2 = time.time()
128
+
129
+ if trial is not None:
130
+ cfg = optuna_suggest_params(cfg, trial)
131
+ speaker_beam_search_decoder = SpeakerTaggingBeamSearchDecoder(loaded_kenlm_model=loaded_kenlm_model, cfg=cfg)
132
+ div_trans_info_dict = speaker_beam_search_decoder.divide_chunks(trans_info_dict=org_trans_info_dict,
133
+ win_len=cfg.parallel_chunk_word_len,
134
+ word_window=cfg.word_window,
135
+ limit_max_spks=cfg.limit_max_spks,
136
+ port=cfg.port,)
137
+ result_trans_info_dict = run_mp_beam_search_decoding(speaker_beam_search_decoder,
138
+ loaded_kenlm_model=loaded_kenlm_model,
139
+ div_trans_info_dict=div_trans_info_dict,
140
+ org_trans_info_dict=org_trans_info_dict,
141
+ div_mp=True,
142
+ win_len=cfg.parallel_chunk_word_len,
143
+ word_window=cfg.word_window,
144
+ limit_max_spks=cfg.limit_max_spks,
145
+ port=cfg.port,
146
+ use_ngram=cfg.use_ngram,
147
+ )
148
+ hypothesis_sessions_dict = convert_nemo_json_to_seglst(result_trans_info_dict)
149
+ cpwer = evaluate_diff(cfg, local_temp_out_dir, cfg.asrdiar_file_name, source_info_dict, hypothesis_sessions_dict, reference_info_dict)
150
+ logging.info(f"Beam Search time taken for trial {trial}: {(time.time() - start_time2)/60:.2f} mins")
151
+ if trial is not None:
152
+ logging.info(f"Trial: {trial.number}")
153
+ logging.info(f"[ cpWER={cpwer:.4f} ]")
154
+ logging.info("-----------------------------------------------")
155
+ return cpwer
156
+
157
+
158
+ def optuna_hyper_optim(
159
+ cfg,
160
+ speaker_beam_search_decoder,
161
+ loaded_kenlm_model,
162
+ # div_trans_info_dict,
163
+ org_trans_info_dict,
164
+ source_info_dict,
165
+ reference_info_dict,
166
+ ):
167
+ """
168
+ Optuna hyper-parameter optimization function.
169
+
170
+ Parameters:
171
+ cfg (dict): A dictionary containing the configuration parameters.
172
+
173
+ """
174
+ worker_function = lambda trial: beamsearch_objective( # noqa: E731
175
+ trial=trial,
176
+ cfg=cfg,
177
+ speaker_beam_search_decoder=speaker_beam_search_decoder,
178
+ loaded_kenlm_model=loaded_kenlm_model,
179
+ # div_trans_info_dict=div_trans_info_dict,
180
+ org_trans_info_dict=org_trans_info_dict,
181
+ source_info_dict=source_info_dict,
182
+ reference_info_dict=reference_info_dict,
183
+ )
184
+ study = optuna.create_study(
185
+ direction="minimize",
186
+ study_name=cfg.optuna_study_name,
187
+ storage=cfg.storage,
188
+ load_if_exists=True
189
+ )
190
+ logger = logging.getLogger()
191
+ logger.setLevel(logging.INFO) # Setup the root logger.
192
+ if cfg.output_log_file is not None:
193
+ logger.addHandler(logging.FileHandler(cfg.output_log_file, mode="a"))
194
+ logger.addHandler(logging.StreamHandler())
195
+ optuna.logging.enable_propagation() # Propagate logs to the root logger.
196
+ study.optimize(worker_function, n_trials=cfg.optuna_n_trials)
requirements.txt CHANGED
@@ -1,4 +1,5 @@
1
  datasets==2.14.5
 
2
  gradio==4.19.2
3
  huggingface-hub==0.19.3
4
  numpy==1.24.2
 
1
  datasets==2.14.5
2
+ meeteval
3
  gradio==4.19.2
4
  huggingface-hub==0.19.3
5
  numpy==1.24.2
scorer.py CHANGED
@@ -1,12 +1,35 @@
1
  import json
2
- import re
3
- import string
4
- import warnings
5
- import pandas as pd
6
- import numpy as np
7
  import os
8
 
9
- def instruction_scorer(data, judgment_file, model_name):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  df = data
12
  img_dict = {}
 
1
  import json
2
+ import tempfile
3
+ import json
4
+ import subprocess
5
+ import logging
 
6
  import os
7
 
8
+
9
+ def instruction_scorer(file_path_input, ref_file_path, system_name):
10
+ cmd_hyp = [
11
+ "meeteval-wer",
12
+ "cpwer",
13
+ "-h", file_path_input,
14
+ "-r", ref_file_path,
15
+ ]
16
+ subprocess.run(cmd_hyp)
17
+
18
+ # Read the JSON file and print the cpWER
19
+ asrdiar_file_name="err_dev"
20
+ output_cpwer_hyp_json_file = os.path.join(f"{asrdiar_file_name}.hyp.seglst_cpwer.json")
21
+ with open(output_cpwer_hyp_json_file, "r") as temp_file:
22
+ data_h = json.load(temp_file)
23
+ print("Hypothesis cpWER:", data_h["error_rate"])
24
+ cpwer = data_h["error_rate"]
25
+ logging.info(f"-> HYPOTHESIS cpWER={cpwer:.4f}")
26
+
27
+ scores_dict = {"cpWER": cpwer, "WER": cpwer}
28
+ return scores_dict
29
+
30
+
31
+
32
+ def __instruction_scorer(data, judgment_file, model_name):
33
 
34
  df = data
35
  img_dict = {}
seglst_files/err_dev.hyp.seglst.json ADDED
The diff for this file is too large to render. See raw diff
 
seglst_files/err_dev.ref.list ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /home/taejinp/projects/update_llm_speaker_tagging/llm_speaker_tagging/SLT-Task2-Post-ASR-Speaker-Tagging/ref_annotated_text/dev/session_e992c01d.seglst.json
2
+ /home/taejinp/projects/update_llm_speaker_tagging/llm_speaker_tagging/SLT-Task2-Post-ASR-Speaker-Tagging/ref_annotated_text/dev/session_17dba297.seglst.json
3
+ /home/taejinp/projects/update_llm_speaker_tagging/llm_speaker_tagging/SLT-Task2-Post-ASR-Speaker-Tagging/ref_annotated_text/dev/session_e6e6ca6b.seglst.json
4
+ /home/taejinp/projects/update_llm_speaker_tagging/llm_speaker_tagging/SLT-Task2-Post-ASR-Speaker-Tagging/ref_annotated_text/dev/session_197ddec4.seglst.json
5
+ /home/taejinp/projects/update_llm_speaker_tagging/llm_speaker_tagging/SLT-Task2-Post-ASR-Speaker-Tagging/ref_annotated_text/dev/session_ac417036.seglst.json
6
+ /home/taejinp/projects/update_llm_speaker_tagging/llm_speaker_tagging/SLT-Task2-Post-ASR-Speaker-Tagging/ref_annotated_text/dev/session_0edd751f.seglst.json
7
+ /home/taejinp/projects/update_llm_speaker_tagging/llm_speaker_tagging/SLT-Task2-Post-ASR-Speaker-Tagging/ref_annotated_text/dev/session_327770bf.seglst.json
8
+ /home/taejinp/projects/update_llm_speaker_tagging/llm_speaker_tagging/SLT-Task2-Post-ASR-Speaker-Tagging/ref_annotated_text/dev/session_1b20cec4.seglst.json
9
+ /home/taejinp/projects/update_llm_speaker_tagging/llm_speaker_tagging/SLT-Task2-Post-ASR-Speaker-Tagging/ref_annotated_text/dev/session_fa752d9e.seglst.json
10
+ /home/taejinp/projects/update_llm_speaker_tagging/llm_speaker_tagging/SLT-Task2-Post-ASR-Speaker-Tagging/ref_annotated_text/dev/session_ed8a6f55.seglst.json
11
+ /home/taejinp/projects/update_llm_speaker_tagging/llm_speaker_tagging/SLT-Task2-Post-ASR-Speaker-Tagging/ref_annotated_text/dev/session_75e7876e.seglst.json
12
+ /home/taejinp/projects/update_llm_speaker_tagging/llm_speaker_tagging/SLT-Task2-Post-ASR-Speaker-Tagging/ref_annotated_text/dev/session_405fe47b.seglst.json
13
+ /home/taejinp/projects/update_llm_speaker_tagging/llm_speaker_tagging/SLT-Task2-Post-ASR-Speaker-Tagging/ref_annotated_text/dev/session_7fe82ea3.seglst.json
seglst_files/err_dev.ref.seglst.json ADDED
The diff for this file is too large to render. See raw diff
 
seglst_files/err_dev.src.list ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /home/taejinp/projects/update_llm_speaker_tagging/llm_speaker_tagging/SLT-Task2-Post-ASR-Speaker-Tagging/err_source_text/dev/session_e992c01d.seglst.json
2
+ /home/taejinp/projects/update_llm_speaker_tagging/llm_speaker_tagging/SLT-Task2-Post-ASR-Speaker-Tagging/err_source_text/dev/session_17dba297.seglst.json
3
+ /home/taejinp/projects/update_llm_speaker_tagging/llm_speaker_tagging/SLT-Task2-Post-ASR-Speaker-Tagging/err_source_text/dev/session_e6e6ca6b.seglst.json
4
+ /home/taejinp/projects/update_llm_speaker_tagging/llm_speaker_tagging/SLT-Task2-Post-ASR-Speaker-Tagging/err_source_text/dev/session_197ddec4.seglst.json
5
+ /home/taejinp/projects/update_llm_speaker_tagging/llm_speaker_tagging/SLT-Task2-Post-ASR-Speaker-Tagging/err_source_text/dev/session_ac417036.seglst.json
6
+ /home/taejinp/projects/update_llm_speaker_tagging/llm_speaker_tagging/SLT-Task2-Post-ASR-Speaker-Tagging/err_source_text/dev/session_0edd751f.seglst.json
7
+ /home/taejinp/projects/update_llm_speaker_tagging/llm_speaker_tagging/SLT-Task2-Post-ASR-Speaker-Tagging/err_source_text/dev/session_327770bf.seglst.json
8
+ /home/taejinp/projects/update_llm_speaker_tagging/llm_speaker_tagging/SLT-Task2-Post-ASR-Speaker-Tagging/err_source_text/dev/session_1b20cec4.seglst.json
9
+ /home/taejinp/projects/update_llm_speaker_tagging/llm_speaker_tagging/SLT-Task2-Post-ASR-Speaker-Tagging/err_source_text/dev/session_fa752d9e.seglst.json
10
+ /home/taejinp/projects/update_llm_speaker_tagging/llm_speaker_tagging/SLT-Task2-Post-ASR-Speaker-Tagging/err_source_text/dev/session_ed8a6f55.seglst.json
11
+ /home/taejinp/projects/update_llm_speaker_tagging/llm_speaker_tagging/SLT-Task2-Post-ASR-Speaker-Tagging/err_source_text/dev/session_75e7876e.seglst.json
12
+ /home/taejinp/projects/update_llm_speaker_tagging/llm_speaker_tagging/SLT-Task2-Post-ASR-Speaker-Tagging/err_source_text/dev/session_405fe47b.seglst.json
13
+ /home/taejinp/projects/update_llm_speaker_tagging/llm_speaker_tagging/SLT-Task2-Post-ASR-Speaker-Tagging/err_source_text/dev/session_7fe82ea3.seglst.json
seglst_files/err_dev.src.seglst.json ADDED
The diff for this file is too large to render. See raw diff