poyum commited on
Commit
f709e5e
·
1 Parent(s): 2dfb14b

gradio space init

Browse files
Files changed (6) hide show
  1. disrpt_eval_2025.py +517 -0
  2. disrpt_io.py +846 -0
  3. eval.py +760 -0
  4. pipeline.py +142 -0
  5. reading.py +512 -0
  6. utils.py +216 -0
disrpt_eval_2025.py ADDED
@@ -0,0 +1,517 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Script to evaluate segmentation f-score and perfect discourse unit segmentation proportion from two files. Two input formats are permitted:
3
+
4
+ * One token per line, with ten columns, no sentence breaks (default *.tok format) - segmentation indicated in column 10
5
+ * The same, but with blank lines between sentences (*.conll format)
6
+
7
+ Token columns follow the CoNLL-U format, with token IDs in the first column and pipe separated key=value pairs in the last column.
8
+ Document boundaries are indicated by a comment: # newdoc_id = ...
9
+ The evaluation uses micro-averaged F-Scores per corpus (not document macro average).
10
+
11
+ Example:
12
+
13
+ # newdoc_id = GUM_bio_byron
14
+ 1 Education _ _ _ _ _ _ _ Seg=B-seg
15
+ 2 and _ _ _ _ _ _ _ _
16
+ 3 early _ _ _ _ _ _ _ _
17
+ 4 loves _ _ _ _ _ _ _ _
18
+ 5 Byron _ _ _ _ _ _ _ Seg=B-seg
19
+ 6 received _ _ _ _ _ _ _ _
20
+
21
+ Or:
22
+
23
+ # newdoc_id = GUM_bio_byron
24
+ # sent_id = GUM_bio_byron-1
25
+ # text = Education and early loves
26
+ 1 Education education NOUN NN Number=Sing 0 root _ Seg=B-seg
27
+ 2 and and CCONJ CC _ 4 cc _ _
28
+ 3 early early ADJ JJ Degree=Pos 4 amod _ _
29
+ 4 loves love NOUN NNS Number=Plur 1 conj _ _
30
+
31
+ # sent_id = GUM_bio_byron-2
32
+ # text = Byron received his early formal education at Aberdeen Grammar School, and in August 1799 entered the school of Dr. William Glennie, in Dulwich. [17]
33
+ 1 Byron Byron PROPN NNP Number=Sing 2 nsubj _ Seg=B-seg
34
+ 2 received receive VERB VBD Mood=Ind|Tense=Past|VerbForm=Fin 0 root _ _
35
+
36
+ For PDTB-style corpora, we calculate exact span-wise f-scores for BIO encoding, without partial credit. In other words,
37
+ predicting an incorrect span with partial overlap is the same as missing a gold span and predicting an incorrect span
38
+ somewhere else in the corpus. Note also that spans must begin with B-Conn - predicted spans beginning with I-Conn are ignored.
39
+ The file format for PDTB style corpora is similar, but with different labels:
40
+
41
+ 1 Fidelity Fidelity PROPN NNP _ 6 nsubj _ _
42
+ 2 , , PUNCT , _ 6 punct _ _
43
+ 3 for for ADP IN _ 4 case _ Seg=B-Conn
44
+ 4 example example NOUN NN _ 6 obl _ Conn=I-conn
45
+ 5 , , PUNCT , _ 6 punct _ _
46
+ 6 prepared prepare VERB VBN _ 0 root _ _
47
+ 7 ads ad NOUN NNS _ 6 obj _ _
48
+
49
+ Arguments:
50
+ * goldfile: shared task gold test data
51
+ * predfile: same format, with predicted segments positions in column 10 - note **number of tokens must match**
52
+ * string_input: if specified, files are replaced by strings with file contents instead of file names
53
+ * no_boundaries: specify to eval only intra-sentence EDUs
54
+ """
55
+
56
+ """ TODO
57
+ - OK labels : en argument, pas en dur
58
+ - OK option sans ls débuts de phrases : cf script "BIO no B'
59
+ - OK imprimer les résultats + propre : sans le "o" bizarre
60
+ - OK faire 2 classes edu et connectives (conn: futur exp for eval connective extended vs head of connective)
61
+ - solution + propre pour la colonne des labels ?
62
+ - faire une classe Eval et transformer les 2 en Eval en sous-classes
63
+ """
64
+
65
+ __author__ = "Amir Zeldes, Janet Liu, Laura Rivière"
66
+ __license__ = "Apache 2.0"
67
+ __version__ = "2.0.0"
68
+
69
+ import io, os, sys, argparse
70
+ import json
71
+ from sklearn.metrics import accuracy_score, classification_report
72
+
73
+ # MWE and ellips : no lab or "_"
74
+ # TODO :
75
+ # print scores *100: 0.6825 => 68.25
76
+ # documentation (automatic generation ?)
77
+ # testunitaire
78
+
79
+ class Evaluation:
80
+ """
81
+ Generic class for evaluation between 2 files.
82
+ :load data, basic check, basic metrics, print results.
83
+ """
84
+ def __init__(self, name: str) -> None:
85
+ self.output = dict()
86
+ self.name = name
87
+ self.report = ""
88
+ self.fill_output('doc_name', self.name)
89
+
90
+ def get_data(self, infile: str, str_i=False) -> str:
91
+ """
92
+ Stock data from file or stream.
93
+ """
94
+ if str_i == False:
95
+ data = io.open(infile, encoding="utf-8").read().strip().replace("\r", "")
96
+ else:
97
+ data = infile.strip()
98
+ return data
99
+
100
+ def fill_output(self, key: str, value) -> None:
101
+ """
102
+ Fill results dict that will be printed.
103
+ """
104
+ self.output[key] = value
105
+
106
+ def check_tokens_number(self, g: list, p: list) -> None:
107
+ """
108
+ Check same number of tokens/labels in both compared files.
109
+ """
110
+ if len(g) != len(p):
111
+ self.report += "\nFATAL: different number of tokens detected in gold and pred:\n"
112
+ self.report += ">>> In " + self.name + ": " + str(len(g)) + " gold tokens but " + str(len(p)) + " predicted tokens\n\n"
113
+ sys.stderr.write(self.report)
114
+ sys.exit(0)
115
+
116
+ def check_identical_tokens(self, g: list, p: list) -> None:
117
+ """
118
+ Check tokens/features are identical.
119
+ """
120
+ for i, tok in enumerate(g):
121
+ if tok != p[i]:
122
+ self.report += "\nWARN: token strings do not match in gold and pred:\n"
123
+ self.report += ">>> First instance in " + self.name + " token " + str(i) + "\n"
124
+ self.report += "Gold: " + tok + " but Pred: " + p[i] + "\n\n"
125
+ sys.stderr.write(self.report)
126
+ break
127
+
128
+ def compute_PRF_metrics(self, tp: int, fp: int, fn: int) -> None:
129
+ """
130
+ Compute Precision, Recall, F-score from True Positive, False Positive and False Negative counts.
131
+ Save result in dict.
132
+ """
133
+ try:
134
+ precision = tp / (float(tp) + fp)
135
+ except Exception as e:
136
+ precision = 0
137
+
138
+ try:
139
+ recall = tp / (float(tp) + fn)
140
+ except Exception as e:
141
+ recall = 0
142
+
143
+ try:
144
+ f_score = 2 * (precision * recall) / (precision + recall)
145
+ except:
146
+ f_score = 0
147
+
148
+ self.fill_output("gold_count", tp + fn )
149
+ self.fill_output("pred_count", tp + fp )
150
+ self.fill_output("precision", precision)
151
+ self.fill_output("recall", recall)
152
+ self.fill_output("f_score", f_score)
153
+
154
+ def compute_accuracy(self, g: list, p: list, k: str) -> None:
155
+ """
156
+ Compute accuracy of predictions list of items, against gold list of items.
157
+ :g: gold list
158
+ :p: predicted list
159
+ :k: name detail of accuracy
160
+ """
161
+ self.fill_output(f"{k}_accuracy", accuracy_score(g, p) )
162
+ self.fill_output(f"{k}_gold_count", len(g) )
163
+ self.fill_output(f"{k}_pred_count", len(p) )
164
+
165
+ def classif_report(self, g: list, p: list, key: str) -> None:
166
+ """
167
+ Compute Precision, Recall and f-score for each instances of gold list.
168
+ """
169
+ stats_dict = classification_report(g, p, labels=sorted(set(g)), zero_division=0.0, output_dict=True)
170
+ self.fill_output(f'{key}_classification_report', stats_dict)
171
+
172
+ def print_results(self) -> None:
173
+ """
174
+ Print dict of saved results.
175
+ """
176
+ # for k in self.output.keys():
177
+ # print(f">> {k} : {self.output[k]}")
178
+
179
+ print(json.dumps(self.output, indent=4))
180
+
181
+
182
+ class RelationsEvaluation(Evaluation):
183
+ """
184
+ Specific evaluaion class for relations classification.
185
+ The evaluation uses the simple accuracy score per corpus.
186
+ :rels disrpt-style data.
187
+ :default eval last column "label"
188
+ :option eval relation type (pdtb: implicit, explicit...) column "rel_type"
189
+ """
190
+
191
+ HEADER = "doc\tunit1_toks\tunit2_toks\tunit1_txt\tunit2_txt\tu1_raw\tu2_raw\ts1_toks\ts2_toks\tunit1_sent\tunit2_sent\tdir\trel_type\torig_label\tlabel"
192
+ # HEADER_23 = "doc\tunit1_toks\tunit2_toks\tunit1_txt\tunit2_txt\ts1_toks\ts2_toks\tunit1_sent\tunit2_sent\tdir\torig_label\tlabel"
193
+
194
+ LABEL_ID = -1
195
+ TYPE_ID = -3
196
+ DISRPT_TYPES = ['Implicit', 'Explicit', 'AltLex', 'AltLexC', 'Hypophora']
197
+
198
+ def __init__(self, name: str, gold_path: str, pred_path: str, str_i=False, rel_type=False) -> None:
199
+ super().__init__(name)
200
+ """
201
+ :param gold_file: Gold shared task file
202
+ :param pred_file: File with predictions
203
+ :param string_input: If True, files are replaced by strings with file contents (for import inside other scripts)
204
+ :param rel_type: If True, scores are computed on types column, not label (relevant for PDTB)
205
+ """
206
+ self.mode = "rel"
207
+ self.g_path = gold_path
208
+ self.p_path = pred_path
209
+ self.opt_str_i = str_i
210
+ self.opt_rel_t = rel_type
211
+ self.key = "labels"
212
+
213
+ self.fill_output("options", {"s": self.opt_str_i, "rt": self.opt_rel_t})
214
+
215
+ def compute_scores(self) -> None:
216
+ """
217
+ Get lists of data to compare, compute metrics.
218
+ """
219
+ gold_units, gold_labels, gold_types = self.parse_rels_data(self.g_path, self.opt_str_i, self.opt_rel_t)
220
+ pred_units, pred_labels, pred_types = self.parse_rels_data(self.p_path, self.opt_str_i, self.opt_rel_t)
221
+ self.check_tokens_number(gold_labels, pred_labels)
222
+ self.check_identical_tokens(gold_units, pred_units)
223
+
224
+ self.compute_accuracy(gold_labels, pred_labels, self.key)
225
+ self.classif_report(gold_labels, pred_labels, self.key)
226
+
227
+ if self.opt_rel_t:
228
+ self.get_types_scores(gold_labels, pred_labels, gold_types)
229
+
230
+ def get_types_scores(self, g: list, p: list, tg: list) -> None:
231
+ """
232
+ This function is to obtain scores of predictions against gold labels, by types of relations.
233
+ """
234
+
235
+ for t in self.DISRPT_TYPES:
236
+ gold_t = []
237
+ pred_t = []
238
+ for i, _ in enumerate(g):
239
+
240
+ if tg[i] == t.lower():
241
+ gold_t.append(g[i])
242
+ pred_t.append(p[i])
243
+
244
+ self.compute_accuracy(gold_t, pred_t, f"types_{t}")
245
+
246
+ def parse_rels_data(self, path: str, str_i: bool, rel_t: bool) -> tuple[list[str], list[str]]:
247
+ """
248
+ Rels format from DISRPT = header, then one relation classification instance per line.
249
+ :LREC_2024_header = 15 columns.
250
+ """
251
+ data = self.get_data(path, str_i)
252
+ header = data.split("\n")[0]
253
+ assert header == self.HEADER, "Unrecognized .rels header."
254
+ #column_ID = self.TYPE_ID if rel_t == True else self.LABEL_ID
255
+
256
+ rels = data.split("\n")[1:]
257
+ labels = [line.split("\t")[self.LABEL_ID] for line in rels] ######## .lower()
258
+ units = [" ".join(line.split("\t")[:3]) for line in rels]
259
+ types = [line.split("\t")[self.TYPE_ID] for line in rels] if rel_t == True else []
260
+
261
+
262
+
263
+ return units, labels, types
264
+
265
+
266
+ class ConnectivesEvaluation(Evaluation):
267
+ """
268
+ Specific evaluation class for PDTB connectives detection.
269
+ :parse conllu-style data
270
+ :eval upon strict connectives spans
271
+ """
272
+ LAB_CONN_B = "Conn=B-conn" # "Seg=B-Conn" #
273
+ LAB_CONN_I = "Conn=I-conn" # "Seg=I-Conn" #
274
+ LAB_CONN_O = "Conn=O" # "_" #
275
+
276
+ def __init__(self, name:str, gold_path:str, pred_path:str, str_i=False) -> None:
277
+ super().__init__(name)
278
+ """
279
+ :param gold_file: Gold shared task file
280
+ :param pred_file: File with predictions
281
+ :param string_input: If True, files are replaced by strings with file contents (for import inside other scripts)
282
+ """
283
+ self.mode = "conn"
284
+ self.seg_type = "connective spans"
285
+ self.g_path = gold_path
286
+ self.p_path = pred_path
287
+ self.opt_str_i = str_i
288
+
289
+ self.fill_output('seg_type', self.seg_type)
290
+ self.fill_output("options", {"s": self.opt_str_i})
291
+
292
+ def compute_scores(self) -> None:
293
+ """
294
+ Get lists of data to compare, compute metrics.
295
+ """
296
+ gold_tokens, gold_labels, gold_spans = self.parse_conn_data(self.g_path, self.opt_str_i)
297
+ pred_tokens, pred_labels, pred_spans = self.parse_conn_data(self.p_path, self.opt_str_i)
298
+
299
+ self.output['tok_count'] = len(gold_tokens)
300
+
301
+ self.check_tokens_number(gold_tokens, pred_tokens)
302
+ self.check_identical_tokens(gold_tokens, pred_tokens)
303
+ tp, fp, fn = self.compare_spans(gold_spans, pred_spans)
304
+ self.compute_PRF_metrics(tp, fp, fn)
305
+
306
+ def compare_spans(self, gold_spans: tuple, pred_spans: tuple) -> tuple[int, int, int]:
307
+ """
308
+ Compare exact spans.
309
+ """
310
+
311
+ true_positive = 0
312
+ false_positive = 0
313
+ false_negative = 0
314
+
315
+ for span in gold_spans: # not verified
316
+ if span in pred_spans:
317
+ true_positive +=1
318
+ else:
319
+ false_negative +=1
320
+ for span in pred_spans:
321
+ if span not in gold_spans:
322
+ false_positive += 1
323
+
324
+ return true_positive, false_positive, false_negative
325
+
326
+ def parse_conn_data(self, path:str, str_i:bool) -> tuple[list, list, list]:
327
+ """
328
+ LABEL = in last column
329
+ """
330
+ data = self.get_data(path, str_i)
331
+ tokens = []
332
+ labels = []
333
+ spans = []
334
+ counter = 0
335
+ span_start = -1
336
+ span_end = -1
337
+ for line in data.split("\n"): # this loop is same than version 1
338
+ if line.startswith("#") or line == "":
339
+ continue
340
+ else:
341
+ fields = line.split("\t") # Token
342
+ label = fields[-1]
343
+ if "-" in fields[0] or "." in fields[0]: # Multi-Word Expression or Ellips : No pred shall be there....
344
+ continue
345
+ elif self.LAB_CONN_B in label:
346
+ if span_start > -1: # add span
347
+ if span_end == -1:
348
+ span_end = span_start
349
+ spans.append((span_start,span_end))
350
+ span_end = -1
351
+ label = self.LAB_CONN_B
352
+ span_start = counter
353
+ elif self.LAB_CONN_I in label:
354
+ label = self.LAB_CONN_I
355
+ span_end = counter
356
+ else:
357
+ label = "_"
358
+ if span_start > -1: # Add span
359
+ if span_end == -1:
360
+ span_end = span_start
361
+ spans.append((span_start,span_end))
362
+ span_start = -1
363
+ span_end = -1
364
+
365
+ tokens.append(fields[1])
366
+ labels.append(label)
367
+ counter += 1
368
+
369
+ if span_start > -1 and span_end > -1: # Add last span
370
+ spans.append((span_start,span_end))
371
+
372
+ if not self.LAB_CONN_B in labels:
373
+ print(f"Unrecognized labels. Expecting: {self.LAB_CONN_B}, {self.LAB_CONN_I}, {self.LAB_CONN_O}...")
374
+ print("maybe the model is so bad it can't find a B")
375
+
376
+ return tokens, labels, spans
377
+
378
+
379
+ class SegmentationEvaluation(Evaluation):
380
+ """
381
+ Specific evaluation class for EDUs segmentation.
382
+ :parse conllu-style data
383
+ :eval upon first token identification
384
+ """
385
+ LAB_SEG_B = "Seg=B-seg" # "BeginSeg=Yes"
386
+ LAB_SEG_I = "Seg=O" # "_"
387
+
388
+ def __init__(self, name: str, gold_path: str, pred_path: str, str_i=False, no_b=False) -> None:
389
+ super().__init__(name)
390
+ """
391
+ :param gold_file: Gold shared task file
392
+ :param pred_file: File with predictions
393
+ :param string_input: If True, files are replaced by strings with file contents (for import inside other scripts)
394
+ """
395
+ self.mode = "edu"
396
+ self.seg_type = "EDUs"
397
+ self.g_path = gold_path
398
+ self.p_path = pred_path
399
+ self.opt_str_i = str_i
400
+ self.no_b = True if "conllu" in gold_path.split(os.sep)[-1] and no_b == True else False # relevant only in conllu
401
+
402
+ self.fill_output('seg_type', self.seg_type)
403
+ self.fill_output("options", {"s": self.opt_str_i})
404
+
405
+ def compute_scores(self) -> None:
406
+ """
407
+ Get lists of data to compare, compute metrics.
408
+ """
409
+ gold_tokens, gold_labels, gold_spans = self.parse_edu_data(self.g_path, self.opt_str_i, self.no_b)
410
+ pred_tokens, pred_labels, pred_spans = self.parse_edu_data(self.p_path, self.opt_str_i, self.no_b)
411
+
412
+ self.output['tok_count'] = len(gold_tokens)
413
+
414
+ self.check_tokens_number(gold_tokens, pred_tokens)
415
+ self.check_identical_tokens(gold_tokens, pred_tokens)
416
+ tp, fp, fn = self.compare_labels(gold_labels, pred_labels)
417
+ self.compute_PRF_metrics(tp, fp, fn)
418
+
419
+ def compare_labels(self, gold_labels: list, pred_labels: list) -> tuple[int, int, int]:
420
+ """
421
+
422
+ """
423
+ true_positive = 0
424
+ false_positive = 0
425
+ false_negative = 0
426
+
427
+ for i, gold_label in enumerate(gold_labels): # not verified
428
+ pred_label = pred_labels[i]
429
+ if gold_label == pred_label:
430
+ if gold_label == "_":
431
+ continue
432
+ else:
433
+ true_positive += 1
434
+ else:
435
+ if pred_label == "_":
436
+ false_negative += 1
437
+ else:
438
+ if gold_label == "_":
439
+ false_positive += 1
440
+ else: # I-Conn/B-Conn mismatch ?
441
+ false_positive +=1
442
+
443
+ return true_positive, false_positive, false_negative
444
+
445
+ def parse_edu_data(self, path: str, str_i: bool, no_b: bool) -> tuple[list, list, list]:
446
+ """
447
+ LABEL = in last column
448
+ """
449
+ data = self.get_data(path, str_i)
450
+ tokens = []
451
+ labels = []
452
+ spans = []
453
+ counter = 0
454
+ span_start = -1
455
+ span_end = -1
456
+ for line in data.split("\n"): # this loop is same than version 1
457
+ if line.startswith("#") or line == "":
458
+ continue
459
+ else:
460
+ fields = line.split("\t") # Token
461
+ label = fields[-1]
462
+ if "-" in fields[0] or "." in fields[0]: # Multi-Word Expression or Ellipsis : No pred shall be there....
463
+ continue
464
+ elif no_b == True and fields[0] == "1":
465
+ label = "_"
466
+ elif self.LAB_SEG_B in label:
467
+ label = self.LAB_SEG_B
468
+ else:
469
+ label = "_" # 🚩
470
+ if span_start > -1: # Add span
471
+ if span_end == -1:
472
+ span_end = span_start
473
+ spans.append((span_start, span_end))
474
+ span_start = -1
475
+ span_end = -1
476
+
477
+ tokens.append(fields[1])
478
+ labels.append(label)
479
+ counter += 1
480
+
481
+ if span_start > -1 and span_end > -1: # Add last span
482
+ spans.append((span_start, span_end))
483
+
484
+ if not self.LAB_SEG_B in labels:
485
+ exit(f"Unrecognized labels. Expecting: {self.LAB_SEG_B}, {self.LAB_SEG_I}...")
486
+
487
+ return tokens, labels, spans
488
+
489
+
490
+ if __name__ == "__main__":
491
+
492
+ p = argparse.ArgumentParser()
493
+ p.add_argument("-g", "--goldfile", required=True, help="Shared task gold file in .tok or .conll or .rels format.")
494
+ p.add_argument("-p", "--predfile", required=True, help="Corresponding file with system predictions.")
495
+ p.add_argument("-t", "--task", required=True, choices=['S', 'C', 'R'], help="Choose one of the three options: S (EDUs Segmentation), C (Connectives Detection), R (Relations Classification)")
496
+ p.add_argument("-s", "--string_input",action="store_true",help="Whether inputs are file names or strings.")
497
+ p.add_argument("-nb", "--no_boundary_edu", default=False, action='store_true', help="Does not count EDU that starts at beginning of sentence.")
498
+ p.add_argument("-rt", "--rel_type", default=False, action='store_true', help="Eval relations types instead of label.")
499
+
500
+ # help(Evaluation)
501
+ # help(SegmentationEvaluation)
502
+ # help(ConnectivesEvaluation)
503
+ # help(RelationsEvaluation)
504
+
505
+ opts = p.parse_args()
506
+
507
+ name = opts.goldfile.split(os.sep)[-1] if os.path.isfile(opts.goldfile) else f"string_input: {opts.goldfile[0:20]}..."
508
+
509
+ if opts.task == "R":
510
+ my_eval = RelationsEvaluation(name, opts.goldfile, opts.predfile, opts.string_input, opts.rel_type)
511
+ elif opts.task == "C":
512
+ my_eval = ConnectivesEvaluation(name, opts.goldfile, opts.predfile, opts.string_input)
513
+ elif opts.task == "S":
514
+ my_eval = SegmentationEvaluation(name, opts.goldfile, opts.predfile, opts.string_input, opts.no_boundary_edu)
515
+
516
+ my_eval.compute_scores()
517
+ my_eval.print_results()
disrpt_io.py ADDED
@@ -0,0 +1,846 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Classes to read/write disrpt-like files
3
+ + analysis of sentence splitter / "gold" sentences or stanza/spacy sentences
4
+ - ersatz
5
+
6
+ Disrpt is a discourse analysis campaign with (as of 2023):
7
+ - discourse segmentation information, in a conll-like format
8
+ - discourse connective information (also conll-like)
9
+ - discourse relations pairs, in a specific format
10
+
11
+ data are separated by corpora and language with conventionnal names
12
+ as language.framework.corpusname
13
+ eg fra.srdt.annodis
14
+
15
+ TODO:
16
+ - refactor how sentences are stored with dictionary: "connlu" / "tok" / "split"
17
+ [ok] dictionary
18
+ ? refactor creation of corpus/documents to allow for update (or load tok+conllu at once)
19
+ - [ok] italian luna corpus has different meta tags avec un niveau supplémentaire: newdoc_id/newturn_id/newutterance_id
20
+ - [ok] check behaviour on languages without pretrained models/what candidates ?
21
+ - nl, pt, it -> en?
22
+ - thai -> multilingual
23
+ - test different candidates sets for splitting locations:
24
+ - [done] all -> trop sous-spécifié et trop lent
25
+ - [ok] en on all but zho+thai
26
+ - (done] en à la place de multilingual ?
27
+ bad scores on zho
28
+ - [ok] fix bad characters: BOM, replacement char etc
29
+ spécial char for apostrophe, cf
30
+ data_clean/eng.dep.scidtb/eng.dep.scidtb_train.tok / newdoc_id = P16-1030 prob de char pour possessif
31
+ ��antagonist��
32
+
33
+ pb basque: "Osasun-zientzietako Ikertzaileen II ." nb tokens ...
34
+ Iru�eko etc
35
+ - pb turk: tur.pdtb.tdb/tur.pdtb.tdb_train: BOM ? '\ufeff' -> 'Makale'
36
+ + extra blanc dans train (785)?
37
+ 774 olduğunu _ _ _ _ _ _ _ _
38
+ 775 söylüyor _ _ _ _ _ _ _ _
39
+ 776 : _ _ _ _ _ _ _ _
40
+ 777 Türkiye _ _ _ _ _ _ _ _
41
+ 778 demokrasi _ _ _ _ _ _ _ _
42
+ 779 istiyor _ _ _ _ _ _ _ _
43
+ 780 ÖDPGenel _ _ _ _ _ _ _ _
44
+ 781 Başkanı _ _ _ _ _ _ _ _
45
+ 782 Ufuk _ _ _ _ _ _ _ _
46
+ 783 Uras'tan _ _ _ _ _ _ _ _
47
+ 784 : _ _ _ _ _ _ _ _
48
+ 785 _ _ _ _ _ _ _ _
49
+ 786 Türkiye _ _ _ _ _ _ _ _
50
+ 787 , _ _ _ _ _ _ _ _
51
+ 788 AİHM'de _ _
52
+ - pb zh
53
+ zh: ?是 is this "?" listed in ersatz ?
54
+ ??hosto2
55
+ sctb 3.巴斯克
56
+ %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
57
+
58
+
59
+ - specific preproc:
60
+ annodis/gum: titles
61
+ gum/rrt : biblio / articles
62
+ scidtb ?
63
+ - different sentence splitters
64
+ - [ok] ersatz
65
+ - trankit
66
+ - [abandoned] stanza: FIXME: lots of errors done by stanza eg split within words (might be due to bad input tokenization)
67
+ - [done] write doc in disrt format (after transformation for instance)
68
+ - [done] eval of beginning of sentences (precision)
69
+ - [done] (done in split_sentence script) eval / nb sentences connl ~= recall sentences
70
+ - eval length sentences (max)
71
+ - [moot] clean main script : arguments/argparse -> script à part
72
+ - [done] method for sentence splitting (for tok)
73
+ - [done] iterate all docs in corpus
74
+ - [done] choose language according to corpus name automatically
75
+ - ?method for sentence resplitting for conllu ? needs ways of indexing tokens for later reeval ? or eval script does not care ?
76
+
77
+
78
+ candidate sets for splitting:
79
+
80
+ - multilingual (default) is as described in ersatz paper == [EOS punctuation][!number]
81
+ - en requires a space following punctuation
82
+ - all: a space between any two characters
83
+ - custom can be written that uses the determiner.Split() base class
84
+
85
+
86
+
87
+ """
88
+ import sys, os
89
+ import dataclasses
90
+ from itertools import chain
91
+ from collections import Counter
92
+ from copy import copy, deepcopy
93
+ from tqdm import tqdm
94
+ #import progressbar
95
+ #from ersatz import split, utils
96
+ # import trankit
97
+ #import stanza
98
+ #from stanza.pipeline.core import DownloadMethod
99
+
100
+ from transformers import pipeline
101
+
102
+ from wtpsplit import SaT
103
+
104
+
105
+
106
+ # needed to track the mistakes made in preprocessing of the disrpt dataset, whose origin is unknown
107
+ BOM = '\ufeff'
108
+ REPL_CHAR = "\ufffd" # �
109
+
110
+ test_doc_seg = """# newdoc id = geop_3_space
111
+ 1 La le DET _ Definite=Def|Gender=Fem|Number=Sing|PronType=Art 2 det _ BeginSeg=Yes
112
+ 2 Space space PROPN _ _ 0 root _ _
113
+ 3 Launcher Launcher PROPN _ _ 2 flat:name _ _
114
+ 4 Initiative initiative PROPN _ _ 2 flat:name _ _
115
+ 5 . . PUNCT _ _ 2 punct _ _
116
+
117
+ 1 Le le DET _ Definite=Def|Gender=Masc|Number=Sing|PronType=Art 2 det _ BeginSeg=Yes
118
+ 2 programme programme NOUN _ Gender=Masc|Number=Sing 10 nsubj _ _
119
+ 3 de de ADP _ _ 4 case _ _
120
+ 4 Space space PROPN _ _ 2 nmod _ _
121
+ 5 Launcher Launcher PROPN _ _ 4 flat:name _ _
122
+ 6 Initiative initiative PROPN _ _ 4 flat:name _ _
123
+ 7 ( ( PUNCT _ _ 8 punct _ BeginSeg=Yes
124
+ 8 SLI SLI PROPN _ _ 4 appos _ _
125
+ 9 ) ) PUNCT _ _ 8 punct _ _
126
+ 10 vise viser VERB _ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 0 root _ BeginSeg=Yes
127
+ 11 à à ADP _ _ 12 mark _ _
128
+ 12 développer développer VERB _ VerbForm=Inf 10 ccomp _ _
129
+ 13 un un DET _ Definite=Ind|Gender=Masc|Number=Sing|PronType=Art 14 det _ _
130
+ 14 système système NOUN _ Gender=Masc|Number=Sing 12 obj _ _
131
+ 15 de de ADP _ _ 16 case _ _
132
+ 16 lanceur lanceur NOUN _ Gender=Masc|Number=Sing 14 nmod _ _
133
+ 17 réutilisable réutilisable ADJ _ Gender=Masc|Number=Sing 16 amod _ _
134
+ 18 entièrement entièrement ADV _ _ 19 advmod _ _
135
+ 19 inédit inédit ADJ _ Gender=Masc|Number=Sing 14 amod _ _
136
+ 20 . . PUNCT _ _ 10 punct _ _
137
+
138
+ # newdoc id = ling_fuchs_section2
139
+ 1 Théorie théorie PROPN _ _ 0 root _ BeginSeg=Yes
140
+ 2 psychomécanique psychomécanique ADJ _ Gender=Masc|Number=Sing 1 amod _ _
141
+ 3 et et CCONJ _ _ 4 cc _ _
142
+ 4 cognition cognition NOUN _ Gender=Fem|Number=Sing 1 conj _ _
143
+ 5 . . PUNCT _ _ 1 punct _ _
144
+ """
145
+
146
+ # token is just a simple record type
147
+ Token = dataclasses.make_dataclass("Token","id form lemma pos xpos morph head_id dep_type extra label".split(),
148
+ namespace={'__repr__': lambda self: self.form,
149
+ 'format': lambda self: ("\t".join(map(str,dataclasses.astuple(self)))),
150
+ # ignored for now cos we just get rid of MWE when reading disrpt file
151
+ # but could be changed in the future
152
+ #'is_MWE': lambda self: type(self.id) is str and "-" in self.id,
153
+ }
154
+ )
155
+
156
+
157
+ class Sentence:
158
+
159
+ def __init__(self,token_list,meta):
160
+ self.toks = token_list
161
+ self.meta = meta
162
+ # Added by Firmin or chloe ?
163
+ self.label_start = ["Seg=B-conn", "Seg=B-seg"]
164
+ self.label_end = ["Seg=I-conn", "Seg=O"]
165
+
166
+ def __iter__(self):
167
+ return iter(self.toks)
168
+
169
+ def __len__(self):
170
+ return len(self.toks)
171
+
172
+ def display(self,segment=False):
173
+ """if segment option set to true, print sentences with marking of EDUs"""
174
+ if segment:
175
+ output = [f"{'|' if token.label=='Seg=B-seg' else ''}{token.form}" for token in self]
176
+ # output = [f"{'|' if token.label=='BeginSeg=Yes' else ''}{token.form}" for token in self]
177
+ return " ".join(output)+"|"
178
+ else:
179
+ return self.meta["text"]
180
+
181
+ def __in__(self,word):
182
+ for token in self.toks:
183
+ if token.form == word:
184
+ return True
185
+ return False
186
+
187
+ def __repr__(self):
188
+ return self.display()
189
+
190
+ def format(self):
191
+ meta = f"# sent_id = {self.meta['sent_id']}\n" + f"# text = {self.meta['text']}\n"
192
+ output = "\n".join([t.format() for t in self.toks])
193
+ return meta+output
194
+
195
+ # not necessary because of trankit auto-mode but probably safer at some point
196
+ # why dont they use normalized language codes !!??
197
+ TRANKIT_LANG_MAP = {
198
+ "de": "german",
199
+ "en":"english",
200
+ # to be tested
201
+ "gum": "english-gum",
202
+ "fr":"french",
203
+ "it": "italian",
204
+ "sp": "spanish",
205
+ "es": "spanish",
206
+ "eu": "basque",
207
+ "zh": "chinese",
208
+ "ru": "russian",
209
+ "tr": "turkish",
210
+ "pt":"portuguese",
211
+ "fa": "persian",
212
+ "nl":"dutch",
213
+ # blah
214
+ }
215
+
216
+ lg_map = {"sp":"es",
217
+ "po":"pt",
218
+ "tu":"tr"}
219
+
220
+
221
+ def get_language(lang,model):
222
+ lang = lang[:2]
223
+ if lang in lg_map:
224
+ lang = lg_map[lang]
225
+ if model=="ersatz":
226
+ if lang not in ersatz_languages:
227
+ lang = "default-multilingual"
228
+ if model=="trankit":
229
+ lang = TRANKIT_LANG_MAP.get(lang,"auto")
230
+ return lang
231
+
232
+ # This is taken from ersatz https://github.com/rewicks/ersatz/blob/master/ersatz/candidates.py
233
+ # sentence ending punctuation
234
+ # U+0964 । Po DEVANAGARI DANDA
235
+ # U+061F ؟ Po ARABIC QUESTION MARK
236
+ # U+002E . Po FULL STOP
237
+ # U+3002 。 Po IDEOGRAPHIC FULL STOP
238
+ # U+0021 ! Po EXCLAMATION MARK
239
+ # U+06D4 ۔ Po ARABIC FULL STOP
240
+ # U+17D4 ។ Po KHMER SIGN KHAN
241
+ # U+003F ? Po QUESTION MARK
242
+ # U+2026 ... Po Ellipsis
243
+ # U+30FB
244
+ # U+002A *
245
+
246
+ # other acceptable punctuation
247
+ # U+3011 】 Pe RIGHT BLACK LENTICULAR BRACKET
248
+ # U+00BB » Pf RIGHT-POINTING DOUBLE ANGLE QUOTATION MARK
249
+ # U+201D " Pf RIGHT DOUBLE QUOTATION MARK
250
+ # U+300F 』 Pe RIGHT WHITE CORNER BRACKET
251
+ # U+2018 ‘ Pi LEFT SINGLE QUOTATION MARK
252
+ # U+0022 " Po QUOTATION MARK
253
+ # U+300D 」 Pe RIGHT CORNER BRACKET
254
+ # U+201C " Pi LEFT DOUBLE QUOTATION MARK
255
+ # U+0027 ' Po APOSTROPHE
256
+ # U+2019 ’ Pf RIGHT SINGLE QUOTATION MARK
257
+ # U+0029 ) Pe RIGHT PARENTHESIS
258
+
259
+ ending_punc = {
260
+ '\u0964',
261
+ '\u061F',
262
+ '\u002E',
263
+ '\u3002',
264
+ '\u0021',
265
+ '\u06D4',
266
+ '\u17D4',
267
+ '\u003F',
268
+ '\uFF61',
269
+ '\uFF0E',
270
+ '\u2026',
271
+ }
272
+
273
+ closing_punc = {
274
+ '\u3011',
275
+ '\u00BB',
276
+ '\u201D',
277
+ '\u300F',
278
+ '\u2018',
279
+ '\u0022',
280
+ '\u300D',
281
+ '\u201C',
282
+ '\u0027',
283
+ '\u2019',
284
+ '\u0029'
285
+ }
286
+
287
+ list_set = {
288
+ '\u30fb',
289
+ '\uFF65',
290
+ '\u002a', # asterisk
291
+ '\u002d',
292
+ '\u4e00'
293
+ }
294
+ class Document:
295
+ _hard_punct = {"default":{".",";","?","!"}| ending_punc,
296
+ "zh": {"。","?"}
297
+ }
298
+
299
+ def __init__(self,sentence_list,meta,src="conllu"):
300
+ self.sentences = {src:sentence_list}
301
+ self.meta = meta
302
+
303
+ def __repr__(self):
304
+ # ADDED (chloe) the if : else of file type
305
+ if "tok" in self.sentences:
306
+ return "\n".join(map(repr,self.sentences.get("conllu",self.sentences["tok"])))
307
+ elif "conllu" in self.sentences:
308
+ return "\n".join(map(repr,self.sentences.get("conllu",self.sentences["conllu"])))
309
+ else:
310
+ sys.exit("Unknown type of file: "+str(self.sentences.keys()))
311
+
312
+
313
+ def get_sentences(self,src="tok"):
314
+ return self.sentences[src]
315
+
316
+ def baseline_split(self,lang="default"):
317
+ """default split for languages where we have issues re-aligning tokens for various reasons
318
+
319
+ this just splits at every token that is a hard punctuations
320
+
321
+ FIXME : this is not complete
322
+ """
323
+ sentence_id = 1
324
+ sentences = []
325
+ current = []
326
+ orig_doc = self.sentences["tok"][0]
327
+ for token in orig_doc:
328
+ current.append(token)
329
+ if token.lemma in self._hard_punct.get(lang,"default"):
330
+ sentences.append(Sentence(current,meta))
331
+ meta = {"doc_id":orig_doc.meta["doc_id"],
332
+ "sent_id" : sentence_id,
333
+ "text": " ".join([x.form for x in current])
334
+ }
335
+ current = []
336
+ sentence += 1
337
+ if current!=[]:
338
+ meta = {"doc_id":orig_doc.meta["doc_id"],
339
+ "sent_id" : sentence_id,
340
+ "text": " ".join([x.form for x in current])
341
+ }
342
+ sentences.append(Sentence(current,meta))
343
+ return sentences
344
+
345
+
346
+ def cutoff_split(self,cutoff=120,lang="default"):
347
+ """
348
+ default split for corpora with little or no punctuation (transcription etc)
349
+
350
+ just make a new sentence as soon as more than cutoff tokens
351
+ """
352
+ sentence_id = 1
353
+ sentences = []
354
+ current = []
355
+ current_cpt = 1
356
+ orig_doc = self.sentences["tok"][0]
357
+ meta = {"doc_id":orig_doc.meta["doc_id"],
358
+ "sent_id" : sentence_id,
359
+ }
360
+ for token in orig_doc:
361
+ token.id = current_cpt
362
+ current_cpt += 1
363
+ current.append(token)
364
+ #print(token, token.id)
365
+ if len(current) >= cutoff:
366
+ #print(orig_doc.meta["doc_id"],token,current)
367
+ meta = {"doc_id":orig_doc.meta["doc_id"],
368
+ "sent_id" : sentence_id,
369
+ "text": " ".join([x.form for x in current])
370
+ }
371
+ sentences.append(Sentence(current,meta))
372
+ current = []
373
+ sentence_id += 1
374
+ current_cpt = 1
375
+ if current!=[]:
376
+ meta = {"doc_id":orig_doc.meta["doc_id"],
377
+ "sent_id" : sentence_id,
378
+ "text": " ".join([x.form for x in current])
379
+ }
380
+ sentences.append(Sentence(current,meta))
381
+ return sentences
382
+
383
+ def ersatz_split(self,doc,lang='default-multilingual',candidates="en"):
384
+ result = split(model=lang,
385
+ text=doc, output=None,
386
+ batch_size=16,
387
+ candidates=candidates,#'multilingual',
388
+ cpu=True, columns=None, delimiter='\t')
389
+ return result
390
+
391
+ def stanza_split(self,orig_doc,lang):
392
+ nlp = stanza.Pipeline(lang=lang, processors='tokenize',download_method=DownloadMethod.REUSE_RESOURCES)
393
+ doc = nlp(orig_doc)
394
+ sentences = []
395
+ for s in doc.sentences:
396
+ sentences.append(" ".join([t.text for t in s.tokens]))
397
+ return sentences
398
+ #for i, sentence in enumerate(doc.sentences): for token in sentence.tokens / token.text
399
+
400
+ def trankit_split(self,orig_doc,lang,pipeline):
401
+ trk_sentences = pipeline.ssplit(orig_doc)
402
+ sentences = []
403
+ for s in trk_sentences["sentences"]:
404
+ sentences.append(s["text"])
405
+ return sentences
406
+
407
+ def sat_split(self, orig_doc, sat_model):
408
+ sat_sentences = sat_model.split( str(orig_doc) )
409
+ sentences = []
410
+ for s in sat_sentences:
411
+ sentences.append(s)
412
+ return sentences
413
+
414
+ # TODO: debug option to for warnings on/off
415
+ def _remap_tokens(self,split_sentences):
416
+ """remap tokens from sentence splitting to the token original information"""
417
+ #return split_sentences
418
+ # if this fails, there's been a bug: count of tokens is different in original text, and total
419
+ # of split sentences
420
+ # TODO: this is bound to happen, but the output should keep the original token count; how ?
421
+ # TODO: REALIGN by detecting split tokens
422
+ orig_token_nb = sum(map(len,self.sentences["tok"]))
423
+ split_token_nb = len(list(chain(*[x.split() for x in split_sentences])))
424
+ try:
425
+ assert orig_token_nb==split_token_nb
426
+ except:
427
+ print("WARNING wrong nb of tokens",orig_token_nb,"initially but",split_token_nb,"after split",file=sys.stderr)
428
+ #raise NotImplementedError
429
+ new_sentences = []
430
+ position = 0
431
+ skip_first_token = False
432
+ # will only work when splitting tok files, not resplitting conllu
433
+ orig_doc = self.sentences["tok"][0]
434
+ for i,s in enumerate(split_sentences):
435
+ new_toks = s.split()
436
+ if skip_first_token:# see below
437
+ new_toks = new_toks[1:]
438
+ toks = orig_doc.toks[position:position+len(new_toks)]
439
+ meta = {"doc_id":orig_doc.meta["doc_id"],
440
+ "sent_id" : i+1,
441
+ "text": " ".join([x.form for x in toks])
442
+ }
443
+ new_tok_position = position
444
+ shift = 0 # advance thru new tokens in case of erroneous splits
445
+ # actual nb of tokens to advance in the original document
446
+ # new tokens might include split token by mistake (tricky)
447
+ new_toks_length = len(new_toks)
448
+ for j in range(len(toks)):
449
+ toks[j].id = j+1
450
+ new_j = j + shift
451
+ try:
452
+ assert toks[j].form==new_toks[new_j]
453
+ # a split token has been detected meaning it had a punctuation sign in it and makes a "fake" sentence
454
+ # it will be recovered in current sentence so should be skipped in the next one
455
+ skip_first_token = False
456
+ except:
457
+ # TODO: check next token can be recovered
458
+ # pb with chinese punctuation difference codes ?
459
+ #print(f"WARNING === Token mismatch: {j,toks[j].form,new_toks[new_j]} \n {toks} \n {new_toks}",file=sys.stderr)
460
+ # first case: within the same sentence (unlikely if a token was split by a punctuation)
461
+ if j!= len(toks)-1:
462
+ if len(toks[j].form)!=len(new_toks[new_j]): # if same length this is probably just an encoding problem (chinese cases) so just ignore it
463
+ #print(f"INFO: split token still within the sentence {j,toks[j].form,new_toks[new_j]} ... should not happen",file=sys.stderr)
464
+ if toks[j].form==new_toks[new_j]+new_toks[new_j+1]:
465
+ #print(f"INFO: split token correctly identified as {j,toks[j].form,new_toks[new_j]+new_toks[new_j+1]} ... advancing to next one",file=sys.stderr)
466
+ shift = shift + 1
467
+ # second case: the sentence ends here and next token is in the next split sentence, which necessarily exists (?)
468
+ else:
469
+ if i+1<len(split_sentences):
470
+ next_sentence = split_sentences[i+1]
471
+ next_token = split_sentences[i+1].split()[0]
472
+ skip_first_token = True
473
+ if toks[j].form==new_toks[new_j]+next_token:
474
+ pass
475
+ #print(f"INFO: token can be recoverd: ",end="",file=sys.stderr)
476
+ else:
477
+ pass
478
+ #print(f"INFO: token can still not be recoverd: ",end="",file=sys.stderr)
479
+ #print(toks[j].form,new_toks[new_j]+next_token,file=sys.stderr)
480
+ else:
481
+ pass
482
+ #print(f"WARNING === unmatched token at end of document",new_toks[new_j],file=sys.stderr)
483
+ # in theory should not happen
484
+ # the next starting position has to be put back ? no
485
+ # position = position - 1
486
+ if len(toks)>0: # joining the first token might have generated an empty sentence
487
+ new_sentences.append(Sentence(toks,meta))
488
+ position = position + len(new_toks) - shift
489
+ else:
490
+ skip_first_token = False
491
+ split_token_nb = sum( [len(s.toks) for s in new_sentences] )
492
+ #print( "split_token_nb", split_token_nb)
493
+ try:
494
+ assert orig_token_nb==split_token_nb
495
+ except:
496
+ print("ERROR wrong nb of tokens",orig_token_nb,"originally but",split_token_nb,"after split+remap",file=sys.stderr)
497
+ sys.exit()
498
+ return new_sentences
499
+
500
+
501
+ def sentence_split(self,model="ersatz",lang="default-multilingual",**kwargs):
502
+ """
503
+ call the sentence splitter to the actual document read as one from a tok file.
504
+ kwargs might contain an open "pipeline" (eg. trankit pipeline) to pass on downstream for splitting sentences, so that it is not re-created for each paragraph
505
+ """
506
+ # if we split, the doc has been read as only one sentence
507
+ # we ignore multi-word-expression at reading time, but if this needs to be changed, it will impact this line:
508
+ doc = [x.form for x in self.sentences["tok"][0]] # if not(x.is_MWE())]
509
+ doc = " ".join(doc)
510
+ if model=="ersatz":
511
+ # empirically seems better: "en" for all alphabet-based language
512
+ # (candidates = candidate location for sentence splitting)
513
+ # not to be confused with the language of the model
514
+ candidates = "en" if lang not in {"zh","th"} else "multilingual"
515
+ new_sentences = self.ersatz_split(doc,lang=lang,candidates=candidates)
516
+ elif model=="stanza":
517
+ new_sentences = self.stanza_split(doc,lang=lang)
518
+ elif model=="trankit":# initiliazed pipeline is passed on here
519
+ new_sentences = self.trankit_split(doc,lang=lang,**kwargs)
520
+ elif model=="baseline":
521
+ new_sentences = self.baseline_split(lang=lang)
522
+ self.sentences["split"] = new_sentences
523
+ elif model=="sat":
524
+ sat_model = kwargs.get("sat_model")
525
+ if sat_model is None:
526
+ raise ValueError("sat_model must be provided for SAT sentence splitting.")
527
+ new_sentences = self.sat_split(doc, sat_model)
528
+ self.sentences["split"] = new_sentences
529
+ elif model == "cutoff":# FIXME should be a way to pass on the cutoff
530
+ new_sentences = self.cutoff_split(lang=lang)
531
+ self.sentences["split"] = new_sentences
532
+ else:
533
+ raise NotImplementedError
534
+ if model!="baseline" and model!="cutoff":
535
+ self.sentences["split"] = self._remap_tokens(new_sentences)
536
+ return self.sentences["split"]
537
+
538
+
539
+ def search_word(self,word):
540
+ return [s for s in self.sentences.get("split",[]) if word in s]
541
+
542
+ def format(self,mode="split"):
543
+ """format the document as disrpt format
544
+ mode=original (sentences) or split (split_sentences)
545
+ """
546
+ target = self.sentences[mode]
547
+
548
+ output = "\n".join([s.format()+"\n" for s in target])
549
+ meta = f"# doc_id = {self.meta}\n"
550
+ return meta+output #+"\n"
551
+
552
+
553
+ class Corpus:
554
+ META_types = {"newdoc_id":"doc_id",
555
+ "newdoc id":"doc_id",
556
+ "doc_id":"doc_id",
557
+ "sent_id":"sent_id",
558
+ "newturn_id":"newturn_id",
559
+ "newutterance":"newutterance",
560
+ "newutterance_id":"newutterance_id",
561
+ "text":"text",
562
+ }
563
+
564
+
565
+
566
+ def __init__(self,data=None):
567
+ """input to constructor is a string
568
+ """
569
+ if data:
570
+ self.docs = self._parse(data.split("\n"))
571
+
572
+ @staticmethod
573
+ def _meta_parse(data_line):
574
+ """ parse comments as they contain meta information"""
575
+ if not("=" in data_line):# not a meta line
576
+ return "",""
577
+ info, value = data_line[1:].strip().split("=",1)
578
+ info = info.strip()
579
+ if info in Corpus.META_types:
580
+ meta_type = Corpus.META_types[info]
581
+ else:# TODO should send a warning
582
+ #print("WARNING: bad meta line",info, value,data_line,file=sys.stderr) -> this is just flooding the output
583
+ meta_type, value = "",""
584
+ return meta_type,value.strip()
585
+
586
+ def search_doc(self,docid):
587
+ return [x for x in self.docs if x.meta==docid]
588
+
589
+ def _parse(self,data_lines,src="tok"):
590
+ """parse disrpt segmentation/connective files"""
591
+ curr_token_list = []
592
+ sentences = []
593
+ docs = []
594
+ s_idx = 0
595
+ doc_idx = 0
596
+ meta = {}
597
+
598
+ for data_line in data_lines:
599
+ data_line = data_line.strip()
600
+ if data_line:
601
+ # comments always include some meta info of the form "metatype = value", minimally the document id
602
+ if data_line.startswith("#"):
603
+ meta_type,value = Corpus._meta_parse(data_line)
604
+ # start of a new doc, save previous one if it exists
605
+ if meta_type=="doc_id":
606
+ # print( doc_idx)
607
+ if doc_idx>0:
608
+ # print(src)
609
+ docs.append(Document(sentences,meta["doc_id"],src=src))
610
+ sentences = []
611
+ curr_token_list = []
612
+ s_idx = 0
613
+ meta = {}
614
+ doc_idx += 1
615
+ if meta_type!="":
616
+ meta[meta_type] = value
617
+ else:
618
+ token, label = self.parse_token(meta, data_line)
619
+ # print(token, label)
620
+ # if this is a MWE, just ignore it. MWE have ids combining original token ids, eg "30-31"
621
+ # TODO: refactor in parse_token + boolean flag if ok
622
+ if not("-" in token[0]) and not("." in token[0]):
623
+ curr_token_list.append(Token(*token,label))
624
+ else:# end of sentence
625
+ meta["text"] = " ".join((x.form for x in curr_token_list))
626
+ s_idx += 1
627
+ # some corpora dont have ids for sentences
628
+ if "sent_id" not in meta:
629
+ meta["sent_id"] = s_idx
630
+ sentences.append(Sentence(curr_token_list,meta))
631
+ curr_token_list = []
632
+ meta = {"doc_id":meta["doc_id"]}
633
+ if len(curr_token_list)>0 or len(sentences)>0:# final sentence for final document
634
+ meta["text"] = " ".join((x.form for x in curr_token_list))
635
+ sentences.append(Sentence(curr_token_list,meta))
636
+ #print("="*50)
637
+ #print(meta.keys())
638
+ #print(len(curr_token_list),len(sentences))
639
+ docs.append(Document(sentences,meta["doc_id"],src=src))
640
+ # print(src)
641
+ return docs
642
+ def format(self, file=None, mode="split"):
643
+ output = "\n\n".join([doc.format(mode=mode) for doc in self.docs])
644
+ if file:
645
+ os.makedirs(os.path.dirname(file), exist_ok=True)
646
+ with open(file, "w", encoding="utf-8") as f:
647
+ f.write(output)
648
+ return output
649
+ def parse_token(self, meta, data_line):
650
+ *token, label = data_line.split("\t")
651
+ if len(token)==8:
652
+ print("ERROR: missing label ",meta,token,file=sys.stderr)
653
+ token = token + [label]
654
+ label = '_'
655
+ # needed because of errors in source of some corpora (russian with BOM kept as token; also bad reading of some chars)
656
+ # to prevent token counts/tokenization from failing, they are replaced with '_'
657
+ # token[1] is the form of the token
658
+ if token[1] == BOM: token[1]="_"
659
+ #if token[1] == '200�000':
660
+ # print("GOTCHA")
661
+ token[1] = token[1].replace(REPL_CHAR,"_")
662
+ label_set = set(label.split("|"))
663
+ label = (label_set & set(self.LABELS))
664
+ if label==set():
665
+ label= "_"
666
+ else:
667
+ label = label.pop()
668
+ return token,label
669
+
670
+ def from_file(self,filepath):
671
+ """
672
+ reads a conllu or tok file
673
+ connlu has sentences, tok does not
674
+
675
+ option to pass on a string instead of file path, mostly for testing
676
+
677
+ TODO: should be a static method
678
+ """
679
+ self.filepath = filepath
680
+ basename = os.path.basename(filepath)
681
+ src = basename.split(".")[-1] # tok or connlu or split
682
+ #print("src = ",src)
683
+ with open(filepath,"r",encoding="utf8") as f:
684
+ data_lines = f.readlines()
685
+ self.docs = self._parse(data_lines,src=src)
686
+ # for sent in self.docs:
687
+ # print( sent )
688
+ def from_string(self, text: str, src="conllu"):
689
+ """
690
+ Lit directement à partir d'une string (utile pour tests ou génération dynamique).
691
+ src peut être 'conllu', 'tok', ou 'split' pour indiquer le format.
692
+ """
693
+ self.filepath = None
694
+ if isinstance(text, str):
695
+ data_lines = text.strip().splitlines()
696
+ else:
697
+ raise ValueError("from_string attend une chaîne de caractères")
698
+ self.docs = self._parse(data_lines, src=src)
699
+ def format(self,mode="split",file=sys.stdout):
700
+ if type(file)==str:
701
+ os.makedirs(os.path.dirname(file), exist_ok=True)
702
+ file = open(file,"w")
703
+ for d in self.docs:
704
+ print(d.format(mode=mode),file=file)
705
+
706
+ def align(self,filepath):
707
+ """load conllu for corresponding tok file"""
708
+ pass
709
+
710
+ def sentence_split(self,model="ersatz",lang="default-multilingual",**kwargs):
711
+ """apply a sentence splitter to the document, assuming this was read from
712
+ a .tok file
713
+
714
+ kwargs might contain an open "pipeline" (eg. trankit pipeline) to pass on downstream for splitting sentences, so that it is not re-created for each paragraph
715
+
716
+ """
717
+ for doc in tqdm(self.docs):
718
+ doc.sentence_split(model=model,lang=lang,**kwargs)
719
+
720
+
721
+ def eval_sentences(self,mode="split"):
722
+ """eval sentence beginning as segment beginning
723
+ TODO rename -> precision
724
+
725
+ only .tok for now but could be used to eval re-split of connlu
726
+ more complex for pdtb: need to align tok and connlu
727
+ """
728
+ tp = 0
729
+ total_s = 0
730
+ labels = []
731
+ for doc in self.docs:
732
+ for s in doc.get_sentences(mode):
733
+ if len(s.toks)==0:
734
+ print("WARNING empty sentence in ",s.meta,file=sys.stderr)
735
+ break
736
+ tp += (s.toks[0].label=="Seg=B-seg")
737
+ # tp += (s.toks[0].label=="BeginSeg=Yes")
738
+ total_s += 1
739
+ labels.extend([x.label for x in s])
740
+ counts = Counter(labels)
741
+ # return tp, total_s, counts["BeginSeg=Yes"]
742
+ return tp, total_s, counts["Seg=B-seg"]
743
+
744
+ class SegmentCorpus(Corpus):
745
+ LABELS = ["Seg=O","Seg=B-seg"]
746
+
747
+ class ConnectiveCorpus(Corpus):
748
+ LABELS = ['Conn=O', 'Conn=B-conn', 'Conn=I-conn']
749
+ id2label = {i: label for i, label in enumerate( LABELS )}
750
+ label2id = {v: k for k,v in id2label.items()}
751
+
752
+ class RelationCorpus(Corpus):
753
+
754
+ def from_file(self,filepath):
755
+ pass
756
+
757
+ # ersatz existing language-specific models
758
+ # for ersatz 1.0.0:
759
+ # ['en', 'ar', 'cs', 'de', 'es', 'et', 'fi', 'fr', 'gu', 'hi', 'iu', 'ja',
760
+ # 'kk', 'km', 'lt', 'lv', 'pl', 'ps', 'ro', 'ru', 'ta', 'tr', 'zh', 'default-multilingual']
761
+ # missing disrpt languages/what candidates ? nl, pt, it -> en? thai -> multilingual
762
+
763
+
764
+ if __name__=="__main__":
765
+ # testing
766
+ import sys, os
767
+ from pathlib import PurePath
768
+ # from ersatz import split, utils
769
+ # ersatz existing language-specific models
770
+ # languages = utils.MODELS.keys()
771
+
772
+
773
+ sat = SaT("sat-3l") # 3L is better with French guillemets
774
+
775
+ #print(corpus.docs[0].sentences[11].display(segment=True))
776
+ print( sat.split("This is a test This is another test.") )
777
+ if len(sys.argv)>1:
778
+ test_path = sys.argv[1]
779
+ else:
780
+ test_path = "../jiant/tests/test_data/eng.pdtb.pdtb/eng.pdtb.pdtb_debug.tok"
781
+ # test_path = "../jiant/tests/test_data/eng.pdtb.pdtb/eng.pdtb.pdtb_debug.tok"
782
+
783
+ basename = os.path.basename(test_path)
784
+ lang = basename.split(".")[0]
785
+ # lang = get_language(lang,"trankit")
786
+
787
+ path = PurePath(test_path)
788
+ #output_path = str(path.with_suffix(".split"))
789
+ output_path = "out"
790
+
791
+ if "pdtb" in test_path:
792
+ corpus = ConnectiveCorpus()
793
+ else:
794
+ corpus = SegmentCorpus()
795
+ corpus.from_file(test_path)
796
+
797
+ sat = SaT("sat-3l") # 3L is better with French guillemets
798
+
799
+ #print(corpus.docs[0].sentences[11].display(segment=True))
800
+ print( sat.split("This is a test This is another test.") )
801
+ doc1 = corpus.docs[0]
802
+ s0 = doc1.sentences["tok"][0]
803
+ print(doc1)
804
+ print(list(sat.split(str(doc1))))
805
+ # list(res)
806
+ # pipe = pipeline("token-classification", model="segment-any-text/sat-1l")
807
+ # res = doc1.sentence_split(model="sat")
808
+
809
+ # ------------------------------------------
810
+ # -- From SaT DOC
811
+ # https://github.com/segment-any-text/wtpsplit?tab=readme-ov-file#usage
812
+ # sat = SaT("sat-3l")
813
+ # optionally run on GPU for better performance
814
+ # also supports TPUs via e.g. sat.to("xla:0"), in that case pass `pad_last_batch=True` to sat.split
815
+ # sat.half().to("cuda")
816
+
817
+ # print( sat.split("This is a test This is another test.") )
818
+ # returns ["This is a test ", "This is another test."]
819
+
820
+ # # do this instead of calling sat.split on every text individually for much better performance
821
+ # sat.split(["This is a test This is another test.", "And some more texts..."])
822
+ # # returns an iterator yielding lists of sentences for every text
823
+
824
+ # # use our '-sm' models for general sentence segmentation tasks
825
+ # sat_sm = SaT("sat-3l-sm")
826
+ # sat_sm.half().to("cuda") # optional, see above
827
+ # sat_sm.split("this is a test this is another test")
828
+ # # returns ["this is a test ", "this is another test"]
829
+
830
+ # # use trained lora modules for strong adaptation to language & domain/style
831
+ # sat_adapted = SaT("sat-3l", style_or_domain="ud", language="en")
832
+ # sat_adapted.half().to("cuda") # optional, see above
833
+ # sat_adapted.split("This is a test This is another test.")
834
+ # # returns ['This is a test ', 'This is another test']
835
+
836
+
837
+
838
+ # check that number of token is conserved by sentence splitting
839
+ # #assert sum(map(len,doc1.sentences))==len(list(chain(*[x.split() for x in res])))
840
+ # pipeline = trankit.Pipeline(lang,gpu=True)
841
+ # corpus.sentence_split(model="trankit",lang=lang,pipeline=pipeline)
842
+ corpus.sentence_split(model="sat", sat_model=sat)
843
+ tp, tot, all = corpus.eval_sentences()
844
+ print(tp, tot, all)
845
+ #print(corpus.docs[0].split_sentences[0].toks[0].format())
846
+ corpus.format(file=output_path)
eval.py ADDED
@@ -0,0 +1,760 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python
2
+ # -*- coding: utf-8 -*-
3
+
4
+ import os, sys
5
+ import numpy as np
6
+ import transformers
7
+
8
+ import utils
9
+
10
+ import reading
11
+
12
+
13
+
14
+ SUBTOKEN_START = '##'
15
+
16
+ '''
17
+ TODOs:
18
+
19
+ - for now, if the dataset is cached, can t use word ids and the predictions
20
+ written are not based on original eval file, thus not exactly same number
21
+ of tokens (ignore contractions) --> doesn t work in disrpt eval script
22
+
23
+ Change in newest version of transformers:
24
+ from seqeval.metrics import accuracy_score
25
+ from seqeval.metrics import classification_report
26
+ from seqeval.metrics import f1_score
27
+ '''
28
+
29
+
30
+ def simple_eval( dataset_eval, model_checkpoint, tokenizer, output_path,
31
+ config, trace=False ):
32
+ '''
33
+ Run the pre-trained model on the (dev) dataset to get predictions,
34
+ then write the predictions in an output file.
35
+
36
+ Parameters:
37
+ -----------
38
+ datasets: dict of DatasetDisc
39
+ The datasets read
40
+ model_checkpoint: str
41
+ path to the saved model
42
+ tokenizer: Tokenizer
43
+ tokenizer of the saved model (TODO: retrieve from model? or should be removed?)
44
+ output_path: str
45
+ path to the output directory where prediction files will be written
46
+ data_collator: DataCollator
47
+ (TODO: retrieve from model?)
48
+ '''
49
+ # Retrieve predictions (list of list of 0 and 1)
50
+ print("\n-- PREDICT on:", dataset_eval.annotations_file )
51
+ model_checkpoint = os.path.normpath(model_checkpoint)
52
+ print("model_checkpoint", model_checkpoint)
53
+ preds_from_model, label_ids, metrics = retrieve_predictions( model_checkpoint,
54
+ dataset_eval, output_path, tokenizer, config )
55
+
56
+ print("preds_from_model.shape", preds_from_model.shape)
57
+ print("label_ids.shape", label_ids.shape)
58
+
59
+ # - Compute metrics
60
+ print("\n-- COMPUTE METRICS" )
61
+ compute_metrics = utils.prepare_compute_metrics( dataset_eval.LABEL_NAMES_BIO )
62
+ metrics=compute_metrics([preds_from_model, label_ids])
63
+ max_preds_from_model = np.argmax(preds_from_model, axis=-1)
64
+
65
+ # - Write predictions:
66
+ pred_file = os.path.join( output_path, dataset_eval.basename+'.preds' )
67
+ print("\n-- WRITE PREDS in:", pred_file )
68
+ pred_file_success = True
69
+
70
+ try:
71
+ try:
72
+ # * retrieving the original words: will fail if cache not emptied
73
+ print( "Write predictions based on words")
74
+ predictions = align_tokens_labels_from_wordids( max_preds_from_model, dataset_eval,
75
+ tokenizer)
76
+
77
+ write_pred_file( dataset_eval.annotations_file, pred_file, predictions, trace=trace )
78
+ except IndexError:
79
+ # if error, we print the predictions with tokens, trying to merge subtokens
80
+ # based on SUBTOKEN_START and labels at -100
81
+ print( "Write predictions based on model tokenisation" )
82
+ aligned_tokens, aligned_golds, aligned_preds = align_tokens_labels_from_subtokens(
83
+ max_preds_from_model, dataset_eval, tokenizer, pred_file, trace=trace )
84
+ write_pred_file_from_scratch( aligned_tokens, aligned_golds, aligned_preds,
85
+ pred_file, trace=trace )
86
+ except Exception as e:
87
+ print( "Problem when trying to write predictions in file", pred_file )
88
+ print( "Exception:", e )
89
+ print("we skip the prediction writing step")
90
+ pred_file_success=False
91
+
92
+ if pred_file_success:
93
+ print( "\n-- EVAL DISRPT script" )
94
+ clean_pred_path = pred_file.replace('.preds', '.cleaned.preds')
95
+ utils.clean_pred_file(pred_file, clean_pred_path)
96
+ utils.compute_metrics_dirspt( dataset_eval, clean_pred_path, task=config['task'] )
97
+ # except:
98
+ # print("Problem when trying to compute scores with DISRPT eval script")
99
+ return metrics
100
+ # - Test DISRPT eval script
101
+ # try:
102
+
103
+
104
+ def write_pred_file(annotations_file, pred_file, predictions, trace=False):
105
+ '''
106
+ Write a file containing the predictions based on the original annotation file.
107
+ It takes each line in the original evaluation file and append the prediction at
108
+ the end. Predictions and original tokens need to be perfectly aligned.
109
+
110
+ Parameters:
111
+ -----------
112
+ annotations_file: str | file path OR raw text
113
+ Path to the original evaluation file, or the text content itself
114
+ pred_file: str
115
+ Path to the output prediction file
116
+ predictions: list of str
117
+ Flat list of all predictions (DISRPT format) for all tokens in eval
118
+ '''
119
+ count_pred_B, count_gold_B = 0, 0
120
+ count_line_dash = 0
121
+ count_line_dot = 0
122
+
123
+
124
+
125
+ # --- Déterminer si annotations_file est un chemin ou du texte brut
126
+ if os.path.isfile(annotations_file):
127
+ with open(annotations_file, 'r', encoding='utf-8') as fin:
128
+ mylines = fin.readlines()
129
+ else:
130
+ # Considérer que c’est une string brute
131
+ mylines = annotations_file.strip().splitlines()
132
+
133
+
134
+ os.makedirs(os.path.dirname(pred_file), exist_ok=True)
135
+ with open(pred_file, 'w', encoding='utf-8') as fout:
136
+ count = 0
137
+ if trace:
138
+ print("len(predictions)", len(predictions))
139
+ for l in mylines:
140
+ l = l.strip()
141
+ if l.startswith("#"): # Keep metadata
142
+ fout.write(l + '\n')
143
+ elif l == '' or l == '\n': # keep line break
144
+ fout.write('\n')
145
+ elif '-' in l.split('\t')[0]: # Keep lines for contractions but no label
146
+ if trace:
147
+ print("WARNING: line with - in token, no label will be added")
148
+ count_line_dash += 1
149
+ fout.write(l + '\t' + '_' + '\n')
150
+ # strange case in GUM
151
+ elif '.' in l.split('\t')[0]: # Keep lines no label
152
+ count_line_dot += 1
153
+ if trace:
154
+ print("WARNING: line with . in token, no label will be added")
155
+ fout.write(l + '\t' + '_' + '\n')
156
+ else:
157
+ if 'B' in predictions[count]:
158
+ count_pred_B += 1
159
+ if 'Seg=B-seg' in l or 'Conn=B-conn' in l:
160
+ count_gold_B += 1
161
+ fout.write(l + '\t' + predictions[count] + '\n')
162
+ count += 1
163
+
164
+ print("Count the number of predictions corresponding to a B", count_pred_B, "vs Gold B", count_gold_B)
165
+ print("Count the number of lines with - in token", count_line_dash)
166
+ print("Count the number of lines with . in token", count_line_dot)
167
+
168
+
169
+
170
+ def write_pred_file_from_scratch( aligned_tokens, aligned_golds, aligned_preds, pred_file, trace=False ):
171
+ '''
172
+ Write a prediction file based on a alignment between tokenisation and predictions.
173
+ Since we are not sur that we retrieved the exact alignment, the writing here is not based
174
+ on the original annotation file, but we use a similar format:
175
+ # Sent ID
176
+ tok_ID token gold_label pred_label
177
+
178
+ The use of the DISRPT script will show whther the alignment worked or not ...
179
+
180
+ Parameters:
181
+ ----------
182
+ aligned_XX: list of list of str
183
+ The tokens / preds / golds for each sentence
184
+ '''
185
+ count_pred_B, count_gold_B = 0, 0
186
+ with open( pred_file, 'w' ) as fout:
187
+ if trace:
188
+ print( 'len tokens', len(aligned_tokens))
189
+ print("len(predictions)", len(aligned_preds))
190
+ print( 'len(golds)', len(aligned_preds))
191
+ for s, tok_sent in enumerate( aligned_tokens ):
192
+ fout.write( "# sent_id = "+str(s)+"\n" )
193
+ for i, tok in enumerate( tok_sent ):
194
+ g = aligned_golds[s][i]
195
+ p = aligned_preds[s][i]
196
+ fout.write( '\t'.join([str(i), tok, g, p])+'\n' )
197
+ if 'B' in p:
198
+ count_pred_B += 1
199
+ if 'Seg=B-seg' in g or 'Conn=B-conn' in g:
200
+ count_gold_B += 1
201
+ fout.write( "\n" )
202
+ print("Count the number of predictions corresponding to a B", count_pred_B, "vs Gold B", count_gold_B)
203
+
204
+
205
+
206
+
207
+
208
+
209
+ def align_tokens_labels_from_wordids( preds_from_model, dataset_eval, tokenizer, trace=False ):
210
+ '''
211
+ Write predictions for segmentation or connective tasks in an output files.
212
+ The output is the same as the input gold file, with an additional column
213
+ corresponding to the predicted label.
214
+
215
+ Easiest way (?): use word_ids information to merge the words that been split et
216
+ retrieve the original tokens from the input .tok / .conllu files and run
217
+ evaluation --> but not kept in the cached datasets
218
+
219
+ Parameters:
220
+ -----------
221
+ preds_from_model: list of int
222
+ The predicted labels (numeric ids)
223
+ dev: DatasetDisc
224
+ Dataset for evalusation
225
+ pred_file: str
226
+ Path to the file where predictions will be written
227
+
228
+ Return:
229
+ -------
230
+ predictions: list of String
231
+ The predicted labels (DISRPT format) for each original input word
232
+ '''
233
+
234
+ word_ids = dataset_eval.all_word_ids
235
+ id2label = dataset_eval.id2label
236
+ predictions = []
237
+ for i in range( preds_from_model.shape[0] ):
238
+ sent_input_ids = dataset_eval.tokenized_datasets['input_ids'][i]
239
+ tokens = dataset_eval.dataset['tokens'][i]
240
+ sent_tokens = tokenizer.decode(sent_input_ids[1:-1])
241
+ aligned_preds = _merge_tokens_preds_sent( word_ids[i], preds_from_model[i], tokens )
242
+ if trace:
243
+ print( '\n', i, sent_tokens )
244
+ print( sent_input_ids )
245
+ print( preds_from_model[i])
246
+ print( ' '.join( tokens ) )
247
+ print( "aligned_preds", aligned_preds )
248
+ for k, tok in enumerate( tokens ):
249
+ # Ignorer les tokens spéciaux
250
+ if tok.startswith('[LANG=') or tok.startswith('[FRAME='):
251
+ if trace:
252
+ print(f"Skip special token: {tok}")
253
+ continue
254
+ label = aligned_preds[k]
255
+ predictions.append( id2label[label] )
256
+ return predictions
257
+
258
+ def _merge_tokens_preds_sent( word_ids, preds, tokens ):
259
+ '''
260
+ The tokenizer split the tokens into subtokens, with labels added on subwords.
261
+ For evaluation, we need to merge the subtokens, and keep only the labels on
262
+ the plain tokens.
263
+ The function takes the whole input_ids and predictions for one sentence and
264
+ return the merged version.
265
+ We also get rid of tokens and associated labels for [CLS] and [SEP] and don't
266
+ keep predictions for padding tokens.
267
+ TODO: here inspireed from the mthod to split the labels, but we can cut the
268
+ 2 continue (kept for debug)
269
+
270
+ input_ids: list
271
+ list of ids of (sub)tokens as produced by the (BERT like) tokenizer
272
+ preds: list
273
+ the predictions of the model
274
+ '''
275
+ aligned_toks = []
276
+ count = 0
277
+ new_labels = []
278
+ current_word = None
279
+ for i, word_id in enumerate( word_ids ):
280
+ count += 1
281
+ if word_id != current_word:
282
+ # New word
283
+ current_word = word_id
284
+ if word_id is not None:
285
+ new_labels.append( preds[i] )
286
+ aligned_toks.append( tokens[word_id] )
287
+ elif word_id is None:
288
+ # Special token
289
+ continue
290
+ else:
291
+ # Same word as previous token
292
+ continue
293
+ if len(new_labels) != len(aligned_toks) or len(new_labels) != len(tokens):
294
+ print( "WARNING, something wrong, not the same nb of tokens and predictions")
295
+ print( len(new_labels), len(aligned_toks), len(tokens) )
296
+ return new_labels
297
+
298
+
299
+ def map_labels_list( list_labels, id2label ):
300
+ return [id2label[l] for l in list_labels]
301
+
302
+ def align_tokens_labels_from_subtokens( preds_from_model, dataset_eval, tokenizer, pred_file, trace=False ):
303
+ '''
304
+ Align tokens and labels (merging subtokens, assigning the right label)
305
+ based on the specific characters for starting a subtoken (e.g. ## for BERT)
306
+ and label -100 assigned to contractions of MWE (e.g. it's).
307
+ But not completely sure that we get the exact alignment with original words here.
308
+ '''
309
+ aligned_tokens, aligned_golds, aligned_preds = [], [], []
310
+ id2label = dataset_eval.id2label
311
+ tokenized_dataset = dataset_eval.tokenized_datasets
312
+ # print("\ndataset_eval.tokenized_datasets", dataset_eval.tokenized_datasets)
313
+ # print("preds_from_model.shape", preds_from_model.shape)
314
+ # For each sentence
315
+ with open(pred_file, 'w') as fout:
316
+ # Iterate on sentences
317
+ for i in range( preds_from_model.shape[0] ):
318
+ # fout.write( "new_sent_"+str(i)+'\n' )
319
+ sent_input_ids = dataset_eval.tokenized_datasets['input_ids'][i]
320
+ sent_gold_labels = tokenized_dataset['labels'][i]
321
+ sent_pred_labels = preds_from_model[i]
322
+ aligned_t, aligned_g, aligned_p = _retrieve_tokens_from_sent( sent_input_ids, sent_pred_labels,
323
+ sent_gold_labels, tokenizer, trace=trace )
324
+ aligned_tokens.append(aligned_t)
325
+ aligned_golds.append( map_labels_list(aligned_g, id2label) )
326
+ aligned_preds.append( map_labels_list(aligned_p, id2label) )
327
+ return aligned_tokens, aligned_golds, aligned_preds
328
+
329
+ def _retrieve_tokens_from_sent( sent_input_ids, preds_from_model, sent_gold_labels, tokenizer, trace=False ):
330
+ # tokenized_dataset = dataset.tokenized_datasets
331
+ cur_token, cur_pred, cur_gold = None, None, None
332
+ tokens, golds, preds = [], [], []
333
+ if trace:
334
+ print( '\n\nlen(sent_input_ids', len(sent_input_ids))
335
+ print( 'len(preds_from_model)', len(preds_from_model) ) #with padding
336
+ print( 'len(sent_gold_labels)', sent_gold_labels)
337
+ # Ignore first and last token / labels
338
+ for j, input_id in enumerate( sent_input_ids[1:-1] ):
339
+ gold_label = sent_gold_labels[j+1]
340
+ pred_label = preds_from_model[j+1]
341
+ subtoken = tokenizer.decode( input_id )
342
+ if trace:
343
+ print( subtoken, gold_label, pred_label )
344
+ # Deal with tokens split into subtokens, keep label of the first subtoken
345
+ if subtoken.startswith( SUBTOKEN_START ) or gold_label == -100:
346
+ if cur_token == None:
347
+ print( "WARNING: first subtoken without a token, probably a contraction or MWE")
348
+ cur_token=""
349
+ cur_token += subtoken
350
+ else:
351
+ if cur_token != None:
352
+ tokens.append( cur_token )
353
+ golds.append(cur_gold)
354
+ preds.append(cur_pred)
355
+ cur_token = subtoken
356
+ cur_pred = pred_label
357
+ cur_gold = gold_label
358
+ # add last one
359
+ tokens.append( cur_token )
360
+ golds.append(cur_gold)
361
+ preds.append(cur_pred)
362
+ if trace:
363
+ print( "\ntokens:", len(tokens), tokens )
364
+ print( "golds", len(golds), golds )
365
+ print( "preds", len(preds), preds )
366
+ for i, tok in enumerate(tokens):
367
+ print( tok, golds[i], preds[i])
368
+ return tokens, golds, preds
369
+
370
+ def retrieve_predictions(model_checkpoint, dataset_eval, output_path, tokenizer, config):
371
+ """
372
+ Load the trainer in eval mode and compute predictions
373
+ on dataset_eval (peut être un dataset HuggingFace OU une liste de phrases)
374
+ """
375
+ import os, transformers
376
+
377
+ model_path = model_checkpoint
378
+ if os.path.isfile(model_checkpoint):
379
+ print(f"[INFO] Le chemin du modèle pointe vers un fichier, utilisation du dossier parent: {os.path.dirname(model_checkpoint)}")
380
+ model_path = os.path.dirname(model_checkpoint)
381
+
382
+ config_file = os.path.join(model_path, "config.json")
383
+ if not os.path.exists(config_file):
384
+ raise FileNotFoundError(f"Aucun fichier config.json trouvé dans {model_path}.")
385
+
386
+ # Load model
387
+ model = transformers.AutoModelForTokenClassification.from_pretrained(model_path)
388
+
389
+ # Collator
390
+ data_collator = transformers.DataCollatorForTokenClassification(
391
+ tokenizer=tokenizer,
392
+ padding=config["tok_config"]["padding"]
393
+ )
394
+ compute_metrics = utils.prepare_compute_metrics(
395
+ getattr(dataset_eval, "LABEL_NAMES_BIO", None) or []
396
+ )
397
+
398
+ # Mode eval
399
+ model.eval()
400
+
401
+ test_args = transformers.TrainingArguments(
402
+ output_dir=output_path,
403
+ do_train=False,
404
+ do_predict=True,
405
+ dataloader_drop_last=False,
406
+ report_to=config.get("report_to", "none"),
407
+ )
408
+
409
+ trainer = transformers.Trainer(
410
+ model=model,
411
+ args=test_args,
412
+ data_collator=data_collator,
413
+ compute_metrics=compute_metrics,
414
+ )
415
+
416
+ # Si dataset_eval est juste une liste de phrases → on fabrique un Dataset
417
+ from datasets import Dataset
418
+
419
+ if isinstance(dataset_eval, list):
420
+ dataset_eval = Dataset.from_dict({"text": dataset_eval})
421
+ def tokenize(batch):
422
+ return tokenizer(batch["text"], truncation=True, padding=True)
423
+ dataset_eval = dataset_eval.map(tokenize, batched=True)
424
+
425
+
426
+ predictions, label_ids, metrics = trainer.predict(dataset_eval)
427
+ else:
428
+ # - Make predictions on eval dataset
429
+ predictions, label_ids, metrics = trainer.predict(dataset_eval.tokenized_datasets)
430
+ return predictions, label_ids, metrics
431
+
432
+
433
+
434
+ # --------------------------------------------------------------------------
435
+ # --------------------------------------------------------------------------
436
+ if __name__=="__main__":
437
+ import argparse, os
438
+ import shutil
439
+
440
+ path = os.path.join(os.path.expanduser("~"), ".cache", "huggingface", "datasets")
441
+
442
+ if os.path.exists(path):
443
+ shutil.rmtree(path)
444
+ print(f"Le dossier '{path}' a été supprimé.")
445
+ else:
446
+ print(f"Le dossier '{path}' n'existe pas.")
447
+
448
+ parser = argparse.ArgumentParser(
449
+ description='DISCUT: Discourse segmentation and connective detection'
450
+ )
451
+
452
+ # EVAL file
453
+ parser.add_argument("-t", "--test",
454
+ help="Eval file. Default: data_test/eng.sample.rstdt/eng.sample.rstdt_dev.conllu",
455
+ default="data_test/eng.sample.rstdt/eng.sample.rstdt_dev.conllu")
456
+
457
+ # PRE FINE-TUNED MODEL
458
+ parser.add_argument("-m", "--model",
459
+ help="path to the directory where is the Model file.",
460
+ default=None)
461
+
462
+ # OUTPUT DIRECTORY
463
+ parser.add_argument("-o", "--output",
464
+ help="Directory where models and pred will be saved. Default: /home/cbraud/experiments/expe_discut_2025/",
465
+ default="./data/temp_expe/")
466
+
467
+ # CONFIG FILE FROM THE FINE TUNED MODEL
468
+ parser.add_argument("-c", "--config",
469
+ help="Config file. Default: ./config_seg.json",
470
+ default="./config_seg.json")
471
+
472
+ # TRACE / VERBOSITY
473
+ parser.add_argument( '-v', '--trace',
474
+ action='store_true',
475
+ default=False,
476
+ help="Whether to print full messages. If used, it will override the value in config file.")
477
+
478
+ # TODO Add an option for choosing the tool to split the sentences
479
+
480
+ args = parser.parse_args()
481
+
482
+ eval_path = args.test
483
+ output_path = args.output
484
+ if not os.path.isdir( output_path ):
485
+ os.makedirs(output_path, exist_ok=True )
486
+ config_file = args.config
487
+ model = args.model
488
+ trace = args.trace
489
+
490
+ print( '\n-[DISCUT]--PROGRAM (eval) ARGUMENTS')
491
+ print( '| Mode', 'eval' )
492
+ if not model:
493
+ sys.exit( "Please provide a path to a model for eval mode.")
494
+ print( '| Test_path:', eval_path )
495
+ print( "| Output_path:", output_path )
496
+ if model:
497
+ print( "| Model:", model )
498
+ print( '| Config:', config_file )
499
+
500
+ print( '\n-[DISCUT]--CONFIG INFO')
501
+ config = utils.read_config( config_file )
502
+ utils.print_config( config )
503
+
504
+ print( "\n-[DISCUT]--READING DATASET")
505
+ ###
506
+ datasets = {}
507
+ datasets['dev'], tokenizer = reading.read_dataset( eval_path, output_path, config )
508
+
509
+ # model also in config[best_model_path]
510
+ metrics=simple_eval( datasets['dev'], model, tokenizer, output_path, config, trace=trace )
511
+
512
+
513
+
514
+
515
+
516
+
517
+
518
+
519
+
520
+
521
+
522
+
523
+
524
+
525
+
526
+
527
+
528
+
529
+
530
+
531
+ # # TODO clean, probably unused arguments here
532
+ # def simple_eval_deprecated( dataset_eval, model_checkpoint, tokenizer, output_path,
533
+ # config ):
534
+ # '''
535
+ # Run the pre-trained model on the (dev) dataset to get predictions,
536
+ # then write the predictions in an output file.
537
+
538
+ # Parameters:
539
+ # -----------
540
+ # datasets: dict of DatasetDisc
541
+ # The datasets read
542
+ # model_checkpoint: str
543
+ # path to the saved model
544
+ # tokenizer: Tokenizer
545
+ # tokenizer of the saved model (TODO: retrieve from model? or should be removed?)
546
+ # output_path: str
547
+ # path to the output directory where prediction files will be written
548
+ # data_collator: DataCollator
549
+ # (TODO: retrieve from model?)
550
+ # '''
551
+ # # tokenized_dataset = dataset_eval.tokenized_datasets
552
+ # dev_dataset = dataset_eval.dataset
553
+
554
+ # LABEL_NAMES = dataset_eval.LABEL_NAMES_BIO
555
+ # # TODO check if needed
556
+ # word_ids = dataset_eval.all_word_ids
557
+ # model = transformers.AutoModelForTokenClassification.from_pretrained(
558
+ # model_checkpoint
559
+ # )
560
+ # data_collator = transformers.DataCollatorForTokenClassification(
561
+ # tokenizer=tokenizer,
562
+ # padding=config["tok_config"]["padding"] )
563
+
564
+ # compute_metrics = utils.prepare_compute_metrics(LABEL_NAMES)
565
+
566
+ # # TODO is it useful to have both .eval() and test_args ?
567
+ # model.eval()
568
+
569
+ # test_args = transformers.TrainingArguments(
570
+ # output_dir = output_path,
571
+ # do_train = False,
572
+ # do_predict = True,
573
+ # #per_device_eval_batch_size = BATCH_SIZE,
574
+ # dataloader_drop_last = False
575
+ # )
576
+
577
+ # trainer = transformers.Trainer(
578
+ # model=model,
579
+ # args=test_args,
580
+ # data_collator=data_collator,
581
+ # compute_metrics=compute_metrics,
582
+ # )
583
+ # predictions, label_ids, metrics = trainer.predict(dataset_eval.tokenized_datasets)
584
+ # preds = np.argmax(predictions, axis=-1)
585
+
586
+ # compute_metrics([predictions, label_ids])
587
+
588
+ # # Try to write predictions: will fail if cache not emptied
589
+ # # because we need word_ids not saved in cache TODO check...
590
+ # pred_file = os.path.join( output_path, dataset_eval.basename+'.preds' )
591
+ # try:
592
+ # write_predictions_words( preds, dataset_eval.tokenized_datasets,
593
+ # tokenizer, pred_file, dataset_eval.id2label,
594
+ # word_ids, dev_dataset, dataset_eval )
595
+ # except IndexError:
596
+ # # if error, we print the predictions with tokens as is
597
+ # write_predictions_subtokens( preds, dataset_eval.tokenized_datasets,
598
+ # tokenizer, pred_file, dataset_eval.id2label )
599
+ # # Test DISRPT eval script
600
+ # print( "\nPerformance computed using disrpt eval script on", dataset_eval.annotations_file,
601
+ # pred_file )
602
+ # if config['task'] == 'seg':
603
+ # my_eval = disrpt_eval_2025.SegmentationEvaluation("temp_test_disrpt_eval_seg",
604
+ # dataset_eval.annotations_file,
605
+ # pred_file )
606
+ # elif config['task'] == 'conn':
607
+ # my_eval = disrpt_eval_2025.ConnectivesEvaluation("temp_test_disrpt_eval_conn",
608
+ # dataset_eval.annotations_file,
609
+ # pred_file )
610
+ # else:
611
+ # raise NotImplementedError
612
+ # my_eval.compute_scores()
613
+ # my_eval.print_results()
614
+
615
+
616
+ # # TODO: dd????
617
+ # # TODO : only for SEG/CONN --> to rename (and make a generic function)
618
+ # def write_predictions_words_deprecated( preds, dev, tokenizer, pred_file, id2label, word_ids,
619
+ # dev_dataset, dd, trace=False ):
620
+ # '''
621
+ # Write predictions for segmentation or connective tasks in an output files.
622
+ # The output is the same as the input gold file, with an additional column
623
+ # corresponding to the predicted label.
624
+
625
+ # ?? We need the word_ids information to merge the words that been split et
626
+ # retrieve the original tokens from the input .tok / .conllu files and run
627
+ # evaluation.
628
+
629
+ # Parameters:
630
+ # -----------
631
+ # preds: list of int
632
+ # The predicted labels (numeric ids)
633
+ # dev: Dataset
634
+ # tokenized_dev
635
+ # pred_file: str
636
+ # Path to the file where predictions will be written
637
+ # id2label: dict
638
+ # Convert from ids to labels
639
+ # word_ids: list?
640
+ # Word ids, None for task rel
641
+ # dev_dataset : Dataset
642
+ # Dataset for the dev set
643
+ # dd : str?
644
+ # dset
645
+ # '''
646
+ # predictions = []
647
+ # for i in range( preds.shape[0] ):
648
+ # sent_input_ids = dev['input_ids'][i]
649
+ # tokens = dev_dataset['tokens'][i]
650
+ # # sentence text
651
+ # sent_tokens = tokenizer.decode(sent_input_ids[1:-1])
652
+ # # list of decoded subtokens
653
+ # #sub_tokens = [tokenizer.decode(tok_id) for tok_id in sent_input_ids]
654
+ # # Merge subtokens and retrieve corresp. pred labels
655
+ # # i.e. we ignore: CLS, SEP, PAD and labels on ##subtoks
656
+ # aligned_preds = merge_tokens_preds_sent( word_ids[i], preds[i], tokens )
657
+ # if trace:
658
+ # print( '\n', i, sent_tokens )
659
+ # print( sent_input_ids )
660
+ # print( preds[i])
661
+ # print( ' '.join( tokens ) )
662
+ # print( "aligned_preds", aligned_preds )
663
+ # # sentence id, but TODO: retrieve doc ids
664
+ # #f.write( "# sent_id = "+str(i)+"\n" )
665
+ # # Write the original sentence text
666
+ # #f.write( "# text = "+sent_tokens+"\n" )
667
+ # # indices should start at 1
668
+ # for k, tok in enumerate( tokens ):
669
+ # label = aligned_preds[k]
670
+ # predictions.append( id2label[label] )
671
+ # #f.write( "\t".join( [str(k+1), tok, "_","_","_","_","_","_","_", id2label[label] ] )+"\n" )
672
+ # #f.write("\n")
673
+ # print("PREDICTIONS", predictions)
674
+ # count_pred_B, count_gold_B = 0, 0
675
+ # with open( dd.annotations_file, 'r' ) as fin:
676
+ # with open( pred_file, 'w' ) as fout:
677
+ # mylines = fin.readlines()
678
+ # count = 0
679
+ # if trace:
680
+ # print("len(predictions)", len(predictions))
681
+ # for l in mylines:
682
+ # l = l.strip()
683
+ # if l.startswith("#"):
684
+ # fout.write( l+'\n')
685
+ # elif l == '' or l == '\n':
686
+ # fout.write('\n')
687
+ # elif '-' in l.split('\t')[0]:
688
+ # fout.write( l+'\t'+'_'+'\n')
689
+ # else:
690
+ # if 'B' in predictions[count]:
691
+ # count_pred_B += 1
692
+ # if 'Seg=B-seg' in l or 'Conn=B-conn' in l:
693
+ # count_gold_B += 1
694
+ # fout.write( l+'\t'+predictions[count]+'\n')
695
+ # count += 1
696
+
697
+
698
+ # print("Count the number of predictions corresponding to a B", count_pred_B, "vs Gold B", count_gold_B)
699
+
700
+
701
+ # # TODO: dd????
702
+ # # TODO : only for SEG/CONN --> to rename (and make a generic function)
703
+ # def write_predictions_words( preds_from_model, dataset_eval, tokenizer, pred_file, trace=True ):
704
+ # '''
705
+ # Write predictions for segmentation or connective tasks in an output files.
706
+ # The output is the same as the input gold file, with an additional column
707
+ # corresponding to the predicted label.
708
+
709
+ # ?? We need the word_ids information to merge the words that been split et
710
+ # retrieve the original tokens from the input .tok / .conllu files and run
711
+ # evaluation.
712
+
713
+ # Parameters:
714
+ # -----------
715
+ # preds_from_model: list of int
716
+ # The predicted labels (numeric ids)
717
+ # dev: Dataset
718
+ # tokenized_dev
719
+ # pred_file: str
720
+ # Path to the file where predictions will be written
721
+ # id2label: dict
722
+ # Convert from ids to labels
723
+ # word_ids: list?
724
+ # Word ids, None for task rel
725
+ # dev_dataset : Dataset
726
+ # Dataset for the dev set
727
+ # dd : str?
728
+ # dset
729
+ # '''
730
+ # word_ids = dataset_eval.all_word_ids
731
+ # id2label = dataset_eval.id2label
732
+ # predictions = []
733
+ # for i in range( preds_from_model.shape[0] ):
734
+ # sent_input_ids = dataset_eval.tokenized_datasets['input_ids'][i]
735
+ # tokens = dataset_eval.dataset['tokens'][i]
736
+ # # sentence text
737
+ # sent_tokens = tokenizer.decode(sent_input_ids[1:-1])
738
+ # # list of decoded subtokens
739
+ # #sub_tokens = [tokenizer.decode(tok_id) for tok_id in sent_input_ids]
740
+ # # Merge subtokens and retrieve corresp. pred labels
741
+ # # i.e. we ignore: CLS, SEP, PAD and labels on ##subtoks
742
+ # aligned_preds = merge_tokens_preds_sent( word_ids[i], preds_from_model[i], tokens )
743
+ # if trace:
744
+ # print( '\n', i, sent_tokens )
745
+ # print( sent_input_ids )
746
+ # print( preds_from_model[i])
747
+ # print( ' '.join( tokens ) )
748
+ # print( "aligned_preds", aligned_preds )
749
+ # # sentence id, but TODO: retrieve doc ids
750
+ # #f.write( "# sent_id = "+str(i)+"\n" )
751
+ # # Write the original sentence text
752
+ # #f.write( "# text = "+sent_tokens+"\n" )
753
+ # # indices should start at 1
754
+ # for k, tok in enumerate( tokens ):
755
+ # label = aligned_preds[k]
756
+ # predictions.append( id2label[label] )
757
+ # #f.write( "\t".join( [str(k+1), tok, "_","_","_","_","_","_","_", id2label[label] ] )+"\n" )
758
+ # #f.write("\n")
759
+ # # print("PREDICTIONS", predictions)
760
+ # write_pred_file( dataset_eval.annotations_file, pred_file, predictions )
pipeline.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import Pipeline, AutoModelForTokenClassification
2
+ import numpy as np
3
+ from eval import retrieve_predictions, align_tokens_labels_from_wordids
4
+ from reading import read_dataset
5
+ from utils import read_config
6
+
7
+
8
+
9
+ def write_sentences_to_format(sentences: list[str], filename: str):
10
+ """
11
+ Écrit une phrase dans un fichier, un mot par ligne, avec le format :
12
+ index<TAB>mot<TAB>_<TAB>_<TAB>_<TAB>_<TAB>_<TAB>_<TAB>_<TAB>Seg=...
13
+ """
14
+
15
+ if not sentences:
16
+ return ""
17
+ if isinstance(sentences, str):
18
+ sentences=[sentences]
19
+ import sys
20
+ sys.stderr.write("Warning: only one sentence provided as a string instead of a list of sentences.\n")
21
+
22
+ full="# newdoc_id = GUM_academic_discrimination\n"
23
+ for sentence in sentences:
24
+ words = sentence.strip().split()
25
+ for i, word in enumerate(words, start=1):
26
+ # Le premier mot → B-seg, sinon O
27
+ seg_label = "B-seg" if i == 1 or word[0].isupper() else "O"
28
+ line = f"{i}\t{word}\t_\t_\t_\t_\t_\t_\t_\tSeg={seg_label}\n"
29
+ full+=line
30
+ if filename:
31
+ with open(filename, "w", encoding="utf-8") as f:
32
+ f.write(full)
33
+
34
+ return full
35
+
36
+
37
+ class DiscoursePipeline(Pipeline):
38
+ def __init__(self, model, tokenizer, config:dict, output_folder="./pipe_out",sat_model:str="sat-3l", **kwargs):
39
+ auto_model = AutoModelForTokenClassification.from_pretrained(model)
40
+ super().__init__(model=auto_model, tokenizer=tokenizer, **kwargs)
41
+ self.config = {"model_checkpoint": model, "sent_spliter":"sat","task":"seg","type":"tok","trace":False,"report_to":"none","sat_model":sat_model,"tok_config":{
42
+ "padding":"max_length",
43
+ "truncation":True,
44
+ "max_length": 512
45
+ }}
46
+ self.model = model
47
+ self.output_folder = output_folder
48
+
49
+ def _sanitize_parameters(self, **kwargs):
50
+ # Permet de passer des paramètres optionnels comme add_lang_token etc.
51
+ preprocess_params = {}
52
+ forward_params = {}
53
+ postprocess_params = {}
54
+ return preprocess_params, forward_params, postprocess_params
55
+
56
+ def preprocess(self, text:str):
57
+ self.original_text=text
58
+ formatted_text=write_sentences_to_format(text.split("\n"), filename=None)
59
+ dataset, _ = read_dataset(
60
+ formatted_text,
61
+ output_path=self.output_folder,
62
+ config=self.config,
63
+ add_lang_token=True,
64
+ add_frame_token=True,
65
+ )
66
+ return {"dataset": dataset}
67
+
68
+ def _forward(self, inputs):
69
+ dataset = inputs["dataset"]
70
+ preds_from_model, label_ids, _ = retrieve_predictions(
71
+ self.model, dataset, self.output_folder, self.tokenizer, self.config
72
+ )
73
+ return {"preds": preds_from_model, "labels": label_ids, "dataset": dataset}
74
+
75
+ def postprocess(self, outputs):
76
+ preds = np.argmax(outputs["preds"], axis=-1)
77
+ predictions = align_tokens_labels_from_wordids(preds, outputs["dataset"], self.tokenizer)
78
+ edus=text_to_edus(self.original_text, predictions)
79
+ return edus
80
+
81
+ def get_plain_text_from_format(formatted_text:str) -> str:
82
+ """
83
+ Lit un fichier conllu ou tok et retourne son contenu sous forme de chaîne de caractères.
84
+ """
85
+ formatted_text=formatted_text.split("\n")
86
+ s=""
87
+ for line in formatted_text:
88
+ if not line.startswith("#"):
89
+ if len(line.split("\t"))>1:
90
+ s+=line.split("\t")[1]+" "
91
+ return s.strip()
92
+
93
+
94
+ def get_preds_from_format(formatted_text:str) -> str:
95
+ """
96
+ Lit un fichier conllu ou tok et retourne son contenu sous forme de chaîne de caractères.
97
+ """
98
+ formatted_text=formatted_text.split("\n")
99
+ s=""
100
+ for line in formatted_text:
101
+ if not line.startswith("#"):
102
+ if len(line.split("\t"))>1:
103
+ s+=line.split("\t")[-1]+" "
104
+ return s.strip()
105
+
106
+
107
+ def text_to_edus(text: str, labels: list[str]) -> list[str]:
108
+ """
109
+ Découpe un texte brut en EDUs à partir d'une séquence de labels BIO.
110
+
111
+ Args:
112
+ text (str): Le texte brut (séquence de mots séparés par des espaces).
113
+ labels (list[str]): La séquence de labels BIO (B, I, O),
114
+ de même longueur que le nombre de tokens du texte.
115
+
116
+ Returns:
117
+ list[str]: La liste des EDUs (chaque EDU est une sous-chaîne du texte).
118
+ """
119
+ words = text.strip().split()
120
+ if len(words) != len(labels):
121
+ raise ValueError(f"Longueur mismatch: {len(words)} mots vs {len(labels)} labels")
122
+
123
+ edus = []
124
+ current_edu = []
125
+
126
+ for word, label in zip(words, labels):
127
+ if label == "Conn=O" or label == "Seg=O":
128
+ current_edu.append(word)
129
+
130
+ elif label == "Conn=B-conn" or label == "Seg=B-seg":
131
+ # Finir l'EDU courant si ouvert
132
+ if current_edu:
133
+
134
+ edus.append(" ".join(current_edu))
135
+ current_edu = []
136
+ current_edu.append(word)
137
+
138
+ # Si un EDU est resté ouvert, on le ferme
139
+ if current_edu:
140
+ edus.append(" ".join(current_edu))
141
+
142
+ return edus
reading.py ADDED
@@ -0,0 +1,512 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python
2
+ # -*- coding: utf-8 -*-
3
+
4
+ import os, sys
5
+
6
+ import datasets
7
+ import transformers
8
+
9
+ import disrpt_io
10
+ import utils
11
+
12
+ # TODO to rm when dealt with this issue of loading languages
13
+ ##from ersatz import utils
14
+ ##LANGUAGES = utils.MODELS.keys()
15
+ LANGUAGES = []
16
+
17
+ def read_dataset( input_path, output_path, config, add_lang_token=True,add_frame_token=True,lang_token="",frame_token="" ):
18
+ '''
19
+ - Read the file in input_path
20
+ - Return a Dataset corresponding to the file
21
+
22
+ Parameters
23
+ ----------
24
+ input_path : str
25
+ Path to the dataset
26
+ output_path : str
27
+ Path to an output directory that can be used to write new split files
28
+ tokenizer : AutoTokenizer
29
+ Tokenizer corresponding the checkpoint model
30
+ add_lang_token : bool
31
+ If True, add a special language token at the beginning of each sequence
32
+
33
+ Returns
34
+ -------
35
+ Dataset
36
+ Contain Dataset built from train_path and dev_path for train mode,
37
+ only dev / test pasth else
38
+ Tokenizer
39
+ The tokenizer used for the dataset
40
+ '''
41
+ model_checkpoint = config["model_checkpoint"]
42
+ # -- Init tokenizer
43
+ tokenizer = transformers.AutoTokenizer.from_pretrained( model_checkpoint )
44
+ # -- Read and tokenize
45
+ dataset = DatasetSeq( input_path, output_path, config, tokenizer, add_lang_token=add_lang_token,add_frame_token=add_frame_token,lang_token=lang_token,frame_token=frame_token )
46
+ dataset.read_and_tokenize()
47
+ # TODO move in class? or do elsewhere
48
+ LABEL_NAMES_BIO = retrieve_bio_labels( dataset ) # TODO should do it only once for all
49
+ dataset.set_label_names_bio(LABEL_NAMES_BIO)
50
+ return dataset, tokenizer
51
+
52
+ # --------------------------------------------------------------------------
53
+ # DatasetDict
54
+
55
+ class DatasetDisc( ):
56
+ def __init__(self, annotations_file, output_path, config, tokenizer, dset=None ):
57
+ """
58
+ Here we save the location of our input file,
59
+ load the data, i.e. retrieve the list of texts and associated labels,
60
+ build the vocabulary if none is given,
61
+ and define the pipelines used to prepare the data
62
+ """
63
+ self.annotations_file = annotations_file
64
+ if isinstance(annotations_file, str) and not os.path.isfile(annotations_file):
65
+ print("this is a string dataset")
66
+ self.basename = "input"
67
+ else:
68
+ self.basename = os.path.basename( self.annotations_file )
69
+ self.dset = self.basename.split(".")[2].split('_')[1]
70
+ self.corpus_name = self.basename.split('_')[0]
71
+
72
+ self.tokenizer = tokenizer
73
+ self.config = config
74
+ # If a sentence splitter is used, the files with the new segmentation will be saved here
75
+ self.output_path = output_path
76
+
77
+ # Retriev info from config: TODO check against info from dir name?
78
+ self.mode = config["type"]
79
+ self.task = config["task"]
80
+ self.trace = config["trace"]
81
+ self.tok_config = config["tok_config"]
82
+ self.sent_spliter = config["sent_spliter"]
83
+
84
+ # Additional fields
85
+ self.id2label, self.label2id = {}, {}
86
+
87
+ # -- Use disrpt_io to read the file and retrieve annotated data
88
+ self.corpus = init_corpus( self.task ) # initialize a Corpus instance, depending on the task
89
+
90
+
91
+
92
+
93
+
94
+ def read_and_tokenize( self ):
95
+ print("\n-- READ FROM FILE:", self.annotations_file )
96
+ try:
97
+ self.read_annotations( )
98
+ except Exception as err:
99
+ print(f"Unexpected {err=}, {type(err)=}", file=sys.stderr)
100
+ raise
101
+ # print( "Problem when reading", self.annotations_file )
102
+
103
+ #print("\n-- SET LABELS")
104
+ self.set_labels( )
105
+ print( "self.label2id", self.label2id )
106
+
107
+ #print("\n-- TOKENIZE DATASET")
108
+ self.tokenize_dataset()
109
+ if self.trace:
110
+ if self.dset:
111
+ print( "\n-- FINISHED READING", self.dset, "PRINTING TRACE --")
112
+ self.print_trace()
113
+
114
+ def tokenize_datasets( self ):
115
+ # Specific to subclasses
116
+ raise NotImplementedError
117
+
118
+ def set_labels( self ):
119
+ # Specific to subclasses
120
+ raise NotImplementedError
121
+
122
+ # outside the class?
123
+ # TODO use **kwags instead?
124
+ def read_annotations( self ):
125
+ '''
126
+ Generate a Corpus object based on the input_file.
127
+ Since .tok files are not segmented into sentences, a sentence splitter
128
+ is used (here, ersatz)
129
+ '''
130
+ if os.path.isfile(self.annotations_file):
131
+ self.corpus.from_file(self.annotations_file)
132
+ lang = os.path.basename(self.annotations_file).split(".")[0]
133
+ frame = os.path.basename(self.annotations_file).split(".")[1]
134
+ base = os.path.basename(self.annotations_file)
135
+ else:
136
+ # on suppose que c’est du texte brut déjà au format attendu
137
+ src = self.mode if self.mode in ["tok", "conllu", "split"] else "conllu"
138
+ self.corpus.from_string(self.annotations_file,src=src)
139
+ lang = self.lang_token
140
+ frame = self.frame_token
141
+ base = "input.text"
142
+
143
+
144
+
145
+ #print(f"[DEBUG] lang? {lang}")
146
+ for doc in self.corpus.docs:
147
+ doc.lang = lang
148
+ doc.frame = frame
149
+ # print(corpus)
150
+ # Split corpus into sentences using Ersatz
151
+ if self.mode == 'tok':
152
+ kwargs={}
153
+ from wtpsplit import SaT
154
+ sat_version="sat-3l"
155
+ if "sat_model" in self.config:
156
+ sat_version=self.config["sat_model"]
157
+
158
+ sat_model = SaT(sat_version)
159
+ kwargs["sat_model"] = sat_model
160
+ self.corpus.sentence_split(model = self.sent_spliter, lang="default-multilingual",sat_model=sat_model)
161
+ # Writing files with the split sentences
162
+ parts = base.split(".")[:-1]
163
+ split_filename = ".".join(parts) + ".split"
164
+ split_file = os.path.join(self.output_path, split_filename)
165
+ self.corpus.format(file=split_file)
166
+ # no need for sentence splitting if mode = conllu or split, no need to write files
167
+
168
+ def print_trace( self ):
169
+ print( "\n| Annotation_file: ", self.annotations_file )
170
+ print( '| Output_path:', self.output_path )
171
+ print( '| Nb_of_instances:', len(self.dataset), "(", len(self.dataset['labels']), ")" )
172
+ # "(", len(self.dataset['tokens']), len(self.dataset['labels']), ")" )
173
+
174
+ def print_stats( self ):
175
+ print( "| Annotation_file: ", self.annotations_file )
176
+ if self.dset: print( "| Data_split: ", self.dset )
177
+ print( "| Task: ", self.task )
178
+ print( "| Lang: ", self.lang )
179
+ print( "| Mode: ", self.mode )
180
+ print( "| Label_names: ", self.LABEL_NAMES)
181
+ #print( "---Number_of_documents", len( self.corpus.docs ) )
182
+ print( "| Number_of_instances: ", len(self.dataset) )
183
+ # TODO : add number of docs: not computed for .rels for now
184
+
185
+ # -------------------------------------------------------------------------------------------------
186
+ class DatasetSeq(DatasetDisc):
187
+ def __init__( self, annotations_file, output_path, config, tokenizer, add_lang_token=True, add_frame_token=True,
188
+ dset=None,lang_token="",frame_token="" ):
189
+ """
190
+ Class for tasks corresponding to sequence labeling problem
191
+ (seg, conn).
192
+ Here we save the location of our input file,
193
+ load the data, i.e. retrieve the list of texts and associated
194
+ labels,
195
+ build the vocabulary if none is given,
196
+ and define the pipelines used to prepare the data """
197
+ DatasetDisc.__init__( self, annotations_file, output_path, config,
198
+ tokenizer )
199
+ self.add_lang_token = add_lang_token
200
+ self.add_frame_token=add_frame_token
201
+ self.lang_token = lang_token
202
+ self.frame_token=frame_token
203
+
204
+ if self.mode == 'tok' and self.output_path == None:
205
+ self.output_path = os.path.dirname( self.annotations_file )
206
+ self.output_path = os.path.join( self.output_path,
207
+ self.basename.replace("."+self.mode, ".split") )
208
+
209
+ self.sent_spliter = None
210
+ if "sent_spliter" in self.config:
211
+ self.sent_spliter = self.config["sent_spliter"] #only for seg
212
+
213
+ self.LABEL_NAMES_BIO = None
214
+ # # TODO not used, really a good idea?
215
+ # self.data_collator = transformers.DataCollatorForTokenClassification(tokenizer=self.tokenizer,
216
+ # padding=self.tok_config["padding"] )
217
+
218
+ def tokenize_dataset( self ):
219
+ # -- Create a HuggingFace Dataset object
220
+ if self.trace:
221
+ print(f"\n-- Creating dataset from generator (add_lang_token={self.add_lang_token})")
222
+ self.dataset = datasets.Dataset.from_generator(
223
+ gen,
224
+ gen_kwargs={"corpus": self.corpus, "label2id": self.label2id, "mode": self.mode, "add_lang_token": self.add_lang_token,"add_frame_token":self.add_frame_token},
225
+ )
226
+ if self.trace:
227
+ print( self.dataset[0])
228
+ # Keep track of the alignement between words ans subtokens, even if not ##
229
+ # BERT* add a tokenisation based on punctuation even if given with a list of words
230
+ self.all_word_ids = []
231
+ # Align labels according to tokenized subwords
232
+ if self.trace:
233
+ print( "\n-- Mapping dataset labels and subwords ")
234
+ self.tokenized_datasets = self.dataset.map(
235
+ tokenize_and_align_labels,
236
+ fn_kwargs = {"tokenizer":self.tokenizer,
237
+ "id2label":self.id2label,
238
+ "label2id":self.label2id,
239
+ "all_word_ids":self.all_word_ids,
240
+ "config":self.config},
241
+ batched=True,
242
+ remove_columns=self.dataset.column_names,
243
+ )
244
+ if self.trace:
245
+ print( self.tokenized_datasets[0])
246
+
247
+
248
+ def set_labels(self):
249
+ self.LABEL_NAMES = self.corpus.LABELS
250
+ self.id2label = {i: label for i, label in enumerate( self.LABEL_NAMES )}
251
+ self.label2id = {v: k for k,v in self.id2label.items()}
252
+
253
+ def set_label_names_bio( self, LABEL_NAMES_BIO ):
254
+ self.LABEL_NAMES_BIO = LABEL_NAMES_BIO
255
+
256
+
257
+ def print_trace( self ):
258
+ super().print_trace()
259
+ print( '\n--First sentence: original tokens and labels.\n')
260
+ print( self.dataset[0]['tokens'] )
261
+ print( self.dataset[0]['labels'] )
262
+ print( "\n---First sentence: tokenized version:\n")
263
+ print( self.tokenized_datasets[0] )
264
+ # print( '\nSource word ids:', len(self.all_word_ids) )
265
+
266
+ # # TODO prepaper a compute_stats before printing, to allow partial printing without trace mode
267
+ # def print_stats( self ):
268
+ # super().print_stats()
269
+ # print( "| Number_of_documents", len( self.corpus.docs ) )
270
+
271
+
272
+ def init_corpus( task ):
273
+ if task.strip().lower() == 'conn':
274
+ return disrpt_io.ConnectiveCorpus()
275
+ elif task == 'seg':
276
+ return disrpt_io.SegmentCorpus()
277
+ else:
278
+ raise NotImplementedError
279
+
280
+ def gen( corpus, label2id, mode, add_lang_token=True,add_frame_token=True ):
281
+ # Ajout d'un token spécial langue au début de chaque séquence
282
+ source = "split"
283
+ if mode == 'conllu':
284
+ source = "conllu"
285
+ for doc in corpus.docs:
286
+ lang = getattr(doc, 'lang', 'xx') if hasattr(doc, 'lang') else 'xx'
287
+ lang_token = f"[LANG={lang}]"
288
+
289
+ frame = getattr(doc, 'frame', 'xx') if hasattr(doc, 'lang') else 'xx'
290
+ frame_token = f"[FRAME={frame}]"
291
+ sent_list = doc.sentences[source] if source in doc.sentences else doc.sentences
292
+ for sentence in sent_list:
293
+ labels = []
294
+ tokens = []
295
+ if add_lang_token:
296
+ tokens.append(lang_token)
297
+ labels.append(-100)
298
+ if add_frame_token:
299
+ tokens.append(frame_token)
300
+ labels.append(-100)
301
+ #print(f"[DEBUG] Ajout du token frame {frame_token} pour la phrase: {' '.join([t.form for t in sentence.toks])}")
302
+ for t in sentence.toks:
303
+ tokens.append(t.form)
304
+ if t.label == '_':
305
+ if 'O' in label2id:
306
+ labels.append(label2id['O'])
307
+ else:
308
+ labels.append(list(label2id.values())[0])
309
+ else:
310
+ labels.append(label2id[t.label])
311
+ yield {
312
+ "tokens": tokens,
313
+ "labels": labels
314
+ }
315
+
316
+
317
+ def get_tokenizer( model_checkpoint ):
318
+ return transformers.AutoTokenizer.from_pretrained(model_checkpoint)
319
+
320
+ def tokenize_and_align_labels( dataset, tokenizer, id2label, label2id, all_word_ids, config ):
321
+ '''
322
+ (Done in batches)
323
+ To preprocess our whole dataset, we need to tokenize all the inputs and
324
+ apply align_labels_with_tokens() on all the labels.
325
+ (with HG, we could use Dataset.map to process batches)
326
+ The word_ids() function needs to get the index of the example we want
327
+ the word IDs of when the inputs to the tokenizer are lists of texts
328
+ (or in our case, list of lists of words), so we add that too:
329
+ "tok_config"
330
+ '''
331
+ tokenized_inputs = tokenizer(
332
+ dataset["tokens"],
333
+ truncation=config["tok_config"]['truncation'],
334
+ padding=config["tok_config"]['padding'],
335
+ max_length=config["tok_config"]['max_length'],
336
+ is_split_into_words=True
337
+ )
338
+ # tokenized_inputs = tokenizer(
339
+ # dataset["tokens"], truncation=True, padding=True, is_split_into_words=True
340
+ # )
341
+ all_labels = dataset["labels"]
342
+ new_labels = []
343
+ #print( "tokenized_inputs.word_ids()", tokenized_inputs.word_ids() )
344
+ #print( [tokenizer.decode(tok) for tok in tokenized_inputs['input_ids']])
345
+ ##with progressbar.ProgressBar(max_value=len(all_labels)) as bar:
346
+ ##for i in tqdm(range(len(all_labels))):
347
+ for i, labels in enumerate(all_labels):
348
+ word_ids = tokenized_inputs.word_ids(i)
349
+ new_labels.append(align_labels_with_tokens(labels, word_ids, id2label, label2id, tokenizer, tokenized_inputs ))
350
+ # Used to fill the self.word_ids field of the Dataset object, but should probably be done some<here else
351
+ all_word_ids.append( word_ids )
352
+ ##bar.update(i)
353
+ tokenized_inputs["labels"] = new_labels
354
+ return tokenized_inputs
355
+
356
+ def align_labels_with_tokens(labels, word_ids, id2label, label2id, tokenizer, tokenized_inputs):
357
+ '''
358
+ BERT like tokenization will create new tokens, we need to align labels.
359
+ Special tokens get a label of -100. This is because by default -100 is an
360
+ index that is ignored in the loss function we will use (cross entropy).
361
+ Then, each token gets the same label as the token that started the word
362
+ it’s inside, since they are part of the same entity. For tokens inside a
363
+ word but not at the beginning, we replace the B- with I- (since the token
364
+ does not begin the entity). [Taken from HF website course on NER]
365
+ '''
366
+ count = 0
367
+ new_labels = []
368
+ current_word = None
369
+ for word_id in word_ids:
370
+ count += 1
371
+ if word_id==0: # ou 1 peut etre
372
+ #TODO
373
+ #add lang token -100
374
+ pass
375
+ if word_id != current_word:
376
+ # Start of a new word!
377
+ current_word = word_id
378
+ label = -100 if word_id is None else labels[word_id]
379
+ new_labels.append(label)
380
+ elif word_id is None:
381
+ # Special token
382
+ new_labels.append(-100)
383
+ else:
384
+ # Same word as previous token
385
+ label = labels[word_id]
386
+ # On ne cherche 'B-' que si label != -100
387
+ if label != -100 and 'B-' in id2label[label]:
388
+ label = -100
389
+ new_labels.append(label)
390
+ return new_labels
391
+
392
+
393
+ def retrieve_bio_labels( dataset ):
394
+ '''
395
+ Needed for compute_metrics, I think? It seems to be using a classic metrics for BIO
396
+ scheme, thus we create a mapping to BIO labels, i.e.:
397
+ '_' --> 'O'
398
+ 'Seg=B-Conn' --> 'B'
399
+ 'Seg=I-Conn' --> 'I'
400
+ Should also work for segmentation TODO: check
401
+ datasets: dict: DatasetSeq instances for train/dev/test
402
+ Return: list: original label names
403
+ list: label names mapped to BIO
404
+ '''
405
+ # need a Dataset instance to retrieve the original label sets
406
+ task = dataset.task
407
+ LABEL_NAMES_BIO = []
408
+ LABEL_NAMES = dataset.LABEL_NAMES
409
+ label2idx, idx2newl = {}, {}
410
+ if task in ["conn", "seg"]:
411
+ for i,l in enumerate( LABEL_NAMES ):
412
+ label2idx[l] = i
413
+ for l in label2idx:
414
+ nl = ''
415
+ if 'B' in l:
416
+ nl = 'B'
417
+ elif 'I' in l:
418
+ nl = 'I'
419
+ else:
420
+ nl = 'O'
421
+ idx2newl[label2idx[l]] = nl
422
+ for i in sorted(idx2newl):
423
+ LABEL_NAMES_BIO.append(idx2newl[i])
424
+ #label_names = ['O', 'B', 'I']
425
+ return LABEL_NAMES_BIO
426
+
427
+ # def _compute_distrib( dataset, id2label ):
428
+ # distrib = {}
429
+ # multi = []
430
+ # for inst in dataset:
431
+ # label = id2label[inst['label']]
432
+ # if label in distrib:
433
+ # distrib[label] += 1
434
+ # else:
435
+ # distrib[label] = 1
436
+ # len_labels = len( inst["all_labels"])
437
+ # if len_labels > 1:
438
+ # #count_multi += 1
439
+ # multi.append( len_labels )
440
+ # return distrib, multi
441
+
442
+ # Defines the language code for the sentence spliter, should be done in disrpt_io?
443
+ def set_language( lang ):
444
+ #lang = "default-multilingual" #default value
445
+ # patch
446
+ if lang=="sp": lang="es"
447
+ if lang not in LANGUAGES:
448
+ lang = "default-multilingual"
449
+ return lang
450
+
451
+
452
+ # ------------------------------------------------------------------
453
+ if __name__=="__main__":
454
+ import argparse, os
455
+
456
+ parser = argparse.ArgumentParser(
457
+ description='DISCUT: reading data from disrpt_io and converting to HuggingFace'
458
+ )
459
+ # TRAIN AND DEV are (list of) FILES or DIRECTORIES
460
+ parser.add_argument("-t", "--train",
461
+ help="Training file. Default: data_test/eng.sample.rstdt/eng.sample.rstdt_train.conllu",
462
+ default="data_test/eng.sample.rstdt/eng.sample.rstdt_train.conllu")
463
+
464
+ parser.add_argument("-d", "--dev",
465
+ help="Dev file. Default: data/eng.sample.rstdt/eng.sample.rstdt_dev.conllu",
466
+ default="data_test/eng.sample.rstdt/eng.sample.rstdt_dev.conllu")
467
+
468
+ # OUTPUT DIRECTORY
469
+ parser.add_argument("-o", "--output",
470
+ help="Directory where models and pred will be saved. Default: /home/cbraud/experiments/expe_discut_2025/",
471
+ default="")
472
+
473
+ # CONFIG FILE
474
+ parser.add_argument("-c", "--config",
475
+ help="Config file. Default: ./config_seg.json",
476
+ default="./config_seg.json")
477
+
478
+ # TRACE / VERBOSITY
479
+ parser.add_argument( '-v', '--trace',
480
+ action='store_true',
481
+ default=False,
482
+ help="Whether to print full messages. If used, it will override the value in config file.")
483
+
484
+ args = parser.parse_args()
485
+
486
+ train_path = args.train
487
+ dev_path = args.dev
488
+ print(dev_path)
489
+ if not os.path.isfile(dev_path[0]):
490
+ print( "ERROR with dev file:", dev_path)
491
+ output_path = args.output
492
+ config_file = args.config
493
+ #eval = args.eval
494
+ trace = args.trace
495
+
496
+ print( '\n-[JEDIS]--PROGRAM (reader) ARGUMENTS')
497
+ print( '| Train_path', train_path )
498
+ print( '| Dev_path', dev_path )
499
+ print( "| Output_path", output_path )
500
+ print( '| Config', config_file )
501
+
502
+ print( '\n-[JEDIS]--CONFIG INFO')
503
+ config = utils.read_config( config_file )
504
+ utils.print_config(config)
505
+ # WE override the config file if the user says no trace in arguments
506
+ # easier than modifying the config files each time
507
+ if not trace:
508
+ config['trace'] = False
509
+
510
+ print( "\n-[JEDIS]--READING DATASETS" )
511
+ # dictionnary containing train (if model=='train') and/or dev (test) Dataset instance
512
+ datasets, tokenizer = read_dataset( train_path, dev_path, config, add_lang_token=True )
utils.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python
2
+ # -*- coding: utf-8 -*-
3
+
4
+ import os, sys
5
+ import json
6
+ import numpy as np
7
+ from pathlib import Path
8
+ import itertools
9
+
10
+ import evaluate
11
+ import disrpt_eval_2025
12
+ #from .disrpt_eval_2025 import *
13
+
14
+ # TODO : should be conditioned on the task or the metric indicated in the config file ??
15
+ def prepare_compute_metrics(LABEL_NAMES):
16
+ '''
17
+ Return the method to be used in the trainer loop.
18
+ For seg or conn, based on seqeval, and here ignore tokens with label
19
+ -100 (okay ?)
20
+
21
+ Parameters :
22
+ ------------
23
+ LABEL_NAMES: Dict
24
+ Needed only for BIO labels, convert to the right labels for seqeval
25
+ task: str
26
+ Should be either 'seg', 'conn', but could be expanded to other
27
+ sequence / classif tasks
28
+
29
+ Returns :
30
+ ---------
31
+ compute_metrics: function
32
+ '''
33
+ def compute_metrics(eval_preds):
34
+ nonlocal LABEL_NAMES
35
+ # nonlocal task
36
+ # Retrieve gold and predictions
37
+ logits, labels = eval_preds
38
+
39
+ predictions = np.argmax(logits, axis=-1)
40
+ metric = evaluate.load("seqeval")
41
+ # Remove ignored index (special tokens) and convert to labels
42
+ true_labels = [[LABEL_NAMES[l] for l in label if l != -100] for label in labels]
43
+ true_predictions = [
44
+ [LABEL_NAMES[p] for (p, l) in zip(prediction, label) if l != -100]
45
+ for prediction, label in zip(predictions, labels)
46
+ ]
47
+ all_metrics = metric.compute(predictions=true_predictions, references=true_labels)
48
+ print_metrics( all_metrics )
49
+ return {
50
+ "precision": all_metrics["overall_precision"],
51
+ "recall": all_metrics["overall_recall"],
52
+ "f1": all_metrics["overall_f1"],
53
+ "accuracy": all_metrics["overall_accuracy"],
54
+ }
55
+ return compute_metrics
56
+
57
+
58
+ def print_metrics( all_metrics ):
59
+ #print( all_metrics )
60
+ for p,v in all_metrics.items():
61
+ if '_' in p:
62
+ print( p, v )
63
+ else:
64
+ print( p+' = '+str(v))
65
+
66
+ def compute_metrics_dirspt( dataset_eval, pred_file, task='seg' ):
67
+ print( "\nPerformance computed using disrpt eval script on", dataset_eval.annotations_file,
68
+ pred_file )
69
+ if task == 'seg':
70
+ #clean_pred_file(pred_file, os.path.basename(pred_file)+"cleaned.preds")
71
+ my_eval = disrpt_eval_2025.SegmentationEvaluation("temp_test_disrpt_eval_seg",
72
+ dataset_eval.annotations_file,
73
+ pred_file )
74
+ elif task == 'conn':
75
+ my_eval = disrpt_eval_2025.ConnectivesEvaluation("temp_test_disrpt_eval_conn",
76
+ dataset_eval.annotations_file,
77
+ pred_file )
78
+ else:
79
+ raise NotImplementedError
80
+ my_eval.compute_scores()
81
+ my_eval.print_results()
82
+
83
+ def clean_pred_file(pred_path: str, out_path: str):
84
+ c=0
85
+ with open(pred_path, "r", encoding="utf8") as fin, open(out_path, "w", encoding="utf8") as fout:
86
+ for line in fin:
87
+ if line.strip() == "" or line.startswith("#"):
88
+ fout.write(line)
89
+ continue
90
+ fields = line.strip().split("\t")
91
+ token = fields[1]
92
+ if token.startswith("[LANG=") or token.startswith("[FRAME="):
93
+ c+=1
94
+ continue # skip meta-tokens
95
+ fout.write(line)
96
+ print(f"we've cleaned {c} tokens")
97
+ # -------------------------------------------------------------------------------------------------
98
+ # ------ UTILS FUNCTIONS
99
+ # -------------------------------------------------------------------------------------------------
100
+ def read_config( config_file ):
101
+ '''Read the config file for training'''
102
+ f = open(config_file)
103
+ config = json.load(f)
104
+ if 'frozen' in config['trainer_config']:
105
+ config['trainer_config']["frozen"] = update_frozen_set( config['trainer_config']["frozen"] )
106
+ return config
107
+
108
+ def update_frozen_set( freeze ):
109
+ # MAke a set from the list of frozen layers
110
+ # [] --> nothing frozen
111
+ # [3] --> only layer 3 frozen
112
+ # [0,3] --> only layers 0 and 3
113
+ # [0-3, 12, 15] --> layers 0 to 3 included, + layers 12 and layers 15
114
+ frozen = set()
115
+ for spec in freeze:
116
+ if "-" in spec: # eg 1-9
117
+ b, e = spec.split("-")
118
+ frozen = frozen | set(range(int(b),int(e)+1))
119
+ else:
120
+ frozen.add(int(spec))
121
+ return frozen
122
+
123
+ def print_config(config):
124
+ '''Print info from config dictionary'''
125
+ print('\n'.join([ '| '+k+": "+str(v) for (k,v) in config.items() ]))
126
+
127
+ # -------------------------------------------------------------------------------------------------
128
+ def retrieve_files_dataset( input_path, list_dataset, mode='conllu', dset='train' ):
129
+ if mode == 'conllu':
130
+ pat = ".[cC][oO][nN][lL][lL][uU]"
131
+ elif mode == 'tok':
132
+ pat = ".[tT][oO][kK]"
133
+ else:
134
+ sys.exit('Unknown mode for file extension: '+mode)
135
+ if len(list_dataset) == 0:
136
+ return list(Path(input_path).rglob("*_"+dset+pat))
137
+ else:
138
+ # files eng.pdtb.pdtb_train.conllu
139
+ matched = []
140
+ for subdir in os.listdir( input_path ):
141
+ if subdir in list_dataset:
142
+ matched.extend( list(Path(os.path.join(input_path,subdir)).rglob("*_"+dset+pat)) )
143
+ return matched
144
+
145
+
146
+ # -------------------------------------------------------------------------------------------------
147
+ # https://wandb.ai/site
148
+ def init_wandb( config, model_checkpoint, annotations_file ):
149
+ '''
150
+ Initialize a new WANDB project to keep track of the experiments.
151
+ Parameters
152
+ ----------
153
+ config : dict
154
+ Allow to retrieve the name of the entity and project (from config file)
155
+ model_checkpoint :
156
+ Name of the PLM used
157
+ annotations_file : str
158
+ Path to the training file
159
+
160
+ Returns
161
+ -------
162
+ None
163
+ '''
164
+ print("HERE WE INITIALIZE A WANDB PROJECT")
165
+
166
+ import wandb
167
+ proj_wandb = config["wandb"]
168
+ ent_wandbd = config["wandb_ent"]
169
+ # start a new wandb run to track this script
170
+ # The project name must be set before initializing the trainer
171
+ wandb.init(
172
+ # set the wandb project where this run will be logged
173
+ project=proj_wandb,
174
+ entity=ent_wandbd,
175
+ # track hyperparameters and run metadata
176
+ config={
177
+ "model_checkpoint": model_checkpoint,
178
+ "dataset": annotations_file,
179
+ }
180
+ )
181
+ wandb.define_metric("epoch")
182
+ wandb.define_metric("epoch")
183
+ wandb.define_metric("f1", step_metric="batch")
184
+ wandb.define_metric("f1", step_metric="epoch")
185
+
186
+ def set_name_output_dir( output_dir, config, corpus_name ):
187
+ '''
188
+ Set the path name for the target directory used to store models. The name should contain
189
+ info about the task, the PLM and the hyperparameter values.
190
+
191
+ Parameters
192
+ ----------
193
+ output_dir : str
194
+ Path to the output directory provided by the user
195
+ config: dict
196
+ Information of configuration
197
+ corpus_name: str
198
+ Name of the corpus
199
+
200
+ Returns
201
+ -------
202
+ Str: Path to the output directory
203
+ '''
204
+ # Retrieve decimal number for learning rate, to avoir scientific notation
205
+ hyperparam = [
206
+ config['trainer_config']['batch_size'],
207
+ np.format_float_positional(config['trainer_config']['learning_rate'])
208
+ ]
209
+ output_dir = os.path.join( output_dir,
210
+ '_'.join( [
211
+ corpus_name,
212
+ config["model_name"],
213
+ config["task"],
214
+ '_'.join([str(p) for p in hyperparam])
215
+ ] ) )
216
+ return output_dir