Spaces:
Sleeping
Sleeping
gradio space init
Browse files- disrpt_eval_2025.py +517 -0
- disrpt_io.py +846 -0
- eval.py +760 -0
- pipeline.py +142 -0
- reading.py +512 -0
- utils.py +216 -0
disrpt_eval_2025.py
ADDED
|
@@ -0,0 +1,517 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Script to evaluate segmentation f-score and perfect discourse unit segmentation proportion from two files. Two input formats are permitted:
|
| 3 |
+
|
| 4 |
+
* One token per line, with ten columns, no sentence breaks (default *.tok format) - segmentation indicated in column 10
|
| 5 |
+
* The same, but with blank lines between sentences (*.conll format)
|
| 6 |
+
|
| 7 |
+
Token columns follow the CoNLL-U format, with token IDs in the first column and pipe separated key=value pairs in the last column.
|
| 8 |
+
Document boundaries are indicated by a comment: # newdoc_id = ...
|
| 9 |
+
The evaluation uses micro-averaged F-Scores per corpus (not document macro average).
|
| 10 |
+
|
| 11 |
+
Example:
|
| 12 |
+
|
| 13 |
+
# newdoc_id = GUM_bio_byron
|
| 14 |
+
1 Education _ _ _ _ _ _ _ Seg=B-seg
|
| 15 |
+
2 and _ _ _ _ _ _ _ _
|
| 16 |
+
3 early _ _ _ _ _ _ _ _
|
| 17 |
+
4 loves _ _ _ _ _ _ _ _
|
| 18 |
+
5 Byron _ _ _ _ _ _ _ Seg=B-seg
|
| 19 |
+
6 received _ _ _ _ _ _ _ _
|
| 20 |
+
|
| 21 |
+
Or:
|
| 22 |
+
|
| 23 |
+
# newdoc_id = GUM_bio_byron
|
| 24 |
+
# sent_id = GUM_bio_byron-1
|
| 25 |
+
# text = Education and early loves
|
| 26 |
+
1 Education education NOUN NN Number=Sing 0 root _ Seg=B-seg
|
| 27 |
+
2 and and CCONJ CC _ 4 cc _ _
|
| 28 |
+
3 early early ADJ JJ Degree=Pos 4 amod _ _
|
| 29 |
+
4 loves love NOUN NNS Number=Plur 1 conj _ _
|
| 30 |
+
|
| 31 |
+
# sent_id = GUM_bio_byron-2
|
| 32 |
+
# text = Byron received his early formal education at Aberdeen Grammar School, and in August 1799 entered the school of Dr. William Glennie, in Dulwich. [17]
|
| 33 |
+
1 Byron Byron PROPN NNP Number=Sing 2 nsubj _ Seg=B-seg
|
| 34 |
+
2 received receive VERB VBD Mood=Ind|Tense=Past|VerbForm=Fin 0 root _ _
|
| 35 |
+
|
| 36 |
+
For PDTB-style corpora, we calculate exact span-wise f-scores for BIO encoding, without partial credit. In other words,
|
| 37 |
+
predicting an incorrect span with partial overlap is the same as missing a gold span and predicting an incorrect span
|
| 38 |
+
somewhere else in the corpus. Note also that spans must begin with B-Conn - predicted spans beginning with I-Conn are ignored.
|
| 39 |
+
The file format for PDTB style corpora is similar, but with different labels:
|
| 40 |
+
|
| 41 |
+
1 Fidelity Fidelity PROPN NNP _ 6 nsubj _ _
|
| 42 |
+
2 , , PUNCT , _ 6 punct _ _
|
| 43 |
+
3 for for ADP IN _ 4 case _ Seg=B-Conn
|
| 44 |
+
4 example example NOUN NN _ 6 obl _ Conn=I-conn
|
| 45 |
+
5 , , PUNCT , _ 6 punct _ _
|
| 46 |
+
6 prepared prepare VERB VBN _ 0 root _ _
|
| 47 |
+
7 ads ad NOUN NNS _ 6 obj _ _
|
| 48 |
+
|
| 49 |
+
Arguments:
|
| 50 |
+
* goldfile: shared task gold test data
|
| 51 |
+
* predfile: same format, with predicted segments positions in column 10 - note **number of tokens must match**
|
| 52 |
+
* string_input: if specified, files are replaced by strings with file contents instead of file names
|
| 53 |
+
* no_boundaries: specify to eval only intra-sentence EDUs
|
| 54 |
+
"""
|
| 55 |
+
|
| 56 |
+
""" TODO
|
| 57 |
+
- OK labels : en argument, pas en dur
|
| 58 |
+
- OK option sans ls débuts de phrases : cf script "BIO no B'
|
| 59 |
+
- OK imprimer les résultats + propre : sans le "o" bizarre
|
| 60 |
+
- OK faire 2 classes edu et connectives (conn: futur exp for eval connective extended vs head of connective)
|
| 61 |
+
- solution + propre pour la colonne des labels ?
|
| 62 |
+
- faire une classe Eval et transformer les 2 en Eval en sous-classes
|
| 63 |
+
"""
|
| 64 |
+
|
| 65 |
+
__author__ = "Amir Zeldes, Janet Liu, Laura Rivière"
|
| 66 |
+
__license__ = "Apache 2.0"
|
| 67 |
+
__version__ = "2.0.0"
|
| 68 |
+
|
| 69 |
+
import io, os, sys, argparse
|
| 70 |
+
import json
|
| 71 |
+
from sklearn.metrics import accuracy_score, classification_report
|
| 72 |
+
|
| 73 |
+
# MWE and ellips : no lab or "_"
|
| 74 |
+
# TODO :
|
| 75 |
+
# print scores *100: 0.6825 => 68.25
|
| 76 |
+
# documentation (automatic generation ?)
|
| 77 |
+
# testunitaire
|
| 78 |
+
|
| 79 |
+
class Evaluation:
|
| 80 |
+
"""
|
| 81 |
+
Generic class for evaluation between 2 files.
|
| 82 |
+
:load data, basic check, basic metrics, print results.
|
| 83 |
+
"""
|
| 84 |
+
def __init__(self, name: str) -> None:
|
| 85 |
+
self.output = dict()
|
| 86 |
+
self.name = name
|
| 87 |
+
self.report = ""
|
| 88 |
+
self.fill_output('doc_name', self.name)
|
| 89 |
+
|
| 90 |
+
def get_data(self, infile: str, str_i=False) -> str:
|
| 91 |
+
"""
|
| 92 |
+
Stock data from file or stream.
|
| 93 |
+
"""
|
| 94 |
+
if str_i == False:
|
| 95 |
+
data = io.open(infile, encoding="utf-8").read().strip().replace("\r", "")
|
| 96 |
+
else:
|
| 97 |
+
data = infile.strip()
|
| 98 |
+
return data
|
| 99 |
+
|
| 100 |
+
def fill_output(self, key: str, value) -> None:
|
| 101 |
+
"""
|
| 102 |
+
Fill results dict that will be printed.
|
| 103 |
+
"""
|
| 104 |
+
self.output[key] = value
|
| 105 |
+
|
| 106 |
+
def check_tokens_number(self, g: list, p: list) -> None:
|
| 107 |
+
"""
|
| 108 |
+
Check same number of tokens/labels in both compared files.
|
| 109 |
+
"""
|
| 110 |
+
if len(g) != len(p):
|
| 111 |
+
self.report += "\nFATAL: different number of tokens detected in gold and pred:\n"
|
| 112 |
+
self.report += ">>> In " + self.name + ": " + str(len(g)) + " gold tokens but " + str(len(p)) + " predicted tokens\n\n"
|
| 113 |
+
sys.stderr.write(self.report)
|
| 114 |
+
sys.exit(0)
|
| 115 |
+
|
| 116 |
+
def check_identical_tokens(self, g: list, p: list) -> None:
|
| 117 |
+
"""
|
| 118 |
+
Check tokens/features are identical.
|
| 119 |
+
"""
|
| 120 |
+
for i, tok in enumerate(g):
|
| 121 |
+
if tok != p[i]:
|
| 122 |
+
self.report += "\nWARN: token strings do not match in gold and pred:\n"
|
| 123 |
+
self.report += ">>> First instance in " + self.name + " token " + str(i) + "\n"
|
| 124 |
+
self.report += "Gold: " + tok + " but Pred: " + p[i] + "\n\n"
|
| 125 |
+
sys.stderr.write(self.report)
|
| 126 |
+
break
|
| 127 |
+
|
| 128 |
+
def compute_PRF_metrics(self, tp: int, fp: int, fn: int) -> None:
|
| 129 |
+
"""
|
| 130 |
+
Compute Precision, Recall, F-score from True Positive, False Positive and False Negative counts.
|
| 131 |
+
Save result in dict.
|
| 132 |
+
"""
|
| 133 |
+
try:
|
| 134 |
+
precision = tp / (float(tp) + fp)
|
| 135 |
+
except Exception as e:
|
| 136 |
+
precision = 0
|
| 137 |
+
|
| 138 |
+
try:
|
| 139 |
+
recall = tp / (float(tp) + fn)
|
| 140 |
+
except Exception as e:
|
| 141 |
+
recall = 0
|
| 142 |
+
|
| 143 |
+
try:
|
| 144 |
+
f_score = 2 * (precision * recall) / (precision + recall)
|
| 145 |
+
except:
|
| 146 |
+
f_score = 0
|
| 147 |
+
|
| 148 |
+
self.fill_output("gold_count", tp + fn )
|
| 149 |
+
self.fill_output("pred_count", tp + fp )
|
| 150 |
+
self.fill_output("precision", precision)
|
| 151 |
+
self.fill_output("recall", recall)
|
| 152 |
+
self.fill_output("f_score", f_score)
|
| 153 |
+
|
| 154 |
+
def compute_accuracy(self, g: list, p: list, k: str) -> None:
|
| 155 |
+
"""
|
| 156 |
+
Compute accuracy of predictions list of items, against gold list of items.
|
| 157 |
+
:g: gold list
|
| 158 |
+
:p: predicted list
|
| 159 |
+
:k: name detail of accuracy
|
| 160 |
+
"""
|
| 161 |
+
self.fill_output(f"{k}_accuracy", accuracy_score(g, p) )
|
| 162 |
+
self.fill_output(f"{k}_gold_count", len(g) )
|
| 163 |
+
self.fill_output(f"{k}_pred_count", len(p) )
|
| 164 |
+
|
| 165 |
+
def classif_report(self, g: list, p: list, key: str) -> None:
|
| 166 |
+
"""
|
| 167 |
+
Compute Precision, Recall and f-score for each instances of gold list.
|
| 168 |
+
"""
|
| 169 |
+
stats_dict = classification_report(g, p, labels=sorted(set(g)), zero_division=0.0, output_dict=True)
|
| 170 |
+
self.fill_output(f'{key}_classification_report', stats_dict)
|
| 171 |
+
|
| 172 |
+
def print_results(self) -> None:
|
| 173 |
+
"""
|
| 174 |
+
Print dict of saved results.
|
| 175 |
+
"""
|
| 176 |
+
# for k in self.output.keys():
|
| 177 |
+
# print(f">> {k} : {self.output[k]}")
|
| 178 |
+
|
| 179 |
+
print(json.dumps(self.output, indent=4))
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
class RelationsEvaluation(Evaluation):
|
| 183 |
+
"""
|
| 184 |
+
Specific evaluaion class for relations classification.
|
| 185 |
+
The evaluation uses the simple accuracy score per corpus.
|
| 186 |
+
:rels disrpt-style data.
|
| 187 |
+
:default eval last column "label"
|
| 188 |
+
:option eval relation type (pdtb: implicit, explicit...) column "rel_type"
|
| 189 |
+
"""
|
| 190 |
+
|
| 191 |
+
HEADER = "doc\tunit1_toks\tunit2_toks\tunit1_txt\tunit2_txt\tu1_raw\tu2_raw\ts1_toks\ts2_toks\tunit1_sent\tunit2_sent\tdir\trel_type\torig_label\tlabel"
|
| 192 |
+
# HEADER_23 = "doc\tunit1_toks\tunit2_toks\tunit1_txt\tunit2_txt\ts1_toks\ts2_toks\tunit1_sent\tunit2_sent\tdir\torig_label\tlabel"
|
| 193 |
+
|
| 194 |
+
LABEL_ID = -1
|
| 195 |
+
TYPE_ID = -3
|
| 196 |
+
DISRPT_TYPES = ['Implicit', 'Explicit', 'AltLex', 'AltLexC', 'Hypophora']
|
| 197 |
+
|
| 198 |
+
def __init__(self, name: str, gold_path: str, pred_path: str, str_i=False, rel_type=False) -> None:
|
| 199 |
+
super().__init__(name)
|
| 200 |
+
"""
|
| 201 |
+
:param gold_file: Gold shared task file
|
| 202 |
+
:param pred_file: File with predictions
|
| 203 |
+
:param string_input: If True, files are replaced by strings with file contents (for import inside other scripts)
|
| 204 |
+
:param rel_type: If True, scores are computed on types column, not label (relevant for PDTB)
|
| 205 |
+
"""
|
| 206 |
+
self.mode = "rel"
|
| 207 |
+
self.g_path = gold_path
|
| 208 |
+
self.p_path = pred_path
|
| 209 |
+
self.opt_str_i = str_i
|
| 210 |
+
self.opt_rel_t = rel_type
|
| 211 |
+
self.key = "labels"
|
| 212 |
+
|
| 213 |
+
self.fill_output("options", {"s": self.opt_str_i, "rt": self.opt_rel_t})
|
| 214 |
+
|
| 215 |
+
def compute_scores(self) -> None:
|
| 216 |
+
"""
|
| 217 |
+
Get lists of data to compare, compute metrics.
|
| 218 |
+
"""
|
| 219 |
+
gold_units, gold_labels, gold_types = self.parse_rels_data(self.g_path, self.opt_str_i, self.opt_rel_t)
|
| 220 |
+
pred_units, pred_labels, pred_types = self.parse_rels_data(self.p_path, self.opt_str_i, self.opt_rel_t)
|
| 221 |
+
self.check_tokens_number(gold_labels, pred_labels)
|
| 222 |
+
self.check_identical_tokens(gold_units, pred_units)
|
| 223 |
+
|
| 224 |
+
self.compute_accuracy(gold_labels, pred_labels, self.key)
|
| 225 |
+
self.classif_report(gold_labels, pred_labels, self.key)
|
| 226 |
+
|
| 227 |
+
if self.opt_rel_t:
|
| 228 |
+
self.get_types_scores(gold_labels, pred_labels, gold_types)
|
| 229 |
+
|
| 230 |
+
def get_types_scores(self, g: list, p: list, tg: list) -> None:
|
| 231 |
+
"""
|
| 232 |
+
This function is to obtain scores of predictions against gold labels, by types of relations.
|
| 233 |
+
"""
|
| 234 |
+
|
| 235 |
+
for t in self.DISRPT_TYPES:
|
| 236 |
+
gold_t = []
|
| 237 |
+
pred_t = []
|
| 238 |
+
for i, _ in enumerate(g):
|
| 239 |
+
|
| 240 |
+
if tg[i] == t.lower():
|
| 241 |
+
gold_t.append(g[i])
|
| 242 |
+
pred_t.append(p[i])
|
| 243 |
+
|
| 244 |
+
self.compute_accuracy(gold_t, pred_t, f"types_{t}")
|
| 245 |
+
|
| 246 |
+
def parse_rels_data(self, path: str, str_i: bool, rel_t: bool) -> tuple[list[str], list[str]]:
|
| 247 |
+
"""
|
| 248 |
+
Rels format from DISRPT = header, then one relation classification instance per line.
|
| 249 |
+
:LREC_2024_header = 15 columns.
|
| 250 |
+
"""
|
| 251 |
+
data = self.get_data(path, str_i)
|
| 252 |
+
header = data.split("\n")[0]
|
| 253 |
+
assert header == self.HEADER, "Unrecognized .rels header."
|
| 254 |
+
#column_ID = self.TYPE_ID if rel_t == True else self.LABEL_ID
|
| 255 |
+
|
| 256 |
+
rels = data.split("\n")[1:]
|
| 257 |
+
labels = [line.split("\t")[self.LABEL_ID] for line in rels] ######## .lower()
|
| 258 |
+
units = [" ".join(line.split("\t")[:3]) for line in rels]
|
| 259 |
+
types = [line.split("\t")[self.TYPE_ID] for line in rels] if rel_t == True else []
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
return units, labels, types
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
class ConnectivesEvaluation(Evaluation):
|
| 267 |
+
"""
|
| 268 |
+
Specific evaluation class for PDTB connectives detection.
|
| 269 |
+
:parse conllu-style data
|
| 270 |
+
:eval upon strict connectives spans
|
| 271 |
+
"""
|
| 272 |
+
LAB_CONN_B = "Conn=B-conn" # "Seg=B-Conn" #
|
| 273 |
+
LAB_CONN_I = "Conn=I-conn" # "Seg=I-Conn" #
|
| 274 |
+
LAB_CONN_O = "Conn=O" # "_" #
|
| 275 |
+
|
| 276 |
+
def __init__(self, name:str, gold_path:str, pred_path:str, str_i=False) -> None:
|
| 277 |
+
super().__init__(name)
|
| 278 |
+
"""
|
| 279 |
+
:param gold_file: Gold shared task file
|
| 280 |
+
:param pred_file: File with predictions
|
| 281 |
+
:param string_input: If True, files are replaced by strings with file contents (for import inside other scripts)
|
| 282 |
+
"""
|
| 283 |
+
self.mode = "conn"
|
| 284 |
+
self.seg_type = "connective spans"
|
| 285 |
+
self.g_path = gold_path
|
| 286 |
+
self.p_path = pred_path
|
| 287 |
+
self.opt_str_i = str_i
|
| 288 |
+
|
| 289 |
+
self.fill_output('seg_type', self.seg_type)
|
| 290 |
+
self.fill_output("options", {"s": self.opt_str_i})
|
| 291 |
+
|
| 292 |
+
def compute_scores(self) -> None:
|
| 293 |
+
"""
|
| 294 |
+
Get lists of data to compare, compute metrics.
|
| 295 |
+
"""
|
| 296 |
+
gold_tokens, gold_labels, gold_spans = self.parse_conn_data(self.g_path, self.opt_str_i)
|
| 297 |
+
pred_tokens, pred_labels, pred_spans = self.parse_conn_data(self.p_path, self.opt_str_i)
|
| 298 |
+
|
| 299 |
+
self.output['tok_count'] = len(gold_tokens)
|
| 300 |
+
|
| 301 |
+
self.check_tokens_number(gold_tokens, pred_tokens)
|
| 302 |
+
self.check_identical_tokens(gold_tokens, pred_tokens)
|
| 303 |
+
tp, fp, fn = self.compare_spans(gold_spans, pred_spans)
|
| 304 |
+
self.compute_PRF_metrics(tp, fp, fn)
|
| 305 |
+
|
| 306 |
+
def compare_spans(self, gold_spans: tuple, pred_spans: tuple) -> tuple[int, int, int]:
|
| 307 |
+
"""
|
| 308 |
+
Compare exact spans.
|
| 309 |
+
"""
|
| 310 |
+
|
| 311 |
+
true_positive = 0
|
| 312 |
+
false_positive = 0
|
| 313 |
+
false_negative = 0
|
| 314 |
+
|
| 315 |
+
for span in gold_spans: # not verified
|
| 316 |
+
if span in pred_spans:
|
| 317 |
+
true_positive +=1
|
| 318 |
+
else:
|
| 319 |
+
false_negative +=1
|
| 320 |
+
for span in pred_spans:
|
| 321 |
+
if span not in gold_spans:
|
| 322 |
+
false_positive += 1
|
| 323 |
+
|
| 324 |
+
return true_positive, false_positive, false_negative
|
| 325 |
+
|
| 326 |
+
def parse_conn_data(self, path:str, str_i:bool) -> tuple[list, list, list]:
|
| 327 |
+
"""
|
| 328 |
+
LABEL = in last column
|
| 329 |
+
"""
|
| 330 |
+
data = self.get_data(path, str_i)
|
| 331 |
+
tokens = []
|
| 332 |
+
labels = []
|
| 333 |
+
spans = []
|
| 334 |
+
counter = 0
|
| 335 |
+
span_start = -1
|
| 336 |
+
span_end = -1
|
| 337 |
+
for line in data.split("\n"): # this loop is same than version 1
|
| 338 |
+
if line.startswith("#") or line == "":
|
| 339 |
+
continue
|
| 340 |
+
else:
|
| 341 |
+
fields = line.split("\t") # Token
|
| 342 |
+
label = fields[-1]
|
| 343 |
+
if "-" in fields[0] or "." in fields[0]: # Multi-Word Expression or Ellips : No pred shall be there....
|
| 344 |
+
continue
|
| 345 |
+
elif self.LAB_CONN_B in label:
|
| 346 |
+
if span_start > -1: # add span
|
| 347 |
+
if span_end == -1:
|
| 348 |
+
span_end = span_start
|
| 349 |
+
spans.append((span_start,span_end))
|
| 350 |
+
span_end = -1
|
| 351 |
+
label = self.LAB_CONN_B
|
| 352 |
+
span_start = counter
|
| 353 |
+
elif self.LAB_CONN_I in label:
|
| 354 |
+
label = self.LAB_CONN_I
|
| 355 |
+
span_end = counter
|
| 356 |
+
else:
|
| 357 |
+
label = "_"
|
| 358 |
+
if span_start > -1: # Add span
|
| 359 |
+
if span_end == -1:
|
| 360 |
+
span_end = span_start
|
| 361 |
+
spans.append((span_start,span_end))
|
| 362 |
+
span_start = -1
|
| 363 |
+
span_end = -1
|
| 364 |
+
|
| 365 |
+
tokens.append(fields[1])
|
| 366 |
+
labels.append(label)
|
| 367 |
+
counter += 1
|
| 368 |
+
|
| 369 |
+
if span_start > -1 and span_end > -1: # Add last span
|
| 370 |
+
spans.append((span_start,span_end))
|
| 371 |
+
|
| 372 |
+
if not self.LAB_CONN_B in labels:
|
| 373 |
+
print(f"Unrecognized labels. Expecting: {self.LAB_CONN_B}, {self.LAB_CONN_I}, {self.LAB_CONN_O}...")
|
| 374 |
+
print("maybe the model is so bad it can't find a B")
|
| 375 |
+
|
| 376 |
+
return tokens, labels, spans
|
| 377 |
+
|
| 378 |
+
|
| 379 |
+
class SegmentationEvaluation(Evaluation):
|
| 380 |
+
"""
|
| 381 |
+
Specific evaluation class for EDUs segmentation.
|
| 382 |
+
:parse conllu-style data
|
| 383 |
+
:eval upon first token identification
|
| 384 |
+
"""
|
| 385 |
+
LAB_SEG_B = "Seg=B-seg" # "BeginSeg=Yes"
|
| 386 |
+
LAB_SEG_I = "Seg=O" # "_"
|
| 387 |
+
|
| 388 |
+
def __init__(self, name: str, gold_path: str, pred_path: str, str_i=False, no_b=False) -> None:
|
| 389 |
+
super().__init__(name)
|
| 390 |
+
"""
|
| 391 |
+
:param gold_file: Gold shared task file
|
| 392 |
+
:param pred_file: File with predictions
|
| 393 |
+
:param string_input: If True, files are replaced by strings with file contents (for import inside other scripts)
|
| 394 |
+
"""
|
| 395 |
+
self.mode = "edu"
|
| 396 |
+
self.seg_type = "EDUs"
|
| 397 |
+
self.g_path = gold_path
|
| 398 |
+
self.p_path = pred_path
|
| 399 |
+
self.opt_str_i = str_i
|
| 400 |
+
self.no_b = True if "conllu" in gold_path.split(os.sep)[-1] and no_b == True else False # relevant only in conllu
|
| 401 |
+
|
| 402 |
+
self.fill_output('seg_type', self.seg_type)
|
| 403 |
+
self.fill_output("options", {"s": self.opt_str_i})
|
| 404 |
+
|
| 405 |
+
def compute_scores(self) -> None:
|
| 406 |
+
"""
|
| 407 |
+
Get lists of data to compare, compute metrics.
|
| 408 |
+
"""
|
| 409 |
+
gold_tokens, gold_labels, gold_spans = self.parse_edu_data(self.g_path, self.opt_str_i, self.no_b)
|
| 410 |
+
pred_tokens, pred_labels, pred_spans = self.parse_edu_data(self.p_path, self.opt_str_i, self.no_b)
|
| 411 |
+
|
| 412 |
+
self.output['tok_count'] = len(gold_tokens)
|
| 413 |
+
|
| 414 |
+
self.check_tokens_number(gold_tokens, pred_tokens)
|
| 415 |
+
self.check_identical_tokens(gold_tokens, pred_tokens)
|
| 416 |
+
tp, fp, fn = self.compare_labels(gold_labels, pred_labels)
|
| 417 |
+
self.compute_PRF_metrics(tp, fp, fn)
|
| 418 |
+
|
| 419 |
+
def compare_labels(self, gold_labels: list, pred_labels: list) -> tuple[int, int, int]:
|
| 420 |
+
"""
|
| 421 |
+
|
| 422 |
+
"""
|
| 423 |
+
true_positive = 0
|
| 424 |
+
false_positive = 0
|
| 425 |
+
false_negative = 0
|
| 426 |
+
|
| 427 |
+
for i, gold_label in enumerate(gold_labels): # not verified
|
| 428 |
+
pred_label = pred_labels[i]
|
| 429 |
+
if gold_label == pred_label:
|
| 430 |
+
if gold_label == "_":
|
| 431 |
+
continue
|
| 432 |
+
else:
|
| 433 |
+
true_positive += 1
|
| 434 |
+
else:
|
| 435 |
+
if pred_label == "_":
|
| 436 |
+
false_negative += 1
|
| 437 |
+
else:
|
| 438 |
+
if gold_label == "_":
|
| 439 |
+
false_positive += 1
|
| 440 |
+
else: # I-Conn/B-Conn mismatch ?
|
| 441 |
+
false_positive +=1
|
| 442 |
+
|
| 443 |
+
return true_positive, false_positive, false_negative
|
| 444 |
+
|
| 445 |
+
def parse_edu_data(self, path: str, str_i: bool, no_b: bool) -> tuple[list, list, list]:
|
| 446 |
+
"""
|
| 447 |
+
LABEL = in last column
|
| 448 |
+
"""
|
| 449 |
+
data = self.get_data(path, str_i)
|
| 450 |
+
tokens = []
|
| 451 |
+
labels = []
|
| 452 |
+
spans = []
|
| 453 |
+
counter = 0
|
| 454 |
+
span_start = -1
|
| 455 |
+
span_end = -1
|
| 456 |
+
for line in data.split("\n"): # this loop is same than version 1
|
| 457 |
+
if line.startswith("#") or line == "":
|
| 458 |
+
continue
|
| 459 |
+
else:
|
| 460 |
+
fields = line.split("\t") # Token
|
| 461 |
+
label = fields[-1]
|
| 462 |
+
if "-" in fields[0] or "." in fields[0]: # Multi-Word Expression or Ellipsis : No pred shall be there....
|
| 463 |
+
continue
|
| 464 |
+
elif no_b == True and fields[0] == "1":
|
| 465 |
+
label = "_"
|
| 466 |
+
elif self.LAB_SEG_B in label:
|
| 467 |
+
label = self.LAB_SEG_B
|
| 468 |
+
else:
|
| 469 |
+
label = "_" # 🚩
|
| 470 |
+
if span_start > -1: # Add span
|
| 471 |
+
if span_end == -1:
|
| 472 |
+
span_end = span_start
|
| 473 |
+
spans.append((span_start, span_end))
|
| 474 |
+
span_start = -1
|
| 475 |
+
span_end = -1
|
| 476 |
+
|
| 477 |
+
tokens.append(fields[1])
|
| 478 |
+
labels.append(label)
|
| 479 |
+
counter += 1
|
| 480 |
+
|
| 481 |
+
if span_start > -1 and span_end > -1: # Add last span
|
| 482 |
+
spans.append((span_start, span_end))
|
| 483 |
+
|
| 484 |
+
if not self.LAB_SEG_B in labels:
|
| 485 |
+
exit(f"Unrecognized labels. Expecting: {self.LAB_SEG_B}, {self.LAB_SEG_I}...")
|
| 486 |
+
|
| 487 |
+
return tokens, labels, spans
|
| 488 |
+
|
| 489 |
+
|
| 490 |
+
if __name__ == "__main__":
|
| 491 |
+
|
| 492 |
+
p = argparse.ArgumentParser()
|
| 493 |
+
p.add_argument("-g", "--goldfile", required=True, help="Shared task gold file in .tok or .conll or .rels format.")
|
| 494 |
+
p.add_argument("-p", "--predfile", required=True, help="Corresponding file with system predictions.")
|
| 495 |
+
p.add_argument("-t", "--task", required=True, choices=['S', 'C', 'R'], help="Choose one of the three options: S (EDUs Segmentation), C (Connectives Detection), R (Relations Classification)")
|
| 496 |
+
p.add_argument("-s", "--string_input",action="store_true",help="Whether inputs are file names or strings.")
|
| 497 |
+
p.add_argument("-nb", "--no_boundary_edu", default=False, action='store_true', help="Does not count EDU that starts at beginning of sentence.")
|
| 498 |
+
p.add_argument("-rt", "--rel_type", default=False, action='store_true', help="Eval relations types instead of label.")
|
| 499 |
+
|
| 500 |
+
# help(Evaluation)
|
| 501 |
+
# help(SegmentationEvaluation)
|
| 502 |
+
# help(ConnectivesEvaluation)
|
| 503 |
+
# help(RelationsEvaluation)
|
| 504 |
+
|
| 505 |
+
opts = p.parse_args()
|
| 506 |
+
|
| 507 |
+
name = opts.goldfile.split(os.sep)[-1] if os.path.isfile(opts.goldfile) else f"string_input: {opts.goldfile[0:20]}..."
|
| 508 |
+
|
| 509 |
+
if opts.task == "R":
|
| 510 |
+
my_eval = RelationsEvaluation(name, opts.goldfile, opts.predfile, opts.string_input, opts.rel_type)
|
| 511 |
+
elif opts.task == "C":
|
| 512 |
+
my_eval = ConnectivesEvaluation(name, opts.goldfile, opts.predfile, opts.string_input)
|
| 513 |
+
elif opts.task == "S":
|
| 514 |
+
my_eval = SegmentationEvaluation(name, opts.goldfile, opts.predfile, opts.string_input, opts.no_boundary_edu)
|
| 515 |
+
|
| 516 |
+
my_eval.compute_scores()
|
| 517 |
+
my_eval.print_results()
|
disrpt_io.py
ADDED
|
@@ -0,0 +1,846 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Classes to read/write disrpt-like files
|
| 3 |
+
+ analysis of sentence splitter / "gold" sentences or stanza/spacy sentences
|
| 4 |
+
- ersatz
|
| 5 |
+
|
| 6 |
+
Disrpt is a discourse analysis campaign with (as of 2023):
|
| 7 |
+
- discourse segmentation information, in a conll-like format
|
| 8 |
+
- discourse connective information (also conll-like)
|
| 9 |
+
- discourse relations pairs, in a specific format
|
| 10 |
+
|
| 11 |
+
data are separated by corpora and language with conventionnal names
|
| 12 |
+
as language.framework.corpusname
|
| 13 |
+
eg fra.srdt.annodis
|
| 14 |
+
|
| 15 |
+
TODO:
|
| 16 |
+
- refactor how sentences are stored with dictionary: "connlu" / "tok" / "split"
|
| 17 |
+
[ok] dictionary
|
| 18 |
+
? refactor creation of corpus/documents to allow for update (or load tok+conllu at once)
|
| 19 |
+
- [ok] italian luna corpus has different meta tags avec un niveau supplémentaire: newdoc_id/newturn_id/newutterance_id
|
| 20 |
+
- [ok] check behaviour on languages without pretrained models/what candidates ?
|
| 21 |
+
- nl, pt, it -> en?
|
| 22 |
+
- thai -> multilingual
|
| 23 |
+
- test different candidates sets for splitting locations:
|
| 24 |
+
- [done] all -> trop sous-spécifié et trop lent
|
| 25 |
+
- [ok] en on all but zho+thai
|
| 26 |
+
- (done] en à la place de multilingual ?
|
| 27 |
+
bad scores on zho
|
| 28 |
+
- [ok] fix bad characters: BOM, replacement char etc
|
| 29 |
+
spécial char for apostrophe, cf
|
| 30 |
+
data_clean/eng.dep.scidtb/eng.dep.scidtb_train.tok / newdoc_id = P16-1030 prob de char pour possessif
|
| 31 |
+
��antagonist��
|
| 32 |
+
|
| 33 |
+
pb basque: "Osasun-zientzietako Ikertzaileen II ." nb tokens ...
|
| 34 |
+
Iru�eko etc
|
| 35 |
+
- pb turk: tur.pdtb.tdb/tur.pdtb.tdb_train: BOM ? '\ufeff' -> 'Makale'
|
| 36 |
+
+ extra blanc dans train (785)?
|
| 37 |
+
774 olduğunu _ _ _ _ _ _ _ _
|
| 38 |
+
775 söylüyor _ _ _ _ _ _ _ _
|
| 39 |
+
776 : _ _ _ _ _ _ _ _
|
| 40 |
+
777 Türkiye _ _ _ _ _ _ _ _
|
| 41 |
+
778 demokrasi _ _ _ _ _ _ _ _
|
| 42 |
+
779 istiyor _ _ _ _ _ _ _ _
|
| 43 |
+
780 ÖDPGenel _ _ _ _ _ _ _ _
|
| 44 |
+
781 Başkanı _ _ _ _ _ _ _ _
|
| 45 |
+
782 Ufuk _ _ _ _ _ _ _ _
|
| 46 |
+
783 Uras'tan _ _ _ _ _ _ _ _
|
| 47 |
+
784 : _ _ _ _ _ _ _ _
|
| 48 |
+
785 _ _ _ _ _ _ _ _
|
| 49 |
+
786 Türkiye _ _ _ _ _ _ _ _
|
| 50 |
+
787 , _ _ _ _ _ _ _ _
|
| 51 |
+
788 AİHM'de _ _
|
| 52 |
+
- pb zh
|
| 53 |
+
zh: ?是 is this "?" listed in ersatz ?
|
| 54 |
+
??hosto2
|
| 55 |
+
sctb 3.巴斯克
|
| 56 |
+
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
- specific preproc:
|
| 60 |
+
annodis/gum: titles
|
| 61 |
+
gum/rrt : biblio / articles
|
| 62 |
+
scidtb ?
|
| 63 |
+
- different sentence splitters
|
| 64 |
+
- [ok] ersatz
|
| 65 |
+
- trankit
|
| 66 |
+
- [abandoned] stanza: FIXME: lots of errors done by stanza eg split within words (might be due to bad input tokenization)
|
| 67 |
+
- [done] write doc in disrt format (after transformation for instance)
|
| 68 |
+
- [done] eval of beginning of sentences (precision)
|
| 69 |
+
- [done] (done in split_sentence script) eval / nb sentences connl ~= recall sentences
|
| 70 |
+
- eval length sentences (max)
|
| 71 |
+
- [moot] clean main script : arguments/argparse -> script à part
|
| 72 |
+
- [done] method for sentence splitting (for tok)
|
| 73 |
+
- [done] iterate all docs in corpus
|
| 74 |
+
- [done] choose language according to corpus name automatically
|
| 75 |
+
- ?method for sentence resplitting for conllu ? needs ways of indexing tokens for later reeval ? or eval script does not care ?
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
candidate sets for splitting:
|
| 79 |
+
|
| 80 |
+
- multilingual (default) is as described in ersatz paper == [EOS punctuation][!number]
|
| 81 |
+
- en requires a space following punctuation
|
| 82 |
+
- all: a space between any two characters
|
| 83 |
+
- custom can be written that uses the determiner.Split() base class
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
"""
|
| 88 |
+
import sys, os
|
| 89 |
+
import dataclasses
|
| 90 |
+
from itertools import chain
|
| 91 |
+
from collections import Counter
|
| 92 |
+
from copy import copy, deepcopy
|
| 93 |
+
from tqdm import tqdm
|
| 94 |
+
#import progressbar
|
| 95 |
+
#from ersatz import split, utils
|
| 96 |
+
# import trankit
|
| 97 |
+
#import stanza
|
| 98 |
+
#from stanza.pipeline.core import DownloadMethod
|
| 99 |
+
|
| 100 |
+
from transformers import pipeline
|
| 101 |
+
|
| 102 |
+
from wtpsplit import SaT
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
# needed to track the mistakes made in preprocessing of the disrpt dataset, whose origin is unknown
|
| 107 |
+
BOM = '\ufeff'
|
| 108 |
+
REPL_CHAR = "\ufffd" # �
|
| 109 |
+
|
| 110 |
+
test_doc_seg = """# newdoc id = geop_3_space
|
| 111 |
+
1 La le DET _ Definite=Def|Gender=Fem|Number=Sing|PronType=Art 2 det _ BeginSeg=Yes
|
| 112 |
+
2 Space space PROPN _ _ 0 root _ _
|
| 113 |
+
3 Launcher Launcher PROPN _ _ 2 flat:name _ _
|
| 114 |
+
4 Initiative initiative PROPN _ _ 2 flat:name _ _
|
| 115 |
+
5 . . PUNCT _ _ 2 punct _ _
|
| 116 |
+
|
| 117 |
+
1 Le le DET _ Definite=Def|Gender=Masc|Number=Sing|PronType=Art 2 det _ BeginSeg=Yes
|
| 118 |
+
2 programme programme NOUN _ Gender=Masc|Number=Sing 10 nsubj _ _
|
| 119 |
+
3 de de ADP _ _ 4 case _ _
|
| 120 |
+
4 Space space PROPN _ _ 2 nmod _ _
|
| 121 |
+
5 Launcher Launcher PROPN _ _ 4 flat:name _ _
|
| 122 |
+
6 Initiative initiative PROPN _ _ 4 flat:name _ _
|
| 123 |
+
7 ( ( PUNCT _ _ 8 punct _ BeginSeg=Yes
|
| 124 |
+
8 SLI SLI PROPN _ _ 4 appos _ _
|
| 125 |
+
9 ) ) PUNCT _ _ 8 punct _ _
|
| 126 |
+
10 vise viser VERB _ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 0 root _ BeginSeg=Yes
|
| 127 |
+
11 à à ADP _ _ 12 mark _ _
|
| 128 |
+
12 développer développer VERB _ VerbForm=Inf 10 ccomp _ _
|
| 129 |
+
13 un un DET _ Definite=Ind|Gender=Masc|Number=Sing|PronType=Art 14 det _ _
|
| 130 |
+
14 système système NOUN _ Gender=Masc|Number=Sing 12 obj _ _
|
| 131 |
+
15 de de ADP _ _ 16 case _ _
|
| 132 |
+
16 lanceur lanceur NOUN _ Gender=Masc|Number=Sing 14 nmod _ _
|
| 133 |
+
17 réutilisable réutilisable ADJ _ Gender=Masc|Number=Sing 16 amod _ _
|
| 134 |
+
18 entièrement entièrement ADV _ _ 19 advmod _ _
|
| 135 |
+
19 inédit inédit ADJ _ Gender=Masc|Number=Sing 14 amod _ _
|
| 136 |
+
20 . . PUNCT _ _ 10 punct _ _
|
| 137 |
+
|
| 138 |
+
# newdoc id = ling_fuchs_section2
|
| 139 |
+
1 Théorie théorie PROPN _ _ 0 root _ BeginSeg=Yes
|
| 140 |
+
2 psychomécanique psychomécanique ADJ _ Gender=Masc|Number=Sing 1 amod _ _
|
| 141 |
+
3 et et CCONJ _ _ 4 cc _ _
|
| 142 |
+
4 cognition cognition NOUN _ Gender=Fem|Number=Sing 1 conj _ _
|
| 143 |
+
5 . . PUNCT _ _ 1 punct _ _
|
| 144 |
+
"""
|
| 145 |
+
|
| 146 |
+
# token is just a simple record type
|
| 147 |
+
Token = dataclasses.make_dataclass("Token","id form lemma pos xpos morph head_id dep_type extra label".split(),
|
| 148 |
+
namespace={'__repr__': lambda self: self.form,
|
| 149 |
+
'format': lambda self: ("\t".join(map(str,dataclasses.astuple(self)))),
|
| 150 |
+
# ignored for now cos we just get rid of MWE when reading disrpt file
|
| 151 |
+
# but could be changed in the future
|
| 152 |
+
#'is_MWE': lambda self: type(self.id) is str and "-" in self.id,
|
| 153 |
+
}
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
class Sentence:
|
| 158 |
+
|
| 159 |
+
def __init__(self,token_list,meta):
|
| 160 |
+
self.toks = token_list
|
| 161 |
+
self.meta = meta
|
| 162 |
+
# Added by Firmin or chloe ?
|
| 163 |
+
self.label_start = ["Seg=B-conn", "Seg=B-seg"]
|
| 164 |
+
self.label_end = ["Seg=I-conn", "Seg=O"]
|
| 165 |
+
|
| 166 |
+
def __iter__(self):
|
| 167 |
+
return iter(self.toks)
|
| 168 |
+
|
| 169 |
+
def __len__(self):
|
| 170 |
+
return len(self.toks)
|
| 171 |
+
|
| 172 |
+
def display(self,segment=False):
|
| 173 |
+
"""if segment option set to true, print sentences with marking of EDUs"""
|
| 174 |
+
if segment:
|
| 175 |
+
output = [f"{'|' if token.label=='Seg=B-seg' else ''}{token.form}" for token in self]
|
| 176 |
+
# output = [f"{'|' if token.label=='BeginSeg=Yes' else ''}{token.form}" for token in self]
|
| 177 |
+
return " ".join(output)+"|"
|
| 178 |
+
else:
|
| 179 |
+
return self.meta["text"]
|
| 180 |
+
|
| 181 |
+
def __in__(self,word):
|
| 182 |
+
for token in self.toks:
|
| 183 |
+
if token.form == word:
|
| 184 |
+
return True
|
| 185 |
+
return False
|
| 186 |
+
|
| 187 |
+
def __repr__(self):
|
| 188 |
+
return self.display()
|
| 189 |
+
|
| 190 |
+
def format(self):
|
| 191 |
+
meta = f"# sent_id = {self.meta['sent_id']}\n" + f"# text = {self.meta['text']}\n"
|
| 192 |
+
output = "\n".join([t.format() for t in self.toks])
|
| 193 |
+
return meta+output
|
| 194 |
+
|
| 195 |
+
# not necessary because of trankit auto-mode but probably safer at some point
|
| 196 |
+
# why dont they use normalized language codes !!??
|
| 197 |
+
TRANKIT_LANG_MAP = {
|
| 198 |
+
"de": "german",
|
| 199 |
+
"en":"english",
|
| 200 |
+
# to be tested
|
| 201 |
+
"gum": "english-gum",
|
| 202 |
+
"fr":"french",
|
| 203 |
+
"it": "italian",
|
| 204 |
+
"sp": "spanish",
|
| 205 |
+
"es": "spanish",
|
| 206 |
+
"eu": "basque",
|
| 207 |
+
"zh": "chinese",
|
| 208 |
+
"ru": "russian",
|
| 209 |
+
"tr": "turkish",
|
| 210 |
+
"pt":"portuguese",
|
| 211 |
+
"fa": "persian",
|
| 212 |
+
"nl":"dutch",
|
| 213 |
+
# blah
|
| 214 |
+
}
|
| 215 |
+
|
| 216 |
+
lg_map = {"sp":"es",
|
| 217 |
+
"po":"pt",
|
| 218 |
+
"tu":"tr"}
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
def get_language(lang,model):
|
| 222 |
+
lang = lang[:2]
|
| 223 |
+
if lang in lg_map:
|
| 224 |
+
lang = lg_map[lang]
|
| 225 |
+
if model=="ersatz":
|
| 226 |
+
if lang not in ersatz_languages:
|
| 227 |
+
lang = "default-multilingual"
|
| 228 |
+
if model=="trankit":
|
| 229 |
+
lang = TRANKIT_LANG_MAP.get(lang,"auto")
|
| 230 |
+
return lang
|
| 231 |
+
|
| 232 |
+
# This is taken from ersatz https://github.com/rewicks/ersatz/blob/master/ersatz/candidates.py
|
| 233 |
+
# sentence ending punctuation
|
| 234 |
+
# U+0964 । Po DEVANAGARI DANDA
|
| 235 |
+
# U+061F ؟ Po ARABIC QUESTION MARK
|
| 236 |
+
# U+002E . Po FULL STOP
|
| 237 |
+
# U+3002 。 Po IDEOGRAPHIC FULL STOP
|
| 238 |
+
# U+0021 ! Po EXCLAMATION MARK
|
| 239 |
+
# U+06D4 ۔ Po ARABIC FULL STOP
|
| 240 |
+
# U+17D4 ។ Po KHMER SIGN KHAN
|
| 241 |
+
# U+003F ? Po QUESTION MARK
|
| 242 |
+
# U+2026 ... Po Ellipsis
|
| 243 |
+
# U+30FB
|
| 244 |
+
# U+002A *
|
| 245 |
+
|
| 246 |
+
# other acceptable punctuation
|
| 247 |
+
# U+3011 】 Pe RIGHT BLACK LENTICULAR BRACKET
|
| 248 |
+
# U+00BB » Pf RIGHT-POINTING DOUBLE ANGLE QUOTATION MARK
|
| 249 |
+
# U+201D " Pf RIGHT DOUBLE QUOTATION MARK
|
| 250 |
+
# U+300F 』 Pe RIGHT WHITE CORNER BRACKET
|
| 251 |
+
# U+2018 ‘ Pi LEFT SINGLE QUOTATION MARK
|
| 252 |
+
# U+0022 " Po QUOTATION MARK
|
| 253 |
+
# U+300D 」 Pe RIGHT CORNER BRACKET
|
| 254 |
+
# U+201C " Pi LEFT DOUBLE QUOTATION MARK
|
| 255 |
+
# U+0027 ' Po APOSTROPHE
|
| 256 |
+
# U+2019 ’ Pf RIGHT SINGLE QUOTATION MARK
|
| 257 |
+
# U+0029 ) Pe RIGHT PARENTHESIS
|
| 258 |
+
|
| 259 |
+
ending_punc = {
|
| 260 |
+
'\u0964',
|
| 261 |
+
'\u061F',
|
| 262 |
+
'\u002E',
|
| 263 |
+
'\u3002',
|
| 264 |
+
'\u0021',
|
| 265 |
+
'\u06D4',
|
| 266 |
+
'\u17D4',
|
| 267 |
+
'\u003F',
|
| 268 |
+
'\uFF61',
|
| 269 |
+
'\uFF0E',
|
| 270 |
+
'\u2026',
|
| 271 |
+
}
|
| 272 |
+
|
| 273 |
+
closing_punc = {
|
| 274 |
+
'\u3011',
|
| 275 |
+
'\u00BB',
|
| 276 |
+
'\u201D',
|
| 277 |
+
'\u300F',
|
| 278 |
+
'\u2018',
|
| 279 |
+
'\u0022',
|
| 280 |
+
'\u300D',
|
| 281 |
+
'\u201C',
|
| 282 |
+
'\u0027',
|
| 283 |
+
'\u2019',
|
| 284 |
+
'\u0029'
|
| 285 |
+
}
|
| 286 |
+
|
| 287 |
+
list_set = {
|
| 288 |
+
'\u30fb',
|
| 289 |
+
'\uFF65',
|
| 290 |
+
'\u002a', # asterisk
|
| 291 |
+
'\u002d',
|
| 292 |
+
'\u4e00'
|
| 293 |
+
}
|
| 294 |
+
class Document:
|
| 295 |
+
_hard_punct = {"default":{".",";","?","!"}| ending_punc,
|
| 296 |
+
"zh": {"。","?"}
|
| 297 |
+
}
|
| 298 |
+
|
| 299 |
+
def __init__(self,sentence_list,meta,src="conllu"):
|
| 300 |
+
self.sentences = {src:sentence_list}
|
| 301 |
+
self.meta = meta
|
| 302 |
+
|
| 303 |
+
def __repr__(self):
|
| 304 |
+
# ADDED (chloe) the if : else of file type
|
| 305 |
+
if "tok" in self.sentences:
|
| 306 |
+
return "\n".join(map(repr,self.sentences.get("conllu",self.sentences["tok"])))
|
| 307 |
+
elif "conllu" in self.sentences:
|
| 308 |
+
return "\n".join(map(repr,self.sentences.get("conllu",self.sentences["conllu"])))
|
| 309 |
+
else:
|
| 310 |
+
sys.exit("Unknown type of file: "+str(self.sentences.keys()))
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
def get_sentences(self,src="tok"):
|
| 314 |
+
return self.sentences[src]
|
| 315 |
+
|
| 316 |
+
def baseline_split(self,lang="default"):
|
| 317 |
+
"""default split for languages where we have issues re-aligning tokens for various reasons
|
| 318 |
+
|
| 319 |
+
this just splits at every token that is a hard punctuations
|
| 320 |
+
|
| 321 |
+
FIXME : this is not complete
|
| 322 |
+
"""
|
| 323 |
+
sentence_id = 1
|
| 324 |
+
sentences = []
|
| 325 |
+
current = []
|
| 326 |
+
orig_doc = self.sentences["tok"][0]
|
| 327 |
+
for token in orig_doc:
|
| 328 |
+
current.append(token)
|
| 329 |
+
if token.lemma in self._hard_punct.get(lang,"default"):
|
| 330 |
+
sentences.append(Sentence(current,meta))
|
| 331 |
+
meta = {"doc_id":orig_doc.meta["doc_id"],
|
| 332 |
+
"sent_id" : sentence_id,
|
| 333 |
+
"text": " ".join([x.form for x in current])
|
| 334 |
+
}
|
| 335 |
+
current = []
|
| 336 |
+
sentence += 1
|
| 337 |
+
if current!=[]:
|
| 338 |
+
meta = {"doc_id":orig_doc.meta["doc_id"],
|
| 339 |
+
"sent_id" : sentence_id,
|
| 340 |
+
"text": " ".join([x.form for x in current])
|
| 341 |
+
}
|
| 342 |
+
sentences.append(Sentence(current,meta))
|
| 343 |
+
return sentences
|
| 344 |
+
|
| 345 |
+
|
| 346 |
+
def cutoff_split(self,cutoff=120,lang="default"):
|
| 347 |
+
"""
|
| 348 |
+
default split for corpora with little or no punctuation (transcription etc)
|
| 349 |
+
|
| 350 |
+
just make a new sentence as soon as more than cutoff tokens
|
| 351 |
+
"""
|
| 352 |
+
sentence_id = 1
|
| 353 |
+
sentences = []
|
| 354 |
+
current = []
|
| 355 |
+
current_cpt = 1
|
| 356 |
+
orig_doc = self.sentences["tok"][0]
|
| 357 |
+
meta = {"doc_id":orig_doc.meta["doc_id"],
|
| 358 |
+
"sent_id" : sentence_id,
|
| 359 |
+
}
|
| 360 |
+
for token in orig_doc:
|
| 361 |
+
token.id = current_cpt
|
| 362 |
+
current_cpt += 1
|
| 363 |
+
current.append(token)
|
| 364 |
+
#print(token, token.id)
|
| 365 |
+
if len(current) >= cutoff:
|
| 366 |
+
#print(orig_doc.meta["doc_id"],token,current)
|
| 367 |
+
meta = {"doc_id":orig_doc.meta["doc_id"],
|
| 368 |
+
"sent_id" : sentence_id,
|
| 369 |
+
"text": " ".join([x.form for x in current])
|
| 370 |
+
}
|
| 371 |
+
sentences.append(Sentence(current,meta))
|
| 372 |
+
current = []
|
| 373 |
+
sentence_id += 1
|
| 374 |
+
current_cpt = 1
|
| 375 |
+
if current!=[]:
|
| 376 |
+
meta = {"doc_id":orig_doc.meta["doc_id"],
|
| 377 |
+
"sent_id" : sentence_id,
|
| 378 |
+
"text": " ".join([x.form for x in current])
|
| 379 |
+
}
|
| 380 |
+
sentences.append(Sentence(current,meta))
|
| 381 |
+
return sentences
|
| 382 |
+
|
| 383 |
+
def ersatz_split(self,doc,lang='default-multilingual',candidates="en"):
|
| 384 |
+
result = split(model=lang,
|
| 385 |
+
text=doc, output=None,
|
| 386 |
+
batch_size=16,
|
| 387 |
+
candidates=candidates,#'multilingual',
|
| 388 |
+
cpu=True, columns=None, delimiter='\t')
|
| 389 |
+
return result
|
| 390 |
+
|
| 391 |
+
def stanza_split(self,orig_doc,lang):
|
| 392 |
+
nlp = stanza.Pipeline(lang=lang, processors='tokenize',download_method=DownloadMethod.REUSE_RESOURCES)
|
| 393 |
+
doc = nlp(orig_doc)
|
| 394 |
+
sentences = []
|
| 395 |
+
for s in doc.sentences:
|
| 396 |
+
sentences.append(" ".join([t.text for t in s.tokens]))
|
| 397 |
+
return sentences
|
| 398 |
+
#for i, sentence in enumerate(doc.sentences): for token in sentence.tokens / token.text
|
| 399 |
+
|
| 400 |
+
def trankit_split(self,orig_doc,lang,pipeline):
|
| 401 |
+
trk_sentences = pipeline.ssplit(orig_doc)
|
| 402 |
+
sentences = []
|
| 403 |
+
for s in trk_sentences["sentences"]:
|
| 404 |
+
sentences.append(s["text"])
|
| 405 |
+
return sentences
|
| 406 |
+
|
| 407 |
+
def sat_split(self, orig_doc, sat_model):
|
| 408 |
+
sat_sentences = sat_model.split( str(orig_doc) )
|
| 409 |
+
sentences = []
|
| 410 |
+
for s in sat_sentences:
|
| 411 |
+
sentences.append(s)
|
| 412 |
+
return sentences
|
| 413 |
+
|
| 414 |
+
# TODO: debug option to for warnings on/off
|
| 415 |
+
def _remap_tokens(self,split_sentences):
|
| 416 |
+
"""remap tokens from sentence splitting to the token original information"""
|
| 417 |
+
#return split_sentences
|
| 418 |
+
# if this fails, there's been a bug: count of tokens is different in original text, and total
|
| 419 |
+
# of split sentences
|
| 420 |
+
# TODO: this is bound to happen, but the output should keep the original token count; how ?
|
| 421 |
+
# TODO: REALIGN by detecting split tokens
|
| 422 |
+
orig_token_nb = sum(map(len,self.sentences["tok"]))
|
| 423 |
+
split_token_nb = len(list(chain(*[x.split() for x in split_sentences])))
|
| 424 |
+
try:
|
| 425 |
+
assert orig_token_nb==split_token_nb
|
| 426 |
+
except:
|
| 427 |
+
print("WARNING wrong nb of tokens",orig_token_nb,"initially but",split_token_nb,"after split",file=sys.stderr)
|
| 428 |
+
#raise NotImplementedError
|
| 429 |
+
new_sentences = []
|
| 430 |
+
position = 0
|
| 431 |
+
skip_first_token = False
|
| 432 |
+
# will only work when splitting tok files, not resplitting conllu
|
| 433 |
+
orig_doc = self.sentences["tok"][0]
|
| 434 |
+
for i,s in enumerate(split_sentences):
|
| 435 |
+
new_toks = s.split()
|
| 436 |
+
if skip_first_token:# see below
|
| 437 |
+
new_toks = new_toks[1:]
|
| 438 |
+
toks = orig_doc.toks[position:position+len(new_toks)]
|
| 439 |
+
meta = {"doc_id":orig_doc.meta["doc_id"],
|
| 440 |
+
"sent_id" : i+1,
|
| 441 |
+
"text": " ".join([x.form for x in toks])
|
| 442 |
+
}
|
| 443 |
+
new_tok_position = position
|
| 444 |
+
shift = 0 # advance thru new tokens in case of erroneous splits
|
| 445 |
+
# actual nb of tokens to advance in the original document
|
| 446 |
+
# new tokens might include split token by mistake (tricky)
|
| 447 |
+
new_toks_length = len(new_toks)
|
| 448 |
+
for j in range(len(toks)):
|
| 449 |
+
toks[j].id = j+1
|
| 450 |
+
new_j = j + shift
|
| 451 |
+
try:
|
| 452 |
+
assert toks[j].form==new_toks[new_j]
|
| 453 |
+
# a split token has been detected meaning it had a punctuation sign in it and makes a "fake" sentence
|
| 454 |
+
# it will be recovered in current sentence so should be skipped in the next one
|
| 455 |
+
skip_first_token = False
|
| 456 |
+
except:
|
| 457 |
+
# TODO: check next token can be recovered
|
| 458 |
+
# pb with chinese punctuation difference codes ?
|
| 459 |
+
#print(f"WARNING === Token mismatch: {j,toks[j].form,new_toks[new_j]} \n {toks} \n {new_toks}",file=sys.stderr)
|
| 460 |
+
# first case: within the same sentence (unlikely if a token was split by a punctuation)
|
| 461 |
+
if j!= len(toks)-1:
|
| 462 |
+
if len(toks[j].form)!=len(new_toks[new_j]): # if same length this is probably just an encoding problem (chinese cases) so just ignore it
|
| 463 |
+
#print(f"INFO: split token still within the sentence {j,toks[j].form,new_toks[new_j]} ... should not happen",file=sys.stderr)
|
| 464 |
+
if toks[j].form==new_toks[new_j]+new_toks[new_j+1]:
|
| 465 |
+
#print(f"INFO: split token correctly identified as {j,toks[j].form,new_toks[new_j]+new_toks[new_j+1]} ... advancing to next one",file=sys.stderr)
|
| 466 |
+
shift = shift + 1
|
| 467 |
+
# second case: the sentence ends here and next token is in the next split sentence, which necessarily exists (?)
|
| 468 |
+
else:
|
| 469 |
+
if i+1<len(split_sentences):
|
| 470 |
+
next_sentence = split_sentences[i+1]
|
| 471 |
+
next_token = split_sentences[i+1].split()[0]
|
| 472 |
+
skip_first_token = True
|
| 473 |
+
if toks[j].form==new_toks[new_j]+next_token:
|
| 474 |
+
pass
|
| 475 |
+
#print(f"INFO: token can be recoverd: ",end="",file=sys.stderr)
|
| 476 |
+
else:
|
| 477 |
+
pass
|
| 478 |
+
#print(f"INFO: token can still not be recoverd: ",end="",file=sys.stderr)
|
| 479 |
+
#print(toks[j].form,new_toks[new_j]+next_token,file=sys.stderr)
|
| 480 |
+
else:
|
| 481 |
+
pass
|
| 482 |
+
#print(f"WARNING === unmatched token at end of document",new_toks[new_j],file=sys.stderr)
|
| 483 |
+
# in theory should not happen
|
| 484 |
+
# the next starting position has to be put back ? no
|
| 485 |
+
# position = position - 1
|
| 486 |
+
if len(toks)>0: # joining the first token might have generated an empty sentence
|
| 487 |
+
new_sentences.append(Sentence(toks,meta))
|
| 488 |
+
position = position + len(new_toks) - shift
|
| 489 |
+
else:
|
| 490 |
+
skip_first_token = False
|
| 491 |
+
split_token_nb = sum( [len(s.toks) for s in new_sentences] )
|
| 492 |
+
#print( "split_token_nb", split_token_nb)
|
| 493 |
+
try:
|
| 494 |
+
assert orig_token_nb==split_token_nb
|
| 495 |
+
except:
|
| 496 |
+
print("ERROR wrong nb of tokens",orig_token_nb,"originally but",split_token_nb,"after split+remap",file=sys.stderr)
|
| 497 |
+
sys.exit()
|
| 498 |
+
return new_sentences
|
| 499 |
+
|
| 500 |
+
|
| 501 |
+
def sentence_split(self,model="ersatz",lang="default-multilingual",**kwargs):
|
| 502 |
+
"""
|
| 503 |
+
call the sentence splitter to the actual document read as one from a tok file.
|
| 504 |
+
kwargs might contain an open "pipeline" (eg. trankit pipeline) to pass on downstream for splitting sentences, so that it is not re-created for each paragraph
|
| 505 |
+
"""
|
| 506 |
+
# if we split, the doc has been read as only one sentence
|
| 507 |
+
# we ignore multi-word-expression at reading time, but if this needs to be changed, it will impact this line:
|
| 508 |
+
doc = [x.form for x in self.sentences["tok"][0]] # if not(x.is_MWE())]
|
| 509 |
+
doc = " ".join(doc)
|
| 510 |
+
if model=="ersatz":
|
| 511 |
+
# empirically seems better: "en" for all alphabet-based language
|
| 512 |
+
# (candidates = candidate location for sentence splitting)
|
| 513 |
+
# not to be confused with the language of the model
|
| 514 |
+
candidates = "en" if lang not in {"zh","th"} else "multilingual"
|
| 515 |
+
new_sentences = self.ersatz_split(doc,lang=lang,candidates=candidates)
|
| 516 |
+
elif model=="stanza":
|
| 517 |
+
new_sentences = self.stanza_split(doc,lang=lang)
|
| 518 |
+
elif model=="trankit":# initiliazed pipeline is passed on here
|
| 519 |
+
new_sentences = self.trankit_split(doc,lang=lang,**kwargs)
|
| 520 |
+
elif model=="baseline":
|
| 521 |
+
new_sentences = self.baseline_split(lang=lang)
|
| 522 |
+
self.sentences["split"] = new_sentences
|
| 523 |
+
elif model=="sat":
|
| 524 |
+
sat_model = kwargs.get("sat_model")
|
| 525 |
+
if sat_model is None:
|
| 526 |
+
raise ValueError("sat_model must be provided for SAT sentence splitting.")
|
| 527 |
+
new_sentences = self.sat_split(doc, sat_model)
|
| 528 |
+
self.sentences["split"] = new_sentences
|
| 529 |
+
elif model == "cutoff":# FIXME should be a way to pass on the cutoff
|
| 530 |
+
new_sentences = self.cutoff_split(lang=lang)
|
| 531 |
+
self.sentences["split"] = new_sentences
|
| 532 |
+
else:
|
| 533 |
+
raise NotImplementedError
|
| 534 |
+
if model!="baseline" and model!="cutoff":
|
| 535 |
+
self.sentences["split"] = self._remap_tokens(new_sentences)
|
| 536 |
+
return self.sentences["split"]
|
| 537 |
+
|
| 538 |
+
|
| 539 |
+
def search_word(self,word):
|
| 540 |
+
return [s for s in self.sentences.get("split",[]) if word in s]
|
| 541 |
+
|
| 542 |
+
def format(self,mode="split"):
|
| 543 |
+
"""format the document as disrpt format
|
| 544 |
+
mode=original (sentences) or split (split_sentences)
|
| 545 |
+
"""
|
| 546 |
+
target = self.sentences[mode]
|
| 547 |
+
|
| 548 |
+
output = "\n".join([s.format()+"\n" for s in target])
|
| 549 |
+
meta = f"# doc_id = {self.meta}\n"
|
| 550 |
+
return meta+output #+"\n"
|
| 551 |
+
|
| 552 |
+
|
| 553 |
+
class Corpus:
|
| 554 |
+
META_types = {"newdoc_id":"doc_id",
|
| 555 |
+
"newdoc id":"doc_id",
|
| 556 |
+
"doc_id":"doc_id",
|
| 557 |
+
"sent_id":"sent_id",
|
| 558 |
+
"newturn_id":"newturn_id",
|
| 559 |
+
"newutterance":"newutterance",
|
| 560 |
+
"newutterance_id":"newutterance_id",
|
| 561 |
+
"text":"text",
|
| 562 |
+
}
|
| 563 |
+
|
| 564 |
+
|
| 565 |
+
|
| 566 |
+
def __init__(self,data=None):
|
| 567 |
+
"""input to constructor is a string
|
| 568 |
+
"""
|
| 569 |
+
if data:
|
| 570 |
+
self.docs = self._parse(data.split("\n"))
|
| 571 |
+
|
| 572 |
+
@staticmethod
|
| 573 |
+
def _meta_parse(data_line):
|
| 574 |
+
""" parse comments as they contain meta information"""
|
| 575 |
+
if not("=" in data_line):# not a meta line
|
| 576 |
+
return "",""
|
| 577 |
+
info, value = data_line[1:].strip().split("=",1)
|
| 578 |
+
info = info.strip()
|
| 579 |
+
if info in Corpus.META_types:
|
| 580 |
+
meta_type = Corpus.META_types[info]
|
| 581 |
+
else:# TODO should send a warning
|
| 582 |
+
#print("WARNING: bad meta line",info, value,data_line,file=sys.stderr) -> this is just flooding the output
|
| 583 |
+
meta_type, value = "",""
|
| 584 |
+
return meta_type,value.strip()
|
| 585 |
+
|
| 586 |
+
def search_doc(self,docid):
|
| 587 |
+
return [x for x in self.docs if x.meta==docid]
|
| 588 |
+
|
| 589 |
+
def _parse(self,data_lines,src="tok"):
|
| 590 |
+
"""parse disrpt segmentation/connective files"""
|
| 591 |
+
curr_token_list = []
|
| 592 |
+
sentences = []
|
| 593 |
+
docs = []
|
| 594 |
+
s_idx = 0
|
| 595 |
+
doc_idx = 0
|
| 596 |
+
meta = {}
|
| 597 |
+
|
| 598 |
+
for data_line in data_lines:
|
| 599 |
+
data_line = data_line.strip()
|
| 600 |
+
if data_line:
|
| 601 |
+
# comments always include some meta info of the form "metatype = value", minimally the document id
|
| 602 |
+
if data_line.startswith("#"):
|
| 603 |
+
meta_type,value = Corpus._meta_parse(data_line)
|
| 604 |
+
# start of a new doc, save previous one if it exists
|
| 605 |
+
if meta_type=="doc_id":
|
| 606 |
+
# print( doc_idx)
|
| 607 |
+
if doc_idx>0:
|
| 608 |
+
# print(src)
|
| 609 |
+
docs.append(Document(sentences,meta["doc_id"],src=src))
|
| 610 |
+
sentences = []
|
| 611 |
+
curr_token_list = []
|
| 612 |
+
s_idx = 0
|
| 613 |
+
meta = {}
|
| 614 |
+
doc_idx += 1
|
| 615 |
+
if meta_type!="":
|
| 616 |
+
meta[meta_type] = value
|
| 617 |
+
else:
|
| 618 |
+
token, label = self.parse_token(meta, data_line)
|
| 619 |
+
# print(token, label)
|
| 620 |
+
# if this is a MWE, just ignore it. MWE have ids combining original token ids, eg "30-31"
|
| 621 |
+
# TODO: refactor in parse_token + boolean flag if ok
|
| 622 |
+
if not("-" in token[0]) and not("." in token[0]):
|
| 623 |
+
curr_token_list.append(Token(*token,label))
|
| 624 |
+
else:# end of sentence
|
| 625 |
+
meta["text"] = " ".join((x.form for x in curr_token_list))
|
| 626 |
+
s_idx += 1
|
| 627 |
+
# some corpora dont have ids for sentences
|
| 628 |
+
if "sent_id" not in meta:
|
| 629 |
+
meta["sent_id"] = s_idx
|
| 630 |
+
sentences.append(Sentence(curr_token_list,meta))
|
| 631 |
+
curr_token_list = []
|
| 632 |
+
meta = {"doc_id":meta["doc_id"]}
|
| 633 |
+
if len(curr_token_list)>0 or len(sentences)>0:# final sentence for final document
|
| 634 |
+
meta["text"] = " ".join((x.form for x in curr_token_list))
|
| 635 |
+
sentences.append(Sentence(curr_token_list,meta))
|
| 636 |
+
#print("="*50)
|
| 637 |
+
#print(meta.keys())
|
| 638 |
+
#print(len(curr_token_list),len(sentences))
|
| 639 |
+
docs.append(Document(sentences,meta["doc_id"],src=src))
|
| 640 |
+
# print(src)
|
| 641 |
+
return docs
|
| 642 |
+
def format(self, file=None, mode="split"):
|
| 643 |
+
output = "\n\n".join([doc.format(mode=mode) for doc in self.docs])
|
| 644 |
+
if file:
|
| 645 |
+
os.makedirs(os.path.dirname(file), exist_ok=True)
|
| 646 |
+
with open(file, "w", encoding="utf-8") as f:
|
| 647 |
+
f.write(output)
|
| 648 |
+
return output
|
| 649 |
+
def parse_token(self, meta, data_line):
|
| 650 |
+
*token, label = data_line.split("\t")
|
| 651 |
+
if len(token)==8:
|
| 652 |
+
print("ERROR: missing label ",meta,token,file=sys.stderr)
|
| 653 |
+
token = token + [label]
|
| 654 |
+
label = '_'
|
| 655 |
+
# needed because of errors in source of some corpora (russian with BOM kept as token; also bad reading of some chars)
|
| 656 |
+
# to prevent token counts/tokenization from failing, they are replaced with '_'
|
| 657 |
+
# token[1] is the form of the token
|
| 658 |
+
if token[1] == BOM: token[1]="_"
|
| 659 |
+
#if token[1] == '200�000':
|
| 660 |
+
# print("GOTCHA")
|
| 661 |
+
token[1] = token[1].replace(REPL_CHAR,"_")
|
| 662 |
+
label_set = set(label.split("|"))
|
| 663 |
+
label = (label_set & set(self.LABELS))
|
| 664 |
+
if label==set():
|
| 665 |
+
label= "_"
|
| 666 |
+
else:
|
| 667 |
+
label = label.pop()
|
| 668 |
+
return token,label
|
| 669 |
+
|
| 670 |
+
def from_file(self,filepath):
|
| 671 |
+
"""
|
| 672 |
+
reads a conllu or tok file
|
| 673 |
+
connlu has sentences, tok does not
|
| 674 |
+
|
| 675 |
+
option to pass on a string instead of file path, mostly for testing
|
| 676 |
+
|
| 677 |
+
TODO: should be a static method
|
| 678 |
+
"""
|
| 679 |
+
self.filepath = filepath
|
| 680 |
+
basename = os.path.basename(filepath)
|
| 681 |
+
src = basename.split(".")[-1] # tok or connlu or split
|
| 682 |
+
#print("src = ",src)
|
| 683 |
+
with open(filepath,"r",encoding="utf8") as f:
|
| 684 |
+
data_lines = f.readlines()
|
| 685 |
+
self.docs = self._parse(data_lines,src=src)
|
| 686 |
+
# for sent in self.docs:
|
| 687 |
+
# print( sent )
|
| 688 |
+
def from_string(self, text: str, src="conllu"):
|
| 689 |
+
"""
|
| 690 |
+
Lit directement à partir d'une string (utile pour tests ou génération dynamique).
|
| 691 |
+
src peut être 'conllu', 'tok', ou 'split' pour indiquer le format.
|
| 692 |
+
"""
|
| 693 |
+
self.filepath = None
|
| 694 |
+
if isinstance(text, str):
|
| 695 |
+
data_lines = text.strip().splitlines()
|
| 696 |
+
else:
|
| 697 |
+
raise ValueError("from_string attend une chaîne de caractères")
|
| 698 |
+
self.docs = self._parse(data_lines, src=src)
|
| 699 |
+
def format(self,mode="split",file=sys.stdout):
|
| 700 |
+
if type(file)==str:
|
| 701 |
+
os.makedirs(os.path.dirname(file), exist_ok=True)
|
| 702 |
+
file = open(file,"w")
|
| 703 |
+
for d in self.docs:
|
| 704 |
+
print(d.format(mode=mode),file=file)
|
| 705 |
+
|
| 706 |
+
def align(self,filepath):
|
| 707 |
+
"""load conllu for corresponding tok file"""
|
| 708 |
+
pass
|
| 709 |
+
|
| 710 |
+
def sentence_split(self,model="ersatz",lang="default-multilingual",**kwargs):
|
| 711 |
+
"""apply a sentence splitter to the document, assuming this was read from
|
| 712 |
+
a .tok file
|
| 713 |
+
|
| 714 |
+
kwargs might contain an open "pipeline" (eg. trankit pipeline) to pass on downstream for splitting sentences, so that it is not re-created for each paragraph
|
| 715 |
+
|
| 716 |
+
"""
|
| 717 |
+
for doc in tqdm(self.docs):
|
| 718 |
+
doc.sentence_split(model=model,lang=lang,**kwargs)
|
| 719 |
+
|
| 720 |
+
|
| 721 |
+
def eval_sentences(self,mode="split"):
|
| 722 |
+
"""eval sentence beginning as segment beginning
|
| 723 |
+
TODO rename -> precision
|
| 724 |
+
|
| 725 |
+
only .tok for now but could be used to eval re-split of connlu
|
| 726 |
+
more complex for pdtb: need to align tok and connlu
|
| 727 |
+
"""
|
| 728 |
+
tp = 0
|
| 729 |
+
total_s = 0
|
| 730 |
+
labels = []
|
| 731 |
+
for doc in self.docs:
|
| 732 |
+
for s in doc.get_sentences(mode):
|
| 733 |
+
if len(s.toks)==0:
|
| 734 |
+
print("WARNING empty sentence in ",s.meta,file=sys.stderr)
|
| 735 |
+
break
|
| 736 |
+
tp += (s.toks[0].label=="Seg=B-seg")
|
| 737 |
+
# tp += (s.toks[0].label=="BeginSeg=Yes")
|
| 738 |
+
total_s += 1
|
| 739 |
+
labels.extend([x.label for x in s])
|
| 740 |
+
counts = Counter(labels)
|
| 741 |
+
# return tp, total_s, counts["BeginSeg=Yes"]
|
| 742 |
+
return tp, total_s, counts["Seg=B-seg"]
|
| 743 |
+
|
| 744 |
+
class SegmentCorpus(Corpus):
|
| 745 |
+
LABELS = ["Seg=O","Seg=B-seg"]
|
| 746 |
+
|
| 747 |
+
class ConnectiveCorpus(Corpus):
|
| 748 |
+
LABELS = ['Conn=O', 'Conn=B-conn', 'Conn=I-conn']
|
| 749 |
+
id2label = {i: label for i, label in enumerate( LABELS )}
|
| 750 |
+
label2id = {v: k for k,v in id2label.items()}
|
| 751 |
+
|
| 752 |
+
class RelationCorpus(Corpus):
|
| 753 |
+
|
| 754 |
+
def from_file(self,filepath):
|
| 755 |
+
pass
|
| 756 |
+
|
| 757 |
+
# ersatz existing language-specific models
|
| 758 |
+
# for ersatz 1.0.0:
|
| 759 |
+
# ['en', 'ar', 'cs', 'de', 'es', 'et', 'fi', 'fr', 'gu', 'hi', 'iu', 'ja',
|
| 760 |
+
# 'kk', 'km', 'lt', 'lv', 'pl', 'ps', 'ro', 'ru', 'ta', 'tr', 'zh', 'default-multilingual']
|
| 761 |
+
# missing disrpt languages/what candidates ? nl, pt, it -> en? thai -> multilingual
|
| 762 |
+
|
| 763 |
+
|
| 764 |
+
if __name__=="__main__":
|
| 765 |
+
# testing
|
| 766 |
+
import sys, os
|
| 767 |
+
from pathlib import PurePath
|
| 768 |
+
# from ersatz import split, utils
|
| 769 |
+
# ersatz existing language-specific models
|
| 770 |
+
# languages = utils.MODELS.keys()
|
| 771 |
+
|
| 772 |
+
|
| 773 |
+
sat = SaT("sat-3l") # 3L is better with French guillemets
|
| 774 |
+
|
| 775 |
+
#print(corpus.docs[0].sentences[11].display(segment=True))
|
| 776 |
+
print( sat.split("This is a test This is another test.") )
|
| 777 |
+
if len(sys.argv)>1:
|
| 778 |
+
test_path = sys.argv[1]
|
| 779 |
+
else:
|
| 780 |
+
test_path = "../jiant/tests/test_data/eng.pdtb.pdtb/eng.pdtb.pdtb_debug.tok"
|
| 781 |
+
# test_path = "../jiant/tests/test_data/eng.pdtb.pdtb/eng.pdtb.pdtb_debug.tok"
|
| 782 |
+
|
| 783 |
+
basename = os.path.basename(test_path)
|
| 784 |
+
lang = basename.split(".")[0]
|
| 785 |
+
# lang = get_language(lang,"trankit")
|
| 786 |
+
|
| 787 |
+
path = PurePath(test_path)
|
| 788 |
+
#output_path = str(path.with_suffix(".split"))
|
| 789 |
+
output_path = "out"
|
| 790 |
+
|
| 791 |
+
if "pdtb" in test_path:
|
| 792 |
+
corpus = ConnectiveCorpus()
|
| 793 |
+
else:
|
| 794 |
+
corpus = SegmentCorpus()
|
| 795 |
+
corpus.from_file(test_path)
|
| 796 |
+
|
| 797 |
+
sat = SaT("sat-3l") # 3L is better with French guillemets
|
| 798 |
+
|
| 799 |
+
#print(corpus.docs[0].sentences[11].display(segment=True))
|
| 800 |
+
print( sat.split("This is a test This is another test.") )
|
| 801 |
+
doc1 = corpus.docs[0]
|
| 802 |
+
s0 = doc1.sentences["tok"][0]
|
| 803 |
+
print(doc1)
|
| 804 |
+
print(list(sat.split(str(doc1))))
|
| 805 |
+
# list(res)
|
| 806 |
+
# pipe = pipeline("token-classification", model="segment-any-text/sat-1l")
|
| 807 |
+
# res = doc1.sentence_split(model="sat")
|
| 808 |
+
|
| 809 |
+
# ------------------------------------------
|
| 810 |
+
# -- From SaT DOC
|
| 811 |
+
# https://github.com/segment-any-text/wtpsplit?tab=readme-ov-file#usage
|
| 812 |
+
# sat = SaT("sat-3l")
|
| 813 |
+
# optionally run on GPU for better performance
|
| 814 |
+
# also supports TPUs via e.g. sat.to("xla:0"), in that case pass `pad_last_batch=True` to sat.split
|
| 815 |
+
# sat.half().to("cuda")
|
| 816 |
+
|
| 817 |
+
# print( sat.split("This is a test This is another test.") )
|
| 818 |
+
# returns ["This is a test ", "This is another test."]
|
| 819 |
+
|
| 820 |
+
# # do this instead of calling sat.split on every text individually for much better performance
|
| 821 |
+
# sat.split(["This is a test This is another test.", "And some more texts..."])
|
| 822 |
+
# # returns an iterator yielding lists of sentences for every text
|
| 823 |
+
|
| 824 |
+
# # use our '-sm' models for general sentence segmentation tasks
|
| 825 |
+
# sat_sm = SaT("sat-3l-sm")
|
| 826 |
+
# sat_sm.half().to("cuda") # optional, see above
|
| 827 |
+
# sat_sm.split("this is a test this is another test")
|
| 828 |
+
# # returns ["this is a test ", "this is another test"]
|
| 829 |
+
|
| 830 |
+
# # use trained lora modules for strong adaptation to language & domain/style
|
| 831 |
+
# sat_adapted = SaT("sat-3l", style_or_domain="ud", language="en")
|
| 832 |
+
# sat_adapted.half().to("cuda") # optional, see above
|
| 833 |
+
# sat_adapted.split("This is a test This is another test.")
|
| 834 |
+
# # returns ['This is a test ', 'This is another test']
|
| 835 |
+
|
| 836 |
+
|
| 837 |
+
|
| 838 |
+
# check that number of token is conserved by sentence splitting
|
| 839 |
+
# #assert sum(map(len,doc1.sentences))==len(list(chain(*[x.split() for x in res])))
|
| 840 |
+
# pipeline = trankit.Pipeline(lang,gpu=True)
|
| 841 |
+
# corpus.sentence_split(model="trankit",lang=lang,pipeline=pipeline)
|
| 842 |
+
corpus.sentence_split(model="sat", sat_model=sat)
|
| 843 |
+
tp, tot, all = corpus.eval_sentences()
|
| 844 |
+
print(tp, tot, all)
|
| 845 |
+
#print(corpus.docs[0].split_sentences[0].toks[0].format())
|
| 846 |
+
corpus.format(file=output_path)
|
eval.py
ADDED
|
@@ -0,0 +1,760 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/python
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
import os, sys
|
| 5 |
+
import numpy as np
|
| 6 |
+
import transformers
|
| 7 |
+
|
| 8 |
+
import utils
|
| 9 |
+
|
| 10 |
+
import reading
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
SUBTOKEN_START = '##'
|
| 15 |
+
|
| 16 |
+
'''
|
| 17 |
+
TODOs:
|
| 18 |
+
|
| 19 |
+
- for now, if the dataset is cached, can t use word ids and the predictions
|
| 20 |
+
written are not based on original eval file, thus not exactly same number
|
| 21 |
+
of tokens (ignore contractions) --> doesn t work in disrpt eval script
|
| 22 |
+
|
| 23 |
+
Change in newest version of transformers:
|
| 24 |
+
from seqeval.metrics import accuracy_score
|
| 25 |
+
from seqeval.metrics import classification_report
|
| 26 |
+
from seqeval.metrics import f1_score
|
| 27 |
+
'''
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def simple_eval( dataset_eval, model_checkpoint, tokenizer, output_path,
|
| 31 |
+
config, trace=False ):
|
| 32 |
+
'''
|
| 33 |
+
Run the pre-trained model on the (dev) dataset to get predictions,
|
| 34 |
+
then write the predictions in an output file.
|
| 35 |
+
|
| 36 |
+
Parameters:
|
| 37 |
+
-----------
|
| 38 |
+
datasets: dict of DatasetDisc
|
| 39 |
+
The datasets read
|
| 40 |
+
model_checkpoint: str
|
| 41 |
+
path to the saved model
|
| 42 |
+
tokenizer: Tokenizer
|
| 43 |
+
tokenizer of the saved model (TODO: retrieve from model? or should be removed?)
|
| 44 |
+
output_path: str
|
| 45 |
+
path to the output directory where prediction files will be written
|
| 46 |
+
data_collator: DataCollator
|
| 47 |
+
(TODO: retrieve from model?)
|
| 48 |
+
'''
|
| 49 |
+
# Retrieve predictions (list of list of 0 and 1)
|
| 50 |
+
print("\n-- PREDICT on:", dataset_eval.annotations_file )
|
| 51 |
+
model_checkpoint = os.path.normpath(model_checkpoint)
|
| 52 |
+
print("model_checkpoint", model_checkpoint)
|
| 53 |
+
preds_from_model, label_ids, metrics = retrieve_predictions( model_checkpoint,
|
| 54 |
+
dataset_eval, output_path, tokenizer, config )
|
| 55 |
+
|
| 56 |
+
print("preds_from_model.shape", preds_from_model.shape)
|
| 57 |
+
print("label_ids.shape", label_ids.shape)
|
| 58 |
+
|
| 59 |
+
# - Compute metrics
|
| 60 |
+
print("\n-- COMPUTE METRICS" )
|
| 61 |
+
compute_metrics = utils.prepare_compute_metrics( dataset_eval.LABEL_NAMES_BIO )
|
| 62 |
+
metrics=compute_metrics([preds_from_model, label_ids])
|
| 63 |
+
max_preds_from_model = np.argmax(preds_from_model, axis=-1)
|
| 64 |
+
|
| 65 |
+
# - Write predictions:
|
| 66 |
+
pred_file = os.path.join( output_path, dataset_eval.basename+'.preds' )
|
| 67 |
+
print("\n-- WRITE PREDS in:", pred_file )
|
| 68 |
+
pred_file_success = True
|
| 69 |
+
|
| 70 |
+
try:
|
| 71 |
+
try:
|
| 72 |
+
# * retrieving the original words: will fail if cache not emptied
|
| 73 |
+
print( "Write predictions based on words")
|
| 74 |
+
predictions = align_tokens_labels_from_wordids( max_preds_from_model, dataset_eval,
|
| 75 |
+
tokenizer)
|
| 76 |
+
|
| 77 |
+
write_pred_file( dataset_eval.annotations_file, pred_file, predictions, trace=trace )
|
| 78 |
+
except IndexError:
|
| 79 |
+
# if error, we print the predictions with tokens, trying to merge subtokens
|
| 80 |
+
# based on SUBTOKEN_START and labels at -100
|
| 81 |
+
print( "Write predictions based on model tokenisation" )
|
| 82 |
+
aligned_tokens, aligned_golds, aligned_preds = align_tokens_labels_from_subtokens(
|
| 83 |
+
max_preds_from_model, dataset_eval, tokenizer, pred_file, trace=trace )
|
| 84 |
+
write_pred_file_from_scratch( aligned_tokens, aligned_golds, aligned_preds,
|
| 85 |
+
pred_file, trace=trace )
|
| 86 |
+
except Exception as e:
|
| 87 |
+
print( "Problem when trying to write predictions in file", pred_file )
|
| 88 |
+
print( "Exception:", e )
|
| 89 |
+
print("we skip the prediction writing step")
|
| 90 |
+
pred_file_success=False
|
| 91 |
+
|
| 92 |
+
if pred_file_success:
|
| 93 |
+
print( "\n-- EVAL DISRPT script" )
|
| 94 |
+
clean_pred_path = pred_file.replace('.preds', '.cleaned.preds')
|
| 95 |
+
utils.clean_pred_file(pred_file, clean_pred_path)
|
| 96 |
+
utils.compute_metrics_dirspt( dataset_eval, clean_pred_path, task=config['task'] )
|
| 97 |
+
# except:
|
| 98 |
+
# print("Problem when trying to compute scores with DISRPT eval script")
|
| 99 |
+
return metrics
|
| 100 |
+
# - Test DISRPT eval script
|
| 101 |
+
# try:
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def write_pred_file(annotations_file, pred_file, predictions, trace=False):
|
| 105 |
+
'''
|
| 106 |
+
Write a file containing the predictions based on the original annotation file.
|
| 107 |
+
It takes each line in the original evaluation file and append the prediction at
|
| 108 |
+
the end. Predictions and original tokens need to be perfectly aligned.
|
| 109 |
+
|
| 110 |
+
Parameters:
|
| 111 |
+
-----------
|
| 112 |
+
annotations_file: str | file path OR raw text
|
| 113 |
+
Path to the original evaluation file, or the text content itself
|
| 114 |
+
pred_file: str
|
| 115 |
+
Path to the output prediction file
|
| 116 |
+
predictions: list of str
|
| 117 |
+
Flat list of all predictions (DISRPT format) for all tokens in eval
|
| 118 |
+
'''
|
| 119 |
+
count_pred_B, count_gold_B = 0, 0
|
| 120 |
+
count_line_dash = 0
|
| 121 |
+
count_line_dot = 0
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
# --- Déterminer si annotations_file est un chemin ou du texte brut
|
| 126 |
+
if os.path.isfile(annotations_file):
|
| 127 |
+
with open(annotations_file, 'r', encoding='utf-8') as fin:
|
| 128 |
+
mylines = fin.readlines()
|
| 129 |
+
else:
|
| 130 |
+
# Considérer que c’est une string brute
|
| 131 |
+
mylines = annotations_file.strip().splitlines()
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
os.makedirs(os.path.dirname(pred_file), exist_ok=True)
|
| 135 |
+
with open(pred_file, 'w', encoding='utf-8') as fout:
|
| 136 |
+
count = 0
|
| 137 |
+
if trace:
|
| 138 |
+
print("len(predictions)", len(predictions))
|
| 139 |
+
for l in mylines:
|
| 140 |
+
l = l.strip()
|
| 141 |
+
if l.startswith("#"): # Keep metadata
|
| 142 |
+
fout.write(l + '\n')
|
| 143 |
+
elif l == '' or l == '\n': # keep line break
|
| 144 |
+
fout.write('\n')
|
| 145 |
+
elif '-' in l.split('\t')[0]: # Keep lines for contractions but no label
|
| 146 |
+
if trace:
|
| 147 |
+
print("WARNING: line with - in token, no label will be added")
|
| 148 |
+
count_line_dash += 1
|
| 149 |
+
fout.write(l + '\t' + '_' + '\n')
|
| 150 |
+
# strange case in GUM
|
| 151 |
+
elif '.' in l.split('\t')[0]: # Keep lines no label
|
| 152 |
+
count_line_dot += 1
|
| 153 |
+
if trace:
|
| 154 |
+
print("WARNING: line with . in token, no label will be added")
|
| 155 |
+
fout.write(l + '\t' + '_' + '\n')
|
| 156 |
+
else:
|
| 157 |
+
if 'B' in predictions[count]:
|
| 158 |
+
count_pred_B += 1
|
| 159 |
+
if 'Seg=B-seg' in l or 'Conn=B-conn' in l:
|
| 160 |
+
count_gold_B += 1
|
| 161 |
+
fout.write(l + '\t' + predictions[count] + '\n')
|
| 162 |
+
count += 1
|
| 163 |
+
|
| 164 |
+
print("Count the number of predictions corresponding to a B", count_pred_B, "vs Gold B", count_gold_B)
|
| 165 |
+
print("Count the number of lines with - in token", count_line_dash)
|
| 166 |
+
print("Count the number of lines with . in token", count_line_dot)
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def write_pred_file_from_scratch( aligned_tokens, aligned_golds, aligned_preds, pred_file, trace=False ):
|
| 171 |
+
'''
|
| 172 |
+
Write a prediction file based on a alignment between tokenisation and predictions.
|
| 173 |
+
Since we are not sur that we retrieved the exact alignment, the writing here is not based
|
| 174 |
+
on the original annotation file, but we use a similar format:
|
| 175 |
+
# Sent ID
|
| 176 |
+
tok_ID token gold_label pred_label
|
| 177 |
+
|
| 178 |
+
The use of the DISRPT script will show whther the alignment worked or not ...
|
| 179 |
+
|
| 180 |
+
Parameters:
|
| 181 |
+
----------
|
| 182 |
+
aligned_XX: list of list of str
|
| 183 |
+
The tokens / preds / golds for each sentence
|
| 184 |
+
'''
|
| 185 |
+
count_pred_B, count_gold_B = 0, 0
|
| 186 |
+
with open( pred_file, 'w' ) as fout:
|
| 187 |
+
if trace:
|
| 188 |
+
print( 'len tokens', len(aligned_tokens))
|
| 189 |
+
print("len(predictions)", len(aligned_preds))
|
| 190 |
+
print( 'len(golds)', len(aligned_preds))
|
| 191 |
+
for s, tok_sent in enumerate( aligned_tokens ):
|
| 192 |
+
fout.write( "# sent_id = "+str(s)+"\n" )
|
| 193 |
+
for i, tok in enumerate( tok_sent ):
|
| 194 |
+
g = aligned_golds[s][i]
|
| 195 |
+
p = aligned_preds[s][i]
|
| 196 |
+
fout.write( '\t'.join([str(i), tok, g, p])+'\n' )
|
| 197 |
+
if 'B' in p:
|
| 198 |
+
count_pred_B += 1
|
| 199 |
+
if 'Seg=B-seg' in g or 'Conn=B-conn' in g:
|
| 200 |
+
count_gold_B += 1
|
| 201 |
+
fout.write( "\n" )
|
| 202 |
+
print("Count the number of predictions corresponding to a B", count_pred_B, "vs Gold B", count_gold_B)
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
def align_tokens_labels_from_wordids( preds_from_model, dataset_eval, tokenizer, trace=False ):
|
| 210 |
+
'''
|
| 211 |
+
Write predictions for segmentation or connective tasks in an output files.
|
| 212 |
+
The output is the same as the input gold file, with an additional column
|
| 213 |
+
corresponding to the predicted label.
|
| 214 |
+
|
| 215 |
+
Easiest way (?): use word_ids information to merge the words that been split et
|
| 216 |
+
retrieve the original tokens from the input .tok / .conllu files and run
|
| 217 |
+
evaluation --> but not kept in the cached datasets
|
| 218 |
+
|
| 219 |
+
Parameters:
|
| 220 |
+
-----------
|
| 221 |
+
preds_from_model: list of int
|
| 222 |
+
The predicted labels (numeric ids)
|
| 223 |
+
dev: DatasetDisc
|
| 224 |
+
Dataset for evalusation
|
| 225 |
+
pred_file: str
|
| 226 |
+
Path to the file where predictions will be written
|
| 227 |
+
|
| 228 |
+
Return:
|
| 229 |
+
-------
|
| 230 |
+
predictions: list of String
|
| 231 |
+
The predicted labels (DISRPT format) for each original input word
|
| 232 |
+
'''
|
| 233 |
+
|
| 234 |
+
word_ids = dataset_eval.all_word_ids
|
| 235 |
+
id2label = dataset_eval.id2label
|
| 236 |
+
predictions = []
|
| 237 |
+
for i in range( preds_from_model.shape[0] ):
|
| 238 |
+
sent_input_ids = dataset_eval.tokenized_datasets['input_ids'][i]
|
| 239 |
+
tokens = dataset_eval.dataset['tokens'][i]
|
| 240 |
+
sent_tokens = tokenizer.decode(sent_input_ids[1:-1])
|
| 241 |
+
aligned_preds = _merge_tokens_preds_sent( word_ids[i], preds_from_model[i], tokens )
|
| 242 |
+
if trace:
|
| 243 |
+
print( '\n', i, sent_tokens )
|
| 244 |
+
print( sent_input_ids )
|
| 245 |
+
print( preds_from_model[i])
|
| 246 |
+
print( ' '.join( tokens ) )
|
| 247 |
+
print( "aligned_preds", aligned_preds )
|
| 248 |
+
for k, tok in enumerate( tokens ):
|
| 249 |
+
# Ignorer les tokens spéciaux
|
| 250 |
+
if tok.startswith('[LANG=') or tok.startswith('[FRAME='):
|
| 251 |
+
if trace:
|
| 252 |
+
print(f"Skip special token: {tok}")
|
| 253 |
+
continue
|
| 254 |
+
label = aligned_preds[k]
|
| 255 |
+
predictions.append( id2label[label] )
|
| 256 |
+
return predictions
|
| 257 |
+
|
| 258 |
+
def _merge_tokens_preds_sent( word_ids, preds, tokens ):
|
| 259 |
+
'''
|
| 260 |
+
The tokenizer split the tokens into subtokens, with labels added on subwords.
|
| 261 |
+
For evaluation, we need to merge the subtokens, and keep only the labels on
|
| 262 |
+
the plain tokens.
|
| 263 |
+
The function takes the whole input_ids and predictions for one sentence and
|
| 264 |
+
return the merged version.
|
| 265 |
+
We also get rid of tokens and associated labels for [CLS] and [SEP] and don't
|
| 266 |
+
keep predictions for padding tokens.
|
| 267 |
+
TODO: here inspireed from the mthod to split the labels, but we can cut the
|
| 268 |
+
2 continue (kept for debug)
|
| 269 |
+
|
| 270 |
+
input_ids: list
|
| 271 |
+
list of ids of (sub)tokens as produced by the (BERT like) tokenizer
|
| 272 |
+
preds: list
|
| 273 |
+
the predictions of the model
|
| 274 |
+
'''
|
| 275 |
+
aligned_toks = []
|
| 276 |
+
count = 0
|
| 277 |
+
new_labels = []
|
| 278 |
+
current_word = None
|
| 279 |
+
for i, word_id in enumerate( word_ids ):
|
| 280 |
+
count += 1
|
| 281 |
+
if word_id != current_word:
|
| 282 |
+
# New word
|
| 283 |
+
current_word = word_id
|
| 284 |
+
if word_id is not None:
|
| 285 |
+
new_labels.append( preds[i] )
|
| 286 |
+
aligned_toks.append( tokens[word_id] )
|
| 287 |
+
elif word_id is None:
|
| 288 |
+
# Special token
|
| 289 |
+
continue
|
| 290 |
+
else:
|
| 291 |
+
# Same word as previous token
|
| 292 |
+
continue
|
| 293 |
+
if len(new_labels) != len(aligned_toks) or len(new_labels) != len(tokens):
|
| 294 |
+
print( "WARNING, something wrong, not the same nb of tokens and predictions")
|
| 295 |
+
print( len(new_labels), len(aligned_toks), len(tokens) )
|
| 296 |
+
return new_labels
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
def map_labels_list( list_labels, id2label ):
|
| 300 |
+
return [id2label[l] for l in list_labels]
|
| 301 |
+
|
| 302 |
+
def align_tokens_labels_from_subtokens( preds_from_model, dataset_eval, tokenizer, pred_file, trace=False ):
|
| 303 |
+
'''
|
| 304 |
+
Align tokens and labels (merging subtokens, assigning the right label)
|
| 305 |
+
based on the specific characters for starting a subtoken (e.g. ## for BERT)
|
| 306 |
+
and label -100 assigned to contractions of MWE (e.g. it's).
|
| 307 |
+
But not completely sure that we get the exact alignment with original words here.
|
| 308 |
+
'''
|
| 309 |
+
aligned_tokens, aligned_golds, aligned_preds = [], [], []
|
| 310 |
+
id2label = dataset_eval.id2label
|
| 311 |
+
tokenized_dataset = dataset_eval.tokenized_datasets
|
| 312 |
+
# print("\ndataset_eval.tokenized_datasets", dataset_eval.tokenized_datasets)
|
| 313 |
+
# print("preds_from_model.shape", preds_from_model.shape)
|
| 314 |
+
# For each sentence
|
| 315 |
+
with open(pred_file, 'w') as fout:
|
| 316 |
+
# Iterate on sentences
|
| 317 |
+
for i in range( preds_from_model.shape[0] ):
|
| 318 |
+
# fout.write( "new_sent_"+str(i)+'\n' )
|
| 319 |
+
sent_input_ids = dataset_eval.tokenized_datasets['input_ids'][i]
|
| 320 |
+
sent_gold_labels = tokenized_dataset['labels'][i]
|
| 321 |
+
sent_pred_labels = preds_from_model[i]
|
| 322 |
+
aligned_t, aligned_g, aligned_p = _retrieve_tokens_from_sent( sent_input_ids, sent_pred_labels,
|
| 323 |
+
sent_gold_labels, tokenizer, trace=trace )
|
| 324 |
+
aligned_tokens.append(aligned_t)
|
| 325 |
+
aligned_golds.append( map_labels_list(aligned_g, id2label) )
|
| 326 |
+
aligned_preds.append( map_labels_list(aligned_p, id2label) )
|
| 327 |
+
return aligned_tokens, aligned_golds, aligned_preds
|
| 328 |
+
|
| 329 |
+
def _retrieve_tokens_from_sent( sent_input_ids, preds_from_model, sent_gold_labels, tokenizer, trace=False ):
|
| 330 |
+
# tokenized_dataset = dataset.tokenized_datasets
|
| 331 |
+
cur_token, cur_pred, cur_gold = None, None, None
|
| 332 |
+
tokens, golds, preds = [], [], []
|
| 333 |
+
if trace:
|
| 334 |
+
print( '\n\nlen(sent_input_ids', len(sent_input_ids))
|
| 335 |
+
print( 'len(preds_from_model)', len(preds_from_model) ) #with padding
|
| 336 |
+
print( 'len(sent_gold_labels)', sent_gold_labels)
|
| 337 |
+
# Ignore first and last token / labels
|
| 338 |
+
for j, input_id in enumerate( sent_input_ids[1:-1] ):
|
| 339 |
+
gold_label = sent_gold_labels[j+1]
|
| 340 |
+
pred_label = preds_from_model[j+1]
|
| 341 |
+
subtoken = tokenizer.decode( input_id )
|
| 342 |
+
if trace:
|
| 343 |
+
print( subtoken, gold_label, pred_label )
|
| 344 |
+
# Deal with tokens split into subtokens, keep label of the first subtoken
|
| 345 |
+
if subtoken.startswith( SUBTOKEN_START ) or gold_label == -100:
|
| 346 |
+
if cur_token == None:
|
| 347 |
+
print( "WARNING: first subtoken without a token, probably a contraction or MWE")
|
| 348 |
+
cur_token=""
|
| 349 |
+
cur_token += subtoken
|
| 350 |
+
else:
|
| 351 |
+
if cur_token != None:
|
| 352 |
+
tokens.append( cur_token )
|
| 353 |
+
golds.append(cur_gold)
|
| 354 |
+
preds.append(cur_pred)
|
| 355 |
+
cur_token = subtoken
|
| 356 |
+
cur_pred = pred_label
|
| 357 |
+
cur_gold = gold_label
|
| 358 |
+
# add last one
|
| 359 |
+
tokens.append( cur_token )
|
| 360 |
+
golds.append(cur_gold)
|
| 361 |
+
preds.append(cur_pred)
|
| 362 |
+
if trace:
|
| 363 |
+
print( "\ntokens:", len(tokens), tokens )
|
| 364 |
+
print( "golds", len(golds), golds )
|
| 365 |
+
print( "preds", len(preds), preds )
|
| 366 |
+
for i, tok in enumerate(tokens):
|
| 367 |
+
print( tok, golds[i], preds[i])
|
| 368 |
+
return tokens, golds, preds
|
| 369 |
+
|
| 370 |
+
def retrieve_predictions(model_checkpoint, dataset_eval, output_path, tokenizer, config):
|
| 371 |
+
"""
|
| 372 |
+
Load the trainer in eval mode and compute predictions
|
| 373 |
+
on dataset_eval (peut être un dataset HuggingFace OU une liste de phrases)
|
| 374 |
+
"""
|
| 375 |
+
import os, transformers
|
| 376 |
+
|
| 377 |
+
model_path = model_checkpoint
|
| 378 |
+
if os.path.isfile(model_checkpoint):
|
| 379 |
+
print(f"[INFO] Le chemin du modèle pointe vers un fichier, utilisation du dossier parent: {os.path.dirname(model_checkpoint)}")
|
| 380 |
+
model_path = os.path.dirname(model_checkpoint)
|
| 381 |
+
|
| 382 |
+
config_file = os.path.join(model_path, "config.json")
|
| 383 |
+
if not os.path.exists(config_file):
|
| 384 |
+
raise FileNotFoundError(f"Aucun fichier config.json trouvé dans {model_path}.")
|
| 385 |
+
|
| 386 |
+
# Load model
|
| 387 |
+
model = transformers.AutoModelForTokenClassification.from_pretrained(model_path)
|
| 388 |
+
|
| 389 |
+
# Collator
|
| 390 |
+
data_collator = transformers.DataCollatorForTokenClassification(
|
| 391 |
+
tokenizer=tokenizer,
|
| 392 |
+
padding=config["tok_config"]["padding"]
|
| 393 |
+
)
|
| 394 |
+
compute_metrics = utils.prepare_compute_metrics(
|
| 395 |
+
getattr(dataset_eval, "LABEL_NAMES_BIO", None) or []
|
| 396 |
+
)
|
| 397 |
+
|
| 398 |
+
# Mode eval
|
| 399 |
+
model.eval()
|
| 400 |
+
|
| 401 |
+
test_args = transformers.TrainingArguments(
|
| 402 |
+
output_dir=output_path,
|
| 403 |
+
do_train=False,
|
| 404 |
+
do_predict=True,
|
| 405 |
+
dataloader_drop_last=False,
|
| 406 |
+
report_to=config.get("report_to", "none"),
|
| 407 |
+
)
|
| 408 |
+
|
| 409 |
+
trainer = transformers.Trainer(
|
| 410 |
+
model=model,
|
| 411 |
+
args=test_args,
|
| 412 |
+
data_collator=data_collator,
|
| 413 |
+
compute_metrics=compute_metrics,
|
| 414 |
+
)
|
| 415 |
+
|
| 416 |
+
# Si dataset_eval est juste une liste de phrases → on fabrique un Dataset
|
| 417 |
+
from datasets import Dataset
|
| 418 |
+
|
| 419 |
+
if isinstance(dataset_eval, list):
|
| 420 |
+
dataset_eval = Dataset.from_dict({"text": dataset_eval})
|
| 421 |
+
def tokenize(batch):
|
| 422 |
+
return tokenizer(batch["text"], truncation=True, padding=True)
|
| 423 |
+
dataset_eval = dataset_eval.map(tokenize, batched=True)
|
| 424 |
+
|
| 425 |
+
|
| 426 |
+
predictions, label_ids, metrics = trainer.predict(dataset_eval)
|
| 427 |
+
else:
|
| 428 |
+
# - Make predictions on eval dataset
|
| 429 |
+
predictions, label_ids, metrics = trainer.predict(dataset_eval.tokenized_datasets)
|
| 430 |
+
return predictions, label_ids, metrics
|
| 431 |
+
|
| 432 |
+
|
| 433 |
+
|
| 434 |
+
# --------------------------------------------------------------------------
|
| 435 |
+
# --------------------------------------------------------------------------
|
| 436 |
+
if __name__=="__main__":
|
| 437 |
+
import argparse, os
|
| 438 |
+
import shutil
|
| 439 |
+
|
| 440 |
+
path = os.path.join(os.path.expanduser("~"), ".cache", "huggingface", "datasets")
|
| 441 |
+
|
| 442 |
+
if os.path.exists(path):
|
| 443 |
+
shutil.rmtree(path)
|
| 444 |
+
print(f"Le dossier '{path}' a été supprimé.")
|
| 445 |
+
else:
|
| 446 |
+
print(f"Le dossier '{path}' n'existe pas.")
|
| 447 |
+
|
| 448 |
+
parser = argparse.ArgumentParser(
|
| 449 |
+
description='DISCUT: Discourse segmentation and connective detection'
|
| 450 |
+
)
|
| 451 |
+
|
| 452 |
+
# EVAL file
|
| 453 |
+
parser.add_argument("-t", "--test",
|
| 454 |
+
help="Eval file. Default: data_test/eng.sample.rstdt/eng.sample.rstdt_dev.conllu",
|
| 455 |
+
default="data_test/eng.sample.rstdt/eng.sample.rstdt_dev.conllu")
|
| 456 |
+
|
| 457 |
+
# PRE FINE-TUNED MODEL
|
| 458 |
+
parser.add_argument("-m", "--model",
|
| 459 |
+
help="path to the directory where is the Model file.",
|
| 460 |
+
default=None)
|
| 461 |
+
|
| 462 |
+
# OUTPUT DIRECTORY
|
| 463 |
+
parser.add_argument("-o", "--output",
|
| 464 |
+
help="Directory where models and pred will be saved. Default: /home/cbraud/experiments/expe_discut_2025/",
|
| 465 |
+
default="./data/temp_expe/")
|
| 466 |
+
|
| 467 |
+
# CONFIG FILE FROM THE FINE TUNED MODEL
|
| 468 |
+
parser.add_argument("-c", "--config",
|
| 469 |
+
help="Config file. Default: ./config_seg.json",
|
| 470 |
+
default="./config_seg.json")
|
| 471 |
+
|
| 472 |
+
# TRACE / VERBOSITY
|
| 473 |
+
parser.add_argument( '-v', '--trace',
|
| 474 |
+
action='store_true',
|
| 475 |
+
default=False,
|
| 476 |
+
help="Whether to print full messages. If used, it will override the value in config file.")
|
| 477 |
+
|
| 478 |
+
# TODO Add an option for choosing the tool to split the sentences
|
| 479 |
+
|
| 480 |
+
args = parser.parse_args()
|
| 481 |
+
|
| 482 |
+
eval_path = args.test
|
| 483 |
+
output_path = args.output
|
| 484 |
+
if not os.path.isdir( output_path ):
|
| 485 |
+
os.makedirs(output_path, exist_ok=True )
|
| 486 |
+
config_file = args.config
|
| 487 |
+
model = args.model
|
| 488 |
+
trace = args.trace
|
| 489 |
+
|
| 490 |
+
print( '\n-[DISCUT]--PROGRAM (eval) ARGUMENTS')
|
| 491 |
+
print( '| Mode', 'eval' )
|
| 492 |
+
if not model:
|
| 493 |
+
sys.exit( "Please provide a path to a model for eval mode.")
|
| 494 |
+
print( '| Test_path:', eval_path )
|
| 495 |
+
print( "| Output_path:", output_path )
|
| 496 |
+
if model:
|
| 497 |
+
print( "| Model:", model )
|
| 498 |
+
print( '| Config:', config_file )
|
| 499 |
+
|
| 500 |
+
print( '\n-[DISCUT]--CONFIG INFO')
|
| 501 |
+
config = utils.read_config( config_file )
|
| 502 |
+
utils.print_config( config )
|
| 503 |
+
|
| 504 |
+
print( "\n-[DISCUT]--READING DATASET")
|
| 505 |
+
###
|
| 506 |
+
datasets = {}
|
| 507 |
+
datasets['dev'], tokenizer = reading.read_dataset( eval_path, output_path, config )
|
| 508 |
+
|
| 509 |
+
# model also in config[best_model_path]
|
| 510 |
+
metrics=simple_eval( datasets['dev'], model, tokenizer, output_path, config, trace=trace )
|
| 511 |
+
|
| 512 |
+
|
| 513 |
+
|
| 514 |
+
|
| 515 |
+
|
| 516 |
+
|
| 517 |
+
|
| 518 |
+
|
| 519 |
+
|
| 520 |
+
|
| 521 |
+
|
| 522 |
+
|
| 523 |
+
|
| 524 |
+
|
| 525 |
+
|
| 526 |
+
|
| 527 |
+
|
| 528 |
+
|
| 529 |
+
|
| 530 |
+
|
| 531 |
+
# # TODO clean, probably unused arguments here
|
| 532 |
+
# def simple_eval_deprecated( dataset_eval, model_checkpoint, tokenizer, output_path,
|
| 533 |
+
# config ):
|
| 534 |
+
# '''
|
| 535 |
+
# Run the pre-trained model on the (dev) dataset to get predictions,
|
| 536 |
+
# then write the predictions in an output file.
|
| 537 |
+
|
| 538 |
+
# Parameters:
|
| 539 |
+
# -----------
|
| 540 |
+
# datasets: dict of DatasetDisc
|
| 541 |
+
# The datasets read
|
| 542 |
+
# model_checkpoint: str
|
| 543 |
+
# path to the saved model
|
| 544 |
+
# tokenizer: Tokenizer
|
| 545 |
+
# tokenizer of the saved model (TODO: retrieve from model? or should be removed?)
|
| 546 |
+
# output_path: str
|
| 547 |
+
# path to the output directory where prediction files will be written
|
| 548 |
+
# data_collator: DataCollator
|
| 549 |
+
# (TODO: retrieve from model?)
|
| 550 |
+
# '''
|
| 551 |
+
# # tokenized_dataset = dataset_eval.tokenized_datasets
|
| 552 |
+
# dev_dataset = dataset_eval.dataset
|
| 553 |
+
|
| 554 |
+
# LABEL_NAMES = dataset_eval.LABEL_NAMES_BIO
|
| 555 |
+
# # TODO check if needed
|
| 556 |
+
# word_ids = dataset_eval.all_word_ids
|
| 557 |
+
# model = transformers.AutoModelForTokenClassification.from_pretrained(
|
| 558 |
+
# model_checkpoint
|
| 559 |
+
# )
|
| 560 |
+
# data_collator = transformers.DataCollatorForTokenClassification(
|
| 561 |
+
# tokenizer=tokenizer,
|
| 562 |
+
# padding=config["tok_config"]["padding"] )
|
| 563 |
+
|
| 564 |
+
# compute_metrics = utils.prepare_compute_metrics(LABEL_NAMES)
|
| 565 |
+
|
| 566 |
+
# # TODO is it useful to have both .eval() and test_args ?
|
| 567 |
+
# model.eval()
|
| 568 |
+
|
| 569 |
+
# test_args = transformers.TrainingArguments(
|
| 570 |
+
# output_dir = output_path,
|
| 571 |
+
# do_train = False,
|
| 572 |
+
# do_predict = True,
|
| 573 |
+
# #per_device_eval_batch_size = BATCH_SIZE,
|
| 574 |
+
# dataloader_drop_last = False
|
| 575 |
+
# )
|
| 576 |
+
|
| 577 |
+
# trainer = transformers.Trainer(
|
| 578 |
+
# model=model,
|
| 579 |
+
# args=test_args,
|
| 580 |
+
# data_collator=data_collator,
|
| 581 |
+
# compute_metrics=compute_metrics,
|
| 582 |
+
# )
|
| 583 |
+
# predictions, label_ids, metrics = trainer.predict(dataset_eval.tokenized_datasets)
|
| 584 |
+
# preds = np.argmax(predictions, axis=-1)
|
| 585 |
+
|
| 586 |
+
# compute_metrics([predictions, label_ids])
|
| 587 |
+
|
| 588 |
+
# # Try to write predictions: will fail if cache not emptied
|
| 589 |
+
# # because we need word_ids not saved in cache TODO check...
|
| 590 |
+
# pred_file = os.path.join( output_path, dataset_eval.basename+'.preds' )
|
| 591 |
+
# try:
|
| 592 |
+
# write_predictions_words( preds, dataset_eval.tokenized_datasets,
|
| 593 |
+
# tokenizer, pred_file, dataset_eval.id2label,
|
| 594 |
+
# word_ids, dev_dataset, dataset_eval )
|
| 595 |
+
# except IndexError:
|
| 596 |
+
# # if error, we print the predictions with tokens as is
|
| 597 |
+
# write_predictions_subtokens( preds, dataset_eval.tokenized_datasets,
|
| 598 |
+
# tokenizer, pred_file, dataset_eval.id2label )
|
| 599 |
+
# # Test DISRPT eval script
|
| 600 |
+
# print( "\nPerformance computed using disrpt eval script on", dataset_eval.annotations_file,
|
| 601 |
+
# pred_file )
|
| 602 |
+
# if config['task'] == 'seg':
|
| 603 |
+
# my_eval = disrpt_eval_2025.SegmentationEvaluation("temp_test_disrpt_eval_seg",
|
| 604 |
+
# dataset_eval.annotations_file,
|
| 605 |
+
# pred_file )
|
| 606 |
+
# elif config['task'] == 'conn':
|
| 607 |
+
# my_eval = disrpt_eval_2025.ConnectivesEvaluation("temp_test_disrpt_eval_conn",
|
| 608 |
+
# dataset_eval.annotations_file,
|
| 609 |
+
# pred_file )
|
| 610 |
+
# else:
|
| 611 |
+
# raise NotImplementedError
|
| 612 |
+
# my_eval.compute_scores()
|
| 613 |
+
# my_eval.print_results()
|
| 614 |
+
|
| 615 |
+
|
| 616 |
+
# # TODO: dd????
|
| 617 |
+
# # TODO : only for SEG/CONN --> to rename (and make a generic function)
|
| 618 |
+
# def write_predictions_words_deprecated( preds, dev, tokenizer, pred_file, id2label, word_ids,
|
| 619 |
+
# dev_dataset, dd, trace=False ):
|
| 620 |
+
# '''
|
| 621 |
+
# Write predictions for segmentation or connective tasks in an output files.
|
| 622 |
+
# The output is the same as the input gold file, with an additional column
|
| 623 |
+
# corresponding to the predicted label.
|
| 624 |
+
|
| 625 |
+
# ?? We need the word_ids information to merge the words that been split et
|
| 626 |
+
# retrieve the original tokens from the input .tok / .conllu files and run
|
| 627 |
+
# evaluation.
|
| 628 |
+
|
| 629 |
+
# Parameters:
|
| 630 |
+
# -----------
|
| 631 |
+
# preds: list of int
|
| 632 |
+
# The predicted labels (numeric ids)
|
| 633 |
+
# dev: Dataset
|
| 634 |
+
# tokenized_dev
|
| 635 |
+
# pred_file: str
|
| 636 |
+
# Path to the file where predictions will be written
|
| 637 |
+
# id2label: dict
|
| 638 |
+
# Convert from ids to labels
|
| 639 |
+
# word_ids: list?
|
| 640 |
+
# Word ids, None for task rel
|
| 641 |
+
# dev_dataset : Dataset
|
| 642 |
+
# Dataset for the dev set
|
| 643 |
+
# dd : str?
|
| 644 |
+
# dset
|
| 645 |
+
# '''
|
| 646 |
+
# predictions = []
|
| 647 |
+
# for i in range( preds.shape[0] ):
|
| 648 |
+
# sent_input_ids = dev['input_ids'][i]
|
| 649 |
+
# tokens = dev_dataset['tokens'][i]
|
| 650 |
+
# # sentence text
|
| 651 |
+
# sent_tokens = tokenizer.decode(sent_input_ids[1:-1])
|
| 652 |
+
# # list of decoded subtokens
|
| 653 |
+
# #sub_tokens = [tokenizer.decode(tok_id) for tok_id in sent_input_ids]
|
| 654 |
+
# # Merge subtokens and retrieve corresp. pred labels
|
| 655 |
+
# # i.e. we ignore: CLS, SEP, PAD and labels on ##subtoks
|
| 656 |
+
# aligned_preds = merge_tokens_preds_sent( word_ids[i], preds[i], tokens )
|
| 657 |
+
# if trace:
|
| 658 |
+
# print( '\n', i, sent_tokens )
|
| 659 |
+
# print( sent_input_ids )
|
| 660 |
+
# print( preds[i])
|
| 661 |
+
# print( ' '.join( tokens ) )
|
| 662 |
+
# print( "aligned_preds", aligned_preds )
|
| 663 |
+
# # sentence id, but TODO: retrieve doc ids
|
| 664 |
+
# #f.write( "# sent_id = "+str(i)+"\n" )
|
| 665 |
+
# # Write the original sentence text
|
| 666 |
+
# #f.write( "# text = "+sent_tokens+"\n" )
|
| 667 |
+
# # indices should start at 1
|
| 668 |
+
# for k, tok in enumerate( tokens ):
|
| 669 |
+
# label = aligned_preds[k]
|
| 670 |
+
# predictions.append( id2label[label] )
|
| 671 |
+
# #f.write( "\t".join( [str(k+1), tok, "_","_","_","_","_","_","_", id2label[label] ] )+"\n" )
|
| 672 |
+
# #f.write("\n")
|
| 673 |
+
# print("PREDICTIONS", predictions)
|
| 674 |
+
# count_pred_B, count_gold_B = 0, 0
|
| 675 |
+
# with open( dd.annotations_file, 'r' ) as fin:
|
| 676 |
+
# with open( pred_file, 'w' ) as fout:
|
| 677 |
+
# mylines = fin.readlines()
|
| 678 |
+
# count = 0
|
| 679 |
+
# if trace:
|
| 680 |
+
# print("len(predictions)", len(predictions))
|
| 681 |
+
# for l in mylines:
|
| 682 |
+
# l = l.strip()
|
| 683 |
+
# if l.startswith("#"):
|
| 684 |
+
# fout.write( l+'\n')
|
| 685 |
+
# elif l == '' or l == '\n':
|
| 686 |
+
# fout.write('\n')
|
| 687 |
+
# elif '-' in l.split('\t')[0]:
|
| 688 |
+
# fout.write( l+'\t'+'_'+'\n')
|
| 689 |
+
# else:
|
| 690 |
+
# if 'B' in predictions[count]:
|
| 691 |
+
# count_pred_B += 1
|
| 692 |
+
# if 'Seg=B-seg' in l or 'Conn=B-conn' in l:
|
| 693 |
+
# count_gold_B += 1
|
| 694 |
+
# fout.write( l+'\t'+predictions[count]+'\n')
|
| 695 |
+
# count += 1
|
| 696 |
+
|
| 697 |
+
|
| 698 |
+
# print("Count the number of predictions corresponding to a B", count_pred_B, "vs Gold B", count_gold_B)
|
| 699 |
+
|
| 700 |
+
|
| 701 |
+
# # TODO: dd????
|
| 702 |
+
# # TODO : only for SEG/CONN --> to rename (and make a generic function)
|
| 703 |
+
# def write_predictions_words( preds_from_model, dataset_eval, tokenizer, pred_file, trace=True ):
|
| 704 |
+
# '''
|
| 705 |
+
# Write predictions for segmentation or connective tasks in an output files.
|
| 706 |
+
# The output is the same as the input gold file, with an additional column
|
| 707 |
+
# corresponding to the predicted label.
|
| 708 |
+
|
| 709 |
+
# ?? We need the word_ids information to merge the words that been split et
|
| 710 |
+
# retrieve the original tokens from the input .tok / .conllu files and run
|
| 711 |
+
# evaluation.
|
| 712 |
+
|
| 713 |
+
# Parameters:
|
| 714 |
+
# -----------
|
| 715 |
+
# preds_from_model: list of int
|
| 716 |
+
# The predicted labels (numeric ids)
|
| 717 |
+
# dev: Dataset
|
| 718 |
+
# tokenized_dev
|
| 719 |
+
# pred_file: str
|
| 720 |
+
# Path to the file where predictions will be written
|
| 721 |
+
# id2label: dict
|
| 722 |
+
# Convert from ids to labels
|
| 723 |
+
# word_ids: list?
|
| 724 |
+
# Word ids, None for task rel
|
| 725 |
+
# dev_dataset : Dataset
|
| 726 |
+
# Dataset for the dev set
|
| 727 |
+
# dd : str?
|
| 728 |
+
# dset
|
| 729 |
+
# '''
|
| 730 |
+
# word_ids = dataset_eval.all_word_ids
|
| 731 |
+
# id2label = dataset_eval.id2label
|
| 732 |
+
# predictions = []
|
| 733 |
+
# for i in range( preds_from_model.shape[0] ):
|
| 734 |
+
# sent_input_ids = dataset_eval.tokenized_datasets['input_ids'][i]
|
| 735 |
+
# tokens = dataset_eval.dataset['tokens'][i]
|
| 736 |
+
# # sentence text
|
| 737 |
+
# sent_tokens = tokenizer.decode(sent_input_ids[1:-1])
|
| 738 |
+
# # list of decoded subtokens
|
| 739 |
+
# #sub_tokens = [tokenizer.decode(tok_id) for tok_id in sent_input_ids]
|
| 740 |
+
# # Merge subtokens and retrieve corresp. pred labels
|
| 741 |
+
# # i.e. we ignore: CLS, SEP, PAD and labels on ##subtoks
|
| 742 |
+
# aligned_preds = merge_tokens_preds_sent( word_ids[i], preds_from_model[i], tokens )
|
| 743 |
+
# if trace:
|
| 744 |
+
# print( '\n', i, sent_tokens )
|
| 745 |
+
# print( sent_input_ids )
|
| 746 |
+
# print( preds_from_model[i])
|
| 747 |
+
# print( ' '.join( tokens ) )
|
| 748 |
+
# print( "aligned_preds", aligned_preds )
|
| 749 |
+
# # sentence id, but TODO: retrieve doc ids
|
| 750 |
+
# #f.write( "# sent_id = "+str(i)+"\n" )
|
| 751 |
+
# # Write the original sentence text
|
| 752 |
+
# #f.write( "# text = "+sent_tokens+"\n" )
|
| 753 |
+
# # indices should start at 1
|
| 754 |
+
# for k, tok in enumerate( tokens ):
|
| 755 |
+
# label = aligned_preds[k]
|
| 756 |
+
# predictions.append( id2label[label] )
|
| 757 |
+
# #f.write( "\t".join( [str(k+1), tok, "_","_","_","_","_","_","_", id2label[label] ] )+"\n" )
|
| 758 |
+
# #f.write("\n")
|
| 759 |
+
# # print("PREDICTIONS", predictions)
|
| 760 |
+
# write_pred_file( dataset_eval.annotations_file, pred_file, predictions )
|
pipeline.py
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import Pipeline, AutoModelForTokenClassification
|
| 2 |
+
import numpy as np
|
| 3 |
+
from eval import retrieve_predictions, align_tokens_labels_from_wordids
|
| 4 |
+
from reading import read_dataset
|
| 5 |
+
from utils import read_config
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def write_sentences_to_format(sentences: list[str], filename: str):
|
| 10 |
+
"""
|
| 11 |
+
Écrit une phrase dans un fichier, un mot par ligne, avec le format :
|
| 12 |
+
index<TAB>mot<TAB>_<TAB>_<TAB>_<TAB>_<TAB>_<TAB>_<TAB>_<TAB>Seg=...
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
if not sentences:
|
| 16 |
+
return ""
|
| 17 |
+
if isinstance(sentences, str):
|
| 18 |
+
sentences=[sentences]
|
| 19 |
+
import sys
|
| 20 |
+
sys.stderr.write("Warning: only one sentence provided as a string instead of a list of sentences.\n")
|
| 21 |
+
|
| 22 |
+
full="# newdoc_id = GUM_academic_discrimination\n"
|
| 23 |
+
for sentence in sentences:
|
| 24 |
+
words = sentence.strip().split()
|
| 25 |
+
for i, word in enumerate(words, start=1):
|
| 26 |
+
# Le premier mot → B-seg, sinon O
|
| 27 |
+
seg_label = "B-seg" if i == 1 or word[0].isupper() else "O"
|
| 28 |
+
line = f"{i}\t{word}\t_\t_\t_\t_\t_\t_\t_\tSeg={seg_label}\n"
|
| 29 |
+
full+=line
|
| 30 |
+
if filename:
|
| 31 |
+
with open(filename, "w", encoding="utf-8") as f:
|
| 32 |
+
f.write(full)
|
| 33 |
+
|
| 34 |
+
return full
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class DiscoursePipeline(Pipeline):
|
| 38 |
+
def __init__(self, model, tokenizer, config:dict, output_folder="./pipe_out",sat_model:str="sat-3l", **kwargs):
|
| 39 |
+
auto_model = AutoModelForTokenClassification.from_pretrained(model)
|
| 40 |
+
super().__init__(model=auto_model, tokenizer=tokenizer, **kwargs)
|
| 41 |
+
self.config = {"model_checkpoint": model, "sent_spliter":"sat","task":"seg","type":"tok","trace":False,"report_to":"none","sat_model":sat_model,"tok_config":{
|
| 42 |
+
"padding":"max_length",
|
| 43 |
+
"truncation":True,
|
| 44 |
+
"max_length": 512
|
| 45 |
+
}}
|
| 46 |
+
self.model = model
|
| 47 |
+
self.output_folder = output_folder
|
| 48 |
+
|
| 49 |
+
def _sanitize_parameters(self, **kwargs):
|
| 50 |
+
# Permet de passer des paramètres optionnels comme add_lang_token etc.
|
| 51 |
+
preprocess_params = {}
|
| 52 |
+
forward_params = {}
|
| 53 |
+
postprocess_params = {}
|
| 54 |
+
return preprocess_params, forward_params, postprocess_params
|
| 55 |
+
|
| 56 |
+
def preprocess(self, text:str):
|
| 57 |
+
self.original_text=text
|
| 58 |
+
formatted_text=write_sentences_to_format(text.split("\n"), filename=None)
|
| 59 |
+
dataset, _ = read_dataset(
|
| 60 |
+
formatted_text,
|
| 61 |
+
output_path=self.output_folder,
|
| 62 |
+
config=self.config,
|
| 63 |
+
add_lang_token=True,
|
| 64 |
+
add_frame_token=True,
|
| 65 |
+
)
|
| 66 |
+
return {"dataset": dataset}
|
| 67 |
+
|
| 68 |
+
def _forward(self, inputs):
|
| 69 |
+
dataset = inputs["dataset"]
|
| 70 |
+
preds_from_model, label_ids, _ = retrieve_predictions(
|
| 71 |
+
self.model, dataset, self.output_folder, self.tokenizer, self.config
|
| 72 |
+
)
|
| 73 |
+
return {"preds": preds_from_model, "labels": label_ids, "dataset": dataset}
|
| 74 |
+
|
| 75 |
+
def postprocess(self, outputs):
|
| 76 |
+
preds = np.argmax(outputs["preds"], axis=-1)
|
| 77 |
+
predictions = align_tokens_labels_from_wordids(preds, outputs["dataset"], self.tokenizer)
|
| 78 |
+
edus=text_to_edus(self.original_text, predictions)
|
| 79 |
+
return edus
|
| 80 |
+
|
| 81 |
+
def get_plain_text_from_format(formatted_text:str) -> str:
|
| 82 |
+
"""
|
| 83 |
+
Lit un fichier conllu ou tok et retourne son contenu sous forme de chaîne de caractères.
|
| 84 |
+
"""
|
| 85 |
+
formatted_text=formatted_text.split("\n")
|
| 86 |
+
s=""
|
| 87 |
+
for line in formatted_text:
|
| 88 |
+
if not line.startswith("#"):
|
| 89 |
+
if len(line.split("\t"))>1:
|
| 90 |
+
s+=line.split("\t")[1]+" "
|
| 91 |
+
return s.strip()
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def get_preds_from_format(formatted_text:str) -> str:
|
| 95 |
+
"""
|
| 96 |
+
Lit un fichier conllu ou tok et retourne son contenu sous forme de chaîne de caractères.
|
| 97 |
+
"""
|
| 98 |
+
formatted_text=formatted_text.split("\n")
|
| 99 |
+
s=""
|
| 100 |
+
for line in formatted_text:
|
| 101 |
+
if not line.startswith("#"):
|
| 102 |
+
if len(line.split("\t"))>1:
|
| 103 |
+
s+=line.split("\t")[-1]+" "
|
| 104 |
+
return s.strip()
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def text_to_edus(text: str, labels: list[str]) -> list[str]:
|
| 108 |
+
"""
|
| 109 |
+
Découpe un texte brut en EDUs à partir d'une séquence de labels BIO.
|
| 110 |
+
|
| 111 |
+
Args:
|
| 112 |
+
text (str): Le texte brut (séquence de mots séparés par des espaces).
|
| 113 |
+
labels (list[str]): La séquence de labels BIO (B, I, O),
|
| 114 |
+
de même longueur que le nombre de tokens du texte.
|
| 115 |
+
|
| 116 |
+
Returns:
|
| 117 |
+
list[str]: La liste des EDUs (chaque EDU est une sous-chaîne du texte).
|
| 118 |
+
"""
|
| 119 |
+
words = text.strip().split()
|
| 120 |
+
if len(words) != len(labels):
|
| 121 |
+
raise ValueError(f"Longueur mismatch: {len(words)} mots vs {len(labels)} labels")
|
| 122 |
+
|
| 123 |
+
edus = []
|
| 124 |
+
current_edu = []
|
| 125 |
+
|
| 126 |
+
for word, label in zip(words, labels):
|
| 127 |
+
if label == "Conn=O" or label == "Seg=O":
|
| 128 |
+
current_edu.append(word)
|
| 129 |
+
|
| 130 |
+
elif label == "Conn=B-conn" or label == "Seg=B-seg":
|
| 131 |
+
# Finir l'EDU courant si ouvert
|
| 132 |
+
if current_edu:
|
| 133 |
+
|
| 134 |
+
edus.append(" ".join(current_edu))
|
| 135 |
+
current_edu = []
|
| 136 |
+
current_edu.append(word)
|
| 137 |
+
|
| 138 |
+
# Si un EDU est resté ouvert, on le ferme
|
| 139 |
+
if current_edu:
|
| 140 |
+
edus.append(" ".join(current_edu))
|
| 141 |
+
|
| 142 |
+
return edus
|
reading.py
ADDED
|
@@ -0,0 +1,512 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/python
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
import os, sys
|
| 5 |
+
|
| 6 |
+
import datasets
|
| 7 |
+
import transformers
|
| 8 |
+
|
| 9 |
+
import disrpt_io
|
| 10 |
+
import utils
|
| 11 |
+
|
| 12 |
+
# TODO to rm when dealt with this issue of loading languages
|
| 13 |
+
##from ersatz import utils
|
| 14 |
+
##LANGUAGES = utils.MODELS.keys()
|
| 15 |
+
LANGUAGES = []
|
| 16 |
+
|
| 17 |
+
def read_dataset( input_path, output_path, config, add_lang_token=True,add_frame_token=True,lang_token="",frame_token="" ):
|
| 18 |
+
'''
|
| 19 |
+
- Read the file in input_path
|
| 20 |
+
- Return a Dataset corresponding to the file
|
| 21 |
+
|
| 22 |
+
Parameters
|
| 23 |
+
----------
|
| 24 |
+
input_path : str
|
| 25 |
+
Path to the dataset
|
| 26 |
+
output_path : str
|
| 27 |
+
Path to an output directory that can be used to write new split files
|
| 28 |
+
tokenizer : AutoTokenizer
|
| 29 |
+
Tokenizer corresponding the checkpoint model
|
| 30 |
+
add_lang_token : bool
|
| 31 |
+
If True, add a special language token at the beginning of each sequence
|
| 32 |
+
|
| 33 |
+
Returns
|
| 34 |
+
-------
|
| 35 |
+
Dataset
|
| 36 |
+
Contain Dataset built from train_path and dev_path for train mode,
|
| 37 |
+
only dev / test pasth else
|
| 38 |
+
Tokenizer
|
| 39 |
+
The tokenizer used for the dataset
|
| 40 |
+
'''
|
| 41 |
+
model_checkpoint = config["model_checkpoint"]
|
| 42 |
+
# -- Init tokenizer
|
| 43 |
+
tokenizer = transformers.AutoTokenizer.from_pretrained( model_checkpoint )
|
| 44 |
+
# -- Read and tokenize
|
| 45 |
+
dataset = DatasetSeq( input_path, output_path, config, tokenizer, add_lang_token=add_lang_token,add_frame_token=add_frame_token,lang_token=lang_token,frame_token=frame_token )
|
| 46 |
+
dataset.read_and_tokenize()
|
| 47 |
+
# TODO move in class? or do elsewhere
|
| 48 |
+
LABEL_NAMES_BIO = retrieve_bio_labels( dataset ) # TODO should do it only once for all
|
| 49 |
+
dataset.set_label_names_bio(LABEL_NAMES_BIO)
|
| 50 |
+
return dataset, tokenizer
|
| 51 |
+
|
| 52 |
+
# --------------------------------------------------------------------------
|
| 53 |
+
# DatasetDict
|
| 54 |
+
|
| 55 |
+
class DatasetDisc( ):
|
| 56 |
+
def __init__(self, annotations_file, output_path, config, tokenizer, dset=None ):
|
| 57 |
+
"""
|
| 58 |
+
Here we save the location of our input file,
|
| 59 |
+
load the data, i.e. retrieve the list of texts and associated labels,
|
| 60 |
+
build the vocabulary if none is given,
|
| 61 |
+
and define the pipelines used to prepare the data
|
| 62 |
+
"""
|
| 63 |
+
self.annotations_file = annotations_file
|
| 64 |
+
if isinstance(annotations_file, str) and not os.path.isfile(annotations_file):
|
| 65 |
+
print("this is a string dataset")
|
| 66 |
+
self.basename = "input"
|
| 67 |
+
else:
|
| 68 |
+
self.basename = os.path.basename( self.annotations_file )
|
| 69 |
+
self.dset = self.basename.split(".")[2].split('_')[1]
|
| 70 |
+
self.corpus_name = self.basename.split('_')[0]
|
| 71 |
+
|
| 72 |
+
self.tokenizer = tokenizer
|
| 73 |
+
self.config = config
|
| 74 |
+
# If a sentence splitter is used, the files with the new segmentation will be saved here
|
| 75 |
+
self.output_path = output_path
|
| 76 |
+
|
| 77 |
+
# Retriev info from config: TODO check against info from dir name?
|
| 78 |
+
self.mode = config["type"]
|
| 79 |
+
self.task = config["task"]
|
| 80 |
+
self.trace = config["trace"]
|
| 81 |
+
self.tok_config = config["tok_config"]
|
| 82 |
+
self.sent_spliter = config["sent_spliter"]
|
| 83 |
+
|
| 84 |
+
# Additional fields
|
| 85 |
+
self.id2label, self.label2id = {}, {}
|
| 86 |
+
|
| 87 |
+
# -- Use disrpt_io to read the file and retrieve annotated data
|
| 88 |
+
self.corpus = init_corpus( self.task ) # initialize a Corpus instance, depending on the task
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def read_and_tokenize( self ):
|
| 95 |
+
print("\n-- READ FROM FILE:", self.annotations_file )
|
| 96 |
+
try:
|
| 97 |
+
self.read_annotations( )
|
| 98 |
+
except Exception as err:
|
| 99 |
+
print(f"Unexpected {err=}, {type(err)=}", file=sys.stderr)
|
| 100 |
+
raise
|
| 101 |
+
# print( "Problem when reading", self.annotations_file )
|
| 102 |
+
|
| 103 |
+
#print("\n-- SET LABELS")
|
| 104 |
+
self.set_labels( )
|
| 105 |
+
print( "self.label2id", self.label2id )
|
| 106 |
+
|
| 107 |
+
#print("\n-- TOKENIZE DATASET")
|
| 108 |
+
self.tokenize_dataset()
|
| 109 |
+
if self.trace:
|
| 110 |
+
if self.dset:
|
| 111 |
+
print( "\n-- FINISHED READING", self.dset, "PRINTING TRACE --")
|
| 112 |
+
self.print_trace()
|
| 113 |
+
|
| 114 |
+
def tokenize_datasets( self ):
|
| 115 |
+
# Specific to subclasses
|
| 116 |
+
raise NotImplementedError
|
| 117 |
+
|
| 118 |
+
def set_labels( self ):
|
| 119 |
+
# Specific to subclasses
|
| 120 |
+
raise NotImplementedError
|
| 121 |
+
|
| 122 |
+
# outside the class?
|
| 123 |
+
# TODO use **kwags instead?
|
| 124 |
+
def read_annotations( self ):
|
| 125 |
+
'''
|
| 126 |
+
Generate a Corpus object based on the input_file.
|
| 127 |
+
Since .tok files are not segmented into sentences, a sentence splitter
|
| 128 |
+
is used (here, ersatz)
|
| 129 |
+
'''
|
| 130 |
+
if os.path.isfile(self.annotations_file):
|
| 131 |
+
self.corpus.from_file(self.annotations_file)
|
| 132 |
+
lang = os.path.basename(self.annotations_file).split(".")[0]
|
| 133 |
+
frame = os.path.basename(self.annotations_file).split(".")[1]
|
| 134 |
+
base = os.path.basename(self.annotations_file)
|
| 135 |
+
else:
|
| 136 |
+
# on suppose que c’est du texte brut déjà au format attendu
|
| 137 |
+
src = self.mode if self.mode in ["tok", "conllu", "split"] else "conllu"
|
| 138 |
+
self.corpus.from_string(self.annotations_file,src=src)
|
| 139 |
+
lang = self.lang_token
|
| 140 |
+
frame = self.frame_token
|
| 141 |
+
base = "input.text"
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
#print(f"[DEBUG] lang? {lang}")
|
| 146 |
+
for doc in self.corpus.docs:
|
| 147 |
+
doc.lang = lang
|
| 148 |
+
doc.frame = frame
|
| 149 |
+
# print(corpus)
|
| 150 |
+
# Split corpus into sentences using Ersatz
|
| 151 |
+
if self.mode == 'tok':
|
| 152 |
+
kwargs={}
|
| 153 |
+
from wtpsplit import SaT
|
| 154 |
+
sat_version="sat-3l"
|
| 155 |
+
if "sat_model" in self.config:
|
| 156 |
+
sat_version=self.config["sat_model"]
|
| 157 |
+
|
| 158 |
+
sat_model = SaT(sat_version)
|
| 159 |
+
kwargs["sat_model"] = sat_model
|
| 160 |
+
self.corpus.sentence_split(model = self.sent_spliter, lang="default-multilingual",sat_model=sat_model)
|
| 161 |
+
# Writing files with the split sentences
|
| 162 |
+
parts = base.split(".")[:-1]
|
| 163 |
+
split_filename = ".".join(parts) + ".split"
|
| 164 |
+
split_file = os.path.join(self.output_path, split_filename)
|
| 165 |
+
self.corpus.format(file=split_file)
|
| 166 |
+
# no need for sentence splitting if mode = conllu or split, no need to write files
|
| 167 |
+
|
| 168 |
+
def print_trace( self ):
|
| 169 |
+
print( "\n| Annotation_file: ", self.annotations_file )
|
| 170 |
+
print( '| Output_path:', self.output_path )
|
| 171 |
+
print( '| Nb_of_instances:', len(self.dataset), "(", len(self.dataset['labels']), ")" )
|
| 172 |
+
# "(", len(self.dataset['tokens']), len(self.dataset['labels']), ")" )
|
| 173 |
+
|
| 174 |
+
def print_stats( self ):
|
| 175 |
+
print( "| Annotation_file: ", self.annotations_file )
|
| 176 |
+
if self.dset: print( "| Data_split: ", self.dset )
|
| 177 |
+
print( "| Task: ", self.task )
|
| 178 |
+
print( "| Lang: ", self.lang )
|
| 179 |
+
print( "| Mode: ", self.mode )
|
| 180 |
+
print( "| Label_names: ", self.LABEL_NAMES)
|
| 181 |
+
#print( "---Number_of_documents", len( self.corpus.docs ) )
|
| 182 |
+
print( "| Number_of_instances: ", len(self.dataset) )
|
| 183 |
+
# TODO : add number of docs: not computed for .rels for now
|
| 184 |
+
|
| 185 |
+
# -------------------------------------------------------------------------------------------------
|
| 186 |
+
class DatasetSeq(DatasetDisc):
|
| 187 |
+
def __init__( self, annotations_file, output_path, config, tokenizer, add_lang_token=True, add_frame_token=True,
|
| 188 |
+
dset=None,lang_token="",frame_token="" ):
|
| 189 |
+
"""
|
| 190 |
+
Class for tasks corresponding to sequence labeling problem
|
| 191 |
+
(seg, conn).
|
| 192 |
+
Here we save the location of our input file,
|
| 193 |
+
load the data, i.e. retrieve the list of texts and associated
|
| 194 |
+
labels,
|
| 195 |
+
build the vocabulary if none is given,
|
| 196 |
+
and define the pipelines used to prepare the data """
|
| 197 |
+
DatasetDisc.__init__( self, annotations_file, output_path, config,
|
| 198 |
+
tokenizer )
|
| 199 |
+
self.add_lang_token = add_lang_token
|
| 200 |
+
self.add_frame_token=add_frame_token
|
| 201 |
+
self.lang_token = lang_token
|
| 202 |
+
self.frame_token=frame_token
|
| 203 |
+
|
| 204 |
+
if self.mode == 'tok' and self.output_path == None:
|
| 205 |
+
self.output_path = os.path.dirname( self.annotations_file )
|
| 206 |
+
self.output_path = os.path.join( self.output_path,
|
| 207 |
+
self.basename.replace("."+self.mode, ".split") )
|
| 208 |
+
|
| 209 |
+
self.sent_spliter = None
|
| 210 |
+
if "sent_spliter" in self.config:
|
| 211 |
+
self.sent_spliter = self.config["sent_spliter"] #only for seg
|
| 212 |
+
|
| 213 |
+
self.LABEL_NAMES_BIO = None
|
| 214 |
+
# # TODO not used, really a good idea?
|
| 215 |
+
# self.data_collator = transformers.DataCollatorForTokenClassification(tokenizer=self.tokenizer,
|
| 216 |
+
# padding=self.tok_config["padding"] )
|
| 217 |
+
|
| 218 |
+
def tokenize_dataset( self ):
|
| 219 |
+
# -- Create a HuggingFace Dataset object
|
| 220 |
+
if self.trace:
|
| 221 |
+
print(f"\n-- Creating dataset from generator (add_lang_token={self.add_lang_token})")
|
| 222 |
+
self.dataset = datasets.Dataset.from_generator(
|
| 223 |
+
gen,
|
| 224 |
+
gen_kwargs={"corpus": self.corpus, "label2id": self.label2id, "mode": self.mode, "add_lang_token": self.add_lang_token,"add_frame_token":self.add_frame_token},
|
| 225 |
+
)
|
| 226 |
+
if self.trace:
|
| 227 |
+
print( self.dataset[0])
|
| 228 |
+
# Keep track of the alignement between words ans subtokens, even if not ##
|
| 229 |
+
# BERT* add a tokenisation based on punctuation even if given with a list of words
|
| 230 |
+
self.all_word_ids = []
|
| 231 |
+
# Align labels according to tokenized subwords
|
| 232 |
+
if self.trace:
|
| 233 |
+
print( "\n-- Mapping dataset labels and subwords ")
|
| 234 |
+
self.tokenized_datasets = self.dataset.map(
|
| 235 |
+
tokenize_and_align_labels,
|
| 236 |
+
fn_kwargs = {"tokenizer":self.tokenizer,
|
| 237 |
+
"id2label":self.id2label,
|
| 238 |
+
"label2id":self.label2id,
|
| 239 |
+
"all_word_ids":self.all_word_ids,
|
| 240 |
+
"config":self.config},
|
| 241 |
+
batched=True,
|
| 242 |
+
remove_columns=self.dataset.column_names,
|
| 243 |
+
)
|
| 244 |
+
if self.trace:
|
| 245 |
+
print( self.tokenized_datasets[0])
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
def set_labels(self):
|
| 249 |
+
self.LABEL_NAMES = self.corpus.LABELS
|
| 250 |
+
self.id2label = {i: label for i, label in enumerate( self.LABEL_NAMES )}
|
| 251 |
+
self.label2id = {v: k for k,v in self.id2label.items()}
|
| 252 |
+
|
| 253 |
+
def set_label_names_bio( self, LABEL_NAMES_BIO ):
|
| 254 |
+
self.LABEL_NAMES_BIO = LABEL_NAMES_BIO
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
def print_trace( self ):
|
| 258 |
+
super().print_trace()
|
| 259 |
+
print( '\n--First sentence: original tokens and labels.\n')
|
| 260 |
+
print( self.dataset[0]['tokens'] )
|
| 261 |
+
print( self.dataset[0]['labels'] )
|
| 262 |
+
print( "\n---First sentence: tokenized version:\n")
|
| 263 |
+
print( self.tokenized_datasets[0] )
|
| 264 |
+
# print( '\nSource word ids:', len(self.all_word_ids) )
|
| 265 |
+
|
| 266 |
+
# # TODO prepaper a compute_stats before printing, to allow partial printing without trace mode
|
| 267 |
+
# def print_stats( self ):
|
| 268 |
+
# super().print_stats()
|
| 269 |
+
# print( "| Number_of_documents", len( self.corpus.docs ) )
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
def init_corpus( task ):
|
| 273 |
+
if task.strip().lower() == 'conn':
|
| 274 |
+
return disrpt_io.ConnectiveCorpus()
|
| 275 |
+
elif task == 'seg':
|
| 276 |
+
return disrpt_io.SegmentCorpus()
|
| 277 |
+
else:
|
| 278 |
+
raise NotImplementedError
|
| 279 |
+
|
| 280 |
+
def gen( corpus, label2id, mode, add_lang_token=True,add_frame_token=True ):
|
| 281 |
+
# Ajout d'un token spécial langue au début de chaque séquence
|
| 282 |
+
source = "split"
|
| 283 |
+
if mode == 'conllu':
|
| 284 |
+
source = "conllu"
|
| 285 |
+
for doc in corpus.docs:
|
| 286 |
+
lang = getattr(doc, 'lang', 'xx') if hasattr(doc, 'lang') else 'xx'
|
| 287 |
+
lang_token = f"[LANG={lang}]"
|
| 288 |
+
|
| 289 |
+
frame = getattr(doc, 'frame', 'xx') if hasattr(doc, 'lang') else 'xx'
|
| 290 |
+
frame_token = f"[FRAME={frame}]"
|
| 291 |
+
sent_list = doc.sentences[source] if source in doc.sentences else doc.sentences
|
| 292 |
+
for sentence in sent_list:
|
| 293 |
+
labels = []
|
| 294 |
+
tokens = []
|
| 295 |
+
if add_lang_token:
|
| 296 |
+
tokens.append(lang_token)
|
| 297 |
+
labels.append(-100)
|
| 298 |
+
if add_frame_token:
|
| 299 |
+
tokens.append(frame_token)
|
| 300 |
+
labels.append(-100)
|
| 301 |
+
#print(f"[DEBUG] Ajout du token frame {frame_token} pour la phrase: {' '.join([t.form for t in sentence.toks])}")
|
| 302 |
+
for t in sentence.toks:
|
| 303 |
+
tokens.append(t.form)
|
| 304 |
+
if t.label == '_':
|
| 305 |
+
if 'O' in label2id:
|
| 306 |
+
labels.append(label2id['O'])
|
| 307 |
+
else:
|
| 308 |
+
labels.append(list(label2id.values())[0])
|
| 309 |
+
else:
|
| 310 |
+
labels.append(label2id[t.label])
|
| 311 |
+
yield {
|
| 312 |
+
"tokens": tokens,
|
| 313 |
+
"labels": labels
|
| 314 |
+
}
|
| 315 |
+
|
| 316 |
+
|
| 317 |
+
def get_tokenizer( model_checkpoint ):
|
| 318 |
+
return transformers.AutoTokenizer.from_pretrained(model_checkpoint)
|
| 319 |
+
|
| 320 |
+
def tokenize_and_align_labels( dataset, tokenizer, id2label, label2id, all_word_ids, config ):
|
| 321 |
+
'''
|
| 322 |
+
(Done in batches)
|
| 323 |
+
To preprocess our whole dataset, we need to tokenize all the inputs and
|
| 324 |
+
apply align_labels_with_tokens() on all the labels.
|
| 325 |
+
(with HG, we could use Dataset.map to process batches)
|
| 326 |
+
The word_ids() function needs to get the index of the example we want
|
| 327 |
+
the word IDs of when the inputs to the tokenizer are lists of texts
|
| 328 |
+
(or in our case, list of lists of words), so we add that too:
|
| 329 |
+
"tok_config"
|
| 330 |
+
'''
|
| 331 |
+
tokenized_inputs = tokenizer(
|
| 332 |
+
dataset["tokens"],
|
| 333 |
+
truncation=config["tok_config"]['truncation'],
|
| 334 |
+
padding=config["tok_config"]['padding'],
|
| 335 |
+
max_length=config["tok_config"]['max_length'],
|
| 336 |
+
is_split_into_words=True
|
| 337 |
+
)
|
| 338 |
+
# tokenized_inputs = tokenizer(
|
| 339 |
+
# dataset["tokens"], truncation=True, padding=True, is_split_into_words=True
|
| 340 |
+
# )
|
| 341 |
+
all_labels = dataset["labels"]
|
| 342 |
+
new_labels = []
|
| 343 |
+
#print( "tokenized_inputs.word_ids()", tokenized_inputs.word_ids() )
|
| 344 |
+
#print( [tokenizer.decode(tok) for tok in tokenized_inputs['input_ids']])
|
| 345 |
+
##with progressbar.ProgressBar(max_value=len(all_labels)) as bar:
|
| 346 |
+
##for i in tqdm(range(len(all_labels))):
|
| 347 |
+
for i, labels in enumerate(all_labels):
|
| 348 |
+
word_ids = tokenized_inputs.word_ids(i)
|
| 349 |
+
new_labels.append(align_labels_with_tokens(labels, word_ids, id2label, label2id, tokenizer, tokenized_inputs ))
|
| 350 |
+
# Used to fill the self.word_ids field of the Dataset object, but should probably be done some<here else
|
| 351 |
+
all_word_ids.append( word_ids )
|
| 352 |
+
##bar.update(i)
|
| 353 |
+
tokenized_inputs["labels"] = new_labels
|
| 354 |
+
return tokenized_inputs
|
| 355 |
+
|
| 356 |
+
def align_labels_with_tokens(labels, word_ids, id2label, label2id, tokenizer, tokenized_inputs):
|
| 357 |
+
'''
|
| 358 |
+
BERT like tokenization will create new tokens, we need to align labels.
|
| 359 |
+
Special tokens get a label of -100. This is because by default -100 is an
|
| 360 |
+
index that is ignored in the loss function we will use (cross entropy).
|
| 361 |
+
Then, each token gets the same label as the token that started the word
|
| 362 |
+
it’s inside, since they are part of the same entity. For tokens inside a
|
| 363 |
+
word but not at the beginning, we replace the B- with I- (since the token
|
| 364 |
+
does not begin the entity). [Taken from HF website course on NER]
|
| 365 |
+
'''
|
| 366 |
+
count = 0
|
| 367 |
+
new_labels = []
|
| 368 |
+
current_word = None
|
| 369 |
+
for word_id in word_ids:
|
| 370 |
+
count += 1
|
| 371 |
+
if word_id==0: # ou 1 peut etre
|
| 372 |
+
#TODO
|
| 373 |
+
#add lang token -100
|
| 374 |
+
pass
|
| 375 |
+
if word_id != current_word:
|
| 376 |
+
# Start of a new word!
|
| 377 |
+
current_word = word_id
|
| 378 |
+
label = -100 if word_id is None else labels[word_id]
|
| 379 |
+
new_labels.append(label)
|
| 380 |
+
elif word_id is None:
|
| 381 |
+
# Special token
|
| 382 |
+
new_labels.append(-100)
|
| 383 |
+
else:
|
| 384 |
+
# Same word as previous token
|
| 385 |
+
label = labels[word_id]
|
| 386 |
+
# On ne cherche 'B-' que si label != -100
|
| 387 |
+
if label != -100 and 'B-' in id2label[label]:
|
| 388 |
+
label = -100
|
| 389 |
+
new_labels.append(label)
|
| 390 |
+
return new_labels
|
| 391 |
+
|
| 392 |
+
|
| 393 |
+
def retrieve_bio_labels( dataset ):
|
| 394 |
+
'''
|
| 395 |
+
Needed for compute_metrics, I think? It seems to be using a classic metrics for BIO
|
| 396 |
+
scheme, thus we create a mapping to BIO labels, i.e.:
|
| 397 |
+
'_' --> 'O'
|
| 398 |
+
'Seg=B-Conn' --> 'B'
|
| 399 |
+
'Seg=I-Conn' --> 'I'
|
| 400 |
+
Should also work for segmentation TODO: check
|
| 401 |
+
datasets: dict: DatasetSeq instances for train/dev/test
|
| 402 |
+
Return: list: original label names
|
| 403 |
+
list: label names mapped to BIO
|
| 404 |
+
'''
|
| 405 |
+
# need a Dataset instance to retrieve the original label sets
|
| 406 |
+
task = dataset.task
|
| 407 |
+
LABEL_NAMES_BIO = []
|
| 408 |
+
LABEL_NAMES = dataset.LABEL_NAMES
|
| 409 |
+
label2idx, idx2newl = {}, {}
|
| 410 |
+
if task in ["conn", "seg"]:
|
| 411 |
+
for i,l in enumerate( LABEL_NAMES ):
|
| 412 |
+
label2idx[l] = i
|
| 413 |
+
for l in label2idx:
|
| 414 |
+
nl = ''
|
| 415 |
+
if 'B' in l:
|
| 416 |
+
nl = 'B'
|
| 417 |
+
elif 'I' in l:
|
| 418 |
+
nl = 'I'
|
| 419 |
+
else:
|
| 420 |
+
nl = 'O'
|
| 421 |
+
idx2newl[label2idx[l]] = nl
|
| 422 |
+
for i in sorted(idx2newl):
|
| 423 |
+
LABEL_NAMES_BIO.append(idx2newl[i])
|
| 424 |
+
#label_names = ['O', 'B', 'I']
|
| 425 |
+
return LABEL_NAMES_BIO
|
| 426 |
+
|
| 427 |
+
# def _compute_distrib( dataset, id2label ):
|
| 428 |
+
# distrib = {}
|
| 429 |
+
# multi = []
|
| 430 |
+
# for inst in dataset:
|
| 431 |
+
# label = id2label[inst['label']]
|
| 432 |
+
# if label in distrib:
|
| 433 |
+
# distrib[label] += 1
|
| 434 |
+
# else:
|
| 435 |
+
# distrib[label] = 1
|
| 436 |
+
# len_labels = len( inst["all_labels"])
|
| 437 |
+
# if len_labels > 1:
|
| 438 |
+
# #count_multi += 1
|
| 439 |
+
# multi.append( len_labels )
|
| 440 |
+
# return distrib, multi
|
| 441 |
+
|
| 442 |
+
# Defines the language code for the sentence spliter, should be done in disrpt_io?
|
| 443 |
+
def set_language( lang ):
|
| 444 |
+
#lang = "default-multilingual" #default value
|
| 445 |
+
# patch
|
| 446 |
+
if lang=="sp": lang="es"
|
| 447 |
+
if lang not in LANGUAGES:
|
| 448 |
+
lang = "default-multilingual"
|
| 449 |
+
return lang
|
| 450 |
+
|
| 451 |
+
|
| 452 |
+
# ------------------------------------------------------------------
|
| 453 |
+
if __name__=="__main__":
|
| 454 |
+
import argparse, os
|
| 455 |
+
|
| 456 |
+
parser = argparse.ArgumentParser(
|
| 457 |
+
description='DISCUT: reading data from disrpt_io and converting to HuggingFace'
|
| 458 |
+
)
|
| 459 |
+
# TRAIN AND DEV are (list of) FILES or DIRECTORIES
|
| 460 |
+
parser.add_argument("-t", "--train",
|
| 461 |
+
help="Training file. Default: data_test/eng.sample.rstdt/eng.sample.rstdt_train.conllu",
|
| 462 |
+
default="data_test/eng.sample.rstdt/eng.sample.rstdt_train.conllu")
|
| 463 |
+
|
| 464 |
+
parser.add_argument("-d", "--dev",
|
| 465 |
+
help="Dev file. Default: data/eng.sample.rstdt/eng.sample.rstdt_dev.conllu",
|
| 466 |
+
default="data_test/eng.sample.rstdt/eng.sample.rstdt_dev.conllu")
|
| 467 |
+
|
| 468 |
+
# OUTPUT DIRECTORY
|
| 469 |
+
parser.add_argument("-o", "--output",
|
| 470 |
+
help="Directory where models and pred will be saved. Default: /home/cbraud/experiments/expe_discut_2025/",
|
| 471 |
+
default="")
|
| 472 |
+
|
| 473 |
+
# CONFIG FILE
|
| 474 |
+
parser.add_argument("-c", "--config",
|
| 475 |
+
help="Config file. Default: ./config_seg.json",
|
| 476 |
+
default="./config_seg.json")
|
| 477 |
+
|
| 478 |
+
# TRACE / VERBOSITY
|
| 479 |
+
parser.add_argument( '-v', '--trace',
|
| 480 |
+
action='store_true',
|
| 481 |
+
default=False,
|
| 482 |
+
help="Whether to print full messages. If used, it will override the value in config file.")
|
| 483 |
+
|
| 484 |
+
args = parser.parse_args()
|
| 485 |
+
|
| 486 |
+
train_path = args.train
|
| 487 |
+
dev_path = args.dev
|
| 488 |
+
print(dev_path)
|
| 489 |
+
if not os.path.isfile(dev_path[0]):
|
| 490 |
+
print( "ERROR with dev file:", dev_path)
|
| 491 |
+
output_path = args.output
|
| 492 |
+
config_file = args.config
|
| 493 |
+
#eval = args.eval
|
| 494 |
+
trace = args.trace
|
| 495 |
+
|
| 496 |
+
print( '\n-[JEDIS]--PROGRAM (reader) ARGUMENTS')
|
| 497 |
+
print( '| Train_path', train_path )
|
| 498 |
+
print( '| Dev_path', dev_path )
|
| 499 |
+
print( "| Output_path", output_path )
|
| 500 |
+
print( '| Config', config_file )
|
| 501 |
+
|
| 502 |
+
print( '\n-[JEDIS]--CONFIG INFO')
|
| 503 |
+
config = utils.read_config( config_file )
|
| 504 |
+
utils.print_config(config)
|
| 505 |
+
# WE override the config file if the user says no trace in arguments
|
| 506 |
+
# easier than modifying the config files each time
|
| 507 |
+
if not trace:
|
| 508 |
+
config['trace'] = False
|
| 509 |
+
|
| 510 |
+
print( "\n-[JEDIS]--READING DATASETS" )
|
| 511 |
+
# dictionnary containing train (if model=='train') and/or dev (test) Dataset instance
|
| 512 |
+
datasets, tokenizer = read_dataset( train_path, dev_path, config, add_lang_token=True )
|
utils.py
ADDED
|
@@ -0,0 +1,216 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/python
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
import os, sys
|
| 5 |
+
import json
|
| 6 |
+
import numpy as np
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
import itertools
|
| 9 |
+
|
| 10 |
+
import evaluate
|
| 11 |
+
import disrpt_eval_2025
|
| 12 |
+
#from .disrpt_eval_2025 import *
|
| 13 |
+
|
| 14 |
+
# TODO : should be conditioned on the task or the metric indicated in the config file ??
|
| 15 |
+
def prepare_compute_metrics(LABEL_NAMES):
|
| 16 |
+
'''
|
| 17 |
+
Return the method to be used in the trainer loop.
|
| 18 |
+
For seg or conn, based on seqeval, and here ignore tokens with label
|
| 19 |
+
-100 (okay ?)
|
| 20 |
+
|
| 21 |
+
Parameters :
|
| 22 |
+
------------
|
| 23 |
+
LABEL_NAMES: Dict
|
| 24 |
+
Needed only for BIO labels, convert to the right labels for seqeval
|
| 25 |
+
task: str
|
| 26 |
+
Should be either 'seg', 'conn', but could be expanded to other
|
| 27 |
+
sequence / classif tasks
|
| 28 |
+
|
| 29 |
+
Returns :
|
| 30 |
+
---------
|
| 31 |
+
compute_metrics: function
|
| 32 |
+
'''
|
| 33 |
+
def compute_metrics(eval_preds):
|
| 34 |
+
nonlocal LABEL_NAMES
|
| 35 |
+
# nonlocal task
|
| 36 |
+
# Retrieve gold and predictions
|
| 37 |
+
logits, labels = eval_preds
|
| 38 |
+
|
| 39 |
+
predictions = np.argmax(logits, axis=-1)
|
| 40 |
+
metric = evaluate.load("seqeval")
|
| 41 |
+
# Remove ignored index (special tokens) and convert to labels
|
| 42 |
+
true_labels = [[LABEL_NAMES[l] for l in label if l != -100] for label in labels]
|
| 43 |
+
true_predictions = [
|
| 44 |
+
[LABEL_NAMES[p] for (p, l) in zip(prediction, label) if l != -100]
|
| 45 |
+
for prediction, label in zip(predictions, labels)
|
| 46 |
+
]
|
| 47 |
+
all_metrics = metric.compute(predictions=true_predictions, references=true_labels)
|
| 48 |
+
print_metrics( all_metrics )
|
| 49 |
+
return {
|
| 50 |
+
"precision": all_metrics["overall_precision"],
|
| 51 |
+
"recall": all_metrics["overall_recall"],
|
| 52 |
+
"f1": all_metrics["overall_f1"],
|
| 53 |
+
"accuracy": all_metrics["overall_accuracy"],
|
| 54 |
+
}
|
| 55 |
+
return compute_metrics
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def print_metrics( all_metrics ):
|
| 59 |
+
#print( all_metrics )
|
| 60 |
+
for p,v in all_metrics.items():
|
| 61 |
+
if '_' in p:
|
| 62 |
+
print( p, v )
|
| 63 |
+
else:
|
| 64 |
+
print( p+' = '+str(v))
|
| 65 |
+
|
| 66 |
+
def compute_metrics_dirspt( dataset_eval, pred_file, task='seg' ):
|
| 67 |
+
print( "\nPerformance computed using disrpt eval script on", dataset_eval.annotations_file,
|
| 68 |
+
pred_file )
|
| 69 |
+
if task == 'seg':
|
| 70 |
+
#clean_pred_file(pred_file, os.path.basename(pred_file)+"cleaned.preds")
|
| 71 |
+
my_eval = disrpt_eval_2025.SegmentationEvaluation("temp_test_disrpt_eval_seg",
|
| 72 |
+
dataset_eval.annotations_file,
|
| 73 |
+
pred_file )
|
| 74 |
+
elif task == 'conn':
|
| 75 |
+
my_eval = disrpt_eval_2025.ConnectivesEvaluation("temp_test_disrpt_eval_conn",
|
| 76 |
+
dataset_eval.annotations_file,
|
| 77 |
+
pred_file )
|
| 78 |
+
else:
|
| 79 |
+
raise NotImplementedError
|
| 80 |
+
my_eval.compute_scores()
|
| 81 |
+
my_eval.print_results()
|
| 82 |
+
|
| 83 |
+
def clean_pred_file(pred_path: str, out_path: str):
|
| 84 |
+
c=0
|
| 85 |
+
with open(pred_path, "r", encoding="utf8") as fin, open(out_path, "w", encoding="utf8") as fout:
|
| 86 |
+
for line in fin:
|
| 87 |
+
if line.strip() == "" or line.startswith("#"):
|
| 88 |
+
fout.write(line)
|
| 89 |
+
continue
|
| 90 |
+
fields = line.strip().split("\t")
|
| 91 |
+
token = fields[1]
|
| 92 |
+
if token.startswith("[LANG=") or token.startswith("[FRAME="):
|
| 93 |
+
c+=1
|
| 94 |
+
continue # skip meta-tokens
|
| 95 |
+
fout.write(line)
|
| 96 |
+
print(f"we've cleaned {c} tokens")
|
| 97 |
+
# -------------------------------------------------------------------------------------------------
|
| 98 |
+
# ------ UTILS FUNCTIONS
|
| 99 |
+
# -------------------------------------------------------------------------------------------------
|
| 100 |
+
def read_config( config_file ):
|
| 101 |
+
'''Read the config file for training'''
|
| 102 |
+
f = open(config_file)
|
| 103 |
+
config = json.load(f)
|
| 104 |
+
if 'frozen' in config['trainer_config']:
|
| 105 |
+
config['trainer_config']["frozen"] = update_frozen_set( config['trainer_config']["frozen"] )
|
| 106 |
+
return config
|
| 107 |
+
|
| 108 |
+
def update_frozen_set( freeze ):
|
| 109 |
+
# MAke a set from the list of frozen layers
|
| 110 |
+
# [] --> nothing frozen
|
| 111 |
+
# [3] --> only layer 3 frozen
|
| 112 |
+
# [0,3] --> only layers 0 and 3
|
| 113 |
+
# [0-3, 12, 15] --> layers 0 to 3 included, + layers 12 and layers 15
|
| 114 |
+
frozen = set()
|
| 115 |
+
for spec in freeze:
|
| 116 |
+
if "-" in spec: # eg 1-9
|
| 117 |
+
b, e = spec.split("-")
|
| 118 |
+
frozen = frozen | set(range(int(b),int(e)+1))
|
| 119 |
+
else:
|
| 120 |
+
frozen.add(int(spec))
|
| 121 |
+
return frozen
|
| 122 |
+
|
| 123 |
+
def print_config(config):
|
| 124 |
+
'''Print info from config dictionary'''
|
| 125 |
+
print('\n'.join([ '| '+k+": "+str(v) for (k,v) in config.items() ]))
|
| 126 |
+
|
| 127 |
+
# -------------------------------------------------------------------------------------------------
|
| 128 |
+
def retrieve_files_dataset( input_path, list_dataset, mode='conllu', dset='train' ):
|
| 129 |
+
if mode == 'conllu':
|
| 130 |
+
pat = ".[cC][oO][nN][lL][lL][uU]"
|
| 131 |
+
elif mode == 'tok':
|
| 132 |
+
pat = ".[tT][oO][kK]"
|
| 133 |
+
else:
|
| 134 |
+
sys.exit('Unknown mode for file extension: '+mode)
|
| 135 |
+
if len(list_dataset) == 0:
|
| 136 |
+
return list(Path(input_path).rglob("*_"+dset+pat))
|
| 137 |
+
else:
|
| 138 |
+
# files eng.pdtb.pdtb_train.conllu
|
| 139 |
+
matched = []
|
| 140 |
+
for subdir in os.listdir( input_path ):
|
| 141 |
+
if subdir in list_dataset:
|
| 142 |
+
matched.extend( list(Path(os.path.join(input_path,subdir)).rglob("*_"+dset+pat)) )
|
| 143 |
+
return matched
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
# -------------------------------------------------------------------------------------------------
|
| 147 |
+
# https://wandb.ai/site
|
| 148 |
+
def init_wandb( config, model_checkpoint, annotations_file ):
|
| 149 |
+
'''
|
| 150 |
+
Initialize a new WANDB project to keep track of the experiments.
|
| 151 |
+
Parameters
|
| 152 |
+
----------
|
| 153 |
+
config : dict
|
| 154 |
+
Allow to retrieve the name of the entity and project (from config file)
|
| 155 |
+
model_checkpoint :
|
| 156 |
+
Name of the PLM used
|
| 157 |
+
annotations_file : str
|
| 158 |
+
Path to the training file
|
| 159 |
+
|
| 160 |
+
Returns
|
| 161 |
+
-------
|
| 162 |
+
None
|
| 163 |
+
'''
|
| 164 |
+
print("HERE WE INITIALIZE A WANDB PROJECT")
|
| 165 |
+
|
| 166 |
+
import wandb
|
| 167 |
+
proj_wandb = config["wandb"]
|
| 168 |
+
ent_wandbd = config["wandb_ent"]
|
| 169 |
+
# start a new wandb run to track this script
|
| 170 |
+
# The project name must be set before initializing the trainer
|
| 171 |
+
wandb.init(
|
| 172 |
+
# set the wandb project where this run will be logged
|
| 173 |
+
project=proj_wandb,
|
| 174 |
+
entity=ent_wandbd,
|
| 175 |
+
# track hyperparameters and run metadata
|
| 176 |
+
config={
|
| 177 |
+
"model_checkpoint": model_checkpoint,
|
| 178 |
+
"dataset": annotations_file,
|
| 179 |
+
}
|
| 180 |
+
)
|
| 181 |
+
wandb.define_metric("epoch")
|
| 182 |
+
wandb.define_metric("epoch")
|
| 183 |
+
wandb.define_metric("f1", step_metric="batch")
|
| 184 |
+
wandb.define_metric("f1", step_metric="epoch")
|
| 185 |
+
|
| 186 |
+
def set_name_output_dir( output_dir, config, corpus_name ):
|
| 187 |
+
'''
|
| 188 |
+
Set the path name for the target directory used to store models. The name should contain
|
| 189 |
+
info about the task, the PLM and the hyperparameter values.
|
| 190 |
+
|
| 191 |
+
Parameters
|
| 192 |
+
----------
|
| 193 |
+
output_dir : str
|
| 194 |
+
Path to the output directory provided by the user
|
| 195 |
+
config: dict
|
| 196 |
+
Information of configuration
|
| 197 |
+
corpus_name: str
|
| 198 |
+
Name of the corpus
|
| 199 |
+
|
| 200 |
+
Returns
|
| 201 |
+
-------
|
| 202 |
+
Str: Path to the output directory
|
| 203 |
+
'''
|
| 204 |
+
# Retrieve decimal number for learning rate, to avoir scientific notation
|
| 205 |
+
hyperparam = [
|
| 206 |
+
config['trainer_config']['batch_size'],
|
| 207 |
+
np.format_float_positional(config['trainer_config']['learning_rate'])
|
| 208 |
+
]
|
| 209 |
+
output_dir = os.path.join( output_dir,
|
| 210 |
+
'_'.join( [
|
| 211 |
+
corpus_name,
|
| 212 |
+
config["model_name"],
|
| 213 |
+
config["task"],
|
| 214 |
+
'_'.join([str(p) for p in hyperparam])
|
| 215 |
+
] ) )
|
| 216 |
+
return output_dir
|