Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +1 -0
- ReadMe.md +47 -0
- data/Multitask_case_dev_VST +0 -0
- data/Multitask_case_dev_san +0 -0
- data/Multitask_case_poetry_san +0 -0
- data/Multitask_case_prose_san +0 -0
- data/Multitask_case_test_VST +0 -0
- data/Multitask_case_test_san +0 -0
- data/Multitask_case_train_VST +0 -0
- data/Multitask_case_train_san +0 -0
- data/Multitask_label_dev_VST +0 -0
- data/Multitask_label_dev_san +0 -0
- data/Multitask_label_poetry_san +0 -0
- data/Multitask_label_prose_san +0 -0
- data/Multitask_label_test_VST +0 -0
- data/Multitask_label_test_san +0 -0
- data/Multitask_label_train_VST +0 -0
- data/Multitask_label_train_san +0 -0
- data/Multitask_morph_dev_VST +0 -0
- data/Multitask_morph_dev_san +0 -0
- data/Multitask_morph_poetry_san +0 -0
- data/Multitask_morph_prose_san +0 -0
- data/Multitask_morph_test_VST +0 -0
- data/Multitask_morph_test_san +0 -0
- data/Multitask_morph_train_VST +0 -0
- data/Multitask_morph_train_san +0 -0
- data/combined_1300_test.txt +0 -0
- data/ud_pos_ner_dp_dev_VST +0 -0
- data/ud_pos_ner_dp_dev_san +0 -0
- data/ud_pos_ner_dp_poetry_VST +0 -0
- data/ud_pos_ner_dp_poetry_san +0 -0
- data/ud_pos_ner_dp_prose_VST +0 -0
- data/ud_pos_ner_dp_prose_san +0 -0
- data/ud_pos_ner_dp_test_VST +0 -0
- data/ud_pos_ner_dp_test_san +0 -0
- data/ud_pos_ner_dp_train_VST +0 -0
- data/ud_pos_ner_dp_train_san +0 -0
- data/ud_pos_ner_dp_train_san_org +0 -0
- examples/BiAFF_macro_UAS_LAS.py +108 -0
- examples/BiAFF_write_1300_combined.py +48 -0
- examples/GraphParser.py +599 -0
- examples/GraphParser_MTL_POS.py +633 -0
- examples/SequenceTagger.py +589 -0
- examples/VST_Pred_Prepare.py +34 -0
- examples/VST_macro_score.py +107 -0
- examples/eval/conll03eval.v2 +336 -0
- examples/eval/conll06eval.pl +1826 -0
- examples/macro_UAS_LAS.py +107 -0
- examples/write_1300_combined.py +48 -0
- run_STBC.sh +75 -0
.gitignore
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
./saved_models
|
ReadMe.md
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Official code for the paper ["Systematic Investigation of Strategies Tailored for Low-Resource Settings for Low-Resource Dependency Parsing"](https://arxiv.org/abs/2201.11374).
|
| 2 |
+
If you use this code please cite our paper.
|
| 3 |
+
|
| 4 |
+
## Requirements
|
| 5 |
+
|
| 6 |
+
* Python 3.7
|
| 7 |
+
* Pytorch 1.1.0
|
| 8 |
+
* Cuda 9.0
|
| 9 |
+
* Gensim 3.8.1
|
| 10 |
+
|
| 11 |
+
We assume that you have installed conda beforehand.
|
| 12 |
+
|
| 13 |
+
```
|
| 14 |
+
conda install pytorch==1.1.0 torchvision==0.3.0 cudatoolkit=9.0 -c pytorch
|
| 15 |
+
pip install gensim==3.8.1
|
| 16 |
+
```
|
| 17 |
+
|
| 18 |
+
## Pretrained embeddings for Sanskrit
|
| 19 |
+
* Pretrained FastText embeddings for STBC/VST can be obtained from [here](https://drive.google.com/drive/folders/1SwdEqikTq-N2vOL7QSUX2vqi3faZE7bq?usp=sharing). Make sure that `.txt` file is placed at `data/`
|
| 20 |
+
* The main results are reported on the systems trained by combining train and dev splits.
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
## How to train model for Sanskrit
|
| 24 |
+
To run proposed system: (1) Pretraining (2) Integration, then simply run bash script `run_STBC.sh` or `run_VST.sh` for the respective dataset. With these scripts you will be able to reproduce our results reported in Section-3 and Table 2.
|
| 25 |
+
|
| 26 |
+
```bash
|
| 27 |
+
bash run_STBC.sh
|
| 28 |
+
|
| 29 |
+
```
|
| 30 |
+
|
| 31 |
+
## Citations
|
| 32 |
+
```
|
| 33 |
+
@misc{sandhan_systematic,
|
| 34 |
+
doi = {10.48550/ARXIV.2201.11374},
|
| 35 |
+
url = {https://arxiv.org/abs/2201.11374},
|
| 36 |
+
author = {Sandhan, Jivnesh and Behera, Laxmidhar and Goyal, Pawan},
|
| 37 |
+
keywords = {Computation and Language (cs.CL), FOS: Computer and information sciences, FOS: Computer and information sciences},
|
| 38 |
+
title = {Systematic Investigation of Strategies Tailored for Low-Resource Settings for Low-Resource Dependency Parsing},
|
| 39 |
+
publisher = {arXiv},
|
| 40 |
+
year = {2022},
|
| 41 |
+
copyright = {Creative Commons Attribution 4.0 International}
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
```
|
| 45 |
+
|
| 46 |
+
## Acknowledgements
|
| 47 |
+
Our ensembled system is built on the top of ["DCST Implementation"](https://github.com/rotmanguy/DCST)
|
data/Multitask_case_dev_VST
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/Multitask_case_dev_san
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/Multitask_case_poetry_san
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/Multitask_case_prose_san
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/Multitask_case_test_VST
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/Multitask_case_test_san
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/Multitask_case_train_VST
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/Multitask_case_train_san
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/Multitask_label_dev_VST
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/Multitask_label_dev_san
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/Multitask_label_poetry_san
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/Multitask_label_prose_san
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/Multitask_label_test_VST
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/Multitask_label_test_san
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/Multitask_label_train_VST
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/Multitask_label_train_san
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/Multitask_morph_dev_VST
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/Multitask_morph_dev_san
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/Multitask_morph_poetry_san
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/Multitask_morph_prose_san
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/Multitask_morph_test_VST
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/Multitask_morph_test_san
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/Multitask_morph_train_VST
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/Multitask_morph_train_san
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/combined_1300_test.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/ud_pos_ner_dp_dev_VST
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/ud_pos_ner_dp_dev_san
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/ud_pos_ner_dp_poetry_VST
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/ud_pos_ner_dp_poetry_san
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/ud_pos_ner_dp_prose_VST
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/ud_pos_ner_dp_prose_san
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/ud_pos_ner_dp_test_VST
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/ud_pos_ner_dp_test_san
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/ud_pos_ner_dp_train_VST
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/ud_pos_ner_dp_train_san
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/ud_pos_ner_dp_train_san_org
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
examples/BiAFF_macro_UAS_LAS.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def load_results(filename):
|
| 5 |
+
|
| 6 |
+
results = []
|
| 7 |
+
sent = []
|
| 8 |
+
with open(filename, 'r') as fp:
|
| 9 |
+
for i, line in enumerate(fp):
|
| 10 |
+
if i == 0:
|
| 11 |
+
continue
|
| 12 |
+
splits = line.strip().split('\t')
|
| 13 |
+
if len(line.strip()) == 0:
|
| 14 |
+
if len(sent) != 0:
|
| 15 |
+
results.append(sent)
|
| 16 |
+
sent = []
|
| 17 |
+
continue
|
| 18 |
+
gold_head = splits[-4]
|
| 19 |
+
gold_label = splits[-3]
|
| 20 |
+
pred_head = splits[-2]
|
| 21 |
+
pred_label = splits[-1]
|
| 22 |
+
sent.append((gold_head, gold_label, pred_head, pred_label))
|
| 23 |
+
print('Total Number of sentences ' + str(len(results)))
|
| 24 |
+
return results
|
| 25 |
+
|
| 26 |
+
def calculate_las_uas(gold_heads, gold_labels, pred_heads, pred_labels):
|
| 27 |
+
|
| 28 |
+
u_correct = 0
|
| 29 |
+
l_correct = 0
|
| 30 |
+
u_total = 0
|
| 31 |
+
l_total = 0
|
| 32 |
+
|
| 33 |
+
for i in range(len(gold_heads)):
|
| 34 |
+
if gold_heads[i] == pred_heads[i]:
|
| 35 |
+
u_correct +=1
|
| 36 |
+
u_total +=1
|
| 37 |
+
l_total +=1
|
| 38 |
+
if gold_heads[i] == pred_heads[i] and gold_labels[i] == pred_labels[i]:
|
| 39 |
+
l_correct +=1
|
| 40 |
+
return u_correct, u_total, l_correct, l_total
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def calculate_stats(results,path):
|
| 44 |
+
u_correct = 0
|
| 45 |
+
l_correct = 0
|
| 46 |
+
u_total = 0
|
| 47 |
+
l_total = 0
|
| 48 |
+
|
| 49 |
+
sent_uas = []
|
| 50 |
+
sent_las = []
|
| 51 |
+
|
| 52 |
+
for i in range(len(results)):
|
| 53 |
+
gold_heads, gold_labels, pred_heads, pred_labels = zip(*results[i])
|
| 54 |
+
u_c, u_t, l_c, l_t = calculate_las_uas(gold_heads, gold_labels, pred_heads, pred_labels)
|
| 55 |
+
if u_t >0:
|
| 56 |
+
uas = float(u_c)/u_t
|
| 57 |
+
las = float(l_c)/l_t
|
| 58 |
+
sent_uas.append(uas)
|
| 59 |
+
sent_las.append(las)
|
| 60 |
+
u_correct += u_c
|
| 61 |
+
l_correct += l_c
|
| 62 |
+
u_total += u_t
|
| 63 |
+
l_total += l_t
|
| 64 |
+
|
| 65 |
+
UAS = float(u_correct)/u_total
|
| 66 |
+
LAS = float(l_correct)/l_total
|
| 67 |
+
path = path.replace('combined_1300_test.txt','Macro-UAS-LAS-score.txt')
|
| 68 |
+
f = open(path,'w')
|
| 69 |
+
f.write('Word level UAS : ' + str(UAS) +'\n')
|
| 70 |
+
f.write('Word level LAS : ' + str(LAS)+'\n')
|
| 71 |
+
f.write('Sentence level UAS : ' + str(float(sum(sent_uas))/len(sent_uas))+'\n')
|
| 72 |
+
f.write('Sentence level LAS : ' + str(float(sum(sent_las))/len(sent_las))+'\n')
|
| 73 |
+
f.close()
|
| 74 |
+
print('Word level UAS : ' + str(UAS))
|
| 75 |
+
print('Word level LAS : ' + str(LAS))
|
| 76 |
+
print('Sentence level UAS : ' + str(float(sum(sent_uas))/len(sent_uas)))
|
| 77 |
+
print('Sentence level LAS : ' + str(float(sum(sent_las))/len(sent_las)))
|
| 78 |
+
|
| 79 |
+
return sent_uas, sent_las, UAS, LAS
|
| 80 |
+
|
| 81 |
+
def write_results(sent_uas, sent_las, filename_uas, filename_las):
|
| 82 |
+
|
| 83 |
+
fp_uas = open(filename_uas, 'w')
|
| 84 |
+
fp_las = open(filename_las, 'w')
|
| 85 |
+
|
| 86 |
+
for i in range(len(sent_uas)):
|
| 87 |
+
fp_uas.write(str(sent_uas[i]) + '\n')
|
| 88 |
+
fp_las.write(str(sent_las[i]) + '\n')
|
| 89 |
+
|
| 90 |
+
fp_uas.close()
|
| 91 |
+
fp_las.close()
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
if __name__=="__main__":
|
| 95 |
+
dirs = sys.argv[1]
|
| 96 |
+
# results_2 = load_results(sys.argv[2])
|
| 97 |
+
##path = "Predictions/Yap/"+dirs
|
| 98 |
+
# path = "/home/jivnesh/Documents/San-SOTA/saved_models/"+dirs+"/final_ensembled/dev_combined_1000.txt"
|
| 99 |
+
path = "./saved_models/"+dirs+"/combined_1300_test.txt"
|
| 100 |
+
result = load_results(path)
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
sent_uas1, sent_las1, UAS1, LAS1 = calculate_stats(result,path)
|
| 104 |
+
# sent_uas2, sent_las2, UAS2, LAS2 = calculate_stats(results_2)
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
write_results(sent_uas1, sent_las1, 'results1_uas.txt', 'results1_las.txt')
|
| 108 |
+
# write_results(sent_uas2, sent_las2, 'results2_uas.txt', 'results2_las.txt')
|
examples/BiAFF_write_1300_combined.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
|
| 3 |
+
def write_combined(dirs):
|
| 4 |
+
path = "./saved_models/"+dirs+'/'
|
| 5 |
+
f = open(path+'domain_san_test_model_domain_san_data_domain_san_gold.txt','r')
|
| 6 |
+
gold = f.readlines()
|
| 7 |
+
f.close()
|
| 8 |
+
f = open(path+'domain_san_test_model_domain_san_data_domain_san_pred.txt','r')
|
| 9 |
+
pred = f.readlines()
|
| 10 |
+
f.close()
|
| 11 |
+
|
| 12 |
+
for i in range(len(gold)):
|
| 13 |
+
if gold[i] == '\n':
|
| 14 |
+
continue
|
| 15 |
+
if gold[i].split('\t')[0] == pred[i].split('\t')[0]:
|
| 16 |
+
gold[i] = gold[i].replace('\n','\t')
|
| 17 |
+
gold[i] = gold[i]+'\t'.join(pred[i].split('\t')[-2:])
|
| 18 |
+
|
| 19 |
+
f = open(path+'domain_san_prose_model_domain_san_data_domain_san_gold.txt','r')
|
| 20 |
+
prose_gold = f.readlines()
|
| 21 |
+
f.close()
|
| 22 |
+
f = open(path+'domain_san_prose_model_domain_san_data_domain_san_pred.txt','r')
|
| 23 |
+
prose_pred = f.readlines()
|
| 24 |
+
f.close()
|
| 25 |
+
|
| 26 |
+
for i in range(len(prose_gold)):
|
| 27 |
+
if prose_gold[i] == '\n':
|
| 28 |
+
gold.append('\n')
|
| 29 |
+
continue
|
| 30 |
+
if prose_gold[i].split('\t')[0] == prose_pred[i].split('\t')[0]:
|
| 31 |
+
line = prose_gold[i].replace('\n','\t')
|
| 32 |
+
line =line+'\t'.join(prose_pred[i].split('\t')[-2:])
|
| 33 |
+
gold.append(line)
|
| 34 |
+
gold.insert(0,'word_id\tword\tpostag\tlemma\tgold_head\tgold_label\tpred_head\tpred_label\n\n')
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
f = open(path+'combined_1300_test.txt','w')
|
| 38 |
+
for line in gold:
|
| 39 |
+
f.write(line)
|
| 40 |
+
f.close()
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
if __name__=="__main__":
|
| 44 |
+
|
| 45 |
+
dir_path = sys.argv[1]
|
| 46 |
+
|
| 47 |
+
write_combined(dir_path)
|
| 48 |
+
|
examples/GraphParser.py
ADDED
|
@@ -0,0 +1,599 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import print_function
|
| 2 |
+
import sys
|
| 3 |
+
from os import path, makedirs
|
| 4 |
+
|
| 5 |
+
sys.path.append(".")
|
| 6 |
+
sys.path.append("..")
|
| 7 |
+
|
| 8 |
+
import argparse
|
| 9 |
+
from copy import deepcopy
|
| 10 |
+
import json
|
| 11 |
+
import numpy as np
|
| 12 |
+
import torch
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
from collections import namedtuple
|
| 15 |
+
from utils.io_ import seeds, Writer, get_logger, prepare_data, rearrange_splits
|
| 16 |
+
from utils.models.parsing_gating import BiAffine_Parser_Gated
|
| 17 |
+
from utils import load_word_embeddings
|
| 18 |
+
from utils.tasks import parse
|
| 19 |
+
import time
|
| 20 |
+
from torch.nn.utils import clip_grad_norm_
|
| 21 |
+
from torch.optim import Adam, SGD
|
| 22 |
+
import uuid
|
| 23 |
+
|
| 24 |
+
uid = uuid.uuid4().hex[:6]
|
| 25 |
+
|
| 26 |
+
logger = get_logger('GraphParser')
|
| 27 |
+
|
| 28 |
+
def read_arguments():
|
| 29 |
+
args_ = argparse.ArgumentParser(description='Sovling GraphParser')
|
| 30 |
+
args_.add_argument('--dataset', choices=['ontonotes', 'ud'], help='Dataset', required=True)
|
| 31 |
+
args_.add_argument('--domain', help='domain/language', required=True)
|
| 32 |
+
args_.add_argument('--rnn_mode', choices=['RNN', 'LSTM', 'GRU'], help='architecture of rnn',
|
| 33 |
+
required=True)
|
| 34 |
+
args_.add_argument('--gating',action='store_true', help='use gated mechanism')
|
| 35 |
+
args_.add_argument('--num_gates', type=int, default=0, help='number of gates for gating mechanism')
|
| 36 |
+
args_.add_argument('--num_epochs', type=int, default=200, help='Number of training epochs')
|
| 37 |
+
args_.add_argument('--batch_size', type=int, default=64, help='Number of sentences in each batch')
|
| 38 |
+
args_.add_argument('--hidden_size', type=int, default=256, help='Number of hidden units in RNN')
|
| 39 |
+
args_.add_argument('--arc_space', type=int, default=128, help='Dimension of tag space')
|
| 40 |
+
args_.add_argument('--arc_tag_space', type=int, default=128, help='Dimension of tag space')
|
| 41 |
+
args_.add_argument('--num_layers', type=int, default=1, help='Number of layers of RNN')
|
| 42 |
+
args_.add_argument('--num_filters', type=int, default=50, help='Number of filters in CNN')
|
| 43 |
+
args_.add_argument('--kernel_size', type=int, default=3, help='Size of Kernel for CNN')
|
| 44 |
+
args_.add_argument('--use_pos', action='store_true', help='use part-of-speech embedding.')
|
| 45 |
+
args_.add_argument('--use_char', action='store_true', help='use character embedding and CNN.')
|
| 46 |
+
args_.add_argument('--word_dim', type=int, default=300, help='Dimension of word embeddings')
|
| 47 |
+
args_.add_argument('--pos_dim', type=int, default=50, help='Dimension of POS embeddings')
|
| 48 |
+
args_.add_argument('--char_dim', type=int, default=50, help='Dimension of Character embeddings')
|
| 49 |
+
args_.add_argument('--initializer', choices=['xavier'], help='initialize model parameters')
|
| 50 |
+
args_.add_argument('--opt', choices=['adam', 'sgd'], help='optimization algorithm')
|
| 51 |
+
args_.add_argument('--momentum', type=float, default=0.9, help='momentum of optimizer')
|
| 52 |
+
args_.add_argument('--betas', nargs=2, type=float, default=[0.9, 0.9], help='betas of optimizer')
|
| 53 |
+
args_.add_argument('--learning_rate', type=float, default=0.01, help='Learning rate')
|
| 54 |
+
args_.add_argument('--decay_rate', type=float, default=0.05, help='Decay rate of learning rate')
|
| 55 |
+
args_.add_argument('--schedule', type=int, help='schedule for learning rate decay')
|
| 56 |
+
args_.add_argument('--clip', type=float, default=5.0, help='gradient clipping')
|
| 57 |
+
args_.add_argument('--gamma', type=float, default=0.0, help='weight for regularization')
|
| 58 |
+
args_.add_argument('--epsilon', type=float, default=1e-8, help='epsilon for adam')
|
| 59 |
+
args_.add_argument('--p_rnn', nargs=2, type=float, required=True, help='dropout rate for RNN')
|
| 60 |
+
args_.add_argument('--p_in', type=float, default=0.33, help='dropout rate for input embeddings')
|
| 61 |
+
args_.add_argument('--p_out', type=float, default=0.33, help='dropout rate for output layer')
|
| 62 |
+
args_.add_argument('--arc_decode', choices=['mst', 'greedy'], help='arc decoding algorithm', required=True)
|
| 63 |
+
args_.add_argument('--unk_replace', type=float, default=0.,
|
| 64 |
+
help='The rate to replace a singleton word with UNK')
|
| 65 |
+
args_.add_argument('--punct_set', nargs='+', type=str, help='List of punctuations')
|
| 66 |
+
args_.add_argument('--word_embedding', choices=['random', 'glove', 'fasttext', 'word2vec'],
|
| 67 |
+
help='Embedding for words')
|
| 68 |
+
args_.add_argument('--word_path', help='path for word embedding dict - in case word_embedding is not random')
|
| 69 |
+
args_.add_argument('--freeze_word_embeddings', action='store_true', help='frozen the word embedding (disable fine-tuning).')
|
| 70 |
+
args_.add_argument('--freeze_sequence_taggers', action='store_true', help='frozen the BiLSTMs of the pre-trained taggers.')
|
| 71 |
+
args_.add_argument('--char_embedding', choices=['random','hellwig'], help='Embedding for characters',
|
| 72 |
+
required=True)
|
| 73 |
+
args_.add_argument('--pos_embedding', choices=['random','one_hot'], help='Embedding for pos',
|
| 74 |
+
required=True)
|
| 75 |
+
args_.add_argument('--char_path', help='path for character embedding dict')
|
| 76 |
+
args_.add_argument('--pos_path', help='path for pos embedding dict')
|
| 77 |
+
args_.add_argument('--set_num_training_samples', type=int, help='downsampling training set to a fixed number of samples')
|
| 78 |
+
args_.add_argument('--model_path', help='path for saving model file.', required=True)
|
| 79 |
+
args_.add_argument('--load_path', help='path for loading saved source model file.', default=None)
|
| 80 |
+
args_.add_argument('--load_sequence_taggers_paths', nargs='+', help='path for loading saved sequence_tagger saved_models files.', default=None)
|
| 81 |
+
args_.add_argument('--strict',action='store_true', help='if True loaded model state should contin '
|
| 82 |
+
'exactly the same keys as current model')
|
| 83 |
+
args_.add_argument('--eval_mode', action='store_true', help='evaluating model without training it')
|
| 84 |
+
args = args_.parse_args()
|
| 85 |
+
args_dict = {}
|
| 86 |
+
args_dict['dataset'] = args.dataset
|
| 87 |
+
args_dict['domain'] = args.domain
|
| 88 |
+
args_dict['rnn_mode'] = args.rnn_mode
|
| 89 |
+
args_dict['gating'] = args.gating
|
| 90 |
+
args_dict['num_gates'] = args.num_gates
|
| 91 |
+
args_dict['arc_decode'] = args.arc_decode
|
| 92 |
+
# args_dict['splits'] = ['train', 'dev', 'test']
|
| 93 |
+
args_dict['splits'] = ['train', 'dev', 'test','poetry','prose']
|
| 94 |
+
args_dict['model_path'] = args.model_path
|
| 95 |
+
if not path.exists(args_dict['model_path']):
|
| 96 |
+
makedirs(args_dict['model_path'])
|
| 97 |
+
args_dict['data_paths'] = {}
|
| 98 |
+
if args_dict['dataset'] == 'ontonotes':
|
| 99 |
+
data_path = 'data/onto_pos_ner_dp'
|
| 100 |
+
else:
|
| 101 |
+
data_path = 'data/ud_pos_ner_dp'
|
| 102 |
+
for split in args_dict['splits']:
|
| 103 |
+
args_dict['data_paths'][split] = data_path + '_' + split + '_' + args_dict['domain']
|
| 104 |
+
###################################
|
| 105 |
+
args_dict['data_paths']['poetry'] = 'data/ud_pos_ner_dp' + '_' + 'poetry' + '_' + args_dict['domain']
|
| 106 |
+
args_dict['data_paths']['prose'] = 'data/ud_pos_ner_dp' + '_' + 'prose' + '_' + args_dict['domain']
|
| 107 |
+
###################################
|
| 108 |
+
args_dict['alphabet_data_paths'] = {}
|
| 109 |
+
for split in args_dict['splits']:
|
| 110 |
+
if args_dict['dataset'] == 'ontonotes':
|
| 111 |
+
args_dict['alphabet_data_paths'][split] = data_path + '_' + split + '_' + 'all'
|
| 112 |
+
else:
|
| 113 |
+
if '_' in args_dict['domain']:
|
| 114 |
+
args_dict['alphabet_data_paths'][split] = data_path + '_' + split + '_' + args_dict['domain'].split('_')[0]
|
| 115 |
+
else:
|
| 116 |
+
args_dict['alphabet_data_paths'][split] = args_dict['data_paths'][split]
|
| 117 |
+
args_dict['model_name'] = 'domain_' + args_dict['domain']
|
| 118 |
+
args_dict['full_model_name'] = path.join(args_dict['model_path'],args_dict['model_name'])
|
| 119 |
+
args_dict['load_path'] = args.load_path
|
| 120 |
+
args_dict['load_sequence_taggers_paths'] = args.load_sequence_taggers_paths
|
| 121 |
+
if args_dict['load_sequence_taggers_paths'] is not None:
|
| 122 |
+
args_dict['gating'] = True
|
| 123 |
+
args_dict['num_gates'] = len(args_dict['load_sequence_taggers_paths']) + 1
|
| 124 |
+
else:
|
| 125 |
+
if not args_dict['gating']:
|
| 126 |
+
args_dict['num_gates'] = 0
|
| 127 |
+
args_dict['strict'] = args.strict
|
| 128 |
+
args_dict['num_epochs'] = args.num_epochs
|
| 129 |
+
args_dict['batch_size'] = args.batch_size
|
| 130 |
+
args_dict['hidden_size'] = args.hidden_size
|
| 131 |
+
args_dict['arc_space'] = args.arc_space
|
| 132 |
+
args_dict['arc_tag_space'] = args.arc_tag_space
|
| 133 |
+
args_dict['num_layers'] = args.num_layers
|
| 134 |
+
args_dict['num_filters'] = args.num_filters
|
| 135 |
+
args_dict['kernel_size'] = args.kernel_size
|
| 136 |
+
args_dict['learning_rate'] = args.learning_rate
|
| 137 |
+
args_dict['initializer'] = nn.init.xavier_uniform_ if args.initializer == 'xavier' else None
|
| 138 |
+
args_dict['opt'] = args.opt
|
| 139 |
+
args_dict['momentum'] = args.momentum
|
| 140 |
+
args_dict['betas'] = tuple(args.betas)
|
| 141 |
+
args_dict['epsilon'] = args.epsilon
|
| 142 |
+
args_dict['decay_rate'] = args.decay_rate
|
| 143 |
+
args_dict['clip'] = args.clip
|
| 144 |
+
args_dict['gamma'] = args.gamma
|
| 145 |
+
args_dict['schedule'] = args.schedule
|
| 146 |
+
args_dict['p_rnn'] = tuple(args.p_rnn)
|
| 147 |
+
args_dict['p_in'] = args.p_in
|
| 148 |
+
args_dict['p_out'] = args.p_out
|
| 149 |
+
args_dict['unk_replace'] = args.unk_replace
|
| 150 |
+
args_dict['set_num_training_samples'] = args.set_num_training_samples
|
| 151 |
+
args_dict['punct_set'] = None
|
| 152 |
+
if args.punct_set is not None:
|
| 153 |
+
args_dict['punct_set'] = set(args.punct_set)
|
| 154 |
+
logger.info("punctuations(%d): %s" % (len(args_dict['punct_set']), ' '.join(args_dict['punct_set'])))
|
| 155 |
+
args_dict['freeze_word_embeddings'] = args.freeze_word_embeddings
|
| 156 |
+
args_dict['freeze_sequence_taggers'] = args.freeze_sequence_taggers
|
| 157 |
+
args_dict['word_embedding'] = args.word_embedding
|
| 158 |
+
args_dict['word_path'] = args.word_path
|
| 159 |
+
args_dict['use_char'] = args.use_char
|
| 160 |
+
args_dict['char_embedding'] = args.char_embedding
|
| 161 |
+
args_dict['char_path'] = args.char_path
|
| 162 |
+
args_dict['pos_embedding'] = args.pos_embedding
|
| 163 |
+
args_dict['pos_path'] = args.pos_path
|
| 164 |
+
args_dict['use_pos'] = args.use_pos
|
| 165 |
+
args_dict['pos_dim'] = args.pos_dim
|
| 166 |
+
args_dict['word_dict'] = None
|
| 167 |
+
args_dict['word_dim'] = args.word_dim
|
| 168 |
+
if args_dict['word_embedding'] != 'random' and args_dict['word_path']:
|
| 169 |
+
args_dict['word_dict'], args_dict['word_dim'] = load_word_embeddings.load_embedding_dict(args_dict['word_embedding'],
|
| 170 |
+
args_dict['word_path'])
|
| 171 |
+
args_dict['char_dict'] = None
|
| 172 |
+
args_dict['char_dim'] = args.char_dim
|
| 173 |
+
if args_dict['char_embedding'] != 'random':
|
| 174 |
+
args_dict['char_dict'], args_dict['char_dim'] = load_word_embeddings.load_embedding_dict(args_dict['char_embedding'],
|
| 175 |
+
args_dict['char_path'])
|
| 176 |
+
args_dict['pos_dict'] = None
|
| 177 |
+
if args_dict['pos_embedding'] != 'random':
|
| 178 |
+
args_dict['pos_dict'], args_dict['pos_dim'] = load_word_embeddings.load_embedding_dict(args_dict['pos_embedding'],
|
| 179 |
+
args_dict['pos_path'])
|
| 180 |
+
args_dict['alphabet_path'] = path.join(args_dict['model_path'], 'alphabets' + '_src_domain_' + args_dict['domain'] + '/')
|
| 181 |
+
args_dict['model_name'] = path.join(args_dict['model_path'], args_dict['model_name'])
|
| 182 |
+
args_dict['eval_mode'] = args.eval_mode
|
| 183 |
+
args_dict['device'] = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 184 |
+
args_dict['word_status'] = 'frozen' if args.freeze_word_embeddings else 'fine tune'
|
| 185 |
+
args_dict['char_status'] = 'enabled' if args.use_char else 'disabled'
|
| 186 |
+
args_dict['pos_status'] = 'enabled' if args.use_pos else 'disabled'
|
| 187 |
+
logger.info("Saving arguments to file")
|
| 188 |
+
save_args(args, args_dict['full_model_name'])
|
| 189 |
+
logger.info("Creating Alphabets")
|
| 190 |
+
alphabet_dict = creating_alphabets(args_dict['alphabet_path'], args_dict['alphabet_data_paths'], args_dict['word_dict'])
|
| 191 |
+
args_dict = {**args_dict, **alphabet_dict}
|
| 192 |
+
ARGS = namedtuple('ARGS', args_dict.keys())
|
| 193 |
+
my_args = ARGS(**args_dict)
|
| 194 |
+
return my_args
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
def creating_alphabets(alphabet_path, alphabet_data_paths, word_dict):
|
| 198 |
+
train_paths = alphabet_data_paths['train']
|
| 199 |
+
extra_paths = [v for k,v in alphabet_data_paths.items() if k != 'train']
|
| 200 |
+
alphabet_dict = {}
|
| 201 |
+
alphabet_dict['alphabets'] = prepare_data.create_alphabets(alphabet_path,
|
| 202 |
+
train_paths,
|
| 203 |
+
extra_paths=extra_paths,
|
| 204 |
+
max_vocabulary_size=100000,
|
| 205 |
+
embedd_dict=word_dict)
|
| 206 |
+
for k, v in alphabet_dict['alphabets'].items():
|
| 207 |
+
num_key = 'num_' + k.split('_')[0]
|
| 208 |
+
alphabet_dict[num_key] = v.size()
|
| 209 |
+
logger.info("%s : %d" % (num_key, alphabet_dict[num_key]))
|
| 210 |
+
return alphabet_dict
|
| 211 |
+
|
| 212 |
+
def construct_embedding_table(alphabet, tokens_dict, dim, token_type='word'):
|
| 213 |
+
if tokens_dict is None:
|
| 214 |
+
return None
|
| 215 |
+
scale = np.sqrt(3.0 / dim)
|
| 216 |
+
table = np.empty([alphabet.size(), dim], dtype=np.float32)
|
| 217 |
+
table[prepare_data.UNK_ID, :] = np.random.uniform(-scale, scale, [1, dim]).astype(np.float32)
|
| 218 |
+
oov_tokens = 0
|
| 219 |
+
for token, index in alphabet.items():
|
| 220 |
+
if token in tokens_dict:
|
| 221 |
+
embedding = tokens_dict[token]
|
| 222 |
+
elif token.lower() in tokens_dict:
|
| 223 |
+
embedding = tokens_dict[token.lower()]
|
| 224 |
+
else:
|
| 225 |
+
embedding = np.random.uniform(-scale, scale, [1, dim]).astype(np.float32)
|
| 226 |
+
oov_tokens += 1
|
| 227 |
+
table[index, :] = embedding
|
| 228 |
+
print('token type : %s, number of oov: %d' % (token_type, oov_tokens))
|
| 229 |
+
table = torch.from_numpy(table)
|
| 230 |
+
return table
|
| 231 |
+
|
| 232 |
+
def save_args(args, full_model_name):
|
| 233 |
+
arg_path = full_model_name + '.arg.json'
|
| 234 |
+
argparse_dict = vars(args)
|
| 235 |
+
with open(arg_path, 'w') as f:
|
| 236 |
+
json.dump(argparse_dict, f)
|
| 237 |
+
|
| 238 |
+
def generate_optimizer(args, lr, params):
|
| 239 |
+
params = filter(lambda param: param.requires_grad, params)
|
| 240 |
+
if args.opt == 'adam':
|
| 241 |
+
return Adam(params, lr=lr, betas=args.betas, weight_decay=args.gamma, eps=args.epsilon)
|
| 242 |
+
elif args.opt == 'sgd':
|
| 243 |
+
return SGD(params, lr=lr, momentum=args.momentum, weight_decay=args.gamma, nesterov=True)
|
| 244 |
+
else:
|
| 245 |
+
raise ValueError('Unknown optimization algorithm: %s' % args.opt)
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
def save_checkpoint(model, optimizer, opt, dev_eval_dict, test_eval_dict, full_model_name):
|
| 249 |
+
path_name = full_model_name + '.pt'
|
| 250 |
+
print('Saving model to: %s' % path_name)
|
| 251 |
+
state = {'model_state_dict': model.state_dict(),
|
| 252 |
+
'optimizer_state_dict': optimizer.state_dict(),
|
| 253 |
+
'opt': opt,
|
| 254 |
+
'dev_eval_dict': dev_eval_dict,
|
| 255 |
+
'test_eval_dict': test_eval_dict}
|
| 256 |
+
torch.save(state, path_name)
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
def load_checkpoint(args, model, optimizer, dev_eval_dict, test_eval_dict, start_epoch, load_path, strict=True):
|
| 260 |
+
print('Loading saved model from: %s' % load_path)
|
| 261 |
+
checkpoint = torch.load(load_path, map_location=args.device)
|
| 262 |
+
if checkpoint['opt'] != args.opt:
|
| 263 |
+
raise ValueError('loaded optimizer type is: %s instead of: %s' % (checkpoint['opt'], args.opt))
|
| 264 |
+
model.load_state_dict(checkpoint['model_state_dict'], strict=strict)
|
| 265 |
+
|
| 266 |
+
if strict:
|
| 267 |
+
generate_optimizer(args, args.learning_rate, model.parameters())
|
| 268 |
+
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
| 269 |
+
for state in optimizer.state.values():
|
| 270 |
+
for k, v in state.items():
|
| 271 |
+
if isinstance(v, torch.Tensor):
|
| 272 |
+
state[k] = v.to(args.device)
|
| 273 |
+
dev_eval_dict = checkpoint['dev_eval_dict']
|
| 274 |
+
test_eval_dict = checkpoint['test_eval_dict']
|
| 275 |
+
start_epoch = dev_eval_dict['in_domain']['epoch']
|
| 276 |
+
return model, optimizer, dev_eval_dict, test_eval_dict, start_epoch
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
def build_model_and_optimizer(args):
|
| 280 |
+
word_table = construct_embedding_table(args.alphabets['word_alphabet'], args.word_dict, args.word_dim, token_type='word')
|
| 281 |
+
char_table = construct_embedding_table(args.alphabets['char_alphabet'], args.char_dict, args.char_dim, token_type='char')
|
| 282 |
+
pos_table = construct_embedding_table(args.alphabets['pos_alphabet'], args.pos_dict, args.pos_dim, token_type='pos')
|
| 283 |
+
model = BiAffine_Parser_Gated(args.word_dim, args.num_word, args.char_dim, args.num_char,
|
| 284 |
+
args.use_pos, args.use_char, args.pos_dim, args.num_pos,
|
| 285 |
+
args.num_filters, args.kernel_size, args.rnn_mode,
|
| 286 |
+
args.hidden_size, args.num_layers, args.num_arc,
|
| 287 |
+
args.arc_space, args.arc_tag_space, args.num_gates,
|
| 288 |
+
embedd_word=word_table, embedd_char=char_table, embedd_pos=pos_table,
|
| 289 |
+
p_in=args.p_in, p_out=args.p_out, p_rnn=args.p_rnn,
|
| 290 |
+
biaffine=True, arc_decode=args.arc_decode, initializer=args.initializer)
|
| 291 |
+
print(model)
|
| 292 |
+
optimizer = generate_optimizer(args, args.learning_rate, model.parameters())
|
| 293 |
+
start_epoch = 0
|
| 294 |
+
dev_eval_dict = {'in_domain': initialize_eval_dict()}
|
| 295 |
+
test_eval_dict = {'in_domain': initialize_eval_dict()}
|
| 296 |
+
if args.load_path:
|
| 297 |
+
model, optimizer, dev_eval_dict, test_eval_dict, start_epoch = \
|
| 298 |
+
load_checkpoint(args, model, optimizer,
|
| 299 |
+
dev_eval_dict, test_eval_dict,
|
| 300 |
+
start_epoch, args.load_path, strict=args.strict)
|
| 301 |
+
if args.load_sequence_taggers_paths:
|
| 302 |
+
pretrained_dict = {}
|
| 303 |
+
model_dict = model.state_dict()
|
| 304 |
+
for idx, path in enumerate(args.load_sequence_taggers_paths):
|
| 305 |
+
print('Loading saved sequence_tagger from: %s' % path)
|
| 306 |
+
checkpoint = torch.load(path, map_location=args.device)
|
| 307 |
+
for k, v in checkpoint['model_state_dict'].items():
|
| 308 |
+
if 'rnn_encoder.' in k:
|
| 309 |
+
pretrained_dict['extra_rnn_encoders.' + str(idx) + '.' + k.replace('rnn_encoder.', '')] = v
|
| 310 |
+
model_dict.update(pretrained_dict)
|
| 311 |
+
model.load_state_dict(model_dict)
|
| 312 |
+
if args.freeze_sequence_taggers:
|
| 313 |
+
print('Freezing Classifiers')
|
| 314 |
+
for name, parameter in model.named_parameters():
|
| 315 |
+
if 'extra_rnn_encoders' in name:
|
| 316 |
+
parameter.requires_grad = False
|
| 317 |
+
if args.freeze_word_embeddings:
|
| 318 |
+
model.rnn_encoder.word_embedd.weight.requires_grad = False
|
| 319 |
+
# model.rnn_encoder.char_embedd.weight.requires_grad = False
|
| 320 |
+
# model.rnn_encoder.pos_embedd.weight.requires_grad = False
|
| 321 |
+
device = args.device
|
| 322 |
+
model.to(device)
|
| 323 |
+
return model, optimizer, dev_eval_dict, test_eval_dict, start_epoch
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
def initialize_eval_dict():
|
| 327 |
+
eval_dict = {}
|
| 328 |
+
eval_dict['dp_uas'] = 0.0
|
| 329 |
+
eval_dict['dp_las'] = 0.0
|
| 330 |
+
eval_dict['epoch'] = 0
|
| 331 |
+
eval_dict['dp_ucorrect'] = 0.0
|
| 332 |
+
eval_dict['dp_lcorrect'] = 0.0
|
| 333 |
+
eval_dict['dp_total'] = 0.0
|
| 334 |
+
eval_dict['dp_ucomplete_match'] = 0.0
|
| 335 |
+
eval_dict['dp_lcomplete_match'] = 0.0
|
| 336 |
+
eval_dict['dp_ucorrect_nopunc'] = 0.0
|
| 337 |
+
eval_dict['dp_lcorrect_nopunc'] = 0.0
|
| 338 |
+
eval_dict['dp_total_nopunc'] = 0.0
|
| 339 |
+
eval_dict['dp_ucomplete_match_nopunc'] = 0.0
|
| 340 |
+
eval_dict['dp_lcomplete_match_nopunc'] = 0.0
|
| 341 |
+
eval_dict['dp_root_correct'] = 0.0
|
| 342 |
+
eval_dict['dp_total_root'] = 0.0
|
| 343 |
+
eval_dict['dp_total_inst'] = 0.0
|
| 344 |
+
eval_dict['dp_total'] = 0.0
|
| 345 |
+
eval_dict['dp_total_inst'] = 0.0
|
| 346 |
+
eval_dict['dp_total_nopunc'] = 0.0
|
| 347 |
+
eval_dict['dp_total_root'] = 0.0
|
| 348 |
+
return eval_dict
|
| 349 |
+
|
| 350 |
+
def in_domain_evaluation(args, datasets, model, optimizer, dev_eval_dict, test_eval_dict, epoch,
|
| 351 |
+
best_model, best_optimizer, patient):
|
| 352 |
+
# In-domain evaluation
|
| 353 |
+
curr_dev_eval_dict = evaluation(args, datasets['dev'], 'dev', model, args.domain, epoch, 'current_results')
|
| 354 |
+
is_best_in_domain = dev_eval_dict['in_domain']['dp_lcorrect_nopunc'] <= curr_dev_eval_dict['dp_lcorrect_nopunc'] or \
|
| 355 |
+
(dev_eval_dict['in_domain']['dp_lcorrect_nopunc'] == curr_dev_eval_dict['dp_lcorrect_nopunc'] and
|
| 356 |
+
dev_eval_dict['in_domain']['dp_ucorrect_nopunc'] <= curr_dev_eval_dict['dp_ucorrect_nopunc'])
|
| 357 |
+
|
| 358 |
+
if is_best_in_domain:
|
| 359 |
+
for key, value in curr_dev_eval_dict.items():
|
| 360 |
+
dev_eval_dict['in_domain'][key] = value
|
| 361 |
+
curr_test_eval_dict = evaluation(args, datasets['test'], 'test', model, args.domain, epoch, 'current_results')
|
| 362 |
+
for key, value in curr_test_eval_dict.items():
|
| 363 |
+
test_eval_dict['in_domain'][key] = value
|
| 364 |
+
best_model = deepcopy(model)
|
| 365 |
+
best_optimizer = deepcopy(optimizer)
|
| 366 |
+
patient = 0
|
| 367 |
+
else:
|
| 368 |
+
patient += 1
|
| 369 |
+
if epoch == args.num_epochs:
|
| 370 |
+
# save in-domain checkpoint
|
| 371 |
+
if args.set_num_training_samples is not None:
|
| 372 |
+
splits_to_write = datasets.keys()
|
| 373 |
+
else:
|
| 374 |
+
splits_to_write = ['dev', 'test']
|
| 375 |
+
for split in splits_to_write:
|
| 376 |
+
if split == 'dev':
|
| 377 |
+
eval_dict = dev_eval_dict['in_domain']
|
| 378 |
+
elif split == 'test':
|
| 379 |
+
eval_dict = test_eval_dict['in_domain']
|
| 380 |
+
else:
|
| 381 |
+
eval_dict = None
|
| 382 |
+
write_results(args, datasets[split], args.domain, split, best_model, args.domain, eval_dict)
|
| 383 |
+
print("Saving best model")
|
| 384 |
+
save_checkpoint(best_model, best_optimizer, args.opt, dev_eval_dict, test_eval_dict, args.full_model_name)
|
| 385 |
+
|
| 386 |
+
print('\n')
|
| 387 |
+
return dev_eval_dict, test_eval_dict, best_model, best_optimizer, patient
|
| 388 |
+
|
| 389 |
+
|
| 390 |
+
def evaluation(args, data, split, model, domain, epoch, str_res='results'):
|
| 391 |
+
# evaluate performance on data
|
| 392 |
+
model.eval()
|
| 393 |
+
|
| 394 |
+
eval_dict = initialize_eval_dict()
|
| 395 |
+
eval_dict['epoch'] = epoch
|
| 396 |
+
for batch in prepare_data.iterate_batch(data, args.batch_size, args.device):
|
| 397 |
+
word, char, pos, ner, heads, arc_tags, auto_label, masks, lengths = batch
|
| 398 |
+
out_arc, out_arc_tag, masks, lengths = model.forward(word, char, pos, mask=masks, length=lengths)
|
| 399 |
+
heads_pred, arc_tags_pred, _ = model.decode(out_arc, out_arc_tag, mask=masks, length=lengths,
|
| 400 |
+
leading_symbolic=prepare_data.NUM_SYMBOLIC_TAGS)
|
| 401 |
+
lengths = lengths.cpu().numpy()
|
| 402 |
+
word = word.data.cpu().numpy()
|
| 403 |
+
pos = pos.data.cpu().numpy()
|
| 404 |
+
ner = ner.data.cpu().numpy()
|
| 405 |
+
heads = heads.data.cpu().numpy()
|
| 406 |
+
arc_tags = arc_tags.data.cpu().numpy()
|
| 407 |
+
heads_pred = heads_pred.data.cpu().numpy()
|
| 408 |
+
arc_tags_pred = arc_tags_pred.data.cpu().numpy()
|
| 409 |
+
stats, stats_nopunc, stats_root, num_inst = parse.eval_(word, pos, heads_pred, arc_tags_pred, heads,
|
| 410 |
+
arc_tags, args.alphabets['word_alphabet'], args.alphabets['pos_alphabet'],
|
| 411 |
+
lengths, punct_set=args.punct_set, symbolic_root=True)
|
| 412 |
+
ucorr, lcorr, total, ucm, lcm = stats
|
| 413 |
+
ucorr_nopunc, lcorr_nopunc, total_nopunc, ucm_nopunc, lcm_nopunc = stats_nopunc
|
| 414 |
+
corr_root, total_root = stats_root
|
| 415 |
+
eval_dict['dp_ucorrect'] += ucorr
|
| 416 |
+
eval_dict['dp_lcorrect'] += lcorr
|
| 417 |
+
eval_dict['dp_total'] += total
|
| 418 |
+
eval_dict['dp_ucomplete_match'] += ucm
|
| 419 |
+
eval_dict['dp_lcomplete_match'] += lcm
|
| 420 |
+
eval_dict['dp_ucorrect_nopunc'] += ucorr_nopunc
|
| 421 |
+
eval_dict['dp_lcorrect_nopunc'] += lcorr_nopunc
|
| 422 |
+
eval_dict['dp_total_nopunc'] += total_nopunc
|
| 423 |
+
eval_dict['dp_ucomplete_match_nopunc'] += ucm_nopunc
|
| 424 |
+
eval_dict['dp_lcomplete_match_nopunc'] += lcm_nopunc
|
| 425 |
+
eval_dict['dp_root_correct'] += corr_root
|
| 426 |
+
eval_dict['dp_total_root'] += total_root
|
| 427 |
+
eval_dict['dp_total_inst'] += num_inst
|
| 428 |
+
|
| 429 |
+
eval_dict['dp_uas'] = eval_dict['dp_ucorrect'] * 100 / eval_dict['dp_total'] # considering w. punctuation
|
| 430 |
+
eval_dict['dp_las'] = eval_dict['dp_lcorrect'] * 100 / eval_dict['dp_total'] # considering w. punctuation
|
| 431 |
+
print_results(eval_dict, split, domain, str_res)
|
| 432 |
+
return eval_dict
|
| 433 |
+
|
| 434 |
+
|
| 435 |
+
def print_results(eval_dict, split, domain, str_res='results'):
|
| 436 |
+
print('----------------------------------------------------------------------------------------------------------------------------')
|
| 437 |
+
print('Testing model on domain %s' % domain)
|
| 438 |
+
print('--------------- Dependency Parsing - %s ---------------' % split)
|
| 439 |
+
print(
|
| 440 |
+
str_res + ' on ' + split + ' W. Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%% (epoch: %d)' % (
|
| 441 |
+
eval_dict['dp_ucorrect'], eval_dict['dp_lcorrect'], eval_dict['dp_total'],
|
| 442 |
+
eval_dict['dp_ucorrect'] * 100 / eval_dict['dp_total'],
|
| 443 |
+
eval_dict['dp_lcorrect'] * 100 / eval_dict['dp_total'],
|
| 444 |
+
eval_dict['dp_ucomplete_match'] * 100 / eval_dict['dp_total_inst'],
|
| 445 |
+
eval_dict['dp_lcomplete_match'] * 100 / eval_dict['dp_total_inst'],
|
| 446 |
+
eval_dict['epoch']))
|
| 447 |
+
print(
|
| 448 |
+
str_res + ' on ' + split + ' Wo Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%% (epoch: %d)' % (
|
| 449 |
+
eval_dict['dp_ucorrect_nopunc'], eval_dict['dp_lcorrect_nopunc'], eval_dict['dp_total_nopunc'],
|
| 450 |
+
eval_dict['dp_ucorrect_nopunc'] * 100 / eval_dict['dp_total_nopunc'],
|
| 451 |
+
eval_dict['dp_lcorrect_nopunc'] * 100 / eval_dict['dp_total_nopunc'],
|
| 452 |
+
eval_dict['dp_ucomplete_match_nopunc'] * 100 / eval_dict['dp_total_inst'],
|
| 453 |
+
eval_dict['dp_lcomplete_match_nopunc'] * 100 / eval_dict['dp_total_inst'],
|
| 454 |
+
eval_dict['epoch']))
|
| 455 |
+
print(str_res + ' on ' + split + ' Root: corr: %d, total: %d, acc: %.2f%% (epoch: %d)' % (
|
| 456 |
+
eval_dict['dp_root_correct'], eval_dict['dp_total_root'],
|
| 457 |
+
eval_dict['dp_root_correct'] * 100 / eval_dict['dp_total_root'], eval_dict['epoch']))
|
| 458 |
+
print('\n')
|
| 459 |
+
|
| 460 |
+
def write_results(args, data, data_domain, split, model, model_domain, eval_dict):
|
| 461 |
+
str_file = args.full_model_name + '_' + split + '_model_domain_' + model_domain + '_data_domain_' + data_domain
|
| 462 |
+
res_filename = str_file + '_res.txt'
|
| 463 |
+
pred_filename = str_file + '_pred.txt'
|
| 464 |
+
gold_filename = str_file + '_gold.txt'
|
| 465 |
+
if eval_dict is not None:
|
| 466 |
+
# save results dictionary into a file
|
| 467 |
+
with open(res_filename, 'w') as f:
|
| 468 |
+
json.dump(eval_dict, f)
|
| 469 |
+
|
| 470 |
+
# save predictions and gold labels into files
|
| 471 |
+
pred_writer = Writer(args.alphabets)
|
| 472 |
+
gold_writer = Writer(args.alphabets)
|
| 473 |
+
pred_writer.start(pred_filename)
|
| 474 |
+
gold_writer.start(gold_filename)
|
| 475 |
+
for batch in prepare_data.iterate_batch(data, args.batch_size, args.device):
|
| 476 |
+
word, char, pos, ner, heads, arc_tags, auto_label, masks, lengths = batch
|
| 477 |
+
out_arc, out_arc_tag, masks, lengths = model.forward(word, char, pos, mask=masks, length=lengths)
|
| 478 |
+
heads_pred, arc_tags_pred, _ = model.decode(out_arc, out_arc_tag, mask=masks, length=lengths,
|
| 479 |
+
leading_symbolic=prepare_data.NUM_SYMBOLIC_TAGS)
|
| 480 |
+
lengths = lengths.cpu().numpy()
|
| 481 |
+
word = word.data.cpu().numpy()
|
| 482 |
+
pos = pos.data.cpu().numpy()
|
| 483 |
+
ner = ner.data.cpu().numpy()
|
| 484 |
+
heads = heads.data.cpu().numpy()
|
| 485 |
+
arc_tags = arc_tags.data.cpu().numpy()
|
| 486 |
+
heads_pred = heads_pred.data.cpu().numpy()
|
| 487 |
+
arc_tags_pred = arc_tags_pred.data.cpu().numpy()
|
| 488 |
+
# writing predictions
|
| 489 |
+
pred_writer.write(word, pos, ner, heads_pred, arc_tags_pred, lengths, symbolic_root=True)
|
| 490 |
+
# writing gold labels
|
| 491 |
+
gold_writer.write(word, pos, ner, heads, arc_tags, lengths, symbolic_root=True)
|
| 492 |
+
|
| 493 |
+
pred_writer.close()
|
| 494 |
+
gold_writer.close()
|
| 495 |
+
|
| 496 |
+
def main():
|
| 497 |
+
logger.info("Reading and creating arguments")
|
| 498 |
+
args = read_arguments()
|
| 499 |
+
logger.info("Reading Data")
|
| 500 |
+
datasets = {}
|
| 501 |
+
for split in args.splits:
|
| 502 |
+
print("Splits are:",split)
|
| 503 |
+
dataset = prepare_data.read_data_to_variable(args.data_paths[split], args.alphabets, args.device,
|
| 504 |
+
symbolic_root=True)
|
| 505 |
+
datasets[split] = dataset
|
| 506 |
+
if args.set_num_training_samples is not None:
|
| 507 |
+
print('Setting train and dev to %d samples' % args.set_num_training_samples)
|
| 508 |
+
datasets = rearrange_splits.rearranging_splits(datasets, args.set_num_training_samples)
|
| 509 |
+
logger.info("Creating Networks")
|
| 510 |
+
num_data = sum(datasets['train'][1])
|
| 511 |
+
model, optimizer, dev_eval_dict, test_eval_dict, start_epoch = build_model_and_optimizer(args)
|
| 512 |
+
best_model = deepcopy(model)
|
| 513 |
+
best_optimizer = deepcopy(optimizer)
|
| 514 |
+
|
| 515 |
+
logger.info('Training INFO of in domain %s' % args.domain)
|
| 516 |
+
logger.info('Training on Dependecy Parsing')
|
| 517 |
+
logger.info("train: gamma: %f, batch: %d, clip: %.2f, unk replace: %.2f" % (args.gamma, args.batch_size, args.clip, args.unk_replace))
|
| 518 |
+
logger.info('number of training samples for %s is: %d' % (args.domain, num_data))
|
| 519 |
+
logger.info("dropout(in, out, rnn): (%.2f, %.2f, %s)" % (args.p_in, args.p_out, args.p_rnn))
|
| 520 |
+
logger.info("num_epochs: %d" % (args.num_epochs))
|
| 521 |
+
print('\n')
|
| 522 |
+
|
| 523 |
+
if not args.eval_mode:
|
| 524 |
+
logger.info("Training")
|
| 525 |
+
num_batches = prepare_data.calc_num_batches(datasets['train'], args.batch_size)
|
| 526 |
+
lr = args.learning_rate
|
| 527 |
+
patient = 0
|
| 528 |
+
decay = 0
|
| 529 |
+
for epoch in range(start_epoch + 1, args.num_epochs + 1):
|
| 530 |
+
print('Epoch %d (Training: rnn mode: %s, optimizer: %s, learning rate=%.6f, eps=%.1e, decay rate=%.2f (schedule=%d, decay=%d)): ' % (
|
| 531 |
+
epoch, args.rnn_mode, args.opt, lr, args.epsilon, args.decay_rate, args.schedule, decay))
|
| 532 |
+
model.train()
|
| 533 |
+
total_loss = 0.0
|
| 534 |
+
total_arc_loss = 0.0
|
| 535 |
+
total_arc_tag_loss = 0.0
|
| 536 |
+
total_train_inst = 0.0
|
| 537 |
+
|
| 538 |
+
train_iter = prepare_data.iterate_batch_rand_bucket_choosing(
|
| 539 |
+
datasets['train'], args.batch_size, args.device, unk_replace=args.unk_replace)
|
| 540 |
+
start_time = time.time()
|
| 541 |
+
batch_num = 0
|
| 542 |
+
for batch_num, batch in enumerate(train_iter):
|
| 543 |
+
batch_num = batch_num + 1
|
| 544 |
+
optimizer.zero_grad()
|
| 545 |
+
# compute loss of main task
|
| 546 |
+
word, char, pos, ner_tags, heads, arc_tags, auto_label, masks, lengths = batch
|
| 547 |
+
out_arc, out_arc_tag, masks, lengths = model.forward(word, char, pos, mask=masks, length=lengths)
|
| 548 |
+
loss_arc, loss_arc_tag = model.loss(out_arc, out_arc_tag, heads, arc_tags, mask=masks, length=lengths)
|
| 549 |
+
loss = loss_arc + loss_arc_tag
|
| 550 |
+
|
| 551 |
+
# update losses
|
| 552 |
+
num_insts = masks.data.sum() - word.size(0)
|
| 553 |
+
total_arc_loss += loss_arc.item() * num_insts
|
| 554 |
+
total_arc_tag_loss += loss_arc_tag.item() * num_insts
|
| 555 |
+
total_loss += loss.item() * num_insts
|
| 556 |
+
total_train_inst += num_insts
|
| 557 |
+
# optimize parameters
|
| 558 |
+
loss.backward()
|
| 559 |
+
clip_grad_norm_(model.parameters(), args.clip)
|
| 560 |
+
optimizer.step()
|
| 561 |
+
|
| 562 |
+
time_ave = (time.time() - start_time) / batch_num
|
| 563 |
+
time_left = (num_batches - batch_num) * time_ave
|
| 564 |
+
|
| 565 |
+
# update log
|
| 566 |
+
if batch_num % 50 == 0:
|
| 567 |
+
log_info = 'train: %d/%d, domain: %s, total loss: %.2f, arc_loss: %.2f, arc_tag_loss: %.2f, time left: %.2fs' % \
|
| 568 |
+
(batch_num, num_batches, args.domain, total_loss / total_train_inst, total_arc_loss / total_train_inst,
|
| 569 |
+
total_arc_tag_loss / total_train_inst, time_left)
|
| 570 |
+
sys.stdout.write(log_info)
|
| 571 |
+
sys.stdout.write('\n')
|
| 572 |
+
sys.stdout.flush()
|
| 573 |
+
print('\n')
|
| 574 |
+
print('train: %d/%d, domain: %s, total_loss: %.2f, arc_loss: %.2f, arc_tag_loss: %.2f, time: %.2fs' %
|
| 575 |
+
(batch_num, num_batches, args.domain, total_loss / total_train_inst, total_arc_loss / total_train_inst,
|
| 576 |
+
total_arc_tag_loss / total_train_inst, time.time() - start_time))
|
| 577 |
+
|
| 578 |
+
dev_eval_dict, test_eval_dict, best_model, best_optimizer, patient = in_domain_evaluation(args, datasets, model, optimizer, dev_eval_dict, test_eval_dict, epoch, best_model, best_optimizer, patient)
|
| 579 |
+
if patient >= args.schedule:
|
| 580 |
+
lr = args.learning_rate / (1.0 + epoch * args.decay_rate)
|
| 581 |
+
optimizer = generate_optimizer(args, lr, model.parameters())
|
| 582 |
+
print('updated learning rate to %.6f' % lr)
|
| 583 |
+
patient = 0
|
| 584 |
+
print_results(test_eval_dict['in_domain'], 'test', args.domain, 'best_results')
|
| 585 |
+
print('\n')
|
| 586 |
+
for split in datasets.keys():
|
| 587 |
+
eval_dict = evaluation(args, datasets[split], split, best_model, args.domain, epoch, 'best_results')
|
| 588 |
+
write_results(args, datasets[split], args.domain, split, model, args.domain, eval_dict)
|
| 589 |
+
|
| 590 |
+
else:
|
| 591 |
+
logger.info("Evaluating")
|
| 592 |
+
epoch = start_epoch
|
| 593 |
+
for split in ['train', 'dev', 'test','poetry','prose']:
|
| 594 |
+
eval_dict = evaluation(args, datasets[split], split, model, args.domain, epoch, 'best_results')
|
| 595 |
+
write_results(args, datasets[split], args.domain, split, model, args.domain, eval_dict)
|
| 596 |
+
|
| 597 |
+
|
| 598 |
+
if __name__ == '__main__':
|
| 599 |
+
main()
|
examples/GraphParser_MTL_POS.py
ADDED
|
@@ -0,0 +1,633 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import print_function
|
| 2 |
+
import sys
|
| 3 |
+
from os import path, makedirs
|
| 4 |
+
|
| 5 |
+
sys.path.append(".")
|
| 6 |
+
sys.path.append("..")
|
| 7 |
+
|
| 8 |
+
import argparse
|
| 9 |
+
from copy import deepcopy
|
| 10 |
+
import json
|
| 11 |
+
import numpy as np
|
| 12 |
+
import torch
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
from collections import namedtuple
|
| 15 |
+
from utils.io_ import seeds, Writer, get_logger, prepare_data, rearrange_splits
|
| 16 |
+
from utils.models.parsing_gating_mtl_pos import BiAffine_Parser_Gated
|
| 17 |
+
# from utils.models.sequence_tagger import Sequence_Tagger
|
| 18 |
+
from utils import load_word_embeddings
|
| 19 |
+
from utils.tasks import parse
|
| 20 |
+
import time
|
| 21 |
+
from torch.nn.utils import clip_grad_norm_
|
| 22 |
+
from torch.optim import Adam, SGD
|
| 23 |
+
import uuid
|
| 24 |
+
|
| 25 |
+
uid = uuid.uuid4().hex[:6]
|
| 26 |
+
|
| 27 |
+
logger = get_logger('GraphParser')
|
| 28 |
+
|
| 29 |
+
def read_arguments():
|
| 30 |
+
args_ = argparse.ArgumentParser(description='Sovling GraphParser')
|
| 31 |
+
args_.add_argument('--dataset', choices=['ontonotes', 'ud'], help='Dataset', required=True)
|
| 32 |
+
args_.add_argument('--domain', help='domain/language', required=True)
|
| 33 |
+
args_.add_argument('--rnn_mode', choices=['RNN', 'LSTM', 'GRU'], help='architecture of rnn',
|
| 34 |
+
required=True)
|
| 35 |
+
args_.add_argument('--gating',action='store_true', help='use gated mechanism')
|
| 36 |
+
args_.add_argument('--num_gates', type=int, default=0, help='number of gates for gating mechanism')
|
| 37 |
+
args_.add_argument('--num_epochs', type=int, default=200, help='Number of training epochs')
|
| 38 |
+
args_.add_argument('--batch_size', type=int, default=64, help='Number of sentences in each batch')
|
| 39 |
+
args_.add_argument('--hidden_size', type=int, default=256, help='Number of hidden units in RNN')
|
| 40 |
+
args_.add_argument('--arc_space', type=int, default=128, help='Dimension of tag space')
|
| 41 |
+
args_.add_argument('--arc_tag_space', type=int, default=128, help='Dimension of tag space')
|
| 42 |
+
args_.add_argument('--num_layers', type=int, default=1, help='Number of layers of RNN')
|
| 43 |
+
args_.add_argument('--num_filters', type=int, default=50, help='Number of filters in CNN')
|
| 44 |
+
args_.add_argument('--kernel_size', type=int, default=3, help='Size of Kernel for CNN')
|
| 45 |
+
args_.add_argument('--use_pos', action='store_true', help='use part-of-speech embedding.')
|
| 46 |
+
args_.add_argument('--use_char', action='store_true', help='use character embedding and CNN.')
|
| 47 |
+
args_.add_argument('--word_dim', type=int, default=300, help='Dimension of word embeddings')
|
| 48 |
+
args_.add_argument('--pos_dim', type=int, default=50, help='Dimension of POS embeddings')
|
| 49 |
+
args_.add_argument('--char_dim', type=int, default=50, help='Dimension of Character embeddings')
|
| 50 |
+
args_.add_argument('--initializer', choices=['xavier'], help='initialize model parameters')
|
| 51 |
+
args_.add_argument('--opt', choices=['adam', 'sgd'], help='optimization algorithm')
|
| 52 |
+
args_.add_argument('--momentum', type=float, default=0.9, help='momentum of optimizer')
|
| 53 |
+
args_.add_argument('--betas', nargs=2, type=float, default=[0.9, 0.9], help='betas of optimizer')
|
| 54 |
+
args_.add_argument('--learning_rate', type=float, default=0.01, help='Learning rate')
|
| 55 |
+
args_.add_argument('--decay_rate', type=float, default=0.05, help='Decay rate of learning rate')
|
| 56 |
+
args_.add_argument('--schedule', type=int, help='schedule for learning rate decay')
|
| 57 |
+
args_.add_argument('--clip', type=float, default=5.0, help='gradient clipping')
|
| 58 |
+
args_.add_argument('--gamma', type=float, default=0.0, help='weight for regularization')
|
| 59 |
+
args_.add_argument('--epsilon', type=float, default=1e-8, help='epsilon for adam')
|
| 60 |
+
args_.add_argument('--p_rnn', nargs=2, type=float, required=True, help='dropout rate for RNN')
|
| 61 |
+
args_.add_argument('--p_in', type=float, default=0.33, help='dropout rate for input embeddings')
|
| 62 |
+
args_.add_argument('--p_out', type=float, default=0.33, help='dropout rate for output layer')
|
| 63 |
+
args_.add_argument('--arc_decode', choices=['mst', 'greedy'], help='arc decoding algorithm', required=True)
|
| 64 |
+
args_.add_argument('--unk_replace', type=float, default=0.,
|
| 65 |
+
help='The rate to replace a singleton word with UNK')
|
| 66 |
+
args_.add_argument('--punct_set', nargs='+', type=str, help='List of punctuations')
|
| 67 |
+
args_.add_argument('--word_embedding', choices=['random', 'glove', 'fasttext', 'word2vec'],
|
| 68 |
+
help='Embedding for words')
|
| 69 |
+
args_.add_argument('--word_path', help='path for word embedding dict - in case word_embedding is not random')
|
| 70 |
+
args_.add_argument('--freeze_word_embeddings', action='store_true', help='frozen the word embedding (disable fine-tuning).')
|
| 71 |
+
args_.add_argument('--freeze_sequence_taggers', action='store_true', help='frozen the BiLSTMs of the pre-trained taggers.')
|
| 72 |
+
args_.add_argument('--char_embedding', choices=['random','hellwig'], help='Embedding for characters',
|
| 73 |
+
required=True)
|
| 74 |
+
args_.add_argument('--pos_embedding', choices=['random','one_hot'], help='Embedding for pos',
|
| 75 |
+
required=True)
|
| 76 |
+
args_.add_argument('--char_path', help='path for character embedding dict')
|
| 77 |
+
args_.add_argument('--pos_path', help='path for pos embedding dict')
|
| 78 |
+
args_.add_argument('--set_num_training_samples', type=int, help='downsampling training set to a fixed number of samples')
|
| 79 |
+
args_.add_argument('--model_path', help='path for saving model file.', required=True)
|
| 80 |
+
args_.add_argument('--load_path', help='path for loading saved source model file.', default=None)
|
| 81 |
+
args_.add_argument('--load_sequence_taggers_paths', nargs='+', help='path for loading saved sequence_tagger saved_models files.', default=None)
|
| 82 |
+
args_.add_argument('--strict',action='store_true', help='if True loaded model state should contin '
|
| 83 |
+
'exactly the same keys as current model')
|
| 84 |
+
args_.add_argument('--eval_mode', action='store_true', help='evaluating model without training it')
|
| 85 |
+
args = args_.parse_args()
|
| 86 |
+
args_dict = {}
|
| 87 |
+
args_dict['dataset'] = args.dataset
|
| 88 |
+
args_dict['domain'] = args.domain
|
| 89 |
+
args_dict['rnn_mode'] = args.rnn_mode
|
| 90 |
+
args_dict['gating'] = args.gating
|
| 91 |
+
args_dict['num_gates'] = args.num_gates
|
| 92 |
+
args_dict['arc_decode'] = args.arc_decode
|
| 93 |
+
# args_dict['splits'] = ['train', 'dev', 'test']
|
| 94 |
+
args_dict['splits'] = ['train', 'dev', 'test','poetry','prose']
|
| 95 |
+
args_dict['model_path'] = args.model_path
|
| 96 |
+
if not path.exists(args_dict['model_path']):
|
| 97 |
+
makedirs(args_dict['model_path'])
|
| 98 |
+
args_dict['data_paths'] = {}
|
| 99 |
+
if args_dict['dataset'] == 'ontonotes':
|
| 100 |
+
data_path = 'data/onto_pos_ner_dp'
|
| 101 |
+
else:
|
| 102 |
+
data_path = 'data/ud_pos_ner_dp'
|
| 103 |
+
for split in args_dict['splits']:
|
| 104 |
+
args_dict['data_paths'][split] = data_path + '_' + split + '_' + args_dict['domain']
|
| 105 |
+
###################################
|
| 106 |
+
args_dict['data_paths']['poetry'] = data_path + '_' + 'poetry' + '_' + args_dict['domain']
|
| 107 |
+
args_dict['data_paths']['prose'] = data_path + '_' + 'prose' + '_' + args_dict['domain']
|
| 108 |
+
# args_dict['data_paths']['poetry'] = 'data/Shishu_300' + '_' + 'poetry' + '_' + args_dict['domain']
|
| 109 |
+
# args_dict['data_paths']['prose'] = 'data/Shishu_300' + '_' + 'prose' + '_' + args_dict['domain']
|
| 110 |
+
|
| 111 |
+
###################################
|
| 112 |
+
args_dict['alphabet_data_paths'] = {}
|
| 113 |
+
for split in args_dict['splits']:
|
| 114 |
+
if args_dict['dataset'] == 'ontonotes':
|
| 115 |
+
args_dict['alphabet_data_paths'][split] = data_path + '_' + split + '_' + 'all'
|
| 116 |
+
else:
|
| 117 |
+
if '_' in args_dict['domain']:
|
| 118 |
+
args_dict['alphabet_data_paths'][split] = data_path + '_' + split + '_' + args_dict['domain'].split('_')[0]
|
| 119 |
+
else:
|
| 120 |
+
args_dict['alphabet_data_paths'][split] = args_dict['data_paths'][split]
|
| 121 |
+
args_dict['model_name'] = 'domain_' + args_dict['domain']
|
| 122 |
+
args_dict['full_model_name'] = path.join(args_dict['model_path'],args_dict['model_name'])
|
| 123 |
+
args_dict['load_path'] = args.load_path
|
| 124 |
+
args_dict['load_sequence_taggers_paths'] = args.load_sequence_taggers_paths
|
| 125 |
+
if args_dict['load_sequence_taggers_paths'] is not None:
|
| 126 |
+
args_dict['gating'] = True
|
| 127 |
+
args_dict['num_gates'] = len(args_dict['load_sequence_taggers_paths']) + 1
|
| 128 |
+
else:
|
| 129 |
+
if not args_dict['gating']:
|
| 130 |
+
args_dict['num_gates'] = 0
|
| 131 |
+
args_dict['strict'] = args.strict
|
| 132 |
+
args_dict['num_epochs'] = args.num_epochs
|
| 133 |
+
args_dict['batch_size'] = args.batch_size
|
| 134 |
+
args_dict['hidden_size'] = args.hidden_size
|
| 135 |
+
args_dict['arc_space'] = args.arc_space
|
| 136 |
+
args_dict['arc_tag_space'] = args.arc_tag_space
|
| 137 |
+
args_dict['num_layers'] = args.num_layers
|
| 138 |
+
args_dict['num_filters'] = args.num_filters
|
| 139 |
+
args_dict['kernel_size'] = args.kernel_size
|
| 140 |
+
args_dict['learning_rate'] = args.learning_rate
|
| 141 |
+
args_dict['initializer'] = nn.init.xavier_uniform_ if args.initializer == 'xavier' else None
|
| 142 |
+
args_dict['opt'] = args.opt
|
| 143 |
+
args_dict['momentum'] = args.momentum
|
| 144 |
+
args_dict['betas'] = tuple(args.betas)
|
| 145 |
+
args_dict['epsilon'] = args.epsilon
|
| 146 |
+
args_dict['decay_rate'] = args.decay_rate
|
| 147 |
+
args_dict['clip'] = args.clip
|
| 148 |
+
args_dict['gamma'] = args.gamma
|
| 149 |
+
args_dict['schedule'] = args.schedule
|
| 150 |
+
args_dict['p_rnn'] = tuple(args.p_rnn)
|
| 151 |
+
args_dict['p_in'] = args.p_in
|
| 152 |
+
args_dict['p_out'] = args.p_out
|
| 153 |
+
args_dict['unk_replace'] = args.unk_replace
|
| 154 |
+
args_dict['set_num_training_samples'] = args.set_num_training_samples
|
| 155 |
+
args_dict['punct_set'] = None
|
| 156 |
+
if args.punct_set is not None:
|
| 157 |
+
args_dict['punct_set'] = set(args.punct_set)
|
| 158 |
+
logger.info("punctuations(%d): %s" % (len(args_dict['punct_set']), ' '.join(args_dict['punct_set'])))
|
| 159 |
+
args_dict['freeze_word_embeddings'] = args.freeze_word_embeddings
|
| 160 |
+
args_dict['freeze_sequence_taggers'] = args.freeze_sequence_taggers
|
| 161 |
+
args_dict['word_embedding'] = args.word_embedding
|
| 162 |
+
args_dict['word_path'] = args.word_path
|
| 163 |
+
args_dict['use_char'] = args.use_char
|
| 164 |
+
args_dict['char_embedding'] = args.char_embedding
|
| 165 |
+
args_dict['char_path'] = args.char_path
|
| 166 |
+
args_dict['pos_embedding'] = args.pos_embedding
|
| 167 |
+
args_dict['pos_path'] = args.pos_path
|
| 168 |
+
args_dict['use_pos'] = args.use_pos
|
| 169 |
+
args_dict['pos_dim'] = args.pos_dim
|
| 170 |
+
args_dict['word_dict'] = None
|
| 171 |
+
args_dict['word_dim'] = args.word_dim
|
| 172 |
+
if args_dict['word_embedding'] != 'random' and args_dict['word_path']:
|
| 173 |
+
args_dict['word_dict'], args_dict['word_dim'] = load_word_embeddings.load_embedding_dict(args_dict['word_embedding'],
|
| 174 |
+
args_dict['word_path'])
|
| 175 |
+
args_dict['char_dict'] = None
|
| 176 |
+
args_dict['char_dim'] = args.char_dim
|
| 177 |
+
if args_dict['char_embedding'] != 'random':
|
| 178 |
+
args_dict['char_dict'], args_dict['char_dim'] = load_word_embeddings.load_embedding_dict(args_dict['char_embedding'],
|
| 179 |
+
args_dict['char_path'])
|
| 180 |
+
args_dict['pos_dict'] = None
|
| 181 |
+
if args_dict['pos_embedding'] != 'random':
|
| 182 |
+
args_dict['pos_dict'], args_dict['pos_dim'] = load_word_embeddings.load_embedding_dict(args_dict['pos_embedding'],
|
| 183 |
+
args_dict['pos_path'])
|
| 184 |
+
args_dict['alphabet_path'] = path.join(args_dict['model_path'], 'alphabets' + '_src_domain_' + args_dict['domain'] + '/')
|
| 185 |
+
args_dict['model_name'] = path.join(args_dict['model_path'], args_dict['model_name'])
|
| 186 |
+
args_dict['eval_mode'] = args.eval_mode
|
| 187 |
+
args_dict['device'] = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 188 |
+
args_dict['word_status'] = 'frozen' if args.freeze_word_embeddings else 'fine tune'
|
| 189 |
+
args_dict['char_status'] = 'enabled' if args.use_char else 'disabled'
|
| 190 |
+
args_dict['pos_status'] = 'enabled' if args.use_pos else 'disabled'
|
| 191 |
+
logger.info("Saving arguments to file")
|
| 192 |
+
save_args(args, args_dict['full_model_name'])
|
| 193 |
+
logger.info("Creating Alphabets")
|
| 194 |
+
alphabet_dict = creating_alphabets(args_dict['alphabet_path'], args_dict['alphabet_data_paths'], args_dict['word_dict'])
|
| 195 |
+
args_dict = {**args_dict, **alphabet_dict}
|
| 196 |
+
ARGS = namedtuple('ARGS', args_dict.keys())
|
| 197 |
+
my_args = ARGS(**args_dict)
|
| 198 |
+
return my_args
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
def creating_alphabets(alphabet_path, alphabet_data_paths, word_dict):
|
| 202 |
+
train_paths = alphabet_data_paths['train']
|
| 203 |
+
extra_paths = [v for k,v in alphabet_data_paths.items() if k != 'train']
|
| 204 |
+
alphabet_dict = {}
|
| 205 |
+
alphabet_dict['alphabets'] = prepare_data.create_alphabets(alphabet_path,
|
| 206 |
+
train_paths,
|
| 207 |
+
extra_paths=extra_paths,
|
| 208 |
+
max_vocabulary_size=100000,
|
| 209 |
+
embedd_dict=word_dict)
|
| 210 |
+
for k, v in alphabet_dict['alphabets'].items():
|
| 211 |
+
num_key = 'num_' + k.split('_')[0]
|
| 212 |
+
alphabet_dict[num_key] = v.size()
|
| 213 |
+
logger.info("%s : %d" % (num_key, alphabet_dict[num_key]))
|
| 214 |
+
return alphabet_dict
|
| 215 |
+
|
| 216 |
+
def construct_embedding_table(alphabet, tokens_dict, dim, token_type='word'):
|
| 217 |
+
if tokens_dict is None:
|
| 218 |
+
return None
|
| 219 |
+
scale = np.sqrt(3.0 / dim)
|
| 220 |
+
table = np.empty([alphabet.size(), dim], dtype=np.float32)
|
| 221 |
+
table[prepare_data.UNK_ID, :] = np.random.uniform(-scale, scale, [1, dim]).astype(np.float32)
|
| 222 |
+
oov_tokens = 0
|
| 223 |
+
for token, index in alphabet.items():
|
| 224 |
+
if token in tokens_dict:
|
| 225 |
+
embedding = tokens_dict[token]
|
| 226 |
+
elif token.lower() in tokens_dict:
|
| 227 |
+
embedding = tokens_dict[token.lower()]
|
| 228 |
+
else:
|
| 229 |
+
embedding = np.random.uniform(-scale, scale, [1, dim]).astype(np.float32)
|
| 230 |
+
oov_tokens += 1
|
| 231 |
+
table[index, :] = embedding
|
| 232 |
+
print('token type : %s, number of oov: %d' % (token_type, oov_tokens))
|
| 233 |
+
table = torch.from_numpy(table)
|
| 234 |
+
return table
|
| 235 |
+
|
| 236 |
+
def save_args(args, full_model_name):
|
| 237 |
+
arg_path = full_model_name + '.arg.json'
|
| 238 |
+
argparse_dict = vars(args)
|
| 239 |
+
with open(arg_path, 'w') as f:
|
| 240 |
+
json.dump(argparse_dict, f)
|
| 241 |
+
|
| 242 |
+
def generate_optimizer(args, lr, params):
|
| 243 |
+
params = filter(lambda param: param.requires_grad, params)
|
| 244 |
+
if args.opt == 'adam':
|
| 245 |
+
return Adam(params, lr=lr, betas=args.betas, weight_decay=args.gamma, eps=args.epsilon)
|
| 246 |
+
elif args.opt == 'sgd':
|
| 247 |
+
return SGD(params, lr=lr, momentum=args.momentum, weight_decay=args.gamma, nesterov=True)
|
| 248 |
+
else:
|
| 249 |
+
raise ValueError('Unknown optimization algorithm: %s' % args.opt)
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
def save_checkpoint(model, optimizer, opt, dev_eval_dict, test_eval_dict, full_model_name):
|
| 253 |
+
path_name = full_model_name + '.pt'
|
| 254 |
+
print('Saving model to: %s' % path_name)
|
| 255 |
+
state = {'model_state_dict': model.state_dict(),
|
| 256 |
+
'optimizer_state_dict': optimizer.state_dict(),
|
| 257 |
+
'opt': opt,
|
| 258 |
+
'dev_eval_dict': dev_eval_dict,
|
| 259 |
+
'test_eval_dict': test_eval_dict}
|
| 260 |
+
torch.save(state, path_name)
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
def load_checkpoint(args, model, optimizer, dev_eval_dict, test_eval_dict, start_epoch, load_path, strict=False):
|
| 264 |
+
print('Loading saved model from: %s' % load_path)
|
| 265 |
+
checkpoint = torch.load(load_path, map_location=args.device)
|
| 266 |
+
if checkpoint['opt'] != args.opt:
|
| 267 |
+
raise ValueError('loaded optimizer type is: %s instead of: %s' % (checkpoint['opt'], args.opt))
|
| 268 |
+
model.load_state_dict(checkpoint['model_state_dict'], strict=strict)
|
| 269 |
+
|
| 270 |
+
if strict:
|
| 271 |
+
generate_optimizer(args, args.learning_rate, model.parameters())
|
| 272 |
+
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
| 273 |
+
for state in optimizer.state.values():
|
| 274 |
+
for k, v in state.items():
|
| 275 |
+
if isinstance(v, torch.Tensor):
|
| 276 |
+
state[k] = v.to(args.device)
|
| 277 |
+
dev_eval_dict = checkpoint['dev_eval_dict']
|
| 278 |
+
test_eval_dict = checkpoint['test_eval_dict']
|
| 279 |
+
start_epoch = dev_eval_dict['in_domain']['epoch']
|
| 280 |
+
return model, optimizer, dev_eval_dict, test_eval_dict, start_epoch
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
def build_model_and_optimizer(args):
|
| 284 |
+
word_table = construct_embedding_table(args.alphabets['word_alphabet'], args.word_dict, args.word_dim, token_type='word')
|
| 285 |
+
char_table = construct_embedding_table(args.alphabets['char_alphabet'], args.char_dict, args.char_dim, token_type='char')
|
| 286 |
+
pos_table = construct_embedding_table(args.alphabets['pos_alphabet'], args.pos_dict, args.pos_dim, token_type='pos')
|
| 287 |
+
model = BiAffine_Parser_Gated(args.word_dim, args.num_word, args.char_dim, args.num_char,
|
| 288 |
+
args.use_pos, args.use_char, args.pos_dim, args.num_pos,
|
| 289 |
+
args.num_filters, args.kernel_size, args.rnn_mode,
|
| 290 |
+
args.hidden_size, args.num_layers, args.num_arc,
|
| 291 |
+
args.arc_space, args.arc_tag_space, args.num_gates,
|
| 292 |
+
embedd_word=word_table, embedd_char=char_table, embedd_pos=pos_table,
|
| 293 |
+
p_in=args.p_in, p_out=args.p_out, p_rnn=args.p_rnn,
|
| 294 |
+
biaffine=True, arc_decode=args.arc_decode, initializer=args.initializer)
|
| 295 |
+
# Tagger = Sequence_Tagger(args.word_dim, args.num_word, args.char_dim, args.num_char,
|
| 296 |
+
# args.use_pos, args.use_char, args.pos_dim, args.num_pos,
|
| 297 |
+
# args.num_filters, args.kernel_size, args.rnn_mode,
|
| 298 |
+
# args.hidden_size, args.num_layers, args.arc_tag_space, 519,
|
| 299 |
+
# embedd_word=word_table, embedd_char=char_table, embedd_pos=pos_table,
|
| 300 |
+
# p_in=args.p_in, p_out=args.p_out, p_rnn=args.p_rnn,
|
| 301 |
+
# initializer=args.initializer)
|
| 302 |
+
print(model)
|
| 303 |
+
# print(Tagger)
|
| 304 |
+
# Tagger.rnn_encoder = model.rnn_encoder
|
| 305 |
+
optimizer = generate_optimizer(args, args.learning_rate, model.parameters())
|
| 306 |
+
start_epoch = 0
|
| 307 |
+
dev_eval_dict = {'in_domain': initialize_eval_dict()}
|
| 308 |
+
test_eval_dict = {'in_domain': initialize_eval_dict()}
|
| 309 |
+
if args.load_path:
|
| 310 |
+
model, optimizer, dev_eval_dict, test_eval_dict, start_epoch = \
|
| 311 |
+
load_checkpoint(args, model, optimizer,
|
| 312 |
+
dev_eval_dict, test_eval_dict,
|
| 313 |
+
start_epoch, args.load_path, strict=False)
|
| 314 |
+
if args.load_sequence_taggers_paths:
|
| 315 |
+
pretrained_dict = {}
|
| 316 |
+
model_dict = model.state_dict()
|
| 317 |
+
for idx, path in enumerate(args.load_sequence_taggers_paths):
|
| 318 |
+
print('Loading saved sequence_tagger from: %s' % path)
|
| 319 |
+
checkpoint = torch.load(path, map_location=args.device)
|
| 320 |
+
for k, v in checkpoint['model_state_dict'].items():
|
| 321 |
+
if 'rnn_encoder.' in k:
|
| 322 |
+
pretrained_dict['extra_rnn_encoders.' + str(idx) + '.' + k.replace('rnn_encoder.', '')] = v
|
| 323 |
+
model_dict.update(pretrained_dict)
|
| 324 |
+
model.load_state_dict(model_dict)
|
| 325 |
+
if args.freeze_sequence_taggers:
|
| 326 |
+
print('Freezing Classifiers')
|
| 327 |
+
for name, parameter in model.named_parameters():
|
| 328 |
+
if 'extra_rnn_encoders' in name:
|
| 329 |
+
parameter.requires_grad = False
|
| 330 |
+
if args.freeze_word_embeddings:
|
| 331 |
+
model.rnn_encoder.word_embedd.weight.requires_grad = False
|
| 332 |
+
# model.rnn_encoder.char_embedd.weight.requires_grad = False
|
| 333 |
+
# model.rnn_encoder.pos_embedd.weight.requires_grad = False
|
| 334 |
+
device = args.device
|
| 335 |
+
model.to(device)
|
| 336 |
+
# Tagger.to(device)
|
| 337 |
+
return model, optimizer, dev_eval_dict, test_eval_dict, start_epoch
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
def initialize_eval_dict():
|
| 341 |
+
eval_dict = {}
|
| 342 |
+
eval_dict['dp_uas'] = 0.0
|
| 343 |
+
eval_dict['dp_las'] = 0.0
|
| 344 |
+
eval_dict['epoch'] = 0
|
| 345 |
+
eval_dict['dp_ucorrect'] = 0.0
|
| 346 |
+
eval_dict['dp_lcorrect'] = 0.0
|
| 347 |
+
eval_dict['dp_total'] = 0.0
|
| 348 |
+
eval_dict['dp_ucomplete_match'] = 0.0
|
| 349 |
+
eval_dict['dp_lcomplete_match'] = 0.0
|
| 350 |
+
eval_dict['dp_ucorrect_nopunc'] = 0.0
|
| 351 |
+
eval_dict['dp_lcorrect_nopunc'] = 0.0
|
| 352 |
+
eval_dict['dp_total_nopunc'] = 0.0
|
| 353 |
+
eval_dict['dp_ucomplete_match_nopunc'] = 0.0
|
| 354 |
+
eval_dict['dp_lcomplete_match_nopunc'] = 0.0
|
| 355 |
+
eval_dict['dp_root_correct'] = 0.0
|
| 356 |
+
eval_dict['dp_total_root'] = 0.0
|
| 357 |
+
eval_dict['dp_total_inst'] = 0.0
|
| 358 |
+
eval_dict['dp_total'] = 0.0
|
| 359 |
+
eval_dict['dp_total_inst'] = 0.0
|
| 360 |
+
eval_dict['dp_total_nopunc'] = 0.0
|
| 361 |
+
eval_dict['dp_total_root'] = 0.0
|
| 362 |
+
return eval_dict
|
| 363 |
+
|
| 364 |
+
def in_domain_evaluation(args, datasets, model, optimizer, dev_eval_dict, test_eval_dict, epoch,
|
| 365 |
+
best_model, best_optimizer, patient):
|
| 366 |
+
# In-domain evaluation
|
| 367 |
+
curr_dev_eval_dict = evaluation(args, datasets['dev'], 'dev', model, args.domain, epoch, 'current_results')
|
| 368 |
+
is_best_in_domain = dev_eval_dict['in_domain']['dp_lcorrect_nopunc'] <= curr_dev_eval_dict['dp_lcorrect_nopunc'] or \
|
| 369 |
+
(dev_eval_dict['in_domain']['dp_lcorrect_nopunc'] == curr_dev_eval_dict['dp_lcorrect_nopunc'] and
|
| 370 |
+
dev_eval_dict['in_domain']['dp_ucorrect_nopunc'] <= curr_dev_eval_dict['dp_ucorrect_nopunc'])
|
| 371 |
+
|
| 372 |
+
if is_best_in_domain:
|
| 373 |
+
for key, value in curr_dev_eval_dict.items():
|
| 374 |
+
dev_eval_dict['in_domain'][key] = value
|
| 375 |
+
curr_test_eval_dict = evaluation(args, datasets['test'], 'test', model, args.domain, epoch, 'current_results')
|
| 376 |
+
for key, value in curr_test_eval_dict.items():
|
| 377 |
+
test_eval_dict['in_domain'][key] = value
|
| 378 |
+
best_model = deepcopy(model)
|
| 379 |
+
best_optimizer = deepcopy(optimizer)
|
| 380 |
+
patient = 0
|
| 381 |
+
else:
|
| 382 |
+
patient += 1
|
| 383 |
+
if epoch == args.num_epochs:
|
| 384 |
+
# save in-domain checkpoint
|
| 385 |
+
if args.set_num_training_samples is not None:
|
| 386 |
+
splits_to_write = datasets.keys()
|
| 387 |
+
else:
|
| 388 |
+
splits_to_write = ['dev', 'test']
|
| 389 |
+
for split in splits_to_write:
|
| 390 |
+
if split == 'dev':
|
| 391 |
+
eval_dict = dev_eval_dict['in_domain']
|
| 392 |
+
elif split == 'test':
|
| 393 |
+
eval_dict = test_eval_dict['in_domain']
|
| 394 |
+
else:
|
| 395 |
+
eval_dict = None
|
| 396 |
+
write_results(args, datasets[split], args.domain, split, best_model, args.domain, eval_dict)
|
| 397 |
+
print("Saving best model")
|
| 398 |
+
save_checkpoint(best_model, best_optimizer, args.opt, dev_eval_dict, test_eval_dict, args.full_model_name)
|
| 399 |
+
|
| 400 |
+
print('\n')
|
| 401 |
+
return dev_eval_dict, test_eval_dict, best_model, best_optimizer, patient
|
| 402 |
+
|
| 403 |
+
|
| 404 |
+
def evaluation(args, data, split, model, domain, epoch, str_res='results'):
|
| 405 |
+
# evaluate performance on data
|
| 406 |
+
model.eval()
|
| 407 |
+
|
| 408 |
+
eval_dict = initialize_eval_dict()
|
| 409 |
+
eval_dict['epoch'] = epoch
|
| 410 |
+
for batch in prepare_data.iterate_batch(data, args.batch_size, args.device):
|
| 411 |
+
word, char, pos, ner, heads, arc_tags, auto_label, masks, lengths = batch
|
| 412 |
+
out_arc, out_arc_tag, masks, lengths = model.forward(word, char, pos, mask=masks, length=lengths)
|
| 413 |
+
heads_pred, arc_tags_pred, _ = model.decode(out_arc, out_arc_tag, mask=masks, length=lengths,
|
| 414 |
+
leading_symbolic=prepare_data.NUM_SYMBOLIC_TAGS)
|
| 415 |
+
lengths = lengths.cpu().numpy()
|
| 416 |
+
word = word.data.cpu().numpy()
|
| 417 |
+
pos = pos.data.cpu().numpy()
|
| 418 |
+
ner = ner.data.cpu().numpy()
|
| 419 |
+
heads = heads.data.cpu().numpy()
|
| 420 |
+
arc_tags = arc_tags.data.cpu().numpy()
|
| 421 |
+
heads_pred = heads_pred.data.cpu().numpy()
|
| 422 |
+
arc_tags_pred = arc_tags_pred.data.cpu().numpy()
|
| 423 |
+
stats, stats_nopunc, stats_root, num_inst = parse.eval_(word, pos, heads_pred, arc_tags_pred, heads,
|
| 424 |
+
arc_tags, args.alphabets['word_alphabet'], args.alphabets['pos_alphabet'],
|
| 425 |
+
lengths, punct_set=args.punct_set, symbolic_root=True)
|
| 426 |
+
ucorr, lcorr, total, ucm, lcm = stats
|
| 427 |
+
ucorr_nopunc, lcorr_nopunc, total_nopunc, ucm_nopunc, lcm_nopunc = stats_nopunc
|
| 428 |
+
corr_root, total_root = stats_root
|
| 429 |
+
eval_dict['dp_ucorrect'] += ucorr
|
| 430 |
+
eval_dict['dp_lcorrect'] += lcorr
|
| 431 |
+
eval_dict['dp_total'] += total
|
| 432 |
+
eval_dict['dp_ucomplete_match'] += ucm
|
| 433 |
+
eval_dict['dp_lcomplete_match'] += lcm
|
| 434 |
+
eval_dict['dp_ucorrect_nopunc'] += ucorr_nopunc
|
| 435 |
+
eval_dict['dp_lcorrect_nopunc'] += lcorr_nopunc
|
| 436 |
+
eval_dict['dp_total_nopunc'] += total_nopunc
|
| 437 |
+
eval_dict['dp_ucomplete_match_nopunc'] += ucm_nopunc
|
| 438 |
+
eval_dict['dp_lcomplete_match_nopunc'] += lcm_nopunc
|
| 439 |
+
eval_dict['dp_root_correct'] += corr_root
|
| 440 |
+
eval_dict['dp_total_root'] += total_root
|
| 441 |
+
eval_dict['dp_total_inst'] += num_inst
|
| 442 |
+
|
| 443 |
+
eval_dict['dp_uas'] = eval_dict['dp_ucorrect'] * 100 / eval_dict['dp_total'] # considering w. punctuation
|
| 444 |
+
eval_dict['dp_las'] = eval_dict['dp_lcorrect'] * 100 / eval_dict['dp_total'] # considering w. punctuation
|
| 445 |
+
print_results(eval_dict, split, domain, str_res)
|
| 446 |
+
return eval_dict
|
| 447 |
+
|
| 448 |
+
|
| 449 |
+
def print_results(eval_dict, split, domain, str_res='results'):
|
| 450 |
+
print('----------------------------------------------------------------------------------------------------------------------------')
|
| 451 |
+
print('Testing model on domain %s' % domain)
|
| 452 |
+
print('--------------- Dependency Parsing - %s ---------------' % split)
|
| 453 |
+
print(
|
| 454 |
+
str_res + ' on ' + split + ' W. Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%% (epoch: %d)' % (
|
| 455 |
+
eval_dict['dp_ucorrect'], eval_dict['dp_lcorrect'], eval_dict['dp_total'],
|
| 456 |
+
eval_dict['dp_ucorrect'] * 100 / eval_dict['dp_total'],
|
| 457 |
+
eval_dict['dp_lcorrect'] * 100 / eval_dict['dp_total'],
|
| 458 |
+
eval_dict['dp_ucomplete_match'] * 100 / eval_dict['dp_total_inst'],
|
| 459 |
+
eval_dict['dp_lcomplete_match'] * 100 / eval_dict['dp_total_inst'],
|
| 460 |
+
eval_dict['epoch']))
|
| 461 |
+
print(
|
| 462 |
+
str_res + ' on ' + split + ' Wo Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%% (epoch: %d)' % (
|
| 463 |
+
eval_dict['dp_ucorrect_nopunc'], eval_dict['dp_lcorrect_nopunc'], eval_dict['dp_total_nopunc'],
|
| 464 |
+
eval_dict['dp_ucorrect_nopunc'] * 100 / eval_dict['dp_total_nopunc'],
|
| 465 |
+
eval_dict['dp_lcorrect_nopunc'] * 100 / eval_dict['dp_total_nopunc'],
|
| 466 |
+
eval_dict['dp_ucomplete_match_nopunc'] * 100 / eval_dict['dp_total_inst'],
|
| 467 |
+
eval_dict['dp_lcomplete_match_nopunc'] * 100 / eval_dict['dp_total_inst'],
|
| 468 |
+
eval_dict['epoch']))
|
| 469 |
+
print(str_res + ' on ' + split + ' Root: corr: %d, total: %d, acc: %.2f%% (epoch: %d)' % (
|
| 470 |
+
eval_dict['dp_root_correct'], eval_dict['dp_total_root'],
|
| 471 |
+
eval_dict['dp_root_correct'] * 100 / eval_dict['dp_total_root'], eval_dict['epoch']))
|
| 472 |
+
print('\n')
|
| 473 |
+
|
| 474 |
+
def write_results(args, data, data_domain, split, model, model_domain, eval_dict):
|
| 475 |
+
str_file = args.full_model_name + '_' + split + '_model_domain_' + model_domain + '_data_domain_' + data_domain
|
| 476 |
+
res_filename = str_file + '_res.txt'
|
| 477 |
+
pred_filename = str_file + '_pred.txt'
|
| 478 |
+
gold_filename = str_file + '_gold.txt'
|
| 479 |
+
if eval_dict is not None:
|
| 480 |
+
# save results dictionary into a file
|
| 481 |
+
with open(res_filename, 'w') as f:
|
| 482 |
+
json.dump(eval_dict, f)
|
| 483 |
+
|
| 484 |
+
# save predictions and gold labels into files
|
| 485 |
+
pred_writer = Writer(args.alphabets)
|
| 486 |
+
gold_writer = Writer(args.alphabets)
|
| 487 |
+
pred_writer.start(pred_filename)
|
| 488 |
+
gold_writer.start(gold_filename)
|
| 489 |
+
for batch in prepare_data.iterate_batch(data, args.batch_size, args.device):
|
| 490 |
+
word, char, pos, ner, heads, arc_tags, auto_label, masks, lengths = batch
|
| 491 |
+
out_arc, out_arc_tag, masks, lengths = model.forward(word, char, pos, mask=masks, length=lengths)
|
| 492 |
+
heads_pred, arc_tags_pred, _ = model.decode(out_arc, out_arc_tag, mask=masks, length=lengths,
|
| 493 |
+
leading_symbolic=prepare_data.NUM_SYMBOLIC_TAGS)
|
| 494 |
+
lengths = lengths.cpu().numpy()
|
| 495 |
+
word = word.data.cpu().numpy()
|
| 496 |
+
pos = pos.data.cpu().numpy()
|
| 497 |
+
ner = ner.data.cpu().numpy()
|
| 498 |
+
heads = heads.data.cpu().numpy()
|
| 499 |
+
arc_tags = arc_tags.data.cpu().numpy()
|
| 500 |
+
heads_pred = heads_pred.data.cpu().numpy()
|
| 501 |
+
arc_tags_pred = arc_tags_pred.data.cpu().numpy()
|
| 502 |
+
# writing predictions
|
| 503 |
+
pred_writer.write(word, pos, ner, heads_pred, arc_tags_pred, lengths, symbolic_root=True)
|
| 504 |
+
# writing gold labels
|
| 505 |
+
gold_writer.write(word, pos, ner, heads, arc_tags, lengths, symbolic_root=True)
|
| 506 |
+
|
| 507 |
+
pred_writer.close()
|
| 508 |
+
gold_writer.close()
|
| 509 |
+
|
| 510 |
+
def main():
|
| 511 |
+
logger.info("Reading and creating arguments")
|
| 512 |
+
args = read_arguments()
|
| 513 |
+
logger.info("Reading Data")
|
| 514 |
+
datasets = {}
|
| 515 |
+
for split in args.splits:
|
| 516 |
+
print("Splits are:",split)
|
| 517 |
+
dataset = prepare_data.read_data_to_variable(args.data_paths[split], args.alphabets, args.device,
|
| 518 |
+
symbolic_root=True)
|
| 519 |
+
datasets[split] = dataset
|
| 520 |
+
if args.set_num_training_samples is not None:
|
| 521 |
+
print('Setting train and dev to %d samples' % args.set_num_training_samples)
|
| 522 |
+
datasets = rearrange_splits.rearranging_splits(datasets, args.set_num_training_samples)
|
| 523 |
+
logger.info("Creating Networks")
|
| 524 |
+
num_data = sum(datasets['train'][1])
|
| 525 |
+
model, optimizer, dev_eval_dict, test_eval_dict, start_epoch = build_model_and_optimizer(args)
|
| 526 |
+
best_model = deepcopy(model)
|
| 527 |
+
best_optimizer = deepcopy(optimizer)
|
| 528 |
+
|
| 529 |
+
logger.info('Training INFO of in domain %s' % args.domain)
|
| 530 |
+
logger.info('Training on Dependecy Parsing')
|
| 531 |
+
logger.info("train: gamma: %f, batch: %d, clip: %.2f, unk replace: %.2f" % (args.gamma, args.batch_size, args.clip, args.unk_replace))
|
| 532 |
+
logger.info('number of training samples for %s is: %d' % (args.domain, num_data))
|
| 533 |
+
logger.info("dropout(in, out, rnn): (%.2f, %.2f, %s)" % (args.p_in, args.p_out, args.p_rnn))
|
| 534 |
+
logger.info("num_epochs: %d" % (args.num_epochs))
|
| 535 |
+
print('\n')
|
| 536 |
+
|
| 537 |
+
if not args.eval_mode:
|
| 538 |
+
logger.info("Training")
|
| 539 |
+
num_batches = prepare_data.calc_num_batches(datasets['train'], args.batch_size)
|
| 540 |
+
lr = args.learning_rate
|
| 541 |
+
patient = 0
|
| 542 |
+
decay = 0
|
| 543 |
+
for epoch in range(start_epoch + 1, args.num_epochs + 1):
|
| 544 |
+
print('Epoch %d (Training: rnn mode: %s, optimizer: %s, learning rate=%.6f, eps=%.1e, decay rate=%.2f (schedule=%d, decay=%d)): ' % (
|
| 545 |
+
epoch, args.rnn_mode, args.opt, lr, args.epsilon, args.decay_rate, args.schedule, decay))
|
| 546 |
+
model.train()
|
| 547 |
+
total_loss = 0.0
|
| 548 |
+
total_mtl_pos_loss =0.0
|
| 549 |
+
total_mtl_case_loss =0.0
|
| 550 |
+
total_mtl_label_loss =0.0
|
| 551 |
+
total_arc_loss = 0.0
|
| 552 |
+
total_arc_tag_loss = 0.0
|
| 553 |
+
total_train_inst = 0.0
|
| 554 |
+
|
| 555 |
+
train_iter = prepare_data.iterate_batch_rand_bucket_choosing(
|
| 556 |
+
datasets['train'], args.batch_size, args.device, unk_replace=args.unk_replace)
|
| 557 |
+
start_time = time.time()
|
| 558 |
+
batch_num = 0
|
| 559 |
+
for batch_num, batch in enumerate(train_iter):
|
| 560 |
+
batch_num = batch_num + 1
|
| 561 |
+
optimizer.zero_grad()
|
| 562 |
+
# compute loss of main task
|
| 563 |
+
word, char, pos, ner_tags, heads, arc_tags, auto_label, masks, lengths = batch
|
| 564 |
+
out_arc, out_arc_tag, masks, lengths = model.forward(word, char, pos, mask=masks, length=lengths)
|
| 565 |
+
loss_arc, loss_arc_tag = model.loss(out_arc, out_arc_tag, heads, arc_tags, mask=masks, length=lengths)
|
| 566 |
+
############################################################
|
| 567 |
+
loss_morph = model.loss_morph(word, char, pos, mask=masks, length=lengths)
|
| 568 |
+
# loss_case = model.loss_case(word, char, ner_tags, mask=masks, length=lengths)
|
| 569 |
+
# loss_label = model.loss_label(word, char, arc_tags, mask=masks, length=lengths)
|
| 570 |
+
|
| 571 |
+
## Adding multi-tasking
|
| 572 |
+
# Tag_output, Tag_masks, Tag_lengths = Tagger.forward(word, char, pos, mask=masks, length=lengths)
|
| 573 |
+
# Tag_loss = Tagger.loss(Tag_output, pos, mask=Tag_masks, length=Tag_lengths)
|
| 574 |
+
# # update losses
|
| 575 |
+
# Tag_num_insts = Tag_masks.data.sum() - word.size(0)
|
| 576 |
+
# Tag_total_loss += Tag_loss.item() * Tag_num_insts
|
| 577 |
+
# Tag_total_train_inst += Tag_num_insts
|
| 578 |
+
|
| 579 |
+
#############################################################
|
| 580 |
+
loss = loss_arc + loss_arc_tag + loss_morph
|
| 581 |
+
|
| 582 |
+
# update losses
|
| 583 |
+
num_insts = masks.data.sum() - word.size(0)
|
| 584 |
+
total_arc_loss += loss_arc.item() * num_insts
|
| 585 |
+
total_arc_tag_loss += loss_arc_tag.item() * num_insts
|
| 586 |
+
total_mtl_pos_loss += loss_morph.item() * num_insts
|
| 587 |
+
# total_mtl_case_loss += loss_case.item() * num_insts
|
| 588 |
+
# total_mtl_label_loss += loss_label.item() * num_insts
|
| 589 |
+
total_loss += loss.item() * num_insts
|
| 590 |
+
total_train_inst += num_insts
|
| 591 |
+
# optimize parameters
|
| 592 |
+
loss.backward()
|
| 593 |
+
clip_grad_norm_(model.parameters(), args.clip)
|
| 594 |
+
optimizer.step()
|
| 595 |
+
|
| 596 |
+
time_ave = (time.time() - start_time) / batch_num
|
| 597 |
+
time_left = (num_batches - batch_num) * time_ave
|
| 598 |
+
|
| 599 |
+
# update log
|
| 600 |
+
if batch_num % 50 == 0:
|
| 601 |
+
log_info = 'train: %d/%d, domain: %s, total loss: %.2f, arc_loss: %.2f, arc_tag_loss: %.2f,morph_loss: %.2f, case_loss: %.2f, label_loss: %.2f, time left: %.2fs' % \
|
| 602 |
+
(batch_num, num_batches, args.domain, total_loss / total_train_inst, total_arc_loss / total_train_inst,
|
| 603 |
+
total_arc_tag_loss / total_train_inst,total_mtl_pos_loss/total_train_inst,total_mtl_case_loss/total_train_inst,total_mtl_label_loss/total_train_inst, time_left)
|
| 604 |
+
sys.stdout.write(log_info)
|
| 605 |
+
sys.stdout.write('\n')
|
| 606 |
+
sys.stdout.flush()
|
| 607 |
+
print('\n')
|
| 608 |
+
print('train: %d/%d, domain: %s, total_loss: %.2f, arc_loss: %.2f, arc_tag_loss: %.2f,morph_loss: %.2f, case_loss: %.2f, label_loss: %.2f, time: %.2fs' %
|
| 609 |
+
(batch_num, num_batches, args.domain, total_loss / total_train_inst, total_arc_loss / total_train_inst,
|
| 610 |
+
total_arc_tag_loss / total_train_inst,total_mtl_pos_loss/total_train_inst,total_mtl_case_loss/total_train_inst,total_mtl_label_loss/total_train_inst, time.time() - start_time))
|
| 611 |
+
|
| 612 |
+
dev_eval_dict, test_eval_dict, best_model, best_optimizer, patient = in_domain_evaluation(args, datasets, model, optimizer, dev_eval_dict, test_eval_dict, epoch, best_model, best_optimizer, patient)
|
| 613 |
+
if patient >= args.schedule:
|
| 614 |
+
lr = args.learning_rate / (1.0 + epoch * args.decay_rate)
|
| 615 |
+
optimizer = generate_optimizer(args, lr, model.parameters())
|
| 616 |
+
print('updated learning rate to %.6f' % lr)
|
| 617 |
+
patient = 0
|
| 618 |
+
print_results(test_eval_dict['in_domain'], 'test', args.domain, 'best_results')
|
| 619 |
+
print('\n')
|
| 620 |
+
for split in datasets.keys():
|
| 621 |
+
eval_dict = evaluation(args, datasets[split], split, best_model, args.domain, epoch, 'best_results')
|
| 622 |
+
write_results(args, datasets[split], args.domain, split, model, args.domain, eval_dict)
|
| 623 |
+
|
| 624 |
+
else:
|
| 625 |
+
logger.info("Evaluating")
|
| 626 |
+
epoch = start_epoch
|
| 627 |
+
for split in ['train', 'dev', 'test','poetry','prose']:
|
| 628 |
+
eval_dict = evaluation(args, datasets[split], split, model, args.domain, epoch, 'best_results')
|
| 629 |
+
write_results(args, datasets[split], args.domain, split, model, args.domain, eval_dict)
|
| 630 |
+
|
| 631 |
+
|
| 632 |
+
if __name__ == '__main__':
|
| 633 |
+
main()
|
examples/SequenceTagger.py
ADDED
|
@@ -0,0 +1,589 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import print_function
|
| 2 |
+
import sys
|
| 3 |
+
from os import path, makedirs, system, remove
|
| 4 |
+
|
| 5 |
+
sys.path.append(".")
|
| 6 |
+
sys.path.append("..")
|
| 7 |
+
|
| 8 |
+
import time
|
| 9 |
+
import argparse
|
| 10 |
+
import uuid
|
| 11 |
+
import json
|
| 12 |
+
import numpy as np
|
| 13 |
+
import torch
|
| 14 |
+
import torch.nn as nn
|
| 15 |
+
from collections import namedtuple
|
| 16 |
+
from copy import deepcopy
|
| 17 |
+
from torch.nn.utils import clip_grad_norm_
|
| 18 |
+
from torch.optim import Adam, SGD
|
| 19 |
+
from utils.io_ import seeds, Writer, get_logger, Index2Instance, prepare_data, write_extra_labels
|
| 20 |
+
from utils.models.sequence_tagger import Sequence_Tagger
|
| 21 |
+
from utils import load_word_embeddings
|
| 22 |
+
from utils.tasks.seqeval import accuracy_score, f1_score, precision_score, recall_score,classification_report
|
| 23 |
+
|
| 24 |
+
uid = uuid.uuid4().hex[:6]
|
| 25 |
+
|
| 26 |
+
logger = get_logger('SequenceTagger')
|
| 27 |
+
|
| 28 |
+
def read_arguments():
|
| 29 |
+
args_ = argparse.ArgumentParser(description='Sovling SequenceTagger')
|
| 30 |
+
args_.add_argument('--dataset', choices=['ontonotes', 'ud'], help='Dataset', required=True)
|
| 31 |
+
args_.add_argument('--domain', help='domain', required=True)
|
| 32 |
+
args_.add_argument('--rnn_mode', choices=['RNN', 'LSTM', 'GRU'], help='architecture of rnn',
|
| 33 |
+
required=True)
|
| 34 |
+
args_.add_argument('--task', default='distance_from_the_root', choices=['distance_from_the_root', 'number_of_children',\
|
| 35 |
+
'relative_pos_based', 'language_model','add_label','add_head_coarse_pos','Multitask_POS_predict','Multitask_case_predict',\
|
| 36 |
+
'Multitask_label_predict','Multitask_coarse_predict','MRL_case','MRL_POS','MRL_no','MRL_label',\
|
| 37 |
+
'predict_coarse_of_modifier','predict_ma_tag_of_modifier','add_head_ma','predict_case_of_modifier'], help='sequence_tagger task')
|
| 38 |
+
args_.add_argument('--num_epochs', type=int, default=200, help='Number of training epochs')
|
| 39 |
+
args_.add_argument('--batch_size', type=int, default=64, help='Number of sentences in each batch')
|
| 40 |
+
args_.add_argument('--hidden_size', type=int, default=256, help='Number of hidden units in RNN')
|
| 41 |
+
args_.add_argument('--tag_space', type=int, default=128, help='Dimension of tag space')
|
| 42 |
+
args_.add_argument('--num_layers', type=int, default=1, help='Number of layers of RNN')
|
| 43 |
+
args_.add_argument('--num_filters', type=int, default=50, help='Number of filters in CNN')
|
| 44 |
+
args_.add_argument('--kernel_size', type=int, default=3, help='Size of Kernel for CNN')
|
| 45 |
+
args_.add_argument('--use_pos', action='store_true', help='use part-of-speech embedding.')
|
| 46 |
+
args_.add_argument('--use_char', action='store_true', help='use character embedding and CNN.')
|
| 47 |
+
args_.add_argument('--word_dim', type=int, default=300, help='Dimension of word embeddings')
|
| 48 |
+
args_.add_argument('--pos_dim', type=int, default=50, help='Dimension of POS embeddings')
|
| 49 |
+
args_.add_argument('--char_dim', type=int, default=50, help='Dimension of Character embeddings')
|
| 50 |
+
args_.add_argument('--initializer', choices=['xavier'], help='initialize model parameters')
|
| 51 |
+
args_.add_argument('--opt', choices=['adam', 'sgd'], help='optimization algorithm')
|
| 52 |
+
args_.add_argument('--momentum', type=float, default=0.9, help='momentum of optimizer')
|
| 53 |
+
args_.add_argument('--betas', nargs=2, type=float, default=[0.9, 0.9], help='betas of optimizer')
|
| 54 |
+
args_.add_argument('--learning_rate', type=float, default=0.01, help='Learning rate')
|
| 55 |
+
args_.add_argument('--decay_rate', type=float, default=0.05, help='Decay rate of learning rate')
|
| 56 |
+
args_.add_argument('--schedule', type=int, help='schedule for learning rate decay')
|
| 57 |
+
args_.add_argument('--clip', type=float, default=5.0, help='gradient clipping')
|
| 58 |
+
args_.add_argument('--gamma', type=float, default=0.0, help='weight for regularization')
|
| 59 |
+
args_.add_argument('--epsilon', type=float, default=1e-8, help='epsilon for adam')
|
| 60 |
+
args_.add_argument('--p_rnn', nargs=2, type=float, required=True, help='dropout rate for RNN')
|
| 61 |
+
args_.add_argument('--p_in', type=float, default=0.33, help='dropout rate for input embeddings')
|
| 62 |
+
args_.add_argument('--p_out', type=float, default=0.33, help='dropout rate for output layer')
|
| 63 |
+
args_.add_argument('--unk_replace', type=float, default=0.,
|
| 64 |
+
help='The rate to replace a singleton word with UNK')
|
| 65 |
+
args_.add_argument('--punct_set', nargs='+', type=str, help='List of punctuations')
|
| 66 |
+
args_.add_argument('--word_embedding', choices=['random', 'glove', 'fasttext', 'word2vec'],
|
| 67 |
+
help='Embedding for words')
|
| 68 |
+
args_.add_argument('--word_path', help='path for word embedding dict - in case word_embedding is not random')
|
| 69 |
+
args_.add_argument('--freeze_word_embeddings', action='store_true', help='frozen the word embedding (disable fine-tuning).')
|
| 70 |
+
args_.add_argument('--char_embedding', choices=['random','hellwig'], help='Embedding for characters',
|
| 71 |
+
required=True)
|
| 72 |
+
args_.add_argument('--pos_embedding', choices=['random','one_hot'], help='Embedding for pos',
|
| 73 |
+
required=True)
|
| 74 |
+
args_.add_argument('--char_path', help='path for character embedding dict')
|
| 75 |
+
args_.add_argument('--pos_path', help='path for pos embedding dict')
|
| 76 |
+
args_.add_argument('--use_unlabeled_data', action='store_true', help='flag to use unlabeled data.')
|
| 77 |
+
args_.add_argument('--use_labeled_data', action='store_true', help='flag to use labeled data.')
|
| 78 |
+
args_.add_argument('--model_path', help='path for saving model file.', required=True)
|
| 79 |
+
args_.add_argument('--parser_path', help='path for loading parser files.', default=None)
|
| 80 |
+
args_.add_argument('--load_path', help='path for loading saved source model file.', default=None)
|
| 81 |
+
args_.add_argument('--strict',action='store_true', help='if True loaded model state should contain '
|
| 82 |
+
'exactly the same keys as current model')
|
| 83 |
+
args_.add_argument('--eval_mode', action='store_true', help='evaluating model without training it')
|
| 84 |
+
args = args_.parse_args()
|
| 85 |
+
args_dict = {}
|
| 86 |
+
args_dict['dataset'] = args.dataset
|
| 87 |
+
args_dict['domain'] = args.domain
|
| 88 |
+
args_dict['task'] = args.task
|
| 89 |
+
args_dict['rnn_mode'] = args.rnn_mode
|
| 90 |
+
args_dict['load_path'] = args.load_path
|
| 91 |
+
args_dict['strict'] = args.strict
|
| 92 |
+
args_dict['model_path'] = args.model_path
|
| 93 |
+
if not path.exists(args_dict['model_path']):
|
| 94 |
+
makedirs(args_dict['model_path'])
|
| 95 |
+
args_dict['parser_path'] = args.parser_path
|
| 96 |
+
args_dict['model_name'] = 'domain_' + args_dict['domain']
|
| 97 |
+
args_dict['full_model_name'] = path.join(args_dict['model_path'],args_dict['model_name'])
|
| 98 |
+
args_dict['use_unlabeled_data'] = args.use_unlabeled_data
|
| 99 |
+
args_dict['use_labeled_data'] = args.use_labeled_data
|
| 100 |
+
print(args_dict['parser_path'])
|
| 101 |
+
if args_dict['task'] == 'number_of_children':
|
| 102 |
+
args_dict['data_paths'] = write_extra_labels.add_number_of_children(args_dict['model_path'], args_dict['parser_path'], args_dict['domain'], args_dict['domain'],
|
| 103 |
+
use_unlabeled_data=args_dict['use_unlabeled_data'],
|
| 104 |
+
use_labeled_data=args_dict['use_labeled_data'])
|
| 105 |
+
elif args_dict['task'] == 'distance_from_the_root':
|
| 106 |
+
args_dict['data_paths'] = write_extra_labels.add_distance_from_the_root(args_dict['model_path'], args_dict['parser_path'], args_dict['domain'], args_dict['domain'],
|
| 107 |
+
use_unlabeled_data=args_dict['use_unlabeled_data'],
|
| 108 |
+
use_labeled_data=args_dict['use_labeled_data'])
|
| 109 |
+
elif args_dict['task'] == 'Multitask_label_predict':
|
| 110 |
+
args_dict['data_paths'] = write_extra_labels.Multitask_label_predict(args_dict['model_path'], args_dict['parser_path'], args_dict['domain'], args_dict['domain'],
|
| 111 |
+
use_unlabeled_data=args_dict['use_unlabeled_data'],
|
| 112 |
+
use_labeled_data=args_dict['use_labeled_data'])
|
| 113 |
+
elif args_dict['task'] == 'Multitask_coarse_predict':
|
| 114 |
+
args_dict['data_paths'] = write_extra_labels.Multitask_coarse_predict(args_dict['model_path'], args_dict['parser_path'], args_dict['domain'], args_dict['domain'],
|
| 115 |
+
use_unlabeled_data=args_dict['use_unlabeled_data'],
|
| 116 |
+
use_labeled_data=args_dict['use_labeled_data'])
|
| 117 |
+
elif args_dict['task'] == 'Multitask_POS_predict':
|
| 118 |
+
args_dict['data_paths'] = write_extra_labels.Multitask_POS_predict(args_dict['model_path'], args_dict['parser_path'], args_dict['domain'], args_dict['domain'],
|
| 119 |
+
use_unlabeled_data=args_dict['use_unlabeled_data'],
|
| 120 |
+
use_labeled_data=args_dict['use_labeled_data'])
|
| 121 |
+
elif args_dict['task'] == 'relative_pos_based':
|
| 122 |
+
args_dict['data_paths'] = write_extra_labels.add_relative_pos_based(args_dict['model_path'], args_dict['parser_path'], args_dict['domain'], args_dict['domain'],
|
| 123 |
+
use_unlabeled_data=args_dict['use_unlabeled_data'],
|
| 124 |
+
use_labeled_data=args_dict['use_labeled_data'])
|
| 125 |
+
elif args_dict['task'] == 'add_label':
|
| 126 |
+
args_dict['data_paths'] = write_extra_labels.add_label(args_dict['model_path'], args_dict['parser_path'], args_dict['domain'], args_dict['domain'],
|
| 127 |
+
use_unlabeled_data=args_dict['use_unlabeled_data'],
|
| 128 |
+
use_labeled_data=args_dict['use_labeled_data'])
|
| 129 |
+
elif args_dict['task'] == 'add_relative_TAG':
|
| 130 |
+
args_dict['data_paths'] = write_extra_labels.add_relative_TAG(args_dict['model_path'], args_dict['parser_path'], args_dict['domain'], args_dict['domain'],
|
| 131 |
+
use_unlabeled_data=args_dict['use_unlabeled_data'],
|
| 132 |
+
use_labeled_data=args_dict['use_labeled_data'])
|
| 133 |
+
elif args_dict['task'] == 'add_head_coarse_pos':
|
| 134 |
+
args_dict['data_paths'] = write_extra_labels.add_head_coarse_pos(args_dict['model_path'], args_dict['parser_path'], args_dict['domain'], args_dict['domain'],
|
| 135 |
+
use_unlabeled_data=args_dict['use_unlabeled_data'],
|
| 136 |
+
use_labeled_data=args_dict['use_labeled_data'])
|
| 137 |
+
elif args_dict['task'] == 'predict_ma_tag_of_modifier':
|
| 138 |
+
args_dict['data_paths'] = write_extra_labels.predict_ma_tag_of_modifier(args_dict['model_path'], args_dict['parser_path'], args_dict['domain'], args_dict['domain'],
|
| 139 |
+
use_unlabeled_data=args_dict['use_unlabeled_data'],
|
| 140 |
+
use_labeled_data=args_dict['use_labeled_data'])
|
| 141 |
+
elif args_dict['task'] == 'Multitask_case_predict':
|
| 142 |
+
args_dict['data_paths'] = write_extra_labels.Multitask_case_predict(args_dict['model_path'], args_dict['parser_path'], args_dict['domain'], args_dict['domain'],
|
| 143 |
+
use_unlabeled_data=args_dict['use_unlabeled_data'],
|
| 144 |
+
use_labeled_data=args_dict['use_labeled_data'])
|
| 145 |
+
elif args_dict['task'] == 'predict_coarse_of_modifier':
|
| 146 |
+
args_dict['data_paths'] = write_extra_labels.predict_coarse_of_modifier(args_dict['model_path'], args_dict['parser_path'], args_dict['domain'], args_dict['domain'],
|
| 147 |
+
use_unlabeled_data=args_dict['use_unlabeled_data'],
|
| 148 |
+
use_labeled_data=args_dict['use_labeled_data'])
|
| 149 |
+
elif args_dict['task'] == 'predict_case_of_modifier':
|
| 150 |
+
args_dict['data_paths'] = write_extra_labels.predict_case_of_modifier(args_dict['model_path'], args_dict['parser_path'], args_dict['domain'], args_dict['domain'],
|
| 151 |
+
use_unlabeled_data=args_dict['use_unlabeled_data'],
|
| 152 |
+
use_labeled_data=args_dict['use_labeled_data'])
|
| 153 |
+
elif args_dict['task'] == 'add_head_ma':
|
| 154 |
+
args_dict['data_paths'] = write_extra_labels.add_head_ma(args_dict['model_path'], args_dict['parser_path'], args_dict['domain'], args_dict['domain'],
|
| 155 |
+
use_unlabeled_data=args_dict['use_unlabeled_data'],
|
| 156 |
+
use_labeled_data=args_dict['use_labeled_data'])
|
| 157 |
+
elif args_dict['task'] == 'MRL_case':
|
| 158 |
+
args_dict['data_paths'] = write_extra_labels.MRL_case(args_dict['model_path'], args_dict['parser_path'], args_dict['domain'], args_dict['domain'],
|
| 159 |
+
use_unlabeled_data=args_dict['use_unlabeled_data'],
|
| 160 |
+
use_labeled_data=args_dict['use_labeled_data'])
|
| 161 |
+
elif args_dict['task'] == 'MRL_POS':
|
| 162 |
+
args_dict['data_paths'] = write_extra_labels.MRL_POS(args_dict['model_path'], args_dict['parser_path'], args_dict['domain'], args_dict['domain'],
|
| 163 |
+
use_unlabeled_data=args_dict['use_unlabeled_data'],
|
| 164 |
+
use_labeled_data=args_dict['use_labeled_data'])
|
| 165 |
+
elif args_dict['task'] == 'MRL_no':
|
| 166 |
+
args_dict['data_paths'] = write_extra_labels.MRL_no(args_dict['model_path'], args_dict['parser_path'], args_dict['domain'], args_dict['domain'],
|
| 167 |
+
use_unlabeled_data=args_dict['use_unlabeled_data'],
|
| 168 |
+
use_labeled_data=args_dict['use_labeled_data'])
|
| 169 |
+
elif args_dict['task'] == 'MRL_label':
|
| 170 |
+
args_dict['data_paths'] = write_extra_labels.MRL_label(args_dict['model_path'], args_dict['parser_path'], args_dict['domain'], args_dict['domain'],
|
| 171 |
+
use_unlabeled_data=args_dict['use_unlabeled_data'],
|
| 172 |
+
use_labeled_data=args_dict['use_labeled_data'])
|
| 173 |
+
else: #args_dict['task'] == 'language_model':
|
| 174 |
+
args_dict['data_paths'] = write_extra_labels.add_language_model(args_dict['model_path'], args_dict['parser_path'], args_dict['domain'], args_dict['domain'],
|
| 175 |
+
use_unlabeled_data=args_dict['use_unlabeled_data'],
|
| 176 |
+
use_labeled_data=args_dict['use_labeled_data'])
|
| 177 |
+
args_dict['splits'] = args_dict['data_paths'].keys()
|
| 178 |
+
alphabet_data_paths = deepcopy(args_dict['data_paths'])
|
| 179 |
+
if args_dict['dataset'] == 'ontonotes':
|
| 180 |
+
data_path = 'data/onto_pos_ner_dp'
|
| 181 |
+
else:
|
| 182 |
+
data_path = 'data/ud_pos_ner_dp'
|
| 183 |
+
# Adding more resources to make sure equal alphabet size for all domains
|
| 184 |
+
for split in args_dict['splits']:
|
| 185 |
+
if args_dict['dataset'] == 'ontonotes':
|
| 186 |
+
alphabet_data_paths['additional_' + split] = data_path + '_' + split + '_' + 'all'
|
| 187 |
+
else:
|
| 188 |
+
if '_' in args_dict['domain']:
|
| 189 |
+
alphabet_data_paths[split] = data_path + '_' + split + '_' + args_dict['domain'].split('_')[0]
|
| 190 |
+
else:
|
| 191 |
+
alphabet_data_paths[split] = args_dict['data_paths'][split]
|
| 192 |
+
args_dict['alphabet_data_paths'] = alphabet_data_paths
|
| 193 |
+
args_dict['num_epochs'] = args.num_epochs
|
| 194 |
+
args_dict['batch_size'] = args.batch_size
|
| 195 |
+
args_dict['hidden_size'] = args.hidden_size
|
| 196 |
+
args_dict['tag_space'] = args.tag_space
|
| 197 |
+
args_dict['num_layers'] = args.num_layers
|
| 198 |
+
args_dict['num_filters'] = args.num_filters
|
| 199 |
+
args_dict['kernel_size'] = args.kernel_size
|
| 200 |
+
args_dict['learning_rate'] = args.learning_rate
|
| 201 |
+
args_dict['initializer'] = nn.init.xavier_uniform_ if args.initializer == 'xavier' else None
|
| 202 |
+
args_dict['opt'] = args.opt
|
| 203 |
+
args_dict['momentum'] = args.momentum
|
| 204 |
+
args_dict['betas'] = tuple(args.betas)
|
| 205 |
+
args_dict['epsilon'] = args.epsilon
|
| 206 |
+
args_dict['decay_rate'] = args.decay_rate
|
| 207 |
+
args_dict['clip'] = args.clip
|
| 208 |
+
args_dict['gamma'] = args.gamma
|
| 209 |
+
args_dict['schedule'] = args.schedule
|
| 210 |
+
args_dict['p_rnn'] = tuple(args.p_rnn)
|
| 211 |
+
args_dict['p_in'] = args.p_in
|
| 212 |
+
args_dict['p_out'] = args.p_out
|
| 213 |
+
args_dict['unk_replace'] = args.unk_replace
|
| 214 |
+
args_dict['punct_set'] = None
|
| 215 |
+
if args.punct_set is not None:
|
| 216 |
+
args_dict['punct_set'] = set(args.punct_set)
|
| 217 |
+
logger.info("punctuations(%d): %s" % (len(args_dict['punct_set']), ' '.join(args_dict['punct_set'])))
|
| 218 |
+
args_dict['freeze_word_embeddings'] = args.freeze_word_embeddings
|
| 219 |
+
args_dict['word_embedding'] = args.word_embedding
|
| 220 |
+
args_dict['word_path'] = args.word_path
|
| 221 |
+
args_dict['use_char'] = args.use_char
|
| 222 |
+
args_dict['char_embedding'] = args.char_embedding
|
| 223 |
+
args_dict['pos_embedding'] = args.pos_embedding
|
| 224 |
+
args_dict['char_path'] = args.char_path
|
| 225 |
+
args_dict['pos_path'] = args.pos_path
|
| 226 |
+
args_dict['use_pos'] = args.use_pos
|
| 227 |
+
args_dict['pos_dim'] = args.pos_dim
|
| 228 |
+
args_dict['word_dict'] = None
|
| 229 |
+
args_dict['word_dim'] = args.word_dim
|
| 230 |
+
if args_dict['word_embedding'] != 'random' and args_dict['word_path']:
|
| 231 |
+
args_dict['word_dict'], args_dict['word_dim'] = load_word_embeddings.load_embedding_dict(args_dict['word_embedding'],
|
| 232 |
+
args_dict['word_path'])
|
| 233 |
+
args_dict['char_dict'] = None
|
| 234 |
+
args_dict['char_dim'] = args.char_dim
|
| 235 |
+
if args_dict['char_embedding'] != 'random':
|
| 236 |
+
args_dict['char_dict'], args_dict['char_dim'] = load_word_embeddings.load_embedding_dict(args_dict['char_embedding'],
|
| 237 |
+
args_dict['char_path'])
|
| 238 |
+
args_dict['pos_dict'] = None
|
| 239 |
+
if args_dict['pos_embedding'] != 'random':
|
| 240 |
+
args_dict['pos_dict'], args_dict['pos_dim'] = load_word_embeddings.load_embedding_dict(args_dict['pos_embedding'],
|
| 241 |
+
args_dict['pos_path'])
|
| 242 |
+
args_dict['alphabet_path'] = path.join(args_dict['model_path'], 'alphabets' + '_src_domain_' + args_dict['domain'] + '/')
|
| 243 |
+
args_dict['alphabet_parser_path'] = path.join(args_dict['parser_path'], 'alphabets' + '_src_domain_' + args_dict['domain'] + '/')
|
| 244 |
+
args_dict['model_name'] = path.join(args_dict['model_path'], args_dict['model_name'])
|
| 245 |
+
args_dict['eval_mode'] = args.eval_mode
|
| 246 |
+
args_dict['device'] = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 247 |
+
args_dict['word_status'] = 'frozen' if args.freeze_word_embeddings else 'fine tune'
|
| 248 |
+
args_dict['char_status'] = 'enabled' if args.use_char else 'disabled'
|
| 249 |
+
args_dict['pos_status'] = 'enabled' if args.use_pos else 'disabled'
|
| 250 |
+
logger.info("Saving arguments to file")
|
| 251 |
+
save_args(args, args_dict['full_model_name'])
|
| 252 |
+
logger.info("Creating Alphabets")
|
| 253 |
+
alphabet_dict = creating_alphabets(args_dict['alphabet_path'], args_dict['alphabet_parser_path'], args_dict['alphabet_data_paths'])
|
| 254 |
+
args_dict = {**args_dict, **alphabet_dict}
|
| 255 |
+
ARGS = namedtuple('ARGS', args_dict.keys())
|
| 256 |
+
my_args = ARGS(**args_dict)
|
| 257 |
+
return my_args
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
def creating_alphabets(alphabet_path, alphabet_parser_path, alphabet_data_paths):
|
| 261 |
+
data_paths_list = alphabet_data_paths.values()
|
| 262 |
+
alphabet_dict = {}
|
| 263 |
+
alphabet_dict['alphabets'] = prepare_data.create_alphabets_for_sequence_tagger(alphabet_path, alphabet_parser_path, data_paths_list)
|
| 264 |
+
for k, v in alphabet_dict['alphabets'].items():
|
| 265 |
+
num_key = 'num_' + k.split('_alphabet')[0]
|
| 266 |
+
alphabet_dict[num_key] = v.size()
|
| 267 |
+
logger.info("%s : %d" % (num_key, alphabet_dict[num_key]))
|
| 268 |
+
return alphabet_dict
|
| 269 |
+
|
| 270 |
+
def construct_embedding_table(alphabet, tokens_dict, dim, token_type='word'):
|
| 271 |
+
if tokens_dict is None:
|
| 272 |
+
return None
|
| 273 |
+
scale = np.sqrt(3.0 / dim)
|
| 274 |
+
table = np.empty([alphabet.size(), dim], dtype=np.float32)
|
| 275 |
+
table[prepare_data.UNK_ID, :] = np.random.uniform(-scale, scale, [1, dim]).astype(np.float32)
|
| 276 |
+
oov_tokens = 0
|
| 277 |
+
for token, index in alphabet.items():
|
| 278 |
+
if token in tokens_dict:
|
| 279 |
+
embedding = tokens_dict[token]
|
| 280 |
+
elif token.lower() in tokens_dict:
|
| 281 |
+
embedding = tokens_dict[token.lower()]
|
| 282 |
+
else:
|
| 283 |
+
embedding = np.random.uniform(-scale, scale, [1, dim]).astype(np.float32)
|
| 284 |
+
oov_tokens += 1
|
| 285 |
+
table[index, :] = embedding
|
| 286 |
+
print('token type : %s, number of oov: %d' % (token_type, oov_tokens))
|
| 287 |
+
table = torch.from_numpy(table)
|
| 288 |
+
return table
|
| 289 |
+
|
| 290 |
+
def get_free_gpu():
|
| 291 |
+
system('nvidia-smi -q -d Memory |grep -A4 GPU|grep Free > tmp.txt')
|
| 292 |
+
memory_available = [int(x.split()[2]) for x in open('tmp.txt', 'r').readlines()]
|
| 293 |
+
remove("tmp.txt")
|
| 294 |
+
free_device = 'cuda:' + str(np.argmax(memory_available))
|
| 295 |
+
return free_device
|
| 296 |
+
|
| 297 |
+
def save_args(args, full_model_name):
|
| 298 |
+
arg_path = full_model_name + '.arg.json'
|
| 299 |
+
argparse_dict = vars(args)
|
| 300 |
+
with open(arg_path, 'w') as f:
|
| 301 |
+
json.dump(argparse_dict, f)
|
| 302 |
+
|
| 303 |
+
def generate_optimizer(args, lr, params):
|
| 304 |
+
params = filter(lambda param: param.requires_grad, params)
|
| 305 |
+
if args.opt == 'adam':
|
| 306 |
+
return Adam(params, lr=lr, betas=args.betas, weight_decay=args.gamma, eps=args.epsilon)
|
| 307 |
+
elif args.opt == 'sgd':
|
| 308 |
+
return SGD(params, lr=lr, momentum=args.momentum, weight_decay=args.gamma, nesterov=True)
|
| 309 |
+
else:
|
| 310 |
+
raise ValueError('Unknown optimization algorithm: %s' % args.opt)
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
def save_checkpoint(model, optimizer, opt, dev_eval_dict, test_eval_dict, full_model_name):
|
| 314 |
+
path_name = full_model_name + '.pt'
|
| 315 |
+
print('Saving model to: %s' % path_name)
|
| 316 |
+
state = {'model_state_dict': model.state_dict(),
|
| 317 |
+
'optimizer_state_dict': optimizer.state_dict(),
|
| 318 |
+
'opt': opt, 'dev_eval_dict': dev_eval_dict, 'test_eval_dict': test_eval_dict}
|
| 319 |
+
torch.save(state, path_name)
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
def load_checkpoint(args, model, optimizer, dev_eval_dict, test_eval_dict, start_epoch, load_path, strict=True):
|
| 323 |
+
print('Loading saved model from: %s' % load_path)
|
| 324 |
+
checkpoint = torch.load(load_path, map_location=args.device)
|
| 325 |
+
if checkpoint['opt'] != args.opt:
|
| 326 |
+
raise ValueError('loaded optimizer type is: %s instead of: %s' % (checkpoint['opt'], args.opt))
|
| 327 |
+
model.load_state_dict(checkpoint['model_state_dict'], strict=strict)
|
| 328 |
+
if strict:
|
| 329 |
+
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
| 330 |
+
dev_eval_dict = checkpoint['dev_eval_dict']
|
| 331 |
+
test_eval_dict = checkpoint['test_eval_dict']
|
| 332 |
+
start_epoch = dev_eval_dict['in_domain']['epoch']
|
| 333 |
+
return model, optimizer, dev_eval_dict, test_eval_dict, start_epoch
|
| 334 |
+
|
| 335 |
+
|
| 336 |
+
def build_model_and_optimizer(args):
|
| 337 |
+
word_table = construct_embedding_table(args.alphabets['word_alphabet'], args.word_dict, args.word_dim, token_type='word')
|
| 338 |
+
char_table = construct_embedding_table(args.alphabets['char_alphabet'], args.char_dict, args.char_dim, token_type='char')
|
| 339 |
+
pos_table = construct_embedding_table(args.alphabets['pos_alphabet'], args.pos_dict, args.pos_dim, token_type='pos')
|
| 340 |
+
model = Sequence_Tagger(args.word_dim, args.num_word, args.char_dim, args.num_char,
|
| 341 |
+
args.use_pos, args.use_char, args.pos_dim, args.num_pos,
|
| 342 |
+
args.num_filters, args.kernel_size, args.rnn_mode,
|
| 343 |
+
args.hidden_size, args.num_layers, args.tag_space, args.num_auto_label,
|
| 344 |
+
embedd_word=word_table, embedd_char=char_table, embedd_pos=pos_table,
|
| 345 |
+
p_in=args.p_in, p_out=args.p_out, p_rnn=args.p_rnn,
|
| 346 |
+
initializer=args.initializer)
|
| 347 |
+
optimizer = generate_optimizer(args, args.learning_rate, model.parameters())
|
| 348 |
+
start_epoch = 0
|
| 349 |
+
dev_eval_dict = {'in_domain': initialize_eval_dict()}
|
| 350 |
+
test_eval_dict = {'in_domain': initialize_eval_dict()}
|
| 351 |
+
if args.load_path:
|
| 352 |
+
model, optimizer, dev_eval_dict, test_eval_dict, start_epoch = \
|
| 353 |
+
load_checkpoint(args, model, optimizer,
|
| 354 |
+
dev_eval_dict, test_eval_dict,
|
| 355 |
+
start_epoch, args.load_path, strict=args.strict)
|
| 356 |
+
if args.freeze_word_embeddings:
|
| 357 |
+
model.rnn_encoder.word_embedd.weight.requires_grad = False
|
| 358 |
+
# model.rnn_encoder.char_embedd.weight.requires_grad = False
|
| 359 |
+
# model.rnn_encoder.pos_embedd.weight.requires_grad = False
|
| 360 |
+
device = args.device
|
| 361 |
+
model.to(device)
|
| 362 |
+
return model, optimizer, dev_eval_dict, test_eval_dict, start_epoch
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
def initialize_eval_dict():
|
| 366 |
+
eval_dict = {}
|
| 367 |
+
eval_dict['auto_label_accuracy'] = 0.0
|
| 368 |
+
eval_dict['auto_label_precision'] = 0.0
|
| 369 |
+
eval_dict['auto_label_recall'] = 0.0
|
| 370 |
+
eval_dict['auto_label_f1'] = 0.0
|
| 371 |
+
return eval_dict
|
| 372 |
+
|
| 373 |
+
def in_domain_evaluation(args, datasets, model, optimizer, dev_eval_dict, test_eval_dict, epoch,
|
| 374 |
+
best_model, best_optimizer, patient):
|
| 375 |
+
# In-domain evaluation
|
| 376 |
+
curr_dev_eval_dict = evaluation(args, datasets['dev'], 'dev', model, args.domain, epoch, 'current_results')
|
| 377 |
+
is_best_in_domain = dev_eval_dict['in_domain']['auto_label_f1'] <= curr_dev_eval_dict['auto_label_f1']
|
| 378 |
+
|
| 379 |
+
if is_best_in_domain:
|
| 380 |
+
for key, value in curr_dev_eval_dict.items():
|
| 381 |
+
dev_eval_dict['in_domain'][key] = value
|
| 382 |
+
curr_test_eval_dict = evaluation(args, datasets['test'], 'test', model, args.domain, epoch, 'current_results')
|
| 383 |
+
for key, value in curr_test_eval_dict.items():
|
| 384 |
+
test_eval_dict['in_domain'][key] = value
|
| 385 |
+
best_model = deepcopy(model)
|
| 386 |
+
best_optimizer = deepcopy(optimizer)
|
| 387 |
+
patient = 0
|
| 388 |
+
else:
|
| 389 |
+
patient += 1
|
| 390 |
+
if epoch == args.num_epochs:
|
| 391 |
+
# save in-domain checkpoint
|
| 392 |
+
for split in ['dev', 'test']:
|
| 393 |
+
eval_dict = dev_eval_dict['in_domain'] if split == 'dev' else test_eval_dict['in_domain']
|
| 394 |
+
write_results(args, datasets[split], args.domain, split, best_model, args.domain, eval_dict)
|
| 395 |
+
save_checkpoint(best_model, best_optimizer, args.opt, dev_eval_dict, test_eval_dict, args.full_model_name)
|
| 396 |
+
|
| 397 |
+
print('\n')
|
| 398 |
+
return dev_eval_dict, test_eval_dict, best_model, best_optimizer, patient, curr_dev_eval_dict
|
| 399 |
+
|
| 400 |
+
|
| 401 |
+
def evaluation(args, data, split, model, domain, epoch, str_res='results'):
|
| 402 |
+
# evaluate performance on data
|
| 403 |
+
model.eval()
|
| 404 |
+
auto_label_idx2inst = Index2Instance(args.alphabets['auto_label_alphabet'])
|
| 405 |
+
eval_dict = initialize_eval_dict()
|
| 406 |
+
eval_dict['epoch'] = epoch
|
| 407 |
+
pred_labels = []
|
| 408 |
+
gold_labels = []
|
| 409 |
+
for batch in prepare_data.iterate_batch(data, args.batch_size, args.device):
|
| 410 |
+
word, char, pos, ner, heads, arc_tags, auto_label, masks, lengths = batch
|
| 411 |
+
output, masks, lengths = model.forward(word, char, pos, mask=masks, length=lengths)
|
| 412 |
+
auto_label_preds = model.decode(output, mask=masks, length=lengths, leading_symbolic=prepare_data.NUM_SYMBOLIC_TAGS)
|
| 413 |
+
lengths = lengths.cpu().numpy()
|
| 414 |
+
word = word.data.cpu().numpy()
|
| 415 |
+
pos = pos.data.cpu().numpy()
|
| 416 |
+
ner = ner.data.cpu().numpy()
|
| 417 |
+
heads = heads.data.cpu().numpy()
|
| 418 |
+
arc_tags = arc_tags.data.cpu().numpy()
|
| 419 |
+
auto_label = auto_label.data.cpu().numpy()
|
| 420 |
+
auto_label_preds = auto_label_preds.data.cpu().numpy()
|
| 421 |
+
gold_labels += auto_label_idx2inst.index2instance(auto_label, lengths, symbolic_root=True)
|
| 422 |
+
pred_labels += auto_label_idx2inst.index2instance(auto_label_preds, lengths, symbolic_root=True)
|
| 423 |
+
|
| 424 |
+
eval_dict['auto_label_accuracy'] = accuracy_score(gold_labels, pred_labels) * 100
|
| 425 |
+
eval_dict['auto_label_precision'] = precision_score(gold_labels, pred_labels) * 100
|
| 426 |
+
eval_dict['auto_label_recall'] = recall_score(gold_labels, pred_labels) * 100
|
| 427 |
+
eval_dict['auto_label_f1'] = f1_score(gold_labels, pred_labels) * 100
|
| 428 |
+
eval_dict['classification_report'] = classification_report(gold_labels, pred_labels)
|
| 429 |
+
print_results(eval_dict, split, domain, str_res)
|
| 430 |
+
return eval_dict
|
| 431 |
+
|
| 432 |
+
|
| 433 |
+
def print_results(eval_dict, split, domain, str_res='results'):
|
| 434 |
+
print('----------------------------------------------------------------------------------------------------------------------------')
|
| 435 |
+
print('Testing model on domain %s' % domain)
|
| 436 |
+
print('--------------- sequence_tagger - %s ---------------' % split)
|
| 437 |
+
print(
|
| 438 |
+
str_res + ' on ' + split + ' accuracy: %.2f%%, precision: %.2f%%, recall: %.2f%%, F1: %.2f%% (epoch: %d)'
|
| 439 |
+
% (eval_dict['auto_label_accuracy'], eval_dict['auto_label_precision'], eval_dict['auto_label_recall'], eval_dict['auto_label_f1'],
|
| 440 |
+
eval_dict['epoch']))
|
| 441 |
+
print(eval_dict['classification_report'])
|
| 442 |
+
|
| 443 |
+
|
| 444 |
+
def write_results(args, data, data_domain, split, model, model_domain, eval_dict):
|
| 445 |
+
str_file = args.full_model_name + '_' + split + '_model_domain_' + model_domain + '_data_domain_' + data_domain
|
| 446 |
+
res_filename = str_file + '_res.txt'
|
| 447 |
+
pred_filename = str_file + '_pred.txt'
|
| 448 |
+
gold_filename = str_file + '_gold.txt'
|
| 449 |
+
|
| 450 |
+
# save results dictionary into a file
|
| 451 |
+
with open(res_filename, 'w') as f:
|
| 452 |
+
json.dump(eval_dict, f)
|
| 453 |
+
|
| 454 |
+
# save predictions and gold labels into files
|
| 455 |
+
pred_writer = Writer(args.alphabets)
|
| 456 |
+
gold_writer = Writer(args.alphabets)
|
| 457 |
+
pred_writer.start(pred_filename)
|
| 458 |
+
gold_writer.start(gold_filename)
|
| 459 |
+
for batch in prepare_data.iterate_batch(data, args.batch_size, args.device):
|
| 460 |
+
word, char, pos, ner, heads, arc_tags, auto_label, masks, lengths = batch
|
| 461 |
+
output, masks, lengths = model.forward(word, char, pos, mask=masks, length=lengths)
|
| 462 |
+
auto_label_preds = model.decode(output, mask=masks, length=lengths,
|
| 463 |
+
leading_symbolic=prepare_data.NUM_SYMBOLIC_TAGS)
|
| 464 |
+
lengths = lengths.cpu().numpy()
|
| 465 |
+
word = word.data.cpu().numpy()
|
| 466 |
+
pos = pos.data.cpu().numpy()
|
| 467 |
+
ner = ner.data.cpu().numpy()
|
| 468 |
+
heads = heads.data.cpu().numpy()
|
| 469 |
+
arc_tags = arc_tags.data.cpu().numpy()
|
| 470 |
+
auto_label_preds = auto_label_preds.data.cpu().numpy()
|
| 471 |
+
# writing predictions
|
| 472 |
+
pred_writer.write(word, pos, ner, heads, arc_tags, lengths, auto_label=auto_label_preds, symbolic_root=True)
|
| 473 |
+
# writing gold labels
|
| 474 |
+
gold_writer.write(word, pos, ner, heads, arc_tags, lengths, auto_label=auto_label, symbolic_root=True)
|
| 475 |
+
|
| 476 |
+
pred_writer.close()
|
| 477 |
+
gold_writer.close()
|
| 478 |
+
|
| 479 |
+
def main():
|
| 480 |
+
logger.info("Reading and creating arguments")
|
| 481 |
+
args = read_arguments()
|
| 482 |
+
logger.info("Reading Data")
|
| 483 |
+
datasets = {}
|
| 484 |
+
for split in args.splits:
|
| 485 |
+
dataset = prepare_data.read_data_to_variable(args.data_paths[split], args.alphabets, args.device,
|
| 486 |
+
symbolic_root=True)
|
| 487 |
+
datasets[split] = dataset
|
| 488 |
+
|
| 489 |
+
logger.info("Creating Networks")
|
| 490 |
+
num_data = sum(datasets['train'][1])
|
| 491 |
+
model, optimizer, dev_eval_dict, test_eval_dict, start_epoch = build_model_and_optimizer(args)
|
| 492 |
+
best_model = deepcopy(model)
|
| 493 |
+
best_optimizer = deepcopy(optimizer)
|
| 494 |
+
logger.info('Training INFO of in domain %s' % args.domain)
|
| 495 |
+
logger.info('Training on Dependecy Parsing')
|
| 496 |
+
print(model)
|
| 497 |
+
logger.info("train: gamma: %f, batch: %d, clip: %.2f, unk replace: %.2f" % (args.gamma, args.batch_size, args.clip, args.unk_replace))
|
| 498 |
+
logger.info('number of training samples for %s is: %d' % (args.domain, num_data))
|
| 499 |
+
logger.info("dropout(in, out, rnn): (%.2f, %.2f, %s)" % (args.p_in, args.p_out, args.p_rnn))
|
| 500 |
+
logger.info("num_epochs: %d" % (args.num_epochs))
|
| 501 |
+
print('\n')
|
| 502 |
+
|
| 503 |
+
if not args.eval_mode:
|
| 504 |
+
logger.info("Training")
|
| 505 |
+
num_batches = prepare_data.calc_num_batches(datasets['train'], args.batch_size)
|
| 506 |
+
lr = args.learning_rate
|
| 507 |
+
patient = 0
|
| 508 |
+
terminal_patient = 0
|
| 509 |
+
decay = 0
|
| 510 |
+
for epoch in range(start_epoch + 1, args.num_epochs + 1):
|
| 511 |
+
print('Epoch %d (Training: rnn mode: %s, optimizer: %s, learning rate=%.6f, eps=%.1e, decay rate=%.2f (schedule=%d, decay=%d)): ' % (
|
| 512 |
+
epoch, args.rnn_mode, args.opt, lr, args.epsilon, args.decay_rate, args.schedule, decay))
|
| 513 |
+
model.train()
|
| 514 |
+
total_loss = 0.0
|
| 515 |
+
total_train_inst = 0.0
|
| 516 |
+
|
| 517 |
+
iter = prepare_data.iterate_batch_rand_bucket_choosing(
|
| 518 |
+
datasets['train'], args.batch_size, args.device, unk_replace=args.unk_replace)
|
| 519 |
+
start_time = time.time()
|
| 520 |
+
batch_num = 0
|
| 521 |
+
for batch_num, batch in enumerate(iter):
|
| 522 |
+
batch_num = batch_num + 1
|
| 523 |
+
optimizer.zero_grad()
|
| 524 |
+
# compute loss of main task
|
| 525 |
+
word, char, pos, ner_tags, heads, arc_tags, auto_label, masks, lengths = batch
|
| 526 |
+
output, masks, lengths = model.forward(word, char, pos, mask=masks, length=lengths)
|
| 527 |
+
loss = model.loss(output, auto_label, mask=masks, length=lengths)
|
| 528 |
+
|
| 529 |
+
# update losses
|
| 530 |
+
num_insts = masks.data.sum() - word.size(0)
|
| 531 |
+
total_loss += loss.item() * num_insts
|
| 532 |
+
total_train_inst += num_insts
|
| 533 |
+
# optimize parameters
|
| 534 |
+
loss.backward()
|
| 535 |
+
clip_grad_norm_(model.parameters(), args.clip)
|
| 536 |
+
optimizer.step()
|
| 537 |
+
|
| 538 |
+
time_ave = (time.time() - start_time) / batch_num
|
| 539 |
+
time_left = (num_batches - batch_num) * time_ave
|
| 540 |
+
|
| 541 |
+
# update log
|
| 542 |
+
if batch_num % 50 == 0:
|
| 543 |
+
log_info = 'train: %d/%d, domain: %s, total loss: %.2f, time left: %.2fs' % \
|
| 544 |
+
(batch_num, num_batches, args.domain, total_loss / total_train_inst, time_left)
|
| 545 |
+
sys.stdout.write(log_info)
|
| 546 |
+
sys.stdout.write('\n')
|
| 547 |
+
sys.stdout.flush()
|
| 548 |
+
print('\n')
|
| 549 |
+
print('train: %d/%d, domain: %s, total_loss: %.2f, time: %.2fs' %
|
| 550 |
+
(batch_num, num_batches, args.domain, total_loss / total_train_inst, time.time() - start_time))
|
| 551 |
+
|
| 552 |
+
dev_eval_dict, test_eval_dict, best_model, best_optimizer, patient,curr_dev_eval_dict = in_domain_evaluation(args, datasets, model, optimizer, dev_eval_dict, test_eval_dict, epoch, best_model, best_optimizer, patient)
|
| 553 |
+
store ={'dev_eval_dict':curr_dev_eval_dict }
|
| 554 |
+
#############################################
|
| 555 |
+
str_file = args.full_model_name + '_' +'all_epochs'
|
| 556 |
+
with open(str_file,'a') as f:
|
| 557 |
+
f.write(str(store)+'\n')
|
| 558 |
+
if patient == 0:
|
| 559 |
+
terminal_patient = 0
|
| 560 |
+
else:
|
| 561 |
+
terminal_patient += 1
|
| 562 |
+
if terminal_patient >= 4 * args.schedule:
|
| 563 |
+
# Save best model and terminate learning
|
| 564 |
+
cur_epoch = epoch
|
| 565 |
+
epoch = args.num_epochs
|
| 566 |
+
in_domain_evaluation(args, datasets, model, optimizer, dev_eval_dict, test_eval_dict, epoch, best_model,
|
| 567 |
+
best_optimizer, patient)
|
| 568 |
+
log_info = 'Terminating training in epoch %d' % (cur_epoch)
|
| 569 |
+
sys.stdout.write(log_info)
|
| 570 |
+
sys.stdout.write('\n')
|
| 571 |
+
sys.stdout.flush()
|
| 572 |
+
return
|
| 573 |
+
if patient >= args.schedule:
|
| 574 |
+
lr = args.learning_rate / (1.0 + epoch * args.decay_rate)
|
| 575 |
+
optimizer = generate_optimizer(args, lr, model.parameters())
|
| 576 |
+
print('updated learning rate to %.6f' % lr)
|
| 577 |
+
patient = 0
|
| 578 |
+
print_results(test_eval_dict['in_domain'], 'test', args.domain, 'best_results')
|
| 579 |
+
print('\n')
|
| 580 |
+
|
| 581 |
+
else:
|
| 582 |
+
logger.info("Evaluating")
|
| 583 |
+
epoch = start_epoch
|
| 584 |
+
for split in ['train', 'dev', 'test']:
|
| 585 |
+
evaluation(args, datasets[split], split, model, args.domain, epoch, 'best_results')
|
| 586 |
+
|
| 587 |
+
|
| 588 |
+
if __name__ == '__main__':
|
| 589 |
+
main()
|
examples/VST_Pred_Prepare.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
|
| 3 |
+
def write_combined(dirs):
|
| 4 |
+
path = "./saved_models/"+dirs+'/final_ensembled_TranSeq/'
|
| 5 |
+
f = open(path+'domain_VST_test_model_domain_VST_data_domain_VST_gold.txt','r')
|
| 6 |
+
gold = f.readlines()
|
| 7 |
+
f.close()
|
| 8 |
+
f = open(path+'domain_VST_test_model_domain_VST_data_domain_VST_pred.txt','r')
|
| 9 |
+
pred = f.readlines()
|
| 10 |
+
f.close()
|
| 11 |
+
|
| 12 |
+
for i in range(len(gold)):
|
| 13 |
+
if gold[i] == '\n':
|
| 14 |
+
continue
|
| 15 |
+
if gold[i].split('\t')[0] == pred[i].split('\t')[0]:
|
| 16 |
+
gold[i] = gold[i].replace('\n','\t')
|
| 17 |
+
gold[i] = gold[i]+'\t'.join(pred[i].split('\t')[-2:])
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
gold.insert(0,'word_id\tword\tpostag\tlemma\tgold_head\tgold_label\tpred_head\tpred_label\n\n')
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
f = open(path+'VST_test.txt','w')
|
| 24 |
+
for line in gold:
|
| 25 |
+
f.write(line)
|
| 26 |
+
f.close()
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
if __name__=="__main__":
|
| 30 |
+
|
| 31 |
+
dir_path = sys.argv[1]
|
| 32 |
+
|
| 33 |
+
write_combined(dir_path)
|
| 34 |
+
|
examples/VST_macro_score.py
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def load_results(filename):
|
| 5 |
+
|
| 6 |
+
results = []
|
| 7 |
+
sent = []
|
| 8 |
+
with open(filename, 'r') as fp:
|
| 9 |
+
for i, line in enumerate(fp):
|
| 10 |
+
if i == 0:
|
| 11 |
+
continue
|
| 12 |
+
splits = line.strip().split('\t')
|
| 13 |
+
if len(line.strip()) == 0:
|
| 14 |
+
if len(sent) != 0:
|
| 15 |
+
results.append(sent)
|
| 16 |
+
sent = []
|
| 17 |
+
continue
|
| 18 |
+
gold_head = splits[-4]
|
| 19 |
+
gold_label = splits[-3]
|
| 20 |
+
pred_head = splits[-2]
|
| 21 |
+
pred_label = splits[-1]
|
| 22 |
+
sent.append((gold_head, gold_label, pred_head, pred_label))
|
| 23 |
+
print('Total Number of sentences ' + str(len(results)))
|
| 24 |
+
return results
|
| 25 |
+
|
| 26 |
+
def calculate_las_uas(gold_heads, gold_labels, pred_heads, pred_labels):
|
| 27 |
+
|
| 28 |
+
u_correct = 0
|
| 29 |
+
l_correct = 0
|
| 30 |
+
u_total = 0
|
| 31 |
+
l_total = 0
|
| 32 |
+
|
| 33 |
+
for i in range(len(gold_heads)):
|
| 34 |
+
if gold_heads[i] == pred_heads[i]:
|
| 35 |
+
u_correct +=1
|
| 36 |
+
u_total +=1
|
| 37 |
+
l_total +=1
|
| 38 |
+
if gold_heads[i] == pred_heads[i] and gold_labels[i] == pred_labels[i]:
|
| 39 |
+
l_correct +=1
|
| 40 |
+
return u_correct, u_total, l_correct, l_total
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def calculate_stats(results,path):
|
| 44 |
+
u_correct = 0
|
| 45 |
+
l_correct = 0
|
| 46 |
+
u_total = 0
|
| 47 |
+
l_total = 0
|
| 48 |
+
|
| 49 |
+
sent_uas = []
|
| 50 |
+
sent_las = []
|
| 51 |
+
|
| 52 |
+
for i in range(len(results)):
|
| 53 |
+
gold_heads, gold_labels, pred_heads, pred_labels = zip(*results[i])
|
| 54 |
+
u_c, u_t, l_c, l_t = calculate_las_uas(gold_heads, gold_labels, pred_heads, pred_labels)
|
| 55 |
+
if u_t >0:
|
| 56 |
+
uas = float(u_c)/u_t
|
| 57 |
+
las = float(l_c)/l_t
|
| 58 |
+
sent_uas.append(uas)
|
| 59 |
+
sent_las.append(las)
|
| 60 |
+
u_correct += u_c
|
| 61 |
+
l_correct += l_c
|
| 62 |
+
u_total += u_t
|
| 63 |
+
l_total += l_t
|
| 64 |
+
|
| 65 |
+
UAS = float(u_correct)/u_total
|
| 66 |
+
LAS = float(l_correct)/l_total
|
| 67 |
+
path = path.replace('VST_test.txt','Macro-UAS-LAS-score.txt')
|
| 68 |
+
f = open(path,'w')
|
| 69 |
+
f.write('Word level UAS : ' + str(UAS) +'\n')
|
| 70 |
+
f.write('Word level LAS : ' + str(LAS)+'\n')
|
| 71 |
+
f.write('Sentence level UAS : ' + str(float(sum(sent_uas))/len(sent_uas))+'\n')
|
| 72 |
+
f.write('Sentence level LAS : ' + str(float(sum(sent_las))/len(sent_las))+'\n')
|
| 73 |
+
f.close()
|
| 74 |
+
print('Word level UAS : ' + str(UAS))
|
| 75 |
+
print('Word level LAS : ' + str(LAS))
|
| 76 |
+
print('Sentence level UAS : ' + str(float(sum(sent_uas))/len(sent_uas)))
|
| 77 |
+
print('Sentence level LAS : ' + str(float(sum(sent_las))/len(sent_las)))
|
| 78 |
+
|
| 79 |
+
return sent_uas, sent_las, UAS, LAS
|
| 80 |
+
|
| 81 |
+
def write_results(sent_uas, sent_las, filename_uas, filename_las):
|
| 82 |
+
|
| 83 |
+
fp_uas = open(filename_uas, 'w')
|
| 84 |
+
fp_las = open(filename_las, 'w')
|
| 85 |
+
|
| 86 |
+
for i in range(len(sent_uas)):
|
| 87 |
+
fp_uas.write(str(sent_uas[i]) + '\n')
|
| 88 |
+
fp_las.write(str(sent_las[i]) + '\n')
|
| 89 |
+
|
| 90 |
+
fp_uas.close()
|
| 91 |
+
fp_las.close()
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
if __name__=="__main__":
|
| 95 |
+
dirs = sys.argv[1]
|
| 96 |
+
# results_2 = load_results(sys.argv[2])
|
| 97 |
+
##path = "Predictions/Yap/"+dirs
|
| 98 |
+
path = "./saved_models/"+dirs+"/final_ensembled_TranSeq/VST_test.txt"
|
| 99 |
+
result = load_results(path)
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
sent_uas1, sent_las1, UAS1, LAS1 = calculate_stats(result,path)
|
| 103 |
+
# sent_uas2, sent_las2, UAS2, LAS2 = calculate_stats(results_2)
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
write_results(sent_uas1, sent_las1, 'results1_uas.txt', 'results1_las.txt')
|
| 107 |
+
# write_results(sent_uas2, sent_las2, 'results2_uas.txt', 'results2_las.txt')
|
examples/eval/conll03eval.v2
ADDED
|
@@ -0,0 +1,336 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/perl -w
|
| 2 |
+
# conlleval: evaluate result of processing CoNLL-2000 shared task
|
| 3 |
+
# usage: conlleval [-l] [-r] [-d delimiterTag] [-o oTag] < file
|
| 4 |
+
# README: http://cnts.uia.ac.be/conll2000/chunking/output.html
|
| 5 |
+
# options: l: generate LaTeX output for tables like in
|
| 6 |
+
# http://cnts.uia.ac.be/conll2003/ner/example.tex
|
| 7 |
+
# r: accept raw result tags (without B- and I- prefix;
|
| 8 |
+
# assumes one word per chunk)
|
| 9 |
+
# d: alternative delimiter tag (default is single space)
|
| 10 |
+
# o: alternative outside tag (default is O)
|
| 11 |
+
# note: the file should contain lines with items separated
|
| 12 |
+
# by $delimiter characters (default space). The final
|
| 13 |
+
# two items should contain the correct tag and the
|
| 14 |
+
# guessed tag in that order. Sentences should be
|
| 15 |
+
# separated from each other by empty lines or lines
|
| 16 |
+
# with $boundary fields (default -X-).
|
| 17 |
+
# url: http://lcg-www.uia.ac.be/conll2000/chunking/
|
| 18 |
+
# started: 1998-09-25
|
| 19 |
+
# version: 2004-01-26
|
| 20 |
+
# author: Erik Tjong Kim Sang <erikt@uia.ua.ac.be>
|
| 21 |
+
|
| 22 |
+
use strict;
|
| 23 |
+
|
| 24 |
+
my $false = 0;
|
| 25 |
+
my $true = 42;
|
| 26 |
+
|
| 27 |
+
my $boundary = "-X-"; # sentence boundary
|
| 28 |
+
my $correct; # current corpus chunk tag (I,O,B)
|
| 29 |
+
my $correctChunk = 0; # number of correctly identified chunks
|
| 30 |
+
my $correctTags = 0; # number of correct chunk tags
|
| 31 |
+
my $correctType; # type of current corpus chunk tag (NP,VP,etc.)
|
| 32 |
+
my $delimiter = " "; # field delimiter
|
| 33 |
+
my $FB1 = 0.0; # FB1 score (Van Rijsbergen 1979)
|
| 34 |
+
my $firstItem; # first feature (for sentence boundary checks)
|
| 35 |
+
my $foundCorrect = 0; # number of chunks in corpus
|
| 36 |
+
my $foundGuessed = 0; # number of identified chunks
|
| 37 |
+
my $guessed; # current guessed chunk tag
|
| 38 |
+
my $guessedType; # type of current guessed chunk tag
|
| 39 |
+
my $i; # miscellaneous counter
|
| 40 |
+
my $inCorrect = $false; # currently processed chunk is correct until now
|
| 41 |
+
my $lastCorrect = "O"; # previous chunk tag in corpus
|
| 42 |
+
my $latex = 0; # generate LaTeX formatted output
|
| 43 |
+
my $lastCorrectType = ""; # type of previously identified chunk tag
|
| 44 |
+
my $lastGuessed = "O"; # previously identified chunk tag
|
| 45 |
+
my $lastGuessedType = ""; # type of previous chunk tag in corpus
|
| 46 |
+
my $lastType; # temporary storage for detecting duplicates
|
| 47 |
+
my $line; # line
|
| 48 |
+
my $nbrOfFeatures = -1; # number of features per line
|
| 49 |
+
my $precision = 0.0; # precision score
|
| 50 |
+
my $oTag = "O"; # outside tag, default O
|
| 51 |
+
my $raw = 0; # raw input: add B to every token
|
| 52 |
+
my $recall = 0.0; # recall score
|
| 53 |
+
my $tokenCounter = 0; # token counter (ignores sentence breaks)
|
| 54 |
+
|
| 55 |
+
my %correctChunk = (); # number of correctly identified chunks per type
|
| 56 |
+
my %foundCorrect = (); # number of chunks in corpus per type
|
| 57 |
+
my %foundGuessed = (); # number of identified chunks per type
|
| 58 |
+
|
| 59 |
+
my @features; # features on line
|
| 60 |
+
my @sortedTypes; # sorted list of chunk type names
|
| 61 |
+
|
| 62 |
+
# sanity check
|
| 63 |
+
while (@ARGV and $ARGV[0] =~ /^-/) {
|
| 64 |
+
if ($ARGV[0] eq "-l") { $latex = 1; shift(@ARGV); }
|
| 65 |
+
elsif ($ARGV[0] eq "-r") { $raw = 1; shift(@ARGV); }
|
| 66 |
+
elsif ($ARGV[0] eq "-d") {
|
| 67 |
+
shift(@ARGV);
|
| 68 |
+
if (not defined $ARGV[0]) {
|
| 69 |
+
die "conlleval: -d requires delimiter character";
|
| 70 |
+
}
|
| 71 |
+
$delimiter = shift(@ARGV);
|
| 72 |
+
} elsif ($ARGV[0] eq "-o") {
|
| 73 |
+
shift(@ARGV);
|
| 74 |
+
if (not defined $ARGV[0]) {
|
| 75 |
+
die "conlleval: -o requires delimiter character";
|
| 76 |
+
}
|
| 77 |
+
$oTag = shift(@ARGV);
|
| 78 |
+
} else { die "conlleval: unknown argument $ARGV[0]\n"; }
|
| 79 |
+
}
|
| 80 |
+
if (@ARGV) { die "conlleval: unexpected command line argument\n"; }
|
| 81 |
+
# process input
|
| 82 |
+
while (<STDIN>) {
|
| 83 |
+
chomp($line = $_);
|
| 84 |
+
@features = split(/$delimiter/,$line);
|
| 85 |
+
if ($nbrOfFeatures < 0) { $nbrOfFeatures = $#features; }
|
| 86 |
+
elsif ($nbrOfFeatures != $#features and @features != 0) {
|
| 87 |
+
printf STDERR "unexpected number of features: %d (%d)\n",
|
| 88 |
+
$#features+1,$nbrOfFeatures+1;
|
| 89 |
+
exit(1);
|
| 90 |
+
}
|
| 91 |
+
if (@features == 0 or
|
| 92 |
+
$features[0] eq $boundary) { @features = ($boundary,"O","O"); }
|
| 93 |
+
if (@features < 2) {
|
| 94 |
+
die "conlleval: unexpected number of features in line $line\n";
|
| 95 |
+
}
|
| 96 |
+
if ($raw) {
|
| 97 |
+
if ($features[$#features] eq $oTag) { $features[$#features] = "O"; }
|
| 98 |
+
if ($features[$#features-1] eq $oTag) { $features[$#features-1] = "O"; }
|
| 99 |
+
if ($features[$#features] ne "O") {
|
| 100 |
+
$features[$#features] = "B-$features[$#features]";
|
| 101 |
+
}
|
| 102 |
+
if ($features[$#features-1] ne "O") {
|
| 103 |
+
$features[$#features-1] = "B-$features[$#features-1]";
|
| 104 |
+
}
|
| 105 |
+
}
|
| 106 |
+
# 20040126 ET code which allows hyphens in the types
|
| 107 |
+
if ($features[$#features] =~ /^([^-]*)-(.*)$/) {
|
| 108 |
+
$guessed = $1;
|
| 109 |
+
$guessedType = $2;
|
| 110 |
+
} else {
|
| 111 |
+
$guessed = $features[$#features];
|
| 112 |
+
$guessedType = "";
|
| 113 |
+
}
|
| 114 |
+
pop(@features);
|
| 115 |
+
if ($features[$#features] =~ /^([^-]*)-(.*)$/) {
|
| 116 |
+
$correct = $1;
|
| 117 |
+
$correctType = $2;
|
| 118 |
+
} else {
|
| 119 |
+
$correct = $features[$#features];
|
| 120 |
+
$correctType = "";
|
| 121 |
+
}
|
| 122 |
+
pop(@features);
|
| 123 |
+
# ($guessed,$guessedType) = split(/-/,pop(@features));
|
| 124 |
+
# ($correct,$correctType) = split(/-/,pop(@features));
|
| 125 |
+
$guessedType = $guessedType ? $guessedType : "";
|
| 126 |
+
$correctType = $correctType ? $correctType : "";
|
| 127 |
+
$firstItem = shift(@features);
|
| 128 |
+
|
| 129 |
+
# 1999-06-26 sentence breaks should always be counted as out of chunk
|
| 130 |
+
if ( $firstItem eq $boundary ) { $guessed = "O"; }
|
| 131 |
+
|
| 132 |
+
if ($inCorrect) {
|
| 133 |
+
if ( &endOfChunk($lastCorrect,$correct,$lastCorrectType,$correctType) and
|
| 134 |
+
&endOfChunk($lastGuessed,$guessed,$lastGuessedType,$guessedType) and
|
| 135 |
+
$lastGuessedType eq $lastCorrectType) {
|
| 136 |
+
$inCorrect=$false;
|
| 137 |
+
$correctChunk++;
|
| 138 |
+
$correctChunk{$lastCorrectType} = $correctChunk{$lastCorrectType} ?
|
| 139 |
+
$correctChunk{$lastCorrectType}+1 : 1;
|
| 140 |
+
} elsif (
|
| 141 |
+
&endOfChunk($lastCorrect,$correct,$lastCorrectType,$correctType) !=
|
| 142 |
+
&endOfChunk($lastGuessed,$guessed,$lastGuessedType,$guessedType) or
|
| 143 |
+
$guessedType ne $correctType ) {
|
| 144 |
+
$inCorrect=$false;
|
| 145 |
+
}
|
| 146 |
+
}
|
| 147 |
+
|
| 148 |
+
if ( &startOfChunk($lastCorrect,$correct,$lastCorrectType,$correctType) and
|
| 149 |
+
&startOfChunk($lastGuessed,$guessed,$lastGuessedType,$guessedType) and
|
| 150 |
+
$guessedType eq $correctType) { $inCorrect = $true; }
|
| 151 |
+
|
| 152 |
+
if ( &startOfChunk($lastCorrect,$correct,$lastCorrectType,$correctType) ) {
|
| 153 |
+
$foundCorrect++;
|
| 154 |
+
$foundCorrect{$correctType} = $foundCorrect{$correctType} ?
|
| 155 |
+
$foundCorrect{$correctType}+1 : 1;
|
| 156 |
+
}
|
| 157 |
+
if ( &startOfChunk($lastGuessed,$guessed,$lastGuessedType,$guessedType) ) {
|
| 158 |
+
$foundGuessed++;
|
| 159 |
+
$foundGuessed{$guessedType} = $foundGuessed{$guessedType} ?
|
| 160 |
+
$foundGuessed{$guessedType}+1 : 1;
|
| 161 |
+
}
|
| 162 |
+
if ( $firstItem ne $boundary ) {
|
| 163 |
+
if ( $correct eq $guessed and $guessedType eq $correctType ) {
|
| 164 |
+
$correctTags++;
|
| 165 |
+
}
|
| 166 |
+
$tokenCounter++;
|
| 167 |
+
}
|
| 168 |
+
|
| 169 |
+
$lastGuessed = $guessed;
|
| 170 |
+
$lastCorrect = $correct;
|
| 171 |
+
$lastGuessedType = $guessedType;
|
| 172 |
+
$lastCorrectType = $correctType;
|
| 173 |
+
}
|
| 174 |
+
if ($inCorrect) {
|
| 175 |
+
$correctChunk++;
|
| 176 |
+
$correctChunk{$lastCorrectType} = $correctChunk{$lastCorrectType} ?
|
| 177 |
+
$correctChunk{$lastCorrectType}+1 : 1;
|
| 178 |
+
}
|
| 179 |
+
|
| 180 |
+
if (not $latex) {
|
| 181 |
+
# compute overall precision, recall and FB1 (default values are 0.0)
|
| 182 |
+
$precision = 100*$correctChunk/$foundGuessed if ($foundGuessed > 0);
|
| 183 |
+
$recall = 100*$correctChunk/$foundCorrect if ($foundCorrect > 0);
|
| 184 |
+
$FB1 = 2*$precision*$recall/($precision+$recall)
|
| 185 |
+
if ($precision+$recall > 0);
|
| 186 |
+
|
| 187 |
+
# print overall performance
|
| 188 |
+
printf "processed $tokenCounter tokens with $foundCorrect phrases; ";
|
| 189 |
+
printf "found: $foundGuessed phrases; correct: $correctChunk.\n";
|
| 190 |
+
if ($tokenCounter>0) {
|
| 191 |
+
printf "accuracy: %6.2f%%; ",100*$correctTags/$tokenCounter;
|
| 192 |
+
printf "precision: %6.2f%%; ",$precision;
|
| 193 |
+
printf "recall: %6.2f%%; ",$recall;
|
| 194 |
+
printf "FB1: %6.2f\n",$FB1;
|
| 195 |
+
}
|
| 196 |
+
}
|
| 197 |
+
|
| 198 |
+
# sort chunk type names
|
| 199 |
+
undef($lastType);
|
| 200 |
+
@sortedTypes = ();
|
| 201 |
+
foreach $i (sort (keys %foundCorrect,keys %foundGuessed)) {
|
| 202 |
+
if (not($lastType) or $lastType ne $i) {
|
| 203 |
+
push(@sortedTypes,($i));
|
| 204 |
+
}
|
| 205 |
+
$lastType = $i;
|
| 206 |
+
}
|
| 207 |
+
# print performance per chunk type
|
| 208 |
+
if (not $latex) {
|
| 209 |
+
for $i (@sortedTypes) {
|
| 210 |
+
$correctChunk{$i} = $correctChunk{$i} ? $correctChunk{$i} : 0;
|
| 211 |
+
if (not($foundGuessed{$i})) { $foundGuessed{$i} = 0; $precision = 0.0; }
|
| 212 |
+
else { $precision = 100*$correctChunk{$i}/$foundGuessed{$i}; }
|
| 213 |
+
if (not($foundCorrect{$i})) { $recall = 0.0; }
|
| 214 |
+
else { $recall = 100*$correctChunk{$i}/$foundCorrect{$i}; }
|
| 215 |
+
if ($precision+$recall == 0.0) { $FB1 = 0.0; }
|
| 216 |
+
else { $FB1 = 2*$precision*$recall/($precision+$recall); }
|
| 217 |
+
printf "%17s: ",$i;
|
| 218 |
+
printf "precision: %6.2f%%; ",$precision;
|
| 219 |
+
printf "recall: %6.2f%%; ",$recall;
|
| 220 |
+
printf "FB1: %6.2f %d\n",$FB1,$foundGuessed{$i};
|
| 221 |
+
}
|
| 222 |
+
} else {
|
| 223 |
+
print " & Precision & Recall & F\$_{\\beta=1} \\\\\\hline";
|
| 224 |
+
for $i (@sortedTypes) {
|
| 225 |
+
$correctChunk{$i} = $correctChunk{$i} ? $correctChunk{$i} : 0;
|
| 226 |
+
if (not($foundGuessed{$i})) { $precision = 0.0; }
|
| 227 |
+
else { $precision = 100*$correctChunk{$i}/$foundGuessed{$i}; }
|
| 228 |
+
if (not($foundCorrect{$i})) { $recall = 0.0; }
|
| 229 |
+
else { $recall = 100*$correctChunk{$i}/$foundCorrect{$i}; }
|
| 230 |
+
if ($precision+$recall == 0.0) { $FB1 = 0.0; }
|
| 231 |
+
else { $FB1 = 2*$precision*$recall/($precision+$recall); }
|
| 232 |
+
printf "\n%-7s & %6.2f\\%% & %6.2f\\%% & %6.2f \\\\",
|
| 233 |
+
$i,$precision,$recall,$FB1;
|
| 234 |
+
}
|
| 235 |
+
print "\\hline\n";
|
| 236 |
+
$precision = 0.0;
|
| 237 |
+
$recall = 0;
|
| 238 |
+
$FB1 = 0.0;
|
| 239 |
+
$precision = 100*$correctChunk/$foundGuessed if ($foundGuessed > 0);
|
| 240 |
+
$recall = 100*$correctChunk/$foundCorrect if ($foundCorrect > 0);
|
| 241 |
+
$FB1 = 2*$precision*$recall/($precision+$recall)
|
| 242 |
+
if ($precision+$recall > 0);
|
| 243 |
+
printf "Overall & %6.2f\\%% & %6.2f\\%% & %6.2f \\\\\\hline\n",
|
| 244 |
+
$precision,$recall,$FB1;
|
| 245 |
+
}
|
| 246 |
+
|
| 247 |
+
exit 0;
|
| 248 |
+
|
| 249 |
+
# endOfChunk: checks if a chunk ended between the previous and current word
|
| 250 |
+
# arguments: previous and current chunk tags, previous and current types
|
| 251 |
+
# note: this code is capable of handling other chunk representations
|
| 252 |
+
# than the default CoNLL-2000 ones, see EACL'99 paper of Tjong
|
| 253 |
+
# Kim Sang and Veenstra http://xxx.lanl.gov/abs/cs.CL/9907006
|
| 254 |
+
|
| 255 |
+
sub endOfChunk {
|
| 256 |
+
my $prevTag = shift(@_);
|
| 257 |
+
my $tag = shift(@_);
|
| 258 |
+
my $prevType = shift(@_);
|
| 259 |
+
my $type = shift(@_);
|
| 260 |
+
my $chunkEnd = $false;
|
| 261 |
+
|
| 262 |
+
if ( $prevTag eq "B" and $tag eq "B" ) { $chunkEnd = $true; }
|
| 263 |
+
if ( $prevTag eq "B" and $tag eq "O" ) { $chunkEnd = $true; }
|
| 264 |
+
if ( $prevTag eq "B" and $tag eq "S" ) { $chunkEnd = $true; }
|
| 265 |
+
|
| 266 |
+
if ( $prevTag eq "I" and $tag eq "B" ) { $chunkEnd = $true; }
|
| 267 |
+
if ( $prevTag eq "I" and $tag eq "S" ) { $chunkEnd = $true; }
|
| 268 |
+
if ( $prevTag eq "I" and $tag eq "O" ) { $chunkEnd = $true; }
|
| 269 |
+
|
| 270 |
+
if ( $prevTag eq "E" and $tag eq "E" ) { $chunkEnd = $true; }
|
| 271 |
+
if ( $prevTag eq "E" and $tag eq "I" ) { $chunkEnd = $true; }
|
| 272 |
+
if ( $prevTag eq "E" and $tag eq "O" ) { $chunkEnd = $true; }
|
| 273 |
+
if ( $prevTag eq "E" and $tag eq "S" ) { $chunkEnd = $true; }
|
| 274 |
+
if ( $prevTag eq "E" and $tag eq "B" ) { $chunkEnd = $true; }
|
| 275 |
+
|
| 276 |
+
if ( $prevTag eq "S" and $tag eq "E" ) { $chunkEnd = $true; }
|
| 277 |
+
if ( $prevTag eq "S" and $tag eq "I" ) { $chunkEnd = $true; }
|
| 278 |
+
if ( $prevTag eq "S" and $tag eq "O" ) { $chunkEnd = $true; }
|
| 279 |
+
if ( $prevTag eq "S" and $tag eq "S" ) { $chunkEnd = $true; }
|
| 280 |
+
if ( $prevTag eq "S" and $tag eq "B" ) { $chunkEnd = $true; }
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
if ($prevTag ne "O" and $prevTag ne "." and $prevType ne $type) {
|
| 284 |
+
$chunkEnd = $true;
|
| 285 |
+
}
|
| 286 |
+
|
| 287 |
+
# corrected 1998-12-22: these chunks are assumed to have length 1
|
| 288 |
+
if ( $prevTag eq "]" ) { $chunkEnd = $true; }
|
| 289 |
+
if ( $prevTag eq "[" ) { $chunkEnd = $true; }
|
| 290 |
+
|
| 291 |
+
return($chunkEnd);
|
| 292 |
+
}
|
| 293 |
+
|
| 294 |
+
# startOfChunk: checks if a chunk started between the previous and current word
|
| 295 |
+
# arguments: previous and current chunk tags, previous and current types
|
| 296 |
+
# note: this code is capable of handling other chunk representations
|
| 297 |
+
# than the default CoNLL-2000 ones, see EACL'99 paper of Tjong
|
| 298 |
+
# Kim Sang and Veenstra http://xxx.lanl.gov/abs/cs.CL/9907006
|
| 299 |
+
|
| 300 |
+
sub startOfChunk {
|
| 301 |
+
my $prevTag = shift(@_);
|
| 302 |
+
my $tag = shift(@_);
|
| 303 |
+
my $prevType = shift(@_);
|
| 304 |
+
my $type = shift(@_);
|
| 305 |
+
my $chunkStart = $false;
|
| 306 |
+
|
| 307 |
+
if ( $prevTag eq "B" and $tag eq "B" ) { $chunkStart = $true; }
|
| 308 |
+
if ( $prevTag eq "I" and $tag eq "B" ) { $chunkStart = $true; }
|
| 309 |
+
if ( $prevTag eq "O" and $tag eq "B" ) { $chunkStart = $true; }
|
| 310 |
+
if ( $prevTag eq "S" and $tag eq "B" ) { $chunkStart = $true; }
|
| 311 |
+
if ( $prevTag eq "E" and $tag eq "B" ) { $chunkStart = $true; }
|
| 312 |
+
|
| 313 |
+
if ( $prevTag eq "B" and $tag eq "S" ) { $chunkStart = $true; }
|
| 314 |
+
if ( $prevTag eq "I" and $tag eq "S" ) { $chunkStart = $true; }
|
| 315 |
+
if ( $prevTag eq "O" and $tag eq "S" ) { $chunkStart = $true; }
|
| 316 |
+
if ( $prevTag eq "S" and $tag eq "S" ) { $chunkStart = $true; }
|
| 317 |
+
if ( $prevTag eq "E" and $tag eq "S" ) { $chunkStart = $true; }
|
| 318 |
+
|
| 319 |
+
if ( $prevTag eq "O" and $tag eq "I" ) { $chunkStart = $true; }
|
| 320 |
+
if ( $prevTag eq "S" and $tag eq "I" ) { $chunkStart = $true; }
|
| 321 |
+
if ( $prevTag eq "E" and $tag eq "I" ) { $chunkStart = $true; }
|
| 322 |
+
|
| 323 |
+
if ( $prevTag eq "S" and $tag eq "E" ) { $chunkStart = $true; }
|
| 324 |
+
if ( $prevTag eq "E" and $tag eq "E" ) { $chunkStart = $true; }
|
| 325 |
+
if ( $prevTag eq "O" and $tag eq "E" ) { $chunkStart = $true; }
|
| 326 |
+
|
| 327 |
+
if ($tag ne "O" and $tag ne "." and $prevType ne $type) {
|
| 328 |
+
$chunkStart = $true;
|
| 329 |
+
}
|
| 330 |
+
|
| 331 |
+
# corrected 1998-12-22: these chunks are assumed to have length 1
|
| 332 |
+
if ( $tag eq "[" ) { $chunkStart = $true; }
|
| 333 |
+
if ( $tag eq "]" ) { $chunkStart = $true; }
|
| 334 |
+
|
| 335 |
+
return($chunkStart);
|
| 336 |
+
}
|
examples/eval/conll06eval.pl
ADDED
|
@@ -0,0 +1,1826 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env perl
|
| 2 |
+
|
| 3 |
+
# Author: Yuval Krymolowski
|
| 4 |
+
# Addition of precision and recall
|
| 5 |
+
# and of frame confusion list: Sabine Buchholz
|
| 6 |
+
# Addition of DEPREL + ATTACHMENT:
|
| 7 |
+
# Prokopis Prokopidis (prokopis at ilsp dot gr)
|
| 8 |
+
# Acknowledgements:
|
| 9 |
+
# to Markus Kuhn for suggesting the use of
|
| 10 |
+
# the Unicode category property
|
| 11 |
+
|
| 12 |
+
if ($] < 5.008001)
|
| 13 |
+
{
|
| 14 |
+
printf STDERR <<EOM
|
| 15 |
+
|
| 16 |
+
This script requires PERL 5.8.1 for running.
|
| 17 |
+
The new version is needed for proper handling
|
| 18 |
+
of Unicode characters.
|
| 19 |
+
|
| 20 |
+
Please obtain a new version or contact the shared task team
|
| 21 |
+
if you are unable to upgrade PERL.
|
| 22 |
+
|
| 23 |
+
EOM
|
| 24 |
+
;
|
| 25 |
+
exit(1) ;
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
require Encode;
|
| 29 |
+
|
| 30 |
+
use strict ;
|
| 31 |
+
use warnings;
|
| 32 |
+
use Getopt::Std ;
|
| 33 |
+
|
| 34 |
+
my ($usage) = <<EOT
|
| 35 |
+
|
| 36 |
+
CoNLL-X evaluation script:
|
| 37 |
+
|
| 38 |
+
[perl] eval.pl [OPTIONS] -g <gold standard> -s <system output>
|
| 39 |
+
|
| 40 |
+
This script evaluates a system output with respect to a gold standard.
|
| 41 |
+
Both files should be in UTF-8 encoded CoNLL-X tabular format.
|
| 42 |
+
|
| 43 |
+
Punctuation tokens (those where all characters have the Unicode
|
| 44 |
+
category property "Punctuation") are ignored for scoring (unless the
|
| 45 |
+
-p flag is used).
|
| 46 |
+
|
| 47 |
+
The output breaks down the errors according to their type and context.
|
| 48 |
+
|
| 49 |
+
Optional parameters:
|
| 50 |
+
-o FILE : output: print output to FILE (default is standard output)
|
| 51 |
+
-q : quiet: only print overall performance, without the details
|
| 52 |
+
-b : evalb: produce output in a format similar to evalb
|
| 53 |
+
(http://nlp.cs.nyu.edu/evalb/); use together with -q
|
| 54 |
+
-p : punctuation: also score on punctuation (default is not to score on it)
|
| 55 |
+
-v : version: show the version number
|
| 56 |
+
-h : help: print this help text and exit
|
| 57 |
+
|
| 58 |
+
EOT
|
| 59 |
+
;
|
| 60 |
+
|
| 61 |
+
my ($line_num) ;
|
| 62 |
+
my ($sep) = '0x01' ;
|
| 63 |
+
|
| 64 |
+
my ($START) = '.S' ;
|
| 65 |
+
my ($END) = '.E' ;
|
| 66 |
+
|
| 67 |
+
my ($con_err_num) = 3 ;
|
| 68 |
+
my ($freq_err_num) = 10 ;
|
| 69 |
+
my ($spec_err_loc_con) = 8 ;
|
| 70 |
+
|
| 71 |
+
################################################################################
|
| 72 |
+
### subfunctions ###
|
| 73 |
+
################################################################################
|
| 74 |
+
|
| 75 |
+
# Whether a string consists entirely of characters with the Unicode
|
| 76 |
+
# category property "Punctuation" (see "man perlunicode")
|
| 77 |
+
sub is_uni_punct
|
| 78 |
+
{
|
| 79 |
+
my ($word) = @_ ;
|
| 80 |
+
|
| 81 |
+
return scalar(Encode::decode_utf8($word)=~ /^\p{Punctuation}+$/) ;
|
| 82 |
+
}
|
| 83 |
+
|
| 84 |
+
# The length of a unicode string, excluding non-spacing marks
|
| 85 |
+
# (for example vowel marks in Arabic)
|
| 86 |
+
|
| 87 |
+
sub uni_len
|
| 88 |
+
{
|
| 89 |
+
my ($word) = @_ ;
|
| 90 |
+
my ($ch, $l) ;
|
| 91 |
+
|
| 92 |
+
$l = 0 ;
|
| 93 |
+
foreach $ch (split(//, Encode::decode_utf8($word)))
|
| 94 |
+
{
|
| 95 |
+
if ($ch !~ /^\p{NonspacingMark}/)
|
| 96 |
+
{
|
| 97 |
+
$l++ ;
|
| 98 |
+
}
|
| 99 |
+
}
|
| 100 |
+
|
| 101 |
+
return $l ;
|
| 102 |
+
}
|
| 103 |
+
|
| 104 |
+
sub filter_context_counts
|
| 105 |
+
{ # filter_context_counts
|
| 106 |
+
|
| 107 |
+
my ($vec, $num, $max_len) = @_ ;
|
| 108 |
+
my ($con, $l, $thresh) ;
|
| 109 |
+
|
| 110 |
+
$thresh = (sort {$b <=> $a} values %{$vec})[$num-1] ;
|
| 111 |
+
|
| 112 |
+
foreach $con (keys %{$vec})
|
| 113 |
+
{
|
| 114 |
+
if (${$vec}{$con} < $thresh)
|
| 115 |
+
{
|
| 116 |
+
delete ${$vec}{$con} ;
|
| 117 |
+
next ;
|
| 118 |
+
}
|
| 119 |
+
|
| 120 |
+
$l = uni_len($con) ;
|
| 121 |
+
|
| 122 |
+
if ($l > ${$max_len})
|
| 123 |
+
{
|
| 124 |
+
${$max_len} = $l ;
|
| 125 |
+
}
|
| 126 |
+
}
|
| 127 |
+
|
| 128 |
+
} # filter_context_counts
|
| 129 |
+
|
| 130 |
+
sub print_context
|
| 131 |
+
{ # print_context
|
| 132 |
+
|
| 133 |
+
my ($counts, $counts_pos, $max_con_len, $max_con_pos_len) = @_ ;
|
| 134 |
+
my (@v_con, @v_con_pos, $con, $con_pos, $i, $n) ;
|
| 135 |
+
|
| 136 |
+
printf OUT " %-*s | %-4s | %-4s | %-4s | %-4s", $max_con_pos_len, 'CPOS', 'any', 'head', 'dep', 'both' ;
|
| 137 |
+
printf OUT " ||" ;
|
| 138 |
+
printf OUT " %-*s | %-4s | %-4s | %-4s | %-4s", $max_con_len, 'word', 'any', 'head', 'dep', 'both' ;
|
| 139 |
+
printf OUT "\n" ;
|
| 140 |
+
printf OUT " %s-+------+------+------+-----", '-' x $max_con_pos_len;
|
| 141 |
+
printf OUT "--++" ;
|
| 142 |
+
printf OUT "--%s-+------+------+------+-----", '-' x $max_con_len;
|
| 143 |
+
printf OUT "\n" ;
|
| 144 |
+
|
| 145 |
+
@v_con = sort {${$counts}{tot}{$b} <=> ${$counts}{tot}{$a}} keys %{${$counts}{tot}} ;
|
| 146 |
+
@v_con_pos = sort {${$counts_pos}{tot}{$b} <=> ${$counts_pos}{tot}{$a}} keys %{${$counts_pos}{tot}} ;
|
| 147 |
+
|
| 148 |
+
$n = scalar @v_con ;
|
| 149 |
+
if (scalar @v_con_pos > $n)
|
| 150 |
+
{
|
| 151 |
+
$n = scalar @v_con_pos ;
|
| 152 |
+
}
|
| 153 |
+
|
| 154 |
+
foreach $i (0 .. $n-1)
|
| 155 |
+
{
|
| 156 |
+
if (defined $v_con_pos[$i])
|
| 157 |
+
{
|
| 158 |
+
$con_pos = $v_con_pos[$i] ;
|
| 159 |
+
printf OUT " %-*s | %4d | %4d | %4d | %4d",
|
| 160 |
+
$max_con_pos_len, $con_pos, ${$counts_pos}{tot}{$con_pos},
|
| 161 |
+
${$counts_pos}{err_head}{$con_pos}, ${$counts_pos}{err_dep}{$con_pos},
|
| 162 |
+
${$counts_pos}{err_dep}{$con_pos}+${$counts_pos}{err_head}{$con_pos}-${$counts_pos}{tot}{$con_pos} ;
|
| 163 |
+
}
|
| 164 |
+
else
|
| 165 |
+
{
|
| 166 |
+
printf OUT " %-*s | %4s | %4s | %4s | %4s",
|
| 167 |
+
$max_con_pos_len, ' ', ' ', ' ', ' ', ' ' ;
|
| 168 |
+
}
|
| 169 |
+
|
| 170 |
+
printf OUT " ||" ;
|
| 171 |
+
|
| 172 |
+
if (defined $v_con[$i])
|
| 173 |
+
{
|
| 174 |
+
$con = $v_con[$i] ;
|
| 175 |
+
printf OUT " %-*s | %4d | %4d | %4d | %4d",
|
| 176 |
+
$max_con_len+length($con)-uni_len($con), $con, ${$counts}{tot}{$con},
|
| 177 |
+
${$counts}{err_head}{$con}, ${$counts}{err_dep}{$con},
|
| 178 |
+
${$counts}{err_dep}{$con}+${$counts}{err_head}{$con}-${$counts}{tot}{$con} ;
|
| 179 |
+
}
|
| 180 |
+
else
|
| 181 |
+
{
|
| 182 |
+
printf OUT " %-*s | %4s | %4s | %4s | %4s",
|
| 183 |
+
$max_con_len, ' ', ' ', ' ', ' ', ' ' ;
|
| 184 |
+
}
|
| 185 |
+
|
| 186 |
+
printf OUT "\n" ;
|
| 187 |
+
}
|
| 188 |
+
|
| 189 |
+
printf OUT " %s-+------+------+------+-----", '-' x $max_con_pos_len;
|
| 190 |
+
printf OUT "--++" ;
|
| 191 |
+
printf OUT "--%s-+------+------+------+-----", '-' x $max_con_len;
|
| 192 |
+
printf OUT "\n" ;
|
| 193 |
+
|
| 194 |
+
printf OUT "\n\n" ;
|
| 195 |
+
|
| 196 |
+
} # print_context
|
| 197 |
+
|
| 198 |
+
sub num_as_word
|
| 199 |
+
{
|
| 200 |
+
my ($num) = @_ ;
|
| 201 |
+
|
| 202 |
+
$num = abs($num) ;
|
| 203 |
+
|
| 204 |
+
if ($num == 1)
|
| 205 |
+
{
|
| 206 |
+
return ('one word') ;
|
| 207 |
+
}
|
| 208 |
+
elsif ($num == 2)
|
| 209 |
+
{
|
| 210 |
+
return ('two words') ;
|
| 211 |
+
}
|
| 212 |
+
elsif ($num == 3)
|
| 213 |
+
{
|
| 214 |
+
return ('three words') ;
|
| 215 |
+
}
|
| 216 |
+
elsif ($num == 4)
|
| 217 |
+
{
|
| 218 |
+
return ('four words') ;
|
| 219 |
+
}
|
| 220 |
+
else
|
| 221 |
+
{
|
| 222 |
+
return ($num.' words') ;
|
| 223 |
+
}
|
| 224 |
+
}
|
| 225 |
+
|
| 226 |
+
sub describe_err
|
| 227 |
+
{ # describe_err
|
| 228 |
+
|
| 229 |
+
my ($head_err, $head_aft_bef, $dep_err) = @_ ;
|
| 230 |
+
my ($dep_g, $dep_s, $desc) ;
|
| 231 |
+
my ($head_aft_bef_g, $head_aft_bef_s) = split(//, $head_aft_bef) ;
|
| 232 |
+
|
| 233 |
+
if ($head_err eq '-')
|
| 234 |
+
{
|
| 235 |
+
$desc = 'correct head' ;
|
| 236 |
+
|
| 237 |
+
if ($head_aft_bef_s eq '0')
|
| 238 |
+
{
|
| 239 |
+
$desc .= ' (0)' ;
|
| 240 |
+
}
|
| 241 |
+
elsif ($head_aft_bef_s eq 'e')
|
| 242 |
+
{
|
| 243 |
+
$desc .= ' (the focus word)' ;
|
| 244 |
+
}
|
| 245 |
+
elsif ($head_aft_bef_s eq 'a')
|
| 246 |
+
{
|
| 247 |
+
$desc .= ' (after the focus word)' ;
|
| 248 |
+
}
|
| 249 |
+
elsif ($head_aft_bef_s eq 'b')
|
| 250 |
+
{
|
| 251 |
+
$desc .= ' (before the focus word)' ;
|
| 252 |
+
}
|
| 253 |
+
}
|
| 254 |
+
elsif ($head_aft_bef_s eq '0')
|
| 255 |
+
{
|
| 256 |
+
$desc = 'head = 0 instead of ' ;
|
| 257 |
+
if ($head_aft_bef_g eq 'a')
|
| 258 |
+
{
|
| 259 |
+
$desc.= 'after ' ;
|
| 260 |
+
}
|
| 261 |
+
if ($head_aft_bef_g eq 'b')
|
| 262 |
+
{
|
| 263 |
+
$desc.= 'before ' ;
|
| 264 |
+
}
|
| 265 |
+
$desc .= 'the focus word' ;
|
| 266 |
+
}
|
| 267 |
+
elsif ($head_aft_bef_g eq '0')
|
| 268 |
+
{
|
| 269 |
+
$desc = 'head is ' ;
|
| 270 |
+
if ($head_aft_bef_g eq 'a')
|
| 271 |
+
{
|
| 272 |
+
$desc.= 'after ' ;
|
| 273 |
+
}
|
| 274 |
+
if ($head_aft_bef_g eq 'b')
|
| 275 |
+
{
|
| 276 |
+
$desc.= 'before ' ;
|
| 277 |
+
}
|
| 278 |
+
$desc .= 'the focus word instead of 0' ;
|
| 279 |
+
}
|
| 280 |
+
else
|
| 281 |
+
{
|
| 282 |
+
$desc = num_as_word($head_err) ;
|
| 283 |
+
if ($head_err < 0)
|
| 284 |
+
{
|
| 285 |
+
$desc .= ' before' ;
|
| 286 |
+
}
|
| 287 |
+
else
|
| 288 |
+
{
|
| 289 |
+
$desc .= ' after' ;
|
| 290 |
+
}
|
| 291 |
+
|
| 292 |
+
$desc = 'head '.$desc.' the correct head ' ;
|
| 293 |
+
|
| 294 |
+
if ($head_aft_bef_s eq '0')
|
| 295 |
+
{
|
| 296 |
+
$desc .= '(0' ;
|
| 297 |
+
}
|
| 298 |
+
elsif ($head_aft_bef_s eq 'e')
|
| 299 |
+
{
|
| 300 |
+
$desc .= '(the focus word' ;
|
| 301 |
+
}
|
| 302 |
+
elsif ($head_aft_bef_s eq 'a')
|
| 303 |
+
{
|
| 304 |
+
$desc .= '(after the focus word' ;
|
| 305 |
+
}
|
| 306 |
+
elsif ($head_aft_bef_s eq 'b')
|
| 307 |
+
{
|
| 308 |
+
$desc .= '(before the focus word' ;
|
| 309 |
+
}
|
| 310 |
+
|
| 311 |
+
if ($head_aft_bef_g ne $head_aft_bef_s)
|
| 312 |
+
{
|
| 313 |
+
$desc .= ' instead of' ;
|
| 314 |
+
if ($head_aft_bef_s eq '0')
|
| 315 |
+
{
|
| 316 |
+
$desc .= '0' ;
|
| 317 |
+
}
|
| 318 |
+
elsif ($head_aft_bef_s eq 'e')
|
| 319 |
+
{
|
| 320 |
+
$desc .= 'the focus word' ;
|
| 321 |
+
}
|
| 322 |
+
elsif ($head_aft_bef_s eq 'a')
|
| 323 |
+
{
|
| 324 |
+
$desc .= 'after the focus word' ;
|
| 325 |
+
}
|
| 326 |
+
elsif ($head_aft_bef_s eq 'b')
|
| 327 |
+
{
|
| 328 |
+
$desc .= 'before the focus word' ;
|
| 329 |
+
}
|
| 330 |
+
}
|
| 331 |
+
|
| 332 |
+
$desc .= ')' ;
|
| 333 |
+
}
|
| 334 |
+
|
| 335 |
+
$desc .= ', ' ;
|
| 336 |
+
|
| 337 |
+
if ($dep_err eq '-')
|
| 338 |
+
{
|
| 339 |
+
$desc .= 'correct dependency' ;
|
| 340 |
+
}
|
| 341 |
+
else
|
| 342 |
+
{
|
| 343 |
+
($dep_g, $dep_s) = ($dep_err =~ /^(.*)->(.*)$/) ;
|
| 344 |
+
$desc .= sprintf('dependency "%s" instead of "%s"', $dep_s, $dep_g) ;
|
| 345 |
+
}
|
| 346 |
+
|
| 347 |
+
return($desc) ;
|
| 348 |
+
|
| 349 |
+
} # describe_err
|
| 350 |
+
|
| 351 |
+
sub get_context
|
| 352 |
+
{ # get_context
|
| 353 |
+
|
| 354 |
+
my ($sent, $i_w) = @_ ;
|
| 355 |
+
my ($w_2, $w_1, $w1, $w2) ;
|
| 356 |
+
my ($p_2, $p_1, $p1, $p2) ;
|
| 357 |
+
|
| 358 |
+
if ($i_w >= 2)
|
| 359 |
+
{
|
| 360 |
+
$w_2 = ${${$sent}[$i_w-2]}{word} ;
|
| 361 |
+
$p_2 = ${${$sent}[$i_w-2]}{pos} ;
|
| 362 |
+
}
|
| 363 |
+
else
|
| 364 |
+
{
|
| 365 |
+
$w_2 = $START ;
|
| 366 |
+
$p_2 = $START ;
|
| 367 |
+
}
|
| 368 |
+
|
| 369 |
+
if ($i_w >= 1)
|
| 370 |
+
{
|
| 371 |
+
$w_1 = ${${$sent}[$i_w-1]}{word} ;
|
| 372 |
+
$p_1 = ${${$sent}[$i_w-1]}{pos} ;
|
| 373 |
+
}
|
| 374 |
+
else
|
| 375 |
+
{
|
| 376 |
+
$w_1 = $START ;
|
| 377 |
+
$p_1 = $START ;
|
| 378 |
+
}
|
| 379 |
+
|
| 380 |
+
if ($i_w <= scalar @{$sent}-2)
|
| 381 |
+
{
|
| 382 |
+
$w1 = ${${$sent}[$i_w+1]}{word} ;
|
| 383 |
+
$p1 = ${${$sent}[$i_w+1]}{pos} ;
|
| 384 |
+
}
|
| 385 |
+
else
|
| 386 |
+
{
|
| 387 |
+
$w1 = $END ;
|
| 388 |
+
$p1 = $END ;
|
| 389 |
+
}
|
| 390 |
+
|
| 391 |
+
if ($i_w <= scalar @{$sent}-3)
|
| 392 |
+
{
|
| 393 |
+
$w2 = ${${$sent}[$i_w+2]}{word} ;
|
| 394 |
+
$p2 = ${${$sent}[$i_w+2]}{pos} ;
|
| 395 |
+
}
|
| 396 |
+
else
|
| 397 |
+
{
|
| 398 |
+
$w2 = $END ;
|
| 399 |
+
$p2 = $END ;
|
| 400 |
+
}
|
| 401 |
+
|
| 402 |
+
return ($w_2, $w_1, $w1, $w2, $p_2, $p_1, $p1, $p2) ;
|
| 403 |
+
|
| 404 |
+
} # get_context
|
| 405 |
+
|
| 406 |
+
sub read_sent
|
| 407 |
+
{ # read_sent
|
| 408 |
+
|
| 409 |
+
my ($sent_gold, $sent_sys) = @_ ;
|
| 410 |
+
my ($line_g, $line_s, $new_sent) ;
|
| 411 |
+
my (%fields_g, %fields_s) ;
|
| 412 |
+
|
| 413 |
+
$new_sent = 1 ;
|
| 414 |
+
|
| 415 |
+
@{$sent_gold} = () ;
|
| 416 |
+
@{$sent_sys} = () ;
|
| 417 |
+
|
| 418 |
+
while (1)
|
| 419 |
+
{ # main reading loop
|
| 420 |
+
|
| 421 |
+
$line_g = <GOLD> ;
|
| 422 |
+
$line_s = <SYS> ;
|
| 423 |
+
|
| 424 |
+
$line_num++ ;
|
| 425 |
+
|
| 426 |
+
# system output has fewer lines than gold standard
|
| 427 |
+
if ((defined $line_g) && (! defined $line_s))
|
| 428 |
+
{
|
| 429 |
+
printf STDERR "line mismatch, line %d:\n", $line_num ;
|
| 430 |
+
printf STDERR " gold: %s", $line_g ;
|
| 431 |
+
printf STDERR " sys : past end of file\n" ;
|
| 432 |
+
exit(1) ;
|
| 433 |
+
}
|
| 434 |
+
|
| 435 |
+
# system output has more lines than gold standard
|
| 436 |
+
if ((! defined $line_g) && (defined $line_s))
|
| 437 |
+
{
|
| 438 |
+
printf STDERR "line mismatch, line %d:\n", $line_num ;
|
| 439 |
+
printf STDERR " gold: past end of file\n" ;
|
| 440 |
+
printf STDERR " sys : %s", $line_s ;
|
| 441 |
+
exit(1) ;
|
| 442 |
+
}
|
| 443 |
+
|
| 444 |
+
# end of file reached for both
|
| 445 |
+
if ((! defined $line_g) && (! defined $line_s))
|
| 446 |
+
{
|
| 447 |
+
return (1) ;
|
| 448 |
+
}
|
| 449 |
+
|
| 450 |
+
# one contains end of sentence but other one does not
|
| 451 |
+
if (($line_g =~ /^\s+$/) != ($line_s =~ /^\s+$/))
|
| 452 |
+
{
|
| 453 |
+
printf STDERR "line mismatch, line %d:\n", $line_num ;
|
| 454 |
+
printf STDERR " gold: %s", $line_g ;
|
| 455 |
+
printf STDERR " sys : %s", $line_s ;
|
| 456 |
+
exit(1) ;
|
| 457 |
+
}
|
| 458 |
+
|
| 459 |
+
# end of sentence reached
|
| 460 |
+
if ($line_g =~ /^\s+$/)
|
| 461 |
+
{
|
| 462 |
+
return(0) ;
|
| 463 |
+
}
|
| 464 |
+
|
| 465 |
+
# now both lines contain information
|
| 466 |
+
|
| 467 |
+
if ($new_sent)
|
| 468 |
+
{
|
| 469 |
+
$new_sent = 0 ;
|
| 470 |
+
}
|
| 471 |
+
|
| 472 |
+
# 'official' column names
|
| 473 |
+
# options.output = ['id','form','lemma','cpostag','postag',
|
| 474 |
+
# 'feats','head','deprel','phead','pdeprel']
|
| 475 |
+
|
| 476 |
+
@fields_g{'word', 'pos', 'head', 'dep'} = (split (/\s+/, $line_g))[1, 3, 6, 7] ;
|
| 477 |
+
|
| 478 |
+
push @{$sent_gold}, { %fields_g } ;
|
| 479 |
+
|
| 480 |
+
@fields_s{'word', 'pos', 'head', 'dep'} = (split (/\s+/, $line_s))[1, 3, 6, 7] ;
|
| 481 |
+
|
| 482 |
+
if (($fields_g{word} ne $fields_s{word})
|
| 483 |
+
||
|
| 484 |
+
($fields_g{pos} ne $fields_s{pos}))
|
| 485 |
+
{
|
| 486 |
+
printf STDERR "Word/pos mismatch, line %d:\n", $line_num ;
|
| 487 |
+
printf STDERR " gold: %s", $line_g ;
|
| 488 |
+
printf STDERR " sys : %s", $line_s ;
|
| 489 |
+
exit(1) ;
|
| 490 |
+
}
|
| 491 |
+
|
| 492 |
+
push @{$sent_sys}, { %fields_s } ;
|
| 493 |
+
|
| 494 |
+
} # main reading loop
|
| 495 |
+
|
| 496 |
+
} # read_sent
|
| 497 |
+
|
| 498 |
+
################################################################################
|
| 499 |
+
### main ###
|
| 500 |
+
################################################################################
|
| 501 |
+
|
| 502 |
+
our ($opt_g, $opt_s, $opt_o, $opt_h, $opt_v, $opt_q, $opt_p, $opt_b) ;
|
| 503 |
+
|
| 504 |
+
my ($sent_num, $eof, $word_num, @err_sent) ;
|
| 505 |
+
my (@sent_gold, @sent_sys, @starts) ;
|
| 506 |
+
my ($word, $pos, $wp, $head_g, $dep_g, $head_s, $dep_s) ;
|
| 507 |
+
my (%counts, $err_head, $err_dep, $con, $con1, $con_pos, $con_pos1, $thresh) ;
|
| 508 |
+
my ($head_err, $dep_err, @cur_err, %err_counts, $err_counter, $err_desc) ;
|
| 509 |
+
my ($loc_con, %loc_con_err_counts, %err_desc) ;
|
| 510 |
+
my ($head_aft_bef_g, $head_aft_bef_s, $head_aft_bef) ;
|
| 511 |
+
my ($con_bef, $con_aft, $con_bef_2, $con_aft_2, @bits, @e_bits, @v_con, @v_con_pos) ;
|
| 512 |
+
my ($con_pos_bef, $con_pos_aft, $con_pos_bef_2, $con_pos_aft_2) ;
|
| 513 |
+
my ($max_word_len, $max_pos_len, $max_con_len, $max_con_pos_len) ;
|
| 514 |
+
my ($max_word_spec_len, $max_con_bef_len, $max_con_aft_len) ;
|
| 515 |
+
my (%freq_err, $err) ;
|
| 516 |
+
|
| 517 |
+
my ($i, $j, $i_w, $l, $n_args) ;
|
| 518 |
+
my ($w_2, $w_1, $w1, $w2) ;
|
| 519 |
+
my ($wp_2, $wp_1, $wp1, $wp2) ;
|
| 520 |
+
my ($p_2, $p_1, $p1, $p2) ;
|
| 521 |
+
|
| 522 |
+
my ($short_output) ;
|
| 523 |
+
my ($score_on_punct) ;
|
| 524 |
+
$counts{punct} = 0; # initialize
|
| 525 |
+
|
| 526 |
+
getopts("g:o:s:qvhpb") ;
|
| 527 |
+
|
| 528 |
+
if (defined $opt_v)
|
| 529 |
+
{
|
| 530 |
+
my $id = '$Id: eval.pl,v 1.9 2006/05/09 20:30:01 yuval Exp $';
|
| 531 |
+
my @parts = split ' ',$id;
|
| 532 |
+
print "Version $parts[2]\n";
|
| 533 |
+
exit(0);
|
| 534 |
+
}
|
| 535 |
+
|
| 536 |
+
if ((defined $opt_h) || ((! defined $opt_g) && (! defined $opt_s)))
|
| 537 |
+
{
|
| 538 |
+
die $usage ;
|
| 539 |
+
}
|
| 540 |
+
|
| 541 |
+
if (! defined $opt_g)
|
| 542 |
+
{
|
| 543 |
+
die "Gold standard file (-g) missing\n" ;
|
| 544 |
+
}
|
| 545 |
+
|
| 546 |
+
if (! defined $opt_s)
|
| 547 |
+
{
|
| 548 |
+
die "System output file (-s) missing\n" ;
|
| 549 |
+
}
|
| 550 |
+
|
| 551 |
+
if (! defined $opt_o)
|
| 552 |
+
{
|
| 553 |
+
$opt_o = '-' ;
|
| 554 |
+
}
|
| 555 |
+
|
| 556 |
+
if (defined $opt_q)
|
| 557 |
+
{
|
| 558 |
+
$short_output = 1 ;
|
| 559 |
+
} else {
|
| 560 |
+
$short_output = 0 ;
|
| 561 |
+
}
|
| 562 |
+
|
| 563 |
+
if (defined $opt_p)
|
| 564 |
+
{
|
| 565 |
+
$score_on_punct = 1 ;
|
| 566 |
+
} else {
|
| 567 |
+
$score_on_punct = 0 ;
|
| 568 |
+
}
|
| 569 |
+
|
| 570 |
+
$line_num = 0 ;
|
| 571 |
+
$sent_num = 0 ;
|
| 572 |
+
$eof = 0 ;
|
| 573 |
+
|
| 574 |
+
@err_sent = () ;
|
| 575 |
+
@starts = () ;
|
| 576 |
+
|
| 577 |
+
%{$err_sent[0]} = () ;
|
| 578 |
+
|
| 579 |
+
$max_pos_len = length('CPOS') ;
|
| 580 |
+
|
| 581 |
+
################################################################################
|
| 582 |
+
### reading input ###
|
| 583 |
+
################################################################################
|
| 584 |
+
|
| 585 |
+
open (GOLD, "<$opt_g") || die "Could not open gold standard file $opt_g\n" ;
|
| 586 |
+
open (SYS, "<$opt_s") || die "Could not open system output file $opt_s\n" ;
|
| 587 |
+
open (OUT, ">$opt_o") || die "Could not open output file $opt_o\n" ;
|
| 588 |
+
|
| 589 |
+
|
| 590 |
+
if (defined $opt_b) { # produce output similar to evalb
|
| 591 |
+
print OUT " Sent. Attachment Correct Scoring \n";
|
| 592 |
+
print OUT " ID Tokens - Unlab. Lab. HEAD HEAD+DEPREL tokens - - - -\n";
|
| 593 |
+
print OUT " ============================================================================\n";
|
| 594 |
+
}
|
| 595 |
+
|
| 596 |
+
|
| 597 |
+
while (! $eof)
|
| 598 |
+
{ # main reading loop
|
| 599 |
+
|
| 600 |
+
$starts[$sent_num] = $line_num+1 ;
|
| 601 |
+
$eof = read_sent(\@sent_gold, \@sent_sys) ;
|
| 602 |
+
|
| 603 |
+
$sent_num++ ;
|
| 604 |
+
|
| 605 |
+
%{$err_sent[$sent_num]} = () ;
|
| 606 |
+
$word_num = scalar @sent_gold ;
|
| 607 |
+
|
| 608 |
+
# for accuracy per sentence
|
| 609 |
+
my %sent_counts = ( tot => 0,
|
| 610 |
+
err_any => 0,
|
| 611 |
+
err_head => 0
|
| 612 |
+
);
|
| 613 |
+
|
| 614 |
+
# printf "$sent_num $word_num\n" ;
|
| 615 |
+
|
| 616 |
+
my @frames_g = ('** '); # the initial frame for the virtual root
|
| 617 |
+
my @frames_s = ('** '); # the initial frame for the virtual root
|
| 618 |
+
foreach $i_w (0 .. $word_num-1)
|
| 619 |
+
{ # loop on words
|
| 620 |
+
push @frames_g, ''; # initialize
|
| 621 |
+
push @frames_s, ''; # initialize
|
| 622 |
+
}
|
| 623 |
+
|
| 624 |
+
foreach $i_w (0 .. $word_num-1)
|
| 625 |
+
{ # loop on words
|
| 626 |
+
|
| 627 |
+
($word, $pos, $head_g, $dep_g)
|
| 628 |
+
= @{$sent_gold[$i_w]}{'word', 'pos', 'head', 'dep'} ;
|
| 629 |
+
$wp = $word.' / '.$pos ;
|
| 630 |
+
|
| 631 |
+
# printf "%d: %s %s %s %s\n", $i_w, $word, $pos, $head_g, $dep_g ;
|
| 632 |
+
|
| 633 |
+
if ((! $score_on_punct) && is_uni_punct($word))
|
| 634 |
+
{
|
| 635 |
+
$counts{punct}++ ;
|
| 636 |
+
# ignore punctuations
|
| 637 |
+
next ;
|
| 638 |
+
}
|
| 639 |
+
|
| 640 |
+
if (length($pos) > $max_pos_len)
|
| 641 |
+
{
|
| 642 |
+
$max_pos_len = length($pos) ;
|
| 643 |
+
}
|
| 644 |
+
|
| 645 |
+
($head_s, $dep_s) = @{$sent_sys[$i_w]}{'head', 'dep'} ;
|
| 646 |
+
|
| 647 |
+
$counts{tot}++ ;
|
| 648 |
+
$counts{word}{$wp}{tot}++ ;
|
| 649 |
+
$counts{pos}{$pos}{tot}++ ;
|
| 650 |
+
$counts{head}{$head_g-$i_w-1}{tot}++ ;
|
| 651 |
+
|
| 652 |
+
# for frame confusions
|
| 653 |
+
# add child to frame of parent
|
| 654 |
+
$frames_g[$head_g] .= "$dep_g ";
|
| 655 |
+
$frames_s[$head_s] .= "$dep_s ";
|
| 656 |
+
# add to frame of token itself
|
| 657 |
+
$frames_g[$i_w+1] .= "*$dep_g* "; # $i_w+1 because $i_w starts counting at zero
|
| 658 |
+
$frames_s[$i_w+1] .= "*$dep_g* ";
|
| 659 |
+
|
| 660 |
+
# for precision and recall of DEPREL
|
| 661 |
+
$counts{dep}{$dep_g}{tot}++ ; # counts for gold standard deprels
|
| 662 |
+
$counts{dep2}{$dep_g}{$dep_s}++ ; # counts for confusions
|
| 663 |
+
$counts{dep_s}{$dep_s}{tot}++ ; # counts for system deprels
|
| 664 |
+
$counts{all_dep}{$dep_g} = 1 ; # list of all deprels that occur ...
|
| 665 |
+
$counts{all_dep}{$dep_s} = 1 ; # ... in either gold or system output
|
| 666 |
+
|
| 667 |
+
# for precision and recall of HEAD direction
|
| 668 |
+
my $dir_g;
|
| 669 |
+
if ($head_g == 0) {
|
| 670 |
+
$dir_g = 'to_root';
|
| 671 |
+
} elsif ($head_g < $i_w+1) { # $i_w+1 because $i_w starts counting at zero
|
| 672 |
+
# also below
|
| 673 |
+
$dir_g = 'left';
|
| 674 |
+
} elsif ($head_g > $i_w+1) {
|
| 675 |
+
$dir_g = 'right';
|
| 676 |
+
} else {
|
| 677 |
+
# token links to itself; should never happen in correct gold standard
|
| 678 |
+
$dir_g = 'self';
|
| 679 |
+
}
|
| 680 |
+
my $dir_s;
|
| 681 |
+
if ($head_s == 0) {
|
| 682 |
+
$dir_s = 'to_root';
|
| 683 |
+
} elsif ($head_s < $i_w+1) {
|
| 684 |
+
$dir_s = 'left';
|
| 685 |
+
} elsif ($head_s > $i_w+1) {
|
| 686 |
+
$dir_s = 'right';
|
| 687 |
+
} else {
|
| 688 |
+
# token links to itself; should not happen in good system
|
| 689 |
+
# (but not forbidden in shared task)
|
| 690 |
+
$dir_s = 'self';
|
| 691 |
+
}
|
| 692 |
+
$counts{dir_g}{$dir_g}{tot}++ ; # counts for gold standard head direction
|
| 693 |
+
$counts{dir2}{$dir_g}{$dir_s}++ ; # counts for confusions
|
| 694 |
+
$counts{dir_s}{$dir_s}{tot}++ ; # counts for system head direction
|
| 695 |
+
|
| 696 |
+
# for precision and recall of HEAD distance
|
| 697 |
+
my $dist_g;
|
| 698 |
+
if ($head_g == 0) {
|
| 699 |
+
$dist_g = 'to_root';
|
| 700 |
+
} elsif ( abs($head_g - ($i_w+1)) <= 1 ) {
|
| 701 |
+
$dist_g = '1'; # includes the 'self' cases
|
| 702 |
+
} elsif ( abs($head_g - ($i_w+1)) <= 2 ) {
|
| 703 |
+
$dist_g = '2';
|
| 704 |
+
} elsif ( abs($head_g - ($i_w+1)) <= 6 ) {
|
| 705 |
+
$dist_g = '3-6';
|
| 706 |
+
} else {
|
| 707 |
+
$dist_g = '7-...';
|
| 708 |
+
}
|
| 709 |
+
my $dist_s;
|
| 710 |
+
if ($head_s == 0) {
|
| 711 |
+
$dist_s = 'to_root';
|
| 712 |
+
} elsif ( abs($head_s - ($i_w+1)) <= 1 ) {
|
| 713 |
+
$dist_s = '1'; # includes the 'self' cases
|
| 714 |
+
} elsif ( abs($head_s - ($i_w+1)) <= 2 ) {
|
| 715 |
+
$dist_s = '2';
|
| 716 |
+
} elsif ( abs($head_s - ($i_w+1)) <= 6 ) {
|
| 717 |
+
$dist_s = '3-6';
|
| 718 |
+
} else {
|
| 719 |
+
$dist_s = '7-...';
|
| 720 |
+
}
|
| 721 |
+
$counts{dist_g}{$dist_g}{tot}++ ; # counts for gold standard head distance
|
| 722 |
+
$counts{dist2}{$dist_g}{$dist_s}++ ; # counts for confusions
|
| 723 |
+
$counts{dist_s}{$dist_s}{tot}++ ; # counts for system head distance
|
| 724 |
+
|
| 725 |
+
|
| 726 |
+
$err_head = ($head_g ne $head_s) ; # error in head
|
| 727 |
+
$err_dep = ($dep_g ne $dep_s) ; # error in deprel
|
| 728 |
+
|
| 729 |
+
$head_err = '-' ;
|
| 730 |
+
$dep_err = '-' ;
|
| 731 |
+
|
| 732 |
+
# for accuracy per sentence
|
| 733 |
+
$sent_counts{tot}++ ;
|
| 734 |
+
if ($err_dep || $err_head) {
|
| 735 |
+
$sent_counts{err_any}++ ;
|
| 736 |
+
}
|
| 737 |
+
if ($err_head) {
|
| 738 |
+
$sent_counts{err_head}++ ;
|
| 739 |
+
}
|
| 740 |
+
|
| 741 |
+
# total counts and counts for CPOS involved in errors
|
| 742 |
+
|
| 743 |
+
if ($head_g eq '0')
|
| 744 |
+
{
|
| 745 |
+
$head_aft_bef_g = '0' ;
|
| 746 |
+
}
|
| 747 |
+
elsif ($head_g eq $i_w+1)
|
| 748 |
+
{
|
| 749 |
+
$head_aft_bef_g = 'e' ;
|
| 750 |
+
}
|
| 751 |
+
else
|
| 752 |
+
{
|
| 753 |
+
$head_aft_bef_g = ($head_g <= $i_w+1 ? 'b' : 'a') ;
|
| 754 |
+
}
|
| 755 |
+
|
| 756 |
+
if ($head_s eq '0')
|
| 757 |
+
{
|
| 758 |
+
$head_aft_bef_s = '0' ;
|
| 759 |
+
}
|
| 760 |
+
elsif ($head_s eq $i_w+1)
|
| 761 |
+
{
|
| 762 |
+
$head_aft_bef_s = 'e' ;
|
| 763 |
+
}
|
| 764 |
+
else
|
| 765 |
+
{
|
| 766 |
+
$head_aft_bef_s = ($head_s <= $i_w+1 ? 'b' : 'a') ;
|
| 767 |
+
}
|
| 768 |
+
|
| 769 |
+
$head_aft_bef = $head_aft_bef_g.$head_aft_bef_s ;
|
| 770 |
+
|
| 771 |
+
if ($err_head)
|
| 772 |
+
{
|
| 773 |
+
if ($head_aft_bef_s eq '0')
|
| 774 |
+
{
|
| 775 |
+
$head_err = 0 ;
|
| 776 |
+
}
|
| 777 |
+
else
|
| 778 |
+
{
|
| 779 |
+
$head_err = $head_s-$head_g ;
|
| 780 |
+
}
|
| 781 |
+
|
| 782 |
+
$err_sent[$sent_num]{head}++ ;
|
| 783 |
+
$counts{err_head}{tot}++ ;
|
| 784 |
+
$counts{err_head}{$head_err}++ ;
|
| 785 |
+
|
| 786 |
+
$counts{word}{err_head}{$wp}++ ;
|
| 787 |
+
$counts{pos}{$pos}{err_head}{tot}++ ;
|
| 788 |
+
$counts{pos}{$pos}{err_head}{$head_err}++ ;
|
| 789 |
+
}
|
| 790 |
+
|
| 791 |
+
if ($err_dep)
|
| 792 |
+
{
|
| 793 |
+
$dep_err = $dep_g.'->'.$dep_s ;
|
| 794 |
+
$err_sent[$sent_num]{dep}++ ;
|
| 795 |
+
$counts{err_dep}{tot}++ ;
|
| 796 |
+
$counts{err_dep}{$dep_err}++ ;
|
| 797 |
+
|
| 798 |
+
$counts{word}{err_dep}{$wp}++ ;
|
| 799 |
+
$counts{pos}{$pos}{err_dep}{tot}++ ;
|
| 800 |
+
$counts{pos}{$pos}{err_dep}{$dep_err}++ ;
|
| 801 |
+
|
| 802 |
+
if ($err_head)
|
| 803 |
+
{
|
| 804 |
+
$counts{err_both}++ ;
|
| 805 |
+
$counts{pos}{$pos}{err_both}++ ;
|
| 806 |
+
}
|
| 807 |
+
}
|
| 808 |
+
|
| 809 |
+
### DEPREL + ATTACHMENT
|
| 810 |
+
if ((!$err_dep) && ($err_head)) {
|
| 811 |
+
$counts{err_head_corr_dep}{tot}++ ;
|
| 812 |
+
$counts{err_head_corr_dep}{$dep_s}++ ;
|
| 813 |
+
}
|
| 814 |
+
### DEPREL + ATTACHMENT
|
| 815 |
+
|
| 816 |
+
# counts for words involved in errors
|
| 817 |
+
|
| 818 |
+
if (! ($err_head || $err_dep))
|
| 819 |
+
{
|
| 820 |
+
next ;
|
| 821 |
+
}
|
| 822 |
+
|
| 823 |
+
$err_sent[$sent_num]{word}++ ;
|
| 824 |
+
$counts{err_any}++ ;
|
| 825 |
+
$counts{word}{err_any}{$wp}++ ;
|
| 826 |
+
$counts{pos}{$pos}{err_any}++ ;
|
| 827 |
+
|
| 828 |
+
($w_2, $w_1, $w1, $w2, $p_2, $p_1, $p1, $p2) = get_context(\@sent_gold, $i_w) ;
|
| 829 |
+
|
| 830 |
+
if ($w_2 ne $START)
|
| 831 |
+
{
|
| 832 |
+
$wp_2 = $w_2.' / '.$p_2 ;
|
| 833 |
+
}
|
| 834 |
+
else
|
| 835 |
+
{
|
| 836 |
+
$wp_2 = $w_2 ;
|
| 837 |
+
}
|
| 838 |
+
|
| 839 |
+
if ($w_1 ne $START)
|
| 840 |
+
{
|
| 841 |
+
$wp_1 = $w_1.' / '.$p_1 ;
|
| 842 |
+
}
|
| 843 |
+
else
|
| 844 |
+
{
|
| 845 |
+
$wp_1 = $w_1 ;
|
| 846 |
+
}
|
| 847 |
+
|
| 848 |
+
if ($w1 ne $END)
|
| 849 |
+
{
|
| 850 |
+
$wp1 = $w1.' / '.$p1 ;
|
| 851 |
+
}
|
| 852 |
+
else
|
| 853 |
+
{
|
| 854 |
+
$wp1 = $w1 ;
|
| 855 |
+
}
|
| 856 |
+
|
| 857 |
+
if ($w2 ne $END)
|
| 858 |
+
{
|
| 859 |
+
$wp2 = $w2.' / '.$p2 ;
|
| 860 |
+
}
|
| 861 |
+
else
|
| 862 |
+
{
|
| 863 |
+
$wp2 = $w2 ;
|
| 864 |
+
}
|
| 865 |
+
|
| 866 |
+
$con_bef = $wp_1 ;
|
| 867 |
+
$con_bef_2 = $wp_2.' + '.$wp_1 ;
|
| 868 |
+
$con_aft = $wp1 ;
|
| 869 |
+
$con_aft_2 = $wp1.' + '.$wp2 ;
|
| 870 |
+
|
| 871 |
+
$con_pos_bef = $p_1 ;
|
| 872 |
+
$con_pos_bef_2 = $p_2.'+'.$p_1 ;
|
| 873 |
+
$con_pos_aft = $p1 ;
|
| 874 |
+
$con_pos_aft_2 = $p1.'+'.$p2 ;
|
| 875 |
+
|
| 876 |
+
if ($w_1 ne $START)
|
| 877 |
+
{
|
| 878 |
+
# do not count '.S' as a word context
|
| 879 |
+
$counts{con_bef_2}{tot}{$con_bef_2}++ ;
|
| 880 |
+
$counts{con_bef_2}{err_head}{$con_bef_2} += $err_head ;
|
| 881 |
+
$counts{con_bef_2}{err_dep}{$con_bef_2} += $err_dep ;
|
| 882 |
+
$counts{con_bef}{tot}{$con_bef}++ ;
|
| 883 |
+
$counts{con_bef}{err_head}{$con_bef} += $err_head ;
|
| 884 |
+
$counts{con_bef}{err_dep}{$con_bef} += $err_dep ;
|
| 885 |
+
}
|
| 886 |
+
|
| 887 |
+
if ($w1 ne $END)
|
| 888 |
+
{
|
| 889 |
+
# do not count '.E' as a word context
|
| 890 |
+
$counts{con_aft_2}{tot}{$con_aft_2}++ ;
|
| 891 |
+
$counts{con_aft_2}{err_head}{$con_aft_2} += $err_head ;
|
| 892 |
+
$counts{con_aft_2}{err_dep}{$con_aft_2} += $err_dep ;
|
| 893 |
+
$counts{con_aft}{tot}{$con_aft}++ ;
|
| 894 |
+
$counts{con_aft}{err_head}{$con_aft} += $err_head ;
|
| 895 |
+
$counts{con_aft}{err_dep}{$con_aft} += $err_dep ;
|
| 896 |
+
}
|
| 897 |
+
|
| 898 |
+
$counts{con_pos_bef_2}{tot}{$con_pos_bef_2}++ ;
|
| 899 |
+
$counts{con_pos_bef_2}{err_head}{$con_pos_bef_2} += $err_head ;
|
| 900 |
+
$counts{con_pos_bef_2}{err_dep}{$con_pos_bef_2} += $err_dep ;
|
| 901 |
+
$counts{con_pos_bef}{tot}{$con_pos_bef}++ ;
|
| 902 |
+
$counts{con_pos_bef}{err_head}{$con_pos_bef} += $err_head ;
|
| 903 |
+
$counts{con_pos_bef}{err_dep}{$con_pos_bef} += $err_dep ;
|
| 904 |
+
|
| 905 |
+
$counts{con_pos_aft_2}{tot}{$con_pos_aft_2}++ ;
|
| 906 |
+
$counts{con_pos_aft_2}{err_head}{$con_pos_aft_2} += $err_head ;
|
| 907 |
+
$counts{con_pos_aft_2}{err_dep}{$con_pos_aft_2} += $err_dep ;
|
| 908 |
+
$counts{con_pos_aft}{tot}{$con_pos_aft}++ ;
|
| 909 |
+
$counts{con_pos_aft}{err_head}{$con_pos_aft} += $err_head ;
|
| 910 |
+
$counts{con_pos_aft}{err_dep}{$con_pos_aft} += $err_dep ;
|
| 911 |
+
|
| 912 |
+
$err = $head_err.$sep.$head_aft_bef.$sep.$dep_err ;
|
| 913 |
+
$freq_err{$err}++ ;
|
| 914 |
+
|
| 915 |
+
} # loop on words
|
| 916 |
+
|
| 917 |
+
foreach $i_w (0 .. $word_num) # including one for the virtual root
|
| 918 |
+
{ # loop on words
|
| 919 |
+
if ($frames_g[$i_w] ne $frames_s[$i_w]) {
|
| 920 |
+
$counts{frame2}{"$frames_g[$i_w]/ $frames_s[$i_w]"}++ ;
|
| 921 |
+
}
|
| 922 |
+
}
|
| 923 |
+
|
| 924 |
+
if (defined $opt_b) { # produce output similar to evalb
|
| 925 |
+
if ($word_num > 0) {
|
| 926 |
+
my ($unlabeled,$labeled) = ('NaN', 'NaN');
|
| 927 |
+
if ($sent_counts{tot} > 0) { # there are scoring tokens
|
| 928 |
+
$unlabeled = 100-$sent_counts{err_head}*100.0/$sent_counts{tot};
|
| 929 |
+
$labeled = 100-$sent_counts{err_any} *100.0/$sent_counts{tot};
|
| 930 |
+
}
|
| 931 |
+
printf OUT " %4d %4d 0 %6.2f %6.2f %4d %4d %4d 0 0 0 0\n",
|
| 932 |
+
$sent_num, $word_num,
|
| 933 |
+
$unlabeled, $labeled,
|
| 934 |
+
$sent_counts{tot}-$sent_counts{err_head},
|
| 935 |
+
$sent_counts{tot}-$sent_counts{err_any},
|
| 936 |
+
$sent_counts{tot},;
|
| 937 |
+
}
|
| 938 |
+
}
|
| 939 |
+
|
| 940 |
+
} # main reading loop
|
| 941 |
+
|
| 942 |
+
################################################################################
|
| 943 |
+
### printing output ###
|
| 944 |
+
################################################################################
|
| 945 |
+
|
| 946 |
+
if (defined $opt_b) { # produce output similar to evalb
|
| 947 |
+
print OUT "\n\n";
|
| 948 |
+
}
|
| 949 |
+
printf OUT " Labeled attachment score: %d / %d * 100 = %.2f %%\n",
|
| 950 |
+
$counts{tot}-$counts{err_any}, $counts{tot}, 100-$counts{err_any}*100.0/$counts{tot} ;
|
| 951 |
+
printf OUT " Unlabeled attachment score: %d / %d * 100 = %.2f %%\n",
|
| 952 |
+
$counts{tot}-$counts{err_head}{tot}, $counts{tot}, 100-$counts{err_head}{tot}*100.0/$counts{tot} ;
|
| 953 |
+
printf OUT " Label accuracy score: %d / %d * 100 = %.2f %%\n",
|
| 954 |
+
$counts{tot}-$counts{err_dep}{tot}, $counts{tot}, 100-$counts{err_dep}{tot}*100.0/$counts{tot} ;
|
| 955 |
+
|
| 956 |
+
if ($short_output)
|
| 957 |
+
{
|
| 958 |
+
exit(0) ;
|
| 959 |
+
}
|
| 960 |
+
printf OUT "\n %s\n\n", '=' x 80 ;
|
| 961 |
+
printf OUT " Evaluation of the results in %s\n vs. gold standard %s:\n\n", $opt_s, $opt_g ;
|
| 962 |
+
|
| 963 |
+
printf OUT " Legend: '%s' - the beginning of a sentence, '%s' - the end of a sentence\n\n", $START, $END ;
|
| 964 |
+
|
| 965 |
+
printf OUT " Number of non-scoring tokens: $counts{punct}\n\n";
|
| 966 |
+
|
| 967 |
+
printf OUT " The overall accuracy and its distribution over CPOSTAGs\n\n" ;
|
| 968 |
+
printf OUT "%s\n", " -----------+-------+-------+------+-------+------+-------+-------" ;
|
| 969 |
+
|
| 970 |
+
printf OUT " %-10s | %-5s | %-5s | %% | %-5s | %% | %-5s | %%\n",
|
| 971 |
+
'Accuracy', 'words', 'right', 'right', 'both' ;
|
| 972 |
+
printf OUT " %-10s | %-5s | %-5s | | %-5s | | %-5s |\n",
|
| 973 |
+
' ', ' ', 'head', ' dep', 'right' ;
|
| 974 |
+
|
| 975 |
+
printf OUT "%s\n", " -----------+-------+-------+------+-------+------+-------+-------" ;
|
| 976 |
+
|
| 977 |
+
printf OUT " %-10s | %5d | %5d | %3.0f%% | %5d | %3.0f%% | %5d | %3.0f%%\n",
|
| 978 |
+
'total', $counts{tot},
|
| 979 |
+
$counts{tot}-$counts{err_head}{tot}, 100-$counts{err_head}{tot}*100.0/$counts{tot},
|
| 980 |
+
$counts{tot}-$counts{err_dep}{tot}, 100-$counts{err_dep}{tot}*100.0/$counts{tot},
|
| 981 |
+
$counts{tot}-$counts{err_any}, 100-$counts{err_any}*100.0/$counts{tot} ;
|
| 982 |
+
|
| 983 |
+
printf OUT "%s\n", " -----------+-------+-------+------+-------+------+-------+-------" ;
|
| 984 |
+
|
| 985 |
+
foreach $pos (sort {$counts{pos}{$b}{tot} <=> $counts{pos}{$a}{tot}} keys %{$counts{pos}})
|
| 986 |
+
{
|
| 987 |
+
if (! defined($counts{pos}{$pos}{err_head}{tot}))
|
| 988 |
+
{
|
| 989 |
+
$counts{pos}{$pos}{err_head}{tot} = 0 ;
|
| 990 |
+
}
|
| 991 |
+
if (! defined($counts{pos}{$pos}{err_dep}{tot}))
|
| 992 |
+
{
|
| 993 |
+
$counts{pos}{$pos}{err_dep}{tot} = 0 ;
|
| 994 |
+
}
|
| 995 |
+
if (! defined($counts{pos}{$pos}{err_any}))
|
| 996 |
+
{
|
| 997 |
+
$counts{pos}{$pos}{err_any} = 0 ;
|
| 998 |
+
}
|
| 999 |
+
|
| 1000 |
+
printf OUT " %-10s | %5d | %5d | %3.0f%% | %5d | %3.0f%% | %5d | %3.0f%%\n",
|
| 1001 |
+
$pos, $counts{pos}{$pos}{tot},
|
| 1002 |
+
$counts{pos}{$pos}{tot}-$counts{pos}{$pos}{err_head}{tot}, 100-$counts{pos}{$pos}{err_head}{tot}*100.0/$counts{pos}{$pos}{tot},
|
| 1003 |
+
$counts{pos}{$pos}{tot}-$counts{pos}{$pos}{err_dep}{tot}, 100-$counts{pos}{$pos}{err_dep}{tot}*100.0/$counts{pos}{$pos}{tot},
|
| 1004 |
+
$counts{pos}{$pos}{tot}-$counts{pos}{$pos}{err_any}, 100-$counts{pos}{$pos}{err_any}*100.0/$counts{pos}{$pos}{tot} ;
|
| 1005 |
+
}
|
| 1006 |
+
|
| 1007 |
+
printf OUT "%s\n", " -----------+-------+-------+------+-------+------+-------+-------" ;
|
| 1008 |
+
|
| 1009 |
+
printf OUT "\n\n" ;
|
| 1010 |
+
|
| 1011 |
+
printf OUT " The overall error rate and its distribution over CPOSTAGs\n\n" ;
|
| 1012 |
+
printf OUT "%s\n", " -----------+-------+-------+------+-------+------+-------+-------" ;
|
| 1013 |
+
|
| 1014 |
+
printf OUT " %-10s | %-5s | %-5s | %% | %-5s | %% | %-5s | %%\n",
|
| 1015 |
+
'Error', 'words', 'head', ' dep', 'both' ;
|
| 1016 |
+
printf OUT " %-10s | %-5s | %-5s | | %-5s | | %-5s |\n",
|
| 1017 |
+
|
| 1018 |
+
'Rate', ' ', 'err', ' err', 'wrong' ;
|
| 1019 |
+
|
| 1020 |
+
printf OUT "%s\n", " -----------+-------+-------+------+-------+------+-------+-------" ;
|
| 1021 |
+
|
| 1022 |
+
printf OUT " %-10s | %5d | %5d | %3.0f%% | %5d | %3.0f%% | %5d | %3.0f%%\n",
|
| 1023 |
+
'total', $counts{tot},
|
| 1024 |
+
$counts{err_head}{tot}, $counts{err_head}{tot}*100.0/$counts{tot},
|
| 1025 |
+
$counts{err_dep}{tot}, $counts{err_dep}{tot}*100.0/$counts{tot},
|
| 1026 |
+
$counts{err_both}, $counts{err_both}*100.0/$counts{tot} ;
|
| 1027 |
+
|
| 1028 |
+
printf OUT "%s\n", " -----------+-------+-------+------+-------+------+-------+-------" ;
|
| 1029 |
+
|
| 1030 |
+
foreach $pos (sort {$counts{pos}{$b}{tot} <=> $counts{pos}{$a}{tot}} keys %{$counts{pos}})
|
| 1031 |
+
{
|
| 1032 |
+
if (! defined($counts{pos}{$pos}{err_both}))
|
| 1033 |
+
{
|
| 1034 |
+
$counts{pos}{$pos}{err_both} = 0 ;
|
| 1035 |
+
}
|
| 1036 |
+
|
| 1037 |
+
printf OUT " %-10s | %5d | %5d | %3.0f%% | %5d | %3.0f%% | %5d | %3.0f%%\n",
|
| 1038 |
+
$pos, $counts{pos}{$pos}{tot},
|
| 1039 |
+
$counts{pos}{$pos}{err_head}{tot}, $counts{pos}{$pos}{err_head}{tot}*100.0/$counts{pos}{$pos}{tot},
|
| 1040 |
+
$counts{pos}{$pos}{err_dep}{tot}, $counts{pos}{$pos}{err_dep}{tot}*100.0/$counts{pos}{$pos}{tot},
|
| 1041 |
+
$counts{pos}{$pos}{err_both}, $counts{pos}{$pos}{err_both}*100.0/$counts{pos}{$pos}{tot} ;
|
| 1042 |
+
|
| 1043 |
+
}
|
| 1044 |
+
|
| 1045 |
+
printf OUT "%s\n", " -----------+-------+-------+------+-------+------+-------+-------" ;
|
| 1046 |
+
|
| 1047 |
+
### added by Sabine Buchholz
|
| 1048 |
+
printf OUT "\n\n";
|
| 1049 |
+
printf OUT " Precision and recall of DEPREL\n\n";
|
| 1050 |
+
printf OUT " ----------------+------+---------+--------+------------+---------------\n";
|
| 1051 |
+
printf OUT " deprel | gold | correct | system | recall (%%) | precision (%%) \n";
|
| 1052 |
+
printf OUT " ----------------+------+---------+--------+------------+---------------\n";
|
| 1053 |
+
foreach my $dep (sort keys %{$counts{all_dep}}) {
|
| 1054 |
+
# initialize
|
| 1055 |
+
my ($tot_corr, $tot_g, $tot_s, $prec, $rec) = (0, 0, 0, 'NaN', 'NaN');
|
| 1056 |
+
|
| 1057 |
+
if (defined($counts{dep2}{$dep}{$dep})) {
|
| 1058 |
+
$tot_corr = $counts{dep2}{$dep}{$dep};
|
| 1059 |
+
}
|
| 1060 |
+
if (defined($counts{dep}{$dep}{tot})) {
|
| 1061 |
+
$tot_g = $counts{dep}{$dep}{tot};
|
| 1062 |
+
$rec = sprintf("%.2f",$tot_corr / $tot_g * 100);
|
| 1063 |
+
}
|
| 1064 |
+
if (defined($counts{dep_s}{$dep}{tot})) {
|
| 1065 |
+
$tot_s = $counts{dep_s}{$dep}{tot};
|
| 1066 |
+
$prec = sprintf("%.2f",$tot_corr / $tot_s * 100);
|
| 1067 |
+
}
|
| 1068 |
+
printf OUT " %-15s | %4d | %7d | %6d | %10s | %13s\n",
|
| 1069 |
+
$dep, $tot_g, $tot_corr, $tot_s, $rec, $prec;
|
| 1070 |
+
}
|
| 1071 |
+
|
| 1072 |
+
### DEPREL + ATTACHMENT:
|
| 1073 |
+
### Same as Sabine's DEPREL apart from $tot_corr calculation
|
| 1074 |
+
printf OUT "\n\n";
|
| 1075 |
+
printf OUT " Precision and recall of DEPREL + ATTACHMENT\n\n";
|
| 1076 |
+
printf OUT " ----------------+------+---------+--------+------------+---------------\n";
|
| 1077 |
+
printf OUT " deprel | gold | correct | system | recall (%%) | precision (%%) \n";
|
| 1078 |
+
printf OUT " ----------------+------+---------+--------+------------+---------------\n";
|
| 1079 |
+
foreach my $dep (sort keys %{$counts{all_dep}}) {
|
| 1080 |
+
# initialize
|
| 1081 |
+
my ($tot_corr, $tot_g, $tot_s, $prec, $rec) = (0, 0, 0, 'NaN', 'NaN');
|
| 1082 |
+
|
| 1083 |
+
if (defined($counts{dep2}{$dep}{$dep})) {
|
| 1084 |
+
if (defined($counts{err_head_corr_dep}{$dep})) {
|
| 1085 |
+
$tot_corr = $counts{dep2}{$dep}{$dep} - $counts{err_head_corr_dep}{$dep};
|
| 1086 |
+
} else {
|
| 1087 |
+
$tot_corr = $counts{dep2}{$dep}{$dep};
|
| 1088 |
+
}
|
| 1089 |
+
}
|
| 1090 |
+
if (defined($counts{dep}{$dep}{tot})) {
|
| 1091 |
+
$tot_g = $counts{dep}{$dep}{tot};
|
| 1092 |
+
$rec = sprintf("%.2f",$tot_corr / $tot_g * 100);
|
| 1093 |
+
}
|
| 1094 |
+
if (defined($counts{dep_s}{$dep}{tot})) {
|
| 1095 |
+
$tot_s = $counts{dep_s}{$dep}{tot};
|
| 1096 |
+
$prec = sprintf("%.2f",$tot_corr / $tot_s * 100);
|
| 1097 |
+
}
|
| 1098 |
+
printf OUT " %-15s | %4d | %7d | %6d | %10s | %13s\n",
|
| 1099 |
+
$dep, $tot_g, $tot_corr, $tot_s, $rec, $prec;
|
| 1100 |
+
}
|
| 1101 |
+
### DEPREL + ATTACHMENT
|
| 1102 |
+
|
| 1103 |
+
printf OUT "\n\n";
|
| 1104 |
+
printf OUT " Precision and recall of binned HEAD direction\n\n";
|
| 1105 |
+
printf OUT " ----------------+------+---------+--------+------------+---------------\n";
|
| 1106 |
+
printf OUT " direction | gold | correct | system | recall (%%) | precision (%%) \n";
|
| 1107 |
+
printf OUT " ----------------+------+---------+--------+------------+---------------\n";
|
| 1108 |
+
foreach my $dir ('to_root', 'left', 'right', 'self') {
|
| 1109 |
+
# initialize
|
| 1110 |
+
my ($tot_corr, $tot_g, $tot_s, $prec, $rec) = (0, 0, 0, 'NaN', 'NaN');
|
| 1111 |
+
|
| 1112 |
+
if (defined($counts{dir2}{$dir}{$dir})) {
|
| 1113 |
+
$tot_corr = $counts{dir2}{$dir}{$dir};
|
| 1114 |
+
}
|
| 1115 |
+
if (defined($counts{dir_g}{$dir}{tot})) {
|
| 1116 |
+
$tot_g = $counts{dir_g}{$dir}{tot};
|
| 1117 |
+
$rec = sprintf("%.2f",$tot_corr / $tot_g * 100);
|
| 1118 |
+
}
|
| 1119 |
+
if (defined($counts{dir_s}{$dir}{tot})) {
|
| 1120 |
+
$tot_s = $counts{dir_s}{$dir}{tot};
|
| 1121 |
+
$prec = sprintf("%.2f",$tot_corr / $tot_s * 100);
|
| 1122 |
+
}
|
| 1123 |
+
printf OUT " %-15s | %4d | %7d | %6d | %10s | %13s\n",
|
| 1124 |
+
$dir, $tot_g, $tot_corr, $tot_s, $rec, $prec;
|
| 1125 |
+
}
|
| 1126 |
+
|
| 1127 |
+
printf OUT "\n\n";
|
| 1128 |
+
printf OUT " Precision and recall of binned HEAD distance\n\n";
|
| 1129 |
+
printf OUT " ----------------+------+---------+--------+------------+---------------\n";
|
| 1130 |
+
printf OUT " distance | gold | correct | system | recall (%%) | precision (%%) \n";
|
| 1131 |
+
printf OUT " ----------------+------+---------+--------+------------+---------------\n";
|
| 1132 |
+
foreach my $dist ('to_root', '1', '2', '3-6', '7-...') {
|
| 1133 |
+
# initialize
|
| 1134 |
+
my ($tot_corr, $tot_g, $tot_s, $prec, $rec) = (0, 0, 0, 'NaN', 'NaN');
|
| 1135 |
+
|
| 1136 |
+
if (defined($counts{dist2}{$dist}{$dist})) {
|
| 1137 |
+
$tot_corr = $counts{dist2}{$dist}{$dist};
|
| 1138 |
+
}
|
| 1139 |
+
if (defined($counts{dist_g}{$dist}{tot})) {
|
| 1140 |
+
$tot_g = $counts{dist_g}{$dist}{tot};
|
| 1141 |
+
$rec = sprintf("%.2f",$tot_corr / $tot_g * 100);
|
| 1142 |
+
}
|
| 1143 |
+
if (defined($counts{dist_s}{$dist}{tot})) {
|
| 1144 |
+
$tot_s = $counts{dist_s}{$dist}{tot};
|
| 1145 |
+
$prec = sprintf("%.2f",$tot_corr / $tot_s * 100);
|
| 1146 |
+
}
|
| 1147 |
+
printf OUT " %-15s | %4d | %7d | %6d | %10s | %13s\n",
|
| 1148 |
+
$dist, $tot_g, $tot_corr, $tot_s, $rec, $prec;
|
| 1149 |
+
}
|
| 1150 |
+
|
| 1151 |
+
printf OUT "\n\n";
|
| 1152 |
+
printf OUT " Frame confusions (gold versus system; *...* marks the head token)\n\n";
|
| 1153 |
+
foreach my $frame (sort {$counts{frame2}{$b} <=> $counts{frame2}{$a}} keys %{$counts{frame2}})
|
| 1154 |
+
{
|
| 1155 |
+
if ($counts{frame2}{$frame} >= 5) # (make 5 a changeable threshold later)
|
| 1156 |
+
{
|
| 1157 |
+
printf OUT " %3d %s\n", $counts{frame2}{$frame}, $frame;
|
| 1158 |
+
}
|
| 1159 |
+
}
|
| 1160 |
+
### end of: added by Sabine Buchholz
|
| 1161 |
+
|
| 1162 |
+
|
| 1163 |
+
#
|
| 1164 |
+
# Leave only the 5 words mostly involved in errors
|
| 1165 |
+
#
|
| 1166 |
+
|
| 1167 |
+
|
| 1168 |
+
$thresh = (sort {$b <=> $a} values %{$counts{word}{err_any}})[4] ;
|
| 1169 |
+
|
| 1170 |
+
# ensure enough space for title
|
| 1171 |
+
$max_word_len = length('word') ;
|
| 1172 |
+
|
| 1173 |
+
foreach $word (keys %{$counts{word}{err_any}})
|
| 1174 |
+
{
|
| 1175 |
+
if ($counts{word}{err_any}{$word} < $thresh)
|
| 1176 |
+
{
|
| 1177 |
+
delete $counts{word}{err_any}{$word} ;
|
| 1178 |
+
next ;
|
| 1179 |
+
}
|
| 1180 |
+
|
| 1181 |
+
$l = uni_len($word) ;
|
| 1182 |
+
if ($l > $max_word_len)
|
| 1183 |
+
{
|
| 1184 |
+
$max_word_len = $l ;
|
| 1185 |
+
}
|
| 1186 |
+
}
|
| 1187 |
+
|
| 1188 |
+
# filter a case when the difference between the error counts
|
| 1189 |
+
# for 2-word and 1-word contexts is small
|
| 1190 |
+
# (leave the 2-word context)
|
| 1191 |
+
|
| 1192 |
+
foreach $con (keys %{$counts{con_aft_2}{tot}})
|
| 1193 |
+
{
|
| 1194 |
+
($w1) = split(/\+/, $con) ;
|
| 1195 |
+
|
| 1196 |
+
if (defined $counts{con_aft}{tot}{$w1} &&
|
| 1197 |
+
$counts{con_aft}{tot}{$w1}-$counts{con_aft_2}{tot}{$con} <= 1)
|
| 1198 |
+
{
|
| 1199 |
+
delete $counts{con_aft}{tot}{$w1} ;
|
| 1200 |
+
}
|
| 1201 |
+
}
|
| 1202 |
+
|
| 1203 |
+
foreach $con (keys %{$counts{con_bef_2}{tot}})
|
| 1204 |
+
{
|
| 1205 |
+
($w_2, $w_1) = split(/\+/, $con) ;
|
| 1206 |
+
|
| 1207 |
+
if (defined $counts{con_bef}{tot}{$w_1} &&
|
| 1208 |
+
$counts{con_bef}{tot}{$w_1}-$counts{con_bef_2}{tot}{$con} <= 1)
|
| 1209 |
+
{
|
| 1210 |
+
delete $counts{con_bef}{tot}{$w_1} ;
|
| 1211 |
+
}
|
| 1212 |
+
}
|
| 1213 |
+
|
| 1214 |
+
foreach $con_pos (keys %{$counts{con_pos_aft_2}{tot}})
|
| 1215 |
+
{
|
| 1216 |
+
($p1) = split(/\+/, $con_pos) ;
|
| 1217 |
+
|
| 1218 |
+
if (defined($counts{con_pos_aft}{tot}{$p1}) &&
|
| 1219 |
+
$counts{con_pos_aft}{tot}{$p1}-$counts{con_pos_aft_2}{tot}{$con_pos} <= 1)
|
| 1220 |
+
{
|
| 1221 |
+
delete $counts{con_pos_aft}{tot}{$p1} ;
|
| 1222 |
+
}
|
| 1223 |
+
}
|
| 1224 |
+
|
| 1225 |
+
foreach $con_pos (keys %{$counts{con_pos_bef_2}{tot}})
|
| 1226 |
+
{
|
| 1227 |
+
($p_2, $p_1) = split(/\+/, $con_pos) ;
|
| 1228 |
+
|
| 1229 |
+
if (defined($counts{con_pos_bef}{tot}{$p_1}) &&
|
| 1230 |
+
$counts{con_pos_bef}{tot}{$p_1}-$counts{con_pos_bef_2}{tot}{$con_pos} <= 1)
|
| 1231 |
+
{
|
| 1232 |
+
delete $counts{con_pos_bef}{tot}{$p_1} ;
|
| 1233 |
+
}
|
| 1234 |
+
}
|
| 1235 |
+
|
| 1236 |
+
# for each context type, take the three contexts most involved in errors
|
| 1237 |
+
|
| 1238 |
+
$max_con_len = 0 ;
|
| 1239 |
+
|
| 1240 |
+
filter_context_counts($counts{con_bef_2}{tot}, $con_err_num, \$max_con_len) ;
|
| 1241 |
+
|
| 1242 |
+
filter_context_counts($counts{con_bef}{tot}, $con_err_num, \$max_con_len) ;
|
| 1243 |
+
|
| 1244 |
+
filter_context_counts($counts{con_aft}{tot}, $con_err_num, \$max_con_len) ;
|
| 1245 |
+
|
| 1246 |
+
filter_context_counts($counts{con_aft_2}{tot}, $con_err_num, \$max_con_len) ;
|
| 1247 |
+
|
| 1248 |
+
# for each CPOS context type, take the three CPOS contexts most involved in errors
|
| 1249 |
+
|
| 1250 |
+
$max_con_pos_len = 0 ;
|
| 1251 |
+
|
| 1252 |
+
$thresh = (sort {$b <=> $a} values %{$counts{con_pos_bef_2}{tot}})[$con_err_num-1] ;
|
| 1253 |
+
|
| 1254 |
+
foreach $con_pos (keys %{$counts{con_pos_bef_2}{tot}})
|
| 1255 |
+
{
|
| 1256 |
+
if ($counts{con_pos_bef_2}{tot}{$con_pos} < $thresh)
|
| 1257 |
+
{
|
| 1258 |
+
delete $counts{con_pos_bef_2}{tot}{$con_pos} ;
|
| 1259 |
+
next ;
|
| 1260 |
+
}
|
| 1261 |
+
if (length($con_pos) > $max_con_pos_len)
|
| 1262 |
+
{
|
| 1263 |
+
$max_con_pos_len = length($con_pos) ;
|
| 1264 |
+
}
|
| 1265 |
+
}
|
| 1266 |
+
|
| 1267 |
+
$thresh = (sort {$b <=> $a} values %{$counts{con_pos_bef}{tot}})[$con_err_num-1] ;
|
| 1268 |
+
|
| 1269 |
+
foreach $con_pos (keys %{$counts{con_pos_bef}{tot}})
|
| 1270 |
+
{
|
| 1271 |
+
if ($counts{con_pos_bef}{tot}{$con_pos} < $thresh)
|
| 1272 |
+
{
|
| 1273 |
+
delete $counts{con_pos_bef}{tot}{$con_pos} ;
|
| 1274 |
+
next ;
|
| 1275 |
+
}
|
| 1276 |
+
if (length($con_pos) > $max_con_pos_len)
|
| 1277 |
+
{
|
| 1278 |
+
$max_con_pos_len = length($con_pos) ;
|
| 1279 |
+
}
|
| 1280 |
+
}
|
| 1281 |
+
|
| 1282 |
+
$thresh = (sort {$b <=> $a} values %{$counts{con_pos_aft}{tot}})[$con_err_num-1] ;
|
| 1283 |
+
|
| 1284 |
+
foreach $con_pos (keys %{$counts{con_pos_aft}{tot}})
|
| 1285 |
+
{
|
| 1286 |
+
if ($counts{con_pos_aft}{tot}{$con_pos} < $thresh)
|
| 1287 |
+
{
|
| 1288 |
+
delete $counts{con_pos_aft}{tot}{$con_pos} ;
|
| 1289 |
+
next ;
|
| 1290 |
+
}
|
| 1291 |
+
if (length($con_pos) > $max_con_pos_len)
|
| 1292 |
+
{
|
| 1293 |
+
$max_con_pos_len = length($con_pos) ;
|
| 1294 |
+
}
|
| 1295 |
+
}
|
| 1296 |
+
|
| 1297 |
+
$thresh = (sort {$b <=> $a} values %{$counts{con_pos_aft_2}{tot}})[$con_err_num-1] ;
|
| 1298 |
+
|
| 1299 |
+
foreach $con_pos (keys %{$counts{con_pos_aft_2}{tot}})
|
| 1300 |
+
{
|
| 1301 |
+
if ($counts{con_pos_aft_2}{tot}{$con_pos} < $thresh)
|
| 1302 |
+
{
|
| 1303 |
+
delete $counts{con_pos_aft_2}{tot}{$con_pos} ;
|
| 1304 |
+
next ;
|
| 1305 |
+
}
|
| 1306 |
+
if (length($con_pos) > $max_con_pos_len)
|
| 1307 |
+
{
|
| 1308 |
+
$max_con_pos_len = length($con_pos) ;
|
| 1309 |
+
}
|
| 1310 |
+
}
|
| 1311 |
+
|
| 1312 |
+
# printing
|
| 1313 |
+
|
| 1314 |
+
# ------------- focus words
|
| 1315 |
+
|
| 1316 |
+
printf OUT "\n\n" ;
|
| 1317 |
+
printf OUT " %d focus words where most of the errors occur:\n\n", scalar keys %{$counts{word}{err_any}} ;
|
| 1318 |
+
|
| 1319 |
+
printf OUT " %-*s | %-4s | %-4s | %-4s | %-4s\n", $max_word_len, ' ', 'any', 'head', 'dep', 'both' ;
|
| 1320 |
+
printf OUT " %s-+------+------+------+------\n", '-' x $max_word_len;
|
| 1321 |
+
|
| 1322 |
+
foreach $word (sort {$counts{word}{err_any}{$b} <=> $counts{word}{err_any}{$a}} keys %{$counts{word}{err_any}})
|
| 1323 |
+
{
|
| 1324 |
+
if (!defined($counts{word}{err_head}{$word}))
|
| 1325 |
+
{
|
| 1326 |
+
$counts{word}{err_head}{$word} = 0 ;
|
| 1327 |
+
}
|
| 1328 |
+
if (! defined($counts{word}{err_dep}{$word}))
|
| 1329 |
+
{
|
| 1330 |
+
$counts{word}{err_dep}{$word} = 0 ;
|
| 1331 |
+
}
|
| 1332 |
+
if (! defined($counts{word}{err_any}{$word}))
|
| 1333 |
+
{
|
| 1334 |
+
$counts{word}{err_any}{$word} = 0;
|
| 1335 |
+
}
|
| 1336 |
+
printf OUT " %-*s | %4d | %4d | %4d | %4d\n",
|
| 1337 |
+
$max_word_len+length($word)-uni_len($word), $word, $counts{word}{err_any}{$word},
|
| 1338 |
+
$counts{word}{err_head}{$word},
|
| 1339 |
+
$counts{word}{err_dep}{$word},
|
| 1340 |
+
$counts{word}{err_dep}{$word}+$counts{word}{err_head}{$word}-$counts{word}{err_any}{$word} ;
|
| 1341 |
+
}
|
| 1342 |
+
|
| 1343 |
+
printf OUT " %s-+------+------+------+------\n", '-' x $max_word_len;
|
| 1344 |
+
|
| 1345 |
+
# ------------- contexts
|
| 1346 |
+
|
| 1347 |
+
printf OUT "\n\n" ;
|
| 1348 |
+
|
| 1349 |
+
printf OUT " one-token preceeding contexts where most of the errors occur:\n\n" ;
|
| 1350 |
+
|
| 1351 |
+
print_context($counts{con_bef}, $counts{con_pos_bef}, $max_con_len, $max_con_pos_len) ;
|
| 1352 |
+
|
| 1353 |
+
printf OUT " two-token preceeding contexts where most of the errors occur:\n\n" ;
|
| 1354 |
+
|
| 1355 |
+
print_context($counts{con_bef_2}, $counts{con_pos_bef_2}, $max_con_len, $max_con_pos_len) ;
|
| 1356 |
+
|
| 1357 |
+
printf OUT " one-token following contexts where most of the errors occur:\n\n" ;
|
| 1358 |
+
|
| 1359 |
+
print_context($counts{con_aft}, $counts{con_pos_aft}, $max_con_len, $max_con_pos_len) ;
|
| 1360 |
+
|
| 1361 |
+
printf OUT " two-token following contexts where most of the errors occur:\n\n" ;
|
| 1362 |
+
|
| 1363 |
+
print_context($counts{con_aft_2}, $counts{con_pos_aft_2}, $max_con_len, $max_con_pos_len) ;
|
| 1364 |
+
|
| 1365 |
+
# ------------- Sentences
|
| 1366 |
+
|
| 1367 |
+
printf OUT " Sentence with the highest number of word errors:\n" ;
|
| 1368 |
+
$i = (sort { (defined($err_sent[$b]{word}) && $err_sent[$b]{word})
|
| 1369 |
+
<=> (defined($err_sent[$a]{word}) && $err_sent[$a]{word}) } 1 .. $sent_num)[0] ;
|
| 1370 |
+
printf OUT " Sentence %d line %d, ", $i, $starts[$i-1] ;
|
| 1371 |
+
printf OUT "%d head errors, %d dependency errors, %d word errors\n",
|
| 1372 |
+
$err_sent[$i]{head}, $err_sent[$i]{dep}, $err_sent[$i]{word} ;
|
| 1373 |
+
|
| 1374 |
+
printf OUT "\n\n" ;
|
| 1375 |
+
|
| 1376 |
+
printf OUT " Sentence with the highest number of head errors:\n" ;
|
| 1377 |
+
$i = (sort { (defined($err_sent[$b]{head}) && $err_sent[$b]{head})
|
| 1378 |
+
<=> (defined($err_sent[$a]{head}) && $err_sent[$a]{head}) } 1 .. $sent_num)[0] ;
|
| 1379 |
+
printf OUT " Sentence %d line %d, ", $i, $starts[$i-1] ;
|
| 1380 |
+
printf OUT "%d head errors, %d dependency errors, %d word errors\n",
|
| 1381 |
+
$err_sent[$i]{head}, $err_sent[$i]{dep}, $err_sent[$i]{word} ;
|
| 1382 |
+
|
| 1383 |
+
printf OUT "\n\n" ;
|
| 1384 |
+
|
| 1385 |
+
printf OUT " Sentence with the highest number of dependency errors:\n" ;
|
| 1386 |
+
$i = (sort { (defined($err_sent[$b]{dep}) && $err_sent[$b]{dep})
|
| 1387 |
+
<=> (defined($err_sent[$a]{dep}) && $err_sent[$a]{dep}) } 1 .. $sent_num)[0] ;
|
| 1388 |
+
printf OUT " Sentence %d line %d, ", $i, $starts[$i-1] ;
|
| 1389 |
+
printf OUT "%d head errors, %d dependency errors, %d word errors\n",
|
| 1390 |
+
$err_sent[$i]{head}, $err_sent[$i]{dep}, $err_sent[$i]{word} ;
|
| 1391 |
+
|
| 1392 |
+
#
|
| 1393 |
+
# Second pass, collect statistics of the frequent errors
|
| 1394 |
+
#
|
| 1395 |
+
|
| 1396 |
+
# filter the errors, leave the most frequent $freq_err_num errors
|
| 1397 |
+
|
| 1398 |
+
$i = 0 ;
|
| 1399 |
+
|
| 1400 |
+
$thresh = (sort {$b <=> $a} values %freq_err)[$freq_err_num-1] ;
|
| 1401 |
+
|
| 1402 |
+
foreach $err (keys %freq_err)
|
| 1403 |
+
{
|
| 1404 |
+
if ($freq_err{$err} < $thresh)
|
| 1405 |
+
{
|
| 1406 |
+
delete $freq_err{$err} ;
|
| 1407 |
+
}
|
| 1408 |
+
}
|
| 1409 |
+
|
| 1410 |
+
# in case there are several errors with the threshold count
|
| 1411 |
+
|
| 1412 |
+
$freq_err_num = scalar keys %freq_err ;
|
| 1413 |
+
|
| 1414 |
+
%err_counts = () ;
|
| 1415 |
+
|
| 1416 |
+
$eof = 0 ;
|
| 1417 |
+
|
| 1418 |
+
seek (GOLD, 0, 0) ;
|
| 1419 |
+
seek (SYS, 0, 0) ;
|
| 1420 |
+
|
| 1421 |
+
while (! $eof)
|
| 1422 |
+
{ # second reading loop
|
| 1423 |
+
|
| 1424 |
+
$eof = read_sent(\@sent_gold, \@sent_sys) ;
|
| 1425 |
+
$sent_num++ ;
|
| 1426 |
+
|
| 1427 |
+
$word_num = scalar @sent_gold ;
|
| 1428 |
+
|
| 1429 |
+
# printf "$sent_num $word_num\n" ;
|
| 1430 |
+
|
| 1431 |
+
foreach $i_w (0 .. $word_num-1)
|
| 1432 |
+
{ # loop on words
|
| 1433 |
+
($word, $pos, $head_g, $dep_g)
|
| 1434 |
+
= @{$sent_gold[$i_w]}{'word', 'pos', 'head', 'dep'} ;
|
| 1435 |
+
|
| 1436 |
+
# printf "%d: %s %s %s %s\n", $i_w, $word, $pos, $head_g, $dep_g ;
|
| 1437 |
+
|
| 1438 |
+
if ((! $score_on_punct) && is_uni_punct($word))
|
| 1439 |
+
{
|
| 1440 |
+
# ignore punctuations
|
| 1441 |
+
next ;
|
| 1442 |
+
}
|
| 1443 |
+
|
| 1444 |
+
($head_s, $dep_s) = @{$sent_sys[$i_w]}{'head', 'dep'} ;
|
| 1445 |
+
|
| 1446 |
+
$err_head = ($head_g ne $head_s) ;
|
| 1447 |
+
$err_dep = ($dep_g ne $dep_s) ;
|
| 1448 |
+
|
| 1449 |
+
$head_err = '-' ;
|
| 1450 |
+
$dep_err = '-' ;
|
| 1451 |
+
|
| 1452 |
+
if ($head_g eq '0')
|
| 1453 |
+
{
|
| 1454 |
+
$head_aft_bef_g = '0' ;
|
| 1455 |
+
}
|
| 1456 |
+
elsif ($head_g eq $i_w+1)
|
| 1457 |
+
{
|
| 1458 |
+
$head_aft_bef_g = 'e' ;
|
| 1459 |
+
}
|
| 1460 |
+
else
|
| 1461 |
+
{
|
| 1462 |
+
$head_aft_bef_g = ($head_g <= $i_w+1 ? 'b' : 'a') ;
|
| 1463 |
+
}
|
| 1464 |
+
|
| 1465 |
+
if ($head_s eq '0')
|
| 1466 |
+
{
|
| 1467 |
+
$head_aft_bef_s = '0' ;
|
| 1468 |
+
}
|
| 1469 |
+
elsif ($head_s eq $i_w+1)
|
| 1470 |
+
{
|
| 1471 |
+
$head_aft_bef_s = 'e' ;
|
| 1472 |
+
}
|
| 1473 |
+
else
|
| 1474 |
+
{
|
| 1475 |
+
$head_aft_bef_s = ($head_s <= $i_w+1 ? 'b' : 'a') ;
|
| 1476 |
+
}
|
| 1477 |
+
|
| 1478 |
+
$head_aft_bef = $head_aft_bef_g.$head_aft_bef_s ;
|
| 1479 |
+
|
| 1480 |
+
if ($err_head)
|
| 1481 |
+
{
|
| 1482 |
+
if ($head_aft_bef_s eq '0')
|
| 1483 |
+
{
|
| 1484 |
+
$head_err = 0 ;
|
| 1485 |
+
}
|
| 1486 |
+
else
|
| 1487 |
+
{
|
| 1488 |
+
$head_err = $head_s-$head_g ;
|
| 1489 |
+
}
|
| 1490 |
+
}
|
| 1491 |
+
|
| 1492 |
+
if ($err_dep)
|
| 1493 |
+
{
|
| 1494 |
+
$dep_err = $dep_g.'->'.$dep_s ;
|
| 1495 |
+
}
|
| 1496 |
+
|
| 1497 |
+
if (! ($err_head || $err_dep))
|
| 1498 |
+
{
|
| 1499 |
+
next ;
|
| 1500 |
+
}
|
| 1501 |
+
|
| 1502 |
+
# handle only the most frequent errors
|
| 1503 |
+
|
| 1504 |
+
$err = $head_err.$sep.$head_aft_bef.$sep.$dep_err ;
|
| 1505 |
+
|
| 1506 |
+
if (! exists $freq_err{$err})
|
| 1507 |
+
{
|
| 1508 |
+
next ;
|
| 1509 |
+
}
|
| 1510 |
+
|
| 1511 |
+
($w_2, $w_1, $w1, $w2, $p_2, $p_1, $p1, $p2) = get_context(\@sent_gold, $i_w) ;
|
| 1512 |
+
|
| 1513 |
+
$con_bef = $w_1 ;
|
| 1514 |
+
$con_bef_2 = $w_2.' + '.$w_1 ;
|
| 1515 |
+
$con_aft = $w1 ;
|
| 1516 |
+
$con_aft_2 = $w1.' + '.$w2 ;
|
| 1517 |
+
|
| 1518 |
+
$con_pos_bef = $p_1 ;
|
| 1519 |
+
$con_pos_bef_2 = $p_2.'+'.$p_1 ;
|
| 1520 |
+
$con_pos_aft = $p1 ;
|
| 1521 |
+
$con_pos_aft_2 = $p1.'+'.$p2 ;
|
| 1522 |
+
|
| 1523 |
+
@cur_err = ($con_pos_bef, $con_bef, $word, $pos, $con_pos_aft, $con_aft) ;
|
| 1524 |
+
|
| 1525 |
+
# printf "# %-25s %-15s %-10s %-25s %-3s %-30s\n",
|
| 1526 |
+
# $con_bef, $word, $pos, $con_aft, $head_err, $dep_err ;
|
| 1527 |
+
|
| 1528 |
+
@bits = (0, 0, 0, 0, 0, 0) ;
|
| 1529 |
+
$j = 0 ;
|
| 1530 |
+
|
| 1531 |
+
while ($j == 0)
|
| 1532 |
+
{
|
| 1533 |
+
for ($i = 0; $i <= $#bits; $i++)
|
| 1534 |
+
{
|
| 1535 |
+
if ($bits[$i] == 0)
|
| 1536 |
+
{
|
| 1537 |
+
$bits[$i] = 1 ;
|
| 1538 |
+
$j = 0 ;
|
| 1539 |
+
last ;
|
| 1540 |
+
}
|
| 1541 |
+
else
|
| 1542 |
+
{
|
| 1543 |
+
$bits[$i] = 0 ;
|
| 1544 |
+
$j = 1 ;
|
| 1545 |
+
}
|
| 1546 |
+
}
|
| 1547 |
+
|
| 1548 |
+
@e_bits = @cur_err ;
|
| 1549 |
+
|
| 1550 |
+
for ($i = 0; $i <= $#bits; $i++)
|
| 1551 |
+
{
|
| 1552 |
+
if (! $bits[$i])
|
| 1553 |
+
{
|
| 1554 |
+
$e_bits[$i] = '*' ;
|
| 1555 |
+
}
|
| 1556 |
+
}
|
| 1557 |
+
|
| 1558 |
+
# include also the last case which is the most general
|
| 1559 |
+
# (wildcards for everything)
|
| 1560 |
+
$err_counts{$err}{join($sep, @e_bits)}++ ;
|
| 1561 |
+
|
| 1562 |
+
}
|
| 1563 |
+
|
| 1564 |
+
} # loop on words
|
| 1565 |
+
} # second reading loop
|
| 1566 |
+
|
| 1567 |
+
printf OUT "\n\n" ;
|
| 1568 |
+
printf OUT " Specific errors, %d most frequent errors:", $freq_err_num ;
|
| 1569 |
+
printf OUT "\n %s\n", '=' x 41 ;
|
| 1570 |
+
|
| 1571 |
+
|
| 1572 |
+
# deleting local contexts which are too general
|
| 1573 |
+
|
| 1574 |
+
foreach $err (keys %err_counts)
|
| 1575 |
+
{
|
| 1576 |
+
foreach $loc_con (sort {$err_counts{$err}{$b} <=> $err_counts{$err}{$a}}
|
| 1577 |
+
keys %{$err_counts{$err}})
|
| 1578 |
+
{
|
| 1579 |
+
@cur_err = split(/\Q$sep\E/, $loc_con) ;
|
| 1580 |
+
|
| 1581 |
+
# In this loop, one or two elements of the local context are
|
| 1582 |
+
# replaced with '*' to make it more general. If the entry for
|
| 1583 |
+
# the general context has the same count it is removed.
|
| 1584 |
+
|
| 1585 |
+
foreach $i (0 .. $#cur_err)
|
| 1586 |
+
{
|
| 1587 |
+
$w1 = $cur_err[$i] ;
|
| 1588 |
+
if ($cur_err[$i] eq '*')
|
| 1589 |
+
{
|
| 1590 |
+
next ;
|
| 1591 |
+
}
|
| 1592 |
+
$cur_err[$i] = '*' ;
|
| 1593 |
+
$con1 = join($sep, @cur_err) ;
|
| 1594 |
+
if ( defined($err_counts{$err}{$con1}) && defined($err_counts{$err}{$loc_con})
|
| 1595 |
+
&& ($err_counts{$err}{$con1} == $err_counts{$err}{$loc_con}))
|
| 1596 |
+
{
|
| 1597 |
+
delete $err_counts{$err}{$con1} ;
|
| 1598 |
+
}
|
| 1599 |
+
for ($j = $i+1; $j <=$#cur_err; $j++)
|
| 1600 |
+
{
|
| 1601 |
+
if ($cur_err[$j] eq '*')
|
| 1602 |
+
{
|
| 1603 |
+
next ;
|
| 1604 |
+
}
|
| 1605 |
+
$w2 = $cur_err[$j] ;
|
| 1606 |
+
$cur_err[$j] = '*' ;
|
| 1607 |
+
$con1 = join($sep, @cur_err) ;
|
| 1608 |
+
if ( defined($err_counts{$err}{$con1}) && defined($err_counts{$err}{$loc_con})
|
| 1609 |
+
&& ($err_counts{$err}{$con1} == $err_counts{$err}{$loc_con}))
|
| 1610 |
+
{
|
| 1611 |
+
delete $err_counts{$err}{$con1} ;
|
| 1612 |
+
}
|
| 1613 |
+
$cur_err[$j] = $w2 ;
|
| 1614 |
+
}
|
| 1615 |
+
$cur_err[$i] = $w1 ;
|
| 1616 |
+
}
|
| 1617 |
+
}
|
| 1618 |
+
}
|
| 1619 |
+
|
| 1620 |
+
# Leaving only the topmost local contexts for each error
|
| 1621 |
+
|
| 1622 |
+
foreach $err (keys %err_counts)
|
| 1623 |
+
{
|
| 1624 |
+
$thresh = (sort {$b <=> $a} values %{$err_counts{$err}})[$spec_err_loc_con-1] || 0 ;
|
| 1625 |
+
|
| 1626 |
+
# of the threshold is too low, take the 2nd highest count
|
| 1627 |
+
# (the highest may be the total which is the generic case
|
| 1628 |
+
# and not relevant for printing)
|
| 1629 |
+
|
| 1630 |
+
if ($thresh < 5)
|
| 1631 |
+
{
|
| 1632 |
+
$thresh = (sort {$b <=> $a} values %{$err_counts{$err}})[1] ;
|
| 1633 |
+
}
|
| 1634 |
+
|
| 1635 |
+
foreach $loc_con (keys %{$err_counts{$err}})
|
| 1636 |
+
{
|
| 1637 |
+
if ($err_counts{$err}{$loc_con} < $thresh)
|
| 1638 |
+
{
|
| 1639 |
+
delete $err_counts{$err}{$loc_con} ;
|
| 1640 |
+
}
|
| 1641 |
+
else
|
| 1642 |
+
{
|
| 1643 |
+
if ($loc_con ne join($sep, ('*', '*', '*', '*', '*', '*')))
|
| 1644 |
+
{
|
| 1645 |
+
$loc_con_err_counts{$loc_con}{$err} = $err_counts{$err}{$loc_con} ;
|
| 1646 |
+
}
|
| 1647 |
+
}
|
| 1648 |
+
}
|
| 1649 |
+
}
|
| 1650 |
+
|
| 1651 |
+
# printing an error summary
|
| 1652 |
+
|
| 1653 |
+
# calculating the context field length
|
| 1654 |
+
|
| 1655 |
+
$max_word_spec_len= length('word') ;
|
| 1656 |
+
$max_con_aft_len = length('word') ;
|
| 1657 |
+
$max_con_bef_len = length('word') ;
|
| 1658 |
+
$max_con_pos_len = length('CPOS') ;
|
| 1659 |
+
|
| 1660 |
+
foreach $err (keys %err_counts)
|
| 1661 |
+
{
|
| 1662 |
+
foreach $loc_con (sort keys %{$err_counts{$err}})
|
| 1663 |
+
{
|
| 1664 |
+
($con_pos_bef, $con_bef, $word, $pos, $con_pos_aft, $con_aft) =
|
| 1665 |
+
split(/\Q$sep\E/, $loc_con) ;
|
| 1666 |
+
|
| 1667 |
+
$l = uni_len($word) ;
|
| 1668 |
+
if ($l > $max_word_spec_len)
|
| 1669 |
+
{
|
| 1670 |
+
$max_word_spec_len = $l ;
|
| 1671 |
+
}
|
| 1672 |
+
|
| 1673 |
+
$l = uni_len($con_bef) ;
|
| 1674 |
+
if ($l > $max_con_bef_len)
|
| 1675 |
+
{
|
| 1676 |
+
$max_con_bef_len = $l ;
|
| 1677 |
+
}
|
| 1678 |
+
|
| 1679 |
+
$l = uni_len($con_aft) ;
|
| 1680 |
+
if ($l > $max_con_aft_len)
|
| 1681 |
+
{
|
| 1682 |
+
$max_con_aft_len = $l ;
|
| 1683 |
+
}
|
| 1684 |
+
|
| 1685 |
+
if (length($con_pos_aft) > $max_con_pos_len)
|
| 1686 |
+
{
|
| 1687 |
+
$max_con_pos_len = length($con_pos_aft) ;
|
| 1688 |
+
}
|
| 1689 |
+
|
| 1690 |
+
if (length($con_pos_bef) > $max_con_pos_len)
|
| 1691 |
+
{
|
| 1692 |
+
$max_con_pos_len = length($con_pos_bef) ;
|
| 1693 |
+
}
|
| 1694 |
+
}
|
| 1695 |
+
}
|
| 1696 |
+
|
| 1697 |
+
$err_counter = 0 ;
|
| 1698 |
+
|
| 1699 |
+
foreach $err (sort {$freq_err{$b} <=> $freq_err{$a}} keys %freq_err)
|
| 1700 |
+
{
|
| 1701 |
+
|
| 1702 |
+
($head_err, $head_aft_bef, $dep_err) = split(/\Q$sep\E/, $err) ;
|
| 1703 |
+
|
| 1704 |
+
$err_counter++ ;
|
| 1705 |
+
$err_desc{$err} = sprintf("%2d. ", $err_counter).
|
| 1706 |
+
describe_err($head_err, $head_aft_bef, $dep_err) ;
|
| 1707 |
+
|
| 1708 |
+
# printf OUT " %-3s %-30s %d\n", $head_err, $dep_err, $freq_err{$err} ;
|
| 1709 |
+
printf OUT "\n" ;
|
| 1710 |
+
printf OUT " %s : %d times\n", $err_desc{$err}, $freq_err{$err} ;
|
| 1711 |
+
|
| 1712 |
+
printf OUT " %s-+-%s-+-%s-+-%s-+-%s-+-%s-+------\n",
|
| 1713 |
+
'-' x $max_con_pos_len, '-' x $max_con_bef_len,
|
| 1714 |
+
'-' x $max_pos_len, '-' x $max_word_spec_len,
|
| 1715 |
+
'-' x $max_con_pos_len, '-' x $max_con_aft_len ;
|
| 1716 |
+
|
| 1717 |
+
printf OUT " %-*s | %-*s | %-*s | %s\n",
|
| 1718 |
+
$max_con_pos_len+$max_con_bef_len+3, ' Before',
|
| 1719 |
+
$max_word_spec_len+$max_pos_len+3, ' Focus',
|
| 1720 |
+
$max_con_pos_len+$max_con_aft_len+3, ' After',
|
| 1721 |
+
'Count' ;
|
| 1722 |
+
|
| 1723 |
+
printf OUT " %-*s %-*s | %-*s %-*s | %-*s %-*s |\n",
|
| 1724 |
+
$max_con_pos_len, 'CPOS', $max_con_bef_len, 'word',
|
| 1725 |
+
$max_pos_len, 'CPOS', $max_word_spec_len, 'word',
|
| 1726 |
+
$max_con_pos_len, 'CPOS', $max_con_aft_len, 'word' ;
|
| 1727 |
+
|
| 1728 |
+
printf OUT " %s-+-%s-+-%s-+-%s-+-%s-+-%s-+------\n",
|
| 1729 |
+
'-' x $max_con_pos_len, '-' x $max_con_bef_len,
|
| 1730 |
+
'-' x $max_pos_len, '-' x $max_word_spec_len,
|
| 1731 |
+
'-' x $max_con_pos_len, '-' x $max_con_aft_len ;
|
| 1732 |
+
|
| 1733 |
+
foreach $loc_con (sort {$err_counts{$err}{$b} <=> $err_counts{$err}{$a}}
|
| 1734 |
+
keys %{$err_counts{$err}})
|
| 1735 |
+
{
|
| 1736 |
+
if ($loc_con eq join($sep, ('*', '*', '*', '*', '*', '*')))
|
| 1737 |
+
{
|
| 1738 |
+
next ;
|
| 1739 |
+
}
|
| 1740 |
+
|
| 1741 |
+
$con1 = $loc_con ;
|
| 1742 |
+
$con1 =~ s/\*/ /g ;
|
| 1743 |
+
|
| 1744 |
+
($con_pos_bef, $con_bef, $word, $pos, $con_pos_aft, $con_aft) =
|
| 1745 |
+
split(/\Q$sep\E/, $con1) ;
|
| 1746 |
+
|
| 1747 |
+
printf OUT " %-*s | %-*s | %-*s | %-*s | %-*s | %-*s | %3d\n",
|
| 1748 |
+
$max_con_pos_len, $con_pos_bef, $max_con_bef_len+length($con_bef)-uni_len($con_bef), $con_bef,
|
| 1749 |
+
$max_pos_len, $pos, $max_word_spec_len+length($word)-uni_len($word), $word,
|
| 1750 |
+
$max_con_pos_len, $con_pos_aft, $max_con_aft_len+length($con_aft)-uni_len($con_aft), $con_aft,
|
| 1751 |
+
$err_counts{$err}{$loc_con} ;
|
| 1752 |
+
}
|
| 1753 |
+
|
| 1754 |
+
printf OUT " %s-+-%s-+-%s-+-%s-+-%s-+-%s-+------\n",
|
| 1755 |
+
'-' x $max_con_pos_len, '-' x $max_con_bef_len,
|
| 1756 |
+
'-' x $max_pos_len, '-' x $max_word_spec_len,
|
| 1757 |
+
'-' x $max_con_pos_len, '-' x $max_con_aft_len ;
|
| 1758 |
+
|
| 1759 |
+
}
|
| 1760 |
+
|
| 1761 |
+
printf OUT "\n\n" ;
|
| 1762 |
+
printf OUT " Local contexts involved in several frequent errors:" ;
|
| 1763 |
+
printf OUT "\n %s\n", '=' x 51 ;
|
| 1764 |
+
printf OUT "\n\n" ;
|
| 1765 |
+
|
| 1766 |
+
foreach $loc_con (sort {scalar keys %{$loc_con_err_counts{$b}} <=>
|
| 1767 |
+
scalar keys %{$loc_con_err_counts{$a}}}
|
| 1768 |
+
keys %loc_con_err_counts)
|
| 1769 |
+
{
|
| 1770 |
+
|
| 1771 |
+
if (scalar keys %{$loc_con_err_counts{$loc_con}} == 1)
|
| 1772 |
+
{
|
| 1773 |
+
next ;
|
| 1774 |
+
}
|
| 1775 |
+
|
| 1776 |
+
printf OUT " %s-+-%s-+-%s-+-%s-+-%s-+-%s-\n",
|
| 1777 |
+
'-' x $max_con_pos_len, '-' x $max_con_bef_len,
|
| 1778 |
+
'-' x $max_pos_len, '-' x $max_word_spec_len,
|
| 1779 |
+
'-' x $max_con_pos_len, '-' x $max_con_aft_len ;
|
| 1780 |
+
|
| 1781 |
+
printf OUT " %-*s | %-*s | %-*s \n",
|
| 1782 |
+
$max_con_pos_len+$max_con_bef_len+3, ' Before',
|
| 1783 |
+
$max_word_spec_len+$max_pos_len+3, ' Focus',
|
| 1784 |
+
$max_con_pos_len+$max_con_aft_len+3, ' After' ;
|
| 1785 |
+
|
| 1786 |
+
printf OUT " %-*s %-*s | %-*s %-*s | %-*s %-*s \n",
|
| 1787 |
+
$max_con_pos_len, 'CPOS', $max_con_bef_len, 'word',
|
| 1788 |
+
$max_pos_len, 'CPOS', $max_word_spec_len, 'word',
|
| 1789 |
+
$max_con_pos_len, 'CPOS', $max_con_aft_len, 'word' ;
|
| 1790 |
+
|
| 1791 |
+
printf OUT " %s-+-%s-+-%s-+-%s-+-%s-+-%s-\n",
|
| 1792 |
+
'-' x $max_con_pos_len, '-' x $max_con_bef_len,
|
| 1793 |
+
'-' x $max_pos_len, '-' x $max_word_spec_len,
|
| 1794 |
+
'-' x $max_con_pos_len, '-' x $max_con_aft_len ;
|
| 1795 |
+
|
| 1796 |
+
$con1 = $loc_con ;
|
| 1797 |
+
$con1 =~ s/\*/ /g ;
|
| 1798 |
+
|
| 1799 |
+
($con_pos_bef, $con_bef, $word, $pos, $con_pos_aft, $con_aft) =
|
| 1800 |
+
split(/\Q$sep\E/, $con1) ;
|
| 1801 |
+
|
| 1802 |
+
printf OUT " %-*s | %-*s | %-*s | %-*s | %-*s | %-*s \n",
|
| 1803 |
+
$max_con_pos_len, $con_pos_bef, $max_con_bef_len+length($con_bef)-uni_len($con_bef), $con_bef,
|
| 1804 |
+
$max_pos_len, $pos, $max_word_spec_len+length($word)-uni_len($word), $word,
|
| 1805 |
+
$max_con_pos_len, $con_pos_aft, $max_con_aft_len+length($con_aft)-uni_len($con_aft), $con_aft ;
|
| 1806 |
+
|
| 1807 |
+
printf OUT " %s-+-%s-+-%s-+-%s-+-%s-+-%s-\n",
|
| 1808 |
+
'-' x $max_con_pos_len, '-' x $max_con_bef_len,
|
| 1809 |
+
'-' x $max_pos_len, '-' x $max_word_spec_len,
|
| 1810 |
+
'-' x $max_con_pos_len, '-' x $max_con_aft_len ;
|
| 1811 |
+
|
| 1812 |
+
foreach $err (sort {$loc_con_err_counts{$loc_con}{$b} <=>
|
| 1813 |
+
$loc_con_err_counts{$loc_con}{$a}}
|
| 1814 |
+
keys %{$loc_con_err_counts{$loc_con}})
|
| 1815 |
+
{
|
| 1816 |
+
printf OUT " %s : %d times\n", $err_desc{$err},
|
| 1817 |
+
$loc_con_err_counts{$loc_con}{$err} ;
|
| 1818 |
+
}
|
| 1819 |
+
|
| 1820 |
+
printf OUT "\n" ;
|
| 1821 |
+
}
|
| 1822 |
+
|
| 1823 |
+
close GOLD ;
|
| 1824 |
+
close SYS ;
|
| 1825 |
+
|
| 1826 |
+
close OUT ;
|
examples/macro_UAS_LAS.py
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def load_results(filename):
|
| 5 |
+
|
| 6 |
+
results = []
|
| 7 |
+
sent = []
|
| 8 |
+
with open(filename, 'r') as fp:
|
| 9 |
+
for i, line in enumerate(fp):
|
| 10 |
+
if i == 0:
|
| 11 |
+
continue
|
| 12 |
+
splits = line.strip().split('\t')
|
| 13 |
+
if len(line.strip()) == 0:
|
| 14 |
+
if len(sent) != 0:
|
| 15 |
+
results.append(sent)
|
| 16 |
+
sent = []
|
| 17 |
+
continue
|
| 18 |
+
gold_head = splits[-4]
|
| 19 |
+
gold_label = splits[-3]
|
| 20 |
+
pred_head = splits[-2]
|
| 21 |
+
pred_label = splits[-1]
|
| 22 |
+
sent.append((gold_head, gold_label, pred_head, pred_label))
|
| 23 |
+
print('Total Number of sentences ' + str(len(results)))
|
| 24 |
+
return results
|
| 25 |
+
|
| 26 |
+
def calculate_las_uas(gold_heads, gold_labels, pred_heads, pred_labels):
|
| 27 |
+
|
| 28 |
+
u_correct = 0
|
| 29 |
+
l_correct = 0
|
| 30 |
+
u_total = 0
|
| 31 |
+
l_total = 0
|
| 32 |
+
|
| 33 |
+
for i in range(len(gold_heads)):
|
| 34 |
+
if gold_heads[i] == pred_heads[i]:
|
| 35 |
+
u_correct +=1
|
| 36 |
+
u_total +=1
|
| 37 |
+
l_total +=1
|
| 38 |
+
if gold_heads[i] == pred_heads[i] and gold_labels[i] == pred_labels[i]:
|
| 39 |
+
l_correct +=1
|
| 40 |
+
return u_correct, u_total, l_correct, l_total
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def calculate_stats(results,path):
|
| 44 |
+
u_correct = 0
|
| 45 |
+
l_correct = 0
|
| 46 |
+
u_total = 0
|
| 47 |
+
l_total = 0
|
| 48 |
+
|
| 49 |
+
sent_uas = []
|
| 50 |
+
sent_las = []
|
| 51 |
+
|
| 52 |
+
for i in range(len(results)):
|
| 53 |
+
gold_heads, gold_labels, pred_heads, pred_labels = zip(*results[i])
|
| 54 |
+
u_c, u_t, l_c, l_t = calculate_las_uas(gold_heads, gold_labels, pred_heads, pred_labels)
|
| 55 |
+
if u_t >0:
|
| 56 |
+
uas = float(u_c)/u_t
|
| 57 |
+
las = float(l_c)/l_t
|
| 58 |
+
sent_uas.append(uas)
|
| 59 |
+
sent_las.append(las)
|
| 60 |
+
u_correct += u_c
|
| 61 |
+
l_correct += l_c
|
| 62 |
+
u_total += u_t
|
| 63 |
+
l_total += l_t
|
| 64 |
+
|
| 65 |
+
UAS = float(u_correct)/u_total
|
| 66 |
+
LAS = float(l_correct)/l_total
|
| 67 |
+
path = path.replace('combined_1300_test.txt','Macro-UAS-LAS-score.txt')
|
| 68 |
+
f = open(path,'w')
|
| 69 |
+
f.write('Word level UAS : ' + str(UAS) +'\n')
|
| 70 |
+
f.write('Word level LAS : ' + str(LAS)+'\n')
|
| 71 |
+
f.write('Sentence level UAS : ' + str(float(sum(sent_uas))/len(sent_uas))+'\n')
|
| 72 |
+
f.write('Sentence level LAS : ' + str(float(sum(sent_las))/len(sent_las))+'\n')
|
| 73 |
+
f.close()
|
| 74 |
+
print('Word level UAS : ' + str(UAS))
|
| 75 |
+
print('Word level LAS : ' + str(LAS))
|
| 76 |
+
print('Sentence level UAS : ' + str(float(sum(sent_uas))/len(sent_uas)))
|
| 77 |
+
print('Sentence level LAS : ' + str(float(sum(sent_las))/len(sent_las)))
|
| 78 |
+
|
| 79 |
+
return sent_uas, sent_las, UAS, LAS
|
| 80 |
+
|
| 81 |
+
def write_results(sent_uas, sent_las, filename_uas, filename_las):
|
| 82 |
+
|
| 83 |
+
fp_uas = open(filename_uas, 'w')
|
| 84 |
+
fp_las = open(filename_las, 'w')
|
| 85 |
+
|
| 86 |
+
for i in range(len(sent_uas)):
|
| 87 |
+
fp_uas.write(str(sent_uas[i]) + '\n')
|
| 88 |
+
fp_las.write(str(sent_las[i]) + '\n')
|
| 89 |
+
|
| 90 |
+
fp_uas.close()
|
| 91 |
+
fp_las.close()
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
if __name__=="__main__":
|
| 95 |
+
dirs = sys.argv[1]
|
| 96 |
+
# results_2 = load_results(sys.argv[2])
|
| 97 |
+
##path = "Predictions/Yap/"+dirs
|
| 98 |
+
path = "./saved_models/"+dirs+"/final_ensembled/combined_1300_test.txt"
|
| 99 |
+
result = load_results(path)
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
sent_uas1, sent_las1, UAS1, LAS1 = calculate_stats(result,path)
|
| 103 |
+
# sent_uas2, sent_las2, UAS2, LAS2 = calculate_stats(results_2)
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
write_results(sent_uas1, sent_las1, 'results1_uas.txt', 'results1_las.txt')
|
| 107 |
+
# write_results(sent_uas2, sent_las2, 'results2_uas.txt', 'results2_las.txt')
|
examples/write_1300_combined.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
|
| 3 |
+
def write_combined(dirs):
|
| 4 |
+
path = "./saved_models/"+dirs+"/final_ensembled/"
|
| 5 |
+
f = open(path+'domain_san_test_model_domain_san_data_domain_san_gold.txt','r')
|
| 6 |
+
gold = f.readlines()
|
| 7 |
+
f.close()
|
| 8 |
+
f = open(path+'domain_san_test_model_domain_san_data_domain_san_pred.txt','r')
|
| 9 |
+
pred = f.readlines()
|
| 10 |
+
f.close()
|
| 11 |
+
|
| 12 |
+
for i in range(len(gold)):
|
| 13 |
+
if gold[i] == '\n':
|
| 14 |
+
continue
|
| 15 |
+
if gold[i].split('\t')[0] == pred[i].split('\t')[0]:
|
| 16 |
+
gold[i] = gold[i].replace('\n','\t')
|
| 17 |
+
gold[i] = gold[i]+'\t'.join(pred[i].split('\t')[-2:])
|
| 18 |
+
|
| 19 |
+
f = open(path+'domain_san_prose_model_domain_san_data_domain_san_gold.txt','r')
|
| 20 |
+
prose_gold = f.readlines()
|
| 21 |
+
f.close()
|
| 22 |
+
f = open(path+'domain_san_prose_model_domain_san_data_domain_san_pred.txt','r')
|
| 23 |
+
prose_pred = f.readlines()
|
| 24 |
+
f.close()
|
| 25 |
+
|
| 26 |
+
for i in range(len(prose_gold)):
|
| 27 |
+
if prose_gold[i] == '\n':
|
| 28 |
+
gold.append('\n')
|
| 29 |
+
continue
|
| 30 |
+
if prose_gold[i].split('\t')[0] == prose_pred[i].split('\t')[0]:
|
| 31 |
+
line = prose_gold[i].replace('\n','\t')
|
| 32 |
+
line =line+'\t'.join(prose_pred[i].split('\t')[-2:])
|
| 33 |
+
gold.append(line)
|
| 34 |
+
gold.insert(0,'word_id\tword\tpostag\tlemma\tgold_head\tgold_label\tpred_head\tpred_label\n\n')
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
f = open(path+'combined_1300_test.txt','w')
|
| 38 |
+
for line in gold:
|
| 39 |
+
f.write(line)
|
| 40 |
+
f.close()
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
if __name__=="__main__":
|
| 44 |
+
|
| 45 |
+
dir_path = sys.argv[1]
|
| 46 |
+
|
| 47 |
+
write_combined(dir_path)
|
| 48 |
+
|
run_STBC.sh
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
domain="san"
|
| 3 |
+
# Path of pretrained embedding file
|
| 4 |
+
word_path=data/cc.STBC.300.txt
|
| 5 |
+
# Path to store model and predictions
|
| 6 |
+
saved_models=saved_models
|
| 7 |
+
declare -i num_epochs=100
|
| 8 |
+
declare -i word_dim=300
|
| 9 |
+
start_time=`date +%s`
|
| 10 |
+
current_time=$(date "+%Y.%m.%d-%H.%M.%S")
|
| 11 |
+
model_path="STBC_"$current_time
|
| 12 |
+
touch $saved_models/base_log.txt
|
| 13 |
+
|
| 14 |
+
###############################################################
|
| 15 |
+
# Running the base Biaffine Parser
|
| 16 |
+
echo "#################################################################"
|
| 17 |
+
echo "Currently base model in progress..."
|
| 18 |
+
echo "#################################################################"
|
| 19 |
+
python examples/GraphParser_MTL_POS.py --dataset ud --domain $domain --rnn_mode LSTM \
|
| 20 |
+
--num_epochs $num_epochs --batch_size 16 --hidden_size 512 --arc_space 512 \
|
| 21 |
+
--arc_tag_space 128 --num_layers 2 --num_filters 100 --use_char --use_pos \
|
| 22 |
+
--word_dim $word_dim --char_dim 100 --pos_dim 100 --initializer xavier --opt adam \
|
| 23 |
+
--learning_rate 0.002 --decay_rate 0.5 --schedule 6 --clip 5.0 --gamma 0.0 \
|
| 24 |
+
--epsilon 1e-6 --p_rnn 0.33 0.33 --p_in 0.33 --p_out 0.33 --arc_decode mst \
|
| 25 |
+
--punct_set '.' '``' ':' ',' --word_embedding fasttext --char_embedding random --pos_embedding random --word_path $word_path \
|
| 26 |
+
--model_path $saved_models/$model_path 2>&1 | tee $saved_models/base_log.txt
|
| 27 |
+
|
| 28 |
+
mv $saved_models/base_log.txt $saved_models/$model_path/base_log.txt
|
| 29 |
+
python examples/BiAFF_write_1300_combined.py $model_path
|
| 30 |
+
python examples/BiAFF_macro_UAS_LAS.py $model_path
|
| 31 |
+
|
| 32 |
+
# ###################################################################
|
| 33 |
+
# Pretraining step: Running the Sequence Tagger
|
| 34 |
+
# Auxiliary tasks : 'Multitask_case_predict' 'Multitask_POS_predict' 'Multitask_label_predict'
|
| 35 |
+
for task in 'Multitask_POS_predict' 'Multitask_label_predict' 'Multitask_case_predict'; do
|
| 36 |
+
touch $saved_models/$model_path/log.txt
|
| 37 |
+
echo "#################################################################"
|
| 38 |
+
echo "Currently $task in progress..."
|
| 39 |
+
echo "#################################################################"
|
| 40 |
+
python examples/SequenceTagger.py --dataset ud --domain $domain --task $task \
|
| 41 |
+
--rnn_mode LSTM --num_epochs $num_epochs --batch_size 16 --hidden_size 512 \
|
| 42 |
+
--tag_space 128 --num_layers 2 --num_filters 100 --use_char --use_pos --char_dim 100 \
|
| 43 |
+
--pos_dim 100 --initializer xavier --opt adam --learning_rate 0.002 --decay_rate 0.5 \
|
| 44 |
+
--schedule 6 --clip 5.0 --gamma 0.0 --epsilon 1e-6 --p_rnn 0.33 0.33 \
|
| 45 |
+
--p_in 0.33 --p_out 0.33 --punct_set '.' '``' ':' ',' \
|
| 46 |
+
--word_dim $word_dim --word_embedding fasttext --word_path $word_path --pos_embedding random \
|
| 47 |
+
--parser_path $saved_models/$model_path/ \
|
| 48 |
+
--use_unlabeled_data --char_embedding random \
|
| 49 |
+
--model_path $saved_models/$model_path/$task/ 2>&1 | tee $saved_models/$model_path/log.txt
|
| 50 |
+
mv $saved_models/$model_path/log.txt $saved_models/$model_path/$task/log.txt
|
| 51 |
+
done
|
| 52 |
+
|
| 53 |
+
######################################################################
|
| 54 |
+
## Integration step: The final ensembled proposed system
|
| 55 |
+
echo "#################################################################"
|
| 56 |
+
echo "Currently final model in progress..."
|
| 57 |
+
echo "#################################################################"
|
| 58 |
+
touch $saved_models/$model_path/log.txt
|
| 59 |
+
python examples/GraphParser_MTL_POS.py --dataset ud --domain $domain --rnn_mode LSTM \
|
| 60 |
+
--num_epochs $num_epochs --batch_size 16 --hidden_size 512 \
|
| 61 |
+
--arc_space 512 --arc_tag_space 128 --num_layers 2 --num_filters 100 --use_char --use_pos \
|
| 62 |
+
--word_dim $word_dim --char_dim 100 --pos_dim 100 --initializer xavier --opt adam \
|
| 63 |
+
--learning_rate 0.002 --decay_rate 0.5 --schedule 6 --clip 5.0 --gamma 0.0 --epsilon 1e-6 \
|
| 64 |
+
--p_rnn 0.33 0.33 --p_in 0.33 --p_out 0.33 --arc_decode mst --pos_embedding random \
|
| 65 |
+
--punct_set '.' '``' ':' ',' --word_embedding fasttext --char_embedding random --word_path $word_path \
|
| 66 |
+
--gating --num_gates 4 \
|
| 67 |
+
--load_sequence_taggers_paths $saved_models/$model_path/Multitask_case_predict/domain_$domain.pt \
|
| 68 |
+
$saved_models/$model_path/Multitask_POS_predict/domain_$domain.pt \
|
| 69 |
+
$saved_models/$model_path/Multitask_label_predict/domain_$domain.pt \
|
| 70 |
+
--model_path $saved_models/$model_path/final_ensembled 2>&1 | tee $saved_models/$model_path/log.txt
|
| 71 |
+
mv $saved_models/$model_path/log.txt $saved_models/$model_path/final_ensembled/log.txt
|
| 72 |
+
python examples/write_1300_combined.py $model_path
|
| 73 |
+
python examples/macro_UAS_LAS.py $model_path
|
| 74 |
+
end_time=`date +%s`
|
| 75 |
+
echo execution time was `expr $end_time - $start_time` s.
|