shivrajanand commited on
Commit
a7b3936
·
verified ·
1 Parent(s): 2139f37

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +1 -0
  2. ReadMe.md +47 -0
  3. data/Multitask_case_dev_VST +0 -0
  4. data/Multitask_case_dev_san +0 -0
  5. data/Multitask_case_poetry_san +0 -0
  6. data/Multitask_case_prose_san +0 -0
  7. data/Multitask_case_test_VST +0 -0
  8. data/Multitask_case_test_san +0 -0
  9. data/Multitask_case_train_VST +0 -0
  10. data/Multitask_case_train_san +0 -0
  11. data/Multitask_label_dev_VST +0 -0
  12. data/Multitask_label_dev_san +0 -0
  13. data/Multitask_label_poetry_san +0 -0
  14. data/Multitask_label_prose_san +0 -0
  15. data/Multitask_label_test_VST +0 -0
  16. data/Multitask_label_test_san +0 -0
  17. data/Multitask_label_train_VST +0 -0
  18. data/Multitask_label_train_san +0 -0
  19. data/Multitask_morph_dev_VST +0 -0
  20. data/Multitask_morph_dev_san +0 -0
  21. data/Multitask_morph_poetry_san +0 -0
  22. data/Multitask_morph_prose_san +0 -0
  23. data/Multitask_morph_test_VST +0 -0
  24. data/Multitask_morph_test_san +0 -0
  25. data/Multitask_morph_train_VST +0 -0
  26. data/Multitask_morph_train_san +0 -0
  27. data/combined_1300_test.txt +0 -0
  28. data/ud_pos_ner_dp_dev_VST +0 -0
  29. data/ud_pos_ner_dp_dev_san +0 -0
  30. data/ud_pos_ner_dp_poetry_VST +0 -0
  31. data/ud_pos_ner_dp_poetry_san +0 -0
  32. data/ud_pos_ner_dp_prose_VST +0 -0
  33. data/ud_pos_ner_dp_prose_san +0 -0
  34. data/ud_pos_ner_dp_test_VST +0 -0
  35. data/ud_pos_ner_dp_test_san +0 -0
  36. data/ud_pos_ner_dp_train_VST +0 -0
  37. data/ud_pos_ner_dp_train_san +0 -0
  38. data/ud_pos_ner_dp_train_san_org +0 -0
  39. examples/BiAFF_macro_UAS_LAS.py +108 -0
  40. examples/BiAFF_write_1300_combined.py +48 -0
  41. examples/GraphParser.py +599 -0
  42. examples/GraphParser_MTL_POS.py +633 -0
  43. examples/SequenceTagger.py +589 -0
  44. examples/VST_Pred_Prepare.py +34 -0
  45. examples/VST_macro_score.py +107 -0
  46. examples/eval/conll03eval.v2 +336 -0
  47. examples/eval/conll06eval.pl +1826 -0
  48. examples/macro_UAS_LAS.py +107 -0
  49. examples/write_1300_combined.py +48 -0
  50. run_STBC.sh +75 -0
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ ./saved_models
ReadMe.md ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Official code for the paper ["Systematic Investigation of Strategies Tailored for Low-Resource Settings for Low-Resource Dependency Parsing"](https://arxiv.org/abs/2201.11374).
2
+ If you use this code please cite our paper.
3
+
4
+ ## Requirements
5
+
6
+ * Python 3.7
7
+ * Pytorch 1.1.0
8
+ * Cuda 9.0
9
+ * Gensim 3.8.1
10
+
11
+ We assume that you have installed conda beforehand.
12
+
13
+ ```
14
+ conda install pytorch==1.1.0 torchvision==0.3.0 cudatoolkit=9.0 -c pytorch
15
+ pip install gensim==3.8.1
16
+ ```
17
+
18
+ ## Pretrained embeddings for Sanskrit
19
+ * Pretrained FastText embeddings for STBC/VST can be obtained from [here](https://drive.google.com/drive/folders/1SwdEqikTq-N2vOL7QSUX2vqi3faZE7bq?usp=sharing). Make sure that `.txt` file is placed at `data/`
20
+ * The main results are reported on the systems trained by combining train and dev splits.
21
+
22
+
23
+ ## How to train model for Sanskrit
24
+ To run proposed system: (1) Pretraining (2) Integration, then simply run bash script `run_STBC.sh` or `run_VST.sh` for the respective dataset. With these scripts you will be able to reproduce our results reported in Section-3 and Table 2.
25
+
26
+ ```bash
27
+ bash run_STBC.sh
28
+
29
+ ```
30
+
31
+ ## Citations
32
+ ```
33
+ @misc{sandhan_systematic,
34
+ doi = {10.48550/ARXIV.2201.11374},
35
+ url = {https://arxiv.org/abs/2201.11374},
36
+ author = {Sandhan, Jivnesh and Behera, Laxmidhar and Goyal, Pawan},
37
+ keywords = {Computation and Language (cs.CL), FOS: Computer and information sciences, FOS: Computer and information sciences},
38
+ title = {Systematic Investigation of Strategies Tailored for Low-Resource Settings for Low-Resource Dependency Parsing},
39
+ publisher = {arXiv},
40
+ year = {2022},
41
+ copyright = {Creative Commons Attribution 4.0 International}
42
+ }
43
+
44
+ ```
45
+
46
+ ## Acknowledgements
47
+ Our ensembled system is built on the top of ["DCST Implementation"](https://github.com/rotmanguy/DCST)
data/Multitask_case_dev_VST ADDED
The diff for this file is too large to render. See raw diff
 
data/Multitask_case_dev_san ADDED
The diff for this file is too large to render. See raw diff
 
data/Multitask_case_poetry_san ADDED
The diff for this file is too large to render. See raw diff
 
data/Multitask_case_prose_san ADDED
The diff for this file is too large to render. See raw diff
 
data/Multitask_case_test_VST ADDED
The diff for this file is too large to render. See raw diff
 
data/Multitask_case_test_san ADDED
The diff for this file is too large to render. See raw diff
 
data/Multitask_case_train_VST ADDED
The diff for this file is too large to render. See raw diff
 
data/Multitask_case_train_san ADDED
The diff for this file is too large to render. See raw diff
 
data/Multitask_label_dev_VST ADDED
The diff for this file is too large to render. See raw diff
 
data/Multitask_label_dev_san ADDED
The diff for this file is too large to render. See raw diff
 
data/Multitask_label_poetry_san ADDED
The diff for this file is too large to render. See raw diff
 
data/Multitask_label_prose_san ADDED
The diff for this file is too large to render. See raw diff
 
data/Multitask_label_test_VST ADDED
The diff for this file is too large to render. See raw diff
 
data/Multitask_label_test_san ADDED
The diff for this file is too large to render. See raw diff
 
data/Multitask_label_train_VST ADDED
The diff for this file is too large to render. See raw diff
 
data/Multitask_label_train_san ADDED
The diff for this file is too large to render. See raw diff
 
data/Multitask_morph_dev_VST ADDED
The diff for this file is too large to render. See raw diff
 
data/Multitask_morph_dev_san ADDED
The diff for this file is too large to render. See raw diff
 
data/Multitask_morph_poetry_san ADDED
The diff for this file is too large to render. See raw diff
 
data/Multitask_morph_prose_san ADDED
The diff for this file is too large to render. See raw diff
 
data/Multitask_morph_test_VST ADDED
The diff for this file is too large to render. See raw diff
 
data/Multitask_morph_test_san ADDED
The diff for this file is too large to render. See raw diff
 
data/Multitask_morph_train_VST ADDED
The diff for this file is too large to render. See raw diff
 
data/Multitask_morph_train_san ADDED
The diff for this file is too large to render. See raw diff
 
data/combined_1300_test.txt ADDED
The diff for this file is too large to render. See raw diff
 
data/ud_pos_ner_dp_dev_VST ADDED
The diff for this file is too large to render. See raw diff
 
data/ud_pos_ner_dp_dev_san ADDED
The diff for this file is too large to render. See raw diff
 
data/ud_pos_ner_dp_poetry_VST ADDED
The diff for this file is too large to render. See raw diff
 
data/ud_pos_ner_dp_poetry_san ADDED
The diff for this file is too large to render. See raw diff
 
data/ud_pos_ner_dp_prose_VST ADDED
The diff for this file is too large to render. See raw diff
 
data/ud_pos_ner_dp_prose_san ADDED
The diff for this file is too large to render. See raw diff
 
data/ud_pos_ner_dp_test_VST ADDED
The diff for this file is too large to render. See raw diff
 
data/ud_pos_ner_dp_test_san ADDED
The diff for this file is too large to render. See raw diff
 
data/ud_pos_ner_dp_train_VST ADDED
The diff for this file is too large to render. See raw diff
 
data/ud_pos_ner_dp_train_san ADDED
The diff for this file is too large to render. See raw diff
 
data/ud_pos_ner_dp_train_san_org ADDED
The diff for this file is too large to render. See raw diff
 
examples/BiAFF_macro_UAS_LAS.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+
3
+
4
+ def load_results(filename):
5
+
6
+ results = []
7
+ sent = []
8
+ with open(filename, 'r') as fp:
9
+ for i, line in enumerate(fp):
10
+ if i == 0:
11
+ continue
12
+ splits = line.strip().split('\t')
13
+ if len(line.strip()) == 0:
14
+ if len(sent) != 0:
15
+ results.append(sent)
16
+ sent = []
17
+ continue
18
+ gold_head = splits[-4]
19
+ gold_label = splits[-3]
20
+ pred_head = splits[-2]
21
+ pred_label = splits[-1]
22
+ sent.append((gold_head, gold_label, pred_head, pred_label))
23
+ print('Total Number of sentences ' + str(len(results)))
24
+ return results
25
+
26
+ def calculate_las_uas(gold_heads, gold_labels, pred_heads, pred_labels):
27
+
28
+ u_correct = 0
29
+ l_correct = 0
30
+ u_total = 0
31
+ l_total = 0
32
+
33
+ for i in range(len(gold_heads)):
34
+ if gold_heads[i] == pred_heads[i]:
35
+ u_correct +=1
36
+ u_total +=1
37
+ l_total +=1
38
+ if gold_heads[i] == pred_heads[i] and gold_labels[i] == pred_labels[i]:
39
+ l_correct +=1
40
+ return u_correct, u_total, l_correct, l_total
41
+
42
+
43
+ def calculate_stats(results,path):
44
+ u_correct = 0
45
+ l_correct = 0
46
+ u_total = 0
47
+ l_total = 0
48
+
49
+ sent_uas = []
50
+ sent_las = []
51
+
52
+ for i in range(len(results)):
53
+ gold_heads, gold_labels, pred_heads, pred_labels = zip(*results[i])
54
+ u_c, u_t, l_c, l_t = calculate_las_uas(gold_heads, gold_labels, pred_heads, pred_labels)
55
+ if u_t >0:
56
+ uas = float(u_c)/u_t
57
+ las = float(l_c)/l_t
58
+ sent_uas.append(uas)
59
+ sent_las.append(las)
60
+ u_correct += u_c
61
+ l_correct += l_c
62
+ u_total += u_t
63
+ l_total += l_t
64
+
65
+ UAS = float(u_correct)/u_total
66
+ LAS = float(l_correct)/l_total
67
+ path = path.replace('combined_1300_test.txt','Macro-UAS-LAS-score.txt')
68
+ f = open(path,'w')
69
+ f.write('Word level UAS : ' + str(UAS) +'\n')
70
+ f.write('Word level LAS : ' + str(LAS)+'\n')
71
+ f.write('Sentence level UAS : ' + str(float(sum(sent_uas))/len(sent_uas))+'\n')
72
+ f.write('Sentence level LAS : ' + str(float(sum(sent_las))/len(sent_las))+'\n')
73
+ f.close()
74
+ print('Word level UAS : ' + str(UAS))
75
+ print('Word level LAS : ' + str(LAS))
76
+ print('Sentence level UAS : ' + str(float(sum(sent_uas))/len(sent_uas)))
77
+ print('Sentence level LAS : ' + str(float(sum(sent_las))/len(sent_las)))
78
+
79
+ return sent_uas, sent_las, UAS, LAS
80
+
81
+ def write_results(sent_uas, sent_las, filename_uas, filename_las):
82
+
83
+ fp_uas = open(filename_uas, 'w')
84
+ fp_las = open(filename_las, 'w')
85
+
86
+ for i in range(len(sent_uas)):
87
+ fp_uas.write(str(sent_uas[i]) + '\n')
88
+ fp_las.write(str(sent_las[i]) + '\n')
89
+
90
+ fp_uas.close()
91
+ fp_las.close()
92
+
93
+
94
+ if __name__=="__main__":
95
+ dirs = sys.argv[1]
96
+ # results_2 = load_results(sys.argv[2])
97
+ ##path = "Predictions/Yap/"+dirs
98
+ # path = "/home/jivnesh/Documents/San-SOTA/saved_models/"+dirs+"/final_ensembled/dev_combined_1000.txt"
99
+ path = "./saved_models/"+dirs+"/combined_1300_test.txt"
100
+ result = load_results(path)
101
+
102
+
103
+ sent_uas1, sent_las1, UAS1, LAS1 = calculate_stats(result,path)
104
+ # sent_uas2, sent_las2, UAS2, LAS2 = calculate_stats(results_2)
105
+
106
+
107
+ write_results(sent_uas1, sent_las1, 'results1_uas.txt', 'results1_las.txt')
108
+ # write_results(sent_uas2, sent_las2, 'results2_uas.txt', 'results2_las.txt')
examples/BiAFF_write_1300_combined.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+
3
+ def write_combined(dirs):
4
+ path = "./saved_models/"+dirs+'/'
5
+ f = open(path+'domain_san_test_model_domain_san_data_domain_san_gold.txt','r')
6
+ gold = f.readlines()
7
+ f.close()
8
+ f = open(path+'domain_san_test_model_domain_san_data_domain_san_pred.txt','r')
9
+ pred = f.readlines()
10
+ f.close()
11
+
12
+ for i in range(len(gold)):
13
+ if gold[i] == '\n':
14
+ continue
15
+ if gold[i].split('\t')[0] == pred[i].split('\t')[0]:
16
+ gold[i] = gold[i].replace('\n','\t')
17
+ gold[i] = gold[i]+'\t'.join(pred[i].split('\t')[-2:])
18
+
19
+ f = open(path+'domain_san_prose_model_domain_san_data_domain_san_gold.txt','r')
20
+ prose_gold = f.readlines()
21
+ f.close()
22
+ f = open(path+'domain_san_prose_model_domain_san_data_domain_san_pred.txt','r')
23
+ prose_pred = f.readlines()
24
+ f.close()
25
+
26
+ for i in range(len(prose_gold)):
27
+ if prose_gold[i] == '\n':
28
+ gold.append('\n')
29
+ continue
30
+ if prose_gold[i].split('\t')[0] == prose_pred[i].split('\t')[0]:
31
+ line = prose_gold[i].replace('\n','\t')
32
+ line =line+'\t'.join(prose_pred[i].split('\t')[-2:])
33
+ gold.append(line)
34
+ gold.insert(0,'word_id\tword\tpostag\tlemma\tgold_head\tgold_label\tpred_head\tpred_label\n\n')
35
+
36
+
37
+ f = open(path+'combined_1300_test.txt','w')
38
+ for line in gold:
39
+ f.write(line)
40
+ f.close()
41
+
42
+
43
+ if __name__=="__main__":
44
+
45
+ dir_path = sys.argv[1]
46
+
47
+ write_combined(dir_path)
48
+
examples/GraphParser.py ADDED
@@ -0,0 +1,599 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import print_function
2
+ import sys
3
+ from os import path, makedirs
4
+
5
+ sys.path.append(".")
6
+ sys.path.append("..")
7
+
8
+ import argparse
9
+ from copy import deepcopy
10
+ import json
11
+ import numpy as np
12
+ import torch
13
+ import torch.nn as nn
14
+ from collections import namedtuple
15
+ from utils.io_ import seeds, Writer, get_logger, prepare_data, rearrange_splits
16
+ from utils.models.parsing_gating import BiAffine_Parser_Gated
17
+ from utils import load_word_embeddings
18
+ from utils.tasks import parse
19
+ import time
20
+ from torch.nn.utils import clip_grad_norm_
21
+ from torch.optim import Adam, SGD
22
+ import uuid
23
+
24
+ uid = uuid.uuid4().hex[:6]
25
+
26
+ logger = get_logger('GraphParser')
27
+
28
+ def read_arguments():
29
+ args_ = argparse.ArgumentParser(description='Sovling GraphParser')
30
+ args_.add_argument('--dataset', choices=['ontonotes', 'ud'], help='Dataset', required=True)
31
+ args_.add_argument('--domain', help='domain/language', required=True)
32
+ args_.add_argument('--rnn_mode', choices=['RNN', 'LSTM', 'GRU'], help='architecture of rnn',
33
+ required=True)
34
+ args_.add_argument('--gating',action='store_true', help='use gated mechanism')
35
+ args_.add_argument('--num_gates', type=int, default=0, help='number of gates for gating mechanism')
36
+ args_.add_argument('--num_epochs', type=int, default=200, help='Number of training epochs')
37
+ args_.add_argument('--batch_size', type=int, default=64, help='Number of sentences in each batch')
38
+ args_.add_argument('--hidden_size', type=int, default=256, help='Number of hidden units in RNN')
39
+ args_.add_argument('--arc_space', type=int, default=128, help='Dimension of tag space')
40
+ args_.add_argument('--arc_tag_space', type=int, default=128, help='Dimension of tag space')
41
+ args_.add_argument('--num_layers', type=int, default=1, help='Number of layers of RNN')
42
+ args_.add_argument('--num_filters', type=int, default=50, help='Number of filters in CNN')
43
+ args_.add_argument('--kernel_size', type=int, default=3, help='Size of Kernel for CNN')
44
+ args_.add_argument('--use_pos', action='store_true', help='use part-of-speech embedding.')
45
+ args_.add_argument('--use_char', action='store_true', help='use character embedding and CNN.')
46
+ args_.add_argument('--word_dim', type=int, default=300, help='Dimension of word embeddings')
47
+ args_.add_argument('--pos_dim', type=int, default=50, help='Dimension of POS embeddings')
48
+ args_.add_argument('--char_dim', type=int, default=50, help='Dimension of Character embeddings')
49
+ args_.add_argument('--initializer', choices=['xavier'], help='initialize model parameters')
50
+ args_.add_argument('--opt', choices=['adam', 'sgd'], help='optimization algorithm')
51
+ args_.add_argument('--momentum', type=float, default=0.9, help='momentum of optimizer')
52
+ args_.add_argument('--betas', nargs=2, type=float, default=[0.9, 0.9], help='betas of optimizer')
53
+ args_.add_argument('--learning_rate', type=float, default=0.01, help='Learning rate')
54
+ args_.add_argument('--decay_rate', type=float, default=0.05, help='Decay rate of learning rate')
55
+ args_.add_argument('--schedule', type=int, help='schedule for learning rate decay')
56
+ args_.add_argument('--clip', type=float, default=5.0, help='gradient clipping')
57
+ args_.add_argument('--gamma', type=float, default=0.0, help='weight for regularization')
58
+ args_.add_argument('--epsilon', type=float, default=1e-8, help='epsilon for adam')
59
+ args_.add_argument('--p_rnn', nargs=2, type=float, required=True, help='dropout rate for RNN')
60
+ args_.add_argument('--p_in', type=float, default=0.33, help='dropout rate for input embeddings')
61
+ args_.add_argument('--p_out', type=float, default=0.33, help='dropout rate for output layer')
62
+ args_.add_argument('--arc_decode', choices=['mst', 'greedy'], help='arc decoding algorithm', required=True)
63
+ args_.add_argument('--unk_replace', type=float, default=0.,
64
+ help='The rate to replace a singleton word with UNK')
65
+ args_.add_argument('--punct_set', nargs='+', type=str, help='List of punctuations')
66
+ args_.add_argument('--word_embedding', choices=['random', 'glove', 'fasttext', 'word2vec'],
67
+ help='Embedding for words')
68
+ args_.add_argument('--word_path', help='path for word embedding dict - in case word_embedding is not random')
69
+ args_.add_argument('--freeze_word_embeddings', action='store_true', help='frozen the word embedding (disable fine-tuning).')
70
+ args_.add_argument('--freeze_sequence_taggers', action='store_true', help='frozen the BiLSTMs of the pre-trained taggers.')
71
+ args_.add_argument('--char_embedding', choices=['random','hellwig'], help='Embedding for characters',
72
+ required=True)
73
+ args_.add_argument('--pos_embedding', choices=['random','one_hot'], help='Embedding for pos',
74
+ required=True)
75
+ args_.add_argument('--char_path', help='path for character embedding dict')
76
+ args_.add_argument('--pos_path', help='path for pos embedding dict')
77
+ args_.add_argument('--set_num_training_samples', type=int, help='downsampling training set to a fixed number of samples')
78
+ args_.add_argument('--model_path', help='path for saving model file.', required=True)
79
+ args_.add_argument('--load_path', help='path for loading saved source model file.', default=None)
80
+ args_.add_argument('--load_sequence_taggers_paths', nargs='+', help='path for loading saved sequence_tagger saved_models files.', default=None)
81
+ args_.add_argument('--strict',action='store_true', help='if True loaded model state should contin '
82
+ 'exactly the same keys as current model')
83
+ args_.add_argument('--eval_mode', action='store_true', help='evaluating model without training it')
84
+ args = args_.parse_args()
85
+ args_dict = {}
86
+ args_dict['dataset'] = args.dataset
87
+ args_dict['domain'] = args.domain
88
+ args_dict['rnn_mode'] = args.rnn_mode
89
+ args_dict['gating'] = args.gating
90
+ args_dict['num_gates'] = args.num_gates
91
+ args_dict['arc_decode'] = args.arc_decode
92
+ # args_dict['splits'] = ['train', 'dev', 'test']
93
+ args_dict['splits'] = ['train', 'dev', 'test','poetry','prose']
94
+ args_dict['model_path'] = args.model_path
95
+ if not path.exists(args_dict['model_path']):
96
+ makedirs(args_dict['model_path'])
97
+ args_dict['data_paths'] = {}
98
+ if args_dict['dataset'] == 'ontonotes':
99
+ data_path = 'data/onto_pos_ner_dp'
100
+ else:
101
+ data_path = 'data/ud_pos_ner_dp'
102
+ for split in args_dict['splits']:
103
+ args_dict['data_paths'][split] = data_path + '_' + split + '_' + args_dict['domain']
104
+ ###################################
105
+ args_dict['data_paths']['poetry'] = 'data/ud_pos_ner_dp' + '_' + 'poetry' + '_' + args_dict['domain']
106
+ args_dict['data_paths']['prose'] = 'data/ud_pos_ner_dp' + '_' + 'prose' + '_' + args_dict['domain']
107
+ ###################################
108
+ args_dict['alphabet_data_paths'] = {}
109
+ for split in args_dict['splits']:
110
+ if args_dict['dataset'] == 'ontonotes':
111
+ args_dict['alphabet_data_paths'][split] = data_path + '_' + split + '_' + 'all'
112
+ else:
113
+ if '_' in args_dict['domain']:
114
+ args_dict['alphabet_data_paths'][split] = data_path + '_' + split + '_' + args_dict['domain'].split('_')[0]
115
+ else:
116
+ args_dict['alphabet_data_paths'][split] = args_dict['data_paths'][split]
117
+ args_dict['model_name'] = 'domain_' + args_dict['domain']
118
+ args_dict['full_model_name'] = path.join(args_dict['model_path'],args_dict['model_name'])
119
+ args_dict['load_path'] = args.load_path
120
+ args_dict['load_sequence_taggers_paths'] = args.load_sequence_taggers_paths
121
+ if args_dict['load_sequence_taggers_paths'] is not None:
122
+ args_dict['gating'] = True
123
+ args_dict['num_gates'] = len(args_dict['load_sequence_taggers_paths']) + 1
124
+ else:
125
+ if not args_dict['gating']:
126
+ args_dict['num_gates'] = 0
127
+ args_dict['strict'] = args.strict
128
+ args_dict['num_epochs'] = args.num_epochs
129
+ args_dict['batch_size'] = args.batch_size
130
+ args_dict['hidden_size'] = args.hidden_size
131
+ args_dict['arc_space'] = args.arc_space
132
+ args_dict['arc_tag_space'] = args.arc_tag_space
133
+ args_dict['num_layers'] = args.num_layers
134
+ args_dict['num_filters'] = args.num_filters
135
+ args_dict['kernel_size'] = args.kernel_size
136
+ args_dict['learning_rate'] = args.learning_rate
137
+ args_dict['initializer'] = nn.init.xavier_uniform_ if args.initializer == 'xavier' else None
138
+ args_dict['opt'] = args.opt
139
+ args_dict['momentum'] = args.momentum
140
+ args_dict['betas'] = tuple(args.betas)
141
+ args_dict['epsilon'] = args.epsilon
142
+ args_dict['decay_rate'] = args.decay_rate
143
+ args_dict['clip'] = args.clip
144
+ args_dict['gamma'] = args.gamma
145
+ args_dict['schedule'] = args.schedule
146
+ args_dict['p_rnn'] = tuple(args.p_rnn)
147
+ args_dict['p_in'] = args.p_in
148
+ args_dict['p_out'] = args.p_out
149
+ args_dict['unk_replace'] = args.unk_replace
150
+ args_dict['set_num_training_samples'] = args.set_num_training_samples
151
+ args_dict['punct_set'] = None
152
+ if args.punct_set is not None:
153
+ args_dict['punct_set'] = set(args.punct_set)
154
+ logger.info("punctuations(%d): %s" % (len(args_dict['punct_set']), ' '.join(args_dict['punct_set'])))
155
+ args_dict['freeze_word_embeddings'] = args.freeze_word_embeddings
156
+ args_dict['freeze_sequence_taggers'] = args.freeze_sequence_taggers
157
+ args_dict['word_embedding'] = args.word_embedding
158
+ args_dict['word_path'] = args.word_path
159
+ args_dict['use_char'] = args.use_char
160
+ args_dict['char_embedding'] = args.char_embedding
161
+ args_dict['char_path'] = args.char_path
162
+ args_dict['pos_embedding'] = args.pos_embedding
163
+ args_dict['pos_path'] = args.pos_path
164
+ args_dict['use_pos'] = args.use_pos
165
+ args_dict['pos_dim'] = args.pos_dim
166
+ args_dict['word_dict'] = None
167
+ args_dict['word_dim'] = args.word_dim
168
+ if args_dict['word_embedding'] != 'random' and args_dict['word_path']:
169
+ args_dict['word_dict'], args_dict['word_dim'] = load_word_embeddings.load_embedding_dict(args_dict['word_embedding'],
170
+ args_dict['word_path'])
171
+ args_dict['char_dict'] = None
172
+ args_dict['char_dim'] = args.char_dim
173
+ if args_dict['char_embedding'] != 'random':
174
+ args_dict['char_dict'], args_dict['char_dim'] = load_word_embeddings.load_embedding_dict(args_dict['char_embedding'],
175
+ args_dict['char_path'])
176
+ args_dict['pos_dict'] = None
177
+ if args_dict['pos_embedding'] != 'random':
178
+ args_dict['pos_dict'], args_dict['pos_dim'] = load_word_embeddings.load_embedding_dict(args_dict['pos_embedding'],
179
+ args_dict['pos_path'])
180
+ args_dict['alphabet_path'] = path.join(args_dict['model_path'], 'alphabets' + '_src_domain_' + args_dict['domain'] + '/')
181
+ args_dict['model_name'] = path.join(args_dict['model_path'], args_dict['model_name'])
182
+ args_dict['eval_mode'] = args.eval_mode
183
+ args_dict['device'] = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
184
+ args_dict['word_status'] = 'frozen' if args.freeze_word_embeddings else 'fine tune'
185
+ args_dict['char_status'] = 'enabled' if args.use_char else 'disabled'
186
+ args_dict['pos_status'] = 'enabled' if args.use_pos else 'disabled'
187
+ logger.info("Saving arguments to file")
188
+ save_args(args, args_dict['full_model_name'])
189
+ logger.info("Creating Alphabets")
190
+ alphabet_dict = creating_alphabets(args_dict['alphabet_path'], args_dict['alphabet_data_paths'], args_dict['word_dict'])
191
+ args_dict = {**args_dict, **alphabet_dict}
192
+ ARGS = namedtuple('ARGS', args_dict.keys())
193
+ my_args = ARGS(**args_dict)
194
+ return my_args
195
+
196
+
197
+ def creating_alphabets(alphabet_path, alphabet_data_paths, word_dict):
198
+ train_paths = alphabet_data_paths['train']
199
+ extra_paths = [v for k,v in alphabet_data_paths.items() if k != 'train']
200
+ alphabet_dict = {}
201
+ alphabet_dict['alphabets'] = prepare_data.create_alphabets(alphabet_path,
202
+ train_paths,
203
+ extra_paths=extra_paths,
204
+ max_vocabulary_size=100000,
205
+ embedd_dict=word_dict)
206
+ for k, v in alphabet_dict['alphabets'].items():
207
+ num_key = 'num_' + k.split('_')[0]
208
+ alphabet_dict[num_key] = v.size()
209
+ logger.info("%s : %d" % (num_key, alphabet_dict[num_key]))
210
+ return alphabet_dict
211
+
212
+ def construct_embedding_table(alphabet, tokens_dict, dim, token_type='word'):
213
+ if tokens_dict is None:
214
+ return None
215
+ scale = np.sqrt(3.0 / dim)
216
+ table = np.empty([alphabet.size(), dim], dtype=np.float32)
217
+ table[prepare_data.UNK_ID, :] = np.random.uniform(-scale, scale, [1, dim]).astype(np.float32)
218
+ oov_tokens = 0
219
+ for token, index in alphabet.items():
220
+ if token in tokens_dict:
221
+ embedding = tokens_dict[token]
222
+ elif token.lower() in tokens_dict:
223
+ embedding = tokens_dict[token.lower()]
224
+ else:
225
+ embedding = np.random.uniform(-scale, scale, [1, dim]).astype(np.float32)
226
+ oov_tokens += 1
227
+ table[index, :] = embedding
228
+ print('token type : %s, number of oov: %d' % (token_type, oov_tokens))
229
+ table = torch.from_numpy(table)
230
+ return table
231
+
232
+ def save_args(args, full_model_name):
233
+ arg_path = full_model_name + '.arg.json'
234
+ argparse_dict = vars(args)
235
+ with open(arg_path, 'w') as f:
236
+ json.dump(argparse_dict, f)
237
+
238
+ def generate_optimizer(args, lr, params):
239
+ params = filter(lambda param: param.requires_grad, params)
240
+ if args.opt == 'adam':
241
+ return Adam(params, lr=lr, betas=args.betas, weight_decay=args.gamma, eps=args.epsilon)
242
+ elif args.opt == 'sgd':
243
+ return SGD(params, lr=lr, momentum=args.momentum, weight_decay=args.gamma, nesterov=True)
244
+ else:
245
+ raise ValueError('Unknown optimization algorithm: %s' % args.opt)
246
+
247
+
248
+ def save_checkpoint(model, optimizer, opt, dev_eval_dict, test_eval_dict, full_model_name):
249
+ path_name = full_model_name + '.pt'
250
+ print('Saving model to: %s' % path_name)
251
+ state = {'model_state_dict': model.state_dict(),
252
+ 'optimizer_state_dict': optimizer.state_dict(),
253
+ 'opt': opt,
254
+ 'dev_eval_dict': dev_eval_dict,
255
+ 'test_eval_dict': test_eval_dict}
256
+ torch.save(state, path_name)
257
+
258
+
259
+ def load_checkpoint(args, model, optimizer, dev_eval_dict, test_eval_dict, start_epoch, load_path, strict=True):
260
+ print('Loading saved model from: %s' % load_path)
261
+ checkpoint = torch.load(load_path, map_location=args.device)
262
+ if checkpoint['opt'] != args.opt:
263
+ raise ValueError('loaded optimizer type is: %s instead of: %s' % (checkpoint['opt'], args.opt))
264
+ model.load_state_dict(checkpoint['model_state_dict'], strict=strict)
265
+
266
+ if strict:
267
+ generate_optimizer(args, args.learning_rate, model.parameters())
268
+ optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
269
+ for state in optimizer.state.values():
270
+ for k, v in state.items():
271
+ if isinstance(v, torch.Tensor):
272
+ state[k] = v.to(args.device)
273
+ dev_eval_dict = checkpoint['dev_eval_dict']
274
+ test_eval_dict = checkpoint['test_eval_dict']
275
+ start_epoch = dev_eval_dict['in_domain']['epoch']
276
+ return model, optimizer, dev_eval_dict, test_eval_dict, start_epoch
277
+
278
+
279
+ def build_model_and_optimizer(args):
280
+ word_table = construct_embedding_table(args.alphabets['word_alphabet'], args.word_dict, args.word_dim, token_type='word')
281
+ char_table = construct_embedding_table(args.alphabets['char_alphabet'], args.char_dict, args.char_dim, token_type='char')
282
+ pos_table = construct_embedding_table(args.alphabets['pos_alphabet'], args.pos_dict, args.pos_dim, token_type='pos')
283
+ model = BiAffine_Parser_Gated(args.word_dim, args.num_word, args.char_dim, args.num_char,
284
+ args.use_pos, args.use_char, args.pos_dim, args.num_pos,
285
+ args.num_filters, args.kernel_size, args.rnn_mode,
286
+ args.hidden_size, args.num_layers, args.num_arc,
287
+ args.arc_space, args.arc_tag_space, args.num_gates,
288
+ embedd_word=word_table, embedd_char=char_table, embedd_pos=pos_table,
289
+ p_in=args.p_in, p_out=args.p_out, p_rnn=args.p_rnn,
290
+ biaffine=True, arc_decode=args.arc_decode, initializer=args.initializer)
291
+ print(model)
292
+ optimizer = generate_optimizer(args, args.learning_rate, model.parameters())
293
+ start_epoch = 0
294
+ dev_eval_dict = {'in_domain': initialize_eval_dict()}
295
+ test_eval_dict = {'in_domain': initialize_eval_dict()}
296
+ if args.load_path:
297
+ model, optimizer, dev_eval_dict, test_eval_dict, start_epoch = \
298
+ load_checkpoint(args, model, optimizer,
299
+ dev_eval_dict, test_eval_dict,
300
+ start_epoch, args.load_path, strict=args.strict)
301
+ if args.load_sequence_taggers_paths:
302
+ pretrained_dict = {}
303
+ model_dict = model.state_dict()
304
+ for idx, path in enumerate(args.load_sequence_taggers_paths):
305
+ print('Loading saved sequence_tagger from: %s' % path)
306
+ checkpoint = torch.load(path, map_location=args.device)
307
+ for k, v in checkpoint['model_state_dict'].items():
308
+ if 'rnn_encoder.' in k:
309
+ pretrained_dict['extra_rnn_encoders.' + str(idx) + '.' + k.replace('rnn_encoder.', '')] = v
310
+ model_dict.update(pretrained_dict)
311
+ model.load_state_dict(model_dict)
312
+ if args.freeze_sequence_taggers:
313
+ print('Freezing Classifiers')
314
+ for name, parameter in model.named_parameters():
315
+ if 'extra_rnn_encoders' in name:
316
+ parameter.requires_grad = False
317
+ if args.freeze_word_embeddings:
318
+ model.rnn_encoder.word_embedd.weight.requires_grad = False
319
+ # model.rnn_encoder.char_embedd.weight.requires_grad = False
320
+ # model.rnn_encoder.pos_embedd.weight.requires_grad = False
321
+ device = args.device
322
+ model.to(device)
323
+ return model, optimizer, dev_eval_dict, test_eval_dict, start_epoch
324
+
325
+
326
+ def initialize_eval_dict():
327
+ eval_dict = {}
328
+ eval_dict['dp_uas'] = 0.0
329
+ eval_dict['dp_las'] = 0.0
330
+ eval_dict['epoch'] = 0
331
+ eval_dict['dp_ucorrect'] = 0.0
332
+ eval_dict['dp_lcorrect'] = 0.0
333
+ eval_dict['dp_total'] = 0.0
334
+ eval_dict['dp_ucomplete_match'] = 0.0
335
+ eval_dict['dp_lcomplete_match'] = 0.0
336
+ eval_dict['dp_ucorrect_nopunc'] = 0.0
337
+ eval_dict['dp_lcorrect_nopunc'] = 0.0
338
+ eval_dict['dp_total_nopunc'] = 0.0
339
+ eval_dict['dp_ucomplete_match_nopunc'] = 0.0
340
+ eval_dict['dp_lcomplete_match_nopunc'] = 0.0
341
+ eval_dict['dp_root_correct'] = 0.0
342
+ eval_dict['dp_total_root'] = 0.0
343
+ eval_dict['dp_total_inst'] = 0.0
344
+ eval_dict['dp_total'] = 0.0
345
+ eval_dict['dp_total_inst'] = 0.0
346
+ eval_dict['dp_total_nopunc'] = 0.0
347
+ eval_dict['dp_total_root'] = 0.0
348
+ return eval_dict
349
+
350
+ def in_domain_evaluation(args, datasets, model, optimizer, dev_eval_dict, test_eval_dict, epoch,
351
+ best_model, best_optimizer, patient):
352
+ # In-domain evaluation
353
+ curr_dev_eval_dict = evaluation(args, datasets['dev'], 'dev', model, args.domain, epoch, 'current_results')
354
+ is_best_in_domain = dev_eval_dict['in_domain']['dp_lcorrect_nopunc'] <= curr_dev_eval_dict['dp_lcorrect_nopunc'] or \
355
+ (dev_eval_dict['in_domain']['dp_lcorrect_nopunc'] == curr_dev_eval_dict['dp_lcorrect_nopunc'] and
356
+ dev_eval_dict['in_domain']['dp_ucorrect_nopunc'] <= curr_dev_eval_dict['dp_ucorrect_nopunc'])
357
+
358
+ if is_best_in_domain:
359
+ for key, value in curr_dev_eval_dict.items():
360
+ dev_eval_dict['in_domain'][key] = value
361
+ curr_test_eval_dict = evaluation(args, datasets['test'], 'test', model, args.domain, epoch, 'current_results')
362
+ for key, value in curr_test_eval_dict.items():
363
+ test_eval_dict['in_domain'][key] = value
364
+ best_model = deepcopy(model)
365
+ best_optimizer = deepcopy(optimizer)
366
+ patient = 0
367
+ else:
368
+ patient += 1
369
+ if epoch == args.num_epochs:
370
+ # save in-domain checkpoint
371
+ if args.set_num_training_samples is not None:
372
+ splits_to_write = datasets.keys()
373
+ else:
374
+ splits_to_write = ['dev', 'test']
375
+ for split in splits_to_write:
376
+ if split == 'dev':
377
+ eval_dict = dev_eval_dict['in_domain']
378
+ elif split == 'test':
379
+ eval_dict = test_eval_dict['in_domain']
380
+ else:
381
+ eval_dict = None
382
+ write_results(args, datasets[split], args.domain, split, best_model, args.domain, eval_dict)
383
+ print("Saving best model")
384
+ save_checkpoint(best_model, best_optimizer, args.opt, dev_eval_dict, test_eval_dict, args.full_model_name)
385
+
386
+ print('\n')
387
+ return dev_eval_dict, test_eval_dict, best_model, best_optimizer, patient
388
+
389
+
390
+ def evaluation(args, data, split, model, domain, epoch, str_res='results'):
391
+ # evaluate performance on data
392
+ model.eval()
393
+
394
+ eval_dict = initialize_eval_dict()
395
+ eval_dict['epoch'] = epoch
396
+ for batch in prepare_data.iterate_batch(data, args.batch_size, args.device):
397
+ word, char, pos, ner, heads, arc_tags, auto_label, masks, lengths = batch
398
+ out_arc, out_arc_tag, masks, lengths = model.forward(word, char, pos, mask=masks, length=lengths)
399
+ heads_pred, arc_tags_pred, _ = model.decode(out_arc, out_arc_tag, mask=masks, length=lengths,
400
+ leading_symbolic=prepare_data.NUM_SYMBOLIC_TAGS)
401
+ lengths = lengths.cpu().numpy()
402
+ word = word.data.cpu().numpy()
403
+ pos = pos.data.cpu().numpy()
404
+ ner = ner.data.cpu().numpy()
405
+ heads = heads.data.cpu().numpy()
406
+ arc_tags = arc_tags.data.cpu().numpy()
407
+ heads_pred = heads_pred.data.cpu().numpy()
408
+ arc_tags_pred = arc_tags_pred.data.cpu().numpy()
409
+ stats, stats_nopunc, stats_root, num_inst = parse.eval_(word, pos, heads_pred, arc_tags_pred, heads,
410
+ arc_tags, args.alphabets['word_alphabet'], args.alphabets['pos_alphabet'],
411
+ lengths, punct_set=args.punct_set, symbolic_root=True)
412
+ ucorr, lcorr, total, ucm, lcm = stats
413
+ ucorr_nopunc, lcorr_nopunc, total_nopunc, ucm_nopunc, lcm_nopunc = stats_nopunc
414
+ corr_root, total_root = stats_root
415
+ eval_dict['dp_ucorrect'] += ucorr
416
+ eval_dict['dp_lcorrect'] += lcorr
417
+ eval_dict['dp_total'] += total
418
+ eval_dict['dp_ucomplete_match'] += ucm
419
+ eval_dict['dp_lcomplete_match'] += lcm
420
+ eval_dict['dp_ucorrect_nopunc'] += ucorr_nopunc
421
+ eval_dict['dp_lcorrect_nopunc'] += lcorr_nopunc
422
+ eval_dict['dp_total_nopunc'] += total_nopunc
423
+ eval_dict['dp_ucomplete_match_nopunc'] += ucm_nopunc
424
+ eval_dict['dp_lcomplete_match_nopunc'] += lcm_nopunc
425
+ eval_dict['dp_root_correct'] += corr_root
426
+ eval_dict['dp_total_root'] += total_root
427
+ eval_dict['dp_total_inst'] += num_inst
428
+
429
+ eval_dict['dp_uas'] = eval_dict['dp_ucorrect'] * 100 / eval_dict['dp_total'] # considering w. punctuation
430
+ eval_dict['dp_las'] = eval_dict['dp_lcorrect'] * 100 / eval_dict['dp_total'] # considering w. punctuation
431
+ print_results(eval_dict, split, domain, str_res)
432
+ return eval_dict
433
+
434
+
435
+ def print_results(eval_dict, split, domain, str_res='results'):
436
+ print('----------------------------------------------------------------------------------------------------------------------------')
437
+ print('Testing model on domain %s' % domain)
438
+ print('--------------- Dependency Parsing - %s ---------------' % split)
439
+ print(
440
+ str_res + ' on ' + split + ' W. Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%% (epoch: %d)' % (
441
+ eval_dict['dp_ucorrect'], eval_dict['dp_lcorrect'], eval_dict['dp_total'],
442
+ eval_dict['dp_ucorrect'] * 100 / eval_dict['dp_total'],
443
+ eval_dict['dp_lcorrect'] * 100 / eval_dict['dp_total'],
444
+ eval_dict['dp_ucomplete_match'] * 100 / eval_dict['dp_total_inst'],
445
+ eval_dict['dp_lcomplete_match'] * 100 / eval_dict['dp_total_inst'],
446
+ eval_dict['epoch']))
447
+ print(
448
+ str_res + ' on ' + split + ' Wo Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%% (epoch: %d)' % (
449
+ eval_dict['dp_ucorrect_nopunc'], eval_dict['dp_lcorrect_nopunc'], eval_dict['dp_total_nopunc'],
450
+ eval_dict['dp_ucorrect_nopunc'] * 100 / eval_dict['dp_total_nopunc'],
451
+ eval_dict['dp_lcorrect_nopunc'] * 100 / eval_dict['dp_total_nopunc'],
452
+ eval_dict['dp_ucomplete_match_nopunc'] * 100 / eval_dict['dp_total_inst'],
453
+ eval_dict['dp_lcomplete_match_nopunc'] * 100 / eval_dict['dp_total_inst'],
454
+ eval_dict['epoch']))
455
+ print(str_res + ' on ' + split + ' Root: corr: %d, total: %d, acc: %.2f%% (epoch: %d)' % (
456
+ eval_dict['dp_root_correct'], eval_dict['dp_total_root'],
457
+ eval_dict['dp_root_correct'] * 100 / eval_dict['dp_total_root'], eval_dict['epoch']))
458
+ print('\n')
459
+
460
+ def write_results(args, data, data_domain, split, model, model_domain, eval_dict):
461
+ str_file = args.full_model_name + '_' + split + '_model_domain_' + model_domain + '_data_domain_' + data_domain
462
+ res_filename = str_file + '_res.txt'
463
+ pred_filename = str_file + '_pred.txt'
464
+ gold_filename = str_file + '_gold.txt'
465
+ if eval_dict is not None:
466
+ # save results dictionary into a file
467
+ with open(res_filename, 'w') as f:
468
+ json.dump(eval_dict, f)
469
+
470
+ # save predictions and gold labels into files
471
+ pred_writer = Writer(args.alphabets)
472
+ gold_writer = Writer(args.alphabets)
473
+ pred_writer.start(pred_filename)
474
+ gold_writer.start(gold_filename)
475
+ for batch in prepare_data.iterate_batch(data, args.batch_size, args.device):
476
+ word, char, pos, ner, heads, arc_tags, auto_label, masks, lengths = batch
477
+ out_arc, out_arc_tag, masks, lengths = model.forward(word, char, pos, mask=masks, length=lengths)
478
+ heads_pred, arc_tags_pred, _ = model.decode(out_arc, out_arc_tag, mask=masks, length=lengths,
479
+ leading_symbolic=prepare_data.NUM_SYMBOLIC_TAGS)
480
+ lengths = lengths.cpu().numpy()
481
+ word = word.data.cpu().numpy()
482
+ pos = pos.data.cpu().numpy()
483
+ ner = ner.data.cpu().numpy()
484
+ heads = heads.data.cpu().numpy()
485
+ arc_tags = arc_tags.data.cpu().numpy()
486
+ heads_pred = heads_pred.data.cpu().numpy()
487
+ arc_tags_pred = arc_tags_pred.data.cpu().numpy()
488
+ # writing predictions
489
+ pred_writer.write(word, pos, ner, heads_pred, arc_tags_pred, lengths, symbolic_root=True)
490
+ # writing gold labels
491
+ gold_writer.write(word, pos, ner, heads, arc_tags, lengths, symbolic_root=True)
492
+
493
+ pred_writer.close()
494
+ gold_writer.close()
495
+
496
+ def main():
497
+ logger.info("Reading and creating arguments")
498
+ args = read_arguments()
499
+ logger.info("Reading Data")
500
+ datasets = {}
501
+ for split in args.splits:
502
+ print("Splits are:",split)
503
+ dataset = prepare_data.read_data_to_variable(args.data_paths[split], args.alphabets, args.device,
504
+ symbolic_root=True)
505
+ datasets[split] = dataset
506
+ if args.set_num_training_samples is not None:
507
+ print('Setting train and dev to %d samples' % args.set_num_training_samples)
508
+ datasets = rearrange_splits.rearranging_splits(datasets, args.set_num_training_samples)
509
+ logger.info("Creating Networks")
510
+ num_data = sum(datasets['train'][1])
511
+ model, optimizer, dev_eval_dict, test_eval_dict, start_epoch = build_model_and_optimizer(args)
512
+ best_model = deepcopy(model)
513
+ best_optimizer = deepcopy(optimizer)
514
+
515
+ logger.info('Training INFO of in domain %s' % args.domain)
516
+ logger.info('Training on Dependecy Parsing')
517
+ logger.info("train: gamma: %f, batch: %d, clip: %.2f, unk replace: %.2f" % (args.gamma, args.batch_size, args.clip, args.unk_replace))
518
+ logger.info('number of training samples for %s is: %d' % (args.domain, num_data))
519
+ logger.info("dropout(in, out, rnn): (%.2f, %.2f, %s)" % (args.p_in, args.p_out, args.p_rnn))
520
+ logger.info("num_epochs: %d" % (args.num_epochs))
521
+ print('\n')
522
+
523
+ if not args.eval_mode:
524
+ logger.info("Training")
525
+ num_batches = prepare_data.calc_num_batches(datasets['train'], args.batch_size)
526
+ lr = args.learning_rate
527
+ patient = 0
528
+ decay = 0
529
+ for epoch in range(start_epoch + 1, args.num_epochs + 1):
530
+ print('Epoch %d (Training: rnn mode: %s, optimizer: %s, learning rate=%.6f, eps=%.1e, decay rate=%.2f (schedule=%d, decay=%d)): ' % (
531
+ epoch, args.rnn_mode, args.opt, lr, args.epsilon, args.decay_rate, args.schedule, decay))
532
+ model.train()
533
+ total_loss = 0.0
534
+ total_arc_loss = 0.0
535
+ total_arc_tag_loss = 0.0
536
+ total_train_inst = 0.0
537
+
538
+ train_iter = prepare_data.iterate_batch_rand_bucket_choosing(
539
+ datasets['train'], args.batch_size, args.device, unk_replace=args.unk_replace)
540
+ start_time = time.time()
541
+ batch_num = 0
542
+ for batch_num, batch in enumerate(train_iter):
543
+ batch_num = batch_num + 1
544
+ optimizer.zero_grad()
545
+ # compute loss of main task
546
+ word, char, pos, ner_tags, heads, arc_tags, auto_label, masks, lengths = batch
547
+ out_arc, out_arc_tag, masks, lengths = model.forward(word, char, pos, mask=masks, length=lengths)
548
+ loss_arc, loss_arc_tag = model.loss(out_arc, out_arc_tag, heads, arc_tags, mask=masks, length=lengths)
549
+ loss = loss_arc + loss_arc_tag
550
+
551
+ # update losses
552
+ num_insts = masks.data.sum() - word.size(0)
553
+ total_arc_loss += loss_arc.item() * num_insts
554
+ total_arc_tag_loss += loss_arc_tag.item() * num_insts
555
+ total_loss += loss.item() * num_insts
556
+ total_train_inst += num_insts
557
+ # optimize parameters
558
+ loss.backward()
559
+ clip_grad_norm_(model.parameters(), args.clip)
560
+ optimizer.step()
561
+
562
+ time_ave = (time.time() - start_time) / batch_num
563
+ time_left = (num_batches - batch_num) * time_ave
564
+
565
+ # update log
566
+ if batch_num % 50 == 0:
567
+ log_info = 'train: %d/%d, domain: %s, total loss: %.2f, arc_loss: %.2f, arc_tag_loss: %.2f, time left: %.2fs' % \
568
+ (batch_num, num_batches, args.domain, total_loss / total_train_inst, total_arc_loss / total_train_inst,
569
+ total_arc_tag_loss / total_train_inst, time_left)
570
+ sys.stdout.write(log_info)
571
+ sys.stdout.write('\n')
572
+ sys.stdout.flush()
573
+ print('\n')
574
+ print('train: %d/%d, domain: %s, total_loss: %.2f, arc_loss: %.2f, arc_tag_loss: %.2f, time: %.2fs' %
575
+ (batch_num, num_batches, args.domain, total_loss / total_train_inst, total_arc_loss / total_train_inst,
576
+ total_arc_tag_loss / total_train_inst, time.time() - start_time))
577
+
578
+ dev_eval_dict, test_eval_dict, best_model, best_optimizer, patient = in_domain_evaluation(args, datasets, model, optimizer, dev_eval_dict, test_eval_dict, epoch, best_model, best_optimizer, patient)
579
+ if patient >= args.schedule:
580
+ lr = args.learning_rate / (1.0 + epoch * args.decay_rate)
581
+ optimizer = generate_optimizer(args, lr, model.parameters())
582
+ print('updated learning rate to %.6f' % lr)
583
+ patient = 0
584
+ print_results(test_eval_dict['in_domain'], 'test', args.domain, 'best_results')
585
+ print('\n')
586
+ for split in datasets.keys():
587
+ eval_dict = evaluation(args, datasets[split], split, best_model, args.domain, epoch, 'best_results')
588
+ write_results(args, datasets[split], args.domain, split, model, args.domain, eval_dict)
589
+
590
+ else:
591
+ logger.info("Evaluating")
592
+ epoch = start_epoch
593
+ for split in ['train', 'dev', 'test','poetry','prose']:
594
+ eval_dict = evaluation(args, datasets[split], split, model, args.domain, epoch, 'best_results')
595
+ write_results(args, datasets[split], args.domain, split, model, args.domain, eval_dict)
596
+
597
+
598
+ if __name__ == '__main__':
599
+ main()
examples/GraphParser_MTL_POS.py ADDED
@@ -0,0 +1,633 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import print_function
2
+ import sys
3
+ from os import path, makedirs
4
+
5
+ sys.path.append(".")
6
+ sys.path.append("..")
7
+
8
+ import argparse
9
+ from copy import deepcopy
10
+ import json
11
+ import numpy as np
12
+ import torch
13
+ import torch.nn as nn
14
+ from collections import namedtuple
15
+ from utils.io_ import seeds, Writer, get_logger, prepare_data, rearrange_splits
16
+ from utils.models.parsing_gating_mtl_pos import BiAffine_Parser_Gated
17
+ # from utils.models.sequence_tagger import Sequence_Tagger
18
+ from utils import load_word_embeddings
19
+ from utils.tasks import parse
20
+ import time
21
+ from torch.nn.utils import clip_grad_norm_
22
+ from torch.optim import Adam, SGD
23
+ import uuid
24
+
25
+ uid = uuid.uuid4().hex[:6]
26
+
27
+ logger = get_logger('GraphParser')
28
+
29
+ def read_arguments():
30
+ args_ = argparse.ArgumentParser(description='Sovling GraphParser')
31
+ args_.add_argument('--dataset', choices=['ontonotes', 'ud'], help='Dataset', required=True)
32
+ args_.add_argument('--domain', help='domain/language', required=True)
33
+ args_.add_argument('--rnn_mode', choices=['RNN', 'LSTM', 'GRU'], help='architecture of rnn',
34
+ required=True)
35
+ args_.add_argument('--gating',action='store_true', help='use gated mechanism')
36
+ args_.add_argument('--num_gates', type=int, default=0, help='number of gates for gating mechanism')
37
+ args_.add_argument('--num_epochs', type=int, default=200, help='Number of training epochs')
38
+ args_.add_argument('--batch_size', type=int, default=64, help='Number of sentences in each batch')
39
+ args_.add_argument('--hidden_size', type=int, default=256, help='Number of hidden units in RNN')
40
+ args_.add_argument('--arc_space', type=int, default=128, help='Dimension of tag space')
41
+ args_.add_argument('--arc_tag_space', type=int, default=128, help='Dimension of tag space')
42
+ args_.add_argument('--num_layers', type=int, default=1, help='Number of layers of RNN')
43
+ args_.add_argument('--num_filters', type=int, default=50, help='Number of filters in CNN')
44
+ args_.add_argument('--kernel_size', type=int, default=3, help='Size of Kernel for CNN')
45
+ args_.add_argument('--use_pos', action='store_true', help='use part-of-speech embedding.')
46
+ args_.add_argument('--use_char', action='store_true', help='use character embedding and CNN.')
47
+ args_.add_argument('--word_dim', type=int, default=300, help='Dimension of word embeddings')
48
+ args_.add_argument('--pos_dim', type=int, default=50, help='Dimension of POS embeddings')
49
+ args_.add_argument('--char_dim', type=int, default=50, help='Dimension of Character embeddings')
50
+ args_.add_argument('--initializer', choices=['xavier'], help='initialize model parameters')
51
+ args_.add_argument('--opt', choices=['adam', 'sgd'], help='optimization algorithm')
52
+ args_.add_argument('--momentum', type=float, default=0.9, help='momentum of optimizer')
53
+ args_.add_argument('--betas', nargs=2, type=float, default=[0.9, 0.9], help='betas of optimizer')
54
+ args_.add_argument('--learning_rate', type=float, default=0.01, help='Learning rate')
55
+ args_.add_argument('--decay_rate', type=float, default=0.05, help='Decay rate of learning rate')
56
+ args_.add_argument('--schedule', type=int, help='schedule for learning rate decay')
57
+ args_.add_argument('--clip', type=float, default=5.0, help='gradient clipping')
58
+ args_.add_argument('--gamma', type=float, default=0.0, help='weight for regularization')
59
+ args_.add_argument('--epsilon', type=float, default=1e-8, help='epsilon for adam')
60
+ args_.add_argument('--p_rnn', nargs=2, type=float, required=True, help='dropout rate for RNN')
61
+ args_.add_argument('--p_in', type=float, default=0.33, help='dropout rate for input embeddings')
62
+ args_.add_argument('--p_out', type=float, default=0.33, help='dropout rate for output layer')
63
+ args_.add_argument('--arc_decode', choices=['mst', 'greedy'], help='arc decoding algorithm', required=True)
64
+ args_.add_argument('--unk_replace', type=float, default=0.,
65
+ help='The rate to replace a singleton word with UNK')
66
+ args_.add_argument('--punct_set', nargs='+', type=str, help='List of punctuations')
67
+ args_.add_argument('--word_embedding', choices=['random', 'glove', 'fasttext', 'word2vec'],
68
+ help='Embedding for words')
69
+ args_.add_argument('--word_path', help='path for word embedding dict - in case word_embedding is not random')
70
+ args_.add_argument('--freeze_word_embeddings', action='store_true', help='frozen the word embedding (disable fine-tuning).')
71
+ args_.add_argument('--freeze_sequence_taggers', action='store_true', help='frozen the BiLSTMs of the pre-trained taggers.')
72
+ args_.add_argument('--char_embedding', choices=['random','hellwig'], help='Embedding for characters',
73
+ required=True)
74
+ args_.add_argument('--pos_embedding', choices=['random','one_hot'], help='Embedding for pos',
75
+ required=True)
76
+ args_.add_argument('--char_path', help='path for character embedding dict')
77
+ args_.add_argument('--pos_path', help='path for pos embedding dict')
78
+ args_.add_argument('--set_num_training_samples', type=int, help='downsampling training set to a fixed number of samples')
79
+ args_.add_argument('--model_path', help='path for saving model file.', required=True)
80
+ args_.add_argument('--load_path', help='path for loading saved source model file.', default=None)
81
+ args_.add_argument('--load_sequence_taggers_paths', nargs='+', help='path for loading saved sequence_tagger saved_models files.', default=None)
82
+ args_.add_argument('--strict',action='store_true', help='if True loaded model state should contin '
83
+ 'exactly the same keys as current model')
84
+ args_.add_argument('--eval_mode', action='store_true', help='evaluating model without training it')
85
+ args = args_.parse_args()
86
+ args_dict = {}
87
+ args_dict['dataset'] = args.dataset
88
+ args_dict['domain'] = args.domain
89
+ args_dict['rnn_mode'] = args.rnn_mode
90
+ args_dict['gating'] = args.gating
91
+ args_dict['num_gates'] = args.num_gates
92
+ args_dict['arc_decode'] = args.arc_decode
93
+ # args_dict['splits'] = ['train', 'dev', 'test']
94
+ args_dict['splits'] = ['train', 'dev', 'test','poetry','prose']
95
+ args_dict['model_path'] = args.model_path
96
+ if not path.exists(args_dict['model_path']):
97
+ makedirs(args_dict['model_path'])
98
+ args_dict['data_paths'] = {}
99
+ if args_dict['dataset'] == 'ontonotes':
100
+ data_path = 'data/onto_pos_ner_dp'
101
+ else:
102
+ data_path = 'data/ud_pos_ner_dp'
103
+ for split in args_dict['splits']:
104
+ args_dict['data_paths'][split] = data_path + '_' + split + '_' + args_dict['domain']
105
+ ###################################
106
+ args_dict['data_paths']['poetry'] = data_path + '_' + 'poetry' + '_' + args_dict['domain']
107
+ args_dict['data_paths']['prose'] = data_path + '_' + 'prose' + '_' + args_dict['domain']
108
+ # args_dict['data_paths']['poetry'] = 'data/Shishu_300' + '_' + 'poetry' + '_' + args_dict['domain']
109
+ # args_dict['data_paths']['prose'] = 'data/Shishu_300' + '_' + 'prose' + '_' + args_dict['domain']
110
+
111
+ ###################################
112
+ args_dict['alphabet_data_paths'] = {}
113
+ for split in args_dict['splits']:
114
+ if args_dict['dataset'] == 'ontonotes':
115
+ args_dict['alphabet_data_paths'][split] = data_path + '_' + split + '_' + 'all'
116
+ else:
117
+ if '_' in args_dict['domain']:
118
+ args_dict['alphabet_data_paths'][split] = data_path + '_' + split + '_' + args_dict['domain'].split('_')[0]
119
+ else:
120
+ args_dict['alphabet_data_paths'][split] = args_dict['data_paths'][split]
121
+ args_dict['model_name'] = 'domain_' + args_dict['domain']
122
+ args_dict['full_model_name'] = path.join(args_dict['model_path'],args_dict['model_name'])
123
+ args_dict['load_path'] = args.load_path
124
+ args_dict['load_sequence_taggers_paths'] = args.load_sequence_taggers_paths
125
+ if args_dict['load_sequence_taggers_paths'] is not None:
126
+ args_dict['gating'] = True
127
+ args_dict['num_gates'] = len(args_dict['load_sequence_taggers_paths']) + 1
128
+ else:
129
+ if not args_dict['gating']:
130
+ args_dict['num_gates'] = 0
131
+ args_dict['strict'] = args.strict
132
+ args_dict['num_epochs'] = args.num_epochs
133
+ args_dict['batch_size'] = args.batch_size
134
+ args_dict['hidden_size'] = args.hidden_size
135
+ args_dict['arc_space'] = args.arc_space
136
+ args_dict['arc_tag_space'] = args.arc_tag_space
137
+ args_dict['num_layers'] = args.num_layers
138
+ args_dict['num_filters'] = args.num_filters
139
+ args_dict['kernel_size'] = args.kernel_size
140
+ args_dict['learning_rate'] = args.learning_rate
141
+ args_dict['initializer'] = nn.init.xavier_uniform_ if args.initializer == 'xavier' else None
142
+ args_dict['opt'] = args.opt
143
+ args_dict['momentum'] = args.momentum
144
+ args_dict['betas'] = tuple(args.betas)
145
+ args_dict['epsilon'] = args.epsilon
146
+ args_dict['decay_rate'] = args.decay_rate
147
+ args_dict['clip'] = args.clip
148
+ args_dict['gamma'] = args.gamma
149
+ args_dict['schedule'] = args.schedule
150
+ args_dict['p_rnn'] = tuple(args.p_rnn)
151
+ args_dict['p_in'] = args.p_in
152
+ args_dict['p_out'] = args.p_out
153
+ args_dict['unk_replace'] = args.unk_replace
154
+ args_dict['set_num_training_samples'] = args.set_num_training_samples
155
+ args_dict['punct_set'] = None
156
+ if args.punct_set is not None:
157
+ args_dict['punct_set'] = set(args.punct_set)
158
+ logger.info("punctuations(%d): %s" % (len(args_dict['punct_set']), ' '.join(args_dict['punct_set'])))
159
+ args_dict['freeze_word_embeddings'] = args.freeze_word_embeddings
160
+ args_dict['freeze_sequence_taggers'] = args.freeze_sequence_taggers
161
+ args_dict['word_embedding'] = args.word_embedding
162
+ args_dict['word_path'] = args.word_path
163
+ args_dict['use_char'] = args.use_char
164
+ args_dict['char_embedding'] = args.char_embedding
165
+ args_dict['char_path'] = args.char_path
166
+ args_dict['pos_embedding'] = args.pos_embedding
167
+ args_dict['pos_path'] = args.pos_path
168
+ args_dict['use_pos'] = args.use_pos
169
+ args_dict['pos_dim'] = args.pos_dim
170
+ args_dict['word_dict'] = None
171
+ args_dict['word_dim'] = args.word_dim
172
+ if args_dict['word_embedding'] != 'random' and args_dict['word_path']:
173
+ args_dict['word_dict'], args_dict['word_dim'] = load_word_embeddings.load_embedding_dict(args_dict['word_embedding'],
174
+ args_dict['word_path'])
175
+ args_dict['char_dict'] = None
176
+ args_dict['char_dim'] = args.char_dim
177
+ if args_dict['char_embedding'] != 'random':
178
+ args_dict['char_dict'], args_dict['char_dim'] = load_word_embeddings.load_embedding_dict(args_dict['char_embedding'],
179
+ args_dict['char_path'])
180
+ args_dict['pos_dict'] = None
181
+ if args_dict['pos_embedding'] != 'random':
182
+ args_dict['pos_dict'], args_dict['pos_dim'] = load_word_embeddings.load_embedding_dict(args_dict['pos_embedding'],
183
+ args_dict['pos_path'])
184
+ args_dict['alphabet_path'] = path.join(args_dict['model_path'], 'alphabets' + '_src_domain_' + args_dict['domain'] + '/')
185
+ args_dict['model_name'] = path.join(args_dict['model_path'], args_dict['model_name'])
186
+ args_dict['eval_mode'] = args.eval_mode
187
+ args_dict['device'] = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
188
+ args_dict['word_status'] = 'frozen' if args.freeze_word_embeddings else 'fine tune'
189
+ args_dict['char_status'] = 'enabled' if args.use_char else 'disabled'
190
+ args_dict['pos_status'] = 'enabled' if args.use_pos else 'disabled'
191
+ logger.info("Saving arguments to file")
192
+ save_args(args, args_dict['full_model_name'])
193
+ logger.info("Creating Alphabets")
194
+ alphabet_dict = creating_alphabets(args_dict['alphabet_path'], args_dict['alphabet_data_paths'], args_dict['word_dict'])
195
+ args_dict = {**args_dict, **alphabet_dict}
196
+ ARGS = namedtuple('ARGS', args_dict.keys())
197
+ my_args = ARGS(**args_dict)
198
+ return my_args
199
+
200
+
201
+ def creating_alphabets(alphabet_path, alphabet_data_paths, word_dict):
202
+ train_paths = alphabet_data_paths['train']
203
+ extra_paths = [v for k,v in alphabet_data_paths.items() if k != 'train']
204
+ alphabet_dict = {}
205
+ alphabet_dict['alphabets'] = prepare_data.create_alphabets(alphabet_path,
206
+ train_paths,
207
+ extra_paths=extra_paths,
208
+ max_vocabulary_size=100000,
209
+ embedd_dict=word_dict)
210
+ for k, v in alphabet_dict['alphabets'].items():
211
+ num_key = 'num_' + k.split('_')[0]
212
+ alphabet_dict[num_key] = v.size()
213
+ logger.info("%s : %d" % (num_key, alphabet_dict[num_key]))
214
+ return alphabet_dict
215
+
216
+ def construct_embedding_table(alphabet, tokens_dict, dim, token_type='word'):
217
+ if tokens_dict is None:
218
+ return None
219
+ scale = np.sqrt(3.0 / dim)
220
+ table = np.empty([alphabet.size(), dim], dtype=np.float32)
221
+ table[prepare_data.UNK_ID, :] = np.random.uniform(-scale, scale, [1, dim]).astype(np.float32)
222
+ oov_tokens = 0
223
+ for token, index in alphabet.items():
224
+ if token in tokens_dict:
225
+ embedding = tokens_dict[token]
226
+ elif token.lower() in tokens_dict:
227
+ embedding = tokens_dict[token.lower()]
228
+ else:
229
+ embedding = np.random.uniform(-scale, scale, [1, dim]).astype(np.float32)
230
+ oov_tokens += 1
231
+ table[index, :] = embedding
232
+ print('token type : %s, number of oov: %d' % (token_type, oov_tokens))
233
+ table = torch.from_numpy(table)
234
+ return table
235
+
236
+ def save_args(args, full_model_name):
237
+ arg_path = full_model_name + '.arg.json'
238
+ argparse_dict = vars(args)
239
+ with open(arg_path, 'w') as f:
240
+ json.dump(argparse_dict, f)
241
+
242
+ def generate_optimizer(args, lr, params):
243
+ params = filter(lambda param: param.requires_grad, params)
244
+ if args.opt == 'adam':
245
+ return Adam(params, lr=lr, betas=args.betas, weight_decay=args.gamma, eps=args.epsilon)
246
+ elif args.opt == 'sgd':
247
+ return SGD(params, lr=lr, momentum=args.momentum, weight_decay=args.gamma, nesterov=True)
248
+ else:
249
+ raise ValueError('Unknown optimization algorithm: %s' % args.opt)
250
+
251
+
252
+ def save_checkpoint(model, optimizer, opt, dev_eval_dict, test_eval_dict, full_model_name):
253
+ path_name = full_model_name + '.pt'
254
+ print('Saving model to: %s' % path_name)
255
+ state = {'model_state_dict': model.state_dict(),
256
+ 'optimizer_state_dict': optimizer.state_dict(),
257
+ 'opt': opt,
258
+ 'dev_eval_dict': dev_eval_dict,
259
+ 'test_eval_dict': test_eval_dict}
260
+ torch.save(state, path_name)
261
+
262
+
263
+ def load_checkpoint(args, model, optimizer, dev_eval_dict, test_eval_dict, start_epoch, load_path, strict=False):
264
+ print('Loading saved model from: %s' % load_path)
265
+ checkpoint = torch.load(load_path, map_location=args.device)
266
+ if checkpoint['opt'] != args.opt:
267
+ raise ValueError('loaded optimizer type is: %s instead of: %s' % (checkpoint['opt'], args.opt))
268
+ model.load_state_dict(checkpoint['model_state_dict'], strict=strict)
269
+
270
+ if strict:
271
+ generate_optimizer(args, args.learning_rate, model.parameters())
272
+ optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
273
+ for state in optimizer.state.values():
274
+ for k, v in state.items():
275
+ if isinstance(v, torch.Tensor):
276
+ state[k] = v.to(args.device)
277
+ dev_eval_dict = checkpoint['dev_eval_dict']
278
+ test_eval_dict = checkpoint['test_eval_dict']
279
+ start_epoch = dev_eval_dict['in_domain']['epoch']
280
+ return model, optimizer, dev_eval_dict, test_eval_dict, start_epoch
281
+
282
+
283
+ def build_model_and_optimizer(args):
284
+ word_table = construct_embedding_table(args.alphabets['word_alphabet'], args.word_dict, args.word_dim, token_type='word')
285
+ char_table = construct_embedding_table(args.alphabets['char_alphabet'], args.char_dict, args.char_dim, token_type='char')
286
+ pos_table = construct_embedding_table(args.alphabets['pos_alphabet'], args.pos_dict, args.pos_dim, token_type='pos')
287
+ model = BiAffine_Parser_Gated(args.word_dim, args.num_word, args.char_dim, args.num_char,
288
+ args.use_pos, args.use_char, args.pos_dim, args.num_pos,
289
+ args.num_filters, args.kernel_size, args.rnn_mode,
290
+ args.hidden_size, args.num_layers, args.num_arc,
291
+ args.arc_space, args.arc_tag_space, args.num_gates,
292
+ embedd_word=word_table, embedd_char=char_table, embedd_pos=pos_table,
293
+ p_in=args.p_in, p_out=args.p_out, p_rnn=args.p_rnn,
294
+ biaffine=True, arc_decode=args.arc_decode, initializer=args.initializer)
295
+ # Tagger = Sequence_Tagger(args.word_dim, args.num_word, args.char_dim, args.num_char,
296
+ # args.use_pos, args.use_char, args.pos_dim, args.num_pos,
297
+ # args.num_filters, args.kernel_size, args.rnn_mode,
298
+ # args.hidden_size, args.num_layers, args.arc_tag_space, 519,
299
+ # embedd_word=word_table, embedd_char=char_table, embedd_pos=pos_table,
300
+ # p_in=args.p_in, p_out=args.p_out, p_rnn=args.p_rnn,
301
+ # initializer=args.initializer)
302
+ print(model)
303
+ # print(Tagger)
304
+ # Tagger.rnn_encoder = model.rnn_encoder
305
+ optimizer = generate_optimizer(args, args.learning_rate, model.parameters())
306
+ start_epoch = 0
307
+ dev_eval_dict = {'in_domain': initialize_eval_dict()}
308
+ test_eval_dict = {'in_domain': initialize_eval_dict()}
309
+ if args.load_path:
310
+ model, optimizer, dev_eval_dict, test_eval_dict, start_epoch = \
311
+ load_checkpoint(args, model, optimizer,
312
+ dev_eval_dict, test_eval_dict,
313
+ start_epoch, args.load_path, strict=False)
314
+ if args.load_sequence_taggers_paths:
315
+ pretrained_dict = {}
316
+ model_dict = model.state_dict()
317
+ for idx, path in enumerate(args.load_sequence_taggers_paths):
318
+ print('Loading saved sequence_tagger from: %s' % path)
319
+ checkpoint = torch.load(path, map_location=args.device)
320
+ for k, v in checkpoint['model_state_dict'].items():
321
+ if 'rnn_encoder.' in k:
322
+ pretrained_dict['extra_rnn_encoders.' + str(idx) + '.' + k.replace('rnn_encoder.', '')] = v
323
+ model_dict.update(pretrained_dict)
324
+ model.load_state_dict(model_dict)
325
+ if args.freeze_sequence_taggers:
326
+ print('Freezing Classifiers')
327
+ for name, parameter in model.named_parameters():
328
+ if 'extra_rnn_encoders' in name:
329
+ parameter.requires_grad = False
330
+ if args.freeze_word_embeddings:
331
+ model.rnn_encoder.word_embedd.weight.requires_grad = False
332
+ # model.rnn_encoder.char_embedd.weight.requires_grad = False
333
+ # model.rnn_encoder.pos_embedd.weight.requires_grad = False
334
+ device = args.device
335
+ model.to(device)
336
+ # Tagger.to(device)
337
+ return model, optimizer, dev_eval_dict, test_eval_dict, start_epoch
338
+
339
+
340
+ def initialize_eval_dict():
341
+ eval_dict = {}
342
+ eval_dict['dp_uas'] = 0.0
343
+ eval_dict['dp_las'] = 0.0
344
+ eval_dict['epoch'] = 0
345
+ eval_dict['dp_ucorrect'] = 0.0
346
+ eval_dict['dp_lcorrect'] = 0.0
347
+ eval_dict['dp_total'] = 0.0
348
+ eval_dict['dp_ucomplete_match'] = 0.0
349
+ eval_dict['dp_lcomplete_match'] = 0.0
350
+ eval_dict['dp_ucorrect_nopunc'] = 0.0
351
+ eval_dict['dp_lcorrect_nopunc'] = 0.0
352
+ eval_dict['dp_total_nopunc'] = 0.0
353
+ eval_dict['dp_ucomplete_match_nopunc'] = 0.0
354
+ eval_dict['dp_lcomplete_match_nopunc'] = 0.0
355
+ eval_dict['dp_root_correct'] = 0.0
356
+ eval_dict['dp_total_root'] = 0.0
357
+ eval_dict['dp_total_inst'] = 0.0
358
+ eval_dict['dp_total'] = 0.0
359
+ eval_dict['dp_total_inst'] = 0.0
360
+ eval_dict['dp_total_nopunc'] = 0.0
361
+ eval_dict['dp_total_root'] = 0.0
362
+ return eval_dict
363
+
364
+ def in_domain_evaluation(args, datasets, model, optimizer, dev_eval_dict, test_eval_dict, epoch,
365
+ best_model, best_optimizer, patient):
366
+ # In-domain evaluation
367
+ curr_dev_eval_dict = evaluation(args, datasets['dev'], 'dev', model, args.domain, epoch, 'current_results')
368
+ is_best_in_domain = dev_eval_dict['in_domain']['dp_lcorrect_nopunc'] <= curr_dev_eval_dict['dp_lcorrect_nopunc'] or \
369
+ (dev_eval_dict['in_domain']['dp_lcorrect_nopunc'] == curr_dev_eval_dict['dp_lcorrect_nopunc'] and
370
+ dev_eval_dict['in_domain']['dp_ucorrect_nopunc'] <= curr_dev_eval_dict['dp_ucorrect_nopunc'])
371
+
372
+ if is_best_in_domain:
373
+ for key, value in curr_dev_eval_dict.items():
374
+ dev_eval_dict['in_domain'][key] = value
375
+ curr_test_eval_dict = evaluation(args, datasets['test'], 'test', model, args.domain, epoch, 'current_results')
376
+ for key, value in curr_test_eval_dict.items():
377
+ test_eval_dict['in_domain'][key] = value
378
+ best_model = deepcopy(model)
379
+ best_optimizer = deepcopy(optimizer)
380
+ patient = 0
381
+ else:
382
+ patient += 1
383
+ if epoch == args.num_epochs:
384
+ # save in-domain checkpoint
385
+ if args.set_num_training_samples is not None:
386
+ splits_to_write = datasets.keys()
387
+ else:
388
+ splits_to_write = ['dev', 'test']
389
+ for split in splits_to_write:
390
+ if split == 'dev':
391
+ eval_dict = dev_eval_dict['in_domain']
392
+ elif split == 'test':
393
+ eval_dict = test_eval_dict['in_domain']
394
+ else:
395
+ eval_dict = None
396
+ write_results(args, datasets[split], args.domain, split, best_model, args.domain, eval_dict)
397
+ print("Saving best model")
398
+ save_checkpoint(best_model, best_optimizer, args.opt, dev_eval_dict, test_eval_dict, args.full_model_name)
399
+
400
+ print('\n')
401
+ return dev_eval_dict, test_eval_dict, best_model, best_optimizer, patient
402
+
403
+
404
+ def evaluation(args, data, split, model, domain, epoch, str_res='results'):
405
+ # evaluate performance on data
406
+ model.eval()
407
+
408
+ eval_dict = initialize_eval_dict()
409
+ eval_dict['epoch'] = epoch
410
+ for batch in prepare_data.iterate_batch(data, args.batch_size, args.device):
411
+ word, char, pos, ner, heads, arc_tags, auto_label, masks, lengths = batch
412
+ out_arc, out_arc_tag, masks, lengths = model.forward(word, char, pos, mask=masks, length=lengths)
413
+ heads_pred, arc_tags_pred, _ = model.decode(out_arc, out_arc_tag, mask=masks, length=lengths,
414
+ leading_symbolic=prepare_data.NUM_SYMBOLIC_TAGS)
415
+ lengths = lengths.cpu().numpy()
416
+ word = word.data.cpu().numpy()
417
+ pos = pos.data.cpu().numpy()
418
+ ner = ner.data.cpu().numpy()
419
+ heads = heads.data.cpu().numpy()
420
+ arc_tags = arc_tags.data.cpu().numpy()
421
+ heads_pred = heads_pred.data.cpu().numpy()
422
+ arc_tags_pred = arc_tags_pred.data.cpu().numpy()
423
+ stats, stats_nopunc, stats_root, num_inst = parse.eval_(word, pos, heads_pred, arc_tags_pred, heads,
424
+ arc_tags, args.alphabets['word_alphabet'], args.alphabets['pos_alphabet'],
425
+ lengths, punct_set=args.punct_set, symbolic_root=True)
426
+ ucorr, lcorr, total, ucm, lcm = stats
427
+ ucorr_nopunc, lcorr_nopunc, total_nopunc, ucm_nopunc, lcm_nopunc = stats_nopunc
428
+ corr_root, total_root = stats_root
429
+ eval_dict['dp_ucorrect'] += ucorr
430
+ eval_dict['dp_lcorrect'] += lcorr
431
+ eval_dict['dp_total'] += total
432
+ eval_dict['dp_ucomplete_match'] += ucm
433
+ eval_dict['dp_lcomplete_match'] += lcm
434
+ eval_dict['dp_ucorrect_nopunc'] += ucorr_nopunc
435
+ eval_dict['dp_lcorrect_nopunc'] += lcorr_nopunc
436
+ eval_dict['dp_total_nopunc'] += total_nopunc
437
+ eval_dict['dp_ucomplete_match_nopunc'] += ucm_nopunc
438
+ eval_dict['dp_lcomplete_match_nopunc'] += lcm_nopunc
439
+ eval_dict['dp_root_correct'] += corr_root
440
+ eval_dict['dp_total_root'] += total_root
441
+ eval_dict['dp_total_inst'] += num_inst
442
+
443
+ eval_dict['dp_uas'] = eval_dict['dp_ucorrect'] * 100 / eval_dict['dp_total'] # considering w. punctuation
444
+ eval_dict['dp_las'] = eval_dict['dp_lcorrect'] * 100 / eval_dict['dp_total'] # considering w. punctuation
445
+ print_results(eval_dict, split, domain, str_res)
446
+ return eval_dict
447
+
448
+
449
+ def print_results(eval_dict, split, domain, str_res='results'):
450
+ print('----------------------------------------------------------------------------------------------------------------------------')
451
+ print('Testing model on domain %s' % domain)
452
+ print('--------------- Dependency Parsing - %s ---------------' % split)
453
+ print(
454
+ str_res + ' on ' + split + ' W. Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%% (epoch: %d)' % (
455
+ eval_dict['dp_ucorrect'], eval_dict['dp_lcorrect'], eval_dict['dp_total'],
456
+ eval_dict['dp_ucorrect'] * 100 / eval_dict['dp_total'],
457
+ eval_dict['dp_lcorrect'] * 100 / eval_dict['dp_total'],
458
+ eval_dict['dp_ucomplete_match'] * 100 / eval_dict['dp_total_inst'],
459
+ eval_dict['dp_lcomplete_match'] * 100 / eval_dict['dp_total_inst'],
460
+ eval_dict['epoch']))
461
+ print(
462
+ str_res + ' on ' + split + ' Wo Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%% (epoch: %d)' % (
463
+ eval_dict['dp_ucorrect_nopunc'], eval_dict['dp_lcorrect_nopunc'], eval_dict['dp_total_nopunc'],
464
+ eval_dict['dp_ucorrect_nopunc'] * 100 / eval_dict['dp_total_nopunc'],
465
+ eval_dict['dp_lcorrect_nopunc'] * 100 / eval_dict['dp_total_nopunc'],
466
+ eval_dict['dp_ucomplete_match_nopunc'] * 100 / eval_dict['dp_total_inst'],
467
+ eval_dict['dp_lcomplete_match_nopunc'] * 100 / eval_dict['dp_total_inst'],
468
+ eval_dict['epoch']))
469
+ print(str_res + ' on ' + split + ' Root: corr: %d, total: %d, acc: %.2f%% (epoch: %d)' % (
470
+ eval_dict['dp_root_correct'], eval_dict['dp_total_root'],
471
+ eval_dict['dp_root_correct'] * 100 / eval_dict['dp_total_root'], eval_dict['epoch']))
472
+ print('\n')
473
+
474
+ def write_results(args, data, data_domain, split, model, model_domain, eval_dict):
475
+ str_file = args.full_model_name + '_' + split + '_model_domain_' + model_domain + '_data_domain_' + data_domain
476
+ res_filename = str_file + '_res.txt'
477
+ pred_filename = str_file + '_pred.txt'
478
+ gold_filename = str_file + '_gold.txt'
479
+ if eval_dict is not None:
480
+ # save results dictionary into a file
481
+ with open(res_filename, 'w') as f:
482
+ json.dump(eval_dict, f)
483
+
484
+ # save predictions and gold labels into files
485
+ pred_writer = Writer(args.alphabets)
486
+ gold_writer = Writer(args.alphabets)
487
+ pred_writer.start(pred_filename)
488
+ gold_writer.start(gold_filename)
489
+ for batch in prepare_data.iterate_batch(data, args.batch_size, args.device):
490
+ word, char, pos, ner, heads, arc_tags, auto_label, masks, lengths = batch
491
+ out_arc, out_arc_tag, masks, lengths = model.forward(word, char, pos, mask=masks, length=lengths)
492
+ heads_pred, arc_tags_pred, _ = model.decode(out_arc, out_arc_tag, mask=masks, length=lengths,
493
+ leading_symbolic=prepare_data.NUM_SYMBOLIC_TAGS)
494
+ lengths = lengths.cpu().numpy()
495
+ word = word.data.cpu().numpy()
496
+ pos = pos.data.cpu().numpy()
497
+ ner = ner.data.cpu().numpy()
498
+ heads = heads.data.cpu().numpy()
499
+ arc_tags = arc_tags.data.cpu().numpy()
500
+ heads_pred = heads_pred.data.cpu().numpy()
501
+ arc_tags_pred = arc_tags_pred.data.cpu().numpy()
502
+ # writing predictions
503
+ pred_writer.write(word, pos, ner, heads_pred, arc_tags_pred, lengths, symbolic_root=True)
504
+ # writing gold labels
505
+ gold_writer.write(word, pos, ner, heads, arc_tags, lengths, symbolic_root=True)
506
+
507
+ pred_writer.close()
508
+ gold_writer.close()
509
+
510
+ def main():
511
+ logger.info("Reading and creating arguments")
512
+ args = read_arguments()
513
+ logger.info("Reading Data")
514
+ datasets = {}
515
+ for split in args.splits:
516
+ print("Splits are:",split)
517
+ dataset = prepare_data.read_data_to_variable(args.data_paths[split], args.alphabets, args.device,
518
+ symbolic_root=True)
519
+ datasets[split] = dataset
520
+ if args.set_num_training_samples is not None:
521
+ print('Setting train and dev to %d samples' % args.set_num_training_samples)
522
+ datasets = rearrange_splits.rearranging_splits(datasets, args.set_num_training_samples)
523
+ logger.info("Creating Networks")
524
+ num_data = sum(datasets['train'][1])
525
+ model, optimizer, dev_eval_dict, test_eval_dict, start_epoch = build_model_and_optimizer(args)
526
+ best_model = deepcopy(model)
527
+ best_optimizer = deepcopy(optimizer)
528
+
529
+ logger.info('Training INFO of in domain %s' % args.domain)
530
+ logger.info('Training on Dependecy Parsing')
531
+ logger.info("train: gamma: %f, batch: %d, clip: %.2f, unk replace: %.2f" % (args.gamma, args.batch_size, args.clip, args.unk_replace))
532
+ logger.info('number of training samples for %s is: %d' % (args.domain, num_data))
533
+ logger.info("dropout(in, out, rnn): (%.2f, %.2f, %s)" % (args.p_in, args.p_out, args.p_rnn))
534
+ logger.info("num_epochs: %d" % (args.num_epochs))
535
+ print('\n')
536
+
537
+ if not args.eval_mode:
538
+ logger.info("Training")
539
+ num_batches = prepare_data.calc_num_batches(datasets['train'], args.batch_size)
540
+ lr = args.learning_rate
541
+ patient = 0
542
+ decay = 0
543
+ for epoch in range(start_epoch + 1, args.num_epochs + 1):
544
+ print('Epoch %d (Training: rnn mode: %s, optimizer: %s, learning rate=%.6f, eps=%.1e, decay rate=%.2f (schedule=%d, decay=%d)): ' % (
545
+ epoch, args.rnn_mode, args.opt, lr, args.epsilon, args.decay_rate, args.schedule, decay))
546
+ model.train()
547
+ total_loss = 0.0
548
+ total_mtl_pos_loss =0.0
549
+ total_mtl_case_loss =0.0
550
+ total_mtl_label_loss =0.0
551
+ total_arc_loss = 0.0
552
+ total_arc_tag_loss = 0.0
553
+ total_train_inst = 0.0
554
+
555
+ train_iter = prepare_data.iterate_batch_rand_bucket_choosing(
556
+ datasets['train'], args.batch_size, args.device, unk_replace=args.unk_replace)
557
+ start_time = time.time()
558
+ batch_num = 0
559
+ for batch_num, batch in enumerate(train_iter):
560
+ batch_num = batch_num + 1
561
+ optimizer.zero_grad()
562
+ # compute loss of main task
563
+ word, char, pos, ner_tags, heads, arc_tags, auto_label, masks, lengths = batch
564
+ out_arc, out_arc_tag, masks, lengths = model.forward(word, char, pos, mask=masks, length=lengths)
565
+ loss_arc, loss_arc_tag = model.loss(out_arc, out_arc_tag, heads, arc_tags, mask=masks, length=lengths)
566
+ ############################################################
567
+ loss_morph = model.loss_morph(word, char, pos, mask=masks, length=lengths)
568
+ # loss_case = model.loss_case(word, char, ner_tags, mask=masks, length=lengths)
569
+ # loss_label = model.loss_label(word, char, arc_tags, mask=masks, length=lengths)
570
+
571
+ ## Adding multi-tasking
572
+ # Tag_output, Tag_masks, Tag_lengths = Tagger.forward(word, char, pos, mask=masks, length=lengths)
573
+ # Tag_loss = Tagger.loss(Tag_output, pos, mask=Tag_masks, length=Tag_lengths)
574
+ # # update losses
575
+ # Tag_num_insts = Tag_masks.data.sum() - word.size(0)
576
+ # Tag_total_loss += Tag_loss.item() * Tag_num_insts
577
+ # Tag_total_train_inst += Tag_num_insts
578
+
579
+ #############################################################
580
+ loss = loss_arc + loss_arc_tag + loss_morph
581
+
582
+ # update losses
583
+ num_insts = masks.data.sum() - word.size(0)
584
+ total_arc_loss += loss_arc.item() * num_insts
585
+ total_arc_tag_loss += loss_arc_tag.item() * num_insts
586
+ total_mtl_pos_loss += loss_morph.item() * num_insts
587
+ # total_mtl_case_loss += loss_case.item() * num_insts
588
+ # total_mtl_label_loss += loss_label.item() * num_insts
589
+ total_loss += loss.item() * num_insts
590
+ total_train_inst += num_insts
591
+ # optimize parameters
592
+ loss.backward()
593
+ clip_grad_norm_(model.parameters(), args.clip)
594
+ optimizer.step()
595
+
596
+ time_ave = (time.time() - start_time) / batch_num
597
+ time_left = (num_batches - batch_num) * time_ave
598
+
599
+ # update log
600
+ if batch_num % 50 == 0:
601
+ log_info = 'train: %d/%d, domain: %s, total loss: %.2f, arc_loss: %.2f, arc_tag_loss: %.2f,morph_loss: %.2f, case_loss: %.2f, label_loss: %.2f, time left: %.2fs' % \
602
+ (batch_num, num_batches, args.domain, total_loss / total_train_inst, total_arc_loss / total_train_inst,
603
+ total_arc_tag_loss / total_train_inst,total_mtl_pos_loss/total_train_inst,total_mtl_case_loss/total_train_inst,total_mtl_label_loss/total_train_inst, time_left)
604
+ sys.stdout.write(log_info)
605
+ sys.stdout.write('\n')
606
+ sys.stdout.flush()
607
+ print('\n')
608
+ print('train: %d/%d, domain: %s, total_loss: %.2f, arc_loss: %.2f, arc_tag_loss: %.2f,morph_loss: %.2f, case_loss: %.2f, label_loss: %.2f, time: %.2fs' %
609
+ (batch_num, num_batches, args.domain, total_loss / total_train_inst, total_arc_loss / total_train_inst,
610
+ total_arc_tag_loss / total_train_inst,total_mtl_pos_loss/total_train_inst,total_mtl_case_loss/total_train_inst,total_mtl_label_loss/total_train_inst, time.time() - start_time))
611
+
612
+ dev_eval_dict, test_eval_dict, best_model, best_optimizer, patient = in_domain_evaluation(args, datasets, model, optimizer, dev_eval_dict, test_eval_dict, epoch, best_model, best_optimizer, patient)
613
+ if patient >= args.schedule:
614
+ lr = args.learning_rate / (1.0 + epoch * args.decay_rate)
615
+ optimizer = generate_optimizer(args, lr, model.parameters())
616
+ print('updated learning rate to %.6f' % lr)
617
+ patient = 0
618
+ print_results(test_eval_dict['in_domain'], 'test', args.domain, 'best_results')
619
+ print('\n')
620
+ for split in datasets.keys():
621
+ eval_dict = evaluation(args, datasets[split], split, best_model, args.domain, epoch, 'best_results')
622
+ write_results(args, datasets[split], args.domain, split, model, args.domain, eval_dict)
623
+
624
+ else:
625
+ logger.info("Evaluating")
626
+ epoch = start_epoch
627
+ for split in ['train', 'dev', 'test','poetry','prose']:
628
+ eval_dict = evaluation(args, datasets[split], split, model, args.domain, epoch, 'best_results')
629
+ write_results(args, datasets[split], args.domain, split, model, args.domain, eval_dict)
630
+
631
+
632
+ if __name__ == '__main__':
633
+ main()
examples/SequenceTagger.py ADDED
@@ -0,0 +1,589 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import print_function
2
+ import sys
3
+ from os import path, makedirs, system, remove
4
+
5
+ sys.path.append(".")
6
+ sys.path.append("..")
7
+
8
+ import time
9
+ import argparse
10
+ import uuid
11
+ import json
12
+ import numpy as np
13
+ import torch
14
+ import torch.nn as nn
15
+ from collections import namedtuple
16
+ from copy import deepcopy
17
+ from torch.nn.utils import clip_grad_norm_
18
+ from torch.optim import Adam, SGD
19
+ from utils.io_ import seeds, Writer, get_logger, Index2Instance, prepare_data, write_extra_labels
20
+ from utils.models.sequence_tagger import Sequence_Tagger
21
+ from utils import load_word_embeddings
22
+ from utils.tasks.seqeval import accuracy_score, f1_score, precision_score, recall_score,classification_report
23
+
24
+ uid = uuid.uuid4().hex[:6]
25
+
26
+ logger = get_logger('SequenceTagger')
27
+
28
+ def read_arguments():
29
+ args_ = argparse.ArgumentParser(description='Sovling SequenceTagger')
30
+ args_.add_argument('--dataset', choices=['ontonotes', 'ud'], help='Dataset', required=True)
31
+ args_.add_argument('--domain', help='domain', required=True)
32
+ args_.add_argument('--rnn_mode', choices=['RNN', 'LSTM', 'GRU'], help='architecture of rnn',
33
+ required=True)
34
+ args_.add_argument('--task', default='distance_from_the_root', choices=['distance_from_the_root', 'number_of_children',\
35
+ 'relative_pos_based', 'language_model','add_label','add_head_coarse_pos','Multitask_POS_predict','Multitask_case_predict',\
36
+ 'Multitask_label_predict','Multitask_coarse_predict','MRL_case','MRL_POS','MRL_no','MRL_label',\
37
+ 'predict_coarse_of_modifier','predict_ma_tag_of_modifier','add_head_ma','predict_case_of_modifier'], help='sequence_tagger task')
38
+ args_.add_argument('--num_epochs', type=int, default=200, help='Number of training epochs')
39
+ args_.add_argument('--batch_size', type=int, default=64, help='Number of sentences in each batch')
40
+ args_.add_argument('--hidden_size', type=int, default=256, help='Number of hidden units in RNN')
41
+ args_.add_argument('--tag_space', type=int, default=128, help='Dimension of tag space')
42
+ args_.add_argument('--num_layers', type=int, default=1, help='Number of layers of RNN')
43
+ args_.add_argument('--num_filters', type=int, default=50, help='Number of filters in CNN')
44
+ args_.add_argument('--kernel_size', type=int, default=3, help='Size of Kernel for CNN')
45
+ args_.add_argument('--use_pos', action='store_true', help='use part-of-speech embedding.')
46
+ args_.add_argument('--use_char', action='store_true', help='use character embedding and CNN.')
47
+ args_.add_argument('--word_dim', type=int, default=300, help='Dimension of word embeddings')
48
+ args_.add_argument('--pos_dim', type=int, default=50, help='Dimension of POS embeddings')
49
+ args_.add_argument('--char_dim', type=int, default=50, help='Dimension of Character embeddings')
50
+ args_.add_argument('--initializer', choices=['xavier'], help='initialize model parameters')
51
+ args_.add_argument('--opt', choices=['adam', 'sgd'], help='optimization algorithm')
52
+ args_.add_argument('--momentum', type=float, default=0.9, help='momentum of optimizer')
53
+ args_.add_argument('--betas', nargs=2, type=float, default=[0.9, 0.9], help='betas of optimizer')
54
+ args_.add_argument('--learning_rate', type=float, default=0.01, help='Learning rate')
55
+ args_.add_argument('--decay_rate', type=float, default=0.05, help='Decay rate of learning rate')
56
+ args_.add_argument('--schedule', type=int, help='schedule for learning rate decay')
57
+ args_.add_argument('--clip', type=float, default=5.0, help='gradient clipping')
58
+ args_.add_argument('--gamma', type=float, default=0.0, help='weight for regularization')
59
+ args_.add_argument('--epsilon', type=float, default=1e-8, help='epsilon for adam')
60
+ args_.add_argument('--p_rnn', nargs=2, type=float, required=True, help='dropout rate for RNN')
61
+ args_.add_argument('--p_in', type=float, default=0.33, help='dropout rate for input embeddings')
62
+ args_.add_argument('--p_out', type=float, default=0.33, help='dropout rate for output layer')
63
+ args_.add_argument('--unk_replace', type=float, default=0.,
64
+ help='The rate to replace a singleton word with UNK')
65
+ args_.add_argument('--punct_set', nargs='+', type=str, help='List of punctuations')
66
+ args_.add_argument('--word_embedding', choices=['random', 'glove', 'fasttext', 'word2vec'],
67
+ help='Embedding for words')
68
+ args_.add_argument('--word_path', help='path for word embedding dict - in case word_embedding is not random')
69
+ args_.add_argument('--freeze_word_embeddings', action='store_true', help='frozen the word embedding (disable fine-tuning).')
70
+ args_.add_argument('--char_embedding', choices=['random','hellwig'], help='Embedding for characters',
71
+ required=True)
72
+ args_.add_argument('--pos_embedding', choices=['random','one_hot'], help='Embedding for pos',
73
+ required=True)
74
+ args_.add_argument('--char_path', help='path for character embedding dict')
75
+ args_.add_argument('--pos_path', help='path for pos embedding dict')
76
+ args_.add_argument('--use_unlabeled_data', action='store_true', help='flag to use unlabeled data.')
77
+ args_.add_argument('--use_labeled_data', action='store_true', help='flag to use labeled data.')
78
+ args_.add_argument('--model_path', help='path for saving model file.', required=True)
79
+ args_.add_argument('--parser_path', help='path for loading parser files.', default=None)
80
+ args_.add_argument('--load_path', help='path for loading saved source model file.', default=None)
81
+ args_.add_argument('--strict',action='store_true', help='if True loaded model state should contain '
82
+ 'exactly the same keys as current model')
83
+ args_.add_argument('--eval_mode', action='store_true', help='evaluating model without training it')
84
+ args = args_.parse_args()
85
+ args_dict = {}
86
+ args_dict['dataset'] = args.dataset
87
+ args_dict['domain'] = args.domain
88
+ args_dict['task'] = args.task
89
+ args_dict['rnn_mode'] = args.rnn_mode
90
+ args_dict['load_path'] = args.load_path
91
+ args_dict['strict'] = args.strict
92
+ args_dict['model_path'] = args.model_path
93
+ if not path.exists(args_dict['model_path']):
94
+ makedirs(args_dict['model_path'])
95
+ args_dict['parser_path'] = args.parser_path
96
+ args_dict['model_name'] = 'domain_' + args_dict['domain']
97
+ args_dict['full_model_name'] = path.join(args_dict['model_path'],args_dict['model_name'])
98
+ args_dict['use_unlabeled_data'] = args.use_unlabeled_data
99
+ args_dict['use_labeled_data'] = args.use_labeled_data
100
+ print(args_dict['parser_path'])
101
+ if args_dict['task'] == 'number_of_children':
102
+ args_dict['data_paths'] = write_extra_labels.add_number_of_children(args_dict['model_path'], args_dict['parser_path'], args_dict['domain'], args_dict['domain'],
103
+ use_unlabeled_data=args_dict['use_unlabeled_data'],
104
+ use_labeled_data=args_dict['use_labeled_data'])
105
+ elif args_dict['task'] == 'distance_from_the_root':
106
+ args_dict['data_paths'] = write_extra_labels.add_distance_from_the_root(args_dict['model_path'], args_dict['parser_path'], args_dict['domain'], args_dict['domain'],
107
+ use_unlabeled_data=args_dict['use_unlabeled_data'],
108
+ use_labeled_data=args_dict['use_labeled_data'])
109
+ elif args_dict['task'] == 'Multitask_label_predict':
110
+ args_dict['data_paths'] = write_extra_labels.Multitask_label_predict(args_dict['model_path'], args_dict['parser_path'], args_dict['domain'], args_dict['domain'],
111
+ use_unlabeled_data=args_dict['use_unlabeled_data'],
112
+ use_labeled_data=args_dict['use_labeled_data'])
113
+ elif args_dict['task'] == 'Multitask_coarse_predict':
114
+ args_dict['data_paths'] = write_extra_labels.Multitask_coarse_predict(args_dict['model_path'], args_dict['parser_path'], args_dict['domain'], args_dict['domain'],
115
+ use_unlabeled_data=args_dict['use_unlabeled_data'],
116
+ use_labeled_data=args_dict['use_labeled_data'])
117
+ elif args_dict['task'] == 'Multitask_POS_predict':
118
+ args_dict['data_paths'] = write_extra_labels.Multitask_POS_predict(args_dict['model_path'], args_dict['parser_path'], args_dict['domain'], args_dict['domain'],
119
+ use_unlabeled_data=args_dict['use_unlabeled_data'],
120
+ use_labeled_data=args_dict['use_labeled_data'])
121
+ elif args_dict['task'] == 'relative_pos_based':
122
+ args_dict['data_paths'] = write_extra_labels.add_relative_pos_based(args_dict['model_path'], args_dict['parser_path'], args_dict['domain'], args_dict['domain'],
123
+ use_unlabeled_data=args_dict['use_unlabeled_data'],
124
+ use_labeled_data=args_dict['use_labeled_data'])
125
+ elif args_dict['task'] == 'add_label':
126
+ args_dict['data_paths'] = write_extra_labels.add_label(args_dict['model_path'], args_dict['parser_path'], args_dict['domain'], args_dict['domain'],
127
+ use_unlabeled_data=args_dict['use_unlabeled_data'],
128
+ use_labeled_data=args_dict['use_labeled_data'])
129
+ elif args_dict['task'] == 'add_relative_TAG':
130
+ args_dict['data_paths'] = write_extra_labels.add_relative_TAG(args_dict['model_path'], args_dict['parser_path'], args_dict['domain'], args_dict['domain'],
131
+ use_unlabeled_data=args_dict['use_unlabeled_data'],
132
+ use_labeled_data=args_dict['use_labeled_data'])
133
+ elif args_dict['task'] == 'add_head_coarse_pos':
134
+ args_dict['data_paths'] = write_extra_labels.add_head_coarse_pos(args_dict['model_path'], args_dict['parser_path'], args_dict['domain'], args_dict['domain'],
135
+ use_unlabeled_data=args_dict['use_unlabeled_data'],
136
+ use_labeled_data=args_dict['use_labeled_data'])
137
+ elif args_dict['task'] == 'predict_ma_tag_of_modifier':
138
+ args_dict['data_paths'] = write_extra_labels.predict_ma_tag_of_modifier(args_dict['model_path'], args_dict['parser_path'], args_dict['domain'], args_dict['domain'],
139
+ use_unlabeled_data=args_dict['use_unlabeled_data'],
140
+ use_labeled_data=args_dict['use_labeled_data'])
141
+ elif args_dict['task'] == 'Multitask_case_predict':
142
+ args_dict['data_paths'] = write_extra_labels.Multitask_case_predict(args_dict['model_path'], args_dict['parser_path'], args_dict['domain'], args_dict['domain'],
143
+ use_unlabeled_data=args_dict['use_unlabeled_data'],
144
+ use_labeled_data=args_dict['use_labeled_data'])
145
+ elif args_dict['task'] == 'predict_coarse_of_modifier':
146
+ args_dict['data_paths'] = write_extra_labels.predict_coarse_of_modifier(args_dict['model_path'], args_dict['parser_path'], args_dict['domain'], args_dict['domain'],
147
+ use_unlabeled_data=args_dict['use_unlabeled_data'],
148
+ use_labeled_data=args_dict['use_labeled_data'])
149
+ elif args_dict['task'] == 'predict_case_of_modifier':
150
+ args_dict['data_paths'] = write_extra_labels.predict_case_of_modifier(args_dict['model_path'], args_dict['parser_path'], args_dict['domain'], args_dict['domain'],
151
+ use_unlabeled_data=args_dict['use_unlabeled_data'],
152
+ use_labeled_data=args_dict['use_labeled_data'])
153
+ elif args_dict['task'] == 'add_head_ma':
154
+ args_dict['data_paths'] = write_extra_labels.add_head_ma(args_dict['model_path'], args_dict['parser_path'], args_dict['domain'], args_dict['domain'],
155
+ use_unlabeled_data=args_dict['use_unlabeled_data'],
156
+ use_labeled_data=args_dict['use_labeled_data'])
157
+ elif args_dict['task'] == 'MRL_case':
158
+ args_dict['data_paths'] = write_extra_labels.MRL_case(args_dict['model_path'], args_dict['parser_path'], args_dict['domain'], args_dict['domain'],
159
+ use_unlabeled_data=args_dict['use_unlabeled_data'],
160
+ use_labeled_data=args_dict['use_labeled_data'])
161
+ elif args_dict['task'] == 'MRL_POS':
162
+ args_dict['data_paths'] = write_extra_labels.MRL_POS(args_dict['model_path'], args_dict['parser_path'], args_dict['domain'], args_dict['domain'],
163
+ use_unlabeled_data=args_dict['use_unlabeled_data'],
164
+ use_labeled_data=args_dict['use_labeled_data'])
165
+ elif args_dict['task'] == 'MRL_no':
166
+ args_dict['data_paths'] = write_extra_labels.MRL_no(args_dict['model_path'], args_dict['parser_path'], args_dict['domain'], args_dict['domain'],
167
+ use_unlabeled_data=args_dict['use_unlabeled_data'],
168
+ use_labeled_data=args_dict['use_labeled_data'])
169
+ elif args_dict['task'] == 'MRL_label':
170
+ args_dict['data_paths'] = write_extra_labels.MRL_label(args_dict['model_path'], args_dict['parser_path'], args_dict['domain'], args_dict['domain'],
171
+ use_unlabeled_data=args_dict['use_unlabeled_data'],
172
+ use_labeled_data=args_dict['use_labeled_data'])
173
+ else: #args_dict['task'] == 'language_model':
174
+ args_dict['data_paths'] = write_extra_labels.add_language_model(args_dict['model_path'], args_dict['parser_path'], args_dict['domain'], args_dict['domain'],
175
+ use_unlabeled_data=args_dict['use_unlabeled_data'],
176
+ use_labeled_data=args_dict['use_labeled_data'])
177
+ args_dict['splits'] = args_dict['data_paths'].keys()
178
+ alphabet_data_paths = deepcopy(args_dict['data_paths'])
179
+ if args_dict['dataset'] == 'ontonotes':
180
+ data_path = 'data/onto_pos_ner_dp'
181
+ else:
182
+ data_path = 'data/ud_pos_ner_dp'
183
+ # Adding more resources to make sure equal alphabet size for all domains
184
+ for split in args_dict['splits']:
185
+ if args_dict['dataset'] == 'ontonotes':
186
+ alphabet_data_paths['additional_' + split] = data_path + '_' + split + '_' + 'all'
187
+ else:
188
+ if '_' in args_dict['domain']:
189
+ alphabet_data_paths[split] = data_path + '_' + split + '_' + args_dict['domain'].split('_')[0]
190
+ else:
191
+ alphabet_data_paths[split] = args_dict['data_paths'][split]
192
+ args_dict['alphabet_data_paths'] = alphabet_data_paths
193
+ args_dict['num_epochs'] = args.num_epochs
194
+ args_dict['batch_size'] = args.batch_size
195
+ args_dict['hidden_size'] = args.hidden_size
196
+ args_dict['tag_space'] = args.tag_space
197
+ args_dict['num_layers'] = args.num_layers
198
+ args_dict['num_filters'] = args.num_filters
199
+ args_dict['kernel_size'] = args.kernel_size
200
+ args_dict['learning_rate'] = args.learning_rate
201
+ args_dict['initializer'] = nn.init.xavier_uniform_ if args.initializer == 'xavier' else None
202
+ args_dict['opt'] = args.opt
203
+ args_dict['momentum'] = args.momentum
204
+ args_dict['betas'] = tuple(args.betas)
205
+ args_dict['epsilon'] = args.epsilon
206
+ args_dict['decay_rate'] = args.decay_rate
207
+ args_dict['clip'] = args.clip
208
+ args_dict['gamma'] = args.gamma
209
+ args_dict['schedule'] = args.schedule
210
+ args_dict['p_rnn'] = tuple(args.p_rnn)
211
+ args_dict['p_in'] = args.p_in
212
+ args_dict['p_out'] = args.p_out
213
+ args_dict['unk_replace'] = args.unk_replace
214
+ args_dict['punct_set'] = None
215
+ if args.punct_set is not None:
216
+ args_dict['punct_set'] = set(args.punct_set)
217
+ logger.info("punctuations(%d): %s" % (len(args_dict['punct_set']), ' '.join(args_dict['punct_set'])))
218
+ args_dict['freeze_word_embeddings'] = args.freeze_word_embeddings
219
+ args_dict['word_embedding'] = args.word_embedding
220
+ args_dict['word_path'] = args.word_path
221
+ args_dict['use_char'] = args.use_char
222
+ args_dict['char_embedding'] = args.char_embedding
223
+ args_dict['pos_embedding'] = args.pos_embedding
224
+ args_dict['char_path'] = args.char_path
225
+ args_dict['pos_path'] = args.pos_path
226
+ args_dict['use_pos'] = args.use_pos
227
+ args_dict['pos_dim'] = args.pos_dim
228
+ args_dict['word_dict'] = None
229
+ args_dict['word_dim'] = args.word_dim
230
+ if args_dict['word_embedding'] != 'random' and args_dict['word_path']:
231
+ args_dict['word_dict'], args_dict['word_dim'] = load_word_embeddings.load_embedding_dict(args_dict['word_embedding'],
232
+ args_dict['word_path'])
233
+ args_dict['char_dict'] = None
234
+ args_dict['char_dim'] = args.char_dim
235
+ if args_dict['char_embedding'] != 'random':
236
+ args_dict['char_dict'], args_dict['char_dim'] = load_word_embeddings.load_embedding_dict(args_dict['char_embedding'],
237
+ args_dict['char_path'])
238
+ args_dict['pos_dict'] = None
239
+ if args_dict['pos_embedding'] != 'random':
240
+ args_dict['pos_dict'], args_dict['pos_dim'] = load_word_embeddings.load_embedding_dict(args_dict['pos_embedding'],
241
+ args_dict['pos_path'])
242
+ args_dict['alphabet_path'] = path.join(args_dict['model_path'], 'alphabets' + '_src_domain_' + args_dict['domain'] + '/')
243
+ args_dict['alphabet_parser_path'] = path.join(args_dict['parser_path'], 'alphabets' + '_src_domain_' + args_dict['domain'] + '/')
244
+ args_dict['model_name'] = path.join(args_dict['model_path'], args_dict['model_name'])
245
+ args_dict['eval_mode'] = args.eval_mode
246
+ args_dict['device'] = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
247
+ args_dict['word_status'] = 'frozen' if args.freeze_word_embeddings else 'fine tune'
248
+ args_dict['char_status'] = 'enabled' if args.use_char else 'disabled'
249
+ args_dict['pos_status'] = 'enabled' if args.use_pos else 'disabled'
250
+ logger.info("Saving arguments to file")
251
+ save_args(args, args_dict['full_model_name'])
252
+ logger.info("Creating Alphabets")
253
+ alphabet_dict = creating_alphabets(args_dict['alphabet_path'], args_dict['alphabet_parser_path'], args_dict['alphabet_data_paths'])
254
+ args_dict = {**args_dict, **alphabet_dict}
255
+ ARGS = namedtuple('ARGS', args_dict.keys())
256
+ my_args = ARGS(**args_dict)
257
+ return my_args
258
+
259
+
260
+ def creating_alphabets(alphabet_path, alphabet_parser_path, alphabet_data_paths):
261
+ data_paths_list = alphabet_data_paths.values()
262
+ alphabet_dict = {}
263
+ alphabet_dict['alphabets'] = prepare_data.create_alphabets_for_sequence_tagger(alphabet_path, alphabet_parser_path, data_paths_list)
264
+ for k, v in alphabet_dict['alphabets'].items():
265
+ num_key = 'num_' + k.split('_alphabet')[0]
266
+ alphabet_dict[num_key] = v.size()
267
+ logger.info("%s : %d" % (num_key, alphabet_dict[num_key]))
268
+ return alphabet_dict
269
+
270
+ def construct_embedding_table(alphabet, tokens_dict, dim, token_type='word'):
271
+ if tokens_dict is None:
272
+ return None
273
+ scale = np.sqrt(3.0 / dim)
274
+ table = np.empty([alphabet.size(), dim], dtype=np.float32)
275
+ table[prepare_data.UNK_ID, :] = np.random.uniform(-scale, scale, [1, dim]).astype(np.float32)
276
+ oov_tokens = 0
277
+ for token, index in alphabet.items():
278
+ if token in tokens_dict:
279
+ embedding = tokens_dict[token]
280
+ elif token.lower() in tokens_dict:
281
+ embedding = tokens_dict[token.lower()]
282
+ else:
283
+ embedding = np.random.uniform(-scale, scale, [1, dim]).astype(np.float32)
284
+ oov_tokens += 1
285
+ table[index, :] = embedding
286
+ print('token type : %s, number of oov: %d' % (token_type, oov_tokens))
287
+ table = torch.from_numpy(table)
288
+ return table
289
+
290
+ def get_free_gpu():
291
+ system('nvidia-smi -q -d Memory |grep -A4 GPU|grep Free > tmp.txt')
292
+ memory_available = [int(x.split()[2]) for x in open('tmp.txt', 'r').readlines()]
293
+ remove("tmp.txt")
294
+ free_device = 'cuda:' + str(np.argmax(memory_available))
295
+ return free_device
296
+
297
+ def save_args(args, full_model_name):
298
+ arg_path = full_model_name + '.arg.json'
299
+ argparse_dict = vars(args)
300
+ with open(arg_path, 'w') as f:
301
+ json.dump(argparse_dict, f)
302
+
303
+ def generate_optimizer(args, lr, params):
304
+ params = filter(lambda param: param.requires_grad, params)
305
+ if args.opt == 'adam':
306
+ return Adam(params, lr=lr, betas=args.betas, weight_decay=args.gamma, eps=args.epsilon)
307
+ elif args.opt == 'sgd':
308
+ return SGD(params, lr=lr, momentum=args.momentum, weight_decay=args.gamma, nesterov=True)
309
+ else:
310
+ raise ValueError('Unknown optimization algorithm: %s' % args.opt)
311
+
312
+
313
+ def save_checkpoint(model, optimizer, opt, dev_eval_dict, test_eval_dict, full_model_name):
314
+ path_name = full_model_name + '.pt'
315
+ print('Saving model to: %s' % path_name)
316
+ state = {'model_state_dict': model.state_dict(),
317
+ 'optimizer_state_dict': optimizer.state_dict(),
318
+ 'opt': opt, 'dev_eval_dict': dev_eval_dict, 'test_eval_dict': test_eval_dict}
319
+ torch.save(state, path_name)
320
+
321
+
322
+ def load_checkpoint(args, model, optimizer, dev_eval_dict, test_eval_dict, start_epoch, load_path, strict=True):
323
+ print('Loading saved model from: %s' % load_path)
324
+ checkpoint = torch.load(load_path, map_location=args.device)
325
+ if checkpoint['opt'] != args.opt:
326
+ raise ValueError('loaded optimizer type is: %s instead of: %s' % (checkpoint['opt'], args.opt))
327
+ model.load_state_dict(checkpoint['model_state_dict'], strict=strict)
328
+ if strict:
329
+ optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
330
+ dev_eval_dict = checkpoint['dev_eval_dict']
331
+ test_eval_dict = checkpoint['test_eval_dict']
332
+ start_epoch = dev_eval_dict['in_domain']['epoch']
333
+ return model, optimizer, dev_eval_dict, test_eval_dict, start_epoch
334
+
335
+
336
+ def build_model_and_optimizer(args):
337
+ word_table = construct_embedding_table(args.alphabets['word_alphabet'], args.word_dict, args.word_dim, token_type='word')
338
+ char_table = construct_embedding_table(args.alphabets['char_alphabet'], args.char_dict, args.char_dim, token_type='char')
339
+ pos_table = construct_embedding_table(args.alphabets['pos_alphabet'], args.pos_dict, args.pos_dim, token_type='pos')
340
+ model = Sequence_Tagger(args.word_dim, args.num_word, args.char_dim, args.num_char,
341
+ args.use_pos, args.use_char, args.pos_dim, args.num_pos,
342
+ args.num_filters, args.kernel_size, args.rnn_mode,
343
+ args.hidden_size, args.num_layers, args.tag_space, args.num_auto_label,
344
+ embedd_word=word_table, embedd_char=char_table, embedd_pos=pos_table,
345
+ p_in=args.p_in, p_out=args.p_out, p_rnn=args.p_rnn,
346
+ initializer=args.initializer)
347
+ optimizer = generate_optimizer(args, args.learning_rate, model.parameters())
348
+ start_epoch = 0
349
+ dev_eval_dict = {'in_domain': initialize_eval_dict()}
350
+ test_eval_dict = {'in_domain': initialize_eval_dict()}
351
+ if args.load_path:
352
+ model, optimizer, dev_eval_dict, test_eval_dict, start_epoch = \
353
+ load_checkpoint(args, model, optimizer,
354
+ dev_eval_dict, test_eval_dict,
355
+ start_epoch, args.load_path, strict=args.strict)
356
+ if args.freeze_word_embeddings:
357
+ model.rnn_encoder.word_embedd.weight.requires_grad = False
358
+ # model.rnn_encoder.char_embedd.weight.requires_grad = False
359
+ # model.rnn_encoder.pos_embedd.weight.requires_grad = False
360
+ device = args.device
361
+ model.to(device)
362
+ return model, optimizer, dev_eval_dict, test_eval_dict, start_epoch
363
+
364
+
365
+ def initialize_eval_dict():
366
+ eval_dict = {}
367
+ eval_dict['auto_label_accuracy'] = 0.0
368
+ eval_dict['auto_label_precision'] = 0.0
369
+ eval_dict['auto_label_recall'] = 0.0
370
+ eval_dict['auto_label_f1'] = 0.0
371
+ return eval_dict
372
+
373
+ def in_domain_evaluation(args, datasets, model, optimizer, dev_eval_dict, test_eval_dict, epoch,
374
+ best_model, best_optimizer, patient):
375
+ # In-domain evaluation
376
+ curr_dev_eval_dict = evaluation(args, datasets['dev'], 'dev', model, args.domain, epoch, 'current_results')
377
+ is_best_in_domain = dev_eval_dict['in_domain']['auto_label_f1'] <= curr_dev_eval_dict['auto_label_f1']
378
+
379
+ if is_best_in_domain:
380
+ for key, value in curr_dev_eval_dict.items():
381
+ dev_eval_dict['in_domain'][key] = value
382
+ curr_test_eval_dict = evaluation(args, datasets['test'], 'test', model, args.domain, epoch, 'current_results')
383
+ for key, value in curr_test_eval_dict.items():
384
+ test_eval_dict['in_domain'][key] = value
385
+ best_model = deepcopy(model)
386
+ best_optimizer = deepcopy(optimizer)
387
+ patient = 0
388
+ else:
389
+ patient += 1
390
+ if epoch == args.num_epochs:
391
+ # save in-domain checkpoint
392
+ for split in ['dev', 'test']:
393
+ eval_dict = dev_eval_dict['in_domain'] if split == 'dev' else test_eval_dict['in_domain']
394
+ write_results(args, datasets[split], args.domain, split, best_model, args.domain, eval_dict)
395
+ save_checkpoint(best_model, best_optimizer, args.opt, dev_eval_dict, test_eval_dict, args.full_model_name)
396
+
397
+ print('\n')
398
+ return dev_eval_dict, test_eval_dict, best_model, best_optimizer, patient, curr_dev_eval_dict
399
+
400
+
401
+ def evaluation(args, data, split, model, domain, epoch, str_res='results'):
402
+ # evaluate performance on data
403
+ model.eval()
404
+ auto_label_idx2inst = Index2Instance(args.alphabets['auto_label_alphabet'])
405
+ eval_dict = initialize_eval_dict()
406
+ eval_dict['epoch'] = epoch
407
+ pred_labels = []
408
+ gold_labels = []
409
+ for batch in prepare_data.iterate_batch(data, args.batch_size, args.device):
410
+ word, char, pos, ner, heads, arc_tags, auto_label, masks, lengths = batch
411
+ output, masks, lengths = model.forward(word, char, pos, mask=masks, length=lengths)
412
+ auto_label_preds = model.decode(output, mask=masks, length=lengths, leading_symbolic=prepare_data.NUM_SYMBOLIC_TAGS)
413
+ lengths = lengths.cpu().numpy()
414
+ word = word.data.cpu().numpy()
415
+ pos = pos.data.cpu().numpy()
416
+ ner = ner.data.cpu().numpy()
417
+ heads = heads.data.cpu().numpy()
418
+ arc_tags = arc_tags.data.cpu().numpy()
419
+ auto_label = auto_label.data.cpu().numpy()
420
+ auto_label_preds = auto_label_preds.data.cpu().numpy()
421
+ gold_labels += auto_label_idx2inst.index2instance(auto_label, lengths, symbolic_root=True)
422
+ pred_labels += auto_label_idx2inst.index2instance(auto_label_preds, lengths, symbolic_root=True)
423
+
424
+ eval_dict['auto_label_accuracy'] = accuracy_score(gold_labels, pred_labels) * 100
425
+ eval_dict['auto_label_precision'] = precision_score(gold_labels, pred_labels) * 100
426
+ eval_dict['auto_label_recall'] = recall_score(gold_labels, pred_labels) * 100
427
+ eval_dict['auto_label_f1'] = f1_score(gold_labels, pred_labels) * 100
428
+ eval_dict['classification_report'] = classification_report(gold_labels, pred_labels)
429
+ print_results(eval_dict, split, domain, str_res)
430
+ return eval_dict
431
+
432
+
433
+ def print_results(eval_dict, split, domain, str_res='results'):
434
+ print('----------------------------------------------------------------------------------------------------------------------------')
435
+ print('Testing model on domain %s' % domain)
436
+ print('--------------- sequence_tagger - %s ---------------' % split)
437
+ print(
438
+ str_res + ' on ' + split + ' accuracy: %.2f%%, precision: %.2f%%, recall: %.2f%%, F1: %.2f%% (epoch: %d)'
439
+ % (eval_dict['auto_label_accuracy'], eval_dict['auto_label_precision'], eval_dict['auto_label_recall'], eval_dict['auto_label_f1'],
440
+ eval_dict['epoch']))
441
+ print(eval_dict['classification_report'])
442
+
443
+
444
+ def write_results(args, data, data_domain, split, model, model_domain, eval_dict):
445
+ str_file = args.full_model_name + '_' + split + '_model_domain_' + model_domain + '_data_domain_' + data_domain
446
+ res_filename = str_file + '_res.txt'
447
+ pred_filename = str_file + '_pred.txt'
448
+ gold_filename = str_file + '_gold.txt'
449
+
450
+ # save results dictionary into a file
451
+ with open(res_filename, 'w') as f:
452
+ json.dump(eval_dict, f)
453
+
454
+ # save predictions and gold labels into files
455
+ pred_writer = Writer(args.alphabets)
456
+ gold_writer = Writer(args.alphabets)
457
+ pred_writer.start(pred_filename)
458
+ gold_writer.start(gold_filename)
459
+ for batch in prepare_data.iterate_batch(data, args.batch_size, args.device):
460
+ word, char, pos, ner, heads, arc_tags, auto_label, masks, lengths = batch
461
+ output, masks, lengths = model.forward(word, char, pos, mask=masks, length=lengths)
462
+ auto_label_preds = model.decode(output, mask=masks, length=lengths,
463
+ leading_symbolic=prepare_data.NUM_SYMBOLIC_TAGS)
464
+ lengths = lengths.cpu().numpy()
465
+ word = word.data.cpu().numpy()
466
+ pos = pos.data.cpu().numpy()
467
+ ner = ner.data.cpu().numpy()
468
+ heads = heads.data.cpu().numpy()
469
+ arc_tags = arc_tags.data.cpu().numpy()
470
+ auto_label_preds = auto_label_preds.data.cpu().numpy()
471
+ # writing predictions
472
+ pred_writer.write(word, pos, ner, heads, arc_tags, lengths, auto_label=auto_label_preds, symbolic_root=True)
473
+ # writing gold labels
474
+ gold_writer.write(word, pos, ner, heads, arc_tags, lengths, auto_label=auto_label, symbolic_root=True)
475
+
476
+ pred_writer.close()
477
+ gold_writer.close()
478
+
479
+ def main():
480
+ logger.info("Reading and creating arguments")
481
+ args = read_arguments()
482
+ logger.info("Reading Data")
483
+ datasets = {}
484
+ for split in args.splits:
485
+ dataset = prepare_data.read_data_to_variable(args.data_paths[split], args.alphabets, args.device,
486
+ symbolic_root=True)
487
+ datasets[split] = dataset
488
+
489
+ logger.info("Creating Networks")
490
+ num_data = sum(datasets['train'][1])
491
+ model, optimizer, dev_eval_dict, test_eval_dict, start_epoch = build_model_and_optimizer(args)
492
+ best_model = deepcopy(model)
493
+ best_optimizer = deepcopy(optimizer)
494
+ logger.info('Training INFO of in domain %s' % args.domain)
495
+ logger.info('Training on Dependecy Parsing')
496
+ print(model)
497
+ logger.info("train: gamma: %f, batch: %d, clip: %.2f, unk replace: %.2f" % (args.gamma, args.batch_size, args.clip, args.unk_replace))
498
+ logger.info('number of training samples for %s is: %d' % (args.domain, num_data))
499
+ logger.info("dropout(in, out, rnn): (%.2f, %.2f, %s)" % (args.p_in, args.p_out, args.p_rnn))
500
+ logger.info("num_epochs: %d" % (args.num_epochs))
501
+ print('\n')
502
+
503
+ if not args.eval_mode:
504
+ logger.info("Training")
505
+ num_batches = prepare_data.calc_num_batches(datasets['train'], args.batch_size)
506
+ lr = args.learning_rate
507
+ patient = 0
508
+ terminal_patient = 0
509
+ decay = 0
510
+ for epoch in range(start_epoch + 1, args.num_epochs + 1):
511
+ print('Epoch %d (Training: rnn mode: %s, optimizer: %s, learning rate=%.6f, eps=%.1e, decay rate=%.2f (schedule=%d, decay=%d)): ' % (
512
+ epoch, args.rnn_mode, args.opt, lr, args.epsilon, args.decay_rate, args.schedule, decay))
513
+ model.train()
514
+ total_loss = 0.0
515
+ total_train_inst = 0.0
516
+
517
+ iter = prepare_data.iterate_batch_rand_bucket_choosing(
518
+ datasets['train'], args.batch_size, args.device, unk_replace=args.unk_replace)
519
+ start_time = time.time()
520
+ batch_num = 0
521
+ for batch_num, batch in enumerate(iter):
522
+ batch_num = batch_num + 1
523
+ optimizer.zero_grad()
524
+ # compute loss of main task
525
+ word, char, pos, ner_tags, heads, arc_tags, auto_label, masks, lengths = batch
526
+ output, masks, lengths = model.forward(word, char, pos, mask=masks, length=lengths)
527
+ loss = model.loss(output, auto_label, mask=masks, length=lengths)
528
+
529
+ # update losses
530
+ num_insts = masks.data.sum() - word.size(0)
531
+ total_loss += loss.item() * num_insts
532
+ total_train_inst += num_insts
533
+ # optimize parameters
534
+ loss.backward()
535
+ clip_grad_norm_(model.parameters(), args.clip)
536
+ optimizer.step()
537
+
538
+ time_ave = (time.time() - start_time) / batch_num
539
+ time_left = (num_batches - batch_num) * time_ave
540
+
541
+ # update log
542
+ if batch_num % 50 == 0:
543
+ log_info = 'train: %d/%d, domain: %s, total loss: %.2f, time left: %.2fs' % \
544
+ (batch_num, num_batches, args.domain, total_loss / total_train_inst, time_left)
545
+ sys.stdout.write(log_info)
546
+ sys.stdout.write('\n')
547
+ sys.stdout.flush()
548
+ print('\n')
549
+ print('train: %d/%d, domain: %s, total_loss: %.2f, time: %.2fs' %
550
+ (batch_num, num_batches, args.domain, total_loss / total_train_inst, time.time() - start_time))
551
+
552
+ dev_eval_dict, test_eval_dict, best_model, best_optimizer, patient,curr_dev_eval_dict = in_domain_evaluation(args, datasets, model, optimizer, dev_eval_dict, test_eval_dict, epoch, best_model, best_optimizer, patient)
553
+ store ={'dev_eval_dict':curr_dev_eval_dict }
554
+ #############################################
555
+ str_file = args.full_model_name + '_' +'all_epochs'
556
+ with open(str_file,'a') as f:
557
+ f.write(str(store)+'\n')
558
+ if patient == 0:
559
+ terminal_patient = 0
560
+ else:
561
+ terminal_patient += 1
562
+ if terminal_patient >= 4 * args.schedule:
563
+ # Save best model and terminate learning
564
+ cur_epoch = epoch
565
+ epoch = args.num_epochs
566
+ in_domain_evaluation(args, datasets, model, optimizer, dev_eval_dict, test_eval_dict, epoch, best_model,
567
+ best_optimizer, patient)
568
+ log_info = 'Terminating training in epoch %d' % (cur_epoch)
569
+ sys.stdout.write(log_info)
570
+ sys.stdout.write('\n')
571
+ sys.stdout.flush()
572
+ return
573
+ if patient >= args.schedule:
574
+ lr = args.learning_rate / (1.0 + epoch * args.decay_rate)
575
+ optimizer = generate_optimizer(args, lr, model.parameters())
576
+ print('updated learning rate to %.6f' % lr)
577
+ patient = 0
578
+ print_results(test_eval_dict['in_domain'], 'test', args.domain, 'best_results')
579
+ print('\n')
580
+
581
+ else:
582
+ logger.info("Evaluating")
583
+ epoch = start_epoch
584
+ for split in ['train', 'dev', 'test']:
585
+ evaluation(args, datasets[split], split, model, args.domain, epoch, 'best_results')
586
+
587
+
588
+ if __name__ == '__main__':
589
+ main()
examples/VST_Pred_Prepare.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+
3
+ def write_combined(dirs):
4
+ path = "./saved_models/"+dirs+'/final_ensembled_TranSeq/'
5
+ f = open(path+'domain_VST_test_model_domain_VST_data_domain_VST_gold.txt','r')
6
+ gold = f.readlines()
7
+ f.close()
8
+ f = open(path+'domain_VST_test_model_domain_VST_data_domain_VST_pred.txt','r')
9
+ pred = f.readlines()
10
+ f.close()
11
+
12
+ for i in range(len(gold)):
13
+ if gold[i] == '\n':
14
+ continue
15
+ if gold[i].split('\t')[0] == pred[i].split('\t')[0]:
16
+ gold[i] = gold[i].replace('\n','\t')
17
+ gold[i] = gold[i]+'\t'.join(pred[i].split('\t')[-2:])
18
+
19
+
20
+ gold.insert(0,'word_id\tword\tpostag\tlemma\tgold_head\tgold_label\tpred_head\tpred_label\n\n')
21
+
22
+
23
+ f = open(path+'VST_test.txt','w')
24
+ for line in gold:
25
+ f.write(line)
26
+ f.close()
27
+
28
+
29
+ if __name__=="__main__":
30
+
31
+ dir_path = sys.argv[1]
32
+
33
+ write_combined(dir_path)
34
+
examples/VST_macro_score.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+
3
+
4
+ def load_results(filename):
5
+
6
+ results = []
7
+ sent = []
8
+ with open(filename, 'r') as fp:
9
+ for i, line in enumerate(fp):
10
+ if i == 0:
11
+ continue
12
+ splits = line.strip().split('\t')
13
+ if len(line.strip()) == 0:
14
+ if len(sent) != 0:
15
+ results.append(sent)
16
+ sent = []
17
+ continue
18
+ gold_head = splits[-4]
19
+ gold_label = splits[-3]
20
+ pred_head = splits[-2]
21
+ pred_label = splits[-1]
22
+ sent.append((gold_head, gold_label, pred_head, pred_label))
23
+ print('Total Number of sentences ' + str(len(results)))
24
+ return results
25
+
26
+ def calculate_las_uas(gold_heads, gold_labels, pred_heads, pred_labels):
27
+
28
+ u_correct = 0
29
+ l_correct = 0
30
+ u_total = 0
31
+ l_total = 0
32
+
33
+ for i in range(len(gold_heads)):
34
+ if gold_heads[i] == pred_heads[i]:
35
+ u_correct +=1
36
+ u_total +=1
37
+ l_total +=1
38
+ if gold_heads[i] == pred_heads[i] and gold_labels[i] == pred_labels[i]:
39
+ l_correct +=1
40
+ return u_correct, u_total, l_correct, l_total
41
+
42
+
43
+ def calculate_stats(results,path):
44
+ u_correct = 0
45
+ l_correct = 0
46
+ u_total = 0
47
+ l_total = 0
48
+
49
+ sent_uas = []
50
+ sent_las = []
51
+
52
+ for i in range(len(results)):
53
+ gold_heads, gold_labels, pred_heads, pred_labels = zip(*results[i])
54
+ u_c, u_t, l_c, l_t = calculate_las_uas(gold_heads, gold_labels, pred_heads, pred_labels)
55
+ if u_t >0:
56
+ uas = float(u_c)/u_t
57
+ las = float(l_c)/l_t
58
+ sent_uas.append(uas)
59
+ sent_las.append(las)
60
+ u_correct += u_c
61
+ l_correct += l_c
62
+ u_total += u_t
63
+ l_total += l_t
64
+
65
+ UAS = float(u_correct)/u_total
66
+ LAS = float(l_correct)/l_total
67
+ path = path.replace('VST_test.txt','Macro-UAS-LAS-score.txt')
68
+ f = open(path,'w')
69
+ f.write('Word level UAS : ' + str(UAS) +'\n')
70
+ f.write('Word level LAS : ' + str(LAS)+'\n')
71
+ f.write('Sentence level UAS : ' + str(float(sum(sent_uas))/len(sent_uas))+'\n')
72
+ f.write('Sentence level LAS : ' + str(float(sum(sent_las))/len(sent_las))+'\n')
73
+ f.close()
74
+ print('Word level UAS : ' + str(UAS))
75
+ print('Word level LAS : ' + str(LAS))
76
+ print('Sentence level UAS : ' + str(float(sum(sent_uas))/len(sent_uas)))
77
+ print('Sentence level LAS : ' + str(float(sum(sent_las))/len(sent_las)))
78
+
79
+ return sent_uas, sent_las, UAS, LAS
80
+
81
+ def write_results(sent_uas, sent_las, filename_uas, filename_las):
82
+
83
+ fp_uas = open(filename_uas, 'w')
84
+ fp_las = open(filename_las, 'w')
85
+
86
+ for i in range(len(sent_uas)):
87
+ fp_uas.write(str(sent_uas[i]) + '\n')
88
+ fp_las.write(str(sent_las[i]) + '\n')
89
+
90
+ fp_uas.close()
91
+ fp_las.close()
92
+
93
+
94
+ if __name__=="__main__":
95
+ dirs = sys.argv[1]
96
+ # results_2 = load_results(sys.argv[2])
97
+ ##path = "Predictions/Yap/"+dirs
98
+ path = "./saved_models/"+dirs+"/final_ensembled_TranSeq/VST_test.txt"
99
+ result = load_results(path)
100
+
101
+
102
+ sent_uas1, sent_las1, UAS1, LAS1 = calculate_stats(result,path)
103
+ # sent_uas2, sent_las2, UAS2, LAS2 = calculate_stats(results_2)
104
+
105
+
106
+ write_results(sent_uas1, sent_las1, 'results1_uas.txt', 'results1_las.txt')
107
+ # write_results(sent_uas2, sent_las2, 'results2_uas.txt', 'results2_las.txt')
examples/eval/conll03eval.v2 ADDED
@@ -0,0 +1,336 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/perl -w
2
+ # conlleval: evaluate result of processing CoNLL-2000 shared task
3
+ # usage: conlleval [-l] [-r] [-d delimiterTag] [-o oTag] < file
4
+ # README: http://cnts.uia.ac.be/conll2000/chunking/output.html
5
+ # options: l: generate LaTeX output for tables like in
6
+ # http://cnts.uia.ac.be/conll2003/ner/example.tex
7
+ # r: accept raw result tags (without B- and I- prefix;
8
+ # assumes one word per chunk)
9
+ # d: alternative delimiter tag (default is single space)
10
+ # o: alternative outside tag (default is O)
11
+ # note: the file should contain lines with items separated
12
+ # by $delimiter characters (default space). The final
13
+ # two items should contain the correct tag and the
14
+ # guessed tag in that order. Sentences should be
15
+ # separated from each other by empty lines or lines
16
+ # with $boundary fields (default -X-).
17
+ # url: http://lcg-www.uia.ac.be/conll2000/chunking/
18
+ # started: 1998-09-25
19
+ # version: 2004-01-26
20
+ # author: Erik Tjong Kim Sang <erikt@uia.ua.ac.be>
21
+
22
+ use strict;
23
+
24
+ my $false = 0;
25
+ my $true = 42;
26
+
27
+ my $boundary = "-X-"; # sentence boundary
28
+ my $correct; # current corpus chunk tag (I,O,B)
29
+ my $correctChunk = 0; # number of correctly identified chunks
30
+ my $correctTags = 0; # number of correct chunk tags
31
+ my $correctType; # type of current corpus chunk tag (NP,VP,etc.)
32
+ my $delimiter = " "; # field delimiter
33
+ my $FB1 = 0.0; # FB1 score (Van Rijsbergen 1979)
34
+ my $firstItem; # first feature (for sentence boundary checks)
35
+ my $foundCorrect = 0; # number of chunks in corpus
36
+ my $foundGuessed = 0; # number of identified chunks
37
+ my $guessed; # current guessed chunk tag
38
+ my $guessedType; # type of current guessed chunk tag
39
+ my $i; # miscellaneous counter
40
+ my $inCorrect = $false; # currently processed chunk is correct until now
41
+ my $lastCorrect = "O"; # previous chunk tag in corpus
42
+ my $latex = 0; # generate LaTeX formatted output
43
+ my $lastCorrectType = ""; # type of previously identified chunk tag
44
+ my $lastGuessed = "O"; # previously identified chunk tag
45
+ my $lastGuessedType = ""; # type of previous chunk tag in corpus
46
+ my $lastType; # temporary storage for detecting duplicates
47
+ my $line; # line
48
+ my $nbrOfFeatures = -1; # number of features per line
49
+ my $precision = 0.0; # precision score
50
+ my $oTag = "O"; # outside tag, default O
51
+ my $raw = 0; # raw input: add B to every token
52
+ my $recall = 0.0; # recall score
53
+ my $tokenCounter = 0; # token counter (ignores sentence breaks)
54
+
55
+ my %correctChunk = (); # number of correctly identified chunks per type
56
+ my %foundCorrect = (); # number of chunks in corpus per type
57
+ my %foundGuessed = (); # number of identified chunks per type
58
+
59
+ my @features; # features on line
60
+ my @sortedTypes; # sorted list of chunk type names
61
+
62
+ # sanity check
63
+ while (@ARGV and $ARGV[0] =~ /^-/) {
64
+ if ($ARGV[0] eq "-l") { $latex = 1; shift(@ARGV); }
65
+ elsif ($ARGV[0] eq "-r") { $raw = 1; shift(@ARGV); }
66
+ elsif ($ARGV[0] eq "-d") {
67
+ shift(@ARGV);
68
+ if (not defined $ARGV[0]) {
69
+ die "conlleval: -d requires delimiter character";
70
+ }
71
+ $delimiter = shift(@ARGV);
72
+ } elsif ($ARGV[0] eq "-o") {
73
+ shift(@ARGV);
74
+ if (not defined $ARGV[0]) {
75
+ die "conlleval: -o requires delimiter character";
76
+ }
77
+ $oTag = shift(@ARGV);
78
+ } else { die "conlleval: unknown argument $ARGV[0]\n"; }
79
+ }
80
+ if (@ARGV) { die "conlleval: unexpected command line argument\n"; }
81
+ # process input
82
+ while (<STDIN>) {
83
+ chomp($line = $_);
84
+ @features = split(/$delimiter/,$line);
85
+ if ($nbrOfFeatures < 0) { $nbrOfFeatures = $#features; }
86
+ elsif ($nbrOfFeatures != $#features and @features != 0) {
87
+ printf STDERR "unexpected number of features: %d (%d)\n",
88
+ $#features+1,$nbrOfFeatures+1;
89
+ exit(1);
90
+ }
91
+ if (@features == 0 or
92
+ $features[0] eq $boundary) { @features = ($boundary,"O","O"); }
93
+ if (@features < 2) {
94
+ die "conlleval: unexpected number of features in line $line\n";
95
+ }
96
+ if ($raw) {
97
+ if ($features[$#features] eq $oTag) { $features[$#features] = "O"; }
98
+ if ($features[$#features-1] eq $oTag) { $features[$#features-1] = "O"; }
99
+ if ($features[$#features] ne "O") {
100
+ $features[$#features] = "B-$features[$#features]";
101
+ }
102
+ if ($features[$#features-1] ne "O") {
103
+ $features[$#features-1] = "B-$features[$#features-1]";
104
+ }
105
+ }
106
+ # 20040126 ET code which allows hyphens in the types
107
+ if ($features[$#features] =~ /^([^-]*)-(.*)$/) {
108
+ $guessed = $1;
109
+ $guessedType = $2;
110
+ } else {
111
+ $guessed = $features[$#features];
112
+ $guessedType = "";
113
+ }
114
+ pop(@features);
115
+ if ($features[$#features] =~ /^([^-]*)-(.*)$/) {
116
+ $correct = $1;
117
+ $correctType = $2;
118
+ } else {
119
+ $correct = $features[$#features];
120
+ $correctType = "";
121
+ }
122
+ pop(@features);
123
+ # ($guessed,$guessedType) = split(/-/,pop(@features));
124
+ # ($correct,$correctType) = split(/-/,pop(@features));
125
+ $guessedType = $guessedType ? $guessedType : "";
126
+ $correctType = $correctType ? $correctType : "";
127
+ $firstItem = shift(@features);
128
+
129
+ # 1999-06-26 sentence breaks should always be counted as out of chunk
130
+ if ( $firstItem eq $boundary ) { $guessed = "O"; }
131
+
132
+ if ($inCorrect) {
133
+ if ( &endOfChunk($lastCorrect,$correct,$lastCorrectType,$correctType) and
134
+ &endOfChunk($lastGuessed,$guessed,$lastGuessedType,$guessedType) and
135
+ $lastGuessedType eq $lastCorrectType) {
136
+ $inCorrect=$false;
137
+ $correctChunk++;
138
+ $correctChunk{$lastCorrectType} = $correctChunk{$lastCorrectType} ?
139
+ $correctChunk{$lastCorrectType}+1 : 1;
140
+ } elsif (
141
+ &endOfChunk($lastCorrect,$correct,$lastCorrectType,$correctType) !=
142
+ &endOfChunk($lastGuessed,$guessed,$lastGuessedType,$guessedType) or
143
+ $guessedType ne $correctType ) {
144
+ $inCorrect=$false;
145
+ }
146
+ }
147
+
148
+ if ( &startOfChunk($lastCorrect,$correct,$lastCorrectType,$correctType) and
149
+ &startOfChunk($lastGuessed,$guessed,$lastGuessedType,$guessedType) and
150
+ $guessedType eq $correctType) { $inCorrect = $true; }
151
+
152
+ if ( &startOfChunk($lastCorrect,$correct,$lastCorrectType,$correctType) ) {
153
+ $foundCorrect++;
154
+ $foundCorrect{$correctType} = $foundCorrect{$correctType} ?
155
+ $foundCorrect{$correctType}+1 : 1;
156
+ }
157
+ if ( &startOfChunk($lastGuessed,$guessed,$lastGuessedType,$guessedType) ) {
158
+ $foundGuessed++;
159
+ $foundGuessed{$guessedType} = $foundGuessed{$guessedType} ?
160
+ $foundGuessed{$guessedType}+1 : 1;
161
+ }
162
+ if ( $firstItem ne $boundary ) {
163
+ if ( $correct eq $guessed and $guessedType eq $correctType ) {
164
+ $correctTags++;
165
+ }
166
+ $tokenCounter++;
167
+ }
168
+
169
+ $lastGuessed = $guessed;
170
+ $lastCorrect = $correct;
171
+ $lastGuessedType = $guessedType;
172
+ $lastCorrectType = $correctType;
173
+ }
174
+ if ($inCorrect) {
175
+ $correctChunk++;
176
+ $correctChunk{$lastCorrectType} = $correctChunk{$lastCorrectType} ?
177
+ $correctChunk{$lastCorrectType}+1 : 1;
178
+ }
179
+
180
+ if (not $latex) {
181
+ # compute overall precision, recall and FB1 (default values are 0.0)
182
+ $precision = 100*$correctChunk/$foundGuessed if ($foundGuessed > 0);
183
+ $recall = 100*$correctChunk/$foundCorrect if ($foundCorrect > 0);
184
+ $FB1 = 2*$precision*$recall/($precision+$recall)
185
+ if ($precision+$recall > 0);
186
+
187
+ # print overall performance
188
+ printf "processed $tokenCounter tokens with $foundCorrect phrases; ";
189
+ printf "found: $foundGuessed phrases; correct: $correctChunk.\n";
190
+ if ($tokenCounter>0) {
191
+ printf "accuracy: %6.2f%%; ",100*$correctTags/$tokenCounter;
192
+ printf "precision: %6.2f%%; ",$precision;
193
+ printf "recall: %6.2f%%; ",$recall;
194
+ printf "FB1: %6.2f\n",$FB1;
195
+ }
196
+ }
197
+
198
+ # sort chunk type names
199
+ undef($lastType);
200
+ @sortedTypes = ();
201
+ foreach $i (sort (keys %foundCorrect,keys %foundGuessed)) {
202
+ if (not($lastType) or $lastType ne $i) {
203
+ push(@sortedTypes,($i));
204
+ }
205
+ $lastType = $i;
206
+ }
207
+ # print performance per chunk type
208
+ if (not $latex) {
209
+ for $i (@sortedTypes) {
210
+ $correctChunk{$i} = $correctChunk{$i} ? $correctChunk{$i} : 0;
211
+ if (not($foundGuessed{$i})) { $foundGuessed{$i} = 0; $precision = 0.0; }
212
+ else { $precision = 100*$correctChunk{$i}/$foundGuessed{$i}; }
213
+ if (not($foundCorrect{$i})) { $recall = 0.0; }
214
+ else { $recall = 100*$correctChunk{$i}/$foundCorrect{$i}; }
215
+ if ($precision+$recall == 0.0) { $FB1 = 0.0; }
216
+ else { $FB1 = 2*$precision*$recall/($precision+$recall); }
217
+ printf "%17s: ",$i;
218
+ printf "precision: %6.2f%%; ",$precision;
219
+ printf "recall: %6.2f%%; ",$recall;
220
+ printf "FB1: %6.2f %d\n",$FB1,$foundGuessed{$i};
221
+ }
222
+ } else {
223
+ print " & Precision & Recall & F\$_{\\beta=1} \\\\\\hline";
224
+ for $i (@sortedTypes) {
225
+ $correctChunk{$i} = $correctChunk{$i} ? $correctChunk{$i} : 0;
226
+ if (not($foundGuessed{$i})) { $precision = 0.0; }
227
+ else { $precision = 100*$correctChunk{$i}/$foundGuessed{$i}; }
228
+ if (not($foundCorrect{$i})) { $recall = 0.0; }
229
+ else { $recall = 100*$correctChunk{$i}/$foundCorrect{$i}; }
230
+ if ($precision+$recall == 0.0) { $FB1 = 0.0; }
231
+ else { $FB1 = 2*$precision*$recall/($precision+$recall); }
232
+ printf "\n%-7s & %6.2f\\%% & %6.2f\\%% & %6.2f \\\\",
233
+ $i,$precision,$recall,$FB1;
234
+ }
235
+ print "\\hline\n";
236
+ $precision = 0.0;
237
+ $recall = 0;
238
+ $FB1 = 0.0;
239
+ $precision = 100*$correctChunk/$foundGuessed if ($foundGuessed > 0);
240
+ $recall = 100*$correctChunk/$foundCorrect if ($foundCorrect > 0);
241
+ $FB1 = 2*$precision*$recall/($precision+$recall)
242
+ if ($precision+$recall > 0);
243
+ printf "Overall & %6.2f\\%% & %6.2f\\%% & %6.2f \\\\\\hline\n",
244
+ $precision,$recall,$FB1;
245
+ }
246
+
247
+ exit 0;
248
+
249
+ # endOfChunk: checks if a chunk ended between the previous and current word
250
+ # arguments: previous and current chunk tags, previous and current types
251
+ # note: this code is capable of handling other chunk representations
252
+ # than the default CoNLL-2000 ones, see EACL'99 paper of Tjong
253
+ # Kim Sang and Veenstra http://xxx.lanl.gov/abs/cs.CL/9907006
254
+
255
+ sub endOfChunk {
256
+ my $prevTag = shift(@_);
257
+ my $tag = shift(@_);
258
+ my $prevType = shift(@_);
259
+ my $type = shift(@_);
260
+ my $chunkEnd = $false;
261
+
262
+ if ( $prevTag eq "B" and $tag eq "B" ) { $chunkEnd = $true; }
263
+ if ( $prevTag eq "B" and $tag eq "O" ) { $chunkEnd = $true; }
264
+ if ( $prevTag eq "B" and $tag eq "S" ) { $chunkEnd = $true; }
265
+
266
+ if ( $prevTag eq "I" and $tag eq "B" ) { $chunkEnd = $true; }
267
+ if ( $prevTag eq "I" and $tag eq "S" ) { $chunkEnd = $true; }
268
+ if ( $prevTag eq "I" and $tag eq "O" ) { $chunkEnd = $true; }
269
+
270
+ if ( $prevTag eq "E" and $tag eq "E" ) { $chunkEnd = $true; }
271
+ if ( $prevTag eq "E" and $tag eq "I" ) { $chunkEnd = $true; }
272
+ if ( $prevTag eq "E" and $tag eq "O" ) { $chunkEnd = $true; }
273
+ if ( $prevTag eq "E" and $tag eq "S" ) { $chunkEnd = $true; }
274
+ if ( $prevTag eq "E" and $tag eq "B" ) { $chunkEnd = $true; }
275
+
276
+ if ( $prevTag eq "S" and $tag eq "E" ) { $chunkEnd = $true; }
277
+ if ( $prevTag eq "S" and $tag eq "I" ) { $chunkEnd = $true; }
278
+ if ( $prevTag eq "S" and $tag eq "O" ) { $chunkEnd = $true; }
279
+ if ( $prevTag eq "S" and $tag eq "S" ) { $chunkEnd = $true; }
280
+ if ( $prevTag eq "S" and $tag eq "B" ) { $chunkEnd = $true; }
281
+
282
+
283
+ if ($prevTag ne "O" and $prevTag ne "." and $prevType ne $type) {
284
+ $chunkEnd = $true;
285
+ }
286
+
287
+ # corrected 1998-12-22: these chunks are assumed to have length 1
288
+ if ( $prevTag eq "]" ) { $chunkEnd = $true; }
289
+ if ( $prevTag eq "[" ) { $chunkEnd = $true; }
290
+
291
+ return($chunkEnd);
292
+ }
293
+
294
+ # startOfChunk: checks if a chunk started between the previous and current word
295
+ # arguments: previous and current chunk tags, previous and current types
296
+ # note: this code is capable of handling other chunk representations
297
+ # than the default CoNLL-2000 ones, see EACL'99 paper of Tjong
298
+ # Kim Sang and Veenstra http://xxx.lanl.gov/abs/cs.CL/9907006
299
+
300
+ sub startOfChunk {
301
+ my $prevTag = shift(@_);
302
+ my $tag = shift(@_);
303
+ my $prevType = shift(@_);
304
+ my $type = shift(@_);
305
+ my $chunkStart = $false;
306
+
307
+ if ( $prevTag eq "B" and $tag eq "B" ) { $chunkStart = $true; }
308
+ if ( $prevTag eq "I" and $tag eq "B" ) { $chunkStart = $true; }
309
+ if ( $prevTag eq "O" and $tag eq "B" ) { $chunkStart = $true; }
310
+ if ( $prevTag eq "S" and $tag eq "B" ) { $chunkStart = $true; }
311
+ if ( $prevTag eq "E" and $tag eq "B" ) { $chunkStart = $true; }
312
+
313
+ if ( $prevTag eq "B" and $tag eq "S" ) { $chunkStart = $true; }
314
+ if ( $prevTag eq "I" and $tag eq "S" ) { $chunkStart = $true; }
315
+ if ( $prevTag eq "O" and $tag eq "S" ) { $chunkStart = $true; }
316
+ if ( $prevTag eq "S" and $tag eq "S" ) { $chunkStart = $true; }
317
+ if ( $prevTag eq "E" and $tag eq "S" ) { $chunkStart = $true; }
318
+
319
+ if ( $prevTag eq "O" and $tag eq "I" ) { $chunkStart = $true; }
320
+ if ( $prevTag eq "S" and $tag eq "I" ) { $chunkStart = $true; }
321
+ if ( $prevTag eq "E" and $tag eq "I" ) { $chunkStart = $true; }
322
+
323
+ if ( $prevTag eq "S" and $tag eq "E" ) { $chunkStart = $true; }
324
+ if ( $prevTag eq "E" and $tag eq "E" ) { $chunkStart = $true; }
325
+ if ( $prevTag eq "O" and $tag eq "E" ) { $chunkStart = $true; }
326
+
327
+ if ($tag ne "O" and $tag ne "." and $prevType ne $type) {
328
+ $chunkStart = $true;
329
+ }
330
+
331
+ # corrected 1998-12-22: these chunks are assumed to have length 1
332
+ if ( $tag eq "[" ) { $chunkStart = $true; }
333
+ if ( $tag eq "]" ) { $chunkStart = $true; }
334
+
335
+ return($chunkStart);
336
+ }
examples/eval/conll06eval.pl ADDED
@@ -0,0 +1,1826 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env perl
2
+
3
+ # Author: Yuval Krymolowski
4
+ # Addition of precision and recall
5
+ # and of frame confusion list: Sabine Buchholz
6
+ # Addition of DEPREL + ATTACHMENT:
7
+ # Prokopis Prokopidis (prokopis at ilsp dot gr)
8
+ # Acknowledgements:
9
+ # to Markus Kuhn for suggesting the use of
10
+ # the Unicode category property
11
+
12
+ if ($] < 5.008001)
13
+ {
14
+ printf STDERR <<EOM
15
+
16
+ This script requires PERL 5.8.1 for running.
17
+ The new version is needed for proper handling
18
+ of Unicode characters.
19
+
20
+ Please obtain a new version or contact the shared task team
21
+ if you are unable to upgrade PERL.
22
+
23
+ EOM
24
+ ;
25
+ exit(1) ;
26
+ }
27
+
28
+ require Encode;
29
+
30
+ use strict ;
31
+ use warnings;
32
+ use Getopt::Std ;
33
+
34
+ my ($usage) = <<EOT
35
+
36
+ CoNLL-X evaluation script:
37
+
38
+ [perl] eval.pl [OPTIONS] -g <gold standard> -s <system output>
39
+
40
+ This script evaluates a system output with respect to a gold standard.
41
+ Both files should be in UTF-8 encoded CoNLL-X tabular format.
42
+
43
+ Punctuation tokens (those where all characters have the Unicode
44
+ category property "Punctuation") are ignored for scoring (unless the
45
+ -p flag is used).
46
+
47
+ The output breaks down the errors according to their type and context.
48
+
49
+ Optional parameters:
50
+ -o FILE : output: print output to FILE (default is standard output)
51
+ -q : quiet: only print overall performance, without the details
52
+ -b : evalb: produce output in a format similar to evalb
53
+ (http://nlp.cs.nyu.edu/evalb/); use together with -q
54
+ -p : punctuation: also score on punctuation (default is not to score on it)
55
+ -v : version: show the version number
56
+ -h : help: print this help text and exit
57
+
58
+ EOT
59
+ ;
60
+
61
+ my ($line_num) ;
62
+ my ($sep) = '0x01' ;
63
+
64
+ my ($START) = '.S' ;
65
+ my ($END) = '.E' ;
66
+
67
+ my ($con_err_num) = 3 ;
68
+ my ($freq_err_num) = 10 ;
69
+ my ($spec_err_loc_con) = 8 ;
70
+
71
+ ################################################################################
72
+ ### subfunctions ###
73
+ ################################################################################
74
+
75
+ # Whether a string consists entirely of characters with the Unicode
76
+ # category property "Punctuation" (see "man perlunicode")
77
+ sub is_uni_punct
78
+ {
79
+ my ($word) = @_ ;
80
+
81
+ return scalar(Encode::decode_utf8($word)=~ /^\p{Punctuation}+$/) ;
82
+ }
83
+
84
+ # The length of a unicode string, excluding non-spacing marks
85
+ # (for example vowel marks in Arabic)
86
+
87
+ sub uni_len
88
+ {
89
+ my ($word) = @_ ;
90
+ my ($ch, $l) ;
91
+
92
+ $l = 0 ;
93
+ foreach $ch (split(//, Encode::decode_utf8($word)))
94
+ {
95
+ if ($ch !~ /^\p{NonspacingMark}/)
96
+ {
97
+ $l++ ;
98
+ }
99
+ }
100
+
101
+ return $l ;
102
+ }
103
+
104
+ sub filter_context_counts
105
+ { # filter_context_counts
106
+
107
+ my ($vec, $num, $max_len) = @_ ;
108
+ my ($con, $l, $thresh) ;
109
+
110
+ $thresh = (sort {$b <=> $a} values %{$vec})[$num-1] ;
111
+
112
+ foreach $con (keys %{$vec})
113
+ {
114
+ if (${$vec}{$con} < $thresh)
115
+ {
116
+ delete ${$vec}{$con} ;
117
+ next ;
118
+ }
119
+
120
+ $l = uni_len($con) ;
121
+
122
+ if ($l > ${$max_len})
123
+ {
124
+ ${$max_len} = $l ;
125
+ }
126
+ }
127
+
128
+ } # filter_context_counts
129
+
130
+ sub print_context
131
+ { # print_context
132
+
133
+ my ($counts, $counts_pos, $max_con_len, $max_con_pos_len) = @_ ;
134
+ my (@v_con, @v_con_pos, $con, $con_pos, $i, $n) ;
135
+
136
+ printf OUT " %-*s | %-4s | %-4s | %-4s | %-4s", $max_con_pos_len, 'CPOS', 'any', 'head', 'dep', 'both' ;
137
+ printf OUT " ||" ;
138
+ printf OUT " %-*s | %-4s | %-4s | %-4s | %-4s", $max_con_len, 'word', 'any', 'head', 'dep', 'both' ;
139
+ printf OUT "\n" ;
140
+ printf OUT " %s-+------+------+------+-----", '-' x $max_con_pos_len;
141
+ printf OUT "--++" ;
142
+ printf OUT "--%s-+------+------+------+-----", '-' x $max_con_len;
143
+ printf OUT "\n" ;
144
+
145
+ @v_con = sort {${$counts}{tot}{$b} <=> ${$counts}{tot}{$a}} keys %{${$counts}{tot}} ;
146
+ @v_con_pos = sort {${$counts_pos}{tot}{$b} <=> ${$counts_pos}{tot}{$a}} keys %{${$counts_pos}{tot}} ;
147
+
148
+ $n = scalar @v_con ;
149
+ if (scalar @v_con_pos > $n)
150
+ {
151
+ $n = scalar @v_con_pos ;
152
+ }
153
+
154
+ foreach $i (0 .. $n-1)
155
+ {
156
+ if (defined $v_con_pos[$i])
157
+ {
158
+ $con_pos = $v_con_pos[$i] ;
159
+ printf OUT " %-*s | %4d | %4d | %4d | %4d",
160
+ $max_con_pos_len, $con_pos, ${$counts_pos}{tot}{$con_pos},
161
+ ${$counts_pos}{err_head}{$con_pos}, ${$counts_pos}{err_dep}{$con_pos},
162
+ ${$counts_pos}{err_dep}{$con_pos}+${$counts_pos}{err_head}{$con_pos}-${$counts_pos}{tot}{$con_pos} ;
163
+ }
164
+ else
165
+ {
166
+ printf OUT " %-*s | %4s | %4s | %4s | %4s",
167
+ $max_con_pos_len, ' ', ' ', ' ', ' ', ' ' ;
168
+ }
169
+
170
+ printf OUT " ||" ;
171
+
172
+ if (defined $v_con[$i])
173
+ {
174
+ $con = $v_con[$i] ;
175
+ printf OUT " %-*s | %4d | %4d | %4d | %4d",
176
+ $max_con_len+length($con)-uni_len($con), $con, ${$counts}{tot}{$con},
177
+ ${$counts}{err_head}{$con}, ${$counts}{err_dep}{$con},
178
+ ${$counts}{err_dep}{$con}+${$counts}{err_head}{$con}-${$counts}{tot}{$con} ;
179
+ }
180
+ else
181
+ {
182
+ printf OUT " %-*s | %4s | %4s | %4s | %4s",
183
+ $max_con_len, ' ', ' ', ' ', ' ', ' ' ;
184
+ }
185
+
186
+ printf OUT "\n" ;
187
+ }
188
+
189
+ printf OUT " %s-+------+------+------+-----", '-' x $max_con_pos_len;
190
+ printf OUT "--++" ;
191
+ printf OUT "--%s-+------+------+------+-----", '-' x $max_con_len;
192
+ printf OUT "\n" ;
193
+
194
+ printf OUT "\n\n" ;
195
+
196
+ } # print_context
197
+
198
+ sub num_as_word
199
+ {
200
+ my ($num) = @_ ;
201
+
202
+ $num = abs($num) ;
203
+
204
+ if ($num == 1)
205
+ {
206
+ return ('one word') ;
207
+ }
208
+ elsif ($num == 2)
209
+ {
210
+ return ('two words') ;
211
+ }
212
+ elsif ($num == 3)
213
+ {
214
+ return ('three words') ;
215
+ }
216
+ elsif ($num == 4)
217
+ {
218
+ return ('four words') ;
219
+ }
220
+ else
221
+ {
222
+ return ($num.' words') ;
223
+ }
224
+ }
225
+
226
+ sub describe_err
227
+ { # describe_err
228
+
229
+ my ($head_err, $head_aft_bef, $dep_err) = @_ ;
230
+ my ($dep_g, $dep_s, $desc) ;
231
+ my ($head_aft_bef_g, $head_aft_bef_s) = split(//, $head_aft_bef) ;
232
+
233
+ if ($head_err eq '-')
234
+ {
235
+ $desc = 'correct head' ;
236
+
237
+ if ($head_aft_bef_s eq '0')
238
+ {
239
+ $desc .= ' (0)' ;
240
+ }
241
+ elsif ($head_aft_bef_s eq 'e')
242
+ {
243
+ $desc .= ' (the focus word)' ;
244
+ }
245
+ elsif ($head_aft_bef_s eq 'a')
246
+ {
247
+ $desc .= ' (after the focus word)' ;
248
+ }
249
+ elsif ($head_aft_bef_s eq 'b')
250
+ {
251
+ $desc .= ' (before the focus word)' ;
252
+ }
253
+ }
254
+ elsif ($head_aft_bef_s eq '0')
255
+ {
256
+ $desc = 'head = 0 instead of ' ;
257
+ if ($head_aft_bef_g eq 'a')
258
+ {
259
+ $desc.= 'after ' ;
260
+ }
261
+ if ($head_aft_bef_g eq 'b')
262
+ {
263
+ $desc.= 'before ' ;
264
+ }
265
+ $desc .= 'the focus word' ;
266
+ }
267
+ elsif ($head_aft_bef_g eq '0')
268
+ {
269
+ $desc = 'head is ' ;
270
+ if ($head_aft_bef_g eq 'a')
271
+ {
272
+ $desc.= 'after ' ;
273
+ }
274
+ if ($head_aft_bef_g eq 'b')
275
+ {
276
+ $desc.= 'before ' ;
277
+ }
278
+ $desc .= 'the focus word instead of 0' ;
279
+ }
280
+ else
281
+ {
282
+ $desc = num_as_word($head_err) ;
283
+ if ($head_err < 0)
284
+ {
285
+ $desc .= ' before' ;
286
+ }
287
+ else
288
+ {
289
+ $desc .= ' after' ;
290
+ }
291
+
292
+ $desc = 'head '.$desc.' the correct head ' ;
293
+
294
+ if ($head_aft_bef_s eq '0')
295
+ {
296
+ $desc .= '(0' ;
297
+ }
298
+ elsif ($head_aft_bef_s eq 'e')
299
+ {
300
+ $desc .= '(the focus word' ;
301
+ }
302
+ elsif ($head_aft_bef_s eq 'a')
303
+ {
304
+ $desc .= '(after the focus word' ;
305
+ }
306
+ elsif ($head_aft_bef_s eq 'b')
307
+ {
308
+ $desc .= '(before the focus word' ;
309
+ }
310
+
311
+ if ($head_aft_bef_g ne $head_aft_bef_s)
312
+ {
313
+ $desc .= ' instead of' ;
314
+ if ($head_aft_bef_s eq '0')
315
+ {
316
+ $desc .= '0' ;
317
+ }
318
+ elsif ($head_aft_bef_s eq 'e')
319
+ {
320
+ $desc .= 'the focus word' ;
321
+ }
322
+ elsif ($head_aft_bef_s eq 'a')
323
+ {
324
+ $desc .= 'after the focus word' ;
325
+ }
326
+ elsif ($head_aft_bef_s eq 'b')
327
+ {
328
+ $desc .= 'before the focus word' ;
329
+ }
330
+ }
331
+
332
+ $desc .= ')' ;
333
+ }
334
+
335
+ $desc .= ', ' ;
336
+
337
+ if ($dep_err eq '-')
338
+ {
339
+ $desc .= 'correct dependency' ;
340
+ }
341
+ else
342
+ {
343
+ ($dep_g, $dep_s) = ($dep_err =~ /^(.*)->(.*)$/) ;
344
+ $desc .= sprintf('dependency "%s" instead of "%s"', $dep_s, $dep_g) ;
345
+ }
346
+
347
+ return($desc) ;
348
+
349
+ } # describe_err
350
+
351
+ sub get_context
352
+ { # get_context
353
+
354
+ my ($sent, $i_w) = @_ ;
355
+ my ($w_2, $w_1, $w1, $w2) ;
356
+ my ($p_2, $p_1, $p1, $p2) ;
357
+
358
+ if ($i_w >= 2)
359
+ {
360
+ $w_2 = ${${$sent}[$i_w-2]}{word} ;
361
+ $p_2 = ${${$sent}[$i_w-2]}{pos} ;
362
+ }
363
+ else
364
+ {
365
+ $w_2 = $START ;
366
+ $p_2 = $START ;
367
+ }
368
+
369
+ if ($i_w >= 1)
370
+ {
371
+ $w_1 = ${${$sent}[$i_w-1]}{word} ;
372
+ $p_1 = ${${$sent}[$i_w-1]}{pos} ;
373
+ }
374
+ else
375
+ {
376
+ $w_1 = $START ;
377
+ $p_1 = $START ;
378
+ }
379
+
380
+ if ($i_w <= scalar @{$sent}-2)
381
+ {
382
+ $w1 = ${${$sent}[$i_w+1]}{word} ;
383
+ $p1 = ${${$sent}[$i_w+1]}{pos} ;
384
+ }
385
+ else
386
+ {
387
+ $w1 = $END ;
388
+ $p1 = $END ;
389
+ }
390
+
391
+ if ($i_w <= scalar @{$sent}-3)
392
+ {
393
+ $w2 = ${${$sent}[$i_w+2]}{word} ;
394
+ $p2 = ${${$sent}[$i_w+2]}{pos} ;
395
+ }
396
+ else
397
+ {
398
+ $w2 = $END ;
399
+ $p2 = $END ;
400
+ }
401
+
402
+ return ($w_2, $w_1, $w1, $w2, $p_2, $p_1, $p1, $p2) ;
403
+
404
+ } # get_context
405
+
406
+ sub read_sent
407
+ { # read_sent
408
+
409
+ my ($sent_gold, $sent_sys) = @_ ;
410
+ my ($line_g, $line_s, $new_sent) ;
411
+ my (%fields_g, %fields_s) ;
412
+
413
+ $new_sent = 1 ;
414
+
415
+ @{$sent_gold} = () ;
416
+ @{$sent_sys} = () ;
417
+
418
+ while (1)
419
+ { # main reading loop
420
+
421
+ $line_g = <GOLD> ;
422
+ $line_s = <SYS> ;
423
+
424
+ $line_num++ ;
425
+
426
+ # system output has fewer lines than gold standard
427
+ if ((defined $line_g) && (! defined $line_s))
428
+ {
429
+ printf STDERR "line mismatch, line %d:\n", $line_num ;
430
+ printf STDERR " gold: %s", $line_g ;
431
+ printf STDERR " sys : past end of file\n" ;
432
+ exit(1) ;
433
+ }
434
+
435
+ # system output has more lines than gold standard
436
+ if ((! defined $line_g) && (defined $line_s))
437
+ {
438
+ printf STDERR "line mismatch, line %d:\n", $line_num ;
439
+ printf STDERR " gold: past end of file\n" ;
440
+ printf STDERR " sys : %s", $line_s ;
441
+ exit(1) ;
442
+ }
443
+
444
+ # end of file reached for both
445
+ if ((! defined $line_g) && (! defined $line_s))
446
+ {
447
+ return (1) ;
448
+ }
449
+
450
+ # one contains end of sentence but other one does not
451
+ if (($line_g =~ /^\s+$/) != ($line_s =~ /^\s+$/))
452
+ {
453
+ printf STDERR "line mismatch, line %d:\n", $line_num ;
454
+ printf STDERR " gold: %s", $line_g ;
455
+ printf STDERR " sys : %s", $line_s ;
456
+ exit(1) ;
457
+ }
458
+
459
+ # end of sentence reached
460
+ if ($line_g =~ /^\s+$/)
461
+ {
462
+ return(0) ;
463
+ }
464
+
465
+ # now both lines contain information
466
+
467
+ if ($new_sent)
468
+ {
469
+ $new_sent = 0 ;
470
+ }
471
+
472
+ # 'official' column names
473
+ # options.output = ['id','form','lemma','cpostag','postag',
474
+ # 'feats','head','deprel','phead','pdeprel']
475
+
476
+ @fields_g{'word', 'pos', 'head', 'dep'} = (split (/\s+/, $line_g))[1, 3, 6, 7] ;
477
+
478
+ push @{$sent_gold}, { %fields_g } ;
479
+
480
+ @fields_s{'word', 'pos', 'head', 'dep'} = (split (/\s+/, $line_s))[1, 3, 6, 7] ;
481
+
482
+ if (($fields_g{word} ne $fields_s{word})
483
+ ||
484
+ ($fields_g{pos} ne $fields_s{pos}))
485
+ {
486
+ printf STDERR "Word/pos mismatch, line %d:\n", $line_num ;
487
+ printf STDERR " gold: %s", $line_g ;
488
+ printf STDERR " sys : %s", $line_s ;
489
+ exit(1) ;
490
+ }
491
+
492
+ push @{$sent_sys}, { %fields_s } ;
493
+
494
+ } # main reading loop
495
+
496
+ } # read_sent
497
+
498
+ ################################################################################
499
+ ### main ###
500
+ ################################################################################
501
+
502
+ our ($opt_g, $opt_s, $opt_o, $opt_h, $opt_v, $opt_q, $opt_p, $opt_b) ;
503
+
504
+ my ($sent_num, $eof, $word_num, @err_sent) ;
505
+ my (@sent_gold, @sent_sys, @starts) ;
506
+ my ($word, $pos, $wp, $head_g, $dep_g, $head_s, $dep_s) ;
507
+ my (%counts, $err_head, $err_dep, $con, $con1, $con_pos, $con_pos1, $thresh) ;
508
+ my ($head_err, $dep_err, @cur_err, %err_counts, $err_counter, $err_desc) ;
509
+ my ($loc_con, %loc_con_err_counts, %err_desc) ;
510
+ my ($head_aft_bef_g, $head_aft_bef_s, $head_aft_bef) ;
511
+ my ($con_bef, $con_aft, $con_bef_2, $con_aft_2, @bits, @e_bits, @v_con, @v_con_pos) ;
512
+ my ($con_pos_bef, $con_pos_aft, $con_pos_bef_2, $con_pos_aft_2) ;
513
+ my ($max_word_len, $max_pos_len, $max_con_len, $max_con_pos_len) ;
514
+ my ($max_word_spec_len, $max_con_bef_len, $max_con_aft_len) ;
515
+ my (%freq_err, $err) ;
516
+
517
+ my ($i, $j, $i_w, $l, $n_args) ;
518
+ my ($w_2, $w_1, $w1, $w2) ;
519
+ my ($wp_2, $wp_1, $wp1, $wp2) ;
520
+ my ($p_2, $p_1, $p1, $p2) ;
521
+
522
+ my ($short_output) ;
523
+ my ($score_on_punct) ;
524
+ $counts{punct} = 0; # initialize
525
+
526
+ getopts("g:o:s:qvhpb") ;
527
+
528
+ if (defined $opt_v)
529
+ {
530
+ my $id = '$Id: eval.pl,v 1.9 2006/05/09 20:30:01 yuval Exp $';
531
+ my @parts = split ' ',$id;
532
+ print "Version $parts[2]\n";
533
+ exit(0);
534
+ }
535
+
536
+ if ((defined $opt_h) || ((! defined $opt_g) && (! defined $opt_s)))
537
+ {
538
+ die $usage ;
539
+ }
540
+
541
+ if (! defined $opt_g)
542
+ {
543
+ die "Gold standard file (-g) missing\n" ;
544
+ }
545
+
546
+ if (! defined $opt_s)
547
+ {
548
+ die "System output file (-s) missing\n" ;
549
+ }
550
+
551
+ if (! defined $opt_o)
552
+ {
553
+ $opt_o = '-' ;
554
+ }
555
+
556
+ if (defined $opt_q)
557
+ {
558
+ $short_output = 1 ;
559
+ } else {
560
+ $short_output = 0 ;
561
+ }
562
+
563
+ if (defined $opt_p)
564
+ {
565
+ $score_on_punct = 1 ;
566
+ } else {
567
+ $score_on_punct = 0 ;
568
+ }
569
+
570
+ $line_num = 0 ;
571
+ $sent_num = 0 ;
572
+ $eof = 0 ;
573
+
574
+ @err_sent = () ;
575
+ @starts = () ;
576
+
577
+ %{$err_sent[0]} = () ;
578
+
579
+ $max_pos_len = length('CPOS') ;
580
+
581
+ ################################################################################
582
+ ### reading input ###
583
+ ################################################################################
584
+
585
+ open (GOLD, "<$opt_g") || die "Could not open gold standard file $opt_g\n" ;
586
+ open (SYS, "<$opt_s") || die "Could not open system output file $opt_s\n" ;
587
+ open (OUT, ">$opt_o") || die "Could not open output file $opt_o\n" ;
588
+
589
+
590
+ if (defined $opt_b) { # produce output similar to evalb
591
+ print OUT " Sent. Attachment Correct Scoring \n";
592
+ print OUT " ID Tokens - Unlab. Lab. HEAD HEAD+DEPREL tokens - - - -\n";
593
+ print OUT " ============================================================================\n";
594
+ }
595
+
596
+
597
+ while (! $eof)
598
+ { # main reading loop
599
+
600
+ $starts[$sent_num] = $line_num+1 ;
601
+ $eof = read_sent(\@sent_gold, \@sent_sys) ;
602
+
603
+ $sent_num++ ;
604
+
605
+ %{$err_sent[$sent_num]} = () ;
606
+ $word_num = scalar @sent_gold ;
607
+
608
+ # for accuracy per sentence
609
+ my %sent_counts = ( tot => 0,
610
+ err_any => 0,
611
+ err_head => 0
612
+ );
613
+
614
+ # printf "$sent_num $word_num\n" ;
615
+
616
+ my @frames_g = ('** '); # the initial frame for the virtual root
617
+ my @frames_s = ('** '); # the initial frame for the virtual root
618
+ foreach $i_w (0 .. $word_num-1)
619
+ { # loop on words
620
+ push @frames_g, ''; # initialize
621
+ push @frames_s, ''; # initialize
622
+ }
623
+
624
+ foreach $i_w (0 .. $word_num-1)
625
+ { # loop on words
626
+
627
+ ($word, $pos, $head_g, $dep_g)
628
+ = @{$sent_gold[$i_w]}{'word', 'pos', 'head', 'dep'} ;
629
+ $wp = $word.' / '.$pos ;
630
+
631
+ # printf "%d: %s %s %s %s\n", $i_w, $word, $pos, $head_g, $dep_g ;
632
+
633
+ if ((! $score_on_punct) && is_uni_punct($word))
634
+ {
635
+ $counts{punct}++ ;
636
+ # ignore punctuations
637
+ next ;
638
+ }
639
+
640
+ if (length($pos) > $max_pos_len)
641
+ {
642
+ $max_pos_len = length($pos) ;
643
+ }
644
+
645
+ ($head_s, $dep_s) = @{$sent_sys[$i_w]}{'head', 'dep'} ;
646
+
647
+ $counts{tot}++ ;
648
+ $counts{word}{$wp}{tot}++ ;
649
+ $counts{pos}{$pos}{tot}++ ;
650
+ $counts{head}{$head_g-$i_w-1}{tot}++ ;
651
+
652
+ # for frame confusions
653
+ # add child to frame of parent
654
+ $frames_g[$head_g] .= "$dep_g ";
655
+ $frames_s[$head_s] .= "$dep_s ";
656
+ # add to frame of token itself
657
+ $frames_g[$i_w+1] .= "*$dep_g* "; # $i_w+1 because $i_w starts counting at zero
658
+ $frames_s[$i_w+1] .= "*$dep_g* ";
659
+
660
+ # for precision and recall of DEPREL
661
+ $counts{dep}{$dep_g}{tot}++ ; # counts for gold standard deprels
662
+ $counts{dep2}{$dep_g}{$dep_s}++ ; # counts for confusions
663
+ $counts{dep_s}{$dep_s}{tot}++ ; # counts for system deprels
664
+ $counts{all_dep}{$dep_g} = 1 ; # list of all deprels that occur ...
665
+ $counts{all_dep}{$dep_s} = 1 ; # ... in either gold or system output
666
+
667
+ # for precision and recall of HEAD direction
668
+ my $dir_g;
669
+ if ($head_g == 0) {
670
+ $dir_g = 'to_root';
671
+ } elsif ($head_g < $i_w+1) { # $i_w+1 because $i_w starts counting at zero
672
+ # also below
673
+ $dir_g = 'left';
674
+ } elsif ($head_g > $i_w+1) {
675
+ $dir_g = 'right';
676
+ } else {
677
+ # token links to itself; should never happen in correct gold standard
678
+ $dir_g = 'self';
679
+ }
680
+ my $dir_s;
681
+ if ($head_s == 0) {
682
+ $dir_s = 'to_root';
683
+ } elsif ($head_s < $i_w+1) {
684
+ $dir_s = 'left';
685
+ } elsif ($head_s > $i_w+1) {
686
+ $dir_s = 'right';
687
+ } else {
688
+ # token links to itself; should not happen in good system
689
+ # (but not forbidden in shared task)
690
+ $dir_s = 'self';
691
+ }
692
+ $counts{dir_g}{$dir_g}{tot}++ ; # counts for gold standard head direction
693
+ $counts{dir2}{$dir_g}{$dir_s}++ ; # counts for confusions
694
+ $counts{dir_s}{$dir_s}{tot}++ ; # counts for system head direction
695
+
696
+ # for precision and recall of HEAD distance
697
+ my $dist_g;
698
+ if ($head_g == 0) {
699
+ $dist_g = 'to_root';
700
+ } elsif ( abs($head_g - ($i_w+1)) <= 1 ) {
701
+ $dist_g = '1'; # includes the 'self' cases
702
+ } elsif ( abs($head_g - ($i_w+1)) <= 2 ) {
703
+ $dist_g = '2';
704
+ } elsif ( abs($head_g - ($i_w+1)) <= 6 ) {
705
+ $dist_g = '3-6';
706
+ } else {
707
+ $dist_g = '7-...';
708
+ }
709
+ my $dist_s;
710
+ if ($head_s == 0) {
711
+ $dist_s = 'to_root';
712
+ } elsif ( abs($head_s - ($i_w+1)) <= 1 ) {
713
+ $dist_s = '1'; # includes the 'self' cases
714
+ } elsif ( abs($head_s - ($i_w+1)) <= 2 ) {
715
+ $dist_s = '2';
716
+ } elsif ( abs($head_s - ($i_w+1)) <= 6 ) {
717
+ $dist_s = '3-6';
718
+ } else {
719
+ $dist_s = '7-...';
720
+ }
721
+ $counts{dist_g}{$dist_g}{tot}++ ; # counts for gold standard head distance
722
+ $counts{dist2}{$dist_g}{$dist_s}++ ; # counts for confusions
723
+ $counts{dist_s}{$dist_s}{tot}++ ; # counts for system head distance
724
+
725
+
726
+ $err_head = ($head_g ne $head_s) ; # error in head
727
+ $err_dep = ($dep_g ne $dep_s) ; # error in deprel
728
+
729
+ $head_err = '-' ;
730
+ $dep_err = '-' ;
731
+
732
+ # for accuracy per sentence
733
+ $sent_counts{tot}++ ;
734
+ if ($err_dep || $err_head) {
735
+ $sent_counts{err_any}++ ;
736
+ }
737
+ if ($err_head) {
738
+ $sent_counts{err_head}++ ;
739
+ }
740
+
741
+ # total counts and counts for CPOS involved in errors
742
+
743
+ if ($head_g eq '0')
744
+ {
745
+ $head_aft_bef_g = '0' ;
746
+ }
747
+ elsif ($head_g eq $i_w+1)
748
+ {
749
+ $head_aft_bef_g = 'e' ;
750
+ }
751
+ else
752
+ {
753
+ $head_aft_bef_g = ($head_g <= $i_w+1 ? 'b' : 'a') ;
754
+ }
755
+
756
+ if ($head_s eq '0')
757
+ {
758
+ $head_aft_bef_s = '0' ;
759
+ }
760
+ elsif ($head_s eq $i_w+1)
761
+ {
762
+ $head_aft_bef_s = 'e' ;
763
+ }
764
+ else
765
+ {
766
+ $head_aft_bef_s = ($head_s <= $i_w+1 ? 'b' : 'a') ;
767
+ }
768
+
769
+ $head_aft_bef = $head_aft_bef_g.$head_aft_bef_s ;
770
+
771
+ if ($err_head)
772
+ {
773
+ if ($head_aft_bef_s eq '0')
774
+ {
775
+ $head_err = 0 ;
776
+ }
777
+ else
778
+ {
779
+ $head_err = $head_s-$head_g ;
780
+ }
781
+
782
+ $err_sent[$sent_num]{head}++ ;
783
+ $counts{err_head}{tot}++ ;
784
+ $counts{err_head}{$head_err}++ ;
785
+
786
+ $counts{word}{err_head}{$wp}++ ;
787
+ $counts{pos}{$pos}{err_head}{tot}++ ;
788
+ $counts{pos}{$pos}{err_head}{$head_err}++ ;
789
+ }
790
+
791
+ if ($err_dep)
792
+ {
793
+ $dep_err = $dep_g.'->'.$dep_s ;
794
+ $err_sent[$sent_num]{dep}++ ;
795
+ $counts{err_dep}{tot}++ ;
796
+ $counts{err_dep}{$dep_err}++ ;
797
+
798
+ $counts{word}{err_dep}{$wp}++ ;
799
+ $counts{pos}{$pos}{err_dep}{tot}++ ;
800
+ $counts{pos}{$pos}{err_dep}{$dep_err}++ ;
801
+
802
+ if ($err_head)
803
+ {
804
+ $counts{err_both}++ ;
805
+ $counts{pos}{$pos}{err_both}++ ;
806
+ }
807
+ }
808
+
809
+ ### DEPREL + ATTACHMENT
810
+ if ((!$err_dep) && ($err_head)) {
811
+ $counts{err_head_corr_dep}{tot}++ ;
812
+ $counts{err_head_corr_dep}{$dep_s}++ ;
813
+ }
814
+ ### DEPREL + ATTACHMENT
815
+
816
+ # counts for words involved in errors
817
+
818
+ if (! ($err_head || $err_dep))
819
+ {
820
+ next ;
821
+ }
822
+
823
+ $err_sent[$sent_num]{word}++ ;
824
+ $counts{err_any}++ ;
825
+ $counts{word}{err_any}{$wp}++ ;
826
+ $counts{pos}{$pos}{err_any}++ ;
827
+
828
+ ($w_2, $w_1, $w1, $w2, $p_2, $p_1, $p1, $p2) = get_context(\@sent_gold, $i_w) ;
829
+
830
+ if ($w_2 ne $START)
831
+ {
832
+ $wp_2 = $w_2.' / '.$p_2 ;
833
+ }
834
+ else
835
+ {
836
+ $wp_2 = $w_2 ;
837
+ }
838
+
839
+ if ($w_1 ne $START)
840
+ {
841
+ $wp_1 = $w_1.' / '.$p_1 ;
842
+ }
843
+ else
844
+ {
845
+ $wp_1 = $w_1 ;
846
+ }
847
+
848
+ if ($w1 ne $END)
849
+ {
850
+ $wp1 = $w1.' / '.$p1 ;
851
+ }
852
+ else
853
+ {
854
+ $wp1 = $w1 ;
855
+ }
856
+
857
+ if ($w2 ne $END)
858
+ {
859
+ $wp2 = $w2.' / '.$p2 ;
860
+ }
861
+ else
862
+ {
863
+ $wp2 = $w2 ;
864
+ }
865
+
866
+ $con_bef = $wp_1 ;
867
+ $con_bef_2 = $wp_2.' + '.$wp_1 ;
868
+ $con_aft = $wp1 ;
869
+ $con_aft_2 = $wp1.' + '.$wp2 ;
870
+
871
+ $con_pos_bef = $p_1 ;
872
+ $con_pos_bef_2 = $p_2.'+'.$p_1 ;
873
+ $con_pos_aft = $p1 ;
874
+ $con_pos_aft_2 = $p1.'+'.$p2 ;
875
+
876
+ if ($w_1 ne $START)
877
+ {
878
+ # do not count '.S' as a word context
879
+ $counts{con_bef_2}{tot}{$con_bef_2}++ ;
880
+ $counts{con_bef_2}{err_head}{$con_bef_2} += $err_head ;
881
+ $counts{con_bef_2}{err_dep}{$con_bef_2} += $err_dep ;
882
+ $counts{con_bef}{tot}{$con_bef}++ ;
883
+ $counts{con_bef}{err_head}{$con_bef} += $err_head ;
884
+ $counts{con_bef}{err_dep}{$con_bef} += $err_dep ;
885
+ }
886
+
887
+ if ($w1 ne $END)
888
+ {
889
+ # do not count '.E' as a word context
890
+ $counts{con_aft_2}{tot}{$con_aft_2}++ ;
891
+ $counts{con_aft_2}{err_head}{$con_aft_2} += $err_head ;
892
+ $counts{con_aft_2}{err_dep}{$con_aft_2} += $err_dep ;
893
+ $counts{con_aft}{tot}{$con_aft}++ ;
894
+ $counts{con_aft}{err_head}{$con_aft} += $err_head ;
895
+ $counts{con_aft}{err_dep}{$con_aft} += $err_dep ;
896
+ }
897
+
898
+ $counts{con_pos_bef_2}{tot}{$con_pos_bef_2}++ ;
899
+ $counts{con_pos_bef_2}{err_head}{$con_pos_bef_2} += $err_head ;
900
+ $counts{con_pos_bef_2}{err_dep}{$con_pos_bef_2} += $err_dep ;
901
+ $counts{con_pos_bef}{tot}{$con_pos_bef}++ ;
902
+ $counts{con_pos_bef}{err_head}{$con_pos_bef} += $err_head ;
903
+ $counts{con_pos_bef}{err_dep}{$con_pos_bef} += $err_dep ;
904
+
905
+ $counts{con_pos_aft_2}{tot}{$con_pos_aft_2}++ ;
906
+ $counts{con_pos_aft_2}{err_head}{$con_pos_aft_2} += $err_head ;
907
+ $counts{con_pos_aft_2}{err_dep}{$con_pos_aft_2} += $err_dep ;
908
+ $counts{con_pos_aft}{tot}{$con_pos_aft}++ ;
909
+ $counts{con_pos_aft}{err_head}{$con_pos_aft} += $err_head ;
910
+ $counts{con_pos_aft}{err_dep}{$con_pos_aft} += $err_dep ;
911
+
912
+ $err = $head_err.$sep.$head_aft_bef.$sep.$dep_err ;
913
+ $freq_err{$err}++ ;
914
+
915
+ } # loop on words
916
+
917
+ foreach $i_w (0 .. $word_num) # including one for the virtual root
918
+ { # loop on words
919
+ if ($frames_g[$i_w] ne $frames_s[$i_w]) {
920
+ $counts{frame2}{"$frames_g[$i_w]/ $frames_s[$i_w]"}++ ;
921
+ }
922
+ }
923
+
924
+ if (defined $opt_b) { # produce output similar to evalb
925
+ if ($word_num > 0) {
926
+ my ($unlabeled,$labeled) = ('NaN', 'NaN');
927
+ if ($sent_counts{tot} > 0) { # there are scoring tokens
928
+ $unlabeled = 100-$sent_counts{err_head}*100.0/$sent_counts{tot};
929
+ $labeled = 100-$sent_counts{err_any} *100.0/$sent_counts{tot};
930
+ }
931
+ printf OUT " %4d %4d 0 %6.2f %6.2f %4d %4d %4d 0 0 0 0\n",
932
+ $sent_num, $word_num,
933
+ $unlabeled, $labeled,
934
+ $sent_counts{tot}-$sent_counts{err_head},
935
+ $sent_counts{tot}-$sent_counts{err_any},
936
+ $sent_counts{tot},;
937
+ }
938
+ }
939
+
940
+ } # main reading loop
941
+
942
+ ################################################################################
943
+ ### printing output ###
944
+ ################################################################################
945
+
946
+ if (defined $opt_b) { # produce output similar to evalb
947
+ print OUT "\n\n";
948
+ }
949
+ printf OUT " Labeled attachment score: %d / %d * 100 = %.2f %%\n",
950
+ $counts{tot}-$counts{err_any}, $counts{tot}, 100-$counts{err_any}*100.0/$counts{tot} ;
951
+ printf OUT " Unlabeled attachment score: %d / %d * 100 = %.2f %%\n",
952
+ $counts{tot}-$counts{err_head}{tot}, $counts{tot}, 100-$counts{err_head}{tot}*100.0/$counts{tot} ;
953
+ printf OUT " Label accuracy score: %d / %d * 100 = %.2f %%\n",
954
+ $counts{tot}-$counts{err_dep}{tot}, $counts{tot}, 100-$counts{err_dep}{tot}*100.0/$counts{tot} ;
955
+
956
+ if ($short_output)
957
+ {
958
+ exit(0) ;
959
+ }
960
+ printf OUT "\n %s\n\n", '=' x 80 ;
961
+ printf OUT " Evaluation of the results in %s\n vs. gold standard %s:\n\n", $opt_s, $opt_g ;
962
+
963
+ printf OUT " Legend: '%s' - the beginning of a sentence, '%s' - the end of a sentence\n\n", $START, $END ;
964
+
965
+ printf OUT " Number of non-scoring tokens: $counts{punct}\n\n";
966
+
967
+ printf OUT " The overall accuracy and its distribution over CPOSTAGs\n\n" ;
968
+ printf OUT "%s\n", " -----------+-------+-------+------+-------+------+-------+-------" ;
969
+
970
+ printf OUT " %-10s | %-5s | %-5s | %% | %-5s | %% | %-5s | %%\n",
971
+ 'Accuracy', 'words', 'right', 'right', 'both' ;
972
+ printf OUT " %-10s | %-5s | %-5s | | %-5s | | %-5s |\n",
973
+ ' ', ' ', 'head', ' dep', 'right' ;
974
+
975
+ printf OUT "%s\n", " -----------+-------+-------+------+-------+------+-------+-------" ;
976
+
977
+ printf OUT " %-10s | %5d | %5d | %3.0f%% | %5d | %3.0f%% | %5d | %3.0f%%\n",
978
+ 'total', $counts{tot},
979
+ $counts{tot}-$counts{err_head}{tot}, 100-$counts{err_head}{tot}*100.0/$counts{tot},
980
+ $counts{tot}-$counts{err_dep}{tot}, 100-$counts{err_dep}{tot}*100.0/$counts{tot},
981
+ $counts{tot}-$counts{err_any}, 100-$counts{err_any}*100.0/$counts{tot} ;
982
+
983
+ printf OUT "%s\n", " -----------+-------+-------+------+-------+------+-------+-------" ;
984
+
985
+ foreach $pos (sort {$counts{pos}{$b}{tot} <=> $counts{pos}{$a}{tot}} keys %{$counts{pos}})
986
+ {
987
+ if (! defined($counts{pos}{$pos}{err_head}{tot}))
988
+ {
989
+ $counts{pos}{$pos}{err_head}{tot} = 0 ;
990
+ }
991
+ if (! defined($counts{pos}{$pos}{err_dep}{tot}))
992
+ {
993
+ $counts{pos}{$pos}{err_dep}{tot} = 0 ;
994
+ }
995
+ if (! defined($counts{pos}{$pos}{err_any}))
996
+ {
997
+ $counts{pos}{$pos}{err_any} = 0 ;
998
+ }
999
+
1000
+ printf OUT " %-10s | %5d | %5d | %3.0f%% | %5d | %3.0f%% | %5d | %3.0f%%\n",
1001
+ $pos, $counts{pos}{$pos}{tot},
1002
+ $counts{pos}{$pos}{tot}-$counts{pos}{$pos}{err_head}{tot}, 100-$counts{pos}{$pos}{err_head}{tot}*100.0/$counts{pos}{$pos}{tot},
1003
+ $counts{pos}{$pos}{tot}-$counts{pos}{$pos}{err_dep}{tot}, 100-$counts{pos}{$pos}{err_dep}{tot}*100.0/$counts{pos}{$pos}{tot},
1004
+ $counts{pos}{$pos}{tot}-$counts{pos}{$pos}{err_any}, 100-$counts{pos}{$pos}{err_any}*100.0/$counts{pos}{$pos}{tot} ;
1005
+ }
1006
+
1007
+ printf OUT "%s\n", " -----------+-------+-------+------+-------+------+-------+-------" ;
1008
+
1009
+ printf OUT "\n\n" ;
1010
+
1011
+ printf OUT " The overall error rate and its distribution over CPOSTAGs\n\n" ;
1012
+ printf OUT "%s\n", " -----------+-------+-------+------+-------+------+-------+-------" ;
1013
+
1014
+ printf OUT " %-10s | %-5s | %-5s | %% | %-5s | %% | %-5s | %%\n",
1015
+ 'Error', 'words', 'head', ' dep', 'both' ;
1016
+ printf OUT " %-10s | %-5s | %-5s | | %-5s | | %-5s |\n",
1017
+
1018
+ 'Rate', ' ', 'err', ' err', 'wrong' ;
1019
+
1020
+ printf OUT "%s\n", " -----------+-------+-------+------+-------+------+-------+-------" ;
1021
+
1022
+ printf OUT " %-10s | %5d | %5d | %3.0f%% | %5d | %3.0f%% | %5d | %3.0f%%\n",
1023
+ 'total', $counts{tot},
1024
+ $counts{err_head}{tot}, $counts{err_head}{tot}*100.0/$counts{tot},
1025
+ $counts{err_dep}{tot}, $counts{err_dep}{tot}*100.0/$counts{tot},
1026
+ $counts{err_both}, $counts{err_both}*100.0/$counts{tot} ;
1027
+
1028
+ printf OUT "%s\n", " -----------+-------+-------+------+-------+------+-------+-------" ;
1029
+
1030
+ foreach $pos (sort {$counts{pos}{$b}{tot} <=> $counts{pos}{$a}{tot}} keys %{$counts{pos}})
1031
+ {
1032
+ if (! defined($counts{pos}{$pos}{err_both}))
1033
+ {
1034
+ $counts{pos}{$pos}{err_both} = 0 ;
1035
+ }
1036
+
1037
+ printf OUT " %-10s | %5d | %5d | %3.0f%% | %5d | %3.0f%% | %5d | %3.0f%%\n",
1038
+ $pos, $counts{pos}{$pos}{tot},
1039
+ $counts{pos}{$pos}{err_head}{tot}, $counts{pos}{$pos}{err_head}{tot}*100.0/$counts{pos}{$pos}{tot},
1040
+ $counts{pos}{$pos}{err_dep}{tot}, $counts{pos}{$pos}{err_dep}{tot}*100.0/$counts{pos}{$pos}{tot},
1041
+ $counts{pos}{$pos}{err_both}, $counts{pos}{$pos}{err_both}*100.0/$counts{pos}{$pos}{tot} ;
1042
+
1043
+ }
1044
+
1045
+ printf OUT "%s\n", " -----------+-------+-------+------+-------+------+-------+-------" ;
1046
+
1047
+ ### added by Sabine Buchholz
1048
+ printf OUT "\n\n";
1049
+ printf OUT " Precision and recall of DEPREL\n\n";
1050
+ printf OUT " ----------------+------+---------+--------+------------+---------------\n";
1051
+ printf OUT " deprel | gold | correct | system | recall (%%) | precision (%%) \n";
1052
+ printf OUT " ----------------+------+---------+--------+------------+---------------\n";
1053
+ foreach my $dep (sort keys %{$counts{all_dep}}) {
1054
+ # initialize
1055
+ my ($tot_corr, $tot_g, $tot_s, $prec, $rec) = (0, 0, 0, 'NaN', 'NaN');
1056
+
1057
+ if (defined($counts{dep2}{$dep}{$dep})) {
1058
+ $tot_corr = $counts{dep2}{$dep}{$dep};
1059
+ }
1060
+ if (defined($counts{dep}{$dep}{tot})) {
1061
+ $tot_g = $counts{dep}{$dep}{tot};
1062
+ $rec = sprintf("%.2f",$tot_corr / $tot_g * 100);
1063
+ }
1064
+ if (defined($counts{dep_s}{$dep}{tot})) {
1065
+ $tot_s = $counts{dep_s}{$dep}{tot};
1066
+ $prec = sprintf("%.2f",$tot_corr / $tot_s * 100);
1067
+ }
1068
+ printf OUT " %-15s | %4d | %7d | %6d | %10s | %13s\n",
1069
+ $dep, $tot_g, $tot_corr, $tot_s, $rec, $prec;
1070
+ }
1071
+
1072
+ ### DEPREL + ATTACHMENT:
1073
+ ### Same as Sabine's DEPREL apart from $tot_corr calculation
1074
+ printf OUT "\n\n";
1075
+ printf OUT " Precision and recall of DEPREL + ATTACHMENT\n\n";
1076
+ printf OUT " ----------------+------+---------+--------+------------+---------------\n";
1077
+ printf OUT " deprel | gold | correct | system | recall (%%) | precision (%%) \n";
1078
+ printf OUT " ----------------+------+---------+--------+------------+---------------\n";
1079
+ foreach my $dep (sort keys %{$counts{all_dep}}) {
1080
+ # initialize
1081
+ my ($tot_corr, $tot_g, $tot_s, $prec, $rec) = (0, 0, 0, 'NaN', 'NaN');
1082
+
1083
+ if (defined($counts{dep2}{$dep}{$dep})) {
1084
+ if (defined($counts{err_head_corr_dep}{$dep})) {
1085
+ $tot_corr = $counts{dep2}{$dep}{$dep} - $counts{err_head_corr_dep}{$dep};
1086
+ } else {
1087
+ $tot_corr = $counts{dep2}{$dep}{$dep};
1088
+ }
1089
+ }
1090
+ if (defined($counts{dep}{$dep}{tot})) {
1091
+ $tot_g = $counts{dep}{$dep}{tot};
1092
+ $rec = sprintf("%.2f",$tot_corr / $tot_g * 100);
1093
+ }
1094
+ if (defined($counts{dep_s}{$dep}{tot})) {
1095
+ $tot_s = $counts{dep_s}{$dep}{tot};
1096
+ $prec = sprintf("%.2f",$tot_corr / $tot_s * 100);
1097
+ }
1098
+ printf OUT " %-15s | %4d | %7d | %6d | %10s | %13s\n",
1099
+ $dep, $tot_g, $tot_corr, $tot_s, $rec, $prec;
1100
+ }
1101
+ ### DEPREL + ATTACHMENT
1102
+
1103
+ printf OUT "\n\n";
1104
+ printf OUT " Precision and recall of binned HEAD direction\n\n";
1105
+ printf OUT " ----------------+------+---------+--------+------------+---------------\n";
1106
+ printf OUT " direction | gold | correct | system | recall (%%) | precision (%%) \n";
1107
+ printf OUT " ----------------+------+---------+--------+------------+---------------\n";
1108
+ foreach my $dir ('to_root', 'left', 'right', 'self') {
1109
+ # initialize
1110
+ my ($tot_corr, $tot_g, $tot_s, $prec, $rec) = (0, 0, 0, 'NaN', 'NaN');
1111
+
1112
+ if (defined($counts{dir2}{$dir}{$dir})) {
1113
+ $tot_corr = $counts{dir2}{$dir}{$dir};
1114
+ }
1115
+ if (defined($counts{dir_g}{$dir}{tot})) {
1116
+ $tot_g = $counts{dir_g}{$dir}{tot};
1117
+ $rec = sprintf("%.2f",$tot_corr / $tot_g * 100);
1118
+ }
1119
+ if (defined($counts{dir_s}{$dir}{tot})) {
1120
+ $tot_s = $counts{dir_s}{$dir}{tot};
1121
+ $prec = sprintf("%.2f",$tot_corr / $tot_s * 100);
1122
+ }
1123
+ printf OUT " %-15s | %4d | %7d | %6d | %10s | %13s\n",
1124
+ $dir, $tot_g, $tot_corr, $tot_s, $rec, $prec;
1125
+ }
1126
+
1127
+ printf OUT "\n\n";
1128
+ printf OUT " Precision and recall of binned HEAD distance\n\n";
1129
+ printf OUT " ----------------+------+---------+--------+------------+---------------\n";
1130
+ printf OUT " distance | gold | correct | system | recall (%%) | precision (%%) \n";
1131
+ printf OUT " ----------------+------+---------+--------+------------+---------------\n";
1132
+ foreach my $dist ('to_root', '1', '2', '3-6', '7-...') {
1133
+ # initialize
1134
+ my ($tot_corr, $tot_g, $tot_s, $prec, $rec) = (0, 0, 0, 'NaN', 'NaN');
1135
+
1136
+ if (defined($counts{dist2}{$dist}{$dist})) {
1137
+ $tot_corr = $counts{dist2}{$dist}{$dist};
1138
+ }
1139
+ if (defined($counts{dist_g}{$dist}{tot})) {
1140
+ $tot_g = $counts{dist_g}{$dist}{tot};
1141
+ $rec = sprintf("%.2f",$tot_corr / $tot_g * 100);
1142
+ }
1143
+ if (defined($counts{dist_s}{$dist}{tot})) {
1144
+ $tot_s = $counts{dist_s}{$dist}{tot};
1145
+ $prec = sprintf("%.2f",$tot_corr / $tot_s * 100);
1146
+ }
1147
+ printf OUT " %-15s | %4d | %7d | %6d | %10s | %13s\n",
1148
+ $dist, $tot_g, $tot_corr, $tot_s, $rec, $prec;
1149
+ }
1150
+
1151
+ printf OUT "\n\n";
1152
+ printf OUT " Frame confusions (gold versus system; *...* marks the head token)\n\n";
1153
+ foreach my $frame (sort {$counts{frame2}{$b} <=> $counts{frame2}{$a}} keys %{$counts{frame2}})
1154
+ {
1155
+ if ($counts{frame2}{$frame} >= 5) # (make 5 a changeable threshold later)
1156
+ {
1157
+ printf OUT " %3d %s\n", $counts{frame2}{$frame}, $frame;
1158
+ }
1159
+ }
1160
+ ### end of: added by Sabine Buchholz
1161
+
1162
+
1163
+ #
1164
+ # Leave only the 5 words mostly involved in errors
1165
+ #
1166
+
1167
+
1168
+ $thresh = (sort {$b <=> $a} values %{$counts{word}{err_any}})[4] ;
1169
+
1170
+ # ensure enough space for title
1171
+ $max_word_len = length('word') ;
1172
+
1173
+ foreach $word (keys %{$counts{word}{err_any}})
1174
+ {
1175
+ if ($counts{word}{err_any}{$word} < $thresh)
1176
+ {
1177
+ delete $counts{word}{err_any}{$word} ;
1178
+ next ;
1179
+ }
1180
+
1181
+ $l = uni_len($word) ;
1182
+ if ($l > $max_word_len)
1183
+ {
1184
+ $max_word_len = $l ;
1185
+ }
1186
+ }
1187
+
1188
+ # filter a case when the difference between the error counts
1189
+ # for 2-word and 1-word contexts is small
1190
+ # (leave the 2-word context)
1191
+
1192
+ foreach $con (keys %{$counts{con_aft_2}{tot}})
1193
+ {
1194
+ ($w1) = split(/\+/, $con) ;
1195
+
1196
+ if (defined $counts{con_aft}{tot}{$w1} &&
1197
+ $counts{con_aft}{tot}{$w1}-$counts{con_aft_2}{tot}{$con} <= 1)
1198
+ {
1199
+ delete $counts{con_aft}{tot}{$w1} ;
1200
+ }
1201
+ }
1202
+
1203
+ foreach $con (keys %{$counts{con_bef_2}{tot}})
1204
+ {
1205
+ ($w_2, $w_1) = split(/\+/, $con) ;
1206
+
1207
+ if (defined $counts{con_bef}{tot}{$w_1} &&
1208
+ $counts{con_bef}{tot}{$w_1}-$counts{con_bef_2}{tot}{$con} <= 1)
1209
+ {
1210
+ delete $counts{con_bef}{tot}{$w_1} ;
1211
+ }
1212
+ }
1213
+
1214
+ foreach $con_pos (keys %{$counts{con_pos_aft_2}{tot}})
1215
+ {
1216
+ ($p1) = split(/\+/, $con_pos) ;
1217
+
1218
+ if (defined($counts{con_pos_aft}{tot}{$p1}) &&
1219
+ $counts{con_pos_aft}{tot}{$p1}-$counts{con_pos_aft_2}{tot}{$con_pos} <= 1)
1220
+ {
1221
+ delete $counts{con_pos_aft}{tot}{$p1} ;
1222
+ }
1223
+ }
1224
+
1225
+ foreach $con_pos (keys %{$counts{con_pos_bef_2}{tot}})
1226
+ {
1227
+ ($p_2, $p_1) = split(/\+/, $con_pos) ;
1228
+
1229
+ if (defined($counts{con_pos_bef}{tot}{$p_1}) &&
1230
+ $counts{con_pos_bef}{tot}{$p_1}-$counts{con_pos_bef_2}{tot}{$con_pos} <= 1)
1231
+ {
1232
+ delete $counts{con_pos_bef}{tot}{$p_1} ;
1233
+ }
1234
+ }
1235
+
1236
+ # for each context type, take the three contexts most involved in errors
1237
+
1238
+ $max_con_len = 0 ;
1239
+
1240
+ filter_context_counts($counts{con_bef_2}{tot}, $con_err_num, \$max_con_len) ;
1241
+
1242
+ filter_context_counts($counts{con_bef}{tot}, $con_err_num, \$max_con_len) ;
1243
+
1244
+ filter_context_counts($counts{con_aft}{tot}, $con_err_num, \$max_con_len) ;
1245
+
1246
+ filter_context_counts($counts{con_aft_2}{tot}, $con_err_num, \$max_con_len) ;
1247
+
1248
+ # for each CPOS context type, take the three CPOS contexts most involved in errors
1249
+
1250
+ $max_con_pos_len = 0 ;
1251
+
1252
+ $thresh = (sort {$b <=> $a} values %{$counts{con_pos_bef_2}{tot}})[$con_err_num-1] ;
1253
+
1254
+ foreach $con_pos (keys %{$counts{con_pos_bef_2}{tot}})
1255
+ {
1256
+ if ($counts{con_pos_bef_2}{tot}{$con_pos} < $thresh)
1257
+ {
1258
+ delete $counts{con_pos_bef_2}{tot}{$con_pos} ;
1259
+ next ;
1260
+ }
1261
+ if (length($con_pos) > $max_con_pos_len)
1262
+ {
1263
+ $max_con_pos_len = length($con_pos) ;
1264
+ }
1265
+ }
1266
+
1267
+ $thresh = (sort {$b <=> $a} values %{$counts{con_pos_bef}{tot}})[$con_err_num-1] ;
1268
+
1269
+ foreach $con_pos (keys %{$counts{con_pos_bef}{tot}})
1270
+ {
1271
+ if ($counts{con_pos_bef}{tot}{$con_pos} < $thresh)
1272
+ {
1273
+ delete $counts{con_pos_bef}{tot}{$con_pos} ;
1274
+ next ;
1275
+ }
1276
+ if (length($con_pos) > $max_con_pos_len)
1277
+ {
1278
+ $max_con_pos_len = length($con_pos) ;
1279
+ }
1280
+ }
1281
+
1282
+ $thresh = (sort {$b <=> $a} values %{$counts{con_pos_aft}{tot}})[$con_err_num-1] ;
1283
+
1284
+ foreach $con_pos (keys %{$counts{con_pos_aft}{tot}})
1285
+ {
1286
+ if ($counts{con_pos_aft}{tot}{$con_pos} < $thresh)
1287
+ {
1288
+ delete $counts{con_pos_aft}{tot}{$con_pos} ;
1289
+ next ;
1290
+ }
1291
+ if (length($con_pos) > $max_con_pos_len)
1292
+ {
1293
+ $max_con_pos_len = length($con_pos) ;
1294
+ }
1295
+ }
1296
+
1297
+ $thresh = (sort {$b <=> $a} values %{$counts{con_pos_aft_2}{tot}})[$con_err_num-1] ;
1298
+
1299
+ foreach $con_pos (keys %{$counts{con_pos_aft_2}{tot}})
1300
+ {
1301
+ if ($counts{con_pos_aft_2}{tot}{$con_pos} < $thresh)
1302
+ {
1303
+ delete $counts{con_pos_aft_2}{tot}{$con_pos} ;
1304
+ next ;
1305
+ }
1306
+ if (length($con_pos) > $max_con_pos_len)
1307
+ {
1308
+ $max_con_pos_len = length($con_pos) ;
1309
+ }
1310
+ }
1311
+
1312
+ # printing
1313
+
1314
+ # ------------- focus words
1315
+
1316
+ printf OUT "\n\n" ;
1317
+ printf OUT " %d focus words where most of the errors occur:\n\n", scalar keys %{$counts{word}{err_any}} ;
1318
+
1319
+ printf OUT " %-*s | %-4s | %-4s | %-4s | %-4s\n", $max_word_len, ' ', 'any', 'head', 'dep', 'both' ;
1320
+ printf OUT " %s-+------+------+------+------\n", '-' x $max_word_len;
1321
+
1322
+ foreach $word (sort {$counts{word}{err_any}{$b} <=> $counts{word}{err_any}{$a}} keys %{$counts{word}{err_any}})
1323
+ {
1324
+ if (!defined($counts{word}{err_head}{$word}))
1325
+ {
1326
+ $counts{word}{err_head}{$word} = 0 ;
1327
+ }
1328
+ if (! defined($counts{word}{err_dep}{$word}))
1329
+ {
1330
+ $counts{word}{err_dep}{$word} = 0 ;
1331
+ }
1332
+ if (! defined($counts{word}{err_any}{$word}))
1333
+ {
1334
+ $counts{word}{err_any}{$word} = 0;
1335
+ }
1336
+ printf OUT " %-*s | %4d | %4d | %4d | %4d\n",
1337
+ $max_word_len+length($word)-uni_len($word), $word, $counts{word}{err_any}{$word},
1338
+ $counts{word}{err_head}{$word},
1339
+ $counts{word}{err_dep}{$word},
1340
+ $counts{word}{err_dep}{$word}+$counts{word}{err_head}{$word}-$counts{word}{err_any}{$word} ;
1341
+ }
1342
+
1343
+ printf OUT " %s-+------+------+------+------\n", '-' x $max_word_len;
1344
+
1345
+ # ------------- contexts
1346
+
1347
+ printf OUT "\n\n" ;
1348
+
1349
+ printf OUT " one-token preceeding contexts where most of the errors occur:\n\n" ;
1350
+
1351
+ print_context($counts{con_bef}, $counts{con_pos_bef}, $max_con_len, $max_con_pos_len) ;
1352
+
1353
+ printf OUT " two-token preceeding contexts where most of the errors occur:\n\n" ;
1354
+
1355
+ print_context($counts{con_bef_2}, $counts{con_pos_bef_2}, $max_con_len, $max_con_pos_len) ;
1356
+
1357
+ printf OUT " one-token following contexts where most of the errors occur:\n\n" ;
1358
+
1359
+ print_context($counts{con_aft}, $counts{con_pos_aft}, $max_con_len, $max_con_pos_len) ;
1360
+
1361
+ printf OUT " two-token following contexts where most of the errors occur:\n\n" ;
1362
+
1363
+ print_context($counts{con_aft_2}, $counts{con_pos_aft_2}, $max_con_len, $max_con_pos_len) ;
1364
+
1365
+ # ------------- Sentences
1366
+
1367
+ printf OUT " Sentence with the highest number of word errors:\n" ;
1368
+ $i = (sort { (defined($err_sent[$b]{word}) && $err_sent[$b]{word})
1369
+ <=> (defined($err_sent[$a]{word}) && $err_sent[$a]{word}) } 1 .. $sent_num)[0] ;
1370
+ printf OUT " Sentence %d line %d, ", $i, $starts[$i-1] ;
1371
+ printf OUT "%d head errors, %d dependency errors, %d word errors\n",
1372
+ $err_sent[$i]{head}, $err_sent[$i]{dep}, $err_sent[$i]{word} ;
1373
+
1374
+ printf OUT "\n\n" ;
1375
+
1376
+ printf OUT " Sentence with the highest number of head errors:\n" ;
1377
+ $i = (sort { (defined($err_sent[$b]{head}) && $err_sent[$b]{head})
1378
+ <=> (defined($err_sent[$a]{head}) && $err_sent[$a]{head}) } 1 .. $sent_num)[0] ;
1379
+ printf OUT " Sentence %d line %d, ", $i, $starts[$i-1] ;
1380
+ printf OUT "%d head errors, %d dependency errors, %d word errors\n",
1381
+ $err_sent[$i]{head}, $err_sent[$i]{dep}, $err_sent[$i]{word} ;
1382
+
1383
+ printf OUT "\n\n" ;
1384
+
1385
+ printf OUT " Sentence with the highest number of dependency errors:\n" ;
1386
+ $i = (sort { (defined($err_sent[$b]{dep}) && $err_sent[$b]{dep})
1387
+ <=> (defined($err_sent[$a]{dep}) && $err_sent[$a]{dep}) } 1 .. $sent_num)[0] ;
1388
+ printf OUT " Sentence %d line %d, ", $i, $starts[$i-1] ;
1389
+ printf OUT "%d head errors, %d dependency errors, %d word errors\n",
1390
+ $err_sent[$i]{head}, $err_sent[$i]{dep}, $err_sent[$i]{word} ;
1391
+
1392
+ #
1393
+ # Second pass, collect statistics of the frequent errors
1394
+ #
1395
+
1396
+ # filter the errors, leave the most frequent $freq_err_num errors
1397
+
1398
+ $i = 0 ;
1399
+
1400
+ $thresh = (sort {$b <=> $a} values %freq_err)[$freq_err_num-1] ;
1401
+
1402
+ foreach $err (keys %freq_err)
1403
+ {
1404
+ if ($freq_err{$err} < $thresh)
1405
+ {
1406
+ delete $freq_err{$err} ;
1407
+ }
1408
+ }
1409
+
1410
+ # in case there are several errors with the threshold count
1411
+
1412
+ $freq_err_num = scalar keys %freq_err ;
1413
+
1414
+ %err_counts = () ;
1415
+
1416
+ $eof = 0 ;
1417
+
1418
+ seek (GOLD, 0, 0) ;
1419
+ seek (SYS, 0, 0) ;
1420
+
1421
+ while (! $eof)
1422
+ { # second reading loop
1423
+
1424
+ $eof = read_sent(\@sent_gold, \@sent_sys) ;
1425
+ $sent_num++ ;
1426
+
1427
+ $word_num = scalar @sent_gold ;
1428
+
1429
+ # printf "$sent_num $word_num\n" ;
1430
+
1431
+ foreach $i_w (0 .. $word_num-1)
1432
+ { # loop on words
1433
+ ($word, $pos, $head_g, $dep_g)
1434
+ = @{$sent_gold[$i_w]}{'word', 'pos', 'head', 'dep'} ;
1435
+
1436
+ # printf "%d: %s %s %s %s\n", $i_w, $word, $pos, $head_g, $dep_g ;
1437
+
1438
+ if ((! $score_on_punct) && is_uni_punct($word))
1439
+ {
1440
+ # ignore punctuations
1441
+ next ;
1442
+ }
1443
+
1444
+ ($head_s, $dep_s) = @{$sent_sys[$i_w]}{'head', 'dep'} ;
1445
+
1446
+ $err_head = ($head_g ne $head_s) ;
1447
+ $err_dep = ($dep_g ne $dep_s) ;
1448
+
1449
+ $head_err = '-' ;
1450
+ $dep_err = '-' ;
1451
+
1452
+ if ($head_g eq '0')
1453
+ {
1454
+ $head_aft_bef_g = '0' ;
1455
+ }
1456
+ elsif ($head_g eq $i_w+1)
1457
+ {
1458
+ $head_aft_bef_g = 'e' ;
1459
+ }
1460
+ else
1461
+ {
1462
+ $head_aft_bef_g = ($head_g <= $i_w+1 ? 'b' : 'a') ;
1463
+ }
1464
+
1465
+ if ($head_s eq '0')
1466
+ {
1467
+ $head_aft_bef_s = '0' ;
1468
+ }
1469
+ elsif ($head_s eq $i_w+1)
1470
+ {
1471
+ $head_aft_bef_s = 'e' ;
1472
+ }
1473
+ else
1474
+ {
1475
+ $head_aft_bef_s = ($head_s <= $i_w+1 ? 'b' : 'a') ;
1476
+ }
1477
+
1478
+ $head_aft_bef = $head_aft_bef_g.$head_aft_bef_s ;
1479
+
1480
+ if ($err_head)
1481
+ {
1482
+ if ($head_aft_bef_s eq '0')
1483
+ {
1484
+ $head_err = 0 ;
1485
+ }
1486
+ else
1487
+ {
1488
+ $head_err = $head_s-$head_g ;
1489
+ }
1490
+ }
1491
+
1492
+ if ($err_dep)
1493
+ {
1494
+ $dep_err = $dep_g.'->'.$dep_s ;
1495
+ }
1496
+
1497
+ if (! ($err_head || $err_dep))
1498
+ {
1499
+ next ;
1500
+ }
1501
+
1502
+ # handle only the most frequent errors
1503
+
1504
+ $err = $head_err.$sep.$head_aft_bef.$sep.$dep_err ;
1505
+
1506
+ if (! exists $freq_err{$err})
1507
+ {
1508
+ next ;
1509
+ }
1510
+
1511
+ ($w_2, $w_1, $w1, $w2, $p_2, $p_1, $p1, $p2) = get_context(\@sent_gold, $i_w) ;
1512
+
1513
+ $con_bef = $w_1 ;
1514
+ $con_bef_2 = $w_2.' + '.$w_1 ;
1515
+ $con_aft = $w1 ;
1516
+ $con_aft_2 = $w1.' + '.$w2 ;
1517
+
1518
+ $con_pos_bef = $p_1 ;
1519
+ $con_pos_bef_2 = $p_2.'+'.$p_1 ;
1520
+ $con_pos_aft = $p1 ;
1521
+ $con_pos_aft_2 = $p1.'+'.$p2 ;
1522
+
1523
+ @cur_err = ($con_pos_bef, $con_bef, $word, $pos, $con_pos_aft, $con_aft) ;
1524
+
1525
+ # printf "# %-25s %-15s %-10s %-25s %-3s %-30s\n",
1526
+ # $con_bef, $word, $pos, $con_aft, $head_err, $dep_err ;
1527
+
1528
+ @bits = (0, 0, 0, 0, 0, 0) ;
1529
+ $j = 0 ;
1530
+
1531
+ while ($j == 0)
1532
+ {
1533
+ for ($i = 0; $i <= $#bits; $i++)
1534
+ {
1535
+ if ($bits[$i] == 0)
1536
+ {
1537
+ $bits[$i] = 1 ;
1538
+ $j = 0 ;
1539
+ last ;
1540
+ }
1541
+ else
1542
+ {
1543
+ $bits[$i] = 0 ;
1544
+ $j = 1 ;
1545
+ }
1546
+ }
1547
+
1548
+ @e_bits = @cur_err ;
1549
+
1550
+ for ($i = 0; $i <= $#bits; $i++)
1551
+ {
1552
+ if (! $bits[$i])
1553
+ {
1554
+ $e_bits[$i] = '*' ;
1555
+ }
1556
+ }
1557
+
1558
+ # include also the last case which is the most general
1559
+ # (wildcards for everything)
1560
+ $err_counts{$err}{join($sep, @e_bits)}++ ;
1561
+
1562
+ }
1563
+
1564
+ } # loop on words
1565
+ } # second reading loop
1566
+
1567
+ printf OUT "\n\n" ;
1568
+ printf OUT " Specific errors, %d most frequent errors:", $freq_err_num ;
1569
+ printf OUT "\n %s\n", '=' x 41 ;
1570
+
1571
+
1572
+ # deleting local contexts which are too general
1573
+
1574
+ foreach $err (keys %err_counts)
1575
+ {
1576
+ foreach $loc_con (sort {$err_counts{$err}{$b} <=> $err_counts{$err}{$a}}
1577
+ keys %{$err_counts{$err}})
1578
+ {
1579
+ @cur_err = split(/\Q$sep\E/, $loc_con) ;
1580
+
1581
+ # In this loop, one or two elements of the local context are
1582
+ # replaced with '*' to make it more general. If the entry for
1583
+ # the general context has the same count it is removed.
1584
+
1585
+ foreach $i (0 .. $#cur_err)
1586
+ {
1587
+ $w1 = $cur_err[$i] ;
1588
+ if ($cur_err[$i] eq '*')
1589
+ {
1590
+ next ;
1591
+ }
1592
+ $cur_err[$i] = '*' ;
1593
+ $con1 = join($sep, @cur_err) ;
1594
+ if ( defined($err_counts{$err}{$con1}) && defined($err_counts{$err}{$loc_con})
1595
+ && ($err_counts{$err}{$con1} == $err_counts{$err}{$loc_con}))
1596
+ {
1597
+ delete $err_counts{$err}{$con1} ;
1598
+ }
1599
+ for ($j = $i+1; $j <=$#cur_err; $j++)
1600
+ {
1601
+ if ($cur_err[$j] eq '*')
1602
+ {
1603
+ next ;
1604
+ }
1605
+ $w2 = $cur_err[$j] ;
1606
+ $cur_err[$j] = '*' ;
1607
+ $con1 = join($sep, @cur_err) ;
1608
+ if ( defined($err_counts{$err}{$con1}) && defined($err_counts{$err}{$loc_con})
1609
+ && ($err_counts{$err}{$con1} == $err_counts{$err}{$loc_con}))
1610
+ {
1611
+ delete $err_counts{$err}{$con1} ;
1612
+ }
1613
+ $cur_err[$j] = $w2 ;
1614
+ }
1615
+ $cur_err[$i] = $w1 ;
1616
+ }
1617
+ }
1618
+ }
1619
+
1620
+ # Leaving only the topmost local contexts for each error
1621
+
1622
+ foreach $err (keys %err_counts)
1623
+ {
1624
+ $thresh = (sort {$b <=> $a} values %{$err_counts{$err}})[$spec_err_loc_con-1] || 0 ;
1625
+
1626
+ # of the threshold is too low, take the 2nd highest count
1627
+ # (the highest may be the total which is the generic case
1628
+ # and not relevant for printing)
1629
+
1630
+ if ($thresh < 5)
1631
+ {
1632
+ $thresh = (sort {$b <=> $a} values %{$err_counts{$err}})[1] ;
1633
+ }
1634
+
1635
+ foreach $loc_con (keys %{$err_counts{$err}})
1636
+ {
1637
+ if ($err_counts{$err}{$loc_con} < $thresh)
1638
+ {
1639
+ delete $err_counts{$err}{$loc_con} ;
1640
+ }
1641
+ else
1642
+ {
1643
+ if ($loc_con ne join($sep, ('*', '*', '*', '*', '*', '*')))
1644
+ {
1645
+ $loc_con_err_counts{$loc_con}{$err} = $err_counts{$err}{$loc_con} ;
1646
+ }
1647
+ }
1648
+ }
1649
+ }
1650
+
1651
+ # printing an error summary
1652
+
1653
+ # calculating the context field length
1654
+
1655
+ $max_word_spec_len= length('word') ;
1656
+ $max_con_aft_len = length('word') ;
1657
+ $max_con_bef_len = length('word') ;
1658
+ $max_con_pos_len = length('CPOS') ;
1659
+
1660
+ foreach $err (keys %err_counts)
1661
+ {
1662
+ foreach $loc_con (sort keys %{$err_counts{$err}})
1663
+ {
1664
+ ($con_pos_bef, $con_bef, $word, $pos, $con_pos_aft, $con_aft) =
1665
+ split(/\Q$sep\E/, $loc_con) ;
1666
+
1667
+ $l = uni_len($word) ;
1668
+ if ($l > $max_word_spec_len)
1669
+ {
1670
+ $max_word_spec_len = $l ;
1671
+ }
1672
+
1673
+ $l = uni_len($con_bef) ;
1674
+ if ($l > $max_con_bef_len)
1675
+ {
1676
+ $max_con_bef_len = $l ;
1677
+ }
1678
+
1679
+ $l = uni_len($con_aft) ;
1680
+ if ($l > $max_con_aft_len)
1681
+ {
1682
+ $max_con_aft_len = $l ;
1683
+ }
1684
+
1685
+ if (length($con_pos_aft) > $max_con_pos_len)
1686
+ {
1687
+ $max_con_pos_len = length($con_pos_aft) ;
1688
+ }
1689
+
1690
+ if (length($con_pos_bef) > $max_con_pos_len)
1691
+ {
1692
+ $max_con_pos_len = length($con_pos_bef) ;
1693
+ }
1694
+ }
1695
+ }
1696
+
1697
+ $err_counter = 0 ;
1698
+
1699
+ foreach $err (sort {$freq_err{$b} <=> $freq_err{$a}} keys %freq_err)
1700
+ {
1701
+
1702
+ ($head_err, $head_aft_bef, $dep_err) = split(/\Q$sep\E/, $err) ;
1703
+
1704
+ $err_counter++ ;
1705
+ $err_desc{$err} = sprintf("%2d. ", $err_counter).
1706
+ describe_err($head_err, $head_aft_bef, $dep_err) ;
1707
+
1708
+ # printf OUT " %-3s %-30s %d\n", $head_err, $dep_err, $freq_err{$err} ;
1709
+ printf OUT "\n" ;
1710
+ printf OUT " %s : %d times\n", $err_desc{$err}, $freq_err{$err} ;
1711
+
1712
+ printf OUT " %s-+-%s-+-%s-+-%s-+-%s-+-%s-+------\n",
1713
+ '-' x $max_con_pos_len, '-' x $max_con_bef_len,
1714
+ '-' x $max_pos_len, '-' x $max_word_spec_len,
1715
+ '-' x $max_con_pos_len, '-' x $max_con_aft_len ;
1716
+
1717
+ printf OUT " %-*s | %-*s | %-*s | %s\n",
1718
+ $max_con_pos_len+$max_con_bef_len+3, ' Before',
1719
+ $max_word_spec_len+$max_pos_len+3, ' Focus',
1720
+ $max_con_pos_len+$max_con_aft_len+3, ' After',
1721
+ 'Count' ;
1722
+
1723
+ printf OUT " %-*s %-*s | %-*s %-*s | %-*s %-*s |\n",
1724
+ $max_con_pos_len, 'CPOS', $max_con_bef_len, 'word',
1725
+ $max_pos_len, 'CPOS', $max_word_spec_len, 'word',
1726
+ $max_con_pos_len, 'CPOS', $max_con_aft_len, 'word' ;
1727
+
1728
+ printf OUT " %s-+-%s-+-%s-+-%s-+-%s-+-%s-+------\n",
1729
+ '-' x $max_con_pos_len, '-' x $max_con_bef_len,
1730
+ '-' x $max_pos_len, '-' x $max_word_spec_len,
1731
+ '-' x $max_con_pos_len, '-' x $max_con_aft_len ;
1732
+
1733
+ foreach $loc_con (sort {$err_counts{$err}{$b} <=> $err_counts{$err}{$a}}
1734
+ keys %{$err_counts{$err}})
1735
+ {
1736
+ if ($loc_con eq join($sep, ('*', '*', '*', '*', '*', '*')))
1737
+ {
1738
+ next ;
1739
+ }
1740
+
1741
+ $con1 = $loc_con ;
1742
+ $con1 =~ s/\*/ /g ;
1743
+
1744
+ ($con_pos_bef, $con_bef, $word, $pos, $con_pos_aft, $con_aft) =
1745
+ split(/\Q$sep\E/, $con1) ;
1746
+
1747
+ printf OUT " %-*s | %-*s | %-*s | %-*s | %-*s | %-*s | %3d\n",
1748
+ $max_con_pos_len, $con_pos_bef, $max_con_bef_len+length($con_bef)-uni_len($con_bef), $con_bef,
1749
+ $max_pos_len, $pos, $max_word_spec_len+length($word)-uni_len($word), $word,
1750
+ $max_con_pos_len, $con_pos_aft, $max_con_aft_len+length($con_aft)-uni_len($con_aft), $con_aft,
1751
+ $err_counts{$err}{$loc_con} ;
1752
+ }
1753
+
1754
+ printf OUT " %s-+-%s-+-%s-+-%s-+-%s-+-%s-+------\n",
1755
+ '-' x $max_con_pos_len, '-' x $max_con_bef_len,
1756
+ '-' x $max_pos_len, '-' x $max_word_spec_len,
1757
+ '-' x $max_con_pos_len, '-' x $max_con_aft_len ;
1758
+
1759
+ }
1760
+
1761
+ printf OUT "\n\n" ;
1762
+ printf OUT " Local contexts involved in several frequent errors:" ;
1763
+ printf OUT "\n %s\n", '=' x 51 ;
1764
+ printf OUT "\n\n" ;
1765
+
1766
+ foreach $loc_con (sort {scalar keys %{$loc_con_err_counts{$b}} <=>
1767
+ scalar keys %{$loc_con_err_counts{$a}}}
1768
+ keys %loc_con_err_counts)
1769
+ {
1770
+
1771
+ if (scalar keys %{$loc_con_err_counts{$loc_con}} == 1)
1772
+ {
1773
+ next ;
1774
+ }
1775
+
1776
+ printf OUT " %s-+-%s-+-%s-+-%s-+-%s-+-%s-\n",
1777
+ '-' x $max_con_pos_len, '-' x $max_con_bef_len,
1778
+ '-' x $max_pos_len, '-' x $max_word_spec_len,
1779
+ '-' x $max_con_pos_len, '-' x $max_con_aft_len ;
1780
+
1781
+ printf OUT " %-*s | %-*s | %-*s \n",
1782
+ $max_con_pos_len+$max_con_bef_len+3, ' Before',
1783
+ $max_word_spec_len+$max_pos_len+3, ' Focus',
1784
+ $max_con_pos_len+$max_con_aft_len+3, ' After' ;
1785
+
1786
+ printf OUT " %-*s %-*s | %-*s %-*s | %-*s %-*s \n",
1787
+ $max_con_pos_len, 'CPOS', $max_con_bef_len, 'word',
1788
+ $max_pos_len, 'CPOS', $max_word_spec_len, 'word',
1789
+ $max_con_pos_len, 'CPOS', $max_con_aft_len, 'word' ;
1790
+
1791
+ printf OUT " %s-+-%s-+-%s-+-%s-+-%s-+-%s-\n",
1792
+ '-' x $max_con_pos_len, '-' x $max_con_bef_len,
1793
+ '-' x $max_pos_len, '-' x $max_word_spec_len,
1794
+ '-' x $max_con_pos_len, '-' x $max_con_aft_len ;
1795
+
1796
+ $con1 = $loc_con ;
1797
+ $con1 =~ s/\*/ /g ;
1798
+
1799
+ ($con_pos_bef, $con_bef, $word, $pos, $con_pos_aft, $con_aft) =
1800
+ split(/\Q$sep\E/, $con1) ;
1801
+
1802
+ printf OUT " %-*s | %-*s | %-*s | %-*s | %-*s | %-*s \n",
1803
+ $max_con_pos_len, $con_pos_bef, $max_con_bef_len+length($con_bef)-uni_len($con_bef), $con_bef,
1804
+ $max_pos_len, $pos, $max_word_spec_len+length($word)-uni_len($word), $word,
1805
+ $max_con_pos_len, $con_pos_aft, $max_con_aft_len+length($con_aft)-uni_len($con_aft), $con_aft ;
1806
+
1807
+ printf OUT " %s-+-%s-+-%s-+-%s-+-%s-+-%s-\n",
1808
+ '-' x $max_con_pos_len, '-' x $max_con_bef_len,
1809
+ '-' x $max_pos_len, '-' x $max_word_spec_len,
1810
+ '-' x $max_con_pos_len, '-' x $max_con_aft_len ;
1811
+
1812
+ foreach $err (sort {$loc_con_err_counts{$loc_con}{$b} <=>
1813
+ $loc_con_err_counts{$loc_con}{$a}}
1814
+ keys %{$loc_con_err_counts{$loc_con}})
1815
+ {
1816
+ printf OUT " %s : %d times\n", $err_desc{$err},
1817
+ $loc_con_err_counts{$loc_con}{$err} ;
1818
+ }
1819
+
1820
+ printf OUT "\n" ;
1821
+ }
1822
+
1823
+ close GOLD ;
1824
+ close SYS ;
1825
+
1826
+ close OUT ;
examples/macro_UAS_LAS.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+
3
+
4
+ def load_results(filename):
5
+
6
+ results = []
7
+ sent = []
8
+ with open(filename, 'r') as fp:
9
+ for i, line in enumerate(fp):
10
+ if i == 0:
11
+ continue
12
+ splits = line.strip().split('\t')
13
+ if len(line.strip()) == 0:
14
+ if len(sent) != 0:
15
+ results.append(sent)
16
+ sent = []
17
+ continue
18
+ gold_head = splits[-4]
19
+ gold_label = splits[-3]
20
+ pred_head = splits[-2]
21
+ pred_label = splits[-1]
22
+ sent.append((gold_head, gold_label, pred_head, pred_label))
23
+ print('Total Number of sentences ' + str(len(results)))
24
+ return results
25
+
26
+ def calculate_las_uas(gold_heads, gold_labels, pred_heads, pred_labels):
27
+
28
+ u_correct = 0
29
+ l_correct = 0
30
+ u_total = 0
31
+ l_total = 0
32
+
33
+ for i in range(len(gold_heads)):
34
+ if gold_heads[i] == pred_heads[i]:
35
+ u_correct +=1
36
+ u_total +=1
37
+ l_total +=1
38
+ if gold_heads[i] == pred_heads[i] and gold_labels[i] == pred_labels[i]:
39
+ l_correct +=1
40
+ return u_correct, u_total, l_correct, l_total
41
+
42
+
43
+ def calculate_stats(results,path):
44
+ u_correct = 0
45
+ l_correct = 0
46
+ u_total = 0
47
+ l_total = 0
48
+
49
+ sent_uas = []
50
+ sent_las = []
51
+
52
+ for i in range(len(results)):
53
+ gold_heads, gold_labels, pred_heads, pred_labels = zip(*results[i])
54
+ u_c, u_t, l_c, l_t = calculate_las_uas(gold_heads, gold_labels, pred_heads, pred_labels)
55
+ if u_t >0:
56
+ uas = float(u_c)/u_t
57
+ las = float(l_c)/l_t
58
+ sent_uas.append(uas)
59
+ sent_las.append(las)
60
+ u_correct += u_c
61
+ l_correct += l_c
62
+ u_total += u_t
63
+ l_total += l_t
64
+
65
+ UAS = float(u_correct)/u_total
66
+ LAS = float(l_correct)/l_total
67
+ path = path.replace('combined_1300_test.txt','Macro-UAS-LAS-score.txt')
68
+ f = open(path,'w')
69
+ f.write('Word level UAS : ' + str(UAS) +'\n')
70
+ f.write('Word level LAS : ' + str(LAS)+'\n')
71
+ f.write('Sentence level UAS : ' + str(float(sum(sent_uas))/len(sent_uas))+'\n')
72
+ f.write('Sentence level LAS : ' + str(float(sum(sent_las))/len(sent_las))+'\n')
73
+ f.close()
74
+ print('Word level UAS : ' + str(UAS))
75
+ print('Word level LAS : ' + str(LAS))
76
+ print('Sentence level UAS : ' + str(float(sum(sent_uas))/len(sent_uas)))
77
+ print('Sentence level LAS : ' + str(float(sum(sent_las))/len(sent_las)))
78
+
79
+ return sent_uas, sent_las, UAS, LAS
80
+
81
+ def write_results(sent_uas, sent_las, filename_uas, filename_las):
82
+
83
+ fp_uas = open(filename_uas, 'w')
84
+ fp_las = open(filename_las, 'w')
85
+
86
+ for i in range(len(sent_uas)):
87
+ fp_uas.write(str(sent_uas[i]) + '\n')
88
+ fp_las.write(str(sent_las[i]) + '\n')
89
+
90
+ fp_uas.close()
91
+ fp_las.close()
92
+
93
+
94
+ if __name__=="__main__":
95
+ dirs = sys.argv[1]
96
+ # results_2 = load_results(sys.argv[2])
97
+ ##path = "Predictions/Yap/"+dirs
98
+ path = "./saved_models/"+dirs+"/final_ensembled/combined_1300_test.txt"
99
+ result = load_results(path)
100
+
101
+
102
+ sent_uas1, sent_las1, UAS1, LAS1 = calculate_stats(result,path)
103
+ # sent_uas2, sent_las2, UAS2, LAS2 = calculate_stats(results_2)
104
+
105
+
106
+ write_results(sent_uas1, sent_las1, 'results1_uas.txt', 'results1_las.txt')
107
+ # write_results(sent_uas2, sent_las2, 'results2_uas.txt', 'results2_las.txt')
examples/write_1300_combined.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+
3
+ def write_combined(dirs):
4
+ path = "./saved_models/"+dirs+"/final_ensembled/"
5
+ f = open(path+'domain_san_test_model_domain_san_data_domain_san_gold.txt','r')
6
+ gold = f.readlines()
7
+ f.close()
8
+ f = open(path+'domain_san_test_model_domain_san_data_domain_san_pred.txt','r')
9
+ pred = f.readlines()
10
+ f.close()
11
+
12
+ for i in range(len(gold)):
13
+ if gold[i] == '\n':
14
+ continue
15
+ if gold[i].split('\t')[0] == pred[i].split('\t')[0]:
16
+ gold[i] = gold[i].replace('\n','\t')
17
+ gold[i] = gold[i]+'\t'.join(pred[i].split('\t')[-2:])
18
+
19
+ f = open(path+'domain_san_prose_model_domain_san_data_domain_san_gold.txt','r')
20
+ prose_gold = f.readlines()
21
+ f.close()
22
+ f = open(path+'domain_san_prose_model_domain_san_data_domain_san_pred.txt','r')
23
+ prose_pred = f.readlines()
24
+ f.close()
25
+
26
+ for i in range(len(prose_gold)):
27
+ if prose_gold[i] == '\n':
28
+ gold.append('\n')
29
+ continue
30
+ if prose_gold[i].split('\t')[0] == prose_pred[i].split('\t')[0]:
31
+ line = prose_gold[i].replace('\n','\t')
32
+ line =line+'\t'.join(prose_pred[i].split('\t')[-2:])
33
+ gold.append(line)
34
+ gold.insert(0,'word_id\tword\tpostag\tlemma\tgold_head\tgold_label\tpred_head\tpred_label\n\n')
35
+
36
+
37
+ f = open(path+'combined_1300_test.txt','w')
38
+ for line in gold:
39
+ f.write(line)
40
+ f.close()
41
+
42
+
43
+ if __name__=="__main__":
44
+
45
+ dir_path = sys.argv[1]
46
+
47
+ write_combined(dir_path)
48
+
run_STBC.sh ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ domain="san"
3
+ # Path of pretrained embedding file
4
+ word_path=data/cc.STBC.300.txt
5
+ # Path to store model and predictions
6
+ saved_models=saved_models
7
+ declare -i num_epochs=100
8
+ declare -i word_dim=300
9
+ start_time=`date +%s`
10
+ current_time=$(date "+%Y.%m.%d-%H.%M.%S")
11
+ model_path="STBC_"$current_time
12
+ touch $saved_models/base_log.txt
13
+
14
+ ###############################################################
15
+ # Running the base Biaffine Parser
16
+ echo "#################################################################"
17
+ echo "Currently base model in progress..."
18
+ echo "#################################################################"
19
+ python examples/GraphParser_MTL_POS.py --dataset ud --domain $domain --rnn_mode LSTM \
20
+ --num_epochs $num_epochs --batch_size 16 --hidden_size 512 --arc_space 512 \
21
+ --arc_tag_space 128 --num_layers 2 --num_filters 100 --use_char --use_pos \
22
+ --word_dim $word_dim --char_dim 100 --pos_dim 100 --initializer xavier --opt adam \
23
+ --learning_rate 0.002 --decay_rate 0.5 --schedule 6 --clip 5.0 --gamma 0.0 \
24
+ --epsilon 1e-6 --p_rnn 0.33 0.33 --p_in 0.33 --p_out 0.33 --arc_decode mst \
25
+ --punct_set '.' '``' ':' ',' --word_embedding fasttext --char_embedding random --pos_embedding random --word_path $word_path \
26
+ --model_path $saved_models/$model_path 2>&1 | tee $saved_models/base_log.txt
27
+
28
+ mv $saved_models/base_log.txt $saved_models/$model_path/base_log.txt
29
+ python examples/BiAFF_write_1300_combined.py $model_path
30
+ python examples/BiAFF_macro_UAS_LAS.py $model_path
31
+
32
+ # ###################################################################
33
+ # Pretraining step: Running the Sequence Tagger
34
+ # Auxiliary tasks : 'Multitask_case_predict' 'Multitask_POS_predict' 'Multitask_label_predict'
35
+ for task in 'Multitask_POS_predict' 'Multitask_label_predict' 'Multitask_case_predict'; do
36
+ touch $saved_models/$model_path/log.txt
37
+ echo "#################################################################"
38
+ echo "Currently $task in progress..."
39
+ echo "#################################################################"
40
+ python examples/SequenceTagger.py --dataset ud --domain $domain --task $task \
41
+ --rnn_mode LSTM --num_epochs $num_epochs --batch_size 16 --hidden_size 512 \
42
+ --tag_space 128 --num_layers 2 --num_filters 100 --use_char --use_pos --char_dim 100 \
43
+ --pos_dim 100 --initializer xavier --opt adam --learning_rate 0.002 --decay_rate 0.5 \
44
+ --schedule 6 --clip 5.0 --gamma 0.0 --epsilon 1e-6 --p_rnn 0.33 0.33 \
45
+ --p_in 0.33 --p_out 0.33 --punct_set '.' '``' ':' ',' \
46
+ --word_dim $word_dim --word_embedding fasttext --word_path $word_path --pos_embedding random \
47
+ --parser_path $saved_models/$model_path/ \
48
+ --use_unlabeled_data --char_embedding random \
49
+ --model_path $saved_models/$model_path/$task/ 2>&1 | tee $saved_models/$model_path/log.txt
50
+ mv $saved_models/$model_path/log.txt $saved_models/$model_path/$task/log.txt
51
+ done
52
+
53
+ ######################################################################
54
+ ## Integration step: The final ensembled proposed system
55
+ echo "#################################################################"
56
+ echo "Currently final model in progress..."
57
+ echo "#################################################################"
58
+ touch $saved_models/$model_path/log.txt
59
+ python examples/GraphParser_MTL_POS.py --dataset ud --domain $domain --rnn_mode LSTM \
60
+ --num_epochs $num_epochs --batch_size 16 --hidden_size 512 \
61
+ --arc_space 512 --arc_tag_space 128 --num_layers 2 --num_filters 100 --use_char --use_pos \
62
+ --word_dim $word_dim --char_dim 100 --pos_dim 100 --initializer xavier --opt adam \
63
+ --learning_rate 0.002 --decay_rate 0.5 --schedule 6 --clip 5.0 --gamma 0.0 --epsilon 1e-6 \
64
+ --p_rnn 0.33 0.33 --p_in 0.33 --p_out 0.33 --arc_decode mst --pos_embedding random \
65
+ --punct_set '.' '``' ':' ',' --word_embedding fasttext --char_embedding random --word_path $word_path \
66
+ --gating --num_gates 4 \
67
+ --load_sequence_taggers_paths $saved_models/$model_path/Multitask_case_predict/domain_$domain.pt \
68
+ $saved_models/$model_path/Multitask_POS_predict/domain_$domain.pt \
69
+ $saved_models/$model_path/Multitask_label_predict/domain_$domain.pt \
70
+ --model_path $saved_models/$model_path/final_ensembled 2>&1 | tee $saved_models/$model_path/log.txt
71
+ mv $saved_models/$model_path/log.txt $saved_models/$model_path/final_ensembled/log.txt
72
+ python examples/write_1300_combined.py $model_path
73
+ python examples/macro_UAS_LAS.py $model_path
74
+ end_time=`date +%s`
75
+ echo execution time was `expr $end_time - $start_time` s.