jiangjiechen commited on
Commit
7f7285f
1 Parent(s): 081073f

init loren for spaces

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +225 -0
  2. app.py +101 -3
  3. cjjpy.py +249 -0
  4. docs/front.png +0 -0
  5. requirements.txt +28 -0
  6. src/available_models/aaai22_roberta.json +0 -0
  7. src/check_client/cjjpy.py +249 -0
  8. src/check_client/fact_checker.py +209 -0
  9. src/check_client/modules/cjjpy.py +249 -0
  10. src/check_client/modules/data_processor.py +354 -0
  11. src/check_client/modules/test_data_processor.py +26 -0
  12. src/check_client/plm_checkers/__init__.py +12 -0
  13. src/check_client/plm_checkers/bert_checker.py +203 -0
  14. src/check_client/plm_checkers/checker_utils.py +223 -0
  15. src/check_client/plm_checkers/roberta_checker.py +203 -0
  16. src/check_client/scripts/train_bert-large.sh +51 -0
  17. src/check_client/scripts/train_roberta.sh +51 -0
  18. src/check_client/train.py +647 -0
  19. src/check_client/utils.py +131 -0
  20. src/cjjpy.py +249 -0
  21. src/dataloaders.py +134 -0
  22. src/er_client/__init__.py +63 -0
  23. src/er_client/cjjpy.py +249 -0
  24. src/er_client/doc_retrieval_by_api.py +44 -0
  25. src/er_client/document_retrieval.py +225 -0
  26. src/er_client/entitylinker.py +84 -0
  27. src/er_client/retrieval_model/bert_model.py +775 -0
  28. src/er_client/retrieval_model/data_loader.py +276 -0
  29. src/er_client/retrieval_model/file_utils.py +249 -0
  30. src/er_client/retrieval_model/models.py +66 -0
  31. src/er_client/retrieval_model/process_data.py +41 -0
  32. src/er_client/retrieval_model/test.py +81 -0
  33. src/er_client/retrieval_model/test.sh +7 -0
  34. src/er_client/sentence_selection.py +54 -0
  35. src/eval_client/cjjpy.py +249 -0
  36. src/eval_client/culpa.py +61 -0
  37. src/eval_client/culprit/eval.human.ref.json +100 -0
  38. src/eval_client/fever_scorer.py +84 -0
  39. src/eval_client/scorer.py +153 -0
  40. src/loren.py +167 -0
  41. src/mrc_client/answer_generator.py +144 -0
  42. src/mrc_client/cjjpy.py +249 -0
  43. src/mrc_client/seq2seq/README.md +590 -0
  44. src/mrc_client/seq2seq/__init__.py +5 -0
  45. src/mrc_client/seq2seq/callbacks.py +115 -0
  46. src/mrc_client/seq2seq/cjjpy.py +249 -0
  47. src/mrc_client/seq2seq/convert_pl_checkpoint_to_hf.py +74 -0
  48. src/mrc_client/seq2seq/finetune.py +465 -0
  49. src/mrc_client/seq2seq/finetune_t5.sh +14 -0
  50. src/mrc_client/seq2seq/finetune_trainer.py +303 -0
.gitignore ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Created by .ignore support plugin (hsz.mobi)
2
+ ### macOS template
3
+ # General
4
+ .DS_Store
5
+ .AppleDouble
6
+ .LSOverride
7
+
8
+ # Icon must end with two \r
9
+ Icon
10
+
11
+ # Thumbnails
12
+ ._*
13
+
14
+ # Files that might appear in the root of a volume
15
+ .DocumentRevisions-V100
16
+ .fseventsd
17
+ .Spotlight-V100
18
+ .TemporaryItems
19
+ .Trashes
20
+ .VolumeIcon.icns
21
+ .com.apple.timemachine.donotpresent
22
+
23
+ # Directories potentially created on remote AFP share
24
+ .AppleDB
25
+ .AppleDesktop
26
+ Network Trash Folder
27
+ Temporary Items
28
+ .apdisk
29
+ ### Python template
30
+ # Byte-compiled / optimized / DLL files
31
+ __pycache__/
32
+ *.py[cod]
33
+ *$py.class
34
+
35
+ # C extensions
36
+ *.so
37
+
38
+ # Distribution / packaging
39
+ .Python
40
+ build/
41
+ develop-eggs/
42
+ dist/
43
+ downloads/
44
+ eggs/
45
+ .eggs/
46
+ lib/
47
+ lib64/
48
+ parts/
49
+ sdist/
50
+ var/
51
+ wheels/
52
+ *.egg-info/
53
+ .installed.cfg
54
+ *.egg
55
+ MANIFEST
56
+
57
+ # PyInstaller
58
+ # Usually these files are written by a python script from a template
59
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
60
+ *.manifest
61
+ *.spec
62
+
63
+ # Installer logs
64
+ pip-log.txt
65
+ pip-delete-this-directory.txt
66
+
67
+ # Unit test / coverage reports
68
+ htmlcov/
69
+ .tox/
70
+ .coverage
71
+ .coverage.*
72
+ .cache
73
+ nosetests.xml
74
+ coverage.xml
75
+ *.cover
76
+ .hypothesis/
77
+ .pytest_cache/
78
+
79
+ # Translations
80
+ *.mo
81
+ *.pot
82
+
83
+ # Django stuff:
84
+ *.log
85
+ local_settings.py
86
+ db.sqlite3
87
+
88
+ # Flask stuff:
89
+ instance/
90
+ .webassets-cache
91
+
92
+ # Scrapy stuff:
93
+ .scrapy
94
+
95
+ # Sphinx documentation
96
+ docs/_build/
97
+
98
+ # PyBuilder
99
+ target/
100
+
101
+ # Jupyter Notebook
102
+ .ipynb_checkpoints
103
+
104
+ # pyenv
105
+ .python-version
106
+
107
+ # celery beat schedule file
108
+ celerybeat-schedule
109
+
110
+ # SageMath parsed files
111
+ *.sage.py
112
+
113
+ # Environments
114
+ .env
115
+ .venv
116
+ env/
117
+ venv/
118
+ ENV/
119
+ env.bak/
120
+ venv.bak/
121
+
122
+ # Spyder project settings
123
+ .spyderproject
124
+ .spyproject
125
+
126
+ # Rope project settings
127
+ .ropeproject
128
+
129
+ # mkdocs documentation
130
+ /site
131
+
132
+ # mypy
133
+ .mypy_cache/
134
+ ### JetBrains template
135
+ # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and WebStorm
136
+ # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
137
+
138
+ # User-specific stuff
139
+ .idea/**/workspace.xml
140
+ .idea/**/tasks.xml
141
+ .idea/**/usage.statistics.xml
142
+ .idea/**/dictionaries
143
+ .idea/**/shelf
144
+
145
+ # Sensitive or high-churn files
146
+ .idea/**/dataSources/
147
+ .idea/**/dataSources.ids
148
+ .idea/**/dataSources.local.xml
149
+ .idea/**/sqlDataSources.xml
150
+ .idea/**/dynamic.xml
151
+ .idea/**/uiDesigner.xml
152
+ .idea/**/dbnavigator.xml
153
+
154
+ # Gradle
155
+ .idea/**/gradle.xml
156
+ .idea/**/libraries
157
+
158
+ # Gradle and Maven with auto-import
159
+ # When using Gradle or Maven with auto-import, you should exclude module files,
160
+ # since they will be recreated, and may cause churn. Uncomment if using
161
+ # auto-import.
162
+ # .idea/modules.xml
163
+ # .idea/*.iml
164
+ # .idea/modules
165
+
166
+ # CMake
167
+ cmake-build-*/
168
+
169
+ # Mongo Explorer plugin
170
+ .idea/**/mongoSettings.xml
171
+
172
+ # File-based project format
173
+ *.iws
174
+
175
+ # IntelliJ
176
+ out/
177
+
178
+ # mpeltonen/sbt-idea plugin
179
+ .idea_modules/
180
+
181
+ # JIRA plugin
182
+ atlassian-ide-plugin.xml
183
+
184
+ # Cursive Clojure plugin
185
+ .idea/replstate.xml
186
+
187
+ # Crashlytics plugin (for Android Studio and IntelliJ)
188
+ com_crashlytics_export_strings.xml
189
+ crashlytics.properties
190
+ crashlytics-build.properties
191
+ fabric.properties
192
+
193
+ # Editor-based Rest Client
194
+ .idea/httpRequests
195
+ ### VirtualEnv template
196
+ # Virtualenv
197
+ # http://iamzed.com/2009/05/07/a-primer-on-virtualenv/
198
+ .Python
199
+ [Bb]in
200
+ [Ii]nclude
201
+ [Ll]ib
202
+ [Ll]ib64
203
+ [Ll]ocal
204
+ pyvenv.cfg
205
+ .venv
206
+ pip-selfcheck.json
207
+
208
+ .idea/
209
+ eden.py
210
+ /_tmp/
211
+ runs
212
+ *nohup*
213
+ *.pt
214
+ *.out
215
+ *.pkl
216
+ *.db
217
+ /cache/
218
+ output/
219
+ *.csv
220
+ *_resources/
221
+ *_proc
222
+ lightning_logs/
223
+ wandb/
224
+ .lock
225
+ *gradio*
app.py CHANGED
@@ -1,7 +1,105 @@
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
 
6
- iface = gr.Interface(fn=greet, inputs="text", outputs="text")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  iface.launch()
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ """
4
+ @Author : Jiangjie Chen
5
+ @Time : 2021/12/13 17:17
6
+ @Contact : jjchen19@fudan.edu.cn
7
+ @Description:
8
+ """
9
+
10
+ import os
11
  import gradio as gr
12
+ from src.loren import Loren
13
+ from huggingface_hub import snapshot_download
14
+ from prettytable import PrettyTable
15
+ import pandas as pd
16
+
17
+ config = {
18
+ "input": "demo",
19
+ "model_type": "roberta",
20
+ "model_name_or_path": "roberta-large",
21
+ "logic_lambda": 0.5,
22
+ "prior": "random",
23
+ "mask_rate": 0.0,
24
+ "cand_k": 3,
25
+ "max_seq2_length": 256,
26
+ "max_seq1_length": 128,
27
+ "max_num_questions": 8
28
+ }
29
+
30
+ model_dir = snapshot_download('Jiangjie/loren')
31
+
32
+ config['fc_dir'] = os.path.join(model_dir, 'fact_checking/roberta-large/')
33
+ config['mrc_dir'] = os.path.join(model_dir, 'mrc_seq2seq/bart-base/')
34
+ config['er_dir'] = os.path.join(model_dir, 'evidence_retrieval/')
35
+
36
+ loren = Loren(config)
37
+ try:
38
+ # js = {
39
+ # 'id': 0,
40
+ # 'evidence': ['EVIDENCE1', 'EVIDENCE2'],
41
+ # 'question': ['QUESTION1', 'QUESTION2'],
42
+ # 'claim_phrase': ['CLAIMPHRASE1', 'CLAIMPHRASE2'],
43
+ # 'local_premise': [['E1 ' * 100, 'E1' * 100, 'E1' * 10], ['E2', 'E2', 'E2']],
44
+ # 'phrase_veracity': [[0.1, 0.5, 0.4], [0.1, 0.7, 0.2]],
45
+ # 'claim_veracity': 'SUPPORT'
46
+ # }
47
+ js = loren.check('Donald Trump won the 2020 U.S. presidential election.')
48
+ except Exception as e:
49
+ raise ValueError(e)
50
+
51
+
52
+ def gradio_formatter(js, output_type):
53
+ if output_type == 'e':
54
+ data = {'Evidence': js['evidence']}
55
+ elif output_type == 'z':
56
+ data = {
57
+ 'Claim Phrase': js['claim_phrase'],
58
+ 'Local Premise': [x[0] for x in js['local_premise']],
59
+ 'p_SUP': [round(x[2], 4) for x in js['phrase_veracity']],
60
+ 'p_REF': [round(x[0], 4) for x in js['phrase_veracity']],
61
+ 'p_NEI': [round(x[1], 4) for x in js['phrase_veracity']],
62
+ }
63
+ else:
64
+ raise NotImplementedError
65
+ data = pd.DataFrame(data)
66
+ pt = PrettyTable(field_names=list(data.columns))
67
+ for v in data.values:
68
+ pt.add_row(v)
69
+
70
+ html = pt.get_html_string(attributes={
71
+ 'style': 'border-width: 1px; border-collapse: collapse',
72
+ }, format=True)
73
+ return html
74
+
75
+
76
+ def run(claim):
77
+ js = loren.check(claim)
78
+ ev_html = gradio_formatter(js, 'e')
79
+ z_html = gradio_formatter(js, 'z')
80
+ return ev_html, z_html, js['claim_veracity'], js
81
 
 
 
82
 
83
+ iface = gr.Interface(
84
+ fn=run,
85
+ inputs="text",
86
+ outputs=[
87
+ 'html',
88
+ 'html',
89
+ 'label',
90
+ 'json'
91
+ ],
92
+ examples=['Donald Trump won the U.S. 2020 presidential election.',
93
+ 'The first inauguration of Bill Clinton was in the United States.'],
94
+ title="LOREN",
95
+ layout='vertical',
96
+ description="LOREN is an interpretable Fact Verification model against Wikipedia. "
97
+ "This is a demo system for \"LOREN: Logic-Regularized Reasoning for Interpretable Fact Verification\". "
98
+ "See the paper for technical details. You can add FLAG on the bottom to record interesting or bad cases!",
99
+ flagging_dir='results/flagged/',
100
+ allow_flagging=True,
101
+ flagging_options=['Good Case!', 'Error: MRC', 'Error: Parsing',
102
+ 'Error: Commonsense', 'Error: Evidence', 'Error: Other'],
103
+ enable_queue=True
104
+ )
105
  iface.launch()
cjjpy.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ '''
4
+ @Author : Jiangjie Chen
5
+ @Time : 2018/11/15 17:08
6
+ @Contact: jjchen19@fudan.edu.cn
7
+ '''
8
+
9
+ import re
10
+ import datetime
11
+ import os
12
+ import argparse
13
+ import logging
14
+ import traceback
15
+
16
+ try:
17
+ import ujson as json
18
+ except:
19
+ import json
20
+
21
+ HADOOP_BIN = 'PATH=/usr/bin/:$PATH hdfs'
22
+ FOR_PUBLIC = True
23
+
24
+
25
+ def LengthStats(filename):
26
+ len_list = []
27
+ thresholds = [0.8, 0.9, 0.95, 0.99, 0.999]
28
+ with open(filename) as f:
29
+ for line in f:
30
+ len_list.append(len(line.strip().split()))
31
+ stats = {
32
+ 'Max': max(len_list),
33
+ 'Min': min(len_list),
34
+ 'Avg': round(sum(len_list) / len(len_list), 4),
35
+ }
36
+ len_list.sort()
37
+ for t in thresholds:
38
+ stats[f"Top-{t}"] = len_list[int(len(len_list) * t)]
39
+
40
+ for k in stats:
41
+ print(f"- {k}: {stats[k]}")
42
+ return stats
43
+
44
+
45
+ class AttrDict(dict):
46
+ def __init__(self, *args, **kwargs):
47
+ super(AttrDict, self).__init__(*args, **kwargs)
48
+ self.__dict__ = self
49
+
50
+
51
+ def TraceBack(error_msg):
52
+ exc = traceback.format_exc()
53
+ msg = f'[Error]: {error_msg}.\n[Traceback]: {exc}'
54
+ return msg
55
+
56
+
57
+ def Now():
58
+ return datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
59
+
60
+
61
+ def AbsParentDir(file, parent='..', postfix=None):
62
+ ppath = os.path.abspath(file)
63
+ parent_level = parent.count('.')
64
+ while parent_level > 0:
65
+ ppath = os.path.dirname(ppath)
66
+ parent_level -= 1
67
+ if postfix is not None:
68
+ return os.path.join(ppath, postfix)
69
+ else:
70
+ return ppath
71
+
72
+
73
+ def init_logger(log_file=None, log_file_level=logging.NOTSET, from_scratch=False):
74
+ from coloredlogs import ColoredFormatter
75
+ import tensorflow as tf
76
+
77
+ fmt = "[%(asctime)s %(levelname)s] %(message)s"
78
+ log_format = ColoredFormatter(fmt=fmt)
79
+ # log_format = logging.Formatter()
80
+ logger = logging.getLogger()
81
+ logger.setLevel(log_file_level)
82
+
83
+ console_handler = logging.StreamHandler()
84
+ console_handler.setFormatter(log_format)
85
+ logger.handlers = [console_handler]
86
+
87
+ if log_file and log_file != '':
88
+ if from_scratch and tf.io.gfile.exists(log_file):
89
+ logger.warning('Removing previous log file: %s' % log_file)
90
+ tf.io.gfile.remove(log_file)
91
+ path = os.path.dirname(log_file)
92
+ os.makedirs(path, exist_ok=True)
93
+ file_handler = logging.FileHandler(log_file)
94
+ file_handler.setLevel(log_file_level)
95
+ file_handler.setFormatter(log_format)
96
+ logger.addHandler(file_handler)
97
+
98
+ return logger
99
+
100
+
101
+ def OverWriteCjjPy(root='.'):
102
+ # import difflib
103
+ # diff = difflib.HtmlDiff()
104
+ cnt = 0
105
+ golden_cjjpy = os.path.join(root, 'cjjpy.py')
106
+ # golden_content = open(golden_cjjpy).readlines()
107
+ for dir, folder, file in os.walk(root):
108
+ for f in file:
109
+ if f == 'cjjpy.py':
110
+ cjjpy = '%s/%s' % (dir, f)
111
+ # content = open(cjjpy).readlines()
112
+ # d = diff.make_file(golden_content, content)
113
+ cnt += 1
114
+ print('[%d]: %s' % (cnt, cjjpy))
115
+ os.system('cp %s %s' % (golden_cjjpy, cjjpy))
116
+
117
+
118
+ def ChangeFileFormat(filename, new_fmt):
119
+ assert type(filename) is str and type(new_fmt) is str
120
+ spt = filename.split('.')
121
+ if len(spt) == 0:
122
+ return filename
123
+ else:
124
+ return filename.replace('.' + spt[-1], new_fmt)
125
+
126
+
127
+ def CountLines(fname):
128
+ with open(fname, 'rb') as f:
129
+ count = 0
130
+ last_data = '\n'
131
+ while True:
132
+ data = f.read(0x400000)
133
+ if not data:
134
+ break
135
+ count += data.count(b'\n')
136
+ last_data = data
137
+ if last_data[-1:] != b'\n':
138
+ count += 1 # Remove this if a wc-like count is needed
139
+ return count
140
+
141
+
142
+ def GetDate():
143
+ return str(datetime.datetime.now())[5:10].replace('-', '')
144
+
145
+
146
+ def TimeClock(seconds):
147
+ sec = int(seconds)
148
+ hour = int(sec / 3600)
149
+ min = int((sec - hour * 3600) / 60)
150
+ ssec = float(seconds) - hour * 3600 - min * 60
151
+ # return '%dh %dm %.2fs' % (hour, min, ssec)
152
+ return '{0:>2d}h{1:>3d}m{2:>6.2f}s'.format(hour, min, ssec)
153
+
154
+
155
+ def StripAll(text):
156
+ return text.strip().replace('\t', '').replace('\n', '').replace(' ', '')
157
+
158
+
159
+ def GetBracket(text, bracket, en_br=False):
160
+ # input should be aa(bb)cc, True for bracket, False for text
161
+ if bracket:
162
+ try:
163
+ return re.findall('\((.*?)\)', text.strip())[-1]
164
+ except:
165
+ return ''
166
+ else:
167
+ if en_br:
168
+ text = re.sub('\(.*?\)', '', text.strip())
169
+ return re.sub('(.*?)', '', text.strip())
170
+
171
+
172
+ def CharLang(uchar, lang):
173
+ assert lang.lower() in ['en', 'cn', 'zh']
174
+ if lang.lower() in ['cn', 'zh']:
175
+ if uchar >= '\u4e00' and uchar <= '\u9fa5':
176
+ return True
177
+ else:
178
+ return False
179
+ elif lang.lower() == 'en':
180
+ if (uchar <= 'Z' and uchar >= 'A') or (uchar <= 'z' and uchar >= 'a'):
181
+ return True
182
+ else:
183
+ return False
184
+ else:
185
+ raise NotImplementedError
186
+
187
+
188
+ def WordLang(word, lang):
189
+ for i in word.strip():
190
+ if i.isspace(): continue
191
+ if not CharLang(i, lang):
192
+ return False
193
+ return True
194
+
195
+
196
+ def SortDict(_dict, reverse=True):
197
+ assert type(_dict) is dict
198
+ return sorted(_dict.items(), key=lambda d: d[1], reverse=reverse)
199
+
200
+
201
+ def lark(content='test'):
202
+ print(content)
203
+
204
+
205
+ if __name__ == '__main__':
206
+ parser = argparse.ArgumentParser()
207
+
208
+ parser.add_argument('--diff', nargs=2,
209
+ help='show difference between two files, shown in downloads/diff.html')
210
+ parser.add_argument('--de_unicode', action='store_true', default=False,
211
+ help='remove unicode characters')
212
+ parser.add_argument('--link_entity', action='store_true', default=False,
213
+ help='')
214
+ parser.add_argument('--max_comm_len', action='store_true', default=False,
215
+ help='')
216
+ parser.add_argument('--search', nargs=2,
217
+ help='search key from file, 2 args: file name & key')
218
+ parser.add_argument('--email', nargs=2,
219
+ help='sending emails, 2 args: subject & content')
220
+ parser.add_argument('--overwrite', action='store_true', default=None,
221
+ help='overwrite all cjjpy under given *dir* based on *dir*/cjjpy.py')
222
+ parser.add_argument('--replace', nargs=3,
223
+ help='replace char, 3 args: file name & replaced char & replacer char')
224
+ parser.add_argument('--lark', nargs=1)
225
+ parser.add_argument('--get_hdfs', nargs=2,
226
+ help='easy copy from hdfs to local fs, 2 args: remote_file/dir & local_dir')
227
+ parser.add_argument('--put_hdfs', nargs=2,
228
+ help='easy put from local fs to hdfs, 2 args: local_file/dir & remote_dir')
229
+ parser.add_argument('--length_stats', nargs=1,
230
+ help='simple token lengths distribution of a line-by-line file')
231
+
232
+ args = parser.parse_args()
233
+
234
+ if args.overwrite:
235
+ print('* Overwriting cjjpy...')
236
+ OverWriteCjjPy()
237
+
238
+ if args.lark:
239
+ try:
240
+ content = args.lark[0]
241
+ except:
242
+ content = 'running complete'
243
+ print(f'* Larking "{content}"...')
244
+ lark(content)
245
+
246
+ if args.length_stats:
247
+ file = args.length_stats[0]
248
+ print(f'* Working on {file} lengths statistics...')
249
+ LengthStats(file)
docs/front.png ADDED
requirements.txt ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ nltk
2
+ tqdm
3
+ six
4
+ scikit-learn
5
+ pathlib
6
+ configargparse
7
+ bottle
8
+ ujson
9
+ GPUtil
10
+ coloredlogs
11
+ inflect
12
+ unidecode
13
+ psutil
14
+ wandb
15
+ rouge_score
16
+ sacrebleu
17
+ tagme
18
+ wikipedia-api
19
+ gradio
20
+ tensorflow
21
+ pytorch-lightning==1.0.4
22
+ allennlp==1.2.2
23
+ allennlp-models==1.2.2
24
+ transformers==3.5.1
25
+ torch==1.7.1
26
+ datasets
27
+ pandas
28
+ prettytable
src/available_models/aaai22_roberta.json ADDED
File without changes
src/check_client/cjjpy.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ '''
4
+ @Author : Jiangjie Chen
5
+ @Time : 2018/11/15 17:08
6
+ @Contact: jjchen19@fudan.edu.cn
7
+ '''
8
+
9
+ import re
10
+ import datetime
11
+ import os
12
+ import argparse
13
+ import logging
14
+ import traceback
15
+
16
+ try:
17
+ import ujson as json
18
+ except:
19
+ import json
20
+
21
+ HADOOP_BIN = 'PATH=/usr/bin/:$PATH hdfs'
22
+ FOR_PUBLIC = True
23
+
24
+
25
+ def LengthStats(filename):
26
+ len_list = []
27
+ thresholds = [0.8, 0.9, 0.95, 0.99, 0.999]
28
+ with open(filename) as f:
29
+ for line in f:
30
+ len_list.append(len(line.strip().split()))
31
+ stats = {
32
+ 'Max': max(len_list),
33
+ 'Min': min(len_list),
34
+ 'Avg': round(sum(len_list) / len(len_list), 4),
35
+ }
36
+ len_list.sort()
37
+ for t in thresholds:
38
+ stats[f"Top-{t}"] = len_list[int(len(len_list) * t)]
39
+
40
+ for k in stats:
41
+ print(f"- {k}: {stats[k]}")
42
+ return stats
43
+
44
+
45
+ class AttrDict(dict):
46
+ def __init__(self, *args, **kwargs):
47
+ super(AttrDict, self).__init__(*args, **kwargs)
48
+ self.__dict__ = self
49
+
50
+
51
+ def TraceBack(error_msg):
52
+ exc = traceback.format_exc()
53
+ msg = f'[Error]: {error_msg}.\n[Traceback]: {exc}'
54
+ return msg
55
+
56
+
57
+ def Now():
58
+ return datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
59
+
60
+
61
+ def AbsParentDir(file, parent='..', postfix=None):
62
+ ppath = os.path.abspath(file)
63
+ parent_level = parent.count('.')
64
+ while parent_level > 0:
65
+ ppath = os.path.dirname(ppath)
66
+ parent_level -= 1
67
+ if postfix is not None:
68
+ return os.path.join(ppath, postfix)
69
+ else:
70
+ return ppath
71
+
72
+
73
+ def init_logger(log_file=None, log_file_level=logging.NOTSET, from_scratch=False):
74
+ from coloredlogs import ColoredFormatter
75
+ import tensorflow as tf
76
+
77
+ fmt = "[%(asctime)s %(levelname)s] %(message)s"
78
+ log_format = ColoredFormatter(fmt=fmt)
79
+ # log_format = logging.Formatter()
80
+ logger = logging.getLogger()
81
+ logger.setLevel(log_file_level)
82
+
83
+ console_handler = logging.StreamHandler()
84
+ console_handler.setFormatter(log_format)
85
+ logger.handlers = [console_handler]
86
+
87
+ if log_file and log_file != '':
88
+ if from_scratch and tf.io.gfile.exists(log_file):
89
+ logger.warning('Removing previous log file: %s' % log_file)
90
+ tf.io.gfile.remove(log_file)
91
+ path = os.path.dirname(log_file)
92
+ os.makedirs(path, exist_ok=True)
93
+ file_handler = logging.FileHandler(log_file)
94
+ file_handler.setLevel(log_file_level)
95
+ file_handler.setFormatter(log_format)
96
+ logger.addHandler(file_handler)
97
+
98
+ return logger
99
+
100
+
101
+ def OverWriteCjjPy(root='.'):
102
+ # import difflib
103
+ # diff = difflib.HtmlDiff()
104
+ cnt = 0
105
+ golden_cjjpy = os.path.join(root, 'cjjpy.py')
106
+ # golden_content = open(golden_cjjpy).readlines()
107
+ for dir, folder, file in os.walk(root):
108
+ for f in file:
109
+ if f == 'cjjpy.py':
110
+ cjjpy = '%s/%s' % (dir, f)
111
+ # content = open(cjjpy).readlines()
112
+ # d = diff.make_file(golden_content, content)
113
+ cnt += 1
114
+ print('[%d]: %s' % (cnt, cjjpy))
115
+ os.system('cp %s %s' % (golden_cjjpy, cjjpy))
116
+
117
+
118
+ def ChangeFileFormat(filename, new_fmt):
119
+ assert type(filename) is str and type(new_fmt) is str
120
+ spt = filename.split('.')
121
+ if len(spt) == 0:
122
+ return filename
123
+ else:
124
+ return filename.replace('.' + spt[-1], new_fmt)
125
+
126
+
127
+ def CountLines(fname):
128
+ with open(fname, 'rb') as f:
129
+ count = 0
130
+ last_data = '\n'
131
+ while True:
132
+ data = f.read(0x400000)
133
+ if not data:
134
+ break
135
+ count += data.count(b'\n')
136
+ last_data = data
137
+ if last_data[-1:] != b'\n':
138
+ count += 1 # Remove this if a wc-like count is needed
139
+ return count
140
+
141
+
142
+ def GetDate():
143
+ return str(datetime.datetime.now())[5:10].replace('-', '')
144
+
145
+
146
+ def TimeClock(seconds):
147
+ sec = int(seconds)
148
+ hour = int(sec / 3600)
149
+ min = int((sec - hour * 3600) / 60)
150
+ ssec = float(seconds) - hour * 3600 - min * 60
151
+ # return '%dh %dm %.2fs' % (hour, min, ssec)
152
+ return '{0:>2d}h{1:>3d}m{2:>6.2f}s'.format(hour, min, ssec)
153
+
154
+
155
+ def StripAll(text):
156
+ return text.strip().replace('\t', '').replace('\n', '').replace(' ', '')
157
+
158
+
159
+ def GetBracket(text, bracket, en_br=False):
160
+ # input should be aa(bb)cc, True for bracket, False for text
161
+ if bracket:
162
+ try:
163
+ return re.findall('\((.*?)\)', text.strip())[-1]
164
+ except:
165
+ return ''
166
+ else:
167
+ if en_br:
168
+ text = re.sub('\(.*?\)', '', text.strip())
169
+ return re.sub('(.*?)', '', text.strip())
170
+
171
+
172
+ def CharLang(uchar, lang):
173
+ assert lang.lower() in ['en', 'cn', 'zh']
174
+ if lang.lower() in ['cn', 'zh']:
175
+ if uchar >= '\u4e00' and uchar <= '\u9fa5':
176
+ return True
177
+ else:
178
+ return False
179
+ elif lang.lower() == 'en':
180
+ if (uchar <= 'Z' and uchar >= 'A') or (uchar <= 'z' and uchar >= 'a'):
181
+ return True
182
+ else:
183
+ return False
184
+ else:
185
+ raise NotImplementedError
186
+
187
+
188
+ def WordLang(word, lang):
189
+ for i in word.strip():
190
+ if i.isspace(): continue
191
+ if not CharLang(i, lang):
192
+ return False
193
+ return True
194
+
195
+
196
+ def SortDict(_dict, reverse=True):
197
+ assert type(_dict) is dict
198
+ return sorted(_dict.items(), key=lambda d: d[1], reverse=reverse)
199
+
200
+
201
+ def lark(content='test'):
202
+ print(content)
203
+
204
+
205
+ if __name__ == '__main__':
206
+ parser = argparse.ArgumentParser()
207
+
208
+ parser.add_argument('--diff', nargs=2,
209
+ help='show difference between two files, shown in downloads/diff.html')
210
+ parser.add_argument('--de_unicode', action='store_true', default=False,
211
+ help='remove unicode characters')
212
+ parser.add_argument('--link_entity', action='store_true', default=False,
213
+ help='')
214
+ parser.add_argument('--max_comm_len', action='store_true', default=False,
215
+ help='')
216
+ parser.add_argument('--search', nargs=2,
217
+ help='search key from file, 2 args: file name & key')
218
+ parser.add_argument('--email', nargs=2,
219
+ help='sending emails, 2 args: subject & content')
220
+ parser.add_argument('--overwrite', action='store_true', default=None,
221
+ help='overwrite all cjjpy under given *dir* based on *dir*/cjjpy.py')
222
+ parser.add_argument('--replace', nargs=3,
223
+ help='replace char, 3 args: file name & replaced char & replacer char')
224
+ parser.add_argument('--lark', nargs=1)
225
+ parser.add_argument('--get_hdfs', nargs=2,
226
+ help='easy copy from hdfs to local fs, 2 args: remote_file/dir & local_dir')
227
+ parser.add_argument('--put_hdfs', nargs=2,
228
+ help='easy put from local fs to hdfs, 2 args: local_file/dir & remote_dir')
229
+ parser.add_argument('--length_stats', nargs=1,
230
+ help='simple token lengths distribution of a line-by-line file')
231
+
232
+ args = parser.parse_args()
233
+
234
+ if args.overwrite:
235
+ print('* Overwriting cjjpy...')
236
+ OverWriteCjjPy()
237
+
238
+ if args.lark:
239
+ try:
240
+ content = args.lark[0]
241
+ except:
242
+ content = 'running complete'
243
+ print(f'* Larking "{content}"...')
244
+ lark(content)
245
+
246
+ if args.length_stats:
247
+ file = args.length_stats[0]
248
+ print(f'* Working on {file} lengths statistics...')
249
+ LengthStats(file)
src/check_client/fact_checker.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ """
4
+ @Author : Bao
5
+ @Date : 2020/8/12
6
+ @Desc :
7
+ @Last modified by : Bao
8
+ @Last modified date : 2020/8/20
9
+ """
10
+
11
+ import os
12
+ import sys
13
+ import logging
14
+ import torch
15
+ import numpy as np
16
+ from tqdm import tqdm
17
+ import tensorflow as tf
18
+ import ujson as json
19
+ import argparse
20
+ import cjjpy as cjj
21
+ from itertools import repeat
22
+ from torch.utils.data import DataLoader, SequentialSampler
23
+ from transformers import (
24
+ BertConfig, BertTokenizer, AutoTokenizer,
25
+ RobertaConfig, RobertaTokenizer,
26
+ )
27
+
28
+ try:
29
+ from .modules.data_processor import DataProcessor
30
+ from .plm_checkers import BertChecker, RobertaChecker
31
+ from .utils import read_json_lines, compute_metrics
32
+ from .train import do_evaluate, set_seed
33
+ from ..eval_client.fever_scorer import FeverScorer
34
+ except:
35
+ sys.path.append(cjj.AbsParentDir(__file__, '.'))
36
+ sys.path.append(cjj.AbsParentDir(__file__, '..'))
37
+ from eval_client.fever_scorer import FeverScorer
38
+ from modules.data_processor import DataProcessor
39
+ from plm_checkers import BertChecker, RobertaChecker
40
+ from utils import read_json_lines, compute_metrics
41
+ from train import do_evaluate, set_seed
42
+
43
+ MODEL_MAPPING = {
44
+ 'bert': (BertConfig, BertTokenizer, BertChecker),
45
+ 'roberta': (RobertaConfig, RobertaTokenizer, RobertaChecker),
46
+ }
47
+
48
+ logger = logging.getLogger(__name__)
49
+ label2id = {"SUPPORTS": 2, "REFUTES": 0, 'NOT ENOUGH INFO': 1}
50
+ id2label = {v: k for k, v in label2id.items()}
51
+
52
+
53
+ class FactChecker:
54
+ def __init__(self, args, fc_ckpt_dir=None, mask_rate=0.):
55
+ self.data_processor = None
56
+ self.tokenizer = None
57
+ self.model = None
58
+ self.args = args
59
+ self.ckpt = args.fc_dir if fc_ckpt_dir is None else fc_ckpt_dir
60
+ self.mask_rate = mask_rate
61
+
62
+ logger.info('Initializing fact checker.')
63
+ self._prepare_ckpt(self.args.model_name_or_path, self.ckpt)
64
+ self.load_model()
65
+
66
+ def _prepare_ckpt(self, model_name_or_path, ckpt_dir):
67
+ tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
68
+ tokenizer.save_pretrained(ckpt_dir)
69
+
70
+ def load_model(self):
71
+ if self.model is None:
72
+ self.data_processor = DataProcessor(
73
+ self.args.model_name_or_path,
74
+ self.args.max_seq1_length,
75
+ self.args.max_seq2_length,
76
+ self.args.max_num_questions,
77
+ self.args.cand_k,
78
+ mask_rate=self.mask_rate
79
+ )
80
+
81
+ _, tokenizer_class, model_class = MODEL_MAPPING[self.args.model_type]
82
+ self.tokenizer = tokenizer_class.from_pretrained(
83
+ self.ckpt,
84
+ do_lower_case=self.args.do_lower_case
85
+ )
86
+ self.model = model_class.from_pretrained(
87
+ self.ckpt,
88
+ from_tf=bool(".ckpt" in self.ckpt),
89
+ logic_lambda=self.args.logic_lambda,
90
+ prior=self.args.prior,
91
+ )
92
+ self.model = torch.nn.DataParallel(self.model)
93
+
94
+ def _check(self, inputs: list, batch_size=32, verbose=True):
95
+ dataset = self.data_processor.convert_inputs_to_dataset(inputs, self.tokenizer, verbose=verbose)
96
+ sampler = SequentialSampler(dataset)
97
+ dataloader = DataLoader(dataset, sampler=sampler, batch_size=batch_size)
98
+
99
+ with torch.no_grad():
100
+ self.model.to(self.args.device)
101
+ self.model.eval()
102
+ iter = tqdm(dataloader, desc="Fact Checking") if verbose else dataloader
103
+ _, y_predicted, z_predicted, m_attn, mask = \
104
+ do_evaluate(iter, self.model, self.args, during_training=False, with_label=False)
105
+
106
+ return y_predicted, z_predicted, m_attn, mask
107
+
108
+ def check_from_file(self, in_filename, out_filename, batch_size, verbose=False):
109
+ if 'test' in in_filename:
110
+ raw_inp = f'{os.environ["PJ_HOME"]}/data/fever/shared_task_test.jsonl'
111
+ else:
112
+ raw_inp = None
113
+ tf.io.gfile.makedirs(os.path.dirname(out_filename))
114
+ inputs = list(read_json_lines(in_filename))
115
+ y_predicted, z_predicted, m_attn, mask = self._check(inputs, batch_size)
116
+
117
+ z_predicted = repeat(None) if z_predicted is None else z_predicted
118
+ m_attn = repeat(None) if m_attn is None else m_attn
119
+ ordered_results = {}
120
+ with_label = inputs[0].get('label') is not None
121
+
122
+ if with_label:
123
+ label_truth = [label2id[x['label']] for x in inputs]
124
+ _, acc_results = compute_metrics(label_truth, y_predicted, z_predicted, mask)
125
+ else:
126
+ acc_results = {}
127
+
128
+ for i, (inp, y, z, attn, _mask) in \
129
+ enumerate(zip(inputs, y_predicted, z_predicted, m_attn, mask)):
130
+ result = {'id': inp['id'],
131
+ 'predicted_label': id2label[y],
132
+ 'predicted_evidence': inp.get('predicted_evidence', [])}
133
+ if verbose:
134
+ if i < 5:
135
+ print("{}\t{}\t{}".format(inp.get("id", i), inp["claim"], y))
136
+ if z is not None and attn is not None:
137
+ result.update({
138
+ 'z_prob': z[:torch.tensor(_mask).sum()],
139
+ 'm_attn': attn[:torch.tensor(_mask).sum()],
140
+ })
141
+ ordered_results[inp['id']] = result
142
+
143
+ with tf.io.gfile.GFile(out_filename, 'w') as fout:
144
+ if raw_inp:
145
+ with tf.io.gfile.GFile(raw_inp) as f:
146
+ for line in f:
147
+ raw_js = json.loads(line)
148
+ fout.write(json.dumps(ordered_results[raw_js['id']]) + '\n')
149
+ else:
150
+ for k in ordered_results:
151
+ fout.write(json.dumps(ordered_results[k]) + '\n')
152
+
153
+ if ('dev' in in_filename or 'val' in in_filename) and with_label:
154
+ scorer = FeverScorer()
155
+ fever_results = scorer.get_scores(out_filename)
156
+ fever_results.update(acc_results)
157
+
158
+ print(fever_results)
159
+ return fever_results
160
+
161
+ def check_from_batch(self, inputs: list, verbose=False):
162
+ y_predicted, z_predicted, m_attn, mask = self._check(inputs, len(inputs), verbose)
163
+ return y_predicted, z_predicted, m_attn
164
+
165
+
166
+ if __name__ == '__main__':
167
+ parser = argparse.ArgumentParser()
168
+ parser.add_argument('--input', '-i', required=True, type=str,
169
+ choices=['val', 'eval', 'test', 'demo'])
170
+ parser.add_argument('--output', '-o', default='none', type=str)
171
+ parser.add_argument('--ckpt', '-c', required=True, type=str)
172
+ parser.add_argument('--model_type', default='roberta', type=str,
173
+ choices=['roberta', 'bert'])
174
+ parser.add_argument('--model_name_or_path', default='roberta-large', type=str)
175
+ parser.add_argument('--verbose', '-v', action='store_true', default=False,
176
+ help='whether output phrasal veracity or not')
177
+ parser.add_argument('--logic_lambda', '-l', required=True, type=float)
178
+ parser.add_argument('--prior', default='random', type=str, choices=['nli', 'uniform', 'logic', 'random'],
179
+ help='type of prior distribution')
180
+ parser.add_argument('--mask_rate', '-m', default=0., type=float)
181
+
182
+ parser.add_argument('--cand_k', '-k', default=3, type=int)
183
+ parser.add_argument('--max_seq1_length', default=256, type=int)
184
+ parser.add_argument('--max_seq2_length', default=128, type=int)
185
+ parser.add_argument('--max_num_questions', default=8, type=int)
186
+ parser.add_argument('--do_lower_case', action='store_true', default=False)
187
+ parser.add_argument('--batch_size', '-b', default=64, type=int)
188
+ parser.add_argument('--seed', default=42)
189
+ parser.add_argument('--n_gpu', default=4)
190
+
191
+ args = parser.parse_args()
192
+
193
+ set_seed(args)
194
+
195
+ args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
196
+
197
+ if args.output == 'none':
198
+ args.ckpt = args.ckpt[:-1] if args.ckpt.endswith('/') else args.ckpt
199
+ base_name = os.path.basename(args.ckpt)
200
+ args.output = f'{os.environ["PJ_HOME"]}/results/fact_checking/AAAI22/{args.input}.{args.model_name_or_path}_m{args.mask_rate}_l{args.logic_lambda}_{base_name}_{args.prior}.predictions.jsonl'
201
+
202
+ assert args.output.endswith('predictions.jsonl'), \
203
+ f"{args.output} must end with predictions.jsonl"
204
+
205
+ args.input = f'{os.environ["PJ_HOME"]}/data/fact_checking/v5/{args.input}.json'
206
+
207
+ checker = FactChecker(args, args.ckpt, args.mask_rate)
208
+ fever_results = checker.check_from_file(args.input, args.output, args.batch_size, args.verbose)
209
+ cjj.lark(f"{args.output}: {fever_results}")
src/check_client/modules/cjjpy.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ '''
4
+ @Author : Jiangjie Chen
5
+ @Time : 2018/11/15 17:08
6
+ @Contact: jjchen19@fudan.edu.cn
7
+ '''
8
+
9
+ import re
10
+ import datetime
11
+ import os
12
+ import argparse
13
+ import logging
14
+ import traceback
15
+
16
+ try:
17
+ import ujson as json
18
+ except:
19
+ import json
20
+
21
+ HADOOP_BIN = 'PATH=/usr/bin/:$PATH hdfs'
22
+ FOR_PUBLIC = True
23
+
24
+
25
+ def LengthStats(filename):
26
+ len_list = []
27
+ thresholds = [0.8, 0.9, 0.95, 0.99, 0.999]
28
+ with open(filename) as f:
29
+ for line in f:
30
+ len_list.append(len(line.strip().split()))
31
+ stats = {
32
+ 'Max': max(len_list),
33
+ 'Min': min(len_list),
34
+ 'Avg': round(sum(len_list) / len(len_list), 4),
35
+ }
36
+ len_list.sort()
37
+ for t in thresholds:
38
+ stats[f"Top-{t}"] = len_list[int(len(len_list) * t)]
39
+
40
+ for k in stats:
41
+ print(f"- {k}: {stats[k]}")
42
+ return stats
43
+
44
+
45
+ class AttrDict(dict):
46
+ def __init__(self, *args, **kwargs):
47
+ super(AttrDict, self).__init__(*args, **kwargs)
48
+ self.__dict__ = self
49
+
50
+
51
+ def TraceBack(error_msg):
52
+ exc = traceback.format_exc()
53
+ msg = f'[Error]: {error_msg}.\n[Traceback]: {exc}'
54
+ return msg
55
+
56
+
57
+ def Now():
58
+ return datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
59
+
60
+
61
+ def AbsParentDir(file, parent='..', postfix=None):
62
+ ppath = os.path.abspath(file)
63
+ parent_level = parent.count('.')
64
+ while parent_level > 0:
65
+ ppath = os.path.dirname(ppath)
66
+ parent_level -= 1
67
+ if postfix is not None:
68
+ return os.path.join(ppath, postfix)
69
+ else:
70
+ return ppath
71
+
72
+
73
+ def init_logger(log_file=None, log_file_level=logging.NOTSET, from_scratch=False):
74
+ from coloredlogs import ColoredFormatter
75
+ import tensorflow as tf
76
+
77
+ fmt = "[%(asctime)s %(levelname)s] %(message)s"
78
+ log_format = ColoredFormatter(fmt=fmt)
79
+ # log_format = logging.Formatter()
80
+ logger = logging.getLogger()
81
+ logger.setLevel(log_file_level)
82
+
83
+ console_handler = logging.StreamHandler()
84
+ console_handler.setFormatter(log_format)
85
+ logger.handlers = [console_handler]
86
+
87
+ if log_file and log_file != '':
88
+ if from_scratch and tf.io.gfile.exists(log_file):
89
+ logger.warning('Removing previous log file: %s' % log_file)
90
+ tf.io.gfile.remove(log_file)
91
+ path = os.path.dirname(log_file)
92
+ os.makedirs(path, exist_ok=True)
93
+ file_handler = logging.FileHandler(log_file)
94
+ file_handler.setLevel(log_file_level)
95
+ file_handler.setFormatter(log_format)
96
+ logger.addHandler(file_handler)
97
+
98
+ return logger
99
+
100
+
101
+ def OverWriteCjjPy(root='.'):
102
+ # import difflib
103
+ # diff = difflib.HtmlDiff()
104
+ cnt = 0
105
+ golden_cjjpy = os.path.join(root, 'cjjpy.py')
106
+ # golden_content = open(golden_cjjpy).readlines()
107
+ for dir, folder, file in os.walk(root):
108
+ for f in file:
109
+ if f == 'cjjpy.py':
110
+ cjjpy = '%s/%s' % (dir, f)
111
+ # content = open(cjjpy).readlines()
112
+ # d = diff.make_file(golden_content, content)
113
+ cnt += 1
114
+ print('[%d]: %s' % (cnt, cjjpy))
115
+ os.system('cp %s %s' % (golden_cjjpy, cjjpy))
116
+
117
+
118
+ def ChangeFileFormat(filename, new_fmt):
119
+ assert type(filename) is str and type(new_fmt) is str
120
+ spt = filename.split('.')
121
+ if len(spt) == 0:
122
+ return filename
123
+ else:
124
+ return filename.replace('.' + spt[-1], new_fmt)
125
+
126
+
127
+ def CountLines(fname):
128
+ with open(fname, 'rb') as f:
129
+ count = 0
130
+ last_data = '\n'
131
+ while True:
132
+ data = f.read(0x400000)
133
+ if not data:
134
+ break
135
+ count += data.count(b'\n')
136
+ last_data = data
137
+ if last_data[-1:] != b'\n':
138
+ count += 1 # Remove this if a wc-like count is needed
139
+ return count
140
+
141
+
142
+ def GetDate():
143
+ return str(datetime.datetime.now())[5:10].replace('-', '')
144
+
145
+
146
+ def TimeClock(seconds):
147
+ sec = int(seconds)
148
+ hour = int(sec / 3600)
149
+ min = int((sec - hour * 3600) / 60)
150
+ ssec = float(seconds) - hour * 3600 - min * 60
151
+ # return '%dh %dm %.2fs' % (hour, min, ssec)
152
+ return '{0:>2d}h{1:>3d}m{2:>6.2f}s'.format(hour, min, ssec)
153
+
154
+
155
+ def StripAll(text):
156
+ return text.strip().replace('\t', '').replace('\n', '').replace(' ', '')
157
+
158
+
159
+ def GetBracket(text, bracket, en_br=False):
160
+ # input should be aa(bb)cc, True for bracket, False for text
161
+ if bracket:
162
+ try:
163
+ return re.findall('\((.*?)\)', text.strip())[-1]
164
+ except:
165
+ return ''
166
+ else:
167
+ if en_br:
168
+ text = re.sub('\(.*?\)', '', text.strip())
169
+ return re.sub('(.*?)', '', text.strip())
170
+
171
+
172
+ def CharLang(uchar, lang):
173
+ assert lang.lower() in ['en', 'cn', 'zh']
174
+ if lang.lower() in ['cn', 'zh']:
175
+ if uchar >= '\u4e00' and uchar <= '\u9fa5':
176
+ return True
177
+ else:
178
+ return False
179
+ elif lang.lower() == 'en':
180
+ if (uchar <= 'Z' and uchar >= 'A') or (uchar <= 'z' and uchar >= 'a'):
181
+ return True
182
+ else:
183
+ return False
184
+ else:
185
+ raise NotImplementedError
186
+
187
+
188
+ def WordLang(word, lang):
189
+ for i in word.strip():
190
+ if i.isspace(): continue
191
+ if not CharLang(i, lang):
192
+ return False
193
+ return True
194
+
195
+
196
+ def SortDict(_dict, reverse=True):
197
+ assert type(_dict) is dict
198
+ return sorted(_dict.items(), key=lambda d: d[1], reverse=reverse)
199
+
200
+
201
+ def lark(content='test'):
202
+ print(content)
203
+
204
+
205
+ if __name__ == '__main__':
206
+ parser = argparse.ArgumentParser()
207
+
208
+ parser.add_argument('--diff', nargs=2,
209
+ help='show difference between two files, shown in downloads/diff.html')
210
+ parser.add_argument('--de_unicode', action='store_true', default=False,
211
+ help='remove unicode characters')
212
+ parser.add_argument('--link_entity', action='store_true', default=False,
213
+ help='')
214
+ parser.add_argument('--max_comm_len', action='store_true', default=False,
215
+ help='')
216
+ parser.add_argument('--search', nargs=2,
217
+ help='search key from file, 2 args: file name & key')
218
+ parser.add_argument('--email', nargs=2,
219
+ help='sending emails, 2 args: subject & content')
220
+ parser.add_argument('--overwrite', action='store_true', default=None,
221
+ help='overwrite all cjjpy under given *dir* based on *dir*/cjjpy.py')
222
+ parser.add_argument('--replace', nargs=3,
223
+ help='replace char, 3 args: file name & replaced char & replacer char')
224
+ parser.add_argument('--lark', nargs=1)
225
+ parser.add_argument('--get_hdfs', nargs=2,
226
+ help='easy copy from hdfs to local fs, 2 args: remote_file/dir & local_dir')
227
+ parser.add_argument('--put_hdfs', nargs=2,
228
+ help='easy put from local fs to hdfs, 2 args: local_file/dir & remote_dir')
229
+ parser.add_argument('--length_stats', nargs=1,
230
+ help='simple token lengths distribution of a line-by-line file')
231
+
232
+ args = parser.parse_args()
233
+
234
+ if args.overwrite:
235
+ print('* Overwriting cjjpy...')
236
+ OverWriteCjjPy()
237
+
238
+ if args.lark:
239
+ try:
240
+ content = args.lark[0]
241
+ except:
242
+ content = 'running complete'
243
+ print(f'* Larking "{content}"...')
244
+ lark(content)
245
+
246
+ if args.length_stats:
247
+ file = args.length_stats[0]
248
+ print(f'* Working on {file} lengths statistics...')
249
+ LengthStats(file)
src/check_client/modules/data_processor.py ADDED
@@ -0,0 +1,354 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ """
4
+ @Author : Bao
5
+ @Date : 2020/4/14
6
+ @Desc :
7
+ @Last modified by : Bao
8
+ @Last modified date : 2020/8/12
9
+ """
10
+
11
+ import os
12
+ import copy
13
+ import logging
14
+ import ujson as json
15
+ import torch
16
+ from tqdm import tqdm
17
+ from torch.utils.data import TensorDataset
18
+ import tensorflow as tf
19
+ import cjjpy as cjj
20
+ import sys
21
+
22
+ try:
23
+ from ...mrc_client.answer_generator import assemble_answers_to_one
24
+ except:
25
+ sys.path.append(cjj.AbsParentDir(__file__, '...'))
26
+ from mrc_client.answer_generator import assemble_answers_to_one
27
+
28
+ logger = logging.getLogger(__name__)
29
+
30
+
31
+ class InputExample(object):
32
+ def __init__(self, guid, claim, evidences, questions, answers,
33
+ evidential, label=None, nli_labels=None):
34
+ self.guid = guid
35
+ self.claim = claim
36
+ self.evidences = evidences
37
+ self.questions = questions
38
+ self.answers = answers
39
+ self.evidential = evidential
40
+ self.label = label
41
+ self.nli_labels = nli_labels
42
+
43
+ def __repr__(self):
44
+ return str(self.to_json_string())
45
+
46
+ def to_dict(self):
47
+ """Serializes this instance to a Python dictionary."""
48
+ output = copy.deepcopy(self.__dict__)
49
+ return output
50
+
51
+ def to_json_string(self):
52
+ """Serializes this instance to a JSON string."""
53
+ return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
54
+
55
+
56
+ class InputFeatures(object):
57
+ def __init__(
58
+ self,
59
+ guid,
60
+ c_input_ids,
61
+ c_attention_mask,
62
+ c_token_type_ids,
63
+ q_input_ids_list,
64
+ q_attention_mask_list,
65
+ q_token_type_ids_list,
66
+ nli_labels=None,
67
+ label=None,
68
+ ):
69
+ self.guid = guid
70
+ self.c_input_ids = c_input_ids
71
+ self.c_attention_mask = c_attention_mask
72
+ self.c_token_type_ids = c_token_type_ids
73
+ self.q_input_ids_list = q_input_ids_list
74
+ self.q_attention_mask_list = q_attention_mask_list
75
+ self.q_token_type_ids_list = q_token_type_ids_list
76
+ self.nli_labels = nli_labels
77
+ self.label = label
78
+
79
+ def __repr__(self):
80
+ return str(self.to_json_string())
81
+
82
+ def to_dict(self):
83
+ """Serializes this instance to a Python dictionary."""
84
+ output = copy.deepcopy(self.__dict__)
85
+ return output
86
+
87
+ def to_json_string(self):
88
+ """Serializes this instance to a JSON string."""
89
+ return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
90
+
91
+
92
+ def _create_input_ids_from_token_ids(token_ids_a, token_ids_b, tokenizer, max_seq_length):
93
+ pair = len(token_ids_b) != 0
94
+
95
+ # Truncate sequences.
96
+ num_special_tokens_to_add = tokenizer.num_special_tokens_to_add(pair=pair)
97
+ while len(token_ids_a) + len(token_ids_b) > max_seq_length - num_special_tokens_to_add:
98
+ if len(token_ids_b) > 0:
99
+ token_ids_b = token_ids_b[:-1]
100
+ else:
101
+ token_ids_a = token_ids_a[:-1]
102
+
103
+ # Add special tokens to input_ids.
104
+ input_ids = tokenizer.build_inputs_with_special_tokens(token_ids_a, token_ids_b if pair else None)
105
+
106
+ # The mask has 1 for real tokens and 0 for padding tokens. Only real tokens are attended to.
107
+ attention_mask = [1] * len(input_ids)
108
+
109
+ # Create token_type_ids.
110
+ token_type_ids = tokenizer.create_token_type_ids_from_sequences(token_ids_a, token_ids_b if pair else None)
111
+
112
+ # Pad up to the sequence length.
113
+ padding_length = max_seq_length - len(input_ids)
114
+ if tokenizer.padding_side == "right":
115
+ input_ids = input_ids + ([tokenizer.pad_token_id] * padding_length)
116
+ attention_mask = attention_mask + ([0] * padding_length)
117
+ token_type_ids = token_type_ids + ([tokenizer.pad_token_type_id] * padding_length)
118
+ else:
119
+ input_ids = ([tokenizer.pad_token_id] * padding_length) + input_ids
120
+ attention_mask = ([0] * padding_length) + attention_mask
121
+ token_type_ids = ([tokenizer.pad_token_type_id] * padding_length) + token_type_ids
122
+
123
+ assert len(input_ids) == max_seq_length
124
+ assert len(attention_mask) == max_seq_length
125
+ assert len(token_type_ids) == max_seq_length
126
+
127
+ return input_ids, attention_mask, token_type_ids
128
+
129
+
130
+ def convert_examples_to_features(
131
+ examples,
132
+ tokenizer,
133
+ max_seq1_length=256,
134
+ max_seq2_length=128,
135
+ verbose=True
136
+ ):
137
+ features = []
138
+ iter = tqdm(examples, desc="Converting Examples") if verbose else examples
139
+ for (ex_index, example) in enumerate(iter):
140
+ encoded_outputs = {"guid": example.guid, 'label': example.label,
141
+ 'nli_labels': example.nli_labels}
142
+
143
+ # ****** for sequence 1 ******* #
144
+ token_ids_a, token_ids_b = [], []
145
+
146
+ # text a in sequence 1
147
+ token_ids = tokenizer.encode(example.claim, add_special_tokens=False) # encode claim
148
+ token_ids_a.extend(token_ids)
149
+
150
+ # text b in sequence 1
151
+ for i, evidence in enumerate(example.evidences):
152
+ token_ids = tokenizer.encode(evidence, add_special_tokens=False) # encode evidence
153
+ token_ids_b.extend(token_ids + [tokenizer.sep_token_id])
154
+ # Remove last sep token in token_ids_b.
155
+ token_ids_b = token_ids_b[:-1]
156
+ token_ids_b = token_ids_b[:max_seq1_length - len(token_ids_a) - 4] # magic number for special tokens
157
+
158
+ # premise </s> </s> hypothesis
159
+ input_ids, attention_mask, token_type_ids = _create_input_ids_from_token_ids(
160
+ token_ids_b,
161
+ token_ids_a,
162
+ tokenizer,
163
+ max_seq1_length,
164
+ )
165
+
166
+ encoded_outputs["c_input_ids"] = input_ids
167
+ encoded_outputs["c_attention_mask"] = attention_mask
168
+ encoded_outputs["c_token_type_ids"] = token_type_ids
169
+
170
+ # ****** for sequence 2 ******* #
171
+ encoded_outputs["q_input_ids_list"] = [] # m x L
172
+ encoded_outputs["q_attention_mask_list"] = []
173
+ encoded_outputs["q_token_type_ids_list"] = []
174
+
175
+ for candidate in example.evidential:
176
+ # text a in sequence 2
177
+ token_ids_a = tokenizer.encode(example.claim, add_special_tokens=False) # encode claim
178
+ # text b in sequence 2
179
+ token_ids_b = tokenizer.encode(candidate, add_special_tokens=False) # encode candidate answer
180
+ # premise </s> </s> hypothesis
181
+ input_ids, attention_mask, token_type_ids = _create_input_ids_from_token_ids(
182
+ token_ids_b,
183
+ token_ids_a,
184
+ tokenizer,
185
+ max_seq2_length,
186
+ )
187
+
188
+ encoded_outputs["q_input_ids_list"].append(input_ids)
189
+ encoded_outputs["q_attention_mask_list"].append(attention_mask)
190
+ encoded_outputs["q_token_type_ids_list"].append(token_type_ids)
191
+
192
+ features.append(InputFeatures(**encoded_outputs))
193
+
194
+ if ex_index < 5 and verbose:
195
+ logger.info("*** Example ***")
196
+ logger.info("guid: {}".format(example.guid))
197
+ logger.info("c_input_ids: {}".format(encoded_outputs["c_input_ids"]))
198
+ for input_ids in encoded_outputs['q_input_ids_list']:
199
+ logger.info('q_input_ids: {}'.format(input_ids))
200
+ logger.info("label: {}".format(example.label))
201
+ logger.info("nli_labels: {}".format(example.nli_labels))
202
+
203
+ return features
204
+
205
+
206
+ class DataProcessor:
207
+ def __init__(
208
+ self,
209
+ model_name_or_path,
210
+ max_seq1_length,
211
+ max_seq2_length,
212
+ max_num_questions,
213
+ cand_k,
214
+ data_dir='',
215
+ cache_dir_name='cache_check',
216
+ overwrite_cache=False,
217
+ mask_rate=0.
218
+ ):
219
+ self.model_name_or_path = model_name_or_path
220
+ self.max_seq1_length = max_seq1_length
221
+ self.max_seq2_length = max_seq2_length
222
+ self.max_num_questions = max_num_questions
223
+ self.k = cand_k
224
+ self.mask_rate = mask_rate
225
+
226
+ self.data_dir = data_dir
227
+ self.cached_data_dir = os.path.join(data_dir, cache_dir_name)
228
+ self.overwrite_cache = overwrite_cache
229
+
230
+ self.label2id = {"SUPPORTS": 2, "REFUTES": 0, 'NOT ENOUGH INFO': 1}
231
+
232
+ def _format_file(self, role):
233
+ return os.path.join(self.data_dir, "{}.json".format(role))
234
+
235
+ def load_and_cache_data(self, role, tokenizer, data_tag):
236
+ tf.io.gfile.makedirs(self.cached_data_dir)
237
+ cached_file = os.path.join(
238
+ self.cached_data_dir,
239
+ "cached_features_{}_{}_{}_{}_{}_{}".format(
240
+ role,
241
+ list(filter(None, self.model_name_or_path.split("/"))).pop(),
242
+ str(self.max_seq1_length),
243
+ str(self.max_seq2_length),
244
+ str(self.k),
245
+ data_tag
246
+ ),
247
+ )
248
+ if os.path.exists(cached_file) and not self.overwrite_cache:
249
+ logger.info("Loading features from cached file {}".format(cached_file))
250
+ features = torch.load(cached_file)
251
+ else:
252
+ examples = []
253
+ with tf.io.gfile.GFile(self._format_file(role)) as f:
254
+ data = f.readlines()
255
+ for line in tqdm(data):
256
+ sample = self._load_line(line)
257
+ examples.append(InputExample(**sample))
258
+ features = convert_examples_to_features(examples, tokenizer,
259
+ self.max_seq1_length, self.max_seq2_length)
260
+ if 'train' in role or 'eval' in role:
261
+ logger.info("Saving features into cached file {}".format(cached_file))
262
+ torch.save(features, cached_file)
263
+
264
+ return self._create_tensor_dataset(features, tokenizer)
265
+
266
+ def convert_inputs_to_dataset(self, inputs, tokenizer, verbose=True):
267
+ examples = []
268
+ for line in inputs:
269
+ sample = self._load_line(line)
270
+ examples.append(InputExample(**sample))
271
+ features = convert_examples_to_features(examples, tokenizer,
272
+ self.max_seq1_length, self.max_seq2_length, verbose)
273
+
274
+ return self._create_tensor_dataset(features, tokenizer, do_predict=True)
275
+
276
+ def _create_tensor_dataset(self, features, tokenizer, do_predict=False):
277
+ all_c_input_ids = torch.tensor([f.c_input_ids for f in features], dtype=torch.long)
278
+ all_c_attention_mask = torch.tensor([f.c_attention_mask for f in features], dtype=torch.long)
279
+ all_c_token_type_ids = torch.tensor([f.c_token_type_ids for f in features], dtype=torch.long)
280
+
281
+ all_q_input_ids_list = []
282
+ all_q_attention_mask_list = []
283
+ all_q_token_type_ids_list = []
284
+
285
+ def _trunc_agg(self, feature, pad_token):
286
+ # feature: m x L
287
+ _input_list = [v for v in feature[:self.max_num_questions]]
288
+ while len(_input_list) < self.max_num_questions:
289
+ _input_list.append([pad_token] * self.max_seq2_length)
290
+ return _input_list
291
+
292
+ for f in features: # N x m x L
293
+ all_q_input_ids_list.append(_trunc_agg(self, f.q_input_ids_list, tokenizer.pad_token_id))
294
+ all_q_attention_mask_list.append(_trunc_agg(self, f.q_attention_mask_list, 0))
295
+ all_q_token_type_ids_list.append(_trunc_agg(self, f.q_token_type_ids_list, tokenizer.pad_token_type_id))
296
+
297
+ all_q_input_ids_list = torch.tensor(all_q_input_ids_list, dtype=torch.long)
298
+ all_q_attention_mask_list = torch.tensor(all_q_attention_mask_list, dtype=torch.long)
299
+ all_q_token_type_ids_list = torch.tensor(all_q_token_type_ids_list, dtype=torch.long)
300
+
301
+ all_nli_labels_list = []
302
+ for f in features:
303
+ all_nli_labels_list.append(f.nli_labels[:self.max_num_questions]
304
+ + max(0, (self.max_num_questions - len(f.nli_labels))) * [[0., 0., 0.]])
305
+ all_nli_labels = torch.tensor(all_nli_labels_list, dtype=torch.float)
306
+
307
+ if not do_predict:
308
+ all_labels = torch.tensor([f.label for f in features], dtype=torch.long)
309
+ dataset = TensorDataset(
310
+ all_c_input_ids, all_c_attention_mask, all_c_token_type_ids,
311
+ all_q_input_ids_list, all_q_attention_mask_list, all_q_token_type_ids_list,
312
+ all_nli_labels, all_labels,
313
+ )
314
+ else:
315
+ dataset = TensorDataset(
316
+ all_c_input_ids, all_c_attention_mask, all_c_token_type_ids,
317
+ all_q_input_ids_list, all_q_attention_mask_list, all_q_token_type_ids_list,
318
+ all_nli_labels,
319
+ )
320
+
321
+ return dataset
322
+
323
+ def _load_line(self, line):
324
+ if isinstance(line, str):
325
+ line = json.loads(line)
326
+ guid = line["id"]
327
+ claim = line["claim"]
328
+
329
+ # TODO: hack no evidence situation
330
+ evidences = line["evidence"] if len(line['evidence']) > 0 else ['no idea'] * 5
331
+ questions = line["questions"]
332
+ answers = line["answers"]
333
+ evidential = assemble_answers_to_one(line, self.k, mask_rate=self.mask_rate)['evidential_assembled']
334
+ label = line.get("label", None)
335
+ nli_labels = line.get('nli_labels', [[0., 0., 0.]] * len(questions))
336
+
337
+ for i, e in enumerate(evidential):
338
+ if '<mask>' in e:
339
+ nli_labels[i] = [0., 0., 0.]
340
+
341
+ answers = [v[0] for v in answers] # k = 1
342
+ label = self.label2id.get(label)
343
+
344
+ sample = {
345
+ "guid": guid,
346
+ "claim": claim,
347
+ "evidences": evidences,
348
+ "questions": questions,
349
+ "answers": answers,
350
+ "evidential": evidential, # already assembled.
351
+ "label": label,
352
+ 'nli_labels': nli_labels
353
+ }
354
+ return sample
src/check_client/modules/test_data_processor.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ """
4
+ @Author : Jiangjie Chen
5
+ @Time : 2020/12/20 18:05
6
+ @Contact : jjchen19@fudan.edu.cn
7
+ @Description:
8
+ """
9
+
10
+ import os
11
+ from data_processor import DataProcessor
12
+ from transformers import RobertaTokenizer
13
+
14
+
15
+ root = os.environ['PJ_HOME']
16
+
17
+ tokenizer = RobertaTokenizer.from_pretrained('roberta-large')
18
+ dp = DataProcessor('roberta-large', 256, 128, 8, cand_k=3, data_dir=f'{root}/data/fact_checking/v5', overwrite_cache=True)
19
+
20
+ # dp.load_and_cache_data('val', tokenizer)
21
+
22
+
23
+ data = {"id":91198,"claim":"Colin Kaepernick became a starting quarterback during the 49ers 63rd season in the National Football League .","evidence":["Things about Colin Kaepernick: He remained the team 's starting quarterback for the rest of the season and went on to lead the 49ers to their first Super Bowl appearance since 1994 , losing to the Baltimore Ravens .","Things about Colin Kaepernick: In the following seasons , Kaepernick lost and won back his starting job , with the 49ers missing the playoffs for three years consecutively .","Things about Colin Kaepernick: During the 2013 season , his first full season as a starter , Kaepernick helped the 49ers reach the NFC Championship , losing to the Seattle Seahawks .","Things about Colin Kaepernick: Kaepernick began his professional career as a backup to Alex Smith , but became the 49ers ' starter in the middle of the 2012 season after Smith suffered a concussion .","Things about Colin Kaepernick: Colin Rand Kaepernick ( ; born November 3 , 1987 ) is an American football quarterback who is currently a free agent ."],"answers":[["Colin Kaepernick",0,16],["a starting quarterback",24,46],["49ers",58,63],["63rd season",64,75],["National Football League",83,107]],"questions":["noun","noun","noun","noun","noun"],"label":"NOT ENOUGH INFO","evidential_assembled":["Who was the starting quarterback for the 49ers in the 63rd season? or <mask> became a starting quarterback during the 49ers 63rd season in the National Football League .","What was Colin Kaepernick's first job title? or Colin Kaepernick became <mask> during the 49ers 63rd season in the National Football League .","What team was Colin Kaepernick a quarterback for? or Colin Kaepernick became a starting quarterback during the <mask> 63rd season in the National Football League .","In what season did Colin Kaepernick become a starting quarterback for the 49ers? or Colin Kaepernick became a starting quarterback during the 49ers <mask> in the National Football League .","What league was Colin Kaepernick a quarterback in? or Colin Kaepernick became a starting quarterback during the 49ers 63rd season in the <mask> ."],"evidential":[["Colin Kaepernick became a starting quarterback during the 49ers 63rd season in the National Football League .","Colin Kapit became a starting quarterback during the 49ers 63rd season in the National Football League .","Colin Kapra became a starting quarterback during the 49ers 63rd season in the National Football League .","Colin Kaepernick became a starting quarterback during the 49ers 63rd season in the National Football League .","Colin Colin Kaepernick became a starting quarterback during the 49ers 63rd season in the National Football League ."],["Colin Kaepernick became a quarterback during the 49ers 63rd season in the National Football League .","Colin Kaepernick became a starter during the 49ers 63rd season in the National Football League .","Colin Kaepernick became a starting quarterback during the 49ers 63rd season in the National Football League .","Colin Kaepernick became a backup quarterback during the 49ers 63rd season in the National Football League .","Colin Kaepernick became a starting quarterback during the 49ers 63rd season in the National Football League ,"],["Colin Kaepernick became a starting quarterback during the 49ers 63rd season in the National Football League .","Colin Kaepernick became a starting quarterback during the 49ers ' 63rd season in the National Football League .","Colin Kaepernick became a starting quarterback during the 49ers' 63rd season in the National Football League .","Colin Kaepernick became a starting quarterback during the Niners 63rd season in the National Football League .","Colin Kaepernick became a starting quarterback during the 49ers 63rd season in the National Football League ."],["Colin Kaepernick became a starting quarterback during the 49ers season in the National Football League .","Colin Kaepernick became a starting quarterback during the 49ers ' season in the National Football League .","Colin Kaepernick became a starting quarterback during the 49ers season in the National Football League .","Colin Kaepernick became a starting quarterback during the 49ers first season in the National Football League .","Colin Kaepernick became a starting quarterback during the 49ers second season in the National Football League ."],["Colin Kaepernick became a starting quarterback during the 49ers 63rd season in the Super Bowl .","Colin Kaepernick became a starting quarterback during the 49ers 63rd season in the NFC .","Colin Kaepernick became a starting quarterback during the 49ers 63rd season in the professional sports .","Colin Kaepernick became a starting quarterback during the 49ers 63rd season in the NFL .","Colin Kaepernick became a starting quarterback during the 49ers 63rd season in the league ."]]}
24
+
25
+ s = dp.convert_inputs_to_dataset([data], tokenizer, True)
26
+ print(s)
src/check_client/plm_checkers/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ """
4
+ @Author : Jiangjie Chen
5
+ @Time : 2020/12/27 15:41
6
+ @Contact : jjchen19@fudan.edu.cn
7
+ @Description:
8
+ """
9
+
10
+
11
+ from .bert_checker import BertChecker
12
+ from .roberta_checker import RobertaChecker
src/check_client/plm_checkers/bert_checker.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ '''
4
+ @Author : Jiangjie Chen
5
+ @Time : 2020/8/18 14:40
6
+ @Contact : jjchen19@fudan.edu.cn
7
+ @Description:
8
+ '''
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+
14
+ from transformers import BertModel, BertPreTrainedModel
15
+ from .checker_utils import attention_mask_to_mask, ClassificationHead, soft_logic, build_pseudo_labels, \
16
+ get_label_embeddings, temperature_annealing
17
+
18
+
19
+ class BertChecker(BertPreTrainedModel):
20
+ def __init__(self, config, logic_lambda=0.0, prior='nli', m=8, temperature=1):
21
+ super().__init__(config)
22
+ self.num_labels = config.num_labels
23
+ self.hidden_size = config.hidden_size
24
+ self.bert = BertModel(config)
25
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
26
+ self._lambda = logic_lambda
27
+ self.prior = prior
28
+ self.temperature = temperature
29
+ self._step = 0
30
+
31
+ # general attention
32
+ self.linear_self_attn = nn.Linear(self.hidden_size, 1, bias=False)
33
+ self.linear_m_attn = nn.Linear(self.hidden_size * 2, 1, bias=False)
34
+
35
+ self.var_hidden_size = self.hidden_size // 4
36
+
37
+ z_hid_size = self.num_labels * m
38
+ self.linear_P_theta = nn.Linear(self.hidden_size * 2 + z_hid_size, self.var_hidden_size)
39
+ y_hid_size = self.var_hidden_size
40
+ self.linear_Q_phi = nn.Linear(self.hidden_size * 2 + y_hid_size, self.var_hidden_size)
41
+
42
+ self.classifier = ClassificationHead(self.var_hidden_size, self.num_labels, config.hidden_dropout_prob) # label embedding for y
43
+ self.z_clf = self.classifier
44
+ self.init_weights()
45
+
46
+ def forward(self, claim_input_ids, claim_attention_mask, claim_token_type_ids,
47
+ qa_input_ids_list, qa_attention_mask_list, qa_token_type_ids_list,
48
+ nli_labels=None, labels=None):
49
+ '''
50
+ m: num of questions; n: num of evidence; k: num of candidate answers
51
+ :param claim_input_ids: b x L1
52
+ :param claim_attention_mask: b x L1
53
+ :param claim_token_type_ids: b x L1
54
+ :param qa_input_ids_list: b x m x L2
55
+ :param qa_attention_mask_list: b x m x L2
56
+ :param qa_token_type_ids_list: b x m x L2
57
+ :param labels: (b,)
58
+ :return:
59
+ '''
60
+ self._step += 1
61
+ _zero = torch.tensor(0.).to(claim_input_ids.device)
62
+
63
+ global_output = self.bert(
64
+ claim_input_ids,
65
+ attention_mask=claim_attention_mask,
66
+ token_type_ids=claim_token_type_ids
67
+ )[0] # b x L1 x h
68
+
69
+ global_output = self.self_select(global_output) # b x h
70
+
71
+ _qa_input_ids_list = qa_input_ids_list.transpose(1, 0) # m x b x L2
72
+ _qa_attention_mask_list = qa_attention_mask_list.transpose(1, 0)
73
+ _qa_token_type_ids_list = qa_token_type_ids_list.transpose(1, 0)
74
+
75
+ local_output_list = []
76
+ for _inp, _attn, _token_ids in zip(_qa_input_ids_list, _qa_attention_mask_list, _qa_token_type_ids_list):
77
+ _local_output = self.bert(_inp, attention_mask=_attn,
78
+ token_type_ids=_token_ids)[0]
79
+ _local_output = self.self_select(_local_output)
80
+ local_output_list.append(_local_output)
81
+
82
+ local_outputs = torch.stack(local_output_list, 0) # m x b x h
83
+ local_outputs = local_outputs.transpose(1, 0).contiguous() # b x m x h
84
+
85
+ neg_elbo, loss, logic_loss = _zero, _zero, _zero
86
+ mask = attention_mask_to_mask(qa_attention_mask_list)
87
+ # b x h, b x m x h -> b x h
88
+ local_outputs_w, m_attn = self.local_attn(global_output, local_outputs, mask)
89
+ local_outputs = torch.cat([local_outputs, global_output.unsqueeze(1).repeat(1, local_outputs.size(1), 1)], -1)
90
+
91
+ if labels is not None:
92
+ # Training
93
+ # ======================== Q_phi ================================
94
+
95
+ labels_onehot = F.one_hot(labels, num_classes=self.num_labels).to(torch.float)
96
+ y_star_emb = get_label_embeddings(labels_onehot, self.classifier.out_proj.weight) # b x h
97
+ z = self.Q_phi(local_outputs, y_star_emb)
98
+ z_softmax = z.softmax(-1)
99
+
100
+ # ======================== P_theta ==============================
101
+
102
+ z_gumbel = F.gumbel_softmax(z, tau=temperature_annealing(self.temperature, self._step),
103
+ dim=-1, hard=True) # b x m x 3
104
+ y = self.P_theta(global_output, local_outputs_w, z_gumbel)
105
+
106
+ # ======================== soft logic ===========================
107
+ mask = mask.to(torch.int)
108
+ y_z = soft_logic(z_softmax, mask) # b x 3
109
+ logic_loss = F.kl_div(y.log_softmax(-1), y_z)
110
+
111
+ # ======================== ELBO =================================
112
+ elbo_neg_p_log = F.cross_entropy(y.view(-1, self.num_labels), labels.view(-1))
113
+ if self.prior == 'nli':
114
+ prior = nli_labels.softmax(dim=-1)
115
+ elif self.prior == 'uniform':
116
+ prior = torch.tensor([1 / self.num_labels] * self.num_labels).to(y)
117
+ prior = prior.unsqueeze(0).unsqueeze(0).repeat(mask.size(0), mask.size(1), 1)
118
+ elif self.prior == 'logic':
119
+ prior = build_pseudo_labels(labels, m_attn)
120
+ else:
121
+ raise NotImplementedError(self.prior)
122
+
123
+ elbo_kl = F.kl_div(z_softmax.log(), prior)
124
+ neg_elbo = elbo_kl + elbo_neg_p_log
125
+
126
+ loss = (1 - abs(self._lambda)) * neg_elbo + abs(self._lambda) * logic_loss
127
+ else:
128
+ # Inference
129
+ if self.prior == 'nli':
130
+ z = nli_labels
131
+ elif self.prior == 'uniform':
132
+ prior = torch.tensor([1 / self.num_labels] * self.num_labels).to(y)
133
+ z = prior.unsqueeze(0).unsqueeze(0).repeat(mask.size(0), mask.size(1), 1)
134
+ else:
135
+ z = torch.rand([local_outputs.size(0), local_outputs.size(1), self.num_labels]).to(local_outputs)
136
+ z_softmax = z.softmax(-1)
137
+
138
+ for i in range(3): # N = 3
139
+ z = z_softmax.argmax(-1)
140
+ z = F.one_hot(z, num_classes=3).to(torch.float)
141
+ y = self.P_theta(global_output, local_outputs_w, z)
142
+ y = y.softmax(-1)
143
+ y_emb = get_label_embeddings(y, self.classifier.out_proj.weight)
144
+ z = self.Q_phi(local_outputs, y_emb)
145
+ z_softmax = z.softmax(-1)
146
+
147
+ return (loss, (neg_elbo, logic_loss), y, m_attn, (z_softmax, mask)) # batch first
148
+
149
+ def Q_phi(self, X, y):
150
+ '''
151
+ X, y => z
152
+ :param X: b x m x h
153
+ :param y_emb: b x 3 / b x h'
154
+ :return: b x m x 3 (ref, nei, sup)
155
+ '''
156
+ y_expand = y.unsqueeze(1).repeat(1, X.size(1), 1) # b x m x 3/h'
157
+ z_hidden = self.linear_Q_phi(torch.cat([y_expand, X], dim=-1)) # b x m x h'
158
+ z_hidden = F.tanh(z_hidden)
159
+ z = self.z_clf(z_hidden)
160
+ return z
161
+
162
+ def P_theta(self, X_global, X_local, z):
163
+ '''
164
+ X, z => y*
165
+ :param X_global: b x h
166
+ :param X_local: b x m x h
167
+ :param z: b x m x 3
168
+ :param mask: b x m
169
+ :return: b x 3, b x m
170
+ '''
171
+ b = z.size(0)
172
+ # global classification
173
+ _logits = torch.cat([X_local, X_global, z.reshape(b, -1)], dim=-1)
174
+ _logits = self.dropout(_logits)
175
+ _logits = self.linear_P_theta(_logits)
176
+ _logits = torch.tanh(_logits)
177
+
178
+ y = self.classifier(_logits)
179
+ return y
180
+
181
+ def self_select(self, h_x):
182
+ '''
183
+ self attention on a vector
184
+ :param h_x: b x L x h
185
+ :return: b x h
186
+ '''
187
+ w = self.dropout(self.linear_self_attn(h_x).squeeze(-1)).softmax(-1)
188
+ return torch.einsum('blh,bl->bh', h_x, w)
189
+
190
+ def local_attn(self, global_output, local_outputs, mask):
191
+ '''
192
+ :param global_output: b x h
193
+ :param qa_outputs: b x m x h
194
+ :param mask: b x m
195
+ :return: b x h, b x m
196
+ '''
197
+ m = local_outputs.size(1)
198
+ scores = self.linear_m_attn(torch.cat([global_output.unsqueeze(1).repeat(1, m, 1),
199
+ local_outputs], dim=-1)).squeeze(-1) # b x m
200
+ mask = 1 - mask
201
+ scores = scores.masked_fill(mask.to(torch.bool), -1e16)
202
+ attn = F.softmax(scores, -1)
203
+ return torch.einsum('bm,bmh->bh', attn, local_outputs), attn
src/check_client/plm_checkers/checker_utils.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ '''
4
+ @Author : Jiangjie Chen
5
+ @Time : 2020/10/15 16:10
6
+ @Contact : jjchen19@fudan.edu.cn
7
+ @Description:
8
+ '''
9
+
10
+ import torch
11
+ import random
12
+ import torch.nn.functional as F
13
+ import torch.nn as nn
14
+
15
+
16
+ class ClassificationHead(nn.Module):
17
+ """Head for sentence-level classification tasks."""
18
+
19
+ def __init__(self, hidden_size, num_labels, hidden_dropout_prob=0.2):
20
+ super().__init__()
21
+ self.dropout = nn.Dropout(hidden_dropout_prob)
22
+ self.out_proj = nn.Linear(hidden_size, num_labels, bias=False)
23
+
24
+ def forward(self, features, **kwargs):
25
+ x = features
26
+ x = self.dropout(x)
27
+ x = self.out_proj(x)
28
+ return x
29
+
30
+
31
+ def temperature_annealing(tau, step):
32
+ if tau == 0.:
33
+ tau = 10. if step % 5 == 0 else 1.
34
+ return tau
35
+
36
+
37
+ def get_label_embeddings(labels, label_embedding):
38
+ '''
39
+ :param labels: b x 3
40
+ :param label_embedding: 3 x h'
41
+ :return: b x h'
42
+ '''
43
+ emb = torch.einsum('oi,bo->bi', label_embedding, labels)
44
+ return emb
45
+
46
+
47
+ def soft_logic(y_i, mask, tnorm='product'):
48
+ '''
49
+ a^b = ab
50
+ avb = 1 - ((1-a)(1-b))
51
+ :param y_i: b x m x 3
52
+ :param mask: b x m
53
+ :param tnorm: product or godel or lukasiewicz
54
+ :return: [b x 3]
55
+ '''
56
+ _sup = y_i[:, :, 2] # b x m
57
+ _ref = y_i[:, :, 0] # b x m
58
+ _sup = _sup * mask + (1 - mask) # pppp1111
59
+ _ref = _ref * mask # pppp0000
60
+
61
+ if tnorm == 'product':
62
+ p_sup = torch.exp(torch.log(_sup).sum(1))
63
+ p_ref = 1 - torch.exp(torch.log(1 - _ref).sum(1))
64
+ elif tnorm == 'godel':
65
+ p_sup = _sup.min(-1).values
66
+ p_ref = _ref.max(-1).values
67
+ elif tnorm == 'lukas':
68
+ raise NotImplementedError(tnorm)
69
+ else:
70
+ raise NotImplementedError(tnorm)
71
+
72
+ p_nei = 1 - p_sup - p_ref
73
+ p_sup = torch.max(p_sup, torch.zeros_like(p_sup))
74
+ p_ref = torch.max(p_ref, torch.zeros_like(p_ref))
75
+ p_nei = torch.max(p_nei, torch.zeros_like(p_nei))
76
+ logical_prob = torch.stack([p_ref, p_nei, p_sup], dim=-1)
77
+ assert torch.lt(logical_prob, 0).to(torch.int).sum().tolist() == 0, \
78
+ (logical_prob, _sup, _ref)
79
+ return logical_prob # b x 3
80
+
81
+
82
+ def build_pseudo_labels(labels, m_attn):
83
+ '''
84
+ :param labels: (b,)
85
+ :param m_attn: b x m
86
+ :return: b x m x 3
87
+ '''
88
+ mask = torch.gt(m_attn, 1e-16).to(torch.int)
89
+ sup_label = torch.tensor(2).to(labels)
90
+ nei_label = torch.tensor(1).to(labels)
91
+ ref_label = torch.tensor(0).to(labels)
92
+ pseudo_labels = []
93
+ for idx, label in enumerate(labels):
94
+ mm = mask[idx].sum(0)
95
+ if label == 2: # SUPPORTS
96
+ pseudo_label = F.one_hot(sup_label.repeat(mask.size(1)), num_classes=3).to(torch.float) # TODO: hyperparam
97
+
98
+ elif label == 0: # REFUTES
99
+ num_samples = magic_proportion(mm)
100
+ ids = torch.topk(m_attn[idx], k=num_samples).indices
101
+ pseudo_label = []
102
+ for i in range(mask.size(1)):
103
+ if i >= mm:
104
+ _label = torch.tensor([1/3, 1/3, 1/3]).to(labels)
105
+ elif i in ids:
106
+ _label = F.one_hot(ref_label, num_classes=3).to(torch.float)
107
+ else:
108
+ if random.random() > 0.5:
109
+ _label = torch.tensor([0., 0., 1.]).to(labels)
110
+ else:
111
+ _label = torch.tensor([0., 1., 0.]).to(labels)
112
+ pseudo_label.append(_label)
113
+ pseudo_label = torch.stack(pseudo_label)
114
+
115
+ else: # NEI
116
+ num_samples = magic_proportion(mm)
117
+ ids = torch.topk(m_attn[idx], k=num_samples).indices
118
+ pseudo_label = sup_label.repeat(mask.size(1))
119
+ pseudo_label[ids] = nei_label
120
+ pseudo_label = F.one_hot(pseudo_label, num_classes=3).to(torch.float) # TODO: hyperparam
121
+
122
+ pseudo_labels.append(pseudo_label)
123
+ return torch.stack(pseudo_labels)
124
+
125
+
126
+ def magic_proportion(m, magic_n=5):
127
+ # 1~4: 1, 5~m: 2
128
+ return m // magic_n + 1
129
+
130
+
131
+ def sequence_mask(lengths, max_len=None):
132
+ """
133
+ Creates a boolean mask from sequence lengths.
134
+ """
135
+ batch_size = lengths.numel()
136
+ max_len = max_len or lengths.max()
137
+ return (torch.arange(0, max_len, device=lengths.device)
138
+ .type_as(lengths)
139
+ .repeat(batch_size, 1)
140
+ .lt(lengths.unsqueeze(1)))
141
+
142
+
143
+ def collapse_w_mask(inputs, mask):
144
+ '''
145
+ :param inputs: b x L x h
146
+ :param mask: b x L
147
+ :return: b x h
148
+ '''
149
+ hidden = inputs.size(-1)
150
+ output = inputs * mask.unsqueeze(-1).repeat((1, 1, hidden)) # b x L x h
151
+ output = output.sum(-2)
152
+ output /= (mask.sum(-1) + 1e-6).unsqueeze(-1).repeat((1, hidden)) # b x h
153
+ return output
154
+
155
+
156
+ def parse_ce_outputs(ce_seq_output, ce_lengths):
157
+ '''
158
+ :param qa_seq_output: b x L1 x h
159
+ :param qa_lengths: e.g. [0,1,1,0,2,2,0,0] (b x L2)
160
+ :return:
161
+ c_output: b x h
162
+ e_output: b x h
163
+ '''
164
+ if ce_lengths.max() == 0:
165
+ b, L1, h = ce_seq_output.size()
166
+ return torch.zeros([b, h]).cuda(), torch.zeros([b, h]).cuda()
167
+ masks = []
168
+ for mask_id in range(1, ce_lengths.max() + 1):
169
+ _m = torch.ones_like(ce_lengths) * mask_id
170
+ mask = _m.eq(ce_lengths).to(torch.int)
171
+ masks.append(mask)
172
+ c_output = collapse_w_mask(ce_seq_output, masks[0])
173
+ e_output = torch.stack([collapse_w_mask(ce_seq_output, m)
174
+ for m in masks[1:]]).mean(0)
175
+ return c_output, e_output
176
+
177
+
178
+ def parse_qa_outputs(qa_seq_output, qa_lengths, k):
179
+ '''
180
+ :param qa_seq_output: b x L2 x h
181
+ :param qa_lengths: e.g. [0,1,1,0,2,2,0,3,0,4,0,5,0,0,0,0] (b x L2)
182
+ :return:
183
+ q_output: b x h
184
+ a_output: b x h
185
+ k_cand_output: k x b x h
186
+ '''
187
+ b, L2, h = qa_seq_output.size()
188
+ if qa_lengths.max() == 0:
189
+ return torch.zeros([b, h]).cuda(), torch.zeros([b, h]).cuda(), \
190
+ torch.zeros([k, b, h]).cuda()
191
+
192
+ masks = []
193
+ for mask_id in range(1, qa_lengths.max() + 1):
194
+ _m = torch.ones_like(qa_lengths) * mask_id
195
+ mask = _m.eq(qa_lengths).to(torch.int)
196
+ masks.append(mask)
197
+
198
+ q_output = collapse_w_mask(qa_seq_output, masks[0])
199
+ a_output = collapse_w_mask(qa_seq_output, masks[1])
200
+ k_cand_output = [collapse_w_mask(qa_seq_output, m)
201
+ for m in masks[2:2 + k]]
202
+ for i in range(k - len(k_cand_output)):
203
+ k_cand_output.append(torch.zeros([b, h]).cuda())
204
+ k_cand_output = torch.stack(k_cand_output, dim=0)
205
+
206
+ return q_output, a_output, k_cand_output
207
+
208
+
209
+ def attention_mask_to_mask(attention_mask):
210
+ '''
211
+ :param attention_mask: b x m x L
212
+ :return: b x m
213
+ '''
214
+ mask = torch.gt(attention_mask.sum(-1), 0).to(torch.int).sum(-1) # (b,)
215
+ mask = sequence_mask(mask, max_len=attention_mask.size(1)).to(torch.int) # (b, m)
216
+ return mask
217
+
218
+
219
+ if __name__ == "__main__":
220
+ y = torch.tensor([[[0.3,0.5,0.2],[0.1,0.4,0.5]]])
221
+ mask = torch.tensor([1,1])
222
+ s = soft_logic(y, mask)
223
+ print(s)
src/check_client/plm_checkers/roberta_checker.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ '''
4
+ @Author : Jiangjie Chen
5
+ @Time : 2020/8/18 14:40
6
+ @Contact : jjchen19@fudan.edu.cn
7
+ @Description:
8
+ '''
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+
14
+ from transformers import RobertaModel, BertPreTrainedModel, RobertaConfig
15
+ from .checker_utils import attention_mask_to_mask, ClassificationHead, soft_logic, build_pseudo_labels, \
16
+ get_label_embeddings, temperature_annealing
17
+
18
+
19
+ class RobertaChecker(BertPreTrainedModel):
20
+ config_class = RobertaConfig
21
+ base_model_prefix = "roberta"
22
+
23
+ def __init__(self, config, logic_lambda=0.0, prior='nli', m=8, temperature=1):
24
+ super().__init__(config)
25
+ self.num_labels = config.num_labels
26
+ self.hidden_size = config.hidden_size
27
+ self.roberta = RobertaModel(config)
28
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
29
+ self._lambda = logic_lambda
30
+ self.prior = prior
31
+ self.temperature = temperature
32
+ self._step = 0
33
+
34
+ # general attention
35
+ self.linear_self_attn = nn.Linear(self.hidden_size, 1, bias=False)
36
+ self.linear_m_attn = nn.Linear(self.hidden_size * 2, 1, bias=False)
37
+
38
+ self.var_hidden_size = self.hidden_size // 4
39
+
40
+ z_hid_size = self.num_labels * m
41
+ self.linear_P_theta = nn.Linear(self.hidden_size * 2 + z_hid_size, self.var_hidden_size)
42
+ y_hid_size = self.var_hidden_size
43
+ self.linear_Q_phi = nn.Linear(self.hidden_size * 2 + y_hid_size, self.var_hidden_size)
44
+
45
+ # TODO: y_clf => classifier. compromise for mnli
46
+ self.classifier = ClassificationHead(self.var_hidden_size, self.num_labels,
47
+ config.hidden_dropout_prob) # label embedding for y
48
+ self.z_clf = self.classifier
49
+ self.init_weights()
50
+
51
+ def forward(self, claim_input_ids, claim_attention_mask, claim_token_type_ids,
52
+ qa_input_ids_list, qa_attention_mask_list, qa_token_type_ids_list,
53
+ nli_labels=None, labels=None):
54
+ '''
55
+ m: num of questions; n: num of evidence; k: num of candidate answers
56
+ :param claim_input_ids: b x L1
57
+ :param claim_attention_mask: b x L1
58
+ :param claim_token_type_ids: b x L1
59
+ :param qa_input_ids_list: b x m x L2
60
+ :param qa_attention_mask_list: b x m x L2
61
+ :param qa_token_type_ids_list: b x m x L2
62
+ :param nli_labels: b x m x 3
63
+ :param labels: (b,)
64
+ :return: (loss, (neg_elbo, logic_loss), y, m_attn, (z_softmax, mask))
65
+ '''
66
+ self._step += 1
67
+ _zero = torch.tensor(0.).to(claim_input_ids.device)
68
+
69
+ # ====================== Representation learning =======================
70
+ global_output = self.roberta(claim_input_ids, attention_mask=claim_attention_mask)[0] # b x L1 x h
71
+ global_output = self.self_select(global_output) # b x h
72
+
73
+ _qa_input_ids_list = qa_input_ids_list.transpose(1, 0) # m x b x L2
74
+ _qa_attention_mask_list = qa_attention_mask_list.transpose(1, 0)
75
+
76
+ local_output_list = []
77
+ for _inp, _attn in zip(_qa_input_ids_list, _qa_attention_mask_list):
78
+ _local_output = self.roberta(_inp, attention_mask=_attn)[0]
79
+ _local_output = self.self_select(_local_output)
80
+ local_output_list.append(_local_output)
81
+
82
+ _local_outputs = torch.stack(local_output_list, 0) # m x b x h
83
+ local_outputs = _local_outputs.transpose(1, 0).contiguous() # b x m x h
84
+
85
+ neg_elbo, loss, logic_loss = _zero, _zero, _zero
86
+ mask = attention_mask_to_mask(qa_attention_mask_list)
87
+ # b x h, b x m x h -> b x h
88
+ local_outputs_w, m_attn = self.local_attn(global_output, local_outputs, mask)
89
+ local_outputs = torch.cat([local_outputs, global_output.unsqueeze(1).repeat(1, local_outputs.size(1), 1)], -1)
90
+
91
+ if labels is not None:
92
+ # Training
93
+ # ======================== Q_phi ================================
94
+
95
+ labels_onehot = F.one_hot(labels, num_classes=self.num_labels).to(torch.float)
96
+ y_star_emb = get_label_embeddings(labels_onehot, self.classifier.out_proj.weight) # b x h
97
+ z = self.Q_phi(local_outputs, y_star_emb)
98
+ z_softmax = z.softmax(-1)
99
+
100
+ # ======================== P_theta ==============================
101
+
102
+ z_gumbel = F.gumbel_softmax(z, tau=temperature_annealing(self.temperature, self._step),
103
+ dim=-1, hard=True) # b x m x 3
104
+ y = self.P_theta(global_output, local_outputs_w, z_gumbel)
105
+
106
+ # ======================== soft logic ===========================
107
+ mask = mask.to(torch.int)
108
+ y_z = soft_logic(z_softmax, mask) # b x 3
109
+ logic_loss = F.kl_div(y.log_softmax(-1), y_z)
110
+
111
+ # ======================== ELBO =================================
112
+ elbo_neg_p_log = F.cross_entropy(y.view(-1, self.num_labels), labels.view(-1))
113
+ if self.prior == 'nli':
114
+ prior = nli_labels.softmax(dim=-1)
115
+ elif self.prior == 'uniform':
116
+ prior = torch.tensor([1 / self.num_labels] * self.num_labels).to(mask.device)
117
+ prior = prior.unsqueeze(0).unsqueeze(0).repeat(mask.size(0), mask.size(1), 1)
118
+ elif self.prior == 'logic':
119
+ prior = build_pseudo_labels(labels, m_attn)
120
+ else:
121
+ raise NotImplementedError(self.prior)
122
+
123
+ elbo_kl = F.kl_div(z_softmax.log(), prior)
124
+ neg_elbo = elbo_kl + elbo_neg_p_log
125
+
126
+ loss = (1 - abs(self._lambda)) * neg_elbo + abs(self._lambda) * logic_loss
127
+ else:
128
+ # Inference
129
+ if self.prior == 'nli':
130
+ z = nli_labels
131
+ elif self.prior == 'uniform':
132
+ prior = torch.tensor([1 / self.num_labels] * self.num_labels).to(mask.device)
133
+ z = prior.unsqueeze(0).unsqueeze(0).repeat(mask.size(0), mask.size(1), 1)
134
+ else:
135
+ z = torch.rand([local_outputs.size(0), local_outputs.size(1), self.num_labels]).to(local_outputs)
136
+ z_softmax = z.softmax(-1)
137
+
138
+ for i in range(3): # N = 3
139
+ z = z_softmax.argmax(-1)
140
+ z = F.one_hot(z, num_classes=3).to(torch.float)
141
+ y = self.P_theta(global_output, local_outputs_w, z)
142
+ y = y.softmax(-1)
143
+ y_emb = get_label_embeddings(y, self.classifier.out_proj.weight)
144
+ z = self.Q_phi(local_outputs, y_emb)
145
+ z_softmax = z.softmax(-1)
146
+
147
+ return (loss, (neg_elbo, logic_loss), y, m_attn, (z_softmax, mask)) # batch first
148
+
149
+ def Q_phi(self, X, y):
150
+ '''
151
+ X, y => z
152
+ :param X: b x m x h
153
+ :param y_emb: b x 3 / b x h'
154
+ :return: b x m x 3 (ref, nei, sup)
155
+ '''
156
+ y_expand = y.unsqueeze(1).repeat(1, X.size(1), 1) # b x m x 3/h'
157
+ z_hidden = self.linear_Q_phi(torch.cat([y_expand, X], dim=-1)) # b x m x h'
158
+ z_hidden = F.tanh(z_hidden)
159
+ z = self.z_clf(z_hidden)
160
+ return z
161
+
162
+ def P_theta(self, X_global, X_local, z):
163
+ '''
164
+ X, z => y*
165
+ :param X_global: b x h
166
+ :param X_local: b x m x h
167
+ :param z: b x m x 3
168
+ :param mask: b x m
169
+ :return: b x 3, b x m
170
+ '''
171
+ b = z.size(0)
172
+ # global classification
173
+ _logits = torch.cat([X_local, X_global, z.reshape(b, -1)], dim=-1)
174
+ _logits = self.dropout(_logits)
175
+ _logits = self.linear_P_theta(_logits)
176
+ _logits = torch.tanh(_logits)
177
+
178
+ y = self.classifier(_logits)
179
+ return y
180
+
181
+ def self_select(self, h_x):
182
+ '''
183
+ self attention on a vector
184
+ :param h_x: b x L x h
185
+ :return: b x h
186
+ '''
187
+ w = self.dropout(self.linear_self_attn(h_x).squeeze(-1)).softmax(-1)
188
+ return torch.einsum('blh,bl->bh', h_x, w)
189
+
190
+ def local_attn(self, global_output, local_outputs, mask):
191
+ '''
192
+ :param global_output: b x h
193
+ :param qa_outputs: b x m x h
194
+ :param mask: b x m
195
+ :return: b x h, b x m
196
+ '''
197
+ m = local_outputs.size(1)
198
+ scores = self.linear_m_attn(torch.cat([global_output.unsqueeze(1).repeat(1, m, 1),
199
+ local_outputs], dim=-1)).squeeze(-1) # b x m
200
+ mask = 1 - mask
201
+ scores = scores.masked_fill(mask.to(torch.bool), -1e16)
202
+ attn = F.softmax(scores, -1)
203
+ return torch.einsum('bm,bmh->bh', attn, local_outputs), attn
src/check_client/scripts/train_bert-large.sh ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ MODEL_TYPE=bert
4
+ MODEL_NAME_OR_PATH=bert-large-cased
5
+ VERSION=v5
6
+ MAX_NUM_QUESTIONS=8
7
+
8
+ MAX_SEQ1_LENGTH=256
9
+ MAX_SEQ2_LENGTH=128
10
+ CAND_K=3
11
+ LAMBDA=${1:-0.5}
12
+ PRIOR=${2:-nli}
13
+ MASK=${3:-0.0}
14
+ echo "lambda = $LAMBDA, prior = $PRIOR, mask = $MASK"
15
+
16
+ DATA_DIR=$PJ_HOME/data/fact_checking/${VERSION}
17
+ OUTPUT_DIR=$PJ_HOME/models/fact_checking/${VERSION}_${MODEL_NAME_OR_PATH}/${VERSION}_${MODEL_NAME_OR_PATH}_AAAI_K${CAND_K}_${PRIOR}_m${MASK}_l${LAMBDA}
18
+ NUM_TRAIN_EPOCH=7
19
+ GRADIENT_ACCUMULATION_STEPS=2
20
+ PER_GPU_TRAIN_BATCH_SIZE=8 # 4546
21
+ PER_GPU_EVAL_BATCH_SIZE=16
22
+ LOGGING_STEPS=200
23
+ SAVE_STEPS=200
24
+
25
+
26
+ python3 train.py \
27
+ --data_dir ${DATA_DIR} \
28
+ --output_dir ${OUTPUT_DIR} \
29
+ --model_type ${MODEL_TYPE} \
30
+ --model_name_or_path ${MODEL_NAME_OR_PATH} \
31
+ --max_seq1_length ${MAX_SEQ1_LENGTH} \
32
+ --max_seq2_length ${MAX_SEQ2_LENGTH} \
33
+ --max_num_questions ${MAX_NUM_QUESTIONS} \
34
+ --do_train \
35
+ --do_eval \
36
+ --evaluate_during_training \
37
+ --learning_rate 1e-5 \
38
+ --num_train_epochs ${NUM_TRAIN_EPOCH} \
39
+ --gradient_accumulation_steps ${GRADIENT_ACCUMULATION_STEPS} \
40
+ --per_gpu_train_batch_size ${PER_GPU_TRAIN_BATCH_SIZE} \
41
+ --per_gpu_eval_batch_size ${PER_GPU_EVAL_BATCH_SIZE} \
42
+ --logging_steps ${LOGGING_STEPS} \
43
+ --save_steps ${SAVE_STEPS} \
44
+ --cand_k ${CAND_K} \
45
+ --logic_lambda ${LAMBDA} \
46
+ --prior ${PRIOR} \
47
+ --overwrite_output_dir \
48
+ --temperature 1.0 \
49
+ --mask_rate ${MASK}
50
+
51
+ python3 cjjpy.py --lark "$OUTPUT_DIR fact checking training completed"
src/check_client/scripts/train_roberta.sh ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ MODEL_TYPE=roberta
4
+ MODEL_NAME_OR_PATH=roberta-large
5
+ VERSION=v5
6
+ MAX_NUM_QUESTIONS=8
7
+
8
+ MAX_SEQ1_LENGTH=256
9
+ MAX_SEQ2_LENGTH=128
10
+ CAND_K=3
11
+ LAMBDA=${1:-0.5}
12
+ PRIOR=${2:-nli}
13
+ MASK=${3:-0.0}
14
+ echo "lambda = $LAMBDA, prior = $PRIOR, mask = $MASK"
15
+
16
+ DATA_DIR=$PJ_HOME/data/fact_checking/${VERSION}
17
+ OUTPUT_DIR=$PJ_HOME/models/fact_checking/${VERSION}_${MODEL_NAME_OR_PATH}/${VERSION}_${MODEL_NAME_OR_PATH}_AAAI_K${CAND_K}_${PRIOR}_m${MASK}_l${LAMBDA}
18
+ NUM_TRAIN_EPOCH=7
19
+ GRADIENT_ACCUMULATION_STEPS=2
20
+ PER_GPU_TRAIN_BATCH_SIZE=8 # 4546
21
+ PER_GPU_EVAL_BATCH_SIZE=16
22
+ LOGGING_STEPS=200
23
+ SAVE_STEPS=200
24
+
25
+
26
+ python3 train.py \
27
+ --data_dir ${DATA_DIR} \
28
+ --output_dir ${OUTPUT_DIR} \
29
+ --model_type ${MODEL_TYPE} \
30
+ --model_name_or_path ${MODEL_NAME_OR_PATH} \
31
+ --max_seq1_length ${MAX_SEQ1_LENGTH} \
32
+ --max_seq2_length ${MAX_SEQ2_LENGTH} \
33
+ --max_num_questions ${MAX_NUM_QUESTIONS} \
34
+ --do_train \
35
+ --do_eval \
36
+ --evaluate_during_training \
37
+ --learning_rate 1e-5 \
38
+ --num_train_epochs ${NUM_TRAIN_EPOCH} \
39
+ --gradient_accumulation_steps ${GRADIENT_ACCUMULATION_STEPS} \
40
+ --per_gpu_train_batch_size ${PER_GPU_TRAIN_BATCH_SIZE} \
41
+ --per_gpu_eval_batch_size ${PER_GPU_EVAL_BATCH_SIZE} \
42
+ --logging_steps ${LOGGING_STEPS} \
43
+ --save_steps ${SAVE_STEPS} \
44
+ --cand_k ${CAND_K} \
45
+ --logic_lambda ${LAMBDA} \
46
+ --prior ${PRIOR} \
47
+ --overwrite_output_dir \
48
+ --temperature 1.0 \
49
+ --mask_rate ${MASK}
50
+
51
+ python3 cjjpy.py --lark "$OUTPUT_DIR fact checking training completed"
src/check_client/train.py ADDED
@@ -0,0 +1,647 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import os
18
+ import glob
19
+ import argparse
20
+ import logging
21
+ import random
22
+ import torch
23
+ import numpy as np
24
+ from tqdm import tqdm
25
+ from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
26
+ from torch.utils.data.distributed import DistributedSampler
27
+ from transformers import (
28
+ AutoConfig,
29
+ AutoTokenizer
30
+ )
31
+ from transformers import WEIGHTS_NAME, AdamW, get_linear_schedule_with_warmup
32
+ import tensorflow as tf
33
+ from pytorch_lightning.loggers import WandbLogger
34
+
35
+ try:
36
+ from .modules.data_processor import DataProcessor
37
+ from .plm_checkers import BertChecker, RobertaChecker, XLNetChecker, DebertaChecker
38
+ from .utils import init_logger, compute_metrics
39
+ except:
40
+ from modules.data_processor import DataProcessor
41
+ from plm_checkers import BertChecker, RobertaChecker, XLNetChecker, DebertaChecker
42
+ from utils import init_logger, compute_metrics
43
+
44
+ try:
45
+ from torch.utils.tensorboard import SummaryWriter
46
+ except ImportError:
47
+ from tensorboardX import SummaryWriter
48
+
49
+ mAutoModel = {
50
+ 'bert': BertChecker,
51
+ 'roberta': RobertaChecker,
52
+ 'xlnet': XLNetChecker,
53
+ 'deberta': DebertaChecker,
54
+ }
55
+
56
+ logger = logging.getLogger(__name__)
57
+
58
+
59
+ def set_seed(args):
60
+ random.seed(args.seed)
61
+ np.random.seed(args.seed)
62
+ torch.manual_seed(args.seed)
63
+ if args.n_gpu > 0:
64
+ torch.cuda.manual_seed_all(args.seed)
65
+
66
+
67
+ def train(args, data_processor, model, tokenizer):
68
+ """ Train the model """
69
+ global wdblogger
70
+ if args.local_rank in [-1, 0]:
71
+ tb_writer = SummaryWriter()
72
+
73
+ tf.io.gfile.makedirs(os.path.dirname(args.output_dir))
74
+ args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
75
+ train_dataset = data_processor.load_and_cache_data("train", tokenizer, args.data_tag)
76
+ train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
77
+ train_dataloader = DataLoader(train_dataset, sampler=train_sampler,
78
+ drop_last=True,
79
+ batch_size=args.train_batch_size)
80
+
81
+ if args.max_steps > 0:
82
+ t_total = args.max_steps
83
+ args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
84
+ else:
85
+ t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
86
+
87
+ # Prepare optimizer and schedule (linear warmup and decay)
88
+ no_decay = ["bias", "LayerNorm.weight"]
89
+ optimizer_grouped_parameters = [
90
+ {
91
+ "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
92
+ "weight_decay": args.weight_decay,
93
+ },
94
+ {
95
+ "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
96
+ "weight_decay": 0.0
97
+ },
98
+ ]
99
+ optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
100
+ scheduler = get_linear_schedule_with_warmup(
101
+ optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total
102
+ )
103
+ if args.fp16:
104
+ try:
105
+ from apex import amp
106
+ except ImportError:
107
+ raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
108
+ model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)
109
+
110
+ # multi-gpu training (should be after apex fp16 initialization)
111
+ if args.n_gpu > 1:
112
+ model = torch.nn.DataParallel(model)
113
+
114
+ # Distributed training (should be after apex fp16 initialization)
115
+ if args.local_rank != -1:
116
+ model = torch.nn.parallel.DistributedDataParallel(
117
+ model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True
118
+ )
119
+
120
+ # Train!
121
+ logger.info("***** Running training *****")
122
+ logger.info("Num examples = %d", len(train_dataset))
123
+ logger.info("Num Epochs = %d", args.num_train_epochs)
124
+ logger.info("Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
125
+ logger.info(
126
+ "Total train batch size (w. parallel, distributed & accumulation) = %d",
127
+ args.train_batch_size
128
+ * args.gradient_accumulation_steps
129
+ * (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
130
+ )
131
+ logger.info("Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
132
+ logger.info("Total optimization steps = %d", t_total)
133
+
134
+ global_step = 0
135
+ tr_loss, logging_loss = 0.0, 0.0
136
+ tr_loss2, logging_loss2 = 0.0, 0.0
137
+ tr_loss3, logging_loss3 = 0.0, 0.0
138
+ set_seed(args) # Added here for reproductibility
139
+ model.zero_grad()
140
+ for _ in range(int(args.num_train_epochs)):
141
+ all_loss = 0.0
142
+ all_accuracy = 0.0
143
+ epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
144
+ for step, batch in enumerate(epoch_iterator):
145
+ model.train()
146
+ batch = tuple(t.to(args.device) for t in batch)
147
+ inputs = {
148
+ "claim_input_ids": batch[0],
149
+ "claim_attention_mask": batch[1],
150
+ "qa_input_ids_list": batch[3],
151
+ "qa_attention_mask_list": batch[4],
152
+ "nli_labels": batch[-2],
153
+ "labels": batch[-1],
154
+ }
155
+ if args.model_type != "distilbert":
156
+ # XLM, DistilBERT, RoBERTa, and XLM-RoBERTa don't use segment_ids
157
+ inputs["claim_token_type_ids"] = batch[2] if args.model_type in ["bert", "xlnet", "albert"] else None
158
+ inputs["qa_token_type_ids_list"] = batch[5] if args.model_type in ["bert", "xlnet", "albert"] else None
159
+
160
+ outputs = model(**inputs)
161
+ loss, _loss2, logits = outputs[0], outputs[1], outputs[2]
162
+ loss2, loss3 = _loss2
163
+
164
+ if args.n_gpu > 1:
165
+ loss = loss.mean() # mean() to average on multi-gpu parallel (not distributed) training
166
+ loss2 = loss2.mean()
167
+ loss3 = loss3.mean()
168
+ if args.gradient_accumulation_steps > 1:
169
+ loss = loss / args.gradient_accumulation_steps
170
+ loss2 = loss2 / args.gradient_accumulation_steps
171
+ loss3 = loss3 / args.gradient_accumulation_steps
172
+
173
+ if args.fp16:
174
+ with amp.scale_loss(loss, optimizer) as scaled_loss:
175
+ scaled_loss.backward()
176
+ else:
177
+ loss.backward()
178
+
179
+ tr_loss += loss.item()
180
+ tr_loss2 += loss2.item()
181
+ tr_loss3 += loss3.item()
182
+
183
+ all_loss += loss.detach().cpu().numpy() * args.gradient_accumulation_steps
184
+ all_accuracy += np.mean(
185
+ inputs["labels"].detach().cpu().numpy() == logits.detach().cpu().numpy().argmax(axis=-1)
186
+ )
187
+ description = "Global step: {:>6}, Loss: {:>.6f}, Accuracy: {:>.6f}".format(
188
+ global_step,
189
+ all_loss / (step + 1),
190
+ all_accuracy / (step + 1),
191
+ )
192
+ epoch_iterator.set_description(description)
193
+ if (step + 1) % args.gradient_accumulation_steps == 0:
194
+ if args.fp16:
195
+ torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
196
+ else:
197
+ torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
198
+
199
+ optimizer.step()
200
+ scheduler.step() # Update learning rate schedule
201
+ model.zero_grad()
202
+ global_step += 1
203
+
204
+ # Log metrics
205
+ if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
206
+ # Only evaluate when single GPU otherwise metrics may not average well
207
+ if args.local_rank == -1 and args.evaluate_during_training:
208
+ results = evaluate(args, data_processor, model, tokenizer)
209
+ for key, value in results.items():
210
+ logger.warning(f"Step: {global_step}, eval_{key}: {value}")
211
+ wdblogger.log_metrics({"eval_{}".format(key): value}, global_step)
212
+ tb_writer.add_scalar("eval_{}".format(key), value, global_step)
213
+ tb_writer.add_scalar("lr", scheduler.get_lr()[0], global_step)
214
+ wdblogger.log_metrics({"lr": scheduler.get_lr()[0]}, global_step)
215
+ tb_writer.add_scalar("loss", (tr_loss - logging_loss) / args.logging_steps, global_step)
216
+ wdblogger.log_metrics({"loss": (tr_loss - logging_loss) / args.logging_steps}, global_step)
217
+ wdblogger.log_metrics({"loss2": (tr_loss2 - logging_loss2) / args.logging_steps}, global_step)
218
+ wdblogger.log_metrics({"loss3": (tr_loss3 - logging_loss3) / args.logging_steps}, global_step)
219
+
220
+ logging_loss = tr_loss
221
+ logging_loss2 = tr_loss2
222
+ logging_loss3 = tr_loss3
223
+ wdblogger.save()
224
+
225
+ # Save model checkpoint
226
+ if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
227
+ output_dir = os.path.join(args.output_dir, "checkpoint-{}".format(global_step))
228
+ if not os.path.exists(output_dir):
229
+ os.makedirs(output_dir)
230
+ # Take care of distributed/parallel training
231
+ model_to_save = model.module if hasattr(model, "module") else model
232
+ model_to_save.save_pretrained(output_dir)
233
+ torch.save(args, os.path.join(output_dir, "training_args.bin"))
234
+ logger.info("Saving model checkpoint to %s", output_dir)
235
+
236
+ if 0 < args.max_steps < global_step:
237
+ epoch_iterator.close()
238
+ break
239
+ if 0 < args.max_steps < global_step:
240
+ break
241
+
242
+ if args.local_rank in [-1, 0]:
243
+ tb_writer.close()
244
+
245
+ return global_step, tr_loss / global_step
246
+
247
+
248
+ def evaluate(args, data_processor, model, tokenizer, prefix=""):
249
+ if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]:
250
+ os.makedirs(args.output_dir)
251
+
252
+ args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
253
+ dataset = data_processor.load_and_cache_data("eval", tokenizer, args.data_tag)
254
+ eval_sampler = SequentialSampler(dataset) if args.local_rank == -1 else DistributedSampler(dataset)
255
+ eval_dataloader = DataLoader(dataset, sampler=eval_sampler,
256
+ drop_last=True,
257
+ batch_size=args.eval_batch_size)
258
+
259
+ # Eval!
260
+ logger.info("***** Running evaluation {} *****".format(prefix))
261
+ logger.info("Num examples = %d", len(dataset))
262
+ logger.info("Batch size = %d", args.eval_batch_size)
263
+
264
+ label_truth, y_predicted, z_predicted, m_attn, mask = \
265
+ do_evaluate(tqdm(eval_dataloader, desc="Evaluating"), model, args, during_training=True, with_label=True)
266
+
267
+ outputs, results = compute_metrics(label_truth, y_predicted, z_predicted, mask)
268
+
269
+ return results
270
+
271
+
272
+ def do_evaluate(dataloader, model, args, during_training=False, with_label=True):
273
+ label_truth = []
274
+ y_predicted = []
275
+ z_predicted = []
276
+ m_attn = []
277
+ mask = []
278
+ for i, batch in enumerate(dataloader):
279
+ model.eval()
280
+ batch = tuple(t.to(args.device) for t in batch)
281
+ with torch.no_grad():
282
+ inputs = {
283
+ "claim_input_ids": batch[0],
284
+ "claim_attention_mask": batch[1],
285
+ "qa_input_ids_list": batch[3],
286
+ "qa_attention_mask_list": batch[4],
287
+ "nli_labels": batch[6],
288
+ }
289
+
290
+ if args.model_type != "distilbert":
291
+ # XLM, DistilBERT, RoBERTa, and XLM-RoBERTa don't use segment_ids
292
+ inputs["claim_token_type_ids"] = batch[2] if args.model_type in ["bert", "xlnet", "albert"] else None
293
+ inputs["qa_token_type_ids_list"] = batch[5] if args.model_type in ["bert", "xlnet", "albert"] else None
294
+
295
+ outputs = model(**inputs)
296
+
297
+ if during_training and (i < 3 and (args.logic_lambda != 0)):
298
+ logger.warning(f'* m_attn:\n {outputs[-2][:5]}\n')
299
+ logger.warning(f'* Logic outputs:\n {outputs[-1][0][:5]}.\n Labels: {batch[-1][:5]}\n')
300
+
301
+ if with_label:
302
+ label_truth += batch[-1].tolist()
303
+ y_predicted += outputs[2].tolist()
304
+ mask += outputs[-1][1].tolist()
305
+ z_predicted += outputs[-1][0].tolist()
306
+ m_attn += outputs[-2].tolist()
307
+
308
+ y_predicted = np.argmax(y_predicted, axis=-1).tolist()
309
+
310
+ return label_truth, y_predicted, z_predicted, m_attn, mask
311
+
312
+
313
+ def main():
314
+ parser = argparse.ArgumentParser()
315
+
316
+ # Required parameters
317
+ parser.add_argument(
318
+ "--data_dir",
319
+ default=None,
320
+ type=str,
321
+ required=True,
322
+ help="The input data dir. Should contain the .tsv files (or other data files) for the task.",
323
+ )
324
+ parser.add_argument(
325
+ "--model_type",
326
+ default=None,
327
+ type=str,
328
+ required=True,
329
+ help="Model type selected in the list: " + ", ".join(mAutoModel.keys()),
330
+ )
331
+ parser.add_argument(
332
+ "--model_name_or_path",
333
+ default=None,
334
+ type=str,
335
+ required=True,
336
+ help="Path to pre-trained model or shortcut name",
337
+ )
338
+ parser.add_argument(
339
+ "--data_tag",
340
+ default='default',
341
+ type=str,
342
+ help='Tag to cached data'
343
+ )
344
+ parser.add_argument(
345
+ "--max_seq1_length",
346
+ default=None,
347
+ type=int,
348
+ required=True,
349
+ help="The maximum total input claim sequence length after tokenization. "
350
+ "Sequences longer than this will be truncated, sequences shorter will be padded.",
351
+ )
352
+ parser.add_argument(
353
+ "--max_seq2_length",
354
+ default=None,
355
+ type=int,
356
+ required=True,
357
+ help="The maximum total input claim sequence length after tokenization. "
358
+ "Sequences longer than this will be truncated, sequences shorter will be padded.",
359
+ )
360
+ parser.add_argument(
361
+ "--max_num_questions",
362
+ default=None,
363
+ type=int,
364
+ required=True,
365
+ help='The maximum number of evidences.',
366
+ )
367
+ parser.add_argument(
368
+ "--cand_k",
369
+ default=1,
370
+ type=int,
371
+ help='The number of evidential answers out of beam size'
372
+ )
373
+ parser.add_argument(
374
+ '--mask_rate',
375
+ default=0.,
376
+ type=float,
377
+ help="Mask rate of QA"
378
+ )
379
+ parser.add_argument(
380
+ "--output_dir",
381
+ default=None,
382
+ type=str,
383
+ required=True,
384
+ help="The output directory where the model predictions and checkpoints will be written.",
385
+ )
386
+
387
+ # Other parameters
388
+ parser.add_argument(
389
+ "--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name",
390
+ )
391
+ parser.add_argument(
392
+ "--tokenizer_name",
393
+ default="",
394
+ type=str,
395
+ help="Pretrained tokenizer name or path if not the same as model_name",
396
+ )
397
+ parser.add_argument(
398
+ "--cache_dir",
399
+ default="",
400
+ type=str,
401
+ help="Where do you want to store the pre-trained models downloaded from s3",
402
+ )
403
+ parser.add_argument(
404
+ "--max_seq_length",
405
+ default=128,
406
+ type=int,
407
+ help="The maximum total input sequence length after tokenization. Sequences longer "
408
+ "than this will be truncated, sequences shorter will be padded.",
409
+ )
410
+ parser.add_argument('--logic_lambda', required=True, type=float,
411
+ help='Regularization term for logic loss, also an indicator for using only logic.')
412
+ parser.add_argument('--prior', default='nli', type=str, choices=['nli', 'uniform', 'logic', 'random'],
413
+ help='type of prior distribution')
414
+ parser.add_argument('--temperature', required=True, type=float, help='Temperature for gumbel softmax.')
415
+
416
+ parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
417
+ parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the dev set.")
418
+ parser.add_argument(
419
+ "--evaluate_during_training", action="store_true", help="Run evaluation during training at each logging step.",
420
+ )
421
+ parser.add_argument(
422
+ "--do_lower_case", action="store_true", help="Set this flag if you are using an uncased model.",
423
+ )
424
+ parser.add_argument(
425
+ "--per_gpu_train_batch_size", default=8, type=int, help="Batch size per GPU/CPU for training.",
426
+ )
427
+ parser.add_argument(
428
+ "--per_gpu_eval_batch_size", default=8, type=int, help="Batch size per GPU/CPU for evaluation.",
429
+ )
430
+ parser.add_argument(
431
+ "--gradient_accumulation_steps",
432
+ type=int,
433
+ default=1,
434
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
435
+ )
436
+ parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
437
+ parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.")
438
+ parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
439
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
440
+ parser.add_argument(
441
+ "--num_train_epochs", default=3.0, type=float, help="Total number of training epochs to perform.",
442
+ )
443
+ parser.add_argument(
444
+ "--max_steps",
445
+ default=-1,
446
+ type=int,
447
+ help="If > 0: set total number of training steps to perform. Override num_train_epochs.",
448
+ )
449
+ parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.")
450
+ parser.add_argument("--logging_steps", type=int, default=500, help="Log every X updates steps.")
451
+ parser.add_argument("--save_steps", type=int, default=500, help="Save checkpoint every X updates steps.")
452
+ parser.add_argument(
453
+ "--eval_all_checkpoints",
454
+ action="store_true",
455
+ help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number",
456
+ )
457
+ parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available")
458
+ parser.add_argument(
459
+ "--overwrite_output_dir", action="store_true", help="Overwrite the content of the output directory",
460
+ )
461
+ parser.add_argument(
462
+ "--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets",
463
+ )
464
+ parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
465
+ parser.add_argument(
466
+ "--fp16",
467
+ action="store_true",
468
+ help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit",
469
+ )
470
+ parser.add_argument(
471
+ "--fp16_opt_level",
472
+ type=str,
473
+ default="O1",
474
+ help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
475
+ "See details at https://nvidia.github.io/apex/amp.html",
476
+ )
477
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
478
+ parser.add_argument("--server_ip", type=str, default="", help="For distant debugging.")
479
+ parser.add_argument("--server_port", type=str, default="", help="For distant debugging.")
480
+ args = parser.parse_args()
481
+
482
+ if (
483
+ os.path.exists(args.output_dir)
484
+ and os.listdir(args.output_dir)
485
+ and args.do_train
486
+ and not args.overwrite_output_dir
487
+ ):
488
+ raise ValueError(
489
+ "Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(
490
+ args.output_dir
491
+ )
492
+ )
493
+
494
+ # Setup distant debugging if needed
495
+ if args.server_ip and args.server_port:
496
+ # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
497
+ import ptvsd
498
+
499
+ print("Waiting for debugger attach")
500
+ ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
501
+ ptvsd.wait_for_attach()
502
+
503
+ # Setup CUDA, GPU & distributed training
504
+ if args.local_rank == -1 or args.no_cuda:
505
+ device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
506
+ args.n_gpu = 0 if args.no_cuda else torch.cuda.device_count()
507
+ else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
508
+ torch.cuda.set_device(args.local_rank)
509
+ device = torch.device("cuda", args.local_rank)
510
+ torch.distributed.init_process_group(backend="nccl")
511
+ args.n_gpu = 1
512
+ args.device = device
513
+
514
+ # Setup logging
515
+ if args.do_train:
516
+ global wdblogger
517
+ tf.io.gfile.makedirs(args.output_dir)
518
+ wdblogger = WandbLogger(name=os.path.basename(args.output_dir))
519
+ wdblogger.log_hyperparams(args)
520
+ wdblogger.save()
521
+ log_file = os.path.join(args.output_dir, 'train.log')
522
+ init_logger(logging.INFO if args.local_rank in [-1, 0] else logging.WARN, log_file)
523
+
524
+ logger.warning(
525
+ "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
526
+ args.local_rank,
527
+ device,
528
+ args.n_gpu,
529
+ bool(args.local_rank != -1),
530
+ args.fp16,
531
+ )
532
+
533
+ # Set seed
534
+ set_seed(args)
535
+
536
+ # Prepare task
537
+ data_processor = DataProcessor(
538
+ args.model_name_or_path,
539
+ args.max_seq1_length,
540
+ args.max_seq2_length,
541
+ args.max_num_questions,
542
+ args.cand_k,
543
+ data_dir=args.data_dir,
544
+ cache_dir_name=os.path.basename(args.output_dir),
545
+ overwrite_cache=args.overwrite_cache,
546
+ mask_rate=args.mask_rate
547
+ )
548
+
549
+ # Make sure only the first process in distributed training will download model & vocab
550
+ if args.local_rank not in [-1, 0]:
551
+ torch.distributed.barrier()
552
+
553
+ # Load pretrained model and tokenizer
554
+ args.model_type = args.model_type.lower()
555
+
556
+ config = AutoConfig.from_pretrained(
557
+ args.config_name if args.config_name else args.model_name_or_path,
558
+ num_labels=3,
559
+ cache_dir=args.cache_dir if args.cache_dir else None,
560
+ )
561
+ tokenizer = AutoTokenizer.from_pretrained(
562
+ args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
563
+ do_lower_case=args.do_lower_case,
564
+ cache_dir=args.cache_dir if args.cache_dir else None,
565
+ )
566
+ model = mAutoModel[args.model_type].from_pretrained(
567
+ args.model_name_or_path,
568
+ from_tf=bool(".ckpt" in args.model_name_or_path),
569
+ config=config,
570
+ cache_dir=args.cache_dir if args.cache_dir else None,
571
+ logic_lambda=args.logic_lambda,
572
+ m=args.max_num_questions,
573
+ prior=args.prior,
574
+ temperature=args.temperature
575
+ )
576
+
577
+ # Make sure only the first process in distributed training will download model & vocab
578
+ if args.local_rank == 0:
579
+ torch.distributed.barrier()
580
+
581
+ if args.do_train:
582
+ model.to(args.device)
583
+ wdblogger.watch(model)
584
+
585
+ logger.info("Training/evaluation parameters %s", args)
586
+
587
+ # Before we do anything with models, we want to ensure that we get fp16 execution of torch.einsum
588
+ # if args.fp16 is set. Otherwise it'll default to "promote" mode, and we'll get fp32 operations.
589
+ # Note that running `--fp16_opt_level="O2"` will remove the need for this code, but it is still valid.
590
+ if args.fp16:
591
+ try:
592
+ import apex
593
+ apex.amp.register_half_function(torch, "einsum")
594
+ except ImportError:
595
+ raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
596
+
597
+ # Training
598
+ if args.do_train:
599
+ global_step, tr_loss = train(args, data_processor, model, tokenizer)
600
+ logger.info("global_step = %s, average loss = %s", global_step, tr_loss)
601
+
602
+ # Save the trained model and the tokenizer
603
+ if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
604
+ logger.info("Saving model checkpoint to %s", args.output_dir)
605
+ # Save a trained model, configuration and tokenizer using `save_pretrained()`.
606
+ # They can then be reloaded using `from_pretrained()`
607
+ # Take care of distributed/parallel training
608
+ model_to_save = model.module if hasattr(model, "module") else model
609
+ model_to_save.save_pretrained(args.output_dir)
610
+ tokenizer.save_pretrained(args.output_dir)
611
+
612
+ # Good practice: save your training arguments together with the trained model
613
+ torch.save(args, os.path.join(args.output_dir, "training_args.bin"))
614
+
615
+ # Evaluation - we can ask to evaluate all the checkpoints (sub-directories) in a directory
616
+ results = {}
617
+ if args.do_eval and args.local_rank in [-1, 0]:
618
+ checkpoints = [args.output_dir]
619
+ if args.eval_all_checkpoints:
620
+ checkpoints = list(
621
+ os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + "/**/" + WEIGHTS_NAME, recursive=True))
622
+ )
623
+ logging.getLogger("transformers.modeling_utils").setLevel(logging.WARN) # Reduce model loading logs
624
+
625
+ logger.info("Evaluate the following checkpoints: %s", checkpoints)
626
+ for checkpoint in checkpoints:
627
+ global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else ""
628
+ model = mAutoModel[args.model_type].from_pretrained(
629
+ checkpoint,
630
+ logic_lambda=args.logic_lambda,
631
+ m=args.max_num_questions,
632
+ prior=args.prior,
633
+ temperature=args.temperature
634
+ )
635
+ model.to(args.device)
636
+
637
+ # Evaluate
638
+ result = evaluate(args, data_processor, model, tokenizer, prefix=global_step)
639
+ result = dict((k + ("_{}".format(global_step) if global_step else ""), v) for k, v in result.items())
640
+ results.update(result)
641
+
642
+ print(results)
643
+ return results
644
+
645
+
646
+ if __name__ == "__main__":
647
+ main()
src/check_client/utils.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ """
4
+ @Author : Bao
5
+ @Date : 2020/8/12
6
+ @Desc :
7
+ @Last modified by : Bao
8
+ @Last modified date : 2020/8/12
9
+ """
10
+
11
+ import logging
12
+ from numpy.core.fromnumeric import argmax
13
+ import ujson as json
14
+ import torch
15
+ from plm_checkers.checker_utils import soft_logic
16
+
17
+
18
+ def init_logger(level, filename=None, mode='a', encoding='utf-8'):
19
+ logging_config = {
20
+ 'format': '%(asctime)s - %(levelname)s - %(name)s:\t%(message)s',
21
+ 'datefmt': '%Y-%m-%d %H:%M:%S',
22
+ 'level': level,
23
+ 'handlers': [logging.StreamHandler()]
24
+ }
25
+ if filename:
26
+ logging_config['handlers'].append(logging.FileHandler(filename, mode, encoding))
27
+ logging.basicConfig(**logging_config)
28
+
29
+
30
+ def read_json(filename, mode='r', encoding='utf-8'):
31
+ with open(filename, mode, encoding=encoding) as fin:
32
+ return json.load(fin)
33
+
34
+
35
+ def save_json(data, filename, mode='w', encoding='utf-8'):
36
+ with open(filename, mode, encoding=encoding) as fout:
37
+ json.dump(data, fout, ensure_ascii=False, indent=4)
38
+
39
+
40
+ def read_json_lines(filename, mode='r', encoding='utf-8', skip=0):
41
+ with open(filename, mode, encoding=encoding) as fin:
42
+ for line in fin:
43
+ if skip > 0:
44
+ skip -= 1
45
+ continue
46
+ yield json.loads(line)
47
+
48
+
49
+ def save_json_lines(data, filename, mode='w', encoding='utf-8', skip=0):
50
+ with open(filename, mode, encoding=encoding) as fout:
51
+ for line in data:
52
+ if skip > 0:
53
+ skip -= 1
54
+ continue
55
+ print(json.dumps(line, ensure_ascii=False), file=fout)
56
+
57
+
58
+ def read_json_dict(filename, mode='r', encoding='utf-8'):
59
+ with open(filename, mode, encoding=encoding) as fin:
60
+ key_2_id = json.load(fin)
61
+ id_2_key = dict(zip(key_2_id.values(), key_2_id.keys()))
62
+
63
+ return key_2_id, id_2_key
64
+
65
+
66
+ def save_json_dict(data, filename, mode='w', encoding='utf-8'):
67
+ with open(filename, mode, encoding=encoding) as fout:
68
+ json.dump(data, fout, ensure_ascii=False, indent=4)
69
+
70
+
71
+ # Calculate precision, recall and f1 value
72
+ # According to https://github.com/dice-group/gerbil/wiki/Precision,-Recall-and-F1-measure
73
+ def get_prf(res):
74
+ if res['TP'] == 0:
75
+ if res['FP'] == 0 and res['FN'] == 0:
76
+ precision = 1.0
77
+ recall = 1.0
78
+ f1 = 1.0
79
+ else:
80
+ precision = 0.0
81
+ recall = 0.0
82
+ f1 = 0.0
83
+ else:
84
+ precision = 1.0 * res['TP'] / (res['TP'] + res['FP'])
85
+ recall = 1.0 * res['TP'] / (res['TP'] + res['FN'])
86
+ f1 = 2 * precision * recall / (precision + recall)
87
+
88
+ return precision, recall, f1
89
+
90
+
91
+ def compute_metrics(truth, predicted, z_predicted, mask):
92
+ assert len(truth) == len(predicted)
93
+
94
+ outputs = []
95
+ results = {}
96
+ cnt = 0
97
+ z_cnt_h, z_cnt_s = 0, 0
98
+ agree_h, agree_s = 0, 0
99
+ for x, y, z, m in zip(truth, predicted, z_predicted, mask):
100
+ res = {'label': x, 'prediction': y}
101
+ if x == y:
102
+ cnt += 1
103
+
104
+ res['pred_z'] = z
105
+
106
+ y_ = soft_logic(torch.tensor([z]), torch.tensor([m]))[0]
107
+ if y_.argmax(-1).item() == x:
108
+ z_cnt_s += 1
109
+ if y_.argmax(-1).item() == y:
110
+ agree_s += 1
111
+
112
+ z_h = torch.tensor(z[:torch.tensor(m).sum()]).argmax(-1).tolist() # m' x 3
113
+ if 0 in z_h: # REFUTES
114
+ y__ = 0
115
+ elif 1 in z_h: # NEI
116
+ y__ = 1
117
+ else: # SUPPPORTS
118
+ y__ = 2
119
+ if y__ == x:
120
+ z_cnt_h += 1
121
+ if y__ == y:
122
+ agree_h += 1
123
+
124
+ outputs.append(res)
125
+
126
+ results['Accuracy'] = cnt / len(truth)
127
+ results['z_Acc_hard'] = z_cnt_h / len(truth)
128
+ results['z_Acc_soft'] = z_cnt_s / len(truth)
129
+ results['Agreement_hard'] = agree_h / len(truth)
130
+ results['Agreement_soft'] = agree_s / len(truth)
131
+ return outputs, results
src/cjjpy.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ '''
4
+ @Author : Jiangjie Chen
5
+ @Time : 2018/11/15 17:08
6
+ @Contact: jjchen19@fudan.edu.cn
7
+ '''
8
+
9
+ import re
10
+ import datetime
11
+ import os
12
+ import argparse
13
+ import logging
14
+ import traceback
15
+
16
+ try:
17
+ import ujson as json
18
+ except:
19
+ import json
20
+
21
+ HADOOP_BIN = 'PATH=/usr/bin/:$PATH hdfs'
22
+ FOR_PUBLIC = True
23
+
24
+
25
+ def LengthStats(filename):
26
+ len_list = []
27
+ thresholds = [0.8, 0.9, 0.95, 0.99, 0.999]
28
+ with open(filename) as f:
29
+ for line in f:
30
+ len_list.append(len(line.strip().split()))
31
+ stats = {
32
+ 'Max': max(len_list),
33
+ 'Min': min(len_list),
34
+ 'Avg': round(sum(len_list) / len(len_list), 4),
35
+ }
36
+ len_list.sort()
37
+ for t in thresholds:
38
+ stats[f"Top-{t}"] = len_list[int(len(len_list) * t)]
39
+
40
+ for k in stats:
41
+ print(f"- {k}: {stats[k]}")
42
+ return stats
43
+
44
+
45
+ class AttrDict(dict):
46
+ def __init__(self, *args, **kwargs):
47
+ super(AttrDict, self).__init__(*args, **kwargs)
48
+ self.__dict__ = self
49
+
50
+
51
+ def TraceBack(error_msg):
52
+ exc = traceback.format_exc()
53
+ msg = f'[Error]: {error_msg}.\n[Traceback]: {exc}'
54
+ return msg
55
+
56
+
57
+ def Now():
58
+ return datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
59
+
60
+
61
+ def AbsParentDir(file, parent='..', postfix=None):
62
+ ppath = os.path.abspath(file)
63
+ parent_level = parent.count('.')
64
+ while parent_level > 0:
65
+ ppath = os.path.dirname(ppath)
66
+ parent_level -= 1
67
+ if postfix is not None:
68
+ return os.path.join(ppath, postfix)
69
+ else:
70
+ return ppath
71
+
72
+
73
+ def init_logger(log_file=None, log_file_level=logging.NOTSET, from_scratch=False):
74
+ from coloredlogs import ColoredFormatter
75
+ import tensorflow as tf
76
+
77
+ fmt = "[%(asctime)s %(levelname)s] %(message)s"
78
+ log_format = ColoredFormatter(fmt=fmt)
79
+ # log_format = logging.Formatter()
80
+ logger = logging.getLogger()
81
+ logger.setLevel(log_file_level)
82
+
83
+ console_handler = logging.StreamHandler()
84
+ console_handler.setFormatter(log_format)
85
+ logger.handlers = [console_handler]
86
+
87
+ if log_file and log_file != '':
88
+ if from_scratch and tf.io.gfile.exists(log_file):
89
+ logger.warning('Removing previous log file: %s' % log_file)
90
+ tf.io.gfile.remove(log_file)
91
+ path = os.path.dirname(log_file)
92
+ os.makedirs(path, exist_ok=True)
93
+ file_handler = logging.FileHandler(log_file)
94
+ file_handler.setLevel(log_file_level)
95
+ file_handler.setFormatter(log_format)
96
+ logger.addHandler(file_handler)
97
+
98
+ return logger
99
+
100
+
101
+ def OverWriteCjjPy(root='.'):
102
+ # import difflib
103
+ # diff = difflib.HtmlDiff()
104
+ cnt = 0
105
+ golden_cjjpy = os.path.join(root, 'cjjpy.py')
106
+ # golden_content = open(golden_cjjpy).readlines()
107
+ for dir, folder, file in os.walk(root):
108
+ for f in file:
109
+ if f == 'cjjpy.py':
110
+ cjjpy = '%s/%s' % (dir, f)
111
+ # content = open(cjjpy).readlines()
112
+ # d = diff.make_file(golden_content, content)
113
+ cnt += 1
114
+ print('[%d]: %s' % (cnt, cjjpy))
115
+ os.system('cp %s %s' % (golden_cjjpy, cjjpy))
116
+
117
+
118
+ def ChangeFileFormat(filename, new_fmt):
119
+ assert type(filename) is str and type(new_fmt) is str
120
+ spt = filename.split('.')
121
+ if len(spt) == 0:
122
+ return filename
123
+ else:
124
+ return filename.replace('.' + spt[-1], new_fmt)
125
+
126
+
127
+ def CountLines(fname):
128
+ with open(fname, 'rb') as f:
129
+ count = 0
130
+ last_data = '\n'
131
+ while True:
132
+ data = f.read(0x400000)
133
+ if not data:
134
+ break
135
+ count += data.count(b'\n')
136
+ last_data = data
137
+ if last_data[-1:] != b'\n':
138
+ count += 1 # Remove this if a wc-like count is needed
139
+ return count
140
+
141
+
142
+ def GetDate():
143
+ return str(datetime.datetime.now())[5:10].replace('-', '')
144
+
145
+
146
+ def TimeClock(seconds):
147
+ sec = int(seconds)
148
+ hour = int(sec / 3600)
149
+ min = int((sec - hour * 3600) / 60)
150
+ ssec = float(seconds) - hour * 3600 - min * 60
151
+ # return '%dh %dm %.2fs' % (hour, min, ssec)
152
+ return '{0:>2d}h{1:>3d}m{2:>6.2f}s'.format(hour, min, ssec)
153
+
154
+
155
+ def StripAll(text):
156
+ return text.strip().replace('\t', '').replace('\n', '').replace(' ', '')
157
+
158
+
159
+ def GetBracket(text, bracket, en_br=False):
160
+ # input should be aa(bb)cc, True for bracket, False for text
161
+ if bracket:
162
+ try:
163
+ return re.findall('\((.*?)\)', text.strip())[-1]
164
+ except:
165
+ return ''
166
+ else:
167
+ if en_br:
168
+ text = re.sub('\(.*?\)', '', text.strip())
169
+ return re.sub('(.*?)', '', text.strip())
170
+
171
+
172
+ def CharLang(uchar, lang):
173
+ assert lang.lower() in ['en', 'cn', 'zh']
174
+ if lang.lower() in ['cn', 'zh']:
175
+ if uchar >= '\u4e00' and uchar <= '\u9fa5':
176
+ return True
177
+ else:
178
+ return False
179
+ elif lang.lower() == 'en':
180
+ if (uchar <= 'Z' and uchar >= 'A') or (uchar <= 'z' and uchar >= 'a'):
181
+ return True
182
+ else:
183
+ return False
184
+ else:
185
+ raise NotImplementedError
186
+
187
+
188
+ def WordLang(word, lang):
189
+ for i in word.strip():
190
+ if i.isspace(): continue
191
+ if not CharLang(i, lang):
192
+ return False
193
+ return True
194
+
195
+
196
+ def SortDict(_dict, reverse=True):
197
+ assert type(_dict) is dict
198
+ return sorted(_dict.items(), key=lambda d: d[1], reverse=reverse)
199
+
200
+
201
+ def lark(content='test'):
202
+ print(content)
203
+
204
+
205
+ if __name__ == '__main__':
206
+ parser = argparse.ArgumentParser()
207
+
208
+ parser.add_argument('--diff', nargs=2,
209
+ help='show difference between two files, shown in downloads/diff.html')
210
+ parser.add_argument('--de_unicode', action='store_true', default=False,
211
+ help='remove unicode characters')
212
+ parser.add_argument('--link_entity', action='store_true', default=False,
213
+ help='')
214
+ parser.add_argument('--max_comm_len', action='store_true', default=False,
215
+ help='')
216
+ parser.add_argument('--search', nargs=2,
217
+ help='search key from file, 2 args: file name & key')
218
+ parser.add_argument('--email', nargs=2,
219
+ help='sending emails, 2 args: subject & content')
220
+ parser.add_argument('--overwrite', action='store_true', default=None,
221
+ help='overwrite all cjjpy under given *dir* based on *dir*/cjjpy.py')
222
+ parser.add_argument('--replace', nargs=3,
223
+ help='replace char, 3 args: file name & replaced char & replacer char')
224
+ parser.add_argument('--lark', nargs=1)
225
+ parser.add_argument('--get_hdfs', nargs=2,
226
+ help='easy copy from hdfs to local fs, 2 args: remote_file/dir & local_dir')
227
+ parser.add_argument('--put_hdfs', nargs=2,
228
+ help='easy put from local fs to hdfs, 2 args: local_file/dir & remote_dir')
229
+ parser.add_argument('--length_stats', nargs=1,
230
+ help='simple token lengths distribution of a line-by-line file')
231
+
232
+ args = parser.parse_args()
233
+
234
+ if args.overwrite:
235
+ print('* Overwriting cjjpy...')
236
+ OverWriteCjjPy()
237
+
238
+ if args.lark:
239
+ try:
240
+ content = args.lark[0]
241
+ except:
242
+ content = 'running complete'
243
+ print(f'* Larking "{content}"...')
244
+ lark(content)
245
+
246
+ if args.length_stats:
247
+ file = args.length_stats[0]
248
+ print(f'* Working on {file} lengths statistics...')
249
+ LengthStats(file)
src/dataloaders.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ '''
4
+ @Author : Jiangjie Chen
5
+ @Time : 2020/7/20 17:34
6
+ @Contact : jjchen19@fudan.edu.cn
7
+ @Description:
8
+ '''
9
+
10
+ import tensorflow as tf
11
+ import cjjpy as cjj
12
+ import os
13
+ import re
14
+ import ujson as json
15
+ from collections import defaultdict
16
+
17
+ pj_prefix = cjj.AbsParentDir(__file__, '..')
18
+
19
+
20
+ class FEVERLoader:
21
+ def __init__(self, role):
22
+ role = 'dev' if role == 'val' else role
23
+ assert role in ['train', 'dev', 'test', 'eval']
24
+ self.role = role
25
+ self.fever_data = defaultdict(dict)
26
+ self.SUPPORTS = 'SUPPORTS'
27
+ self.REFUTES = 'REFUTES'
28
+ self.NEI = 'NOT ENOUGH INFO'
29
+
30
+ def __iter__(self):
31
+ for k in self.fever_data:
32
+ yield k
33
+
34
+ def __len__(self):
35
+ return len(self.fever_data)
36
+
37
+ def __getitem__(self, item):
38
+ return self.fever_data[item]
39
+
40
+ def load_fever(self, retrieve_type='bert', clean_load=True):
41
+ self._load_fever_golden()
42
+ self._load_fever_all()
43
+ self._load_fever_retrieved(retrieve_type, clean_load)
44
+
45
+ def _load_json(self, fname):
46
+ with tf.io.gfile.GFile(fname) as f:
47
+ return [json.loads(x) for x in f.readlines()]
48
+
49
+ def _new_role(self):
50
+ role = self.role if self.role != 'eval' else 'dev'
51
+ return role
52
+
53
+ def _load_fever_golden(self):
54
+ if self.role == 'test':
55
+ postfix = f'data/fever/shared_task_test.jsonl'
56
+ for js in self._load_json(f'{pj_prefix}/{postfix}'):
57
+ self.fever_data[js['id']].update({
58
+ 'id': js['id'],
59
+ 'claim': js['claim']
60
+ })
61
+ else:
62
+ role = self._new_role()
63
+ postfix = f'data/fever/baked_data/golden_{role}.json'
64
+ for js in self._load_json(f'{pj_prefix}/{postfix}'):
65
+ self.fever_data[js['id']].update({
66
+ 'id': js['id'],
67
+ 'claim': js['claim'],
68
+ 'label': js['label'],
69
+ 'golden_evidence': self._clean_evidence(js['evidence'])
70
+ })
71
+ print('* FEVER golden loaded.')
72
+
73
+ def _load_fever_all(self):
74
+ role = self._new_role()
75
+ postfix = f'data/fever/baked_data/all_{role}.json'
76
+ for js in self._load_json(f'{pj_prefix}/{postfix}'):
77
+ self.fever_data[js['id']].update({
78
+ 'all_evidence': self._clean_evidence(js['evidence'])
79
+ })
80
+ print('* FEVER all loaded.')
81
+
82
+ def _load_fever_retrieved(self, retrieve_type, clean_load):
83
+ assert retrieve_type in ['bert']
84
+ postfix = f'data/fever/baked_data/{retrieve_type}_{self.role}.json'
85
+ for js in self._load_json(f'{pj_prefix}/{postfix}'):
86
+ self.fever_data[js['id']].update({
87
+ f'{retrieve_type}_evidence': self._clean_evidence(js['evidence']) if clean_load else js['evidence']
88
+ })
89
+ print(f'* FEVER {retrieve_type} loaded.')
90
+
91
+ def clean_text(self, sentence):
92
+ sentence = re.sub(" \-LSB\-.*?\-RSB\-", "", sentence)
93
+ sentence = re.sub("\-LRB\- \-RRB\- ", "", sentence)
94
+ sentence = re.sub(" -LRB-", " ( ", sentence)
95
+ sentence = re.sub("-RRB-", " )", sentence)
96
+
97
+ sentence = re.sub(" LSB.*?RSB", "", sentence)
98
+ sentence = re.sub("LRB RRB ", "", sentence)
99
+ sentence = re.sub("LRB", " ( ", sentence)
100
+ sentence = re.sub("RRB", " )", sentence)
101
+ sentence = re.sub("--", "-", sentence)
102
+ sentence = re.sub("``", '"', sentence)
103
+ sentence = re.sub("''", '"', sentence)
104
+ sentence = re.sub(' ', ' ', sentence)
105
+ return sentence
106
+
107
+ def clean_title(self, title):
108
+ title = re.sub("_", " ", title)
109
+ title = re.sub(" -LRB-", " ( ", title)
110
+ title = re.sub("-RRB-", " )", title)
111
+ title = re.sub("-COLON-", ":", title)
112
+ title = re.sub(' ', ' ', title)
113
+ return title
114
+
115
+ def _clean_evidence(self, evidence):
116
+ cev = []
117
+ for ev in evidence:
118
+ if len(ev) == 4:
119
+ cev.append([self.clean_title(ev[0]), ev[1], self.clean_text(ev[2]), ev[3]])
120
+ elif len(ev) == 3:
121
+ cev.append([self.clean_title(ev[0]), ev[1], self.clean_text(ev[2])])
122
+ elif len(ev) == 0:
123
+ cev.append(ev)
124
+ else:
125
+ raise ValueError(ev)
126
+ return cev
127
+
128
+
129
+ if __name__ == '__main__':
130
+ floader = FEVERLoader('test')
131
+ floader.load_fever('bert', clean_load=False)
132
+ for k in floader:
133
+ print(floader[k])
134
+ input()
src/er_client/__init__.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ '''
4
+ @Author : Jiangjie Chen
5
+ @Time : 2020/9/21 16:13
6
+ @Contact : jjchen19@fudan.edu.cn
7
+ @Description:
8
+ '''
9
+
10
+ import cjjpy as cjj
11
+ import os
12
+ # from .document_retrieval import DocRetrieval
13
+ from .doc_retrieval_by_api import DocRetrieval
14
+ from .sentence_selection import SentSelector
15
+
16
+
17
+ arg_values = {
18
+ 'batch_size': 32,
19
+ 'dropout': 0.6,
20
+ 'use_cuda': True,
21
+ 'bert_hidden_dim': 768,
22
+ 'layer': 1,
23
+ 'num_labels': 3,
24
+ 'evi_num': 5,
25
+ 'threshold': 0.0,
26
+ 'max_len': 120,
27
+ }
28
+
29
+ args = cjj.AttrDict(arg_values)
30
+
31
+ class EvidenceRetrieval:
32
+ def __init__(self, er_model_dir=cjj.AbsParentDir(__file__, '...', 'models/evidence_retrieval/')):
33
+ # self.doc_retriever = DocRetrieval(cjj.AbsParentDir(__file__, '...', 'data/fever.db'),
34
+ # add_claim=True, k_wiki_results=7)
35
+ self.doc_retrieval = DocRetrieval(link_type='tagme')
36
+ self.sent_selector = SentSelector(os.path.join(er_model_dir, 'bert_base/'),
37
+ os.path.join(er_model_dir, 'retrieval_model/model.best.pt'),
38
+ args)
39
+
40
+ def retrieve(self, claim):
41
+ # noun_phrases, wiki_results, predicted_pages = self.doc_retriever.exact_match(claim)
42
+ # evidence = []
43
+ # for page in predicted_pages:
44
+ # evidence.extend(self.doc_retriever.db.get_doc_lines(page))
45
+ evidence = self.doc_retrieval.retrieve_docs(claim)
46
+ evidence = self.rank_sentences(claim, evidence)
47
+ return evidence
48
+
49
+ def rank_sentences(self, claim, sentences, id=None):
50
+ '''
51
+ :param claim: str
52
+ :param sentences: [(ent, num, sent) * N]
53
+ :param id:
54
+ :return: [(ent, num, sent) * k]
55
+ '''
56
+ if id is None:
57
+ id = len(claim)
58
+
59
+ result = self.sent_selector.rank_sentences([{'claim': claim,
60
+ 'evidence': sentences,
61
+ 'id': id}])
62
+ evidence = result.get(id, [])
63
+ return evidence
src/er_client/cjjpy.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ '''
4
+ @Author : Jiangjie Chen
5
+ @Time : 2018/11/15 17:08
6
+ @Contact: jjchen19@fudan.edu.cn
7
+ '''
8
+
9
+ import re
10
+ import datetime
11
+ import os
12
+ import argparse
13
+ import logging
14
+ import traceback
15
+
16
+ try:
17
+ import ujson as json
18
+ except:
19
+ import json
20
+
21
+ HADOOP_BIN = 'PATH=/usr/bin/:$PATH hdfs'
22
+ FOR_PUBLIC = True
23
+
24
+
25
+ def LengthStats(filename):
26
+ len_list = []
27
+ thresholds = [0.8, 0.9, 0.95, 0.99, 0.999]
28
+ with open(filename) as f:
29
+ for line in f:
30
+ len_list.append(len(line.strip().split()))
31
+ stats = {
32
+ 'Max': max(len_list),
33
+ 'Min': min(len_list),
34
+ 'Avg': round(sum(len_list) / len(len_list), 4),
35
+ }
36
+ len_list.sort()
37
+ for t in thresholds:
38
+ stats[f"Top-{t}"] = len_list[int(len(len_list) * t)]
39
+
40
+ for k in stats:
41
+ print(f"- {k}: {stats[k]}")
42
+ return stats
43
+
44
+
45
+ class AttrDict(dict):
46
+ def __init__(self, *args, **kwargs):
47
+ super(AttrDict, self).__init__(*args, **kwargs)
48
+ self.__dict__ = self
49
+
50
+
51
+ def TraceBack(error_msg):
52
+ exc = traceback.format_exc()
53
+ msg = f'[Error]: {error_msg}.\n[Traceback]: {exc}'
54
+ return msg
55
+
56
+
57
+ def Now():
58
+ return datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
59
+
60
+
61
+ def AbsParentDir(file, parent='..', postfix=None):
62
+ ppath = os.path.abspath(file)
63
+ parent_level = parent.count('.')
64
+ while parent_level > 0:
65
+ ppath = os.path.dirname(ppath)
66
+ parent_level -= 1
67
+ if postfix is not None:
68
+ return os.path.join(ppath, postfix)
69
+ else:
70
+ return ppath
71
+
72
+
73
+ def init_logger(log_file=None, log_file_level=logging.NOTSET, from_scratch=False):
74
+ from coloredlogs import ColoredFormatter
75
+ import tensorflow as tf
76
+
77
+ fmt = "[%(asctime)s %(levelname)s] %(message)s"
78
+ log_format = ColoredFormatter(fmt=fmt)
79
+ # log_format = logging.Formatter()
80
+ logger = logging.getLogger()
81
+ logger.setLevel(log_file_level)
82
+
83
+ console_handler = logging.StreamHandler()
84
+ console_handler.setFormatter(log_format)
85
+ logger.handlers = [console_handler]
86
+
87
+ if log_file and log_file != '':
88
+ if from_scratch and tf.io.gfile.exists(log_file):
89
+ logger.warning('Removing previous log file: %s' % log_file)
90
+ tf.io.gfile.remove(log_file)
91
+ path = os.path.dirname(log_file)
92
+ os.makedirs(path, exist_ok=True)
93
+ file_handler = logging.FileHandler(log_file)
94
+ file_handler.setLevel(log_file_level)
95
+ file_handler.setFormatter(log_format)
96
+ logger.addHandler(file_handler)
97
+
98
+ return logger
99
+
100
+
101
+ def OverWriteCjjPy(root='.'):
102
+ # import difflib
103
+ # diff = difflib.HtmlDiff()
104
+ cnt = 0
105
+ golden_cjjpy = os.path.join(root, 'cjjpy.py')
106
+ # golden_content = open(golden_cjjpy).readlines()
107
+ for dir, folder, file in os.walk(root):
108
+ for f in file:
109
+ if f == 'cjjpy.py':
110
+ cjjpy = '%s/%s' % (dir, f)
111
+ # content = open(cjjpy).readlines()
112
+ # d = diff.make_file(golden_content, content)
113
+ cnt += 1
114
+ print('[%d]: %s' % (cnt, cjjpy))
115
+ os.system('cp %s %s' % (golden_cjjpy, cjjpy))
116
+
117
+
118
+ def ChangeFileFormat(filename, new_fmt):
119
+ assert type(filename) is str and type(new_fmt) is str
120
+ spt = filename.split('.')
121
+ if len(spt) == 0:
122
+ return filename
123
+ else:
124
+ return filename.replace('.' + spt[-1], new_fmt)
125
+
126
+
127
+ def CountLines(fname):
128
+ with open(fname, 'rb') as f:
129
+ count = 0
130
+ last_data = '\n'
131
+ while True:
132
+ data = f.read(0x400000)
133
+ if not data:
134
+ break
135
+ count += data.count(b'\n')
136
+ last_data = data
137
+ if last_data[-1:] != b'\n':
138
+ count += 1 # Remove this if a wc-like count is needed
139
+ return count
140
+
141
+
142
+ def GetDate():
143
+ return str(datetime.datetime.now())[5:10].replace('-', '')
144
+
145
+
146
+ def TimeClock(seconds):
147
+ sec = int(seconds)
148
+ hour = int(sec / 3600)
149
+ min = int((sec - hour * 3600) / 60)
150
+ ssec = float(seconds) - hour * 3600 - min * 60
151
+ # return '%dh %dm %.2fs' % (hour, min, ssec)
152
+ return '{0:>2d}h{1:>3d}m{2:>6.2f}s'.format(hour, min, ssec)
153
+
154
+
155
+ def StripAll(text):
156
+ return text.strip().replace('\t', '').replace('\n', '').replace(' ', '')
157
+
158
+
159
+ def GetBracket(text, bracket, en_br=False):
160
+ # input should be aa(bb)cc, True for bracket, False for text
161
+ if bracket:
162
+ try:
163
+ return re.findall('\((.*?)\)', text.strip())[-1]
164
+ except:
165
+ return ''
166
+ else:
167
+ if en_br:
168
+ text = re.sub('\(.*?\)', '', text.strip())
169
+ return re.sub('(.*?)', '', text.strip())
170
+
171
+
172
+ def CharLang(uchar, lang):
173
+ assert lang.lower() in ['en', 'cn', 'zh']
174
+ if lang.lower() in ['cn', 'zh']:
175
+ if uchar >= '\u4e00' and uchar <= '\u9fa5':
176
+ return True
177
+ else:
178
+ return False
179
+ elif lang.lower() == 'en':
180
+ if (uchar <= 'Z' and uchar >= 'A') or (uchar <= 'z' and uchar >= 'a'):
181
+ return True
182
+ else:
183
+ return False
184
+ else:
185
+ raise NotImplementedError
186
+
187
+
188
+ def WordLang(word, lang):
189
+ for i in word.strip():
190
+ if i.isspace(): continue
191
+ if not CharLang(i, lang):
192
+ return False
193
+ return True
194
+
195
+
196
+ def SortDict(_dict, reverse=True):
197
+ assert type(_dict) is dict
198
+ return sorted(_dict.items(), key=lambda d: d[1], reverse=reverse)
199
+
200
+
201
+ def lark(content='test'):
202
+ print(content)
203
+
204
+
205
+ if __name__ == '__main__':
206
+ parser = argparse.ArgumentParser()
207
+
208
+ parser.add_argument('--diff', nargs=2,
209
+ help='show difference between two files, shown in downloads/diff.html')
210
+ parser.add_argument('--de_unicode', action='store_true', default=False,
211
+ help='remove unicode characters')
212
+ parser.add_argument('--link_entity', action='store_true', default=False,
213
+ help='')
214
+ parser.add_argument('--max_comm_len', action='store_true', default=False,
215
+ help='')
216
+ parser.add_argument('--search', nargs=2,
217
+ help='search key from file, 2 args: file name & key')
218
+ parser.add_argument('--email', nargs=2,
219
+ help='sending emails, 2 args: subject & content')
220
+ parser.add_argument('--overwrite', action='store_true', default=None,
221
+ help='overwrite all cjjpy under given *dir* based on *dir*/cjjpy.py')
222
+ parser.add_argument('--replace', nargs=3,
223
+ help='replace char, 3 args: file name & replaced char & replacer char')
224
+ parser.add_argument('--lark', nargs=1)
225
+ parser.add_argument('--get_hdfs', nargs=2,
226
+ help='easy copy from hdfs to local fs, 2 args: remote_file/dir & local_dir')
227
+ parser.add_argument('--put_hdfs', nargs=2,
228
+ help='easy put from local fs to hdfs, 2 args: local_file/dir & remote_dir')
229
+ parser.add_argument('--length_stats', nargs=1,
230
+ help='simple token lengths distribution of a line-by-line file')
231
+
232
+ args = parser.parse_args()
233
+
234
+ if args.overwrite:
235
+ print('* Overwriting cjjpy...')
236
+ OverWriteCjjPy()
237
+
238
+ if args.lark:
239
+ try:
240
+ content = args.lark[0]
241
+ except:
242
+ content = 'running complete'
243
+ print(f'* Larking "{content}"...')
244
+ lark(content)
245
+
246
+ if args.length_stats:
247
+ file = args.length_stats[0]
248
+ print(f'* Working on {file} lengths statistics...')
249
+ LengthStats(file)
src/er_client/doc_retrieval_by_api.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ '''
4
+ @Author : Jiangjie Chen
5
+ @Time : 2020/11/12 21:19
6
+ @Contact : jjchen19@fudan.edu.cn
7
+ @Description:
8
+ '''
9
+
10
+ import wikipediaapi
11
+ import nltk
12
+ from nltk.tokenize import sent_tokenize
13
+ nltk.download('punkt')
14
+ try:
15
+ from entitylinker import ELClient
16
+ except:
17
+ from .entitylinker import ELClient
18
+
19
+
20
+ class DocRetrieval:
21
+ def __init__(self, link_type):
22
+ self.wiki = wikipediaapi.Wikipedia('en')
23
+ self.er_client = ELClient(link_type, verbose=True)
24
+
25
+ def _get_page(self, title):
26
+ summary = self.wiki.page(title).summary
27
+ sents = []
28
+ for i, sent in enumerate(sent_tokenize(summary)):
29
+ sents.append((title, i, sent, 0))
30
+ return sents
31
+
32
+ def retrieve_docs(self, claim):
33
+ el_results = self.er_client.link(claim)
34
+ sents = []
35
+ for text, label, kb_id, title in el_results:
36
+ if title == '': continue
37
+ sents += self._get_page(title)
38
+ return sents
39
+
40
+
41
+ if __name__ == '__main__':
42
+ doc = DocRetrieval('tagme')
43
+ print(doc.retrieve_docs('joe biden won the U.S. president.'))
44
+ print(doc.retrieve_docs('Joe Biden won the U.S. president.'))
src/er_client/document_retrieval.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding:utf-8 -*-
2
+
3
+ """
4
+ @Author : Bao
5
+ @Date : 2020/9/17
6
+ @Desc : Document selection and sentence ranking code from KGAT. Not used in LOREN.
7
+ @Last modified by : Bao
8
+ @Last modified date : 2020/9/17
9
+ """
10
+
11
+ import re
12
+ import time
13
+ import json
14
+ import nltk
15
+ from tqdm import tqdm
16
+ from allennlp.predictors import Predictor
17
+ from drqa.retriever import DocDB, utils
18
+ from drqa.retriever.utils import normalize
19
+ import wikipedia
20
+
21
+
22
+ class FeverDocDB(DocDB):
23
+ def __init__(self, path=None):
24
+ super().__init__(path)
25
+
26
+ def get_doc_lines(self, doc_id):
27
+ """Fetch the raw text of the doc for 'doc_id'."""
28
+ cursor = self.connection.cursor()
29
+ cursor.execute(
30
+ "SELECT lines FROM documents WHERE id = ?",
31
+ (utils.normalize(doc_id),)
32
+ )
33
+ result = cursor.fetchone()
34
+ cursor.close()
35
+
36
+ result = result[0] if result is not None else ''
37
+ doc_lines = []
38
+ for line in result.split('\n'):
39
+ if len(line) == 0: continue
40
+ line = line.split('\t')[1]
41
+ if len(line) == 0: continue
42
+ doc_lines.append((doc_id, len(doc_lines), line, 0))
43
+
44
+ return doc_lines
45
+
46
+ def get_non_empty_doc_ids(self):
47
+ """Fetch all ids of docs stored in the db."""
48
+ cursor = self.connection.cursor()
49
+ cursor.execute("SELECT id FROM documents WHERE length(trim(text)) > 0")
50
+ results = [r[0] for r in cursor.fetchall()]
51
+ cursor.close()
52
+ return results
53
+
54
+
55
+ class DocRetrieval:
56
+ def __init__(self, database_path, add_claim=False, k_wiki_results=None):
57
+ self.db = FeverDocDB(database_path)
58
+ self.add_claim = add_claim
59
+ self.k_wiki_results = k_wiki_results
60
+ self.porter_stemmer = nltk.PorterStemmer()
61
+ self.tokenizer = nltk.word_tokenize
62
+ self.predictor = Predictor.from_path(
63
+ "https://storage.googleapis.com/allennlp-public-models/elmo-constituency-parser-2020.02.10.tar.gz"
64
+ )
65
+
66
+ def get_NP(self, tree, nps):
67
+ if isinstance(tree, dict):
68
+ if "children" not in tree:
69
+ if tree['nodeType'] == "NP":
70
+ # print(tree['word'])
71
+ # print(tree)
72
+ nps.append(tree['word'])
73
+ elif "children" in tree:
74
+ if tree['nodeType'] == "NP":
75
+ # print(tree['word'])
76
+ nps.append(tree['word'])
77
+ self.get_NP(tree['children'], nps)
78
+ else:
79
+ self.get_NP(tree['children'], nps)
80
+ elif isinstance(tree, list):
81
+ for sub_tree in tree:
82
+ self.get_NP(sub_tree, nps)
83
+
84
+ return nps
85
+
86
+ def get_subjects(self, tree):
87
+ subject_words = []
88
+ subjects = []
89
+ for subtree in tree['children']:
90
+ if subtree['nodeType'] == "VP" or subtree['nodeType'] == 'S' or subtree['nodeType'] == 'VBZ':
91
+ subjects.append(' '.join(subject_words))
92
+ subject_words.append(subtree['word'])
93
+ else:
94
+ subject_words.append(subtree['word'])
95
+ return subjects
96
+
97
+ def get_noun_phrases(self, claim):
98
+ tokens = self.predictor.predict(claim)
99
+ nps = []
100
+ tree = tokens['hierplane_tree']['root']
101
+ noun_phrases = self.get_NP(tree, nps)
102
+ subjects = self.get_subjects(tree)
103
+ for subject in subjects:
104
+ if len(subject) > 0:
105
+ noun_phrases.append(subject)
106
+ if self.add_claim:
107
+ noun_phrases.append(claim)
108
+ return list(set(noun_phrases))
109
+
110
+ def get_doc_for_claim(self, noun_phrases):
111
+ predicted_pages = []
112
+ for np in noun_phrases:
113
+ if len(np) > 300:
114
+ continue
115
+ i = 1
116
+ while i < 12:
117
+ try:
118
+ # print(np)
119
+ # res = server.lookup(np, keep_all=True)
120
+ # docs = [y for _, y in res] if res is not None else []
121
+ docs = wikipedia.search(np)
122
+ if self.k_wiki_results is not None:
123
+ predicted_pages.extend(docs[:self.k_wiki_results])
124
+ else:
125
+ predicted_pages.extend(docs)
126
+ except (ConnectionResetError, ConnectionError, ConnectionAbortedError, ConnectionRefusedError):
127
+ print("Connection reset error received! Trial #" + str(i))
128
+ time.sleep(600 * i)
129
+ i += 1
130
+ else:
131
+ break
132
+
133
+ # sleep_num = random.uniform(0.1,0.7)
134
+ # time.sleep(sleep_num)
135
+ predicted_pages = set(predicted_pages)
136
+ processed_pages = []
137
+ for page in predicted_pages:
138
+ page = page.replace(" ", "_")
139
+ page = page.replace("(", "-LRB-")
140
+ page = page.replace(")", "-RRB-")
141
+ page = page.replace(":", "-COLON-")
142
+ processed_pages.append(page)
143
+
144
+ return processed_pages
145
+
146
+ def np_conc(self, noun_phrases):
147
+ noun_phrases = set(noun_phrases)
148
+ predicted_pages = []
149
+ for np in noun_phrases:
150
+ page = np.replace('( ', '-LRB-')
151
+ page = page.replace(' )', '-RRB-')
152
+ page = page.replace(' - ', '-')
153
+ page = page.replace(' :', '-COLON-')
154
+ page = page.replace(' ,', ',')
155
+ page = page.replace(" 's", "'s")
156
+ page = page.replace(' ', '_')
157
+
158
+ if len(page) < 1:
159
+ continue
160
+ doc_lines = self.db.get_doc_lines(page)
161
+ if len(doc_lines) > 0:
162
+ predicted_pages.append(page)
163
+ return predicted_pages
164
+
165
+ def exact_match(self, claim):
166
+ noun_phrases = self.get_noun_phrases(claim)
167
+ wiki_results = self.get_doc_for_claim(noun_phrases)
168
+ wiki_results = list(set(wiki_results))
169
+
170
+ claim = claim.replace(".", "")
171
+ claim = claim.replace("-", " ")
172
+ words = [self.porter_stemmer.stem(word.lower()) for word in self.tokenizer(claim)]
173
+ words = set(words)
174
+ predicted_pages = self.np_conc(noun_phrases)
175
+
176
+ for page in wiki_results:
177
+ page = normalize(page)
178
+ processed_page = re.sub("-LRB-.*?-RRB-", "", page)
179
+ processed_page = re.sub("_", " ", processed_page)
180
+ processed_page = re.sub("-COLON-", ":", processed_page)
181
+ processed_page = processed_page.replace("-", " ")
182
+ processed_page = processed_page.replace("–", " ")
183
+ processed_page = processed_page.replace(".", "")
184
+ page_words = [self.porter_stemmer.stem(word.lower()) for word in self.tokenizer(processed_page) if
185
+ len(word) > 0]
186
+
187
+ if all([item in words for item in page_words]):
188
+ if ':' in page:
189
+ page = page.replace(":", "-COLON-")
190
+ predicted_pages.append(page)
191
+ predicted_pages = list(set(predicted_pages))
192
+
193
+ return noun_phrases, wiki_results, predicted_pages
194
+
195
+
196
+ def save_to_file(results, client, filename):
197
+ with open(filename, 'w', encoding='utf-8') as fout:
198
+ for _id, line in enumerate(results):
199
+ claim = line['claim']
200
+ evidence = []
201
+ for page in line['predicted_pages']:
202
+ evidence.extend(client.db.get_doc_lines(page))
203
+ print(json.dumps({'claim': claim, 'evidence': evidence}, ensure_ascii=False), file=fout)
204
+
205
+
206
+ if __name__ == '__main__':
207
+ database_path = 'data/fever.db'
208
+ add_claim = True
209
+ k_wiki_results = 7
210
+ client = DocRetrieval(database_path, add_claim, k_wiki_results)
211
+
212
+ results = []
213
+ with open('data/claims.json', 'r', encoding='utf-8') as fin:
214
+ for line in tqdm(fin):
215
+ line = json.loads(line)
216
+ _, _, predicted_pages = client.exact_match(line['claim'])
217
+ evidence = []
218
+ for page in predicted_pages:
219
+ evidence.extend(client.db.get_doc_lines(page))
220
+ line['evidence'] = evidence
221
+ results.append(line)
222
+
223
+ with open('data/pages.json', 'w', encoding='utf-8') as fout:
224
+ for line in results:
225
+ print(json.dumps(line, ensure_ascii=False), file=fout)
src/er_client/entitylinker.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ '''
4
+ @Author : Jiangjie Chen
5
+ @Time : 2020/5/11 19:08
6
+ @Contact : jjchen19@fudan.edu.cn
7
+ @Description:
8
+ '''
9
+
10
+ import os
11
+ import tagme
12
+
13
+
14
+ def read_title_id(entity_def_path):
15
+ id_to_title = {}
16
+ with open(entity_def_path, 'r', encoding='UTF-8') as f:
17
+ lines = f.readlines()
18
+ for i, line in enumerate(lines):
19
+ if i > 0:
20
+ entity, id = line.strip().split('|')
21
+ id_to_title[id] = entity
22
+
23
+ return id_to_title
24
+
25
+
26
+ class ELClient:
27
+ def __init__(self, link_type, min_rho=0.1, prefix=None, verbose=False):
28
+ self.verbose = verbose
29
+ self.link_type = link_type
30
+ if link_type == 'tagme':
31
+ self.min_rho = min_rho
32
+ tagme.GCUBE_TOKEN = os.environ['TAGME_APIKEY']
33
+ elif link_type == 'spacy':
34
+ assert prefix is not None
35
+ self.init_spacy_linker(prefix)
36
+ else:
37
+ raise NotImplementedError(link_type)
38
+
39
+ def init_spacy_linker(self, prefix):
40
+ entity_def_path = f"{prefix}/entity_defs.csv"
41
+ self._print('* Loading entity linker...')
42
+ self.nlp = spacy.load(prefix)
43
+ self.id2title = read_title_id(entity_def_path)
44
+ self._print('* Entity linker loaded.')
45
+
46
+ def _tagme_link(self, text):
47
+ result = []
48
+ for ann in tagme.annotate(text, long_text=1).get_annotations(min_rho=self.min_rho):
49
+ result.append((text[ann.begin:ann.end], ann.score, ann.entity_id, ann.entity_title))
50
+ # result.append({'begin': ann.begin,
51
+ # 'end': ann.end,
52
+ # 'id': ann.entity_id,
53
+ # 'title': ann.entity_title,
54
+ # 'score': ann.score})
55
+ result.sort(key=lambda x: x[1], reverse=True)
56
+ return result
57
+
58
+ def link(self, text):
59
+ if self.link_type == 'tagme':
60
+ return self._tagme_link(text)
61
+ else:
62
+ return self._spacy_link(text)
63
+
64
+ def _spacy_link(self, text):
65
+ text = self._preprocess_text(text)
66
+ doc = self.nlp(text)
67
+ ents = [(e.text, e.label_, e.kb_id_, self.id2title.get(e.kb_id_, ''))
68
+ for e in doc.ents if e.kb_id_ != 'NIL']
69
+ return ents
70
+
71
+ def _preprocess_text(self, text):
72
+ if isinstance(text, list):
73
+ text = ' '.join(text)
74
+ text = text.strip().replace('-lrb-', '(').replace('-rrb-', ')')
75
+ return text
76
+
77
+ def _print(self, x):
78
+ if self.verbose: print(x)
79
+
80
+
81
+ if __name__ == '__main__':
82
+ elcl = ELClient(link_type='tagme', verbose=True)
83
+ res = elcl.link('Jeff Dean wants to meet Yoshua Bengio.')
84
+ print(res)
src/er_client/retrieval_model/bert_model.py ADDED
@@ -0,0 +1,775 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """PyTorch BERT model."""
17
+
18
+ from __future__ import absolute_import, division, print_function, unicode_literals
19
+
20
+ import copy
21
+ import json
22
+ import logging
23
+ import math
24
+ import os
25
+ import shutil
26
+ import tarfile
27
+ import tempfile
28
+ import sys
29
+ from io import open
30
+
31
+ import torch
32
+ from torch import nn
33
+ from torch.nn import CrossEntropyLoss
34
+
35
+ from .file_utils import cached_path
36
+
37
+ logger = logging.getLogger(__name__)
38
+
39
+ PRETRAINED_MODEL_ARCHIVE_MAP = {
40
+ 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz",
41
+ 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased.tar.gz",
42
+ 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased.tar.gz",
43
+ 'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased.tar.gz",
44
+ 'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased.tar.gz",
45
+ 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased.tar.gz",
46
+ 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese.tar.gz",
47
+ }
48
+ CONFIG_NAME = 'bert_config.json'
49
+ WEIGHTS_NAME = 'pytorch_model.bin'
50
+ TF_WEIGHTS_NAME = 'model.ckpt'
51
+
52
+ def load_tf_weights_in_bert(model, tf_checkpoint_path):
53
+ """ Load tf checkpoints in a pytorch model
54
+ """
55
+ try:
56
+ import re
57
+ import numpy as np
58
+ import tensorflow as tf
59
+ except ImportError:
60
+ print("Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see "
61
+ "https://www.tensorflow.org/install/ for installation instructions.")
62
+ raise
63
+ tf_path = os.path.abspath(tf_checkpoint_path)
64
+ print("Converting TensorFlow checkpoint from {}".format(tf_path))
65
+ # Load weights from TF model
66
+ init_vars = tf.train.list_variables(tf_path)
67
+ names = []
68
+ arrays = []
69
+ for name, shape in init_vars:
70
+ print("Loading TF weight {} with shape {}".format(name, shape))
71
+ array = tf.train.load_variable(tf_path, name)
72
+ names.append(name)
73
+ arrays.append(array)
74
+
75
+ for name, array in zip(names, arrays):
76
+ name = name.split('/')
77
+ # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
78
+ # which are not required for using pretrained model
79
+ if any(n in ["adam_v", "adam_m"] for n in name):
80
+ print("Skipping {}".format("/".join(name)))
81
+ continue
82
+ pointer = model
83
+ for m_name in name:
84
+ if re.fullmatch(r'[A-Za-z]+_\d+', m_name):
85
+ l = re.split(r'_(\d+)', m_name)
86
+ else:
87
+ l = [m_name]
88
+ if l[0] == 'kernel' or l[0] == 'gamma':
89
+ pointer = getattr(pointer, 'weight')
90
+ elif l[0] == 'output_bias' or l[0] == 'beta':
91
+ pointer = getattr(pointer, 'bias')
92
+ elif l[0] == 'output_weights':
93
+ pointer = getattr(pointer, 'weight')
94
+ else:
95
+ pointer = getattr(pointer, l[0])
96
+ if len(l) >= 2:
97
+ num = int(l[1])
98
+ pointer = pointer[num]
99
+ if m_name[-11:] == '_embeddings':
100
+ pointer = getattr(pointer, 'weight')
101
+ elif m_name == 'kernel':
102
+ array = np.transpose(array)
103
+ try:
104
+ assert pointer.shape == array.shape
105
+ except AssertionError as e:
106
+ e.args += (pointer.shape, array.shape)
107
+ raise
108
+ print("Initialize PyTorch weight {}".format(name))
109
+ pointer.data = torch.from_numpy(array)
110
+ return model
111
+
112
+
113
+ def gelu(x):
114
+ """Implementation of the gelu activation function.
115
+ For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
116
+ 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
117
+ Also see https://arxiv.org/abs/1606.08415
118
+ """
119
+ return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
120
+
121
+
122
+ def swish(x):
123
+ return x * torch.sigmoid(x)
124
+
125
+
126
+ ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish}
127
+
128
+
129
+ class BertConfig(object):
130
+ """Configuration class to store the configuration of a `BertModel`.
131
+ """
132
+ def __init__(self,
133
+ vocab_size_or_config_json_file,
134
+ hidden_size=768,
135
+ num_hidden_layers=12,
136
+ num_attention_heads=12,
137
+ intermediate_size=3072,
138
+ hidden_act="gelu",
139
+ hidden_dropout_prob=0.1,
140
+ attention_probs_dropout_prob=0.1,
141
+ max_position_embeddings=512,
142
+ type_vocab_size=2,
143
+ initializer_range=0.02):
144
+ """Constructs BertConfig.
145
+
146
+ Args:
147
+ vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `BertModel`.
148
+ hidden_size: Size of the encoder layers and the pooler layer.
149
+ num_hidden_layers: Number of hidden layers in the Transformer encoder.
150
+ num_attention_heads: Number of attention heads for each attention layer in
151
+ the Transformer encoder.
152
+ intermediate_size: The size of the "intermediate" (i.e., feed-forward)
153
+ layer in the Transformer encoder.
154
+ hidden_act: The non-linear activation function (function or string) in the
155
+ encoder and pooler. If string, "gelu", "relu" and "swish" are supported.
156
+ hidden_dropout_prob: The dropout probabilitiy for all fully connected
157
+ layers in the embeddings, encoder, and pooler.
158
+ attention_probs_dropout_prob: The dropout ratio for the attention
159
+ probabilities.
160
+ max_position_embeddings: The maximum sequence length that this model might
161
+ ever be used with. Typically set this to something large just in case
162
+ (e.g., 512 or 1024 or 2048).
163
+ type_vocab_size: The vocabulary size of the `token_type_ids` passed into
164
+ `BertModel`.
165
+ initializer_range: The sttdev of the truncated_normal_initializer for
166
+ initializing all weight matrices.
167
+ """
168
+ if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2
169
+ and isinstance(vocab_size_or_config_json_file, unicode)):
170
+ with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader:
171
+ json_config = json.loads(reader.read())
172
+ for key, value in json_config.items():
173
+ self.__dict__[key] = value
174
+ elif isinstance(vocab_size_or_config_json_file, int):
175
+ self.vocab_size = vocab_size_or_config_json_file
176
+ self.hidden_size = hidden_size
177
+ self.num_hidden_layers = num_hidden_layers
178
+ self.num_attention_heads = num_attention_heads
179
+ self.hidden_act = hidden_act
180
+ self.intermediate_size = intermediate_size
181
+ self.hidden_dropout_prob = hidden_dropout_prob
182
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
183
+ self.max_position_embeddings = max_position_embeddings
184
+ self.type_vocab_size = type_vocab_size
185
+ self.initializer_range = initializer_range
186
+ else:
187
+ raise ValueError("First argument must be either a vocabulary size (int)"
188
+ "or the path to a pretrained model config file (str)")
189
+
190
+ @classmethod
191
+ def from_dict(cls, json_object):
192
+ """Constructs a `BertConfig` from a Python dictionary of parameters."""
193
+ config = BertConfig(vocab_size_or_config_json_file=-1)
194
+ for key, value in json_object.items():
195
+ config.__dict__[key] = value
196
+ return config
197
+
198
+ @classmethod
199
+ def from_json_file(cls, json_file):
200
+ """Constructs a `BertConfig` from a json file of parameters."""
201
+ with open(json_file, "r", encoding='utf-8') as reader:
202
+ text = reader.read()
203
+ return cls.from_dict(json.loads(text))
204
+
205
+ def __repr__(self):
206
+ return str(self.to_json_string())
207
+
208
+ def to_dict(self):
209
+ """Serializes this instance to a Python dictionary."""
210
+ output = copy.deepcopy(self.__dict__)
211
+ return output
212
+
213
+ def to_json_string(self):
214
+ """Serializes this instance to a JSON string."""
215
+ return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
216
+
217
+ try:
218
+ from apex.normalization.fused_layer_norm import FusedLayerNorm as BertLayerNorm
219
+ except ImportError:
220
+ print("Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.")
221
+ class BertLayerNorm(nn.Module):
222
+ def __init__(self, hidden_size, eps=1e-12):
223
+ """Construct a layernorm module in the TF style (epsilon inside the square root).
224
+ """
225
+ super(BertLayerNorm, self).__init__()
226
+ self.weight = nn.Parameter(torch.ones(hidden_size))
227
+ self.bias = nn.Parameter(torch.zeros(hidden_size))
228
+ self.variance_epsilon = eps
229
+
230
+ def forward(self, x):
231
+ u = x.mean(-1, keepdim=True)
232
+ s = (x - u).pow(2).mean(-1, keepdim=True)
233
+ x = (x - u) / torch.sqrt(s + self.variance_epsilon)
234
+ return self.weight * x + self.bias
235
+
236
+ class BertEmbeddings(nn.Module):
237
+ """Construct the embeddings from word, position and token_type embeddings.
238
+ """
239
+ def __init__(self, config):
240
+ super(BertEmbeddings, self).__init__()
241
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
242
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
243
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
244
+
245
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
246
+ # any TensorFlow checkpoint file
247
+ self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
248
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
249
+
250
+ def forward(self, input_ids, token_type_ids=None):
251
+ seq_length = input_ids.size(1)
252
+ position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
253
+ position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
254
+ if token_type_ids is None:
255
+ token_type_ids = torch.zeros_like(input_ids)
256
+
257
+ words_embeddings = self.word_embeddings(input_ids)
258
+ position_embeddings = self.position_embeddings(position_ids)
259
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
260
+
261
+ embeddings = words_embeddings + position_embeddings + token_type_embeddings
262
+ embeddings = self.LayerNorm(embeddings)
263
+ embeddings = self.dropout(embeddings)
264
+ return embeddings
265
+
266
+
267
+ class BertSelfAttention(nn.Module):
268
+ def __init__(self, config):
269
+ super(BertSelfAttention, self).__init__()
270
+ if config.hidden_size % config.num_attention_heads != 0:
271
+ raise ValueError(
272
+ "The hidden size (%d) is not a multiple of the number of attention "
273
+ "heads (%d)" % (config.hidden_size, config.num_attention_heads))
274
+ self.num_attention_heads = config.num_attention_heads
275
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
276
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
277
+
278
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
279
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
280
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
281
+
282
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
283
+
284
+ def transpose_for_scores(self, x):
285
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
286
+ x = x.view(*new_x_shape)
287
+ return x.permute(0, 2, 1, 3)
288
+
289
+ def forward(self, hidden_states, attention_mask):
290
+ mixed_query_layer = self.query(hidden_states)
291
+ mixed_key_layer = self.key(hidden_states)
292
+ mixed_value_layer = self.value(hidden_states)
293
+
294
+ query_layer = self.transpose_for_scores(mixed_query_layer)
295
+ key_layer = self.transpose_for_scores(mixed_key_layer)
296
+ value_layer = self.transpose_for_scores(mixed_value_layer)
297
+
298
+ # Take the dot product between "query" and "key" to get the raw attention scores.
299
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
300
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
301
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
302
+ attention_scores = attention_scores + attention_mask
303
+
304
+ # Normalize the attention scores to probabilities.
305
+ attention_probs = nn.Softmax(dim=-1)(attention_scores)
306
+
307
+ # This is actually dropping out entire tokens to attend to, which might
308
+ # seem a bit unusual, but is taken from the original Transformer paper.
309
+ attention_probs = self.dropout(attention_probs)
310
+
311
+ context_layer = torch.matmul(attention_probs, value_layer)
312
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
313
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
314
+ context_layer = context_layer.view(*new_context_layer_shape)
315
+ return context_layer
316
+
317
+
318
+ class BertSelfOutput(nn.Module):
319
+ def __init__(self, config):
320
+ super(BertSelfOutput, self).__init__()
321
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
322
+ self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
323
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
324
+
325
+ def forward(self, hidden_states, input_tensor):
326
+ hidden_states = self.dense(hidden_states)
327
+ hidden_states = self.dropout(hidden_states)
328
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
329
+ return hidden_states
330
+
331
+
332
+ class BertAttention(nn.Module):
333
+ def __init__(self, config):
334
+ super(BertAttention, self).__init__()
335
+ self.self = BertSelfAttention(config)
336
+ self.output = BertSelfOutput(config)
337
+
338
+ def forward(self, input_tensor, attention_mask):
339
+ self_output = self.self(input_tensor, attention_mask)
340
+ attention_output = self.output(self_output, input_tensor)
341
+ return attention_output
342
+
343
+
344
+ class BertIntermediate(nn.Module):
345
+ def __init__(self, config):
346
+ super(BertIntermediate, self).__init__()
347
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
348
+ if isinstance(config.hidden_act, str) or (sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode)):
349
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
350
+ else:
351
+ self.intermediate_act_fn = config.hidden_act
352
+
353
+ def forward(self, hidden_states):
354
+ hidden_states = self.dense(hidden_states)
355
+ hidden_states = self.intermediate_act_fn(hidden_states)
356
+ return hidden_states
357
+
358
+
359
+ class BertOutput(nn.Module):
360
+ def __init__(self, config):
361
+ super(BertOutput, self).__init__()
362
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
363
+ self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
364
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
365
+
366
+ def forward(self, hidden_states, input_tensor):
367
+ hidden_states = self.dense(hidden_states)
368
+ hidden_states = self.dropout(hidden_states)
369
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
370
+ return hidden_states
371
+
372
+
373
+ class BertLayer(nn.Module):
374
+ def __init__(self, config):
375
+ super(BertLayer, self).__init__()
376
+ self.attention = BertAttention(config)
377
+ self.intermediate = BertIntermediate(config)
378
+ self.output = BertOutput(config)
379
+
380
+ def forward(self, hidden_states, attention_mask):
381
+ attention_output = self.attention(hidden_states, attention_mask)
382
+ intermediate_output = self.intermediate(attention_output)
383
+ layer_output = self.output(intermediate_output, attention_output)
384
+ return layer_output
385
+
386
+
387
+ class BertEncoder(nn.Module):
388
+ def __init__(self, config):
389
+ super(BertEncoder, self).__init__()
390
+ layer = BertLayer(config)
391
+ self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)])
392
+
393
+ def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True):
394
+ all_encoder_layers = []
395
+ for layer_module in self.layer:
396
+ hidden_states = layer_module(hidden_states, attention_mask)
397
+ if output_all_encoded_layers:
398
+ all_encoder_layers.append(hidden_states)
399
+ if not output_all_encoded_layers:
400
+ all_encoder_layers.append(hidden_states)
401
+ return all_encoder_layers
402
+
403
+
404
+ class BertPooler(nn.Module):
405
+ def __init__(self, config):
406
+ super(BertPooler, self).__init__()
407
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
408
+ self.activation = nn.Tanh()
409
+
410
+ def forward(self, hidden_states):
411
+ # We "pool" the model by simply taking the hidden state corresponding
412
+ # to the first token.
413
+ first_token_tensor = hidden_states[:, 0]
414
+ pooled_output = self.dense(first_token_tensor)
415
+ pooled_output = self.activation(pooled_output)
416
+ return pooled_output
417
+
418
+
419
+ class BertPredictionHeadTransform(nn.Module):
420
+ def __init__(self, config):
421
+ super(BertPredictionHeadTransform, self).__init__()
422
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
423
+ if isinstance(config.hidden_act, str) or (sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode)):
424
+ self.transform_act_fn = ACT2FN[config.hidden_act]
425
+ else:
426
+ self.transform_act_fn = config.hidden_act
427
+ self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
428
+
429
+ def forward(self, hidden_states):
430
+ hidden_states = self.dense(hidden_states)
431
+ hidden_states = self.transform_act_fn(hidden_states)
432
+ hidden_states = self.LayerNorm(hidden_states)
433
+ return hidden_states
434
+
435
+
436
+ class BertLMPredictionHead(nn.Module):
437
+ def __init__(self, config, bert_model_embedding_weights):
438
+ super(BertLMPredictionHead, self).__init__()
439
+ self.transform = BertPredictionHeadTransform(config)
440
+
441
+ # The output weights are the same as the input embeddings, but there is
442
+ # an output-only bias for each token.
443
+ self.decoder = nn.Linear(bert_model_embedding_weights.size(1),
444
+ bert_model_embedding_weights.size(0),
445
+ bias=False)
446
+ self.decoder.weight = bert_model_embedding_weights
447
+ self.bias = nn.Parameter(torch.zeros(bert_model_embedding_weights.size(0)))
448
+
449
+ def forward(self, hidden_states):
450
+ hidden_states = self.transform(hidden_states)
451
+ hidden_states = self.decoder(hidden_states) + self.bias
452
+ return hidden_states
453
+
454
+
455
+ class BertOnlyMLMHead(nn.Module):
456
+ def __init__(self, config, bert_model_embedding_weights):
457
+ super(BertOnlyMLMHead, self).__init__()
458
+ self.predictions = BertLMPredictionHead(config, bert_model_embedding_weights)
459
+
460
+ def forward(self, sequence_output):
461
+ prediction_scores = self.predictions(sequence_output)
462
+ return prediction_scores
463
+
464
+
465
+ class BertOnlyNSPHead(nn.Module):
466
+ def __init__(self, config):
467
+ super(BertOnlyNSPHead, self).__init__()
468
+ self.seq_relationship = nn.Linear(config.hidden_size, 2)
469
+
470
+ def forward(self, pooled_output):
471
+ seq_relationship_score = self.seq_relationship(pooled_output)
472
+ return seq_relationship_score
473
+
474
+
475
+ class BertPreTrainingHeads(nn.Module):
476
+ def __init__(self, config, bert_model_embedding_weights):
477
+ super(BertPreTrainingHeads, self).__init__()
478
+ self.predictions = BertLMPredictionHead(config, bert_model_embedding_weights)
479
+ self.seq_relationship = nn.Linear(config.hidden_size, 2)
480
+
481
+ def forward(self, sequence_output, pooled_output):
482
+ prediction_scores = self.predictions(sequence_output)
483
+ seq_relationship_score = self.seq_relationship(pooled_output)
484
+ return prediction_scores, seq_relationship_score
485
+
486
+
487
+ class BertPreTrainedModel(nn.Module):
488
+ """ An abstract class to handle weights initialization and
489
+ a simple interface for dowloading and loading pretrained models.
490
+ """
491
+ def __init__(self, config, *inputs, **kwargs):
492
+ super(BertPreTrainedModel, self).__init__()
493
+ if not isinstance(config, BertConfig):
494
+ raise ValueError(
495
+ "Parameter config in `{}(config)` should be an instance of class `BertConfig`. "
496
+ "To create a model from a Google pretrained model use "
497
+ "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
498
+ self.__class__.__name__, self.__class__.__name__
499
+ ))
500
+ self.config = config
501
+
502
+ def init_bert_weights(self, module):
503
+ """ Initialize the weights.
504
+ """
505
+ if isinstance(module, (nn.Linear, nn.Embedding)):
506
+ # Slightly different from the TF version which uses truncated_normal for initialization
507
+ # cf https://github.com/pytorch/pytorch/pull/5617
508
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
509
+ elif isinstance(module, BertLayerNorm):
510
+ module.bias.data.zero_()
511
+ module.weight.data.fill_(1.0)
512
+ if isinstance(module, nn.Linear) and module.bias is not None:
513
+ module.bias.data.zero_()
514
+
515
+ @classmethod
516
+ def from_pretrained(cls, pretrained_model_name_or_path, state_dict=None, cache_dir=None,
517
+ from_tf=False, *inputs, **kwargs):
518
+ """
519
+ Instantiate a BertPreTrainedModel from a pre-trained model file or a pytorch state dict.
520
+ Download and cache the pre-trained model file if needed.
521
+
522
+ Params:
523
+ pretrained_model_name_or_path: either:
524
+ - a str with the name of a pre-trained model to load selected in the list of:
525
+ . `bert-base-uncased`
526
+ . `bert-large-uncased`
527
+ . `bert-base-cased`
528
+ . `bert-large-cased`
529
+ . `bert-base-multilingual-uncased`
530
+ . `bert-base-multilingual-cased`
531
+ . `bert-base-chinese`
532
+ - a path or url to a pretrained model archive containing:
533
+ . `bert_config.json` a configuration file for the model
534
+ . `pytorch_model.bin` a PyTorch dump of a BertForPreTraining instance
535
+ - a path or url to a pretrained model archive containing:
536
+ . `bert_config.json` a configuration file for the model
537
+ . `model.chkpt` a TensorFlow checkpoint
538
+ from_tf: should we load the weights from a locally saved TensorFlow checkpoint
539
+ cache_dir: an optional path to a folder in which the pre-trained models will be cached.
540
+ state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of Google pre-trained models
541
+ *inputs, **kwargs: additional input for the specific Bert class
542
+ (ex: num_labels for BertForSequenceClassification)
543
+ """
544
+ if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP:
545
+ archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name_or_path]
546
+ else:
547
+ archive_file = pretrained_model_name_or_path
548
+ # redirect to the cache, if necessary
549
+ try:
550
+ resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
551
+ except EnvironmentError:
552
+ logger.error(
553
+ "Model name '{}' was not found in model name list ({}). "
554
+ "We assumed '{}' was a path or url but couldn't find any file "
555
+ "associated to this path or url.".format(
556
+ pretrained_model_name_or_path,
557
+ ', '.join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()),
558
+ archive_file))
559
+ return None
560
+ if resolved_archive_file == archive_file:
561
+ logger.info("loading archive file {}".format(archive_file))
562
+ else:
563
+ logger.info("loading archive file {} from cache at {}".format(
564
+ archive_file, resolved_archive_file))
565
+ tempdir = None
566
+ if os.path.isdir(resolved_archive_file) or from_tf:
567
+ serialization_dir = resolved_archive_file
568
+ else:
569
+ # Extract archive to temp dir
570
+ tempdir = tempfile.mkdtemp()
571
+ logger.info("extracting archive file {} to temp dir {}".format(
572
+ resolved_archive_file, tempdir))
573
+ with tarfile.open(resolved_archive_file, 'r:gz') as archive:
574
+ archive.extractall(tempdir)
575
+ serialization_dir = tempdir
576
+ # Load config
577
+ config_file = os.path.join(serialization_dir, CONFIG_NAME)
578
+ config = BertConfig.from_json_file(config_file)
579
+ logger.info("Model config {}".format(config))
580
+ # Instantiate model.
581
+ model = cls(config, *inputs, **kwargs)
582
+ if state_dict is None and not from_tf:
583
+ weights_path = os.path.join(serialization_dir, WEIGHTS_NAME)
584
+ state_dict = torch.load(weights_path, map_location='cpu' if not torch.cuda.is_available() else None)
585
+ if tempdir:
586
+ # Clean up temp dir
587
+ shutil.rmtree(tempdir)
588
+ if from_tf:
589
+ # Directly load from a TensorFlow checkpoint
590
+ weights_path = os.path.join(serialization_dir, TF_WEIGHTS_NAME)
591
+ return load_tf_weights_in_bert(model, weights_path)
592
+ # Load from a PyTorch state_dict
593
+ old_keys = []
594
+ new_keys = []
595
+ for key in state_dict.keys():
596
+ new_key = None
597
+ if 'gamma' in key:
598
+ new_key = key.replace('gamma', 'weight')
599
+ if 'beta' in key:
600
+ new_key = key.replace('beta', 'bias')
601
+ if new_key:
602
+ old_keys.append(key)
603
+ new_keys.append(new_key)
604
+ for old_key, new_key in zip(old_keys, new_keys):
605
+ state_dict[new_key] = state_dict.pop(old_key)
606
+
607
+ missing_keys = []
608
+ unexpected_keys = []
609
+ error_msgs = []
610
+ # copy state_dict so _load_from_state_dict can modify it
611
+ metadata = getattr(state_dict, '_metadata', None)
612
+ state_dict = state_dict.copy()
613
+ if metadata is not None:
614
+ state_dict._metadata = metadata
615
+
616
+ def load(module, prefix=''):
617
+ local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
618
+ module._load_from_state_dict(
619
+ state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
620
+ for name, child in module._modules.items():
621
+ if child is not None:
622
+ load(child, prefix + name + '.')
623
+ start_prefix = ''
624
+ if not hasattr(model, 'bert') and any(s.startswith('bert.') for s in state_dict.keys()):
625
+ start_prefix = 'bert.'
626
+ load(model, prefix=start_prefix)
627
+ if len(missing_keys) > 0:
628
+ logger.info("Weights of {} not initialized from pretrained model: {}".format(
629
+ model.__class__.__name__, missing_keys))
630
+ if len(unexpected_keys) > 0:
631
+ logger.info("Weights from pretrained model not used in {}: {}".format(
632
+ model.__class__.__name__, unexpected_keys))
633
+ if len(error_msgs) > 0:
634
+ raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
635
+ model.__class__.__name__, "\n\t".join(error_msgs)))
636
+ return model
637
+
638
+
639
+ class BertModel(BertPreTrainedModel):
640
+ """BERT model ("Bidirectional Embedding Representations from a Transformer").
641
+
642
+ Params:
643
+ config: a BertConfig class instance with the configuration to build a new model
644
+
645
+ Inputs:
646
+ `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
647
+ with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
648
+ `extract_features.py`, `run_classifier.py` and `run_squad.py`)
649
+ `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
650
+ types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
651
+ a `sentence B` token (see BERT paper for more details).
652
+ `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
653
+ selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
654
+ input sequence length in the current batch. It's the mask that we typically use for attention when
655
+ a batch has varying length sentences.
656
+ `output_all_encoded_layers`: boolean which controls the content of the `encoded_layers` output as described below. Default: `True`.
657
+
658
+ Outputs: Tuple of (encoded_layers, pooled_output)
659
+ `encoded_layers`: controled by `output_all_encoded_layers` argument:
660
+ - `output_all_encoded_layers=True`: output a list of the full sequences of encoded-hidden-states at the end
661
+ of each attention block (i.e. 12 full sequences for BERT-base, 24 for BERT-large), each
662
+ encoded-hidden-state is a torch.FloatTensor of size [batch_size, sequence_length, hidden_size],
663
+ - `output_all_encoded_layers=False`: output only the full sequence of hidden-states corresponding
664
+ to the last attention block of shape [batch_size, sequence_length, hidden_size],
665
+ `pooled_output`: a torch.FloatTensor of size [batch_size, hidden_size] which is the output of a
666
+ classifier pretrained on top of the hidden state associated to the first character of the
667
+ input (`CLS`) to train on the Next-Sentence task (see BERT's paper).
668
+
669
+ Example usage:
670
+ ```python
671
+ # Already been converted into WordPiece token ids
672
+ input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
673
+ input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
674
+ token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
675
+
676
+ config = modeling.BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
677
+ num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
678
+
679
+ model = modeling.BertModel(config=config)
680
+ all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
681
+ ```
682
+ """
683
+ def __init__(self, config):
684
+ super(BertModel, self).__init__(config)
685
+ self.embeddings = BertEmbeddings(config)
686
+ self.encoder = BertEncoder(config)
687
+ self.pooler = BertPooler(config)
688
+ self.apply(self.init_bert_weights)
689
+
690
+ def forward(self, input_ids, token_type_ids=None, attention_mask=None, output_all_encoded_layers=True):
691
+ if attention_mask is None:
692
+ attention_mask = torch.ones_like(input_ids)
693
+ if token_type_ids is None:
694
+ token_type_ids = torch.zeros_like(input_ids)
695
+
696
+ # We create a 3D attention mask from a 2D tensor mask.
697
+ # Sizes are [batch_size, 1, 1, to_seq_length]
698
+ # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
699
+ # this attention mask is more simple than the triangular masking of causal attention
700
+ # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
701
+ extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
702
+
703
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
704
+ # masked positions, this operation will create a tensor which is 0.0 for
705
+ # positions we want to attend and -10000.0 for masked positions.
706
+ # Since we are adding it to the raw scores before the softmax, this is
707
+ # effectively the same as removing these entirely.
708
+ extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
709
+ extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
710
+
711
+ embedding_output = self.embeddings(input_ids, token_type_ids)
712
+ encoded_layers = self.encoder(embedding_output,
713
+ extended_attention_mask,
714
+ output_all_encoded_layers=output_all_encoded_layers)
715
+ sequence_output = encoded_layers[-1]
716
+ pooled_output = self.pooler(sequence_output)
717
+ if not output_all_encoded_layers:
718
+ encoded_layers = encoded_layers[-1]
719
+ return encoded_layers, pooled_output
720
+
721
+
722
+
723
+
724
+
725
+ class BertForSequenceEncoder(BertPreTrainedModel):
726
+ """BERT model for classification.
727
+ This module is composed of the BERT model with a linear layer on top of
728
+ the pooled output.
729
+ Params:
730
+ `config`: a BertConfig class instance with the configuration to build a new model.
731
+ `num_labels`: the number of classes for the classifier. Default = 2.
732
+ Inputs:
733
+ `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
734
+ with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
735
+ `extract_features.py`, `run_classifier.py` and `run_squad.py`)
736
+ `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
737
+ types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
738
+ a `sentence B` token (see BERT paper for more details).
739
+ `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
740
+ selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
741
+ input sequence length in the current batch. It's the mask that we typically use for attention when
742
+ a batch has varying length sentences.
743
+ `labels`: labels for the classification output: torch.LongTensor of shape [batch_size]
744
+ with indices selected in [0, ..., num_labels].
745
+ Outputs:
746
+ if `labels` is not `None`:
747
+ Outputs the CrossEntropy classification loss of the output with the labels.
748
+ if `labels` is `None`:
749
+ Outputs the classification logits of shape [batch_size, num_labels].
750
+ Example usage:
751
+ ```python
752
+ # Already been converted into WordPiece token ids
753
+ input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
754
+ input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
755
+ token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
756
+ config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
757
+ num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
758
+ num_labels = 2
759
+ model = BertForSequenceClassification(config, num_labels)
760
+ logits = model(input_ids, token_type_ids, input_mask)
761
+ ```
762
+ """
763
+ def __init__(self, config):
764
+ super(BertForSequenceEncoder, self).__init__(config)
765
+ self.bert = BertModel(config)
766
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
767
+ self.apply(self.init_bert_weights)
768
+
769
+ def forward(self, input_ids, attention_mask, token_type_ids):
770
+ output, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False)
771
+ output = self.dropout(output)
772
+ pooled_output = self.dropout(pooled_output)
773
+ return output, pooled_output
774
+
775
+
src/er_client/retrieval_model/data_loader.py ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import numpy as np
4
+ import json
5
+ import re
6
+ from torch.autograd import Variable
7
+
8
+
9
+ def _truncate_seq_pair(tokens_a, tokens_b, max_length):
10
+ """Truncates a sequence pair in place to the maximum length."""
11
+
12
+ # This is a simple heuristic which will always truncate the longer sequence
13
+ # one token at a time. This makes more sense than truncating an equal percent
14
+ # of tokens from each, since if one sequence is very short then each token
15
+ # that's truncated likely contains more information than a longer sequence.
16
+ while True:
17
+ total_length = len(tokens_a) + len(tokens_b)
18
+ if total_length <= max_length:
19
+ break
20
+ if len(tokens_a) > len(tokens_b):
21
+ tokens_a.pop()
22
+ else:
23
+ tokens_b.pop()
24
+
25
+
26
+ def tok2int_sent(sentence, tokenizer, max_seq_length):
27
+ """Loads a data file into a list of `InputBatch`s."""
28
+ sent_a, sent_b = sentence
29
+ tokens_a = tokenizer.tokenize(sent_a)
30
+
31
+ tokens_b = None
32
+ if sent_b:
33
+ tokens_b = tokenizer.tokenize(sent_b)
34
+ _truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3)
35
+ else:
36
+ # Account for [CLS] and [SEP] with "- 2"
37
+ if len(tokens_a) > max_seq_length - 2:
38
+ tokens_a = tokens_a[:(max_seq_length - 2)]
39
+
40
+ tokens = ["[CLS]"] + tokens_a + ["[SEP]"]
41
+ segment_ids = [0] * len(tokens)
42
+ if tokens_b:
43
+ tokens = tokens + tokens_b + ["[SEP]"]
44
+ segment_ids += [1] * (len(tokens_b) + 1)
45
+ input_ids = tokenizer.convert_tokens_to_ids(tokens)
46
+ input_mask = [1] * len(input_ids)
47
+ padding = [0] * (max_seq_length - len(input_ids))
48
+
49
+ input_ids += padding
50
+ input_mask += padding
51
+ segment_ids += padding
52
+
53
+ assert len(input_ids) == max_seq_length
54
+ assert len(input_mask) == max_seq_length
55
+ assert len(segment_ids) == max_seq_length
56
+
57
+ return input_ids, input_mask, segment_ids
58
+
59
+
60
+ def tok2int_list(src_list, tokenizer, max_seq_length, max_seq_size=-1):
61
+ inp_padding = list()
62
+ msk_padding = list()
63
+ seg_padding = list()
64
+ for step, sent in enumerate(src_list):
65
+ input_ids, input_mask, input_seg = tok2int_sent(sent, tokenizer, max_seq_length)
66
+ inp_padding.append(input_ids)
67
+ msk_padding.append(input_mask)
68
+ seg_padding.append(input_seg)
69
+ # if max_seq_size != -1:
70
+ # inp_padding = inp_padding[:max_seq_size]
71
+ # msk_padding = msk_padding[:max_seq_size]
72
+ # seg_padding = seg_padding[:max_seq_size]
73
+ # inp_padding += ([[0] * max_seq_length] * (max_seq_size - len(inp_padding)))
74
+ # msk_padding += ([[0] * max_seq_length] * (max_seq_size - len(msk_padding)))
75
+ # seg_padding += ([[0] * max_seq_length] * (max_seq_size - len(seg_padding)))
76
+ return inp_padding, msk_padding, seg_padding
77
+
78
+
79
+ class DataLoader(object):
80
+ ''' For data iteration '''
81
+
82
+ def __init__(self, data_path, tokenizer, args, test=False, cuda=True, batch_size=64):
83
+ self.cuda = cuda
84
+
85
+ self.batch_size = batch_size
86
+ self.tokenizer = tokenizer
87
+ self.max_len = args.max_len
88
+ self.evi_num = args.evi_num
89
+ self.threshold = args.threshold
90
+ self.data_path = data_path
91
+ self.test = test
92
+ examples = self.read_file(data_path)
93
+ self.examples = examples
94
+ self.total_num = len(examples)
95
+ if self.test:
96
+ self.total_num = 100000
97
+ self.total_step = np.ceil(self.total_num * 1.0 / batch_size)
98
+ self.shuffle()
99
+ else:
100
+ self.total_step = self.total_num / batch_size
101
+ self.shuffle()
102
+ self.step = 0
103
+
104
+ def process_sent(self, sentence):
105
+ sentence = re.sub(" \-LSB\-.*?\-RSB\-", "", sentence)
106
+ sentence = re.sub("\-LRB\- \-RRB\- ", "", sentence)
107
+ sentence = re.sub(" -LRB-", " ( ", sentence)
108
+ sentence = re.sub("-RRB-", " )", sentence)
109
+ sentence = re.sub("--", "-", sentence)
110
+ sentence = re.sub("``", '"', sentence)
111
+ sentence = re.sub("''", '"', sentence)
112
+
113
+ return sentence
114
+
115
+ def process_wiki_title(self, title):
116
+ title = re.sub("_", " ", title)
117
+ title = re.sub(" -LRB-", " ( ", title)
118
+ title = re.sub("-RRB-", " )", title)
119
+ title = re.sub("-COLON-", ":", title)
120
+ return title
121
+
122
+ def read_file(self, data_path):
123
+ examples = list()
124
+ with open(data_path) as fin:
125
+ for step, line in enumerate(fin):
126
+ sublines = line.strip().split("\t")
127
+ examples.append(
128
+ [self.process_sent(sublines[0]), self.process_sent(sublines[2]), self.process_sent(sublines[4])])
129
+ return examples
130
+
131
+ def shuffle(self):
132
+ np.random.shuffle(self.examples)
133
+
134
+ def __iter__(self):
135
+ return self
136
+
137
+ def __next__(self):
138
+ return self.next()
139
+
140
+ def __len__(self):
141
+ return self._n_batch
142
+
143
+ def next(self):
144
+ ''' Get the next batch '''
145
+ if self.step < self.total_step:
146
+ examples = self.examples[self.step * self.batch_size: (self.step + 1) * self.batch_size]
147
+ pos_inputs = list()
148
+ neg_inputs = list()
149
+ for example in examples:
150
+ pos_inputs.append([example[0], example[1]])
151
+ neg_inputs.append([example[0], example[2]])
152
+ inp_pos, msk_pos, seg_pos = tok2int_list(pos_inputs, self.tokenizer, self.max_len)
153
+ inp_neg, msk_neg, seg_neg = tok2int_list(neg_inputs, self.tokenizer, self.max_len)
154
+
155
+ inp_tensor_pos = Variable(
156
+ torch.LongTensor(inp_pos))
157
+ msk_tensor_pos = Variable(
158
+ torch.LongTensor(msk_pos))
159
+ seg_tensor_pos = Variable(
160
+ torch.LongTensor(seg_pos))
161
+ inp_tensor_neg = Variable(
162
+ torch.LongTensor(inp_neg))
163
+ msk_tensor_neg = Variable(
164
+ torch.LongTensor(msk_neg))
165
+ seg_tensor_neg = Variable(
166
+ torch.LongTensor(seg_neg))
167
+
168
+ if self.cuda:
169
+ inp_tensor_pos = inp_tensor_pos.cuda()
170
+ msk_tensor_pos = msk_tensor_pos.cuda()
171
+ seg_tensor_pos = seg_tensor_pos.cuda()
172
+ inp_tensor_neg = inp_tensor_neg.cuda()
173
+ msk_tensor_neg = msk_tensor_neg.cuda()
174
+ seg_tensor_neg = seg_tensor_neg.cuda()
175
+ self.step += 1
176
+ return inp_tensor_pos, msk_tensor_pos, seg_tensor_pos, inp_tensor_neg, msk_tensor_neg, seg_tensor_neg
177
+ else:
178
+ self.step = 0
179
+ if not self.test:
180
+ # examples = self.read_file(self.data_path)
181
+ # self.examples = examples
182
+ self.shuffle()
183
+ raise StopIteration()
184
+
185
+
186
+ class DataLoaderTest(object):
187
+ ''' For data iteration '''
188
+
189
+ def __init__(self, data_path, tokenizer, args, cuda=True, batch_size=64):
190
+ self.cuda = cuda
191
+
192
+ self.batch_size = batch_size
193
+ self.tokenizer = tokenizer
194
+ self.max_len = args.max_len
195
+ self.evi_num = args.evi_num
196
+ self.threshold = args.threshold
197
+ self.data_path = data_path
198
+ inputs, ids, evi_list = self.read_all(data_path)
199
+ self.inputs = inputs
200
+ self.ids = ids
201
+ self.evi_list = evi_list
202
+
203
+ self.total_num = len(inputs)
204
+ self.total_step = np.ceil(self.total_num * 1.0 / batch_size)
205
+ self.step = 0
206
+
207
+ def process_sent(self, sentence):
208
+ sentence = re.sub(" \-LSB\-.*?\-RSB\-", "", sentence)
209
+ sentence = re.sub("\-LRB\- \-RRB\- ", "", sentence)
210
+ sentence = re.sub(" -LRB-", " ( ", sentence)
211
+ sentence = re.sub("-RRB-", " )", sentence)
212
+ sentence = re.sub("--", "-", sentence)
213
+ sentence = re.sub("``", '"', sentence)
214
+ sentence = re.sub("''", '"', sentence)
215
+
216
+ return sentence
217
+
218
+ def process_wiki_title(self, title):
219
+ title = re.sub("_", " ", title)
220
+ title = re.sub(" -LRB-", " ( ", title)
221
+ title = re.sub("-RRB-", " )", title)
222
+ title = re.sub("-COLON-", ":", title)
223
+ return title
224
+
225
+ def read_all(self, data):
226
+ if not isinstance(data, list):
227
+ with open(data) as f:
228
+ data_ = [json.loads(line) for line in f]
229
+ else:
230
+ data_ = data
231
+ inputs = list()
232
+ ids = list()
233
+ evi_list = list()
234
+ for instance in data_:
235
+ claim = instance['claim']
236
+ id = instance['id']
237
+ for evidence in instance['evidence']:
238
+ ids.append(id)
239
+ inputs.append([self.process_sent(claim), self.process_sent(evidence[2])])
240
+ evi_list.append(evidence)
241
+ return inputs, ids, evi_list
242
+
243
+ def shuffle(self):
244
+ np.random.shuffle(self.examples)
245
+
246
+ def __iter__(self):
247
+ return self
248
+
249
+ def __next__(self):
250
+ return self.next()
251
+
252
+ def __len__(self):
253
+ return self._n_batch
254
+
255
+ def next(self):
256
+ ''' Get the next batch '''
257
+ if self.step < self.total_step:
258
+ inputs = self.inputs[self.step * self.batch_size: (self.step + 1) * self.batch_size]
259
+ ids = self.ids[self.step * self.batch_size: (self.step + 1) * self.batch_size]
260
+ evi_list = self.evi_list[self.step * self.batch_size: (self.step + 1) * self.batch_size]
261
+ inp, msk, seg = tok2int_list(inputs, self.tokenizer, self.max_len, -1)
262
+ inp_tensor_input = Variable(
263
+ torch.LongTensor(inp))
264
+ msk_tensor_input = Variable(
265
+ torch.LongTensor(msk))
266
+ seg_tensor_input = Variable(
267
+ torch.LongTensor(seg))
268
+ if self.cuda:
269
+ inp_tensor_input = inp_tensor_input.cuda()
270
+ msk_tensor_input = msk_tensor_input.cuda()
271
+ seg_tensor_input = seg_tensor_input.cuda()
272
+ self.step += 1
273
+ return inp_tensor_input, msk_tensor_input, seg_tensor_input, ids, evi_list
274
+ else:
275
+ self.step = 0
276
+ raise StopIteration()
src/er_client/retrieval_model/file_utils.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utilities for working with the local dataset cache.
3
+ This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp
4
+ Copyright by the AllenNLP authors.
5
+ """
6
+ from __future__ import (absolute_import, division, print_function, unicode_literals)
7
+
8
+ import json
9
+ import logging
10
+ import os
11
+ import shutil
12
+ import tempfile
13
+ from functools import wraps
14
+ from hashlib import sha256
15
+ import sys
16
+ from io import open
17
+
18
+ import boto3
19
+ import requests
20
+ from botocore.exceptions import ClientError
21
+ from tqdm import tqdm
22
+
23
+ try:
24
+ from urllib.parse import urlparse
25
+ except ImportError:
26
+ from urlparse import urlparse
27
+
28
+ try:
29
+ from pathlib import Path
30
+ PYTORCH_PRETRAINED_BERT_CACHE = Path(os.getenv('PYTORCH_PRETRAINED_BERT_CACHE',
31
+ Path.home() / '.pytorch_pretrained_bert'))
32
+ except AttributeError:
33
+ PYTORCH_PRETRAINED_BERT_CACHE = os.getenv('PYTORCH_PRETRAINED_BERT_CACHE',
34
+ os.path.join(os.path.expanduser("~"), '.pytorch_pretrained_bert'))
35
+
36
+ logger = logging.getLogger(__name__) # pylint: disable=invalid-name
37
+
38
+
39
+ def url_to_filename(url, etag=None):
40
+ """
41
+ Convert `url` into a hashed filename in a repeatable way.
42
+ If `etag` is specified, append its hash to the url's, delimited
43
+ by a period.
44
+ """
45
+ url_bytes = url.encode('utf-8')
46
+ url_hash = sha256(url_bytes)
47
+ filename = url_hash.hexdigest()
48
+
49
+ if etag:
50
+ etag_bytes = etag.encode('utf-8')
51
+ etag_hash = sha256(etag_bytes)
52
+ filename += '.' + etag_hash.hexdigest()
53
+
54
+ return filename
55
+
56
+
57
+ def filename_to_url(filename, cache_dir=None):
58
+ """
59
+ Return the url and etag (which may be ``None``) stored for `filename`.
60
+ Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist.
61
+ """
62
+ if cache_dir is None:
63
+ cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
64
+ if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
65
+ cache_dir = str(cache_dir)
66
+
67
+ cache_path = os.path.join(cache_dir, filename)
68
+ if not os.path.exists(cache_path):
69
+ raise EnvironmentError("file {} not found".format(cache_path))
70
+
71
+ meta_path = cache_path + '.json'
72
+ if not os.path.exists(meta_path):
73
+ raise EnvironmentError("file {} not found".format(meta_path))
74
+
75
+ with open(meta_path, encoding="utf-8") as meta_file:
76
+ metadata = json.load(meta_file)
77
+ url = metadata['url']
78
+ etag = metadata['etag']
79
+
80
+ return url, etag
81
+
82
+
83
+ def cached_path(url_or_filename, cache_dir=None):
84
+ """
85
+ Given something that might be a URL (or might be a local path),
86
+ determine which. If it's a URL, download the file and cache it, and
87
+ return the path to the cached file. If it's already a local path,
88
+ make sure the file exists and then return the path.
89
+ """
90
+ if cache_dir is None:
91
+ cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
92
+ if sys.version_info[0] == 3 and isinstance(url_or_filename, Path):
93
+ url_or_filename = str(url_or_filename)
94
+ if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
95
+ cache_dir = str(cache_dir)
96
+
97
+ parsed = urlparse(url_or_filename)
98
+
99
+ if parsed.scheme in ('http', 'https', 's3'):
100
+ # URL, so get it from the cache (downloading if necessary)
101
+ return get_from_cache(url_or_filename, cache_dir)
102
+ elif os.path.exists(url_or_filename):
103
+ # File, and it exists.
104
+ return url_or_filename
105
+ elif parsed.scheme == '':
106
+ # File, but it doesn't exist.
107
+ raise EnvironmentError("file {} not found".format(url_or_filename))
108
+ else:
109
+ # Something unknown
110
+ raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename))
111
+
112
+
113
+ def split_s3_path(url):
114
+ """Split a full s3 path into the bucket name and path."""
115
+ parsed = urlparse(url)
116
+ if not parsed.netloc or not parsed.path:
117
+ raise ValueError("bad s3 path {}".format(url))
118
+ bucket_name = parsed.netloc
119
+ s3_path = parsed.path
120
+ # Remove '/' at beginning of path.
121
+ if s3_path.startswith("/"):
122
+ s3_path = s3_path[1:]
123
+ return bucket_name, s3_path
124
+
125
+
126
+ def s3_request(func):
127
+ """
128
+ Wrapper function for s3 requests in order to create more helpful error
129
+ messages.
130
+ """
131
+
132
+ @wraps(func)
133
+ def wrapper(url, *args, **kwargs):
134
+ try:
135
+ return func(url, *args, **kwargs)
136
+ except ClientError as exc:
137
+ if int(exc.response["Error"]["Code"]) == 404:
138
+ raise EnvironmentError("file {} not found".format(url))
139
+ else:
140
+ raise
141
+
142
+ return wrapper
143
+
144
+
145
+ @s3_request
146
+ def s3_etag(url):
147
+ """Check ETag on S3 object."""
148
+ s3_resource = boto3.resource("s3")
149
+ bucket_name, s3_path = split_s3_path(url)
150
+ s3_object = s3_resource.Object(bucket_name, s3_path)
151
+ return s3_object.e_tag
152
+
153
+
154
+ @s3_request
155
+ def s3_get(url, temp_file):
156
+ """Pull a file directly from S3."""
157
+ s3_resource = boto3.resource("s3")
158
+ bucket_name, s3_path = split_s3_path(url)
159
+ s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file)
160
+
161
+
162
+ def http_get(url, temp_file):
163
+ req = requests.get(url, stream=True)
164
+ content_length = req.headers.get('Content-Length')
165
+ total = int(content_length) if content_length is not None else None
166
+ progress = tqdm(unit="B", total=total)
167
+ for chunk in req.iter_content(chunk_size=1024):
168
+ if chunk: # filter out keep-alive new chunks
169
+ progress.update(len(chunk))
170
+ temp_file.write(chunk)
171
+ progress.close()
172
+
173
+
174
+ def get_from_cache(url, cache_dir=None):
175
+ """
176
+ Given a URL, look for the corresponding dataset in the local cache.
177
+ If it's not there, download it. Then return the path to the cached file.
178
+ """
179
+ if cache_dir is None:
180
+ cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
181
+ if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
182
+ cache_dir = str(cache_dir)
183
+
184
+ if not os.path.exists(cache_dir):
185
+ os.makedirs(cache_dir)
186
+
187
+ # Get eTag to add to filename, if it exists.
188
+ if url.startswith("s3://"):
189
+ etag = s3_etag(url)
190
+ else:
191
+ response = requests.head(url, allow_redirects=True)
192
+ if response.status_code != 200:
193
+ raise IOError("HEAD request failed for url {} with status code {}"
194
+ .format(url, response.status_code))
195
+ etag = response.headers.get("ETag")
196
+
197
+ filename = url_to_filename(url, etag)
198
+
199
+ # get cache path to put the file
200
+ cache_path = os.path.join(cache_dir, filename)
201
+
202
+ if not os.path.exists(cache_path):
203
+ # Download to temporary file, then copy to cache dir once finished.
204
+ # Otherwise you get corrupt cache entries if the download gets interrupted.
205
+ with tempfile.NamedTemporaryFile() as temp_file:
206
+ logger.info("%s not found in cache, downloading to %s", url, temp_file.name)
207
+
208
+ # GET file object
209
+ if url.startswith("s3://"):
210
+ s3_get(url, temp_file)
211
+ else:
212
+ http_get(url, temp_file)
213
+
214
+ # we are copying the file before closing it, so flush to avoid truncation
215
+ temp_file.flush()
216
+ # shutil.copyfileobj() starts at the current position, so go to the start
217
+ temp_file.seek(0)
218
+
219
+ logger.info("copying %s to cache at %s", temp_file.name, cache_path)
220
+ with open(cache_path, 'wb') as cache_file:
221
+ shutil.copyfileobj(temp_file, cache_file)
222
+
223
+ logger.info("creating metadata file for %s", cache_path)
224
+ meta = {'url': url, 'etag': etag}
225
+ meta_path = cache_path + '.json'
226
+ with open(meta_path, 'w', encoding="utf-8") as meta_file:
227
+ json.dump(meta, meta_file)
228
+
229
+ logger.info("removing temp file %s", temp_file.name)
230
+
231
+ return cache_path
232
+
233
+
234
+ def read_set_from_file(filename):
235
+ '''
236
+ Extract a de-duped collection (set) of text from a file.
237
+ Expected file format is one item per line.
238
+ '''
239
+ collection = set()
240
+ with open(filename, 'r', encoding='utf-8') as file_:
241
+ for line in file_:
242
+ collection.add(line.rstrip())
243
+ return collection
244
+
245
+
246
+ def get_file_extension(path, dot=True, lower=True):
247
+ ext = os.path.splitext(path)[1]
248
+ ext = ext if dot else ext[1:]
249
+ return ext.lower() if lower else ext
src/er_client/retrieval_model/models.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from torch.nn import BatchNorm1d, Linear, ReLU
6
+ from .bert_model import BertForSequenceEncoder
7
+
8
+ from torch.nn import BatchNorm1d, Linear, ReLU
9
+ from .bert_model import BertForSequenceEncoder
10
+ from torch.autograd import Variable
11
+ import numpy as np
12
+
13
+
14
+
15
+
16
+ def kernal_mus(n_kernels):
17
+ """
18
+ get the mu for each guassian kernel. Mu is the middle of each bin
19
+ :param n_kernels: number of kernels (including exact match). first one is exact match
20
+ :return: l_mu, a list of mu.
21
+ """
22
+ l_mu = [1]
23
+ if n_kernels == 1:
24
+ return l_mu
25
+
26
+ bin_size = 2.0 / (n_kernels - 1) # score range from [-1, 1]
27
+ l_mu.append(1 - bin_size / 2) # mu: middle of the bin
28
+ for i in range(1, n_kernels - 1):
29
+ l_mu.append(l_mu[i] - bin_size)
30
+ return l_mu
31
+
32
+
33
+ def kernel_sigmas(n_kernels):
34
+ """
35
+ get sigmas for each guassian kernel.
36
+ :param n_kernels: number of kernels (including exactmath.)
37
+ :param lamb:
38
+ :param use_exact:
39
+ :return: l_sigma, a list of simga
40
+ """
41
+ bin_size = 2.0 / (n_kernels - 1)
42
+ l_sigma = [0.001] # for exact match. small variance -> exact match
43
+ if n_kernels == 1:
44
+ return l_sigma
45
+
46
+ l_sigma += [0.1] * (n_kernels - 1)
47
+ return l_sigma
48
+
49
+ class inference_model(nn.Module):
50
+ def __init__(self, bert_model, args):
51
+ super(inference_model, self).__init__()
52
+ self.bert_hidden_dim = args.bert_hidden_dim
53
+ self.dropout = nn.Dropout(args.dropout)
54
+ self.max_len = args.max_len
55
+ self.num_labels = args.num_labels
56
+ self.pred_model = bert_model
57
+ #self.proj_hidden = nn.Linear(self.bert_hidden_dim, 128)
58
+ self.proj_match = nn.Linear(self.bert_hidden_dim, 1)
59
+
60
+
61
+ def forward(self, inp_tensor, msk_tensor, seg_tensor):
62
+ _, inputs = self.pred_model(inp_tensor, msk_tensor, seg_tensor)
63
+ inputs = self.dropout(inputs)
64
+ score = self.proj_match(inputs).squeeze(-1)
65
+ score = torch.tanh(score)
66
+ return score
src/er_client/retrieval_model/process_data.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import argparse
4
+
5
+ if __name__ == "__main__":
6
+ parser = argparse.ArgumentParser()
7
+ parser.add_argument('--gold_file')
8
+ parser.add_argument('--retrieval_file')
9
+ parser.add_argument('--output')
10
+ parser.add_argument('--test', action='store_true', default=False)
11
+ args = parser.parse_args()
12
+ filter_dict = dict()
13
+ data_dict = dict()
14
+ golden_dict = dict()
15
+ with open(args.gold_file) as f:
16
+ for line in f:
17
+ data = json.loads(line)
18
+ data_dict[data["id"]] = {"id": data["id"], "evidence":[], "claim": data["claim"]}
19
+ if "label" in data:
20
+ data_dict[data["id"]]["label"] = data["label"]
21
+ if not args.test:
22
+ for evidence in data["evidence"]:
23
+ data_dict[data["id"]]["evidence"].append([evidence[0], evidence[1], evidence[2], 1.0])
24
+ string = str(data["id"]) + "_" + evidence[0] + "_" + str(evidence[1])
25
+ golden_dict[string] = 1
26
+ with open(args.retrieval_file) as f:
27
+ for line in f:
28
+ data = json.loads(line)
29
+ for step, evidence in enumerate(data["evidence"]):
30
+ string = str(data["id"]) + "_" + str(evidence[0]) + "_" + str(evidence[1])
31
+ if string not in golden_dict and string not in filter_dict:
32
+ data_dict[data["id"]]["evidence"].append([evidence[0], evidence[1], evidence[2], evidence[4]])
33
+ filter_dict[string] = 1
34
+ with open(args.output, "w") as out:
35
+ for data in data_dict.values():
36
+ evidence_tmp = data["evidence"]
37
+ evidence_tmp = sorted(evidence_tmp, key=lambda x:x[3], reverse=True)
38
+ data["evidence"] = evidence_tmp[:5]
39
+ out.write(json.dumps(data) + "\n")
40
+
41
+
src/er_client/retrieval_model/test.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import argparse
3
+ import os
4
+ import json
5
+ import torch
6
+ from tqdm import tqdm
7
+ from transformers import BertTokenizer
8
+
9
+ from .models import inference_model
10
+ from .data_loader import DataLoaderTest
11
+ from .bert_model import BertForSequenceEncoder
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ def save_to_file(all_predict, outpath, evi_num):
17
+ with open(outpath, "w") as out:
18
+ for key, values in all_predict.items():
19
+ sorted_values = sorted(values, key=lambda x:x[-1], reverse=True)
20
+ data = json.dumps({"id": key, "evidence": sorted_values[:evi_num]})
21
+ out.write(data + "\n")
22
+
23
+
24
+ def eval_model(model, validset_reader):
25
+ model.eval()
26
+ all_predict = dict()
27
+ for inp_tensor, msk_tensor, seg_tensor, ids, evi_list in tqdm(validset_reader):
28
+ probs = model(inp_tensor, msk_tensor, seg_tensor)
29
+ probs = probs.tolist()
30
+ assert len(probs) == len(evi_list)
31
+ for i in range(len(probs)):
32
+ if ids[i] not in all_predict:
33
+ all_predict[ids[i]] = []
34
+ #if probs[i][1] >= probs[i][0]:
35
+ all_predict[ids[i]].append(evi_list[i] + [probs[i]])
36
+ return all_predict
37
+
38
+
39
+ if __name__ == "__main__":
40
+ parser = argparse.ArgumentParser()
41
+ parser.add_argument('--test_path', help='train path')
42
+ parser.add_argument('--name', help='train path')
43
+ parser.add_argument("--batch_size", default=32, type=int, help="Total batch size for training.")
44
+ parser.add_argument('--outdir', required=True, help='path to output directory')
45
+ parser.add_argument('--bert_pretrain', required=True)
46
+ parser.add_argument('--checkpoint', required=True)
47
+ parser.add_argument('--dropout', type=float, default=0.6, help='Dropout.')
48
+ parser.add_argument('--no-cuda', action='store_true', default=False, help='Disables CUDA training.')
49
+ parser.add_argument("--bert_hidden_dim", default=768, type=int, help="Total batch size for training.")
50
+ parser.add_argument("--layer", type=int, default=1, help='Graph Layer.')
51
+ parser.add_argument("--num_labels", type=int, default=3)
52
+ parser.add_argument("--evi_num", type=int, default=5, help='Evidence num.')
53
+ parser.add_argument("--threshold", type=float, default=0.0, help='Evidence num.')
54
+ parser.add_argument("--max_len", default=120, type=int,
55
+ help="The maximum total input sequence length after WordPiece tokenization. Sequences "
56
+ "longer than this will be truncated, and sequences shorter than this will be padded.")
57
+ args = parser.parse_args()
58
+
59
+ if not os.path.exists(args.outdir):
60
+ os.mkdir(args.outdir)
61
+ args.cuda = not args.no_cuda and torch.cuda.is_available()
62
+ handlers = [logging.FileHandler(os.path.abspath(args.outdir) + '/train_log.txt'), logging.StreamHandler()]
63
+ logging.basicConfig(format='[%(asctime)s] %(levelname)s: %(message)s', level=logging.DEBUG,
64
+ datefmt='%d-%m-%Y %H:%M:%S', handlers=handlers)
65
+ logger.info(args)
66
+ logger.info('Start training!')
67
+
68
+ tokenizer = BertTokenizer.from_pretrained(args.bert_pretrain, do_lower_case=False)
69
+ logger.info("loading training set")
70
+ validset_reader = DataLoaderTest(args.test_path, tokenizer, args, batch_size=args.batch_size)
71
+
72
+ logger.info('initializing estimator model')
73
+ bert_model = BertForSequenceEncoder.from_pretrained(args.bert_pretrain)
74
+ bert_model = bert_model.cuda()
75
+ model = inference_model(bert_model, args)
76
+ model.load_state_dict(torch.load(args.checkpoint)['model'])
77
+ model = model.cuda()
78
+ logger.info('Start eval!')
79
+ save_path = args.outdir + "/" + args.name
80
+ predict_dict = eval_model(model, validset_reader)
81
+ save_to_file(predict_dict, save_path, args.evi_num)
src/er_client/retrieval_model/test.sh ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ python test.py \
2
+ --test_path ../data/pages.json \
3
+ --bert_pretrain ../evidence_retrieval/bert_base \
4
+ --checkpoint ../evidence_retrieval/retrieval_model/model.best.pt \
5
+ --evi_num 5 \
6
+ --outdir ../data \
7
+ --name evidence.json
src/er_client/sentence_selection.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ '''
4
+ @Author : Jiangjie Chen
5
+ @Time : 2020/9/20 11:42
6
+ @Contact : jjchen19@fudan.edu.cn
7
+ @Description:
8
+ '''
9
+
10
+ import torch
11
+ from transformers import BertTokenizer
12
+ from .retrieval_model.bert_model import BertForSequenceEncoder
13
+ from .retrieval_model.models import inference_model
14
+ from .retrieval_model.data_loader import DataLoaderTest
15
+
16
+
17
+ class SentSelector:
18
+ def __init__(self, pretrained_bert_path, select_model_path, args):
19
+ self.args = args
20
+ self.use_cuda = self.args.use_cuda and torch.cuda.is_available()
21
+
22
+ self.tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
23
+ self.bert_model = BertForSequenceEncoder.from_pretrained(pretrained_bert_path)
24
+
25
+ self.rank_model = inference_model(self.bert_model, self.args)
26
+ self.rank_model.load_state_dict(torch.load(select_model_path)['model'])
27
+
28
+ if self.use_cuda:
29
+ self.bert_model = self.bert_model.cuda()
30
+ self.rank_model.cuda()
31
+
32
+ def rank_sentences(self, js: list):
33
+ '''
34
+ :param js: [{'claim': xxx, 'id': xx, 'evidence': xxx}]
35
+ :return: [(ent, num, sent, prob), (ent, num, sent, prob)]
36
+ '''
37
+ data_reader = DataLoaderTest(js, self.tokenizer, self.args, self.use_cuda)
38
+ self.rank_model.eval()
39
+ all_predict = dict()
40
+ for inp_tensor, msk_tensor, seg_tensor, ids, evi_list in data_reader:
41
+ probs = self.rank_model(inp_tensor, msk_tensor, seg_tensor)
42
+ probs = probs.tolist()
43
+ assert len(probs) == len(evi_list)
44
+ for i in range(len(probs)):
45
+ if ids[i] not in all_predict:
46
+ all_predict[ids[i]] = []
47
+ # if probs[i][1] >= probs[i][0]:
48
+ all_predict[ids[i]].append(tuple(evi_list[i]) + (probs[i],))
49
+
50
+ results = {}
51
+ for k, v in all_predict.items():
52
+ sorted_v = sorted(v, key=lambda x: x[-1], reverse=True)
53
+ results[k] = sorted_v[:self.args.evi_num]
54
+ return results
src/eval_client/cjjpy.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ '''
4
+ @Author : Jiangjie Chen
5
+ @Time : 2018/11/15 17:08
6
+ @Contact: jjchen19@fudan.edu.cn
7
+ '''
8
+
9
+ import re
10
+ import datetime
11
+ import os
12
+ import argparse
13
+ import logging
14
+ import traceback
15
+
16
+ try:
17
+ import ujson as json
18
+ except:
19
+ import json
20
+
21
+ HADOOP_BIN = 'PATH=/usr/bin/:$PATH hdfs'
22
+ FOR_PUBLIC = True
23
+
24
+
25
+ def LengthStats(filename):
26
+ len_list = []
27
+ thresholds = [0.8, 0.9, 0.95, 0.99, 0.999]
28
+ with open(filename) as f:
29
+ for line in f:
30
+ len_list.append(len(line.strip().split()))
31
+ stats = {
32
+ 'Max': max(len_list),
33
+ 'Min': min(len_list),
34
+ 'Avg': round(sum(len_list) / len(len_list), 4),
35
+ }
36
+ len_list.sort()
37
+ for t in thresholds:
38
+ stats[f"Top-{t}"] = len_list[int(len(len_list) * t)]
39
+
40
+ for k in stats:
41
+ print(f"- {k}: {stats[k]}")
42
+ return stats
43
+
44
+
45
+ class AttrDict(dict):
46
+ def __init__(self, *args, **kwargs):
47
+ super(AttrDict, self).__init__(*args, **kwargs)
48
+ self.__dict__ = self
49
+
50
+
51
+ def TraceBack(error_msg):
52
+ exc = traceback.format_exc()
53
+ msg = f'[Error]: {error_msg}.\n[Traceback]: {exc}'
54
+ return msg
55
+
56
+
57
+ def Now():
58
+ return datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
59
+
60
+
61
+ def AbsParentDir(file, parent='..', postfix=None):
62
+ ppath = os.path.abspath(file)
63
+ parent_level = parent.count('.')
64
+ while parent_level > 0:
65
+ ppath = os.path.dirname(ppath)
66
+ parent_level -= 1
67
+ if postfix is not None:
68
+ return os.path.join(ppath, postfix)
69
+ else:
70
+ return ppath
71
+
72
+
73
+ def init_logger(log_file=None, log_file_level=logging.NOTSET, from_scratch=False):
74
+ from coloredlogs import ColoredFormatter
75
+ import tensorflow as tf
76
+
77
+ fmt = "[%(asctime)s %(levelname)s] %(message)s"
78
+ log_format = ColoredFormatter(fmt=fmt)
79
+ # log_format = logging.Formatter()
80
+ logger = logging.getLogger()
81
+ logger.setLevel(log_file_level)
82
+
83
+ console_handler = logging.StreamHandler()
84
+ console_handler.setFormatter(log_format)
85
+ logger.handlers = [console_handler]
86
+
87
+ if log_file and log_file != '':
88
+ if from_scratch and tf.io.gfile.exists(log_file):
89
+ logger.warning('Removing previous log file: %s' % log_file)
90
+ tf.io.gfile.remove(log_file)
91
+ path = os.path.dirname(log_file)
92
+ os.makedirs(path, exist_ok=True)
93
+ file_handler = logging.FileHandler(log_file)
94
+ file_handler.setLevel(log_file_level)
95
+ file_handler.setFormatter(log_format)
96
+ logger.addHandler(file_handler)
97
+
98
+ return logger
99
+
100
+
101
+ def OverWriteCjjPy(root='.'):
102
+ # import difflib
103
+ # diff = difflib.HtmlDiff()
104
+ cnt = 0
105
+ golden_cjjpy = os.path.join(root, 'cjjpy.py')
106
+ # golden_content = open(golden_cjjpy).readlines()
107
+ for dir, folder, file in os.walk(root):
108
+ for f in file:
109
+ if f == 'cjjpy.py':
110
+ cjjpy = '%s/%s' % (dir, f)
111
+ # content = open(cjjpy).readlines()
112
+ # d = diff.make_file(golden_content, content)
113
+ cnt += 1
114
+ print('[%d]: %s' % (cnt, cjjpy))
115
+ os.system('cp %s %s' % (golden_cjjpy, cjjpy))
116
+
117
+
118
+ def ChangeFileFormat(filename, new_fmt):
119
+ assert type(filename) is str and type(new_fmt) is str
120
+ spt = filename.split('.')
121
+ if len(spt) == 0:
122
+ return filename
123
+ else:
124
+ return filename.replace('.' + spt[-1], new_fmt)
125
+
126
+
127
+ def CountLines(fname):
128
+ with open(fname, 'rb') as f:
129
+ count = 0
130
+ last_data = '\n'
131
+ while True:
132
+ data = f.read(0x400000)
133
+ if not data:
134
+ break
135
+ count += data.count(b'\n')
136
+ last_data = data
137
+ if last_data[-1:] != b'\n':
138
+ count += 1 # Remove this if a wc-like count is needed
139
+ return count
140
+
141
+
142
+ def GetDate():
143
+ return str(datetime.datetime.now())[5:10].replace('-', '')
144
+
145
+
146
+ def TimeClock(seconds):
147
+ sec = int(seconds)
148
+ hour = int(sec / 3600)
149
+ min = int((sec - hour * 3600) / 60)
150
+ ssec = float(seconds) - hour * 3600 - min * 60
151
+ # return '%dh %dm %.2fs' % (hour, min, ssec)
152
+ return '{0:>2d}h{1:>3d}m{2:>6.2f}s'.format(hour, min, ssec)
153
+
154
+
155
+ def StripAll(text):
156
+ return text.strip().replace('\t', '').replace('\n', '').replace(' ', '')
157
+
158
+
159
+ def GetBracket(text, bracket, en_br=False):
160
+ # input should be aa(bb)cc, True for bracket, False for text
161
+ if bracket:
162
+ try:
163
+ return re.findall('\((.*?)\)', text.strip())[-1]
164
+ except:
165
+ return ''
166
+ else:
167
+ if en_br:
168
+ text = re.sub('\(.*?\)', '', text.strip())
169
+ return re.sub('(.*?)', '', text.strip())
170
+
171
+
172
+ def CharLang(uchar, lang):
173
+ assert lang.lower() in ['en', 'cn', 'zh']
174
+ if lang.lower() in ['cn', 'zh']:
175
+ if uchar >= '\u4e00' and uchar <= '\u9fa5':
176
+ return True
177
+ else:
178
+ return False
179
+ elif lang.lower() == 'en':
180
+ if (uchar <= 'Z' and uchar >= 'A') or (uchar <= 'z' and uchar >= 'a'):
181
+ return True
182
+ else:
183
+ return False
184
+ else:
185
+ raise NotImplementedError
186
+
187
+
188
+ def WordLang(word, lang):
189
+ for i in word.strip():
190
+ if i.isspace(): continue
191
+ if not CharLang(i, lang):
192
+ return False
193
+ return True
194
+
195
+
196
+ def SortDict(_dict, reverse=True):
197
+ assert type(_dict) is dict
198
+ return sorted(_dict.items(), key=lambda d: d[1], reverse=reverse)
199
+
200
+
201
+ def lark(content='test'):
202
+ print(content)
203
+
204
+
205
+ if __name__ == '__main__':
206
+ parser = argparse.ArgumentParser()
207
+
208
+ parser.add_argument('--diff', nargs=2,
209
+ help='show difference between two files, shown in downloads/diff.html')
210
+ parser.add_argument('--de_unicode', action='store_true', default=False,
211
+ help='remove unicode characters')
212
+ parser.add_argument('--link_entity', action='store_true', default=False,
213
+ help='')
214
+ parser.add_argument('--max_comm_len', action='store_true', default=False,
215
+ help='')
216
+ parser.add_argument('--search', nargs=2,
217
+ help='search key from file, 2 args: file name & key')
218
+ parser.add_argument('--email', nargs=2,
219
+ help='sending emails, 2 args: subject & content')
220
+ parser.add_argument('--overwrite', action='store_true', default=None,
221
+ help='overwrite all cjjpy under given *dir* based on *dir*/cjjpy.py')
222
+ parser.add_argument('--replace', nargs=3,
223
+ help='replace char, 3 args: file name & replaced char & replacer char')
224
+ parser.add_argument('--lark', nargs=1)
225
+ parser.add_argument('--get_hdfs', nargs=2,
226
+ help='easy copy from hdfs to local fs, 2 args: remote_file/dir & local_dir')
227
+ parser.add_argument('--put_hdfs', nargs=2,
228
+ help='easy put from local fs to hdfs, 2 args: local_file/dir & remote_dir')
229
+ parser.add_argument('--length_stats', nargs=1,
230
+ help='simple token lengths distribution of a line-by-line file')
231
+
232
+ args = parser.parse_args()
233
+
234
+ if args.overwrite:
235
+ print('* Overwriting cjjpy...')
236
+ OverWriteCjjPy()
237
+
238
+ if args.lark:
239
+ try:
240
+ content = args.lark[0]
241
+ except:
242
+ content = 'running complete'
243
+ print(f'* Larking "{content}"...')
244
+ lark(content)
245
+
246
+ if args.length_stats:
247
+ file = args.length_stats[0]
248
+ print(f'* Working on {file} lengths statistics...')
249
+ LengthStats(file)
src/eval_client/culpa.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding:utf-8 -*-
2
+
3
+ """
4
+ @Author : Bao
5
+ @Date : 2021/9/7
6
+ @Desc :
7
+ @Last modified by : Bao
8
+ @Last modified date : 2021/9/7
9
+ """
10
+
11
+ import json
12
+ import numpy as np
13
+ import argparse
14
+ from collections import defaultdict
15
+ from sklearn.metrics import precision_recall_fscore_support
16
+
17
+ # ref --> label 1, nei & sup --> label 0
18
+ idx2label = {0: 1, 1: 0, 2: 0}
19
+
20
+
21
+ def read_json_lines(filename, mode='r', encoding='utf-8', skip=0):
22
+ with open(filename, mode, encoding=encoding) as fin:
23
+ for line in fin:
24
+ if skip > 0:
25
+ skip -= 1
26
+ continue
27
+ yield json.loads(line)
28
+
29
+
30
+ def process(filein):
31
+ id2info = defaultdict(dict)
32
+ for line in read_json_lines('eval.human.ref.merged.json'):
33
+ labels = [0] * len(line['questions'])
34
+ for cul in line['culprit']:
35
+ labels[cul] = 1
36
+ id2info[line['id']].update({'id': line['id'], 'labels': labels})
37
+
38
+ for line in read_json_lines(filein):
39
+ if line['id'] not in id2info: continue
40
+ predicted = [idx2label[_] for _ in np.argmax(line['z_prob'], axis=-1)]
41
+ id2info[line['id']]['predicted'] = predicted
42
+
43
+ ps, rs, fs = [], [], []
44
+ for info in id2info.values():
45
+ p, r, f, _ = precision_recall_fscore_support(info['labels'], info['predicted'], average='binary')
46
+ ps.append(p)
47
+ rs.append(r)
48
+ fs.append(f)
49
+ print(filein)
50
+ print('Precision: {}'.format(sum(ps) / len(ps)))
51
+ print('Recall: {}'.format(sum(rs) / len(rs)))
52
+ print('F1: {}'.format(sum(fs) / len(fs)))
53
+
54
+ return sum(ps) / len(ps), sum(rs) / len(rs), sum(fs) / len(fs)
55
+
56
+
57
+ if __name__ == '__main__':
58
+ parser = argparse.ArgumentParser()
59
+ parser.add_argument('-i', type=str, help='predicted jsonl file with phrasal veracity predictions.')
60
+ args = parser.parse_args()
61
+ process(args.i)
src/eval_client/culprit/eval.human.ref.json ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {"id": 102600, "claim": "Sausage Party was released in May of 2016 .", "questions": ["What was the name of the new album released in May 2016? or <mask> was released in May of 2016 .", "When was Sausage Party released? or Sausage Party was released in <mask> of 2016 .", "When was Sausage Party released? or Sausage Party was released in May of <mask> .", "What was Sausage Party's release date? or Sausage Party was <mask> in May of 2016 ."], "answers": [["Sausage Party", 0, 13], ["May", 30, 33], ["2016", 37, 41], ["released", 18, 26]], "evidential": [["Sausage Party", "The Sausage Party", "A Sausage Party", "Sausage party"], ["August", "the summer", "March", "the fall"], ["2016", "the year 2016", "March 2016", "2015"], ["released", "announced", "premiered", "released domestically"]], "culprit": [1]}
2
+ {"id": 92833, "claim": "Anne Boleyn did not live in England in 1522 .", "questions": ["Who did not live in England in 1522? or <mask> did not live in England in 1522 .", "Where did Anne Boleyn live in 1522? or Anne Boleyn did not live in <mask> in 1522 .", "When did Anne Boleyn not live in England? or Anne Boleyn did not live in England in <mask> .", "What did Anne Boleyn not do in England? or Anne Boleyn did not <mask> in England in 1522 ."], "answers": [["Anne Boleyn", 0, 11], ["England", 28, 35], ["1522", 39, 43], ["live", 20, 24]], "evidential": [["Anne Boleyn", "Ann Boleyn", "Anne Bolyn", "A woman"], ["England", "Europe", "the world", "the UK"], ["1532", "1536", "1533", "1534"], ["live", "stay", "marry", "reside"]], "culprit": [1, 2]}
3
+ {"id": 159707, "claim": "Edgar Wright is only a producer .", "questions": ["Who is the only producer? or <mask> is only a producer .", "What is Edgar Wright's job title? or Edgar Wright is <mask> ."], "answers": [["Edgar Wright", 0, 12], ["only a producer", 16, 31]], "evidential": [["Edgar Wright", "Edgar Wright Jr.", "Edgar W. Wright", "Edgar Wayne Wright"], ["a producer", "a director", "a screenwriter", "a film producer"]], "culprit": [1]}
4
+ {"id": 146055, "claim": "The Giver is a bill .", "questions": ["What is a bill called? or <mask> is a bill .", "What is the Giver? or The Giver is <mask> ."], "answers": [["The Giver", 0, 9], ["a bill", 13, 19]], "evidential": [["The Giver", "A The Giver", "The giver", "The Giver Act"], ["a film", "a work", "a motion picture", "a movie"]], "culprit": [1]}
5
+ {"id": 8443, "claim": "A Milli is by Justin Bieber .", "questions": ["What is the name of Justin Bieber's song? or <mask> is by Justin Bieber .", "Who is A Milli by? or A Milli is by <mask> ."], "answers": [["A Milli", 0, 7], ["Justin Bieber", 14, 27]], "evidential": [["A Milli", "A Milli song", "A Milli Song", "A Milli."], ["Justin Bieber", "a Justin Bieber", "an artist", "a musician"]], "culprit": [1]}
6
+ {"id": 67833, "claim": "Shane McMahon did not win the Hardcore Championship once .", "questions": ["Who won the Hardcore Championship once? or <mask> did not win the Hardcore Championship once .", "What did Shane McMahon not win once? or Shane McMahon did not win <mask> once .", "What did Shane McMahon not do once? or Shane McMahon did not <mask> the Hardcore Championship once ."], "answers": [["Shane McMahon", 0, 13], ["the Hardcore Championship", 26, 51], ["win", 22, 25]], "evidential": [["Shane McMahon", "Shane McMahon", "Shane McMah", "Shane McMahon ("], ["the European Championship", "the Hardcore Championship", "a wrestling championship", "a championship"], ["win", "won", "achieve", "earn"]], "culprit": [1]}
7
+ {"id": 116789, "claim": "Minor League Baseball is a hierarchy of only amateur baseball leagues .", "questions": ["What is the name of the only amateur baseball league? or <mask> is a hierarchy of only amateur baseball leagues .", "What is Minor League Baseball? or Minor League Baseball is <mask> of only amateur baseball leagues .", "What is Minor League Baseball a hierarchy of? or Minor League Baseball is a hierarchy of <mask> ."], "answers": [["Minor League Baseball", 0, 21], ["a hierarchy", 25, 36], ["only amateur baseball leagues", 40, 69]], "evidential": [["Minor League Baseball", "The Minor League Baseball", "Minor league Baseball", "Major League Baseball"], ["a hierarchy", "an organization", "a system", "a structure"], ["professional baseball leagues", "minor league baseball", "professional baseball teams", "baseball leagues"]], "culprit": [2]}
8
+ {"id": 12454, "claim": "Tangled is a silent film .", "questions": ["What is the name of the film that is a silent film? or <mask> is a silent film .", "What type of film is Tangled? or Tangled is <mask> ."], "answers": [["Tangled", 0, 7], ["a silent film", 11, 24]], "evidential": [["Tangled", "Tangles", "Tangled (", "Tangling"], ["an animated film", "a musical fantasy film", "a fantasy film", "a film"]], "culprit": [1]}
9
+ {"id": 149501, "claim": "Kung Fu Panda was number three at the box office .", "questions": ["What movie was number three at the box office? or <mask> was number three at the box office .", "What was Kung Fu Panda's box office number? or Kung Fu Panda was <mask> at the box office .", "Where was Kung Fu Panda number three? or Kung Fu Panda was number three at <mask> ."], "answers": [["Kung Fu Panda", 0, 13], ["number three", 18, 30], ["the box office", 34, 48]], "evidential": [["Kung Fu Panda", "Kung fu Panda", "Kung F Panda", "Kungfu Panda"], ["the number one", "number one", "the number one movie", "the number one film"], ["the box office", "the movie box office", "the US box office", "a box office"]], "culprit": [1]}
10
+ {"id": 51962, "claim": "Mandy Moore is a Canadian film actress .", "questions": ["Who is the name of the Canadian film actress? or <mask> is a Canadian film actress .", "What nationality is Mandy Moore? or Mandy Moore is <mask> film actress .", "What is Mandy Moore's career? or Mandy Moore is a Canadian <mask> ."], "answers": [["Mandy Moore", 0, 11], ["a Canadian", 15, 25], ["film actress", 26, 38]], "evidential": [["Mandy Moore", "Mandy Moore ( choreographer )", "Mandy Moore ( dancer )", "Mandy Moore( choreographer )"], ["an American", "an american", "a North American", "an North American"], ["actress", "film actress", "actor", "singer"]], "culprit": [1]}
11
+ {"id": 217102, "claim": "Innovation is viewed as the application of better solutions that negate market needs .", "questions": ["What is viewed as the application of better solutions that negate market needs? or <mask> is viewed as the application of better solutions that negate market needs .", "Innovation is viewed as what? or Innovation is viewed as <mask> of better solutions that negate market needs .", "Innovation is viewed as the application of what? or Innovation is viewed as the application of <mask> that negate market needs .", "Innovation is viewed as the application of better solutions that negate what? or Innovation is viewed as the application of better solutions that negate <mask> .", "What is innovation <mask> as? or Innovation is <mask> as the application of better solutions that negate market needs .", "Innovation is viewed as the application of better solutions that do what to market needs? or Innovation is viewed as the application of better solutions that <mask> market needs ."], "answers": [["Innovation", 0, 10], ["the application", 24, 39], ["better solutions", 43, 59], ["market needs", 72, 84], ["viewed", 14, 20], ["negate", 65, 71]], "evidential": [["Innovation", "Technology innovation", "Insulin", "In innovation"], ["the application", "an application", "a application", "the applications"], ["solutions", "new solutions", "better solutions", "products"], ["new requirements", "existing market needs", "existing market requirements", "existing requirements"], ["viewed", "perceived", "characterized", "described"], ["meet", "meet existing", "meet current", "met"]], "culprit": [5]}
12
+ {"id": 202314, "claim": "The New Jersey Turnpike has zero shoulders .", "questions": ["What has zero shoulders? or <mask> has zero shoulders .", "What is the total length of the New Jersey Turnpike? or The New Jersey Turnpike has <mask> ."], "answers": [["The New Jersey Turnpike", 0, 23], ["zero shoulders", 28, 42]], "evidential": [["The New Jersey Turnpike", "New Jersey Turnpike", "A New Jersey Turnpike", "the New Jersey Turnpike"], ["12 ft lanes", "a total length", "12 feet long", "12 feet"]], "culprit": [1]}
13
+ {"id": 226106, "claim": "Bongwater is set outside of Oregon .", "questions": ["What is the name of the town outside of Oregon? or <mask> is set outside of Oregon .", "What state is Bongwater located outside of? or Bongwater is set outside of <mask> .", "Where is Bongwater located outside of Oregon? or Bongwater is <mask> outside of Oregon .", "Where is Bongwater located? or Bongwater is set <mask> of Oregon ."], "answers": [["Bongwater", 0, 9], ["Oregon", 28, 34], ["set", 13, 16], ["outside", 17, 24]], "evidential": [["Bongwater", "The film Bongwater", "Bongwwater", "Bongswater"], ["Oregon", "a state", "Washington State", "the Oregon"], ["set", "located", "filmed", "based"], ["the state", "outside", "the city", "the coast"]], "culprit": [3]}
14
+ {"id": 182051, "claim": "The Fly was first released in 1999 .", "questions": ["What was the name of the first film released in 1999? or <mask> was first released in 1999 .", "When was The Fly first released? or The Fly was first released in <mask> .", "When was The Fly first <mask>? or The Fly was first <mask> in 1999 .", "When was The Fly released? or The Fly was <mask> released in 1999 ."], "answers": [["The Fly", 0, 7], ["1999", 30, 34], ["released", 18, 26], ["first", 12, 17]], "evidential": [["The Fly", "The Fly 's", "A film The Fly", "The fly"], ["August 1986", "1986", "the 1980s", "the eighties"], ["released", "published", "distributed", "release"], ["first", "originally", "last", "only"]], "culprit": [1]}
15
+ {"id": 65598, "claim": "Uganda was not ruled by the British .", "questions": ["What country was not ruled by the British? or <mask> was not ruled by the British .", "Who ruled Uganda? or Uganda was not ruled by <mask> .", "What was Uganda not <mask> by the British? or Uganda was not <mask> by the British ."], "answers": [["Uganda", 0, 6], ["the British", 24, 35], ["ruled", 15, 20]], "evidential": [["Uganda", "Uganda", "Ugandan", "Uganda"], ["the British", "Britain", "a colony", "British"], ["ruled", "controlled", "governed", "owned"]], "culprit": [1]}
16
+ {"id": 117126, "claim": "Pocahontas was not the daughter of Powhatan .", "questions": ["Who was not the daughter of Powhatan? or <mask> was not the daughter of Powhatan .", "What was Pocahontas' mother's name? or Pocahontas was not <mask> of Powhatan .", "Who was Pocahontas' father? or Pocahontas was not the daughter of <mask> ."], "answers": [["Pocahontas", 0, 10], ["the daughter", 19, 31], ["Powhatan", 35, 43]], "evidential": [["Pocahontas", "Pocahonta", "Pocahontas n't", "Pocahontas Jr."], ["the daughter", "a daughter", "the granddaughter", "the child"], ["Powhatan", "a Native American", "a chief", "a person"]], "culprit": [1, 2]}
17
+ {"id": 164506, "claim": "The Nobel Prize in Chemistry was awarded to a person from anywhere except the Netherlands .", "questions": ["What award was given to a person from anywhere except the Netherlands? or <mask> was awarded to a person from anywhere except the Netherlands .", "Who was the Nobel Prize in Chemistry awarded to? or The Nobel Prize in Chemistry was awarded to <mask> from anywhere except the Netherlands .", "Where was the Nobel Prize in Chemistry awarded to? or The Nobel Prize in Chemistry was awarded to a person from <mask> except the Netherlands .", "Where was the Nobel Prize in Chemistry awarded to? or The Nobel Prize in Chemistry was awarded to a person from anywhere except <mask> .", "How is the Nobel Prize in Chemistry <mask>? or The Nobel Prize in Chemistry was <mask> to a person from anywhere except the Netherlands ."], "answers": [["The Nobel Prize in Chemistry", 0, 28], ["a person", 44, 52], ["anywhere", 58, 66], ["the Netherlands", 74, 89], ["awarded", 33, 40]], "evidential": [["The Nobel Prize in Chemistry", "A Nobel Prize in Chemistry", "Nobel Prize in Chemistry", "The Nobel prize in Chemistry"], ["scientists", "a scientist", "people", "anyone"], ["every country", "every state", "every place", "all"], ["the Netherlands", "Sweden", "Europe", "Norway"], ["awarded", "given", "presented", "distributed"]], "culprit": [2, 3]}
18
+ {"id": 113010, "claim": "Duane Chapman is not a former bail bondsman .", "questions": ["Who is not a former bail bondsman? or <mask> is not a former bail bondsman .", "What is Duane Chapman's profession? or Duane Chapman is not <mask> ."], "answers": [["Duane Chapman", 0, 13], ["a former bail bondsman", 21, 43]], "evidential": [["Duane Chapman", "Duane ChapmanI.", "Duane Chapman I.", "Duane Chapman II."], ["a bail bondsman", "a bounty hunter", "a former bail bondsman", "a bail bondsman"]], "culprit": [1]}
19
+ {"id": 109582, "claim": "US Airways Flight 1549 did not have any people on board .", "questions": ["What flight did not have any people on board? or <mask> did not have any people on board .", "What was not on board the US Airways Flight 1549? or US Airways Flight 1549 did not have any <mask> on board .", "What was the name of the aircraft that did not have any people on? or US Airways Flight 1549 did not have any people on <mask> ."], "answers": [["US Airways Flight 1549", 0, 22], ["people", 40, 46], ["board", 50, 55]], "evidential": [["US Airways Flight 1549", "The US Airways Flight 1549", "American Airways Flight 1549", "United Airways Flight 1549"], ["people", "passengers", "humans", "birds"], ["an aircraft", "an Airbus A320", "the Airbus A320", "an airliner"]], "culprit": [2]}
20
+ {"id": 23766, "claim": "Charles de Gaulle was a Polish Resistance leader .", "questions": ["Who was the leader of the Polish Resistance? or <mask> was a Polish Resistance leader .", "What nationality was Charles de Gaulle? or Charles de Gaulle was <mask> Resistance leader .", "What political party was Charles de Gaulle a leader of? or Charles de Gaulle was a Polish <mask> leader .", "What was Charles de Gaulle's role in the Polish Resistance? or Charles de Gaulle was a Polish Resistance <mask> ."], "answers": [["Charles de Gaulle", 0, 17], ["a Polish", 22, 30], ["Resistance", 31, 41], ["leader", 42, 48]], "evidential": [["Charles de Gaulle", "Charles De Gaulle", "Charles de Gaulle", "Louis de Gaulle"], ["a French", "an American", "the French", "an English"], ["French", "Nationalist", "Communist", "National Socialist"], ["leader", "chief leader", "person", "chief strategist"]], "culprit": [1]}
21
+ {"id": 94556, "claim": "Pirates of the Caribbean has yet to be opened in Disneyland Paris .", "questions": ["What is the name of the movie that has yet to be opened at Disneyland Paris? or <mask> has yet to be opened in Disneyland Paris .", "Where is Pirates of the Caribbean currently located? or Pirates of the Caribbean has yet to be opened in <mask> .", "What is the name of the first attraction to open at Disneyland Paris? or Pirates of the Caribbean has yet to be <mask> in Disneyland Paris .", "How long has it been since the Pirates of the Caribbean opened? or Pirates of the Caribbean has <mask> to be opened in Disneyland Paris ."], "answers": [["Pirates of the Caribbean", 0, 24], ["Disneyland Paris", 49, 65], ["opened", 39, 45], ["yet", 29, 32]], "evidential": [["Pirates of the Caribbean", "The Pirates of the Caribbean", "Pirates of The Caribbean", "Pirates of Caribbean"], ["Disneyland Paris", "Disney Disneyland Paris", "Disney Paris", "Disney Park"], ["an attraction", "the first attraction", "a ride", "the first ride"], ["yet", "a decade", "a year", "the time"]], "culprit": [2, 3]}
22
+ {"id": 225871, "claim": "Revolver has only ever topped a single chart .", "questions": ["What has only ever topped a single chart? or <mask> has only ever topped a single chart .", "How many charts has Revolver ever topped? or Revolver has only ever topped <mask> .", "How many times has Revolver ever <mask> a single chart? or Revolver has only ever <mask> a single chart .", "How many times has Revolver topped a single chart? or Revolver has <mask> ever topped a single chart ."], "answers": [["Revolver", 0, 8], ["a single chart", 30, 44], ["topped", 23, 29], ["only", 13, 17]], "evidential": [["Revolver", "Revololver", "Revolver Record", "The Revolver"], ["four charts", "two charts", "three charts", "zero charts"], ["topped", "charted", "reached", "appeared"], ["never", "n't", "only", "rarely"]], "culprit": [1, 3]}
23
+ {"id": 164417, "claim": "Carey Hayes was born in 1897 .", "questions": ["Who was born in 1897? or <mask> was born in 1897 .", "When was Carey Hayes born? or Carey Hayes was born in <mask> .", "What was Carey Hayes' birth year? or Carey Hayes was <mask> in 1897 ."], "answers": [["Carey Hayes", 0, 11], ["1897", 24, 28], ["born", 16, 20]], "evidential": [["Carey Hayes", "Carey Hayes (", "Carey Hayes", "Carey Hayden"], ["1961", "the 1960s", "the 1960 's", "the 20th century"], ["born", "conceived", "created", "born in"]], "culprit": [1]}
24
+ {"id": 70311, "claim": "IMDb is a professional Dota 2 player .", "questions": ["What is the name of the professional Dota 2 player? or <mask> is a professional Dota 2 player .", "What is IMDb's professional name? or IMDb is a professional <mask> 2 player .", "How many players does IMDb have? or IMDb is a professional Dota <mask> .", "IMDb is what type of player? or IMDb is <mask> Dota 2 player ."], "answers": [["IMDb", 0, 4], ["Dota", 23, 27], ["2 player", 28, 36], ["a professional", 8, 22]], "evidential": [["IMDb", "The Internet Movie Database", "The internet movie database", "The internet Movie Database"], ["Game", "Web", "video game", "Webmaster"], ["users", "one player", "one user", "user"], ["an online database", "a fictional", "a popular", "a professional"]], "culprit": [1, 2, 3]}
25
+ {"id": 123479, "claim": "The Hundred Years ' War does not include the Lancastrian War .", "questions": ["What does not include the Lancastrian War? or <mask> does not include the Lancastrian War .", "What is not included in the Hundred Years' War? or The Hundred Years ' War does not include <mask> .", "What does the Hundred Years' War not <mask>? or The Hundred Years ' War does not <mask> the Lancastrian War ."], "answers": [["The Hundred Years ' War", 0, 23], ["the Lancastrian War", 41, 60], ["include", 33, 40]], "evidential": [["The Hundred Years ' War", "The Hundred Years' War", "Hundred Years ' War", "A Hundred Years ' War"], ["a conflict", "local conflicts", "a war", "several conflicts"], ["mention", "name", "see", "include"]], "culprit": [1]}
26
+ {"id": 16811, "claim": "Efraim Diveroli is a Spaniard .", "questions": ["Who is a Spaniard? or <mask> is a Spaniard .", "What is Efraim Diveroli's nationality? or Efraim Diveroli is <mask> ."], "answers": [["Efraim Diveroli", 0, 15], ["a Spaniard", 19, 29]], "evidential": [["Efraim Diveroli", "Efranim Diveroli", "Efriim Diveroli", "Efrafri Diveroli"], ["an American", "American", "North American", "a North American"]], "culprit": [1]}
27
+ {"id": 183618, "claim": "Finding Dory was written by anyone except Andrew Stanton .", "questions": ["What was the name of the book that was written by anyone other than Andrew Stanton? or <mask> was written by anyone except Andrew Stanton .", "Who wrote Finding Dory? or Finding Dory was written by <mask> except Andrew Stanton .", "Who wrote Finding Dory? or Finding Dory was written by anyone except <mask> .", "Who else wrote Finding Dory? or Finding Dory was <mask> by anyone except Andrew Stanton ."], "answers": [["Finding Dory", 0, 12], ["anyone", 28, 34], ["Andrew Stanton", 42, 56], ["written", 17, 24]], "evidential": [["Finding Dory", "The Finding Dory", "Finding dory", "Finding Dory 2"], ["anyone", "every person", "almost anyone", "almost all"], ["Andrew Stanton", "Andrew Strouse", "Andrew Stanton", "Andy Stanton"], ["written", "penned", "directed", "authored"]], "culprit": [2]}
28
+ {"id": 125315, "claim": "Phoenix , Arizona is the most populous city in Massachusetts .", "questions": ["What is the most populous city in Massachusetts? or <mask> , Arizona is the most populous city in Massachusetts .", "What state is Phoenix located in? or Phoenix , <mask> is the most populous city in Massachusetts .", "What state is Phoenix located in? or Phoenix , Arizona is the most populous city in <mask> .", "What is the population of Phoenix in Massachusetts? or Phoenix , Arizona is <mask> populous city in Massachusetts .", "What is the population of Phoenix? or Phoenix , Arizona is the most <mask> city in Massachusetts ."], "answers": [["Phoenix", 0, 7], ["Arizona", 10, 17], ["Massachusetts", 47, 60], ["the most", 21, 29], ["populous", 30, 38]], "evidential": [["Phoenix", "The Phoenix", "Arizona Phoenix", "Tempe"], ["Arizona", "Arizona Republic", "Arizona State", "United States"], ["the United States", "the US", "Arizona", "a state"], ["the most", "the fifth most", "the 5th most", "the fourth most"], ["populous", "populous city", "populous US", "large"]], "culprit": [2]}
29
+ {"id": 216367, "claim": "All speakers of the Chagatai language lived in France .", "questions": ["Who lived in France? or All <mask> of the Chagatai language lived in France .", "What language did all French speakers speak? or All speakers of <mask> lived in France .", "Where did the Chagatai language live? or All speakers of the Chagatai language lived in <mask> .", "Where did all speakers of the Chagatai language live? or All speakers of the Chagatai language <mask> in France ."], "answers": [["speakers", 4, 12], ["the Chagatai language", 16, 37], ["France", 47, 53], ["lived", 38, 43]], "evidential": [["The authors", "An author", "People", "A person"], ["the Chagatai language", "Chagatai language", "The Chagatai language", "a Chagatai language"], ["Europe", "a place", "France", "Asia"], ["lived", "existed", "resided", "originated"]], "culprit": [2]}
30
+ {"id": 23428, "claim": "The Cincinnati Kid was directed by Norman Jewison in 1960 .", "questions": ["What movie was directed by Norman Jewison? or <mask> was directed by Norman Jewison in 1960 .", "Who directed The Cincinnati Kid? or The Cincinnati Kid was directed by <mask> in 1960 .", "When was The Cincinnati Kid directed? or The Cincinnati Kid was directed by Norman Jewison in <mask> .", "What was the name of the film that was produced by Norman Jewison? or The Cincinnati Kid was <mask> by Norman Jewison in 1960 ."], "answers": [["The Cincinnati Kid", 0, 18], ["Norman Jewison", 35, 49], ["1960", 53, 57], ["directed", 23, 31]], "evidential": [["The Cincinnati Kid", "the Cincinnati Kid", "The CincinnatiKid", "Cincinnati Kid"], ["Norman Jewison", "a man", "Norman JewISON", "Norman Jewisons"], ["1965", "the 1960s", "the 1960 's", "the late 1960s"], ["directed", "produced", "written", "filmed"]], "culprit": [2]}
31
+ {"id": 67903, "claim": "Murda Beatz 's real name is Donald Trump .", "questions": ["Who is Donald Trump's real name? or <mask> 's real name is Donald Trump .", "What is Beatz' real name? or Murda Beatz 's real name is <mask> .", "What is the <mask> name of Murda Beatz? or Murda Beatz 's <mask> name is Donald Trump ."], "answers": [["Murda Beatz", 0, 11], ["Donald Trump", 28, 40], ["real", 15, 19]], "evidential": [["Murda Beatz", "Murdas Beatz", "Murda beatz", "Murdac Beatz"], ["Donald Trump", "a Donald Trump", "Donald Donald Trump", "Donald John Trump"], ["middle", "real", "full", "legal"]], "culprit": [1]}
32
+ {"id": 45585, "claim": "Harris Jayaraj is from Idaho .", "questions": ["Who is from Idaho? or <mask> is from Idaho .", "Where is Harris Jayaraj from? or Harris Jayaraj is from <mask> ."], "answers": [["Harris Jayaraj", 0, 14], ["Idaho", 23, 28]], "evidential": [["Harris Jayaraj", "Harris Jayaram", "Harris Jayarbaj", "Harris Jayaraja"], ["a state", "Idaho", "a place", "America"]], "culprit": [1]}
33
+ {"id": 95601, "claim": "Ian Gillan is only a singer .", "questions": ["Who is the only singer? or <mask> is only a singer .", "What is Ian Gillan's job? or Ian Gillan is <mask> ."], "answers": [["Ian Gillan", 0, 10], ["only a singer", 14, 27]], "evidential": [["Ian Gillan", "Ian Gillan", "Ian Gillan", "Ian Gillans"], ["a singer", "a vocalist", "a singer and songwriter", "a performer"]], "culprit": [1]}
34
+ {"id": 122348, "claim": "Wolfgang Amadeus Mozart never married .", "questions": ["Who never married? or <mask> never married .", "What did Wolfgang Amadeus Mozart never do? or Wolfgang Amadeus Mozart never <mask> .", "How did Wolfgang Amadeus Mozart get married? or Wolfgang Amadeus Mozart <mask> married ."], "answers": [["Wolfgang Amadeus Mozart", 0, 23], ["married", 30, 37], ["never", 24, 29]], "evidential": [["Wolfgang Amadeus Mozart", "Amadeus Mozart", "Johannes Amadeus Mozart", "Wolfgang Amadeu Mozart"], ["married", "marry", "died", "live"], ["got", "eventually", "later", "was"]], "culprit": [2]}
35
+ {"id": 146157, "claim": "The New England Patriots lost five Super Bowls .", "questions": ["Who lost five Super Bowls? or <mask> lost five Super Bowls .", "What type of game did the New England Patriots lose? or The New England Patriots lost five <mask> .", "How many Super Bowls did the New England Patriots win? or The New England Patriots <mask> five Super Bowls .", "How many Super Bowls did the New England Patriots lose? or The New England Patriots lost <mask> Super Bowls ."], "answers": [["The New England Patriots", 0, 24], ["Super Bowls", 35, 46], ["lost", 25, 29], ["five", 30, 34]], "evidential": [["New England Patriots", "The Patriots", "The New Patriots", "Patriots"], ["Super Bowls", "a Super Bowl", "the Super Bowl", "a football game"], ["won", "played", "reached", "achieved"], ["five", "5", "least five", "seven"]], "culprit": [2]}
36
+ {"id": 107699, "claim": "Floyd Mayweather Jr. is incapable of boxing .", "questions": ["Who is incapable of boxing? or <mask> is incapable of boxing .", "Floyd Mayweather Jr. is incapable of what sport? or Floyd Mayweather Jr. is incapable of <mask> .", "Is Floyd Mayweather Jr. capable or <mask> of boxing? or Floyd Mayweather Jr. is <mask> of boxing ."], "answers": [["Floyd Mayweather Jr.", 0, 20], ["boxing", 37, 43], ["incapable", 24, 33]], "evidential": [["Floyd Mayweather Jr.", "Floyd Mayweather Jr .", "Floyd Mayweather Jr.?", "Floyd Mayweather Jr.:"], ["boxing", "professional boxing", "boxed", "a sport"], ["incapable", "capable", "a capable", "an athlete"]], "culprit": [2]}
37
+ {"id": 216594, "claim": "Calcaneal spurs are only detected by a dancing technique .", "questions": ["What is only detected by a dancing technique? or <mask> are only detected by a dancing technique .", "What is the only way to detect Calcaneal spurs? or Calcaneal spurs are only detected by <mask> .", "How are Calcaneal spurs <mask>? or Calcaneal spurs are only <mask> by a dancing technique .", "How are Calcaneal spurs detected by a dancing technique? or Calcaneal spurs are <mask> detected by a dancing technique ."], "answers": [["Calcaneal spurs", 0, 15], ["a dancing technique", 37, 56], ["detected", 25, 33], ["only", 20, 24]], "evidential": [["Calcaneal spurs", "Calcaneal spur", "Calcaneals spurs", "Calcane al spurs"], ["a radiographic examination", "an x ray", "radiographic examination", "a radiographic exam"], ["detected", "observed", "seen", "indicated"], ["typically", "usually", "often", "frequently"]], "culprit": [1, 3]}
38
+ {"id": 118068, "claim": "Liverpool is unrelated to The Beatles .", "questions": ["What city is not related to The Beatles? or <mask> is unrelated to The Beatles .", "Liverpool is not related to what band? or Liverpool is unrelated to <mask> .", "Is Liverpool related to The Beatles? or Liverpool is <mask> to The Beatles ."], "answers": [["Liverpool", 0, 9], ["The Beatles", 26, 37], ["unrelated", 13, 22]], "evidential": [["Liverpool", "The Liverpool", "Liverpool City", "Liverpool"], ["The Beatles", "the Beatles", "a rock band", "a band"], ["related", "connected", "a home", "home"]], "culprit": [2]}
39
+ {"id": 110504, "claim": "The Mighty Ducks was only distributed by a subsidiary of 20th Century Fox .", "questions": ["What was the name of the show that was distributed by 20th Century Fox? or <mask> was only distributed by a subsidiary of 20th Century Fox .", "Who distributed the Mighty Ducks? or The Mighty Ducks was only distributed by <mask> of 20th Century Fox .", "Who distributed the Mighty Ducks? or The Mighty Ducks was only distributed by a subsidiary of <mask> .", "How was the Mighty Ducks <mask>? or The Mighty Ducks was only <mask> by a subsidiary of 20th Century Fox .", "How many times was The Mighty Ducks distributed by 20th Century Fox? or The Mighty Ducks was <mask> distributed by a subsidiary of 20th Century Fox ."], "answers": [["The Mighty Ducks", 0, 16], ["a subsidiary", 41, 53], ["20th Century Fox", 57, 73], ["distributed", 26, 37], ["only", 21, 25]], "evidential": [["The Mighty Ducks", "The Mighty Ducks of Anaheim", "The Mighty Duck", "Mighty Ducks"], ["the parent company", "a division", "a subsidiary", "the company"], ["Walt Disney Pictures", "Disney Pictures", "a company", "Walt Disney Productions"], ["distributed", "produced", "released", "created"], ["only", "never", "twice", "previously"]], "culprit": [1, 2, 4]}
40
+ {"id": 161151, "claim": "No Strings Attached was released on May 21 .", "questions": ["What was released on May 21? or No <mask> was released on May 21 .", "When was No Strings Attached released? or No Strings Attached was released on <mask> .", "When was No Strings Attached <mask>? or No Strings Attached was <mask> on May 21 ."], "answers": [["Strings Attached", 3, 19], ["May 21", 36, 42], ["released", 24, 32]], "evidential": [["Strings Attached", "strings Attached", "Strings Attached album", "Strings Attached film"], ["January 21 , 2011", "January 21st", "January 21st 2011", "January 21"], ["released", "published", "issued", "distributed"]], "culprit": [1]}
41
+ {"id": 150099, "claim": "Sherilyn Fenn is Japanese .", "questions": ["Who is the name of the Japanese woman who is a native of Japan? or <mask> is Japanese .", "What language is Sherilyn Fenn? or Sherilyn Fenn is <mask> ."], "answers": [["Sherilyn Fenn", 0, 13], ["Japanese", 17, 25]], "evidential": [["Sherilyn Fenn", "The Sherilyn Fenn", "Sherilyn Fenna", "Cherilyn Fenn"], ["American", "English", "North American", "French"]], "culprit": [1]}
42
+ {"id": 157652, "claim": "Touchscreens are only used in gaming computers .", "questions": ["What type of screen is used in gaming computers? or <mask> are only used in gaming computers .", "What type of computers are touch screens used for? or Touchscreens are only used in <mask> .", "What is the only way a touch screen can be <mask> in gaming computers? or Touchscreens are only <mask> in gaming computers .", "How are touchscreens used in gaming computers? or Touchscreens are <mask> used in gaming computers ."], "answers": [["Touchscreens", 0, 12], ["gaming computers", 30, 46], ["used", 22, 26], ["only", 17, 21]], "evidential": [["Touchscreens", "Touchscreen", "Touchscreen devices", "Touch screens"], ["personal computers", "electronic voting machines", "computer systems", "mobile computers"], ["common", "used", "found", "prevalent"], ["commonly", "frequently", "increasingly", "widely"]], "culprit": [3]}
43
+ {"id": 209863, "claim": "In a Lonely Place had nothing to do with any novel by Dorthy B. Hughes .", "questions": ["What was the name of the book that had nothing to do with any novel by Dorthy or <mask> had nothing to do with any novel by Dorthy B. Hughes .", "What did In a Lonely Place have to do with Dorthy B. Hughes or In a Lonely Place had <mask> to do with any novel by Dorthy B. Hughes .", "What type of work did In a Lonely Place have nothing to do with? or In a Lonely Place had nothing to do with any <mask> by Dorthy B. Hughes .", "Who wrote In a Lonely Place? or In a Lonely Place had nothing to do with any novel by <mask> ."], "answers": [["In a Lonely Place", 0, 17], ["nothing", 22, 29], ["novel", 45, 50], ["Dorthy B. Hughes", 54, 70]], "evidential": [["In a Lonely Place", "in a Lonely Place", "In a Lonely place", "In a Lonely Place ."], ["a lot", "a thing", "nothing", "a script"], ["novels", "mystery work", "written work", "written works"], ["Dorothy B. Hughes", "a mystery writer", "the mystery writer", "the author"]], "culprit": [1, 2, 3]}
44
+ {"id": 3305, "claim": "Julianne Moore was not in the television series As the World Turns .", "questions": ["Who was not in the television series As The World Turns? or <mask> was not in the television series As the World Turns .", "What was Julianne Moore not in? or Julianne Moore was not in <mask> As the World Turns .", "What television series did Julianne Moore not appear in? or Julianne Moore was not in the television series As <mask> ."], "answers": [["Julianne Moore", 0, 14], ["the television series", 26, 47], ["the World Turns", 51, 66]], "evidential": [["Julianne Moore", "Juliene Moore", "Juliann Moore", "Julianna Moore"], ["the soap opera", "the television show", "the television series", "the show"], ["the World Turns", "The World Turns", "the World turns", "a World Turns"]], "culprit": [1, 2]}
45
+ {"id": 83351, "claim": "In 2015 , among Mexicans , 70 % of adults had consumed alcoholic drink in the last year .", "questions": ["In what year did 70 % of Mexican adults drink alcohol? or In <mask> , among Mexicans , 70 % of adults had consumed alcoholic drink in the last year .", "What ethnicity had the highest percentage of alcoholic beverages in 2015? or In 2015 , among <mask> , 70 % of adults had consumed alcoholic drink in the last year .", "What percentage of Mexican adults had consumed alcohol in 2015? or In 2015 , among Mexicans , <mask> of adults had consumed alcoholic drink in the last year .", "What group of Mexicans consumed alcohol in 2015? or In 2015 , among Mexicans , 70 % of <mask> had consumed alcoholic drink in the last year .", "What type of drink did 70 % of Mexican adults consume in 2015? or In 2015 , among Mexicans , 70 % of adults had consumed <mask> in the last year .", "In what year did 70 % of Mexican adults drink alcohol? or In 2015 , among Mexicans , 70 % of adults had consumed alcoholic drink in <mask> .", "What did 70 % of adults in Mexico do with alcoholic beverages? or In 2015 , among Mexicans , 70 % of adults had <mask> alcoholic drink in the last year ."], "answers": [["2015", 3, 7], ["Mexicans", 16, 24], ["70 %", 27, 31], ["adults", 35, 41], ["alcoholic drink", 55, 70], ["the last year", 74, 87], ["consumed", 46, 54]], "evidential": [["2015", "2015 's", "the 2015 year", "the last year"], ["Americans", "Mexican", "the Mexican", "Mexicans"], ["89 %", "90 %", "70 %", "87 %"], ["adults", "people", "adult", "Americans"], ["alcohol", "alcoholic drink", "alcoholic drinks", "alcoholic beverages"], ["the last year", "the past year", "the year", "2015"], ["drank", "drunk", "consumed", "drinking"]], "culprit": [1]}
46
+ {"id": 97937, "claim": "Watchmen is a film set in the future .", "questions": ["What is the name of the film set in the future? or <mask> is a film set in the future .", "What type of film is Watchmen? or Watchmen is <mask> set in the future .", "What is the setting of Watchmen? or Watchmen is a film set in <mask> .", "Where is Watchmen <mask>? or Watchmen is a film <mask> in the future ."], "answers": [["Watchmen", 0, 8], ["a film", 12, 18], ["the future", 26, 36], ["set", 19, 22]], "evidential": [["Watchmen", "Watchmen ( film )", "Watchmen( film )", "Watchmen(film )"], ["a superhero film", "a film", "a dystopian film", "a cinematic film"], ["an alternate history", "a dystopian history", "a dystopian future", "a past"], ["set", "located", "filmed", "based"]], "culprit": [2]}
47
+ {"id": 8298, "claim": "Simon Pegg is only a banker .", "questions": ["Who is a banker? or <mask> is only a banker .", "What is Simon Pegg's job title? or Simon Pegg is <mask> ."], "answers": [["Simon Pegg", 0, 10], ["only a banker", 14, 27]], "evidential": [["Simon Pegg", "Simon Pgg", "Simon pegg", "Simon Pegg"], ["a producer", "a screenwriter", "an entertainer", "an executive producer"]], "culprit": [1]}
48
+ {"id": 193862, "claim": "Barry Van Dyke is the first son of Dick Van Dyke .", "questions": ["Who is the first son of Dick Van Dyke? or <mask> is the first son of Dick Van Dyke .", "What is Barry Van Dyke's first name? or Barry Van Dyke is <mask> of Dick Van Dyke .", "Who is Barry Van Dyke's father? or Barry Van Dyke is the first son of <mask> ."], "answers": [["Barry Van Dyke", 0, 14], ["the first son", 18, 31], ["Dick Van Dyke", 35, 48]], "evidential": [["Barry Van Dyke", "Barry van Dyke", "Dick Van Dyke", "A man"], ["the second son", "the first son", "the second child", "the son"], ["Dick Van Dyke", "an entertainer", "an actor", "a comedian"]], "culprit": [1]}
49
+ {"id": 55279, "claim": "Helmand Province contains a city .", "questions": ["What province contains a city? or <mask> contains a city .", "What does Helmand Province contain? or Helmand Province contains <mask> .", "What is the name of the city in Helmand Province? or Helmand Province <mask> a city ."], "answers": [["Helmand Province", 0, 16], ["a city", 26, 32], ["contains", 17, 25]], "evidential": [["Helmand Province", "Helmand province", "Helmand Provincial", "Helmand District"], ["people", "a city", "a town", "a population"], ["contains", "includes", "possesses", "features"]], "culprit": [1]}
50
+ {"id": 69871, "claim": "Robert Zemeckis has rarely directed movies .", "questions": ["Who has rarely directed a movie? or <mask> has rarely directed movies .", "What type of film has Zemeckis rarely directed? or Robert Zemeckis has rarely directed <mask> .", "What type of movies has Zemeckis rarely made? or Robert Zemeckis has rarely <mask> movies .", "How often has Zemeckis directed movies? or Robert Zemeckis has <mask> directed movies ."], "answers": [["Robert Zemeckis", 0, 15], ["movies", 36, 42], ["directed", 27, 35], ["rarely", 20, 26]], "evidential": [["Robert Zemeckis", "Robert Zemeckis", "Robert Zemckis", "Robert Memeckis"], ["a film", "a drama film", "a comedy", "a comedy film"], ["directed", "direct", "produced", "directing"], ["never", "rarely", "always", "only"]], "culprit": [3]}
51
+ {"id": 48276, "claim": "Raees ( film ) stars an Indian film actor born in April 1965 .", "questions": ["What film stars an Indian actor? or <mask> stars an Indian film actor born in April 1965 .", "What nationality is Raees? or Raees ( film ) stars <mask> film actor born in April 1965 .", "What is Raees' career? or Raees ( film ) stars an Indian <mask> born in April 1965 .", "When was Raees born? or Raees ( film ) stars an Indian film actor born in <mask> .", "What is Raees' career? or Raees ( film ) <mask> an Indian film actor born in April 1965 .", "What is the birth year of Raees? or Raees ( film ) stars an Indian film actor <mask> in April 1965 ."], "answers": [["Raees ( film )", 0, 14], ["an Indian", 21, 30], ["film actor", 31, 41], ["April 1965", 50, 60], ["stars", 15, 20], ["born", 42, 46]], "evidential": [["Raees ( film )", "Raees", "Raees( film )", "Raes ( film )"], ["an Indian", "a Indian", "An Indian", "an India"], ["film actor", "film actor and television personality", "actor", "television personality"], ["1965", "the sixties", "the 1960s", "the year 1965"], ["stars", "features", "starred", "includes"], ["born", "birth year", "birth date", "founded"]], "culprit": [3]}
52
+ {"id": 101845, "claim": "Richard Kuklinski is a innocent man .", "questions": ["Who is an innocent man? or <mask> is a innocent man .", "What is Richard Kuklinski? or Richard Kuklinski is <mask> ."], "answers": [["Richard Kuklinski", 0, 17], ["a innocent man", 21, 35]], "evidential": [["Richard Kuklinski", "Richard Kuklinski", "Richard Kuklinsky", "Richard Kuplinski"], ["a person", "a killer", "a serial killer", "a criminal"]], "culprit": [1]}
53
+ {"id": 44240, "claim": "Amancio Ortega refuses to be a businessman .", "questions": ["Who refuses to be a businessman? or <mask> refuses to be a businessman .", "What does Amancio Ortega refuse to be? or Amancio Ortega refuses to be <mask> .", "What does Amancio Ortega do to be a businessman? or Amancio Ortega <mask> to be a businessman ."], "answers": [["Amancio Ortega", 0, 14], ["a businessman", 29, 42], ["refuses", 15, 22]], "evidential": [["Amancio Ortega", "Amancio Ortega Gaona", "Amancio Ortega Jr.", "Amancio Orlando Ortega"], ["a businessman", "a tycoon", "a person", "a businessperson"], ["wants", "used", "works", "acts"]], "culprit": [2]}
54
+ {"id": 142735, "claim": "Elizabeth I was the daughter of a salesman .", "questions": ["What was my mother's name? or <mask> I was the daughter of a salesman .", "What was Elizabeth I's mother's name? or Elizabeth I was <mask> of a salesman .", "What was Elizabeth I's father's occupation? or Elizabeth I was the daughter of <mask> ."], "answers": [["Elizabeth", 0, 9], ["the daughter", 16, 28], ["a salesman", 32, 42]], "evidential": [["Elizabeth", "Elizabeth I", "ElizabethI", "Elizabeth II"], ["the daughter", "the second daughter", "the first daughter", "the second wife"], ["a man", "a second wife", "Henry VIII", "a person"]], "culprit": [2]}
55
+ {"id": 167977, "claim": "Don Bradman was called the \" greatest living Australian \" by a President .", "questions": ["Who was called the \"greatest living Australian\" by a President? or <mask> was called the \" greatest living Australian \" by a President .", "What nationality was Don Bradman? or Don Bradman was called the \" greatest living <mask> \" by a President .", "What was Bradman called by a President? or Don Bradman was called the \" greatest living Australian <mask> by a President .", "Who called Don Bradman the \"greatest living Australian\"? or Don Bradman was called the \" greatest living Australian \" by <mask> .", "What was Don Bradman called by a President? or Don Bradman was called <mask> Australian \" by a President ."], "answers": [["Don Bradman", 0, 11], ["Australian", 45, 55], ["\"", 56, 57], ["a President", 61, 72], ["the \" greatest living", 23, 44]], "evidential": [["Don Bradman", "Donald Bradman", "Don Bradm", "An Australian"], ["Australian", "American", "an Australian", "Australia"], ["person", "\"", "honored", "icon"], ["Prime Minister John Howard", "John Howard", "a Prime Minister", "the Prime Minister"], ["the \" greatest living", "the \" great living", "the \" best living", "the \" Greatest living"]], "culprit": [3]}
56
+ {"id": 227084, "claim": "Roar ( song ) is a Katy Perry song from her fifth album .", "questions": ["What is the name of Katy Perry's fifth album? or <mask> is a Katy Perry song from her fifth album .", "What is the name of the song Roar? or Roar ( song ) is <mask> song from her fifth album .", "What is Roar? or Roar ( song ) is a Katy Perry <mask> from her fifth album .", "What album is Roar from? or Roar ( song ) is a Katy Perry song from her <mask> ."], "answers": [["Roar ( song )", 0, 13], ["a Katy Perry", 17, 29], ["song", 30, 34], ["fifth album", 44, 55]], "evidential": [["Roar", "Roars", "Roar .", "Rar"], ["a Katy Perry", "an Katy Perry", "an American", "an artist 's"], ["song", "title", "single", "track"], ["fourth studio album", "fourth album", "fourth record", "fourth studio record"]], "culprit": [3]}
57
+ {"id": 205646, "claim": "St. Anger is the second studio album by Metallica .", "questions": ["What is the name of Metallica's second album? or <mask> is the second studio album by Metallica .", "What is the name of the second album by Metallica? or St. Anger is <mask> by Metallica .", "What band released St. Anger? or St. Anger is the second studio album by <mask> ."], "answers": [["St. Anger", 0, 9], ["the second studio album", 13, 36], ["Metallica", 40, 49]], "evidential": [["St. Anger", "The St. Anger", "St . Anger", "St. Anger ."], ["the eighth studio album", "an album", "an eighth studio album", "the eighth album"], ["Metallica", "a heavy metal band", "the Metallica", "a heavy metal group"]], "culprit": [1]}
58
+ {"id": 209095, "claim": "Stadium Arcadium was released after 2009 .", "questions": ["What stadium was released after 2009? or <mask> was released after 2009 .", "In what year was Stadium Arcadium released? or Stadium Arcadium was released after <mask> .", "What happened to Stadium Arcadium after 2009? or Stadium Arcadium was <mask> after 2009 .", "When was Stadium Arcadium released? or Stadium Arcadium was released <mask> 2009 ."], "answers": [["Stadium Arcadium", 0, 16], ["2009", 36, 40], ["released", 21, 29], ["after", 30, 35]], "evidential": [["Stadium Arcadium", "Stadium Arcadia", "Stadium Arcadadium", "Stadium Arcadion"], ["2006", "the 2000s", "a different year", "a 2006 album"], ["released", "disbanded", "dropped", "cancelled"], ["before", "after", "around", "back"]], "culprit": [1, 3]}
59
+ {"id": 155657, "claim": "The Prowler was created by Stan Lee , John Buscema , and dust .", "questions": ["What was the name of the film created by Stan Lee, John Buscema and Dust or <mask> was created by Stan Lee , John Buscema , and dust .", "Who created The Prowler? or The Prowler was created by <mask> , John Buscema , and dust .", "Who created The Prowler? or The Prowler was created by Stan Lee , <mask> , and dust .", "What was the Prowler made of? or The Prowler was created by Stan Lee , John Buscema , and <mask> .", "How was The Prowler <mask>? or The Prowler was <mask> by Stan Lee , John Buscema , and dust ."], "answers": [["The Prowler", 0, 11], ["Stan Lee", 27, 35], ["John Buscema", 38, 50], ["dust", 57, 61], ["created", 16, 23]], "evidential": [["The Prowler", "The Prowler ( 1981 film )", "The Prowler( 1981 film )", "Prowler ( 1981 film )"], ["Stan Lee", "Jim Mooney", "writer editor", "writer and editor"], ["comics editor", "comics writers", "people", "comics editors"], ["a writer", "a person", "characters", "comics"], ["created", "produced", "designed", "invented"]], "culprit": [3]}
60
+ {"id": 172095, "claim": "Selena Gomez & the Scene 's debut album was released in any month except September .", "questions": ["What group's debut album was released in any month except September? or <mask> 's debut album was released in any month except September .", "Selena Gomez & the Scene's debut album was released in what <mask> or Selena Gomez & the Scene 's debut album was released in any <mask> except September .", "Selena Gomez & the Scene's debut album was released in what month or Selena Gomez & the Scene 's debut album was released in any month except <mask> .", "When was Selena Gomez's debut album <mask>? or Selena Gomez & the Scene 's debut album was <mask> in any month except September ."], "answers": [["Selena Gomez & the Scene", 0, 24], ["month", 60, 65], ["September", 73, 82], ["released", 44, 52]], "evidential": [["Selena Gomez & the Scene", "The Selena Gomez & the Scene", "Selena Gomez & The Scene", "Selena Gomez & the Scene"], ["September", "the month", "the US", "the summer"], ["September", "October", "July", "August"], ["released", "published", "issued", "launched"]], "culprit": [2]}
61
+ {"id": 191441, "claim": "Keith Urban was released by Sony Music Entertainment .", "questions": ["What artist was released by Sony Music Entertainment? or <mask> was released by Sony Music Entertainment .", "What company released Keith Urban? or Keith Urban was released by <mask> .", "When was Keith Urban <mask>? or Keith Urban was <mask> by Sony Music Entertainment ."], "answers": [["Keith Urban", 0, 11], ["Sony Music Entertainment", 28, 52], ["released", 16, 24]], "evidential": [["Keith Urban", "Keith Urban II", "Keith U.", "The Keith Urban"], ["Capitol Nashville", "Capitol Records", "a company", "Capitol"], ["released", "created", "signed", "founded"]], "culprit": [1]}
62
+ {"id": 188640, "claim": "Foot Locker operates in only 11 countries .", "questions": ["What company operates in only 11 countries? or <mask> operates in only 11 countries .", "How many countries does Foot Locker operate in? or Foot Locker operates in <mask> countries .", "How does Foot Locker operate in 11 countries? or Foot Locker <mask> in only 11 countries ."], "answers": [["Foot Locker", 0, 11], ["only 11", 24, 31], ["operates", 12, 20]], "evidential": [["Foot Locker", "Foot Locker , Inc.", "Foot Locker ( Inc.", "Foot Locker Inc."], ["28", "least 28", "27", "29"], ["operates", "operate", "exists", "runs"]], "culprit": [1]}
63
+ {"id": 164407, "claim": "Carey Hayes is only a German lawyer .", "questions": ["Who is a German lawyer? or <mask> is only a German lawyer .", "What nationality is Hayes? or Carey Hayes is only <mask> lawyer .", "What is Hayes' profession? or Carey Hayes is only a German <mask> .", "How old is Carey Hayes? or Carey Hayes is <mask> German lawyer ."], "answers": [["Carey Hayes", 0, 11], ["a German", 20, 28], ["lawyer", 29, 35], ["only a", 15, 21]], "evidential": [["Carey Hayes", "Carey Hayes Jr.", "Carey Hayes", "Carey Hayden"], ["an American", "an american", "a North American", "an Oregon"], ["writer", "screenwriter", "a writer", "author"], ["a 21st century", "a 21 year old", "a 21-year old", "a young"]], "culprit": [1, 2, 3]}
64
+ {"id": 83545, "claim": "Volkswagen Group declines financing , leasing , and fleet management .", "questions": ["Which company declines financing, leasing and fleet management? or <mask> declines financing , leasing , and fleet management .", "What does Volkswagen Group decline? or Volkswagen Group declines financing , leasing , and <mask> .", "What does Volkswagen Group do with financing, leasing and fleet management? or Volkswagen Group <mask> financing , leasing , and fleet management ."], "answers": [["Volkswagen Group", 0, 16], ["fleet management", 52, 68], ["declines", 17, 25]], "evidential": [["Volkswagen Group", "The Volkswagen Group", "VW Group", "Volkswagen group"], ["fleet management", "fleet management services", "fleets management", "vehicles fleet management"], ["offers", "provides", "performs", "facilitates"]], "culprit": [2]}
65
+ {"id": 97837, "claim": "Caroline Kennedy is against diplomacy .", "questions": ["Who is against diplomacy? or <mask> is against diplomacy .", "Caroline Kennedy is against what? or Caroline Kennedy is against <mask> ."], "answers": [["Caroline Kennedy", 0, 16], ["diplomacy", 28, 37]], "evidential": [["Caroline Kennedy", "Caroline Flemming", "Caroline Klemming", "Caroline Kennedy"], ["politics", "the Democratic Party", "a presidential election", "a presidential campaign"]], "culprit": [1]}
66
+ {"id": 229309, "claim": "A working animal is released by humans .", "questions": ["What is released by humans? or <mask> is released by humans .", "Who releases a working animal? or A working animal is released by <mask> .", "What happens to a working animal when it is <mask>? or A working animal is <mask> by humans ."], "answers": [["A working animal", 0, 16], ["humans", 32, 38], ["released", 20, 28]], "evidential": [["A working animal", "A Working animal", "Working animal", "An animal"], ["humans", "a human", "human beings", "people"], ["kept", "domesticated", "raised", "captured"]], "culprit": [2]}
67
+ {"id": 98672, "claim": "Balibo ( film ) starts in the year 1995 .", "questions": ["What film was released in 1995? or <mask> starts in the year 1995 .", "When does Balibo begin? or Balibo ( film ) starts in <mask> .", "When does Balibo begin? or Balibo ( film ) <mask> in the year 1995 ."], "answers": [["Balibo ( film )", 0, 15], ["the year 1995", 26, 39], ["starts", 16, 22]], "evidential": [["Balibo", "Balibo ( film )", "Balibo( film )", "Balibo ( films )"], ["1975", "the 1970s", "the 1980s", "the year 1975"], ["begins", "starts", "began", "begin"]], "culprit": [1]}
68
+ {"id": 55239, "claim": "Victor Frankenstein is only a romance film .", "questions": ["What is the name of the film that is a romance? or <mask> is only a romance film .", "What is the purpose of Victor Frankenstein? or Victor Frankenstein is <mask> ."], "answers": [["Victor Frankenstein", 0, 19], ["only a romance film", 23, 42]], "evidential": [["Victor Frankenstein ( film )", "Victor Frankenstein", "Victor Frankenstein( film )", "Victor Frankenstein ( films )"], ["a film", "a motion picture", "a recorded work", "directed"]], "culprit": [1]}
69
+ {"id": 7728, "claim": "Hinduism has zero textual resources .", "questions": ["What religion has zero textual resources? or <mask> has zero textual resources .", "How many textual resources does Hinduism have? or Hinduism has <mask> ."], "answers": [["Hinduism", 0, 8], ["zero textual resources", 13, 35]], "evidential": [["Hinduism", "Hindu religion", "Indianism", "Buddhism"], ["multiple textual resources", "many shared textual resources", "shared textual resources", "many textual resources"]], "culprit": [1]}
70
+ {"id": 202475, "claim": "Tinker Tailor Soldier Spy only stars Gary Oldman .", "questions": ["What movie stars Gary Oldman? or <mask> only stars Gary Oldman .", "Who stars in Tinker Tailor Soldier Spy? or Tinker Tailor Soldier Spy only stars <mask> .", "What is Gary Oldman's first name? or Tinker Tailor Soldier Spy only <mask> Gary Oldman .", "How many episodes does Tinker Tailor Soldier Spy have? or Tinker Tailor Soldier Spy <mask> stars Gary Oldman ."], "answers": [["Tinker Tailor Soldier Spy", 0, 25], ["Gary Oldman", 37, 48], ["stars", 31, 36], ["only", 26, 30]], "evidential": [["Tinker Tailor Soldier Spy", "The Tinker Tailor Soldier Spy", "Tinker Tailor Soldier Spy", "Tinker Tailor Soldier Spy movie"], ["Gary Oldman", "an actor", "George Smiley", "a man"], ["stars", "features", "includes", "contains"], ["only", "one episode", "2 episodes", "one series"]], "culprit": [2, 3]}
71
+ {"id": 159091, "claim": "Guatemala has lived without war for its entire existence .", "questions": ["What country has lived without war for its entire existence? or <mask> has lived without war for its entire existence .", "What has Guatemala lived without? or Guatemala has lived without <mask> for its entire existence .", "How long has Guatemala lived without war? or Guatemala has lived without war for its <mask> .", "How long has Guatemala been without war? or Guatemala has <mask> without war for its entire existence ."], "answers": [["Guatemala", 0, 9], ["war", 28, 31], ["entire existence", 40, 56], ["lived", 14, 19]], "evidential": [["Guatemala", "Central America Guatemala", "Guatemalan", "Central America"], ["a military", "a government", "war", "a war"], ["time", "existence", "decade", "decades"], ["existed", "gone", "lived", "been"]], "culprit": [1, 2]}
72
+ {"id": 24481, "claim": "David Spade was fired from being in Joe Dirt 2 : Beautiful Loser .", "questions": ["Who was fired from being in Joe Dirt 2? or <mask> was fired from being in Joe Dirt 2 : Beautiful Loser .", "What was David Spade's first role in? or David Spade was fired from being in <mask> 2 : Beautiful Loser .", "How many episodes of Joe Dirt did Spade have? or David Spade was fired from being in Joe Dirt <mask> : Beautiful Loser .", "What was the title of Joe Dirt 2? or David Spade was fired from being in Joe Dirt 2 : <mask> .", "How did David Spade react to being in Joe Dirt 2? or David Spade was <mask> from being in Joe Dirt 2 : Beautiful Loser ."], "answers": [["David Spade", 0, 11], ["Joe Dirt", 36, 44], ["2", 45, 46], ["Beautiful Loser", 49, 64], ["fired", 16, 21]], "evidential": [["David Spade", "David Spades", "David Spade", "David Spader"], ["Joe Dirt", "the comedy Joe Dirt", "the film Joe Dirt", "the movie Joe Dirt"], ["2", "two episodes", "two", "2 :"], ["Beautiful Loser", "Beautiful Ler", "BeautifulLoser", "Beautiful Losers"], ["distracted", "banned", "traumatized", "disheartened"]], "culprit": [4]}
73
+ {"id": 67876, "claim": "Britt Robertson was not in the television series Girlboss .", "questions": ["Who was not in the television series Girlboss? or <mask> was not in the television series Girlboss .", "What television series did Britt Robertson not appear in? or Britt Robertson was not in the television series <mask> .", "What was Britt Robertson not in? or Britt Robertson was not in <mask> Girlboss ."], "answers": [["Britt Robertson", 0, 15], ["Girlboss", 49, 57], ["the television series", 27, 48]], "evidential": [["Britt Robertson", "Brittany Robertson", "Britt Roberts", "Brit Robertson"], ["Girlboss", "The Secret Circle", "Girlsboss", "a Netflix comedy"], ["the comedy television series", "the show", "the comedy TV series", "the TV series"]], "culprit": [1, 2]}
74
+ {"id": 76324, "claim": "Richard Dawson is still alive .", "questions": ["Who is still alive? or <mask> is still alive .", "How old is Richard Dawson? or Richard Dawson is <mask> alive .", "What is Richard Dawson's age? or Richard Dawson is still <mask> ."], "answers": [["Richard Dawson", 0, 14], ["still", 18, 23], ["alive", 24, 29]], "evidential": [["Richard Dawson", "Richard Dwayne Dawson", "Richard Dawsons", "Richard D Dawson"], ["still", "alive", "barely", "currently"], ["dead", "deceased", "alive", "63"]], "culprit": [1, 2]}
75
+ {"id": 104710, "claim": "Miranda Otto is the son of Barry Otto .", "questions": ["Who is the son of Barry Otto? or <mask> is the son of Barry Otto .", "What is Miranda Otto's biological name? or Miranda Otto is <mask> of Barry Otto .", "Who is Miranda Otto's father? or Miranda Otto is the son of <mask> ."], "answers": [["Miranda Otto", 0, 12], ["the son", 16, 23], ["Barry Otto", 27, 37]], "evidential": [["Miranda Otto", "Miriam Otto", "Miranda Oster", "Miranda Oste"], ["the daughter", "the sister", "the biological daughter", "the granddaughter"], ["an actor", "Barry Otto", "an actress", "a man"]], "culprit": [1]}
76
+ {"id": 92988, "claim": "See You on the Other Side is a boat .", "questions": ["What side of the boat is See You on? or See You on the Other <mask> is a boat .", "What is See You on the Other Side? or See You on the Other Side is <mask> .", "What is the name of the boat that is on the other side? or <mask> You on the Other Side is a boat ."], "answers": [["Side", 21, 25], ["a boat", 29, 35], ["See", 0, 3]], "evidential": [["Side", "side", "side 2", "Side 2"], ["an album", "a recorded work", "a record", "a work"], ["See", "The album", "See '", "see"]], "culprit": [1]}
77
+ {"id": 150834, "claim": "Tool has not produced albums .", "questions": ["Which tool has not produced an album? or <mask> has not produced albums .", "Tool has not produced what? or Tool has not produced <mask> .", "Tool has not what type of albums? or Tool has not <mask> albums ."], "answers": [["Tool", 0, 4], ["albums", 22, 28], ["produced", 13, 21]], "evidential": [["Tool", "Tool ( band )", "Tool( band )", "Tool(band )"], ["albums", "an album", "music", "records"], ["produced", "released", "created", "published"]], "culprit": [1]}
78
+ {"id": 135684, "claim": "Elizabeth I was the son of Anne Boleyn .", "questions": ["Who was the son of Anne Boleyn? or <mask> I was the son of Anne Boleyn .", "What was Elizabeth I's father's name? or Elizabeth I was <mask> of Anne Boleyn .", "Who was Elizabeth I's mother? or Elizabeth I was the son of <mask> ."], "answers": [["Elizabeth", 0, 9], ["the son", 16, 23], ["Anne Boleyn", 27, 38]], "evidential": [["Elizabeth", "Elizabeth I", "Queen Elizabeth", "ElizabethI"], ["the daughter", "the child", "the son", "a daughter"], ["Anne Boleyn", "Ann Boleyn", "Anne Bolyn", "a woman"]], "culprit": [1]}
79
+ {"id": 124045, "claim": "Ron Weasley was denied membership to Gryffindor house .", "questions": ["Who was denied membership to Gryffindor house? or <mask> was denied membership to Gryffindor house .", "What was Ron Weasley denied? or Ron Weasley was denied <mask> to Gryffindor house .", "What house was Ron Weasley denied membership to? or Ron Weasley was denied membership to <mask> house .", "What was Ron Weasley denied membership to? or Ron Weasley was denied membership to Gryffindor <mask> .", "What was Ron Weasley's status as a member of Gryffindor or Ron Weasley was <mask> membership to Gryffindor house ."], "answers": [["Ron Weasley", 0, 11], ["membership", 23, 33], ["Gryffindor", 37, 47], ["house", 48, 53], ["denied", 16, 22]], "evidential": [["Ron Weasley", "The Ron Weasley", "A Ron Weasley", "Ronald Weasley"], ["access", "a visit", "membership", "a membership"], ["the Gryffindor", "a Gryffindor", "The Gryffindor", "the Gryfindor"], ["house", "family", "houses", "home"], ["given", "granted", "denied", "required"]], "culprit": [4]}
80
+ {"id": 56381, "claim": "Lorelai Gilmore 's uncle was played by Edward Herrmann .", "questions": ["Who was the uncle of Edward Herrmann? or <mask> 's uncle was played by Edward Herrmann .", "Who played Lorelai Gilmore's uncle? or Lorelai Gilmore 's uncle was played by <mask> .", "What role did Edward Herrmann play in Lorelai Gilmore's uncle? or Lorelai Gilmore 's uncle was <mask> by Edward Herrmann ."], "answers": [["Lorelai Gilmore", 0, 15], ["Edward Herrmann", 39, 54], ["played", 29, 35]], "evidential": [["Lorelai Gilmore", "Lorelai Gilmore", "Lorelai Gilpin", "Lorelai Glyn"], ["Edward Herrmann", "an actor", "Edward Herrman", "a man"], ["played", "portrayed", "performed", "voiced"]], "culprit": [1]}
81
+ {"id": 78742, "claim": "Tim Roth is not an English actor .", "questions": ["Who is an English actor? or <mask> is not an English actor .", "What is Tim Roth's nationality? or Tim Roth is not <mask> actor .", "What is Tim Roth's profession? or Tim Roth is not an English <mask> ."], "answers": [["Tim Roth", 0, 8], ["an English", 16, 26], ["actor", 27, 32]], "evidential": [["Tim Roth", "Timothy Roth", "Tim Roth", "Tim R Roth"], ["an English", "a European", "a British", "an European"], ["actor", "director", "film actor", "film director"]], "culprit": [1, 2]}
82
+ {"id": 180717, "claim": "Victoria ( Dance Exponents song ) was released in New Zealand in 1980 .", "questions": ["What song was released in New Zealand in 1980? or <mask> was released in New Zealand in 1980 .", "Where was Victoria released? or Victoria ( Dance Exponents song ) was released in <mask> in 1980 .", "When was Victoria released in New Zealand? or Victoria ( Dance Exponents song ) was released in New Zealand in <mask> .", "What was the name of Victoria's song? or Victoria ( Dance Exponents song ) was <mask> in New Zealand in 1980 ."], "answers": [["Victoria ( Dance Exponents song )", 0, 33], ["New Zealand", 50, 61], ["1980", 65, 69], ["released", 38, 46]], "evidential": [["Victoria / Dance Exponents song", "Victoria", "Victoria Song", "Victoria song"], ["New Zealand", "China", "Australia", "the world"], ["1982", "the 1980s", "the eighties", "the nineties"], ["released", "performed", "played", "recorded"]], "culprit": [2]}
83
+ {"id": 125491, "claim": "Hot Right Now is from the album Escape from Planet Monday .", "questions": ["What is the name of the song from the album Escape from Planet Monday? or <mask> is from the album Escape from Planet Monday .", "What is the name of the album that Hot Right Now is from? or Hot Right Now is from <mask> from Planet Monday .", "What album is Hot Right Now from? or Hot Right Now is from the album <mask> ."], "answers": [["Hot Right Now", 0, 13], ["the album Escape", 22, 38], ["Escape from Planet Monday", 32, 57]], "evidential": [["Hot Right Now", "Hot right Now", "Hit Right Now", "Hot Right now"], ["Escape", "the album Escape", "an album", "the single Escape"], ["Escape from Planet Monday", "Nextlevelism", "Escape From Planet Monday", "Next Levelism"]], "culprit": [1, 2]}
84
+ {"id": 100204, "claim": "Shadowhunters did not premiere in 2016 .", "questions": ["What movie did not premiere in 2016? or <mask> did not premiere in 2016 .", "When did Shadowhunters not premiere? or Shadowhunters did not premiere in <mask> .", "What did Shadowhunters not do in 2016? or Shadowhunters did not <mask> in 2016 ."], "answers": [["Shadowhunters", 0, 13], ["2016", 34, 38], ["premiere", 22, 30]], "evidential": [["Shadowhunters", "The Shadowhunters", "Shadowshunters", "Shadowhunterters"], ["2016", "2015", "January 2016", "the 2010s"], ["premiere", "air", "start", "launch"]], "culprit": [1]}
85
+ {"id": 73208, "claim": "Reign Over Me was written and directed by Spike Lee .", "questions": ["What movie was directed by Spike Lee? or <mask> was written and directed by Spike Lee .", "Who directed Reign Over Me? or Reign Over Me was written and directed by <mask> .", "What was the name of the film that directed it? or Reign Over Me was <mask> and directed by Spike Lee .", "What was the film <mask> by Spike Lee? or Reign Over Me was written and <mask> by Spike Lee ."], "answers": [["Reign Over Me", 0, 13], ["Spike Lee", 42, 51], ["written", 18, 25], ["directed", 30, 38]], "evidential": [["Reign Over Me", "Reign over Me", "Reign of Me", "Reign Over me"], ["a man", "an American", "a person", "an actor"], ["written", "penned", "authored", "wrote"], ["directed", "produced", "written", "created"]], "culprit": [1]}
86
+ {"id": 225871, "claim": "Revolver has only ever topped a single chart .", "questions": ["What has only ever topped a single chart? or <mask> has only ever topped a single chart .", "How many charts has Revolver ever topped? or Revolver has only ever topped <mask> .", "How many times has Revolver ever <mask> a single chart? or Revolver has only ever <mask> a single chart .", "How many times has Revolver topped a single chart? or Revolver has <mask> ever topped a single chart ."], "answers": [["Revolver", 0, 8], ["a single chart", 30, 44], ["topped", 23, 29], ["only", 13, 17]], "evidential": [["Revolver", "Revololver", "Revolver Record", "The Revolver"], ["four charts", "two charts", "three charts", "zero charts"], ["topped", "charted", "reached", "appeared"], ["never", "n't", "only", "rarely"]], "culprit": [1, 3]}
87
+ {"id": 125225, "claim": "Omar Khadr has always been free .", "questions": ["Who has always been free? or <mask> has always been free .", "How long has Omar Khadr been free? or Omar Khadr has <mask> been free .", "Omar Khadr has always been what? or Omar Khadr has always been <mask> ."], "answers": [["Omar Khadr", 0, 10], ["always", 15, 21], ["free", 27, 31]], "evidential": [["Omar Khadr", "Omar Khadri", "Omar Khadr", "Om Khadr"], ["yet", "never", "always", "since"], ["a prisoner", "a person", "a human", "a detainee"]], "culprit": [1, 2]}
88
+ {"id": 174514, "claim": "Red Bull Racing was eradicated in the United Kingdom .", "questions": ["What was the name of the race that was eradicated in the UK? or <mask> was eradicated in the United Kingdom .", "Where was Red Bull Racing eradicated? or Red Bull Racing was eradicated in <mask> .", "What happened to Red Bull Racing in the UK? or Red Bull Racing was <mask> in the United Kingdom ."], "answers": [["Red Bull Racing", 0, 15], ["the United Kingdom", 34, 52], ["eradicated", 20, 30]], "evidential": [["Red Bull Racing", "Red Bull R&B Racing", "Red Bull Racing", "Red Bull racing"], ["Austria", "Europe", "the UK", "England"], ["acquired", "founded", "established", "created"]], "culprit": [2]}
89
+ {"id": 67464, "claim": "Louie ( season 1 ) was created by David Benioff .", "questions": ["What was the name of the show created by David Benioff? or <mask> was created by David Benioff .", "Who created Louie? or Louie ( season 1 ) was created by <mask> .", "What was the name of Louie? or Louie ( season 1 ) was <mask> by David Benioff ."], "answers": [["Louie ( season 1 )", 0, 18], ["David Benioff", 34, 47], ["created", 23, 30]], "evidential": [["Louie", "Louie ( season 1 )", "Louis C.K.", "The show Louie"], ["Louis C.K", "a person", "a series creator", "a man"], ["created", "written", "penned", "produced"]], "culprit": [1, 2]}
90
+ {"id": 84710, "claim": "Buffy the Vampire Slayer is created by Joss Whedon in 1990 .", "questions": ["What movie was created by Joss Whedon? or <mask> is created by Joss Whedon in 1990 .", "Who created Buffy the Vampire Slayer? or Buffy the Vampire Slayer is created by <mask> in 1990 .", "When was Buffy the Vampire Slayer created? or Buffy the Vampire Slayer is created by Joss Whedon in <mask> .", "What was the name of the film that made Buffy the Vampire Slayer? or Buffy the Vampire Slayer is <mask> by Joss Whedon in 1990 ."], "answers": [["Buffy the Vampire Slayer", 0, 24], ["Joss Whedon", 39, 50], ["1990", 54, 58], ["created", 28, 35]], "evidential": [["Buffy the Vampire Slayer", "The Buffy the Vampire Slayer", "Buffy The Vampire Slayer", "Buffy of the Vampire Slayer"], ["Joss Whedon", "a person", "a man", "an American"], ["the 1990s", "the 2000s", "the nineties", "1992"], ["created", "produced", "directed", "a film"]], "culprit": [2]}
91
+ {"id": 198041, "claim": "The New York City Landmarks Preservation Commission includes zero architects .", "questions": ["What organization has zero architects? or <mask> includes zero architects .", "How many architects does the New York City Landmarks Preservation Commission have? or The New York City Landmarks Preservation Commission includes <mask> .", "How many architects does the New York City Landmarks Preservation Commission have? or The New York City Landmarks Preservation Commission <mask> zero architects ."], "answers": [["The New York City Landmarks Preservation Commission", 0, 51], ["zero architects", 61, 76], ["includes", 52, 60]], "evidential": [["The New York City Landmarks Preservation Commission", "New York City Landmarks Preservation Commission", "The New York City Landmarks Preservation commission", "A New York City Landmarks Preservation Commission"], ["11 architects", "three architects", "11 commissioners", "ten architects"], ["includes", "contains", "consists", "involves"]], "culprit": [1]}
92
+ {"id": 42390, "claim": "Jack Falahee is Mongolian .", "questions": ["Who is the Mongolian whose name is? or <mask> is Mongolian .", "What nationality is Jack Falahee? or Jack Falahee is <mask> ."], "answers": [["Jack Falahee", 0, 12], ["Mongolian", 16, 25]], "evidential": [["Jack Falahee", "Jack Falahe", "John Falahee", "Jack Falaefhee"], ["American", "an American", "North American", "European"]], "culprit": [1]}
93
+ {"id": 175736, "claim": "The Cry of the Owl is based on Patricia Highsmith 's eighth movie .", "questions": ["What is the name of the movie based on Patricia Highsmith's eighth film? or <mask> is based on Patricia Highsmith 's eighth movie .", "Who wrote the movie The Cry of the Owl? or The Cry of the Owl is based on <mask> 's eighth movie .", "What is the story of The Cry Of The Owl? or The Cry of the Owl is <mask> on Patricia Highsmith 's eighth movie .", "What was the first movie based on? or The Cry of the Owl is based on Patricia Highsmith 's <mask> movie ."], "answers": [["The Cry of the Owl", 0, 18], ["Patricia Highsmith", 31, 49], ["based", 22, 27], ["eighth", 53, 59]], "evidential": [["The Cry of the Owl", "The Cry of the Owl ( 2009 film )", "The Cry of the Owl( 2009 film )", "The Cry of the Owl(2009 film )"], ["Patricia Highsmith", "an author", "a writer", "a novelist"], ["based", "a story", "a novel", "loosely"], ["first", "book", "novel", "a novel"]], "culprit": [3]}
94
+ {"id": 152929, "claim": "Firefox is an operating system shell .", "questions": ["What is the name of the operating system shell? or <mask> is an operating system shell .", "What is Firefox? or Firefox is <mask> ."], "answers": [["Firefox", 0, 7], ["an operating system shell", 11, 36]], "evidential": [["Firefox", "Mozilla", "Mozilla Firefox", "The Firefox"], ["a web browser", "a free web browser", "open source", "a free software application"]], "culprit": [1]}
95
+ {"id": 183589, "claim": "Finding Dory was directed by Ingmar Bergman .", "questions": ["What movie was directed by Ingmar Bergman? or <mask> was directed by Ingmar Bergman .", "Who directed Finding Dory? or Finding Dory was directed by <mask> .", "What was the name of the film that starred in Finding Dory? or Finding Dory was <mask> by Ingmar Bergman ."], "answers": [["Finding Dory", 0, 12], ["Ingmar Bergman", 29, 43], ["directed", 17, 25]], "evidential": [["Finding Dory", "The Finding Dory", "Finding dory", "Finding Dory movie"], ["Andrew Stanton", "Angus MacLane", "a person", "Angus Maclane"], ["directed", "written", "produced", "penned"]], "culprit": [1]}
96
+ {"id": 108957, "claim": "Agent Raghav \u2013 Crime Branch is a phone .", "questions": ["What is the name of the agent that is on the phone? or <mask> is a phone .", "What is the name of the agent in the Crime Branch? or Agent Raghav \u2013 Crime Branch is <mask> ."], "answers": [["Agent Raghav \u2013 Crime Branch", 0, 27], ["a phone", 31, 38]], "evidential": [["Agent Raghav \u2013 Crime Branch", "Agent Raghav - Crime Branch", "Agent Raghav", "Agent Raghav \u2014 Crime Branch"], ["an anthology television series", "a serial", "a television serial", "a television series"]], "culprit": [1]}
97
+ {"id": 3160, "claim": "University of Chicago Law School is ranked first in the 2016 QS World University Rankings .", "questions": ["What is the name of the law school that is ranked first in the 2016 QS World or <mask> is ranked first in the 2016 QS World University Rankings .", "What is the name of the organization that ranks law schools in the world? or University of Chicago Law School is ranked first in the 2016 <mask> .", "What is the ranking of University of Chicago Law School in the 2016 QS World University Rankings or University of Chicago Law School is <mask> first in the 2016 QS World University Rankings .", "What is the ranking of University of Chicago Law School in the 2016 QS World University Rankings or University of Chicago Law School is ranked <mask> in the 2016 QS World University Rankings .", "In what year did the University of Chicago Law School rank first in the QS World University Ranking or University of Chicago Law School is ranked first in <mask> QS World University Rankings ."], "answers": [["University of Chicago Law School", 0, 32], ["QS World University Rankings", 61, 89], ["ranked", 36, 42], ["first", 43, 48], ["the 2016", 52, 60]], "evidential": [["University of Chicago Law School", "The University of Chicago Law School", "the University of Chicago Law School", "University of Chicago law School"], ["QS World University Rankings", "the QS World University Rankings", "S&S World University Rankings", "QS World University Rankings."], ["ranked", "listed", "placed", "Ranked"], ["12th", "11th", "twelveth", "ninth"], ["the 2016", "the 2015", "2016", "The 2016"]], "culprit": [3]}
98
+ {"id": 148309, "claim": "The Adventures of Pluto Nash failed to be a released film .", "questions": ["What failed to be a released film? or <mask> failed to be a released film .", "What did The Adventures of Pluto Nash fail to be? or The Adventures of Pluto Nash failed to be <mask> .", "What was the result of The Adventures of Pluto Nash? or The Adventures of Pluto Nash <mask> to be a released film ."], "answers": [["The Adventures of Pluto Nash", 0, 28], ["a released film", 42, 57], ["failed", 29, 35]], "evidential": [["The Adventures of Pluto Nash", "The adventures of Pluto Nash", "The Adventures of Pluto N", "An Adventures of Pluto Nash"], ["a release", "released", "an release", "release"], ["happened", "ceased", "turned", "failed"]], "culprit": [2]}
99
+ {"id": 227135, "claim": "The New Orleans Pelicans only play in the NHL .", "questions": ["Who plays in the NHL? or <mask> only play in the NHL .", "What league do the New Orleans Pelicans play in? or The New Orleans Pelicans only play in <mask> .", "What do the New Orleans Pelicans only do in the NHL? or The New Orleans Pelicans only <mask> in the NHL .", "How many of the New Orleans Pelicans play in the NHL? or The New Orleans Pelicans <mask> play in the NHL ."], "answers": [["The New Orleans Pelicans", 0, 24], ["the NHL", 38, 45], ["play", 30, 34], ["only", 25, 29]], "evidential": [["The New Orleans Pelicans", "New Orleans Pelicans", "the New Orleans Pelicans", "The New Orleans Saints"], ["the National Basketball Association", "the NBA", "a league", "the Western Conference"], ["play", "compete", "participate", "plays"], ["only", "two", "one", "currently"]], "culprit": [1, 3]}
100
+ {"id": 126678, "claim": "The Colosseum is a wrestler from Italy .", "questions": ["What is the name of the wrestler from Italy? or <mask> is a wrestler from Italy .", "Who is the Colosseum? or The Colosseum is <mask> from Italy .", "Where is The Colosseum? or The Colosseum is a wrestler from <mask> ."], "answers": [["The Colosseum", 0, 13], ["a wrestler", 17, 27], ["Italy", 33, 38]], "evidential": [["The Colosseum", "Colosseum", "The colosseum", "A Colosseum"], ["a tourist attraction", "an attraction", "an amphitheater", "a popular tourist attraction"], ["Rome", "Italy", "a city", "the city"]], "culprit": [1]}
src/eval_client/fever_scorer.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding:utf-8 -*-
2
+
3
+ """
4
+ @Author : Bao
5
+ @Date : 2020/8/24
6
+ @Desc :
7
+ @Last modified by : Bao
8
+ @Last modified date : 2020/9/1
9
+ """
10
+
11
+ import os
12
+ import json
13
+ import numpy as np
14
+ from collections import defaultdict
15
+ import tensorflow as tf
16
+ from sklearn.metrics import precision_recall_fscore_support
17
+ try:
18
+ from .scorer import fever_score
19
+ except:
20
+ from scorer import fever_score
21
+
22
+
23
+ prefix = os.environ['PJ_HOME']
24
+
25
+
26
+ class FeverScorer:
27
+ def __init__(self):
28
+ self.id2label = {2: 'SUPPORTS', 0: 'REFUTES', 1: 'NOT ENOUGH INFO'}
29
+ self.label2id = {value: key for key, value in self.id2label.items()}
30
+
31
+ def get_scores(self, predicted_file, actual_file=f'{prefix}/data/fever/shared_task_dev.jsonl'):
32
+ id2results = defaultdict(dict)
33
+
34
+ with tf.io.gfile.GFile(predicted_file) as f:
35
+ for line in f:
36
+ js = json.loads(line)
37
+ guid = js['id']
38
+ id2results[guid] = js
39
+
40
+ with tf.io.gfile.GFile(actual_file) as fin:
41
+ for line in fin:
42
+ line = json.loads(line)
43
+ guid = line['id']
44
+ evidence = line['evidence']
45
+ label = line['label']
46
+ id2results[guid]['evidence'] = evidence
47
+ id2results[guid]['label'] = label
48
+
49
+ results = self.label_score(list(id2results.values()))
50
+ score, accuracy, precision, recall, f1 = fever_score(list(id2results.values()))
51
+ results.update({
52
+ 'Evidence Precision': precision,
53
+ 'Evidence Recall': recall,
54
+ 'Evidence F1': f1,
55
+ 'FEVER Score': score,
56
+ 'Label Accuracy': accuracy
57
+ })
58
+
59
+ return results
60
+
61
+ def label_score(self, results):
62
+ truth = np.array([v['label'] for v in results])
63
+ prediction = np.array([v['predicted_label'] for v in results])
64
+ labels = list(self.label2id.keys())
65
+ results = {}
66
+ p, r, f, _ = precision_recall_fscore_support(truth, prediction, labels=labels)
67
+ for i, label in enumerate(self.label2id.keys()):
68
+ results['{} Precision'.format(label)] = p[i]
69
+ results['{} Recall'.format(label)] = r[i]
70
+ results['{} F1'.format(label)] = f[i]
71
+
72
+ return results
73
+
74
+
75
+ if __name__ == '__main__':
76
+ import argparse
77
+
78
+ parser = argparse.ArgumentParser()
79
+ parser.add_argument("--predicted_file", '-i', type=str)
80
+ args = parser.parse_args()
81
+
82
+ scorer = FeverScorer()
83
+ results = scorer.get_scores(args.predicted_file)
84
+ print(json.dumps(results, ensure_ascii=False, indent=4))
src/eval_client/scorer.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import six
2
+
3
+ def check_predicted_evidence_format(instance):
4
+ if 'predicted_evidence' in instance.keys() and len(instance['predicted_evidence']):
5
+ assert all(isinstance(prediction, list)
6
+ for prediction in instance["predicted_evidence"]), \
7
+ "Predicted evidence must be a list of (page,line) lists"
8
+
9
+ assert all(len(prediction) == 2
10
+ for prediction in instance["predicted_evidence"]), \
11
+ "Predicted evidence must be a list of (page,line) lists"
12
+
13
+ assert all(isinstance(prediction[0], six.string_types)
14
+ for prediction in instance["predicted_evidence"]), \
15
+ "Predicted evidence must be a list of (page<string>,line<int>) lists"
16
+
17
+ assert all(isinstance(prediction[1], int)
18
+ for prediction in instance["predicted_evidence"]), \
19
+ "Predicted evidence must be a list of (page<string>,line<int>) lists"
20
+
21
+
22
+ def is_correct_label(instance):
23
+ return instance["label"].upper() == instance["predicted_label"].upper()
24
+
25
+
26
+ def is_strictly_correct(instance, max_evidence=None):
27
+ #Strict evidence matching is only for NEI class
28
+ check_predicted_evidence_format(instance)
29
+
30
+ if instance["label"].upper() != "NOT ENOUGH INFO" and is_correct_label(instance):
31
+ assert 'predicted_evidence' in instance, "Predicted evidence must be provided for strict scoring"
32
+
33
+ if max_evidence is None:
34
+ max_evidence = len(instance["predicted_evidence"])
35
+
36
+
37
+ for evience_group in instance["evidence"]:
38
+ #Filter out the annotation ids. We just want the evidence page and line number
39
+ actual_sentences = [[e[2], e[3]] for e in evience_group]
40
+ #Only return true if an entire group of actual sentences is in the predicted sentences
41
+ if all([actual_sent in instance["predicted_evidence"][:max_evidence] for actual_sent in actual_sentences]):
42
+ return True
43
+
44
+ #If the class is NEI, we don't score the evidence retrieval component
45
+ elif instance["label"].upper() == "NOT ENOUGH INFO" and is_correct_label(instance):
46
+ return True
47
+
48
+ return False
49
+
50
+
51
+ def evidence_macro_precision(instance, max_evidence=None):
52
+ this_precision = 0.0
53
+ this_precision_hits = 0.0
54
+
55
+ if instance["label"].upper() != "NOT ENOUGH INFO":
56
+ all_evi = [[e[2], e[3]] for eg in instance["evidence"] for e in eg if e[3] is not None]
57
+
58
+ predicted_evidence = instance["predicted_evidence"] if max_evidence is None else \
59
+ instance["predicted_evidence"][:max_evidence]
60
+
61
+ for prediction in predicted_evidence:
62
+ if prediction in all_evi:
63
+ this_precision += 1.0
64
+ this_precision_hits += 1.0
65
+
66
+ return (this_precision / this_precision_hits) if this_precision_hits > 0 else 1.0, 1.0
67
+
68
+ return 0.0, 0.0
69
+
70
+ def evidence_macro_recall(instance, max_evidence=None):
71
+ # We only want to score F1/Precision/Recall of recalled evidence for NEI claims
72
+ if instance["label"].upper() != "NOT ENOUGH INFO":
73
+ # If there's no evidence to predict, return 1
74
+ if len(instance["evidence"]) == 0 or all([len(eg) == 0 for eg in instance]):
75
+ return 1.0, 1.0
76
+
77
+ predicted_evidence = instance["predicted_evidence"] if max_evidence is None else \
78
+ instance["predicted_evidence"][:max_evidence]
79
+
80
+ for evidence_group in instance["evidence"]:
81
+ evidence = [[e[2], e[3]] for e in evidence_group]
82
+ if all([item in predicted_evidence for item in evidence]):
83
+ # We only want to score complete groups of evidence. Incomplete groups are worthless.
84
+ return 1.0, 1.0
85
+ return 0.0, 1.0
86
+ return 0.0, 0.0
87
+
88
+
89
+ # Micro is not used. This code is just included to demostrate our model of macro/micro
90
+ def evidence_micro_precision(instance):
91
+ this_precision = 0
92
+ this_precision_hits = 0
93
+
94
+ # We only want to score Macro F1/Precision/Recall of recalled evidence for NEI claims
95
+ if instance["label"].upper() != "NOT ENOUGH INFO":
96
+ all_evi = [[e[2], e[3]] for eg in instance["evidence"] for e in eg if e[3] is not None]
97
+
98
+ for prediction in instance["predicted_evidence"]:
99
+ if prediction in all_evi:
100
+ this_precision += 1.0
101
+ this_precision_hits += 1.0
102
+
103
+ return this_precision, this_precision_hits
104
+
105
+
106
+ def fever_score(predictions,actual=None, max_evidence=5):
107
+ correct = 0
108
+ strict = 0
109
+
110
+ macro_precision = 0
111
+ macro_precision_hits = 0
112
+
113
+ macro_recall = 0
114
+ macro_recall_hits = 0
115
+
116
+ for idx,instance in enumerate(predictions):
117
+ assert 'predicted_evidence' in instance.keys(), 'evidence must be provided for the prediction'
118
+
119
+ #If it's a blind test set, we need to copy in the values from the actual data
120
+ if 'evidence' not in instance or 'label' not in instance:
121
+ assert actual is not None, 'in blind evaluation mode, actual data must be provided'
122
+ assert len(actual) == len(predictions), 'actual data and predicted data length must match'
123
+ assert 'evidence' in actual[idx].keys(), 'evidence must be provided for the actual evidence'
124
+ instance['evidence'] = actual[idx]['evidence']
125
+ instance['label'] = actual[idx]['label']
126
+
127
+ assert 'evidence' in instance.keys(), 'gold evidence must be provided'
128
+
129
+ if is_correct_label(instance):
130
+ correct += 1.0
131
+
132
+ if is_strictly_correct(instance, max_evidence):
133
+ strict+=1.0
134
+
135
+ macro_prec = evidence_macro_precision(instance, max_evidence)
136
+ macro_precision += macro_prec[0]
137
+ macro_precision_hits += macro_prec[1]
138
+
139
+ macro_rec = evidence_macro_recall(instance, max_evidence)
140
+ macro_recall += macro_rec[0]
141
+ macro_recall_hits += macro_rec[1]
142
+
143
+ total = len(predictions)
144
+
145
+ strict_score = strict / total
146
+ acc_score = correct / total
147
+
148
+ pr = (macro_precision / macro_precision_hits) if macro_precision_hits > 0 else 1.0
149
+ rec = (macro_recall / macro_recall_hits) if macro_recall_hits > 0 else 0.0
150
+
151
+ f1 = 2.0 * pr * rec / (pr + rec)
152
+
153
+ return strict_score, acc_score, pr, rec, f1
src/loren.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ '''
4
+ @Author : Jiangjie Chen
5
+ @Time : 2020/9/17 15:55
6
+ @Contact : jjchen19@fudan.edu.cn
7
+ @Description:
8
+ '''
9
+
10
+ import os
11
+ import sys
12
+ import json
13
+ import logging
14
+ import cjjpy as cjj
15
+
16
+ try:
17
+ from .qg_client.question_generator import QuestionGenerator
18
+ from .mrc_client.answer_generator import AnswerGenerator, chunks, assemble_answers_to_one
19
+ from .parsing_client.sentence_parser import SentenceParser, deal_bracket
20
+ from .check_client.fact_checker import FactChecker, id2label
21
+ from .er_client import EvidenceRetrieval
22
+ except:
23
+ sys.path.append(cjj.AbsParentDir(__file__, '.'))
24
+ from qg_client.question_generator import QuestionGenerator
25
+ from mrc_client.answer_generator import AnswerGenerator, chunks, assemble_answers_to_one
26
+ from parsing_client.sentence_parser import SentenceParser, deal_bracket
27
+ from check_client.fact_checker import FactChecker, id2label
28
+ from er_client import EvidenceRetrieval
29
+
30
+
31
+ def load_config(config):
32
+ if isinstance(config, str):
33
+ with open(config) as f:
34
+ config = json.load(f)
35
+ cfg = cjj.AttrDict(config)
36
+ return cfg
37
+
38
+
39
+ class Loren:
40
+ def __init__(self, config_file, verbose=True):
41
+ self.verbose = verbose
42
+ self.args = load_config(config_file)
43
+ self.sent_client = SentenceParser()
44
+ self.qg_client = QuestionGenerator('t5', verbose=False)
45
+ self.ag_client = AnswerGenerator(self.args.mrc_dir)
46
+ self.fc_client = FactChecker(self.args, self.args.fc_dir)
47
+ self.er_client = EvidenceRetrieval(self.args.er_dir)
48
+ self.logger = cjj.init_logger(f'{os.environ["PJ_HOME"]}/results/loren_dev.log',
49
+ log_file_level=logging.INFO if self.verbose else logging.WARNING)
50
+ self.logger.info('*** Loren initialized. ***')
51
+
52
+ def check(self, claim, evidence=None):
53
+ self.logger.info('*** Verifying "%s"... ***' % claim)
54
+ js = self.prep(claim, evidence)
55
+ js['id'] = 0
56
+ y_predicted, z_predicted, m_attn = self.fc_client.check_from_batch([js], verbose=self.verbose)
57
+ label = id2label[y_predicted[0]]
58
+
59
+ # Update js
60
+ js['local_premises'] = assemble_answers_to_one(js, k=3)
61
+ js['evidence'] = [self.fc_client.tokenizer.clean_up_tokenization(e[2]) for e in js['evidence']]
62
+ js['questions'] = [self.fc_client.tokenizer.clean_up_tokenization(q) for q in js['questions']]
63
+ js['claim_phrases'] = [self.fc_client.tokenizer.clean_up_tokenization(a[0]) for a in js['answers']]
64
+ js['local_premises'] = [[self.fc_client.tokenizer.clean_up_tokenization(a) for a in aa]
65
+ for aa in js['local_premises']]
66
+ # js['m_attn'] = m_attn[0][:len(js['claim_phrases'])]
67
+ js['phrase_veracity'] = z_predicted[0][:len(js['claim_phrases'])]
68
+ js['claim_veracity'] = label
69
+
70
+ self.logger.info(" * Intermediary: %s *" % str(js))
71
+ self.logger.info('*** Verification completed: "%s" ***' % label)
72
+ return js
73
+
74
+ def prep(self, claim, evidence=None):
75
+ '''
76
+ :param evidence: 'aaa||bbb||ccc' / [entity, num, evidence, (prob)] if not None
77
+ '''
78
+ evidence = self._prep_evidence(claim, evidence)
79
+ self.logger.info(' * Evidence prepared. *')
80
+ assert isinstance(evidence, list)
81
+
82
+ js = {'claim': claim, 'evidence': evidence}
83
+ js = self._prep_claim_phrases(js)
84
+ self.logger.info(' * Claim phrases prepared. *')
85
+ js = self._prep_questions(js)
86
+ self.logger.info(' * Probing questions prepared. *')
87
+ js = self._prep_evidential_phrases(js)
88
+ self.logger.info(' * Evidential phrases prepared. *')
89
+ return js
90
+
91
+ def _prep_claim_phrases(self, js):
92
+ results = self.sent_client.identify_NPs(deal_bracket(js['claim'], True),
93
+ candidate_NPs=[x[0] for x in js['evidence']])
94
+ NPs = results['NPs']
95
+ claim = results['text']
96
+ verbs = results['verbs']
97
+ adjs = results['adjs']
98
+ _cache = {'claim': claim,
99
+ 'evidence': js['evidence'],
100
+ 'answers': NPs + verbs + adjs,
101
+ 'answer_roles': ['noun'] * len(NPs) + ['verb'] * len(verbs) + ['adj'] * len(adjs)}
102
+ if len(_cache['answers']) == 0:
103
+ _cache['answers'] = js['claim'].split()[0]
104
+ _cache['answer_roles'] = ['noun']
105
+ return _cache
106
+
107
+ def _prep_questions(self, js):
108
+ _cache = []
109
+ for answer in js['answers']:
110
+ _cache.append((js['claim'], [answer]))
111
+ qa_pairs = self.qg_client.generate([(x, y) for x, y in _cache])
112
+ for q, clz_q, a in qa_pairs:
113
+ if 'questions' in js:
114
+ js['regular_qs'].append(q)
115
+ js['cloze_qs'].append(clz_q)
116
+ js['questions'].append(self.qg_client.assemble_question(q, clz_q))
117
+ else:
118
+ js['regular_qs'] = [q]
119
+ js['cloze_qs'] = [clz_q]
120
+ js['questions'] = [self.qg_client.assemble_question(q, clz_q)]
121
+ return js
122
+
123
+ def _prep_evidential_phrases(self, js):
124
+ examples = []
125
+ for q in js['questions']:
126
+ ex = self.ag_client.assemble(q, " ".join([x[2] for x in js['evidence']]))
127
+ examples.append(ex)
128
+ predicted = self.ag_client.generate(examples, num_beams=self.args['cand_k'],
129
+ num_return_sequences=self.args['cand_k'],
130
+ batch_size=2, verbose=False)
131
+ for answers in predicted:
132
+ if 'evidential' in js:
133
+ js['evidential'].append(answers)
134
+ else:
135
+ js['evidential'] = [answers]
136
+ return js
137
+
138
+ def _prep_evidence(self, claim, evidence=None):
139
+ '''
140
+ :param evidence: 'aaa||bbb||ccc' / [entity, num, evidence, (prob)]
141
+ :return: [entity, num, evidence, (prob)]
142
+ '''
143
+ if evidence in [None, '', 'null', 'NULL', 'Null']:
144
+ evidence = self.er_client.retrieve(claim)
145
+ evidence = [(ev[0], ev[1], deal_bracket(ev[2], True, ev[0])) for ev in evidence]
146
+ else:
147
+ if isinstance(evidence, str):
148
+ # TODO: magic sentence number
149
+ evidence = [("None", i, ev.strip()) for i, ev in enumerate(evidence.split('||')[:5])]
150
+ return evidence
151
+
152
+
153
+ if __name__ == '__main__':
154
+ import argparse
155
+
156
+ parser = argparse.ArgumentParser()
157
+ parser.add_argument('--config', '-c', type=str, required=True,
158
+ default='available_models/aaai22_roberta.json',
159
+ help='Config json file with hyper-parameters')
160
+ args = parser.parse_args()
161
+
162
+ loren = Loren(args.config)
163
+ while True:
164
+ claim = input('> ')
165
+ label, js = loren.check(claim)
166
+ print(label)
167
+ print(js)
src/mrc_client/answer_generator.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ '''
4
+ @Author : Jiangjie Chen
5
+ @Time : 2020/8/12 14:44
6
+ @Contact : jjchen19@fudan.edu.cn
7
+ @Description:
8
+ '''
9
+
10
+ import re
11
+ import time
12
+ from pathlib import Path
13
+ from typing import Dict, List
14
+ import torch
15
+ from logging import getLogger
16
+ from tqdm import tqdm
17
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
18
+ import ujson as json
19
+ import random
20
+
21
+ try:
22
+ from .seq2seq.seq2seq_utils import (
23
+ use_task_specific_params,
24
+ calculate_rouge,
25
+ chunks,
26
+ Seq2SeqDataset,
27
+ lmap,
28
+ load_json,
29
+ save_json,
30
+ )
31
+ except ImportError:
32
+ import cjjpy as cjj
33
+ import sys
34
+ sys.path.append(cjj.AbsParentDir(__file__, '.'))
35
+ from seq2seq.seq2seq_utils import (
36
+ use_task_specific_params,
37
+ calculate_rouge,
38
+ chunks,
39
+ Seq2SeqDataset,
40
+ lmap,
41
+ load_json,
42
+ save_json,
43
+ )
44
+
45
+ logger = getLogger(__name__)
46
+ DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
47
+ random.seed(1111)
48
+
49
+
50
+ def assemble_answers_to_one(js, k=5, mask_token='<mask>', mask_rate=0.):
51
+ if isinstance(js, str):
52
+ js = json.loads(js)
53
+
54
+ should_keep = random.random() > mask_rate
55
+ js.pop('evidential_assembled')
56
+ for q, answers in zip(js['cloze_qs'], js['evidential']):
57
+ if mask_token in q:
58
+ s = q.find(mask_token)
59
+ e = s + len(mask_token)
60
+ nq_list = []
61
+ if should_keep:
62
+ for i in range(k):
63
+ answer_span = answers[i]
64
+ nq = q[:s] + answer_span + q[e:]
65
+ nq_list.append(nq)
66
+ else:
67
+ for i in range(k):
68
+ answer_span = mask_token
69
+ nq = q[:s] + answer_span + q[e:]
70
+ nq_list.append(nq)
71
+ ev_nqs = ' '.join(nq_list)
72
+ if js.get('evidential_assembled') is None:
73
+ js['evidential_assembled'] = [ev_nqs]
74
+ else:
75
+ js['evidential_assembled'].append(ev_nqs)
76
+ assert len(js['evidential_assembled']) == len(js['answers'])
77
+ return js
78
+
79
+
80
+ class AnswerGenerator():
81
+ def __init__(self, model_name, device=DEFAULT_DEVICE):
82
+ self.model_name = str(model_name)
83
+ self.device = device
84
+ self.model = None
85
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
86
+
87
+ def init_model(self):
88
+ if self.model is None:
89
+ self.model = AutoModelForSeq2SeqLM.from_pretrained(self.model_name).to(self.device)
90
+
91
+ def assemble(self, question, context):
92
+ sep = '\n' if 'unifiedqa' in self.tokenizer.name_or_path else self.tokenizer.sep_token
93
+ return f'{question} {sep} {context}'
94
+
95
+ def generate(self, examples, out_file=None, batch_size=16, verbose=True,
96
+ max_length=20, min_length=1, num_beams=4, num_return_sequences=4,
97
+ prefix=None, fp16=False, task='summarization', **generate_kwargs):
98
+ '''
99
+ :param examples: [N]
100
+ :return: [N x num_return_seq]
101
+ '''
102
+ self.init_model()
103
+ if fp16:
104
+ self.model = self.model.half()
105
+ # update config with summarization specific params
106
+ use_task_specific_params(self.model, task)
107
+
108
+ fout = None if out_file is None else Path(out_file).open("w", encoding="utf-8")
109
+ generated = []
110
+ if verbose:
111
+ iter = tqdm(list(chunks(examples, batch_size)), desc="MRC")
112
+ else:
113
+ iter = list(chunks(examples, batch_size))
114
+ if prefix is None:
115
+ prefix = prefix or getattr(self.model.config, "prefix", "") or ""
116
+ for examples_chunk in iter:
117
+ examples_chunk = [prefix + text for text in examples_chunk]
118
+ batch = self.tokenizer(examples_chunk, return_tensors="pt", truncation=True,
119
+ padding="longest").to(self.device)
120
+ summaries = self.model.generate(
121
+ input_ids=batch.input_ids,
122
+ attention_mask=batch.attention_mask,
123
+ max_length=max_length,
124
+ min_length=min_length,
125
+ num_beams=num_beams,
126
+ num_return_sequences=num_return_sequences,
127
+ length_penalty=1.2,
128
+ repetition_penalty=1.2,
129
+ **generate_kwargs,
130
+ )
131
+ dec = self.tokenizer.batch_decode(summaries, skip_special_tokens=True,
132
+ clean_up_tokenization_spaces=False)
133
+ if fout is not None:
134
+ for hypothesis in dec:
135
+ fout.write(hypothesis.strip() + "\n")
136
+ fout.flush()
137
+ else:
138
+ generated += dec
139
+ if fout is not None:
140
+ fout.close()
141
+ generated = list(map(lambda x: x.strip(), generated))
142
+ generated = list(chunks(generated, num_return_sequences))
143
+ return generated
144
+
src/mrc_client/cjjpy.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ '''
4
+ @Author : Jiangjie Chen
5
+ @Time : 2018/11/15 17:08
6
+ @Contact: jjchen19@fudan.edu.cn
7
+ '''
8
+
9
+ import re
10
+ import datetime
11
+ import os
12
+ import argparse
13
+ import logging
14
+ import traceback
15
+
16
+ try:
17
+ import ujson as json
18
+ except:
19
+ import json
20
+
21
+ HADOOP_BIN = 'PATH=/usr/bin/:$PATH hdfs'
22
+ FOR_PUBLIC = True
23
+
24
+
25
+ def LengthStats(filename):
26
+ len_list = []
27
+ thresholds = [0.8, 0.9, 0.95, 0.99, 0.999]
28
+ with open(filename) as f:
29
+ for line in f:
30
+ len_list.append(len(line.strip().split()))
31
+ stats = {
32
+ 'Max': max(len_list),
33
+ 'Min': min(len_list),
34
+ 'Avg': round(sum(len_list) / len(len_list), 4),
35
+ }
36
+ len_list.sort()
37
+ for t in thresholds:
38
+ stats[f"Top-{t}"] = len_list[int(len(len_list) * t)]
39
+
40
+ for k in stats:
41
+ print(f"- {k}: {stats[k]}")
42
+ return stats
43
+
44
+
45
+ class AttrDict(dict):
46
+ def __init__(self, *args, **kwargs):
47
+ super(AttrDict, self).__init__(*args, **kwargs)
48
+ self.__dict__ = self
49
+
50
+
51
+ def TraceBack(error_msg):
52
+ exc = traceback.format_exc()
53
+ msg = f'[Error]: {error_msg}.\n[Traceback]: {exc}'
54
+ return msg
55
+
56
+
57
+ def Now():
58
+ return datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
59
+
60
+
61
+ def AbsParentDir(file, parent='..', postfix=None):
62
+ ppath = os.path.abspath(file)
63
+ parent_level = parent.count('.')
64
+ while parent_level > 0:
65
+ ppath = os.path.dirname(ppath)
66
+ parent_level -= 1
67
+ if postfix is not None:
68
+ return os.path.join(ppath, postfix)
69
+ else:
70
+ return ppath
71
+
72
+
73
+ def init_logger(log_file=None, log_file_level=logging.NOTSET, from_scratch=False):
74
+ from coloredlogs import ColoredFormatter
75
+ import tensorflow as tf
76
+
77
+ fmt = "[%(asctime)s %(levelname)s] %(message)s"
78
+ log_format = ColoredFormatter(fmt=fmt)
79
+ # log_format = logging.Formatter()
80
+ logger = logging.getLogger()
81
+ logger.setLevel(log_file_level)
82
+
83
+ console_handler = logging.StreamHandler()
84
+ console_handler.setFormatter(log_format)
85
+ logger.handlers = [console_handler]
86
+
87
+ if log_file and log_file != '':
88
+ if from_scratch and tf.io.gfile.exists(log_file):
89
+ logger.warning('Removing previous log file: %s' % log_file)
90
+ tf.io.gfile.remove(log_file)
91
+ path = os.path.dirname(log_file)
92
+ os.makedirs(path, exist_ok=True)
93
+ file_handler = logging.FileHandler(log_file)
94
+ file_handler.setLevel(log_file_level)
95
+ file_handler.setFormatter(log_format)
96
+ logger.addHandler(file_handler)
97
+
98
+ return logger
99
+
100
+
101
+ def OverWriteCjjPy(root='.'):
102
+ # import difflib
103
+ # diff = difflib.HtmlDiff()
104
+ cnt = 0
105
+ golden_cjjpy = os.path.join(root, 'cjjpy.py')
106
+ # golden_content = open(golden_cjjpy).readlines()
107
+ for dir, folder, file in os.walk(root):
108
+ for f in file:
109
+ if f == 'cjjpy.py':
110
+ cjjpy = '%s/%s' % (dir, f)
111
+ # content = open(cjjpy).readlines()
112
+ # d = diff.make_file(golden_content, content)
113
+ cnt += 1
114
+ print('[%d]: %s' % (cnt, cjjpy))
115
+ os.system('cp %s %s' % (golden_cjjpy, cjjpy))
116
+
117
+
118
+ def ChangeFileFormat(filename, new_fmt):
119
+ assert type(filename) is str and type(new_fmt) is str
120
+ spt = filename.split('.')
121
+ if len(spt) == 0:
122
+ return filename
123
+ else:
124
+ return filename.replace('.' + spt[-1], new_fmt)
125
+
126
+
127
+ def CountLines(fname):
128
+ with open(fname, 'rb') as f:
129
+ count = 0
130
+ last_data = '\n'
131
+ while True:
132
+ data = f.read(0x400000)
133
+ if not data:
134
+ break
135
+ count += data.count(b'\n')
136
+ last_data = data
137
+ if last_data[-1:] != b'\n':
138
+ count += 1 # Remove this if a wc-like count is needed
139
+ return count
140
+
141
+
142
+ def GetDate():
143
+ return str(datetime.datetime.now())[5:10].replace('-', '')
144
+
145
+
146
+ def TimeClock(seconds):
147
+ sec = int(seconds)
148
+ hour = int(sec / 3600)
149
+ min = int((sec - hour * 3600) / 60)
150
+ ssec = float(seconds) - hour * 3600 - min * 60
151
+ # return '%dh %dm %.2fs' % (hour, min, ssec)
152
+ return '{0:>2d}h{1:>3d}m{2:>6.2f}s'.format(hour, min, ssec)
153
+
154
+
155
+ def StripAll(text):
156
+ return text.strip().replace('\t', '').replace('\n', '').replace(' ', '')
157
+
158
+
159
+ def GetBracket(text, bracket, en_br=False):
160
+ # input should be aa(bb)cc, True for bracket, False for text
161
+ if bracket:
162
+ try:
163
+ return re.findall('\((.*?)\)', text.strip())[-1]
164
+ except:
165
+ return ''
166
+ else:
167
+ if en_br:
168
+ text = re.sub('\(.*?\)', '', text.strip())
169
+ return re.sub('(.*?)', '', text.strip())
170
+
171
+
172
+ def CharLang(uchar, lang):
173
+ assert lang.lower() in ['en', 'cn', 'zh']
174
+ if lang.lower() in ['cn', 'zh']:
175
+ if uchar >= '\u4e00' and uchar <= '\u9fa5':
176
+ return True
177
+ else:
178
+ return False
179
+ elif lang.lower() == 'en':
180
+ if (uchar <= 'Z' and uchar >= 'A') or (uchar <= 'z' and uchar >= 'a'):
181
+ return True
182
+ else:
183
+ return False
184
+ else:
185
+ raise NotImplementedError
186
+
187
+
188
+ def WordLang(word, lang):
189
+ for i in word.strip():
190
+ if i.isspace(): continue
191
+ if not CharLang(i, lang):
192
+ return False
193
+ return True
194
+
195
+
196
+ def SortDict(_dict, reverse=True):
197
+ assert type(_dict) is dict
198
+ return sorted(_dict.items(), key=lambda d: d[1], reverse=reverse)
199
+
200
+
201
+ def lark(content='test'):
202
+ print(content)
203
+
204
+
205
+ if __name__ == '__main__':
206
+ parser = argparse.ArgumentParser()
207
+
208
+ parser.add_argument('--diff', nargs=2,
209
+ help='show difference between two files, shown in downloads/diff.html')
210
+ parser.add_argument('--de_unicode', action='store_true', default=False,
211
+ help='remove unicode characters')
212
+ parser.add_argument('--link_entity', action='store_true', default=False,
213
+ help='')
214
+ parser.add_argument('--max_comm_len', action='store_true', default=False,
215
+ help='')
216
+ parser.add_argument('--search', nargs=2,
217
+ help='search key from file, 2 args: file name & key')
218
+ parser.add_argument('--email', nargs=2,
219
+ help='sending emails, 2 args: subject & content')
220
+ parser.add_argument('--overwrite', action='store_true', default=None,
221
+ help='overwrite all cjjpy under given *dir* based on *dir*/cjjpy.py')
222
+ parser.add_argument('--replace', nargs=3,
223
+ help='replace char, 3 args: file name & replaced char & replacer char')
224
+ parser.add_argument('--lark', nargs=1)
225
+ parser.add_argument('--get_hdfs', nargs=2,
226
+ help='easy copy from hdfs to local fs, 2 args: remote_file/dir & local_dir')
227
+ parser.add_argument('--put_hdfs', nargs=2,
228
+ help='easy put from local fs to hdfs, 2 args: local_file/dir & remote_dir')
229
+ parser.add_argument('--length_stats', nargs=1,
230
+ help='simple token lengths distribution of a line-by-line file')
231
+
232
+ args = parser.parse_args()
233
+
234
+ if args.overwrite:
235
+ print('* Overwriting cjjpy...')
236
+ OverWriteCjjPy()
237
+
238
+ if args.lark:
239
+ try:
240
+ content = args.lark[0]
241
+ except:
242
+ content = 'running complete'
243
+ print(f'* Larking "{content}"...')
244
+ lark(content)
245
+
246
+ if args.length_stats:
247
+ file = args.length_stats[0]
248
+ print(f'* Working on {file} lengths statistics...')
249
+ LengthStats(file)
src/mrc_client/seq2seq/README.md ADDED
@@ -0,0 +1,590 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Sequence to Sequence Training and Evaluation
2
+
3
+ This directory contains examples for finetuning and evaluating transformers on summarization and translation tasks.
4
+ Please tag @patil-suraj with any issues/unexpected behaviors, or send a PR!
5
+ For deprecated `bertabs` instructions, see [`bertabs/README.md`](bertabs/README.md).
6
+
7
+ ### Supported Architectures
8
+
9
+ - `BartForConditionalGeneration` (and anything that inherits from it)
10
+ - `MarianMTModel`
11
+ - `PegasusForConditionalGeneration`
12
+ - `MBartForConditionalGeneration`
13
+ - `FSMTForConditionalGeneration`
14
+ - `T5ForConditionalGeneration`
15
+
16
+ ## Datasets
17
+
18
+ #### XSUM
19
+
20
+ ```bash
21
+ cd examples/seq2seq
22
+ wget https://cdn-datasets.huggingface.co/summarization/xsum.tar.gz
23
+ tar -xzvf xsum.tar.gz
24
+ export XSUM_DIR=${PWD}/xsum
25
+ ```
26
+ this should make a directory called `xsum/` with files like `test.source`.
27
+ To use your own data, copy that files format. Each article to be summarized is on its own line.
28
+
29
+ #### CNN/DailyMail
30
+
31
+ ```bash
32
+ cd examples/seq2seq
33
+ wget https://cdn-datasets.huggingface.co/summarization/cnn_dm_v2.tgz
34
+ tar -xzvf cnn_dm_v2.tgz # empty lines removed
35
+ mv cnn_cln cnn_dm
36
+ export CNN_DIR=${PWD}/cnn_dm
37
+ ```
38
+ this should make a directory called `cnn_dm/` with 6 files.
39
+
40
+ #### WMT16 English-Romanian Translation Data
41
+
42
+ download with this command:
43
+ ```bash
44
+ wget https://cdn-datasets.huggingface.co/translation/wmt_en_ro.tar.gz
45
+ tar -xzvf wmt_en_ro.tar.gz
46
+ export ENRO_DIR=${PWD}/wmt_en_ro
47
+ ```
48
+ this should make a directory called `wmt_en_ro/` with 6 files.
49
+
50
+ #### WMT English-German
51
+
52
+ ```bash
53
+ wget https://cdn-datasets.huggingface.co/translation/wmt_en_de.tgz
54
+ tar -xzvf wmt_en_de.tgz
55
+ export DATA_DIR=${PWD}/wmt_en_de
56
+ ```
57
+
58
+ #### FSMT datasets (wmt)
59
+
60
+ Refer to the scripts starting with `eval_` under:
61
+ https://github.com/huggingface/transformers/tree/master/scripts/fsmt
62
+
63
+ #### Pegasus (multiple datasets)
64
+
65
+ Multiple eval datasets are available for download from:
66
+ https://github.com/stas00/porting/tree/master/datasets/pegasus
67
+
68
+
69
+ #### Your Data
70
+
71
+ If you are using your own data, it must be formatted as one directory with 6 files:
72
+ ```
73
+ train.source
74
+ train.target
75
+ val.source
76
+ val.target
77
+ test.source
78
+ test.target
79
+ ```
80
+ The `.source` files are the input, the `.target` files are the desired output.
81
+
82
+ ### Tips and Tricks
83
+
84
+ General Tips:
85
+ - since you need to run from `examples/seq2seq`, and likely need to modify code, the easiest workflow is fork transformers, clone your fork, and run `pip install -e .` before you get started.
86
+ - try `--freeze_encoder` or `--freeze_embeds` for faster training/larger batch size. (3hr per epoch with bs=8, see the "xsum_shared_task" command below)
87
+ - `fp16_opt_level=O1` (the default works best).
88
+ - In addition to the pytorch-lightning .ckpt checkpoint, a transformers checkpoint will be saved.
89
+ Load it with `BartForConditionalGeneration.from_pretrained(f'{output_dir}/best_tfmr)`.
90
+ - At the moment, `--do_predict` does not work in a multi-gpu setting. You need to use `evaluate_checkpoint` or the `run_eval.py` code.
91
+ - This warning can be safely ignored:
92
+ > "Some weights of BartForConditionalGeneration were not initialized from the model checkpoint at facebook/bart-large-xsum and are newly initialized: ['final_logits_bias']"
93
+ - Both finetuning and eval are 30% faster with `--fp16`. For that you need to [install apex](https://github.com/NVIDIA/apex#quick-start).
94
+ - Read scripts before you run them!
95
+
96
+ Summarization Tips:
97
+ - (summ) 1 epoch at batch size 1 for bart-large takes 24 hours and requires 13GB GPU RAM with fp16 on an NVIDIA-V100.
98
+ - If you want to run experiments on improving the summarization finetuning process, try the XSUM Shared Task (below). It's faster to train than CNNDM because the summaries are shorter.
99
+ - For CNN/DailyMail, the default `val_max_target_length` and `test_max_target_length` will truncate the ground truth labels, resulting in slightly higher rouge scores. To get accurate rouge scores, you should rerun calculate_rouge on the `{output_dir}/test_generations.txt` file saved by `trainer.test()`
100
+ - `--max_target_length=60 --val_max_target_length=60 --test_max_target_length=100 ` is a reasonable setting for XSUM.
101
+ - `wandb` can be used by specifying `--logger_name wandb`. It is useful for reproducibility. Specify the environment variable `WANDB_PROJECT='hf_xsum'` to do the XSUM shared task.
102
+ - If you are finetuning on your own dataset, start from `distilbart-cnn-12-6` if you want long summaries and `distilbart-xsum-12-6` if you want short summaries.
103
+ (It rarely makes sense to start from `bart-large` unless you are a researching finetuning methods).
104
+
105
+ **Update 2018-07-18**
106
+ Datasets: `LegacySeq2SeqDataset` will be used for all tokenizers without a `prepare_seq2seq_batch` method. Otherwise, `Seq2SeqDataset` will be used.
107
+ Future work/help wanted: A new dataset to support multilingual tasks.
108
+
109
+
110
+ ### Finetuning Scripts
111
+ All finetuning bash scripts call finetune.py (or distillation.py) with reasonable command line arguments. They usually require extra command line arguments to work.
112
+
113
+ To see all the possible command line options, run:
114
+
115
+ ```bash
116
+ ./finetune.py --help
117
+ ```
118
+
119
+ ### Finetuning Training Params
120
+
121
+ To override the pretrained model's training params, you can pass them to `./finetune.sh`:
122
+
123
+ ```bash
124
+ ./finetune.sh \
125
+ [...]
126
+ --encoder_layerdrop 0.1 \
127
+ --decoder_layerdrop 0.1 \
128
+ --dropout 0.1 \
129
+ --attention_dropout 0.1 \
130
+ ```
131
+
132
+ ### Summarization Finetuning
133
+ Run/modify `finetune.sh`
134
+
135
+ The following command should work on a 16GB GPU:
136
+ ```bash
137
+ ./finetune.sh \
138
+ --data_dir $XSUM_DIR \
139
+ --train_batch_size=1 \
140
+ --eval_batch_size=1 \
141
+ --output_dir=xsum_results \
142
+ --num_train_epochs 6 \
143
+ --model_name_or_path facebook/bart-large
144
+ ```
145
+
146
+ There is a starter finetuning script for pegasus at `finetune_pegasus_xsum.sh`.
147
+
148
+ ### Translation Finetuning
149
+
150
+ First, follow the wmt_en_ro download instructions.
151
+ Then you can finetune mbart_cc25 on english-romanian with the following command.
152
+ **Recommendation:** Read and potentially modify the fairly opinionated defaults in `train_mbart_cc25_enro.sh` script before running it.
153
+
154
+ Best performing command:
155
+ ```bash
156
+ # optionally
157
+ export ENRO_DIR='wmt_en_ro' # Download instructions above
158
+ # export WANDB_PROJECT="MT" # optional
159
+ export MAX_LEN=128
160
+ export BS=4
161
+ ./train_mbart_cc25_enro.sh --output_dir enro_finetune_baseline --label_smoothing 0.1 --fp16_opt_level=O1 --logger_name wandb --sortish_sampler
162
+ ```
163
+ This should take < 6h/epoch on a 16GB v100 and achieve test BLEU above 26
164
+ To get results in line with fairseq, you need to do some postprocessing. (see `romanian_postprocessing.md`)
165
+
166
+ MultiGPU command
167
+ (using 8 GPUS as an example)
168
+ ```bash
169
+ export ENRO_DIR='wmt_en_ro' # Download instructions above
170
+ # export WANDB_PROJECT="MT" # optional
171
+ export MAX_LEN=128
172
+ export BS=4
173
+ ./train_mbart_cc25_enro.sh --output_dir enro_finetune_baseline --gpus 8 --logger_name wandb
174
+ ```
175
+ ### Finetuning Outputs
176
+ As you train, `output_dir` will be filled with files, that look kind of like this (comments are mine).
177
+ Some of them are metrics, some of them are checkpoints, some of them are metadata. Here is a quick tour:
178
+
179
+ ```bash
180
+ output_dir
181
+ ├── best_tfmr # this is a huggingface checkpoint generated by save_pretrained. It is the same model as the PL .ckpt file below
182
+ │   ├── config.json
183
+ │   ├── merges.txt
184
+ │   ├── pytorch_model.bin
185
+ │   ├── special_tokens_map.json
186
+ │   ├── tokenizer_config.json
187
+ │   └── vocab.json
188
+ ├── git_log.json # repo, branch, and commit hash
189
+ ├── val_avg_rouge2=0.1984-step_count=11.ckpt # this is a pytorch lightning checkpoint associated with the best val score. (it will be called BLEU for MT)
190
+ ├── metrics.json # new validation metrics will continually be appended to this
191
+ ├── student # this is a huggingface checkpoint generated by SummarizationDistiller. It is the student before it gets finetuned.
192
+ │   ├── config.json
193
+ │   └── pytorch_model.bin
194
+ ├── test_generations.txt
195
+ # ^^ are the summaries or translations produced by your best checkpoint on the test data. Populated when training is done
196
+ ├── test_results.txt # a convenience file with the test set metrics. This data is also in metrics.json['test']
197
+ ├── hparams.pkl # the command line args passed after some light preprocessing. Should be saved fairly quickly.
198
+ ```
199
+ After training, you can recover the best checkpoint by running
200
+ ```python
201
+ from transformers import AutoModelForSeq2SeqLM
202
+ model = AutoModelForSeq2SeqLM.from_pretrained(f'{output_dir}/best_tfmr')
203
+ ```
204
+
205
+ ### Fine-tuning using Seq2SeqTrainer
206
+ To use `Seq2SeqTrainer` for fine-tuning you should use the `finetune_trainer.py` script. It subclasses `Trainer` to extend it for seq2seq training. Except the `Trainer` releated `TrainingArguments`, it shares the same argument names as that of `finetune.py` file. One notable difference is that, calculating generative metrics (BLEU, ROUGE) is optional and is controlled using the `--predict_with_generate` argument, set this argument to calculate BLEU and ROUGE metrics.
207
+
208
+ With PyTorch 1.6+ it'll automatically use `native AMP` when `--fp16` is set.
209
+
210
+ To see all the possible command line options, run:
211
+
212
+ ```bash
213
+ ./builtin_trainer/finetune.sh --help # This calls python finetune_trainer.py --help
214
+ ```
215
+
216
+ **At the moment, `Seq2SeqTrainer` does not support *with teacher* distillation.**
217
+
218
+ All `Seq2SeqTrainer` based fine-tuning scripts are included in the `builtin_trainer` directory.
219
+
220
+ #### TPU Training
221
+ `Seq2SeqTrainer` supports TPU training with few caveats
222
+ 1. As `generate` method does not work on TPU at the moment, `predict_with_generate` can not be used. You should use `--prediction_loss_only` to only calculate loss, and do not set `--do_predict` and `--predict_with_generate`.
223
+ 2. All sequences should be padded to be of equal length otherwise it leads to extremely slow training. (`finetune_trainer.py` does this automatically when running on TPU.)
224
+
225
+ We provide a very simple launcher script named `xla_spawn.py` that lets you run our example scripts on multiple TPU cores without any boilerplate. Just pass a --num_cores flag to this script, then your regular training script with its arguments (this is similar to the torch.distributed.launch helper for torch.distributed).
226
+
227
+ `builtin_trainer/finetune_tpu.sh` script provides minimal arguments needed for TPU training.
228
+
229
+ Following command fine-tunes `sshleifer/student_marian_en_ro_6_3` on TPU V3-8 and should complete one epoch in ~5-6 mins.
230
+
231
+ ```bash
232
+ ./builtin_trainer/train_distil_marian_enro_tpu.sh
233
+ ```
234
+
235
+ # DistilBART
236
+ <!---It should be called distilling bart and pegasus, but I don't want to break the link in the paper.-->
237
+ This section describes all code and artifacts from our [Paper](http://arxiv.org/abs/2010.13002)
238
+
239
+ ![DBART](https://huggingface.co/front/thumbnails/distilbart_large.png)
240
+
241
+ + For the CNN/DailyMail dataset, (relatively longer, more extractive summaries), we found a simple technique that works, which we call "Shrink and Fine-tune", or SFT.
242
+ you just copy alternating layers from `facebook/bart-large-cnn` and fine-tune more on the cnn/dm data. `sshleifer/distill-pegasus-cnn-16-4`, `sshleifer/distilbart-cnn-12-6` and all other checkpoints under `sshleifer` that start with `distilbart-cnn` were trained this way.
243
+ + For the XSUM dataset, training on pseudo-labels worked best for Pegasus (`sshleifer/distill-pegasus-16-4`), while training with KD worked best for `distilbart-xsum-12-6`
244
+ + For `sshleifer/dbart-xsum-12-3`
245
+ + We ran 100s experiments, and didn't want to document 100s of commands. If you want a command to replicate a figure from the paper that is not documented below, feel free to ask on the [forums](https://discuss.huggingface.co/t/seq2seq-distillation-methodology-questions/1270) and tag `@sshleifer`.
246
+ + You can see the performance tradeoffs of model sizes [here](https://docs.google.com/spreadsheets/d/1EkhDMwVO02m8jCD1cG3RoFPLicpcL1GQHTQjfvDYgIM/edit#gid=0).
247
+ and more granular timing results [here](https://docs.google.com/spreadsheets/d/1EkhDMwVO02m8jCD1cG3RoFPLicpcL1GQHTQjfvDYgIM/edit#gid=1753259047&range=B2:I23).
248
+
249
+ ### Evaluation
250
+
251
+ use [run_distributed_eval](./run_distributed_eval.py), with the following convenient alias
252
+ ```bash
253
+ deval () {
254
+ proc=$1
255
+ m=$2
256
+ dd=$3
257
+ sd=$4
258
+ shift
259
+ shift
260
+ shift
261
+ shift
262
+ python -m torch.distributed.launch --nproc_per_node=$proc run_distributed_eval.py \
263
+ --model_name $m --save_dir $sd --data_dir $dd $@
264
+ }
265
+ ```
266
+ On a 1 GPU system, here are four commands (that assume `xsum`, `cnn_dm` are downloaded, cmd-F for those links in this file).
267
+
268
+ `distilBART`:
269
+ ```bash
270
+ deval 1 sshleifer/distilbart-xsum-12-3 xsum dbart_12_3_xsum_eval --fp16 # --help for more choices.
271
+ deval 1 sshleifer/distilbart-cnn_dm-12-6 cnn_dm dbart_12_6_cnn_eval --fp16
272
+ ```
273
+
274
+ `distill-pegasus`:
275
+ ```bash
276
+ deval 1 sshleifer/distill-pegasus-cnn-16-4 cnn_dm dpx_cnn_eval
277
+ deval 1 sshleifer/distill-pegasus-xsum-16-4 xsum dpx_xsum_eval
278
+ ```
279
+
280
+ ### Distillation
281
+ + For all of the following commands, you can get roughly equivalent result and faster run times by passing `--num_beams=4`. That's not what we did for the paper.
282
+ + Besides the KD section, you can also run commands with the built-in transformers trainer. See, for example, [builtin_trainer/train_distilbart_cnn.sh](./builtin_trainer/train_distilbart_cnn.sh).
283
+ + Large performance deviations (> 5X slower or more than 0.5 Rouge-2 worse), should be reported.
284
+ + Multi-gpu (controlled with `--gpus` should work, but might require more epochs).
285
+
286
+ #### Recommended Workflow
287
+ + Get your dataset in the right format. (see 6 files above).
288
+ + Find a teacher model [Pegasus](https://huggingface.co/models?search=pegasus) (slower, better ROUGE) or `facebook/bart-large-xsum`/`facebook/bart-large-cnn` (faster, slightly lower.).
289
+ Choose the checkpoint where the corresponding dataset is most similar (or identical to) your dataset.
290
+ + Follow the sections in order below. You can stop after SFT if you are satisfied, or move on to pseudo-labeling if you want more performance.
291
+ + student size: If you want a close to free 50% speedup, cut the decoder in half. If you want a larger speedup, cut it in 4.
292
+ + If your SFT run starts at a validation ROUGE-2 that is more than 10 pts below the teacher's validation ROUGE-2, you have a bug. Switching to a more expensive technique will not help. Try setting a breakpoint and looking at generation and truncation defaults/hyper-parameters, and share your experience on the forums!
293
+
294
+
295
+ #### Initialization
296
+ We use [make_student.py](./make_student.py) to copy alternating layers from the teacher, and save the resulting model to disk
297
+ ```bash
298
+ python make_student.py facebook/bart-large-xsum --save_path dbart_xsum_12_3 -e 12 -d 3
299
+ ```
300
+ or for `pegasus-xsum`
301
+ ```bash
302
+ python make_student.py google/pegasus-xsum --save_path dpx_xsum_16_4 --e 16 --d 4
303
+ ```
304
+ we now have an initialized student saved to `dbart_xsum_12_3`, which we will use for the following commands.
305
+ + Extension: To replicate more complicated initialize experiments in section 6.1, or try your own. Use the `create_student_by_copying_alternating_layers` function.
306
+
307
+ #### Pegasus
308
+ + The following commands are written for BART and will require, at minimum, the following modifications
309
+ + reduce batch size, and increase gradient accumulation steps so that the product `gpus * batch size * gradient_accumulation_steps = 256`. We used `--learning-rate` = 1e-4 * gradient accumulation steps.
310
+ + don't use fp16
311
+ + `--tokenizer_name google/pegasus-large`
312
+
313
+ ### SFT (No Teacher Distillation)
314
+ You don't need `distillation.py`, you can just run:
315
+
316
+ ```bash
317
+ python finetune.py \
318
+ --data_dir xsum \
319
+ --freeze_encoder --freeze_embeds \
320
+ --learning_rate=3e-4 \
321
+ --do_train \
322
+ --do_predict \
323
+ --fp16 --fp16_opt_level=O1 \
324
+ --val_check_interval 0.1 --n_val 1000 --eval_beams 2 --length_penalty=0.5 \
325
+ --max_target_length=60 --val_max_target_length=60 --test_max_target_length=100 \
326
+ --model_name_or_path dbart_xsum_12_3 \
327
+ --train_batch_size=64 --eval_batch_size=64 \
328
+ --sortish_sampler \
329
+ --num_train_epochs=6 \
330
+ --warmup_steps 500 \
331
+ --output_dir distilbart_xsum_sft_12_3 --gpus 1
332
+ ```
333
+
334
+ + Note: The command that produced `sshleifer/distilbart-cnn-12-6` is at [train_distilbart_cnn.sh](./[train_distilbart_cnn.sh)
335
+
336
+ ```bash
337
+ ./train_distilbart_cnn.sh
338
+ ```
339
+ <!--- runtime: 6H on NVIDIA RTX 24GB GPU -->
340
+ + Tip: You can get the same simple distillation logic by using `distillation.py --no_teacher ` followed by identical arguments as the ones in `train_distilbart_cnn.sh`.
341
+ If you are using `wandb` and comparing the two distillation methods, using this entry point will make your logs consistent,
342
+ because you will have the same hyper-parameters logged in every run.
343
+
344
+ ### Pseudo-Labeling
345
+ + You don't need `distillation.py`.
346
+ + Instructions to generate pseudo-labels and use pre-computed pseudo-labels can be found [here](./precomputed_pseudo_labels.md).
347
+ Simply run `finetune.py` with one of those pseudo-label datasets as `--data_dir` (`DATA`, below).
348
+
349
+ ```bash
350
+ python finetune.py \
351
+ --teacher facebook/bart-large-xsum --data_dir DATA \
352
+ --freeze_encoder --freeze_embeds \
353
+ --learning_rate=3e-4 \
354
+ --do_train \
355
+ --do_predict \
356
+ --fp16 --fp16_opt_level=O1 \
357
+ --val_check_interval 0.1 --n_val 1000 --eval_beams 2 --length_penalty=0.5 \
358
+ --max_target_length=60 --val_max_target_length=60 --test_max_target_length=100 \
359
+ --model_name_or_path dbart_xsum_12_3 \
360
+ --train_batch_size=32 --eval_batch_size=32 \
361
+ --sortish_sampler \
362
+ --num_train_epochs=5 \
363
+ --warmup_steps 500 \
364
+ --output_dir dbart_xsum_12_3_PL --gpus 1 --logger_name wandb
365
+ ```
366
+
367
+
368
+
369
+ To combine datasets, as in Section 6.2, try something like:
370
+ ```bash
371
+ curl -S https://cdn-datasets.huggingface.co/pseudo/xsum/bart_xsum_pl.tgz | tar -xvz -C .
372
+ curl -S https://cdn-datasets.huggingface.co/pseudo/xsum/pegasus_xsum.tgz | tar -xvz -C .
373
+ curl -S https://cdn-datasets.huggingface.co/summarization/xsum.tar.gz | tar -xvz -C .
374
+ mkdir all_pl
375
+ cat bart_xsum_pl/train.source pegasus_xsum/train.source xsum/train.source > all_pl/train.source
376
+ cat bart_xsum_pl/train.target pegasus_xsum/train.target xsum/train.target > all_pl/train.target
377
+ cp xsum/val* all_pl
378
+ cp xsum/test* all_pl
379
+ ```
380
+ then use `all_pl` as DATA in the command above.
381
+
382
+ #### Direct Knowledge Distillation (KD)
383
+ + In this method, we use try to enforce that the student and teacher produce similar encoder_outputs, logits, and hidden_states using `BartSummarizationDistiller`.
384
+ + This method was used for `sshleifer/distilbart-xsum-12-6`, `6-6`, and `9-6` checkpoints were produced.
385
+ + You must use [`distillation.py`](./distillation.py). Note that this command initializes the student for you.
386
+
387
+ The command that produced `sshleifer/distilbart-xsum-12-6` is at [./train_distilbart_xsum.sh](train_distilbart_xsum.sh)
388
+ ```bash
389
+ ./train_distilbart_xsum.sh --logger_name wandb --gpus 1
390
+ ```
391
+
392
+ + Expected ROUGE-2 between 21.3 and 21.6, run time ~13H.
393
+ + direct KD + Pegasus is VERY slow and works best with `--supervise_forward --normalize_hidden`.
394
+
395
+ <!--- runtime: 13H on V-100 16GB GPU. -->
396
+
397
+ ### Citation
398
+
399
+ ```bibtex
400
+ @misc{shleifer2020pretrained,
401
+ title={Pre-trained Summarization Distillation},
402
+ author={Sam Shleifer and Alexander M. Rush},
403
+ year={2020},
404
+ eprint={2010.13002},
405
+ archivePrefix={arXiv},
406
+ primaryClass={cs.CL}
407
+ }
408
+ @article{Wolf2019HuggingFacesTS,
409
+ title={HuggingFace's Transformers: State-of-the-art Natural Language Processing},
410
+ author={Thomas Wolf and Lysandre Debut and Victor Sanh and Julien Chaumond and Clement Delangue and Anthony Moi and Pierric Cistac and Tim Rault and Rémi Louf and Morgan Funtowicz and Joe Davison and Sam Shleifer and Patrick von Platen and Clara Ma and Yacine Jernite and Julien Plu and Canwen Xu and Teven Le Scao and Sylvain Gugger and Mariama Drame and Quentin Lhoest and Alexander M. Rush},
411
+ journal={ArXiv},
412
+ year={2019},
413
+ volume={abs/1910.03771}
414
+ }
415
+ ```
416
+
417
+ This is the end of the distillation section, the rest of this doc pertains to general seq2seq commands.
418
+
419
+ ## Evaluation Commands
420
+
421
+ To create summaries for each article in dataset, we use `run_eval.py`, here are a few commands that run eval for different tasks and models.
422
+ If 'translation' is in your task name, the computed metric will be BLEU. Otherwise, ROUGE will be used.
423
+
424
+ For t5, you need to specify --task translation_{src}_to_{tgt} as follows:
425
+ ```bash
426
+ export DATA_DIR=wmt_en_ro
427
+ ./run_eval.py t5-base \
428
+ $DATA_DIR/val.source t5_val_generations.txt \
429
+ --reference_path $DATA_DIR/val.target \
430
+ --score_path enro_bleu.json \
431
+ --task translation_en_to_ro \
432
+ --n_obs 100 \
433
+ --device cuda \
434
+ --fp16 \
435
+ --bs 32
436
+ ```
437
+
438
+ This command works for MBART, although the BLEU score is suspiciously low.
439
+ ```bash
440
+ export DATA_DIR=wmt_en_ro
441
+ ./run_eval.py facebook/mbart-large-en-ro $DATA_DIR/val.source mbart_val_generations.txt \
442
+ --reference_path $DATA_DIR/val.target \
443
+ --score_path enro_bleu.json \
444
+ --task translation \
445
+ --n_obs 100 \
446
+ --device cuda \
447
+ --fp16 \
448
+ --bs 32
449
+ ```
450
+
451
+ Summarization (xsum will be very similar):
452
+ ```bash
453
+ export DATA_DIR=cnn_dm
454
+ ./run_eval.py sshleifer/distilbart-cnn-12-6 $DATA_DIR/val.source dbart_val_generations.txt \
455
+ --reference_path $DATA_DIR/val.target \
456
+ --score_path cnn_rouge.json \
457
+ --task summarization \
458
+ --n_obs 100 \
459
+
460
+ th 56 \
461
+ --fp16 \
462
+ --bs 32
463
+ ```
464
+
465
+ ### Multi-GPU Evaluation
466
+ here is a command to run xsum evaluation on 8 GPUS. It is more than linearly faster than run_eval.py in some cases
467
+ because it uses SortishSampler to minimize padding. You can also use it on 1 GPU. `data_dir` must have
468
+ `{type_path}.source` and `{type_path}.target`. Run `./run_distributed_eval.py --help` for all clargs.
469
+
470
+ ```bash
471
+ python -m torch.distributed.launch --nproc_per_node=8 run_distributed_eval.py \
472
+ --model_name sshleifer/distilbart-large-xsum-12-3 \
473
+ --save_dir xsum_generations \
474
+ --data_dir xsum \
475
+ --fp16 # you can pass generate kwargs like num_beams here, just like run_eval.py
476
+ ```
477
+
478
+ Contributions that implement this command for other distributed hardware setups are welcome!
479
+
480
+ #### Single-GPU Eval: Tips and Tricks
481
+
482
+ When using `run_eval.py`, the following features can be useful:
483
+
484
+ * if you running the script multiple times and want to make it easier to track what arguments produced that output, use `--dump-args`. Along with the results it will also dump any custom params that were passed to the script. For example if you used: `--num_beams 8 --early_stopping true`, the output will be:
485
+ ```
486
+ {'bleu': 26.887, 'n_obs': 10, 'runtime': 1, 'seconds_per_sample': 0.1, 'num_beams': 8, 'early_stopping': True}
487
+ ```
488
+
489
+ `--info` is an additional argument available for the same purpose of tracking the conditions of the experiment. It's useful to pass things that weren't in the argument list, e.g. a language pair `--info "lang:en-ru"`. But also if you pass `--info` without a value it will fallback to the current date/time string, e.g. `2020-09-13 18:44:43`.
490
+
491
+ If using `--dump-args --info`, the output will be:
492
+
493
+ ```
494
+ {'bleu': 26.887, 'n_obs': 10, 'runtime': 1, 'seconds_per_sample': 0.1, 'num_beams': 8, 'early_stopping': True, 'info': '2020-09-13 18:44:43'}
495
+ ```
496
+
497
+ If using `--dump-args --info "pair:en-ru chkpt=best`, the output will be:
498
+
499
+ ```
500
+ {'bleu': 26.887, 'n_obs': 10, 'runtime': 1, 'seconds_per_sample': 0.1, 'num_beams': 8, 'early_stopping': True, 'info': 'pair=en-ru chkpt=best'}
501
+ ```
502
+
503
+
504
+ * if you need to perform a parametric search in order to find the best ones that lead to the highest BLEU score, let `run_eval_search.py` to do the searching for you.
505
+
506
+ The script accepts the exact same arguments as `run_eval.py`, plus an additional argument `--search`. The value of `--search` is parsed, reformatted and fed to ``run_eval.py`` as additional args.
507
+
508
+ The format for the `--search` value is a simple string with hparams and colon separated values to try, e.g.:
509
+ ```
510
+ --search "num_beams=5:10 length_penalty=0.8:1.0:1.2 early_stopping=true:false"
511
+ ```
512
+ which will generate `12` `(2*3*2)` searches for a product of each hparam. For example the example that was just used will invoke `run_eval.py` repeatedly with:
513
+
514
+ ```
515
+ --num_beams 5 --length_penalty 0.8 --early_stopping true
516
+ --num_beams 5 --length_penalty 0.8 --early_stopping false
517
+ [...]
518
+ --num_beams 10 --length_penalty 1.2 --early_stopping false
519
+ ```
520
+
521
+ On completion, this function prints a markdown table of the results sorted by the best BLEU score and the winning arguments.
522
+
523
+ ```
524
+ bleu | num_beams | length_penalty | early_stopping
525
+ ----- | --------- | -------------- | --------------
526
+ 26.71 | 5 | 1.1 | 1
527
+ 26.66 | 5 | 0.9 | 1
528
+ 26.66 | 5 | 0.9 | 0
529
+ 26.41 | 5 | 1.1 | 0
530
+ 21.94 | 1 | 0.9 | 1
531
+ 21.94 | 1 | 0.9 | 0
532
+ 21.94 | 1 | 1.1 | 1
533
+ 21.94 | 1 | 1.1 | 0
534
+
535
+ Best score args:
536
+ stas/wmt19-en-ru data/en-ru/val.source data/en-ru/test_translations.txt --reference_path data/en-ru/val.target --score_path data/en-ru/test_bleu.json --bs 8 --task translation --num_beams 5 --length_penalty 1.1 --early_stopping True
537
+ ```
538
+
539
+ If you pass `--info "some experiment-specific info"` it will get printed before the results table - this is useful for scripting and multiple runs, so one can tell the different sets of results from each other.
540
+
541
+
542
+ ### Contributing
543
+ - follow the standard contributing guidelines and code of conduct.
544
+ - add tests to `test_seq2seq_examples.py`
545
+ - To run only the seq2seq tests, you must be in the root of the repository and run:
546
+ ```bash
547
+ pytest examples/seq2seq/
548
+ ```
549
+
550
+ ### Converting pytorch-lightning checkpoints
551
+ pytorch lightning ``-do_predict`` often fails, after you are done training, the best way to evaluate your model is to convert it.
552
+
553
+ This should be done for you, with a file called `{save_dir}/best_tfmr`.
554
+
555
+ If that file doesn't exist but you have a lightning `.ckpt` file, you can run
556
+ ```bash
557
+ python convert_pl_checkpoint_to_hf.py PATH_TO_CKPT randomly_initialized_hf_model_path save_dir/best_tfmr
558
+ ```
559
+ Then either `run_eval` or `run_distributed_eval` with `save_dir/best_tfmr` (see previous sections)
560
+
561
+
562
+ # Experimental Features
563
+ These features are harder to use and not always useful.
564
+
565
+ ### Dynamic Batch Size for MT
566
+ `finetune.py` has a command line arg `--max_tokens_per_batch` that allows batches to be dynamically sized.
567
+ This feature can only be used:
568
+ - with fairseq installed
569
+ - on 1 GPU
570
+ - without sortish sampler
571
+ - after calling `./save_len_file.py $tok $data_dir`
572
+
573
+ For example,
574
+ ```bash
575
+ ./save_len_file.py Helsinki-NLP/opus-mt-en-ro wmt_en_ro
576
+ ./dynamic_bs_example.sh --max_tokens_per_batch=2000 --output_dir benchmark_dynamic_bs
577
+ ```
578
+ splits `wmt_en_ro/train` into 11,197 uneven lengthed batches and can finish 1 epoch in 8 minutes on a v100.
579
+
580
+ For comparison,
581
+ ```bash
582
+ ./dynamic_bs_example.sh --sortish_sampler --train_batch_size 48
583
+ ```
584
+ uses 12,723 batches of length 48 and takes slightly more time 9.5 minutes.
585
+
586
+ The feature is still experimental, because:
587
+ + we can make it much more robust if we have memory mapped/preprocessed datasets.
588
+ + The speedup over sortish sampler is not that large at the moment.
589
+
590
+
src/mrc_client/seq2seq/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+
5
+ sys.path.insert(1, os.path.dirname(os.path.realpath(__file__)))
src/mrc_client/seq2seq/callbacks.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ from pathlib import Path
4
+
5
+ import numpy as np
6
+ import pytorch_lightning as pl
7
+ import torch
8
+ from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
9
+ from pytorch_lightning.utilities import rank_zero_only
10
+
11
+ from seq2seq_utils import save_json
12
+
13
+
14
+ def count_trainable_parameters(model):
15
+ model_parameters = filter(lambda p: p.requires_grad, model.parameters())
16
+ params = sum([np.prod(p.size()) for p in model_parameters])
17
+ return params
18
+
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+
23
+ class Seq2SeqLoggingCallback(pl.Callback):
24
+ def on_batch_end(self, trainer, pl_module):
25
+ lrs = {f"lr_group_{i}": param["lr"] for i, param in enumerate(pl_module.trainer.optimizers[0].param_groups)}
26
+ pl_module.logger.log_metrics(lrs)
27
+
28
+ @rank_zero_only
29
+ def _write_logs(
30
+ self, trainer: pl.Trainer, pl_module: pl.LightningModule, type_path: str, save_generations=True
31
+ ) -> None:
32
+ logger.info(f"***** {type_path} results at step {trainer.global_step:05d} *****")
33
+ metrics = trainer.callback_metrics
34
+ trainer.logger.log_metrics({k: v for k, v in metrics.items() if k not in ["log", "progress_bar", "preds"]})
35
+ # Log results
36
+ od = Path(pl_module.hparams.output_dir)
37
+ if type_path == "test":
38
+ results_file = od / "test_results.txt"
39
+ generations_file = od / "test_generations.txt"
40
+ else:
41
+ # this never gets hit. I prefer not to save intermediate generations, and results are in metrics.json
42
+ # If people want this it will be easy enough to add back.
43
+ results_file = od / f"{type_path}_results/{trainer.global_step:05d}.txt"
44
+ generations_file = od / f"{type_path}_generations/{trainer.global_step:05d}.txt"
45
+ results_file.parent.mkdir(exist_ok=True)
46
+ generations_file.parent.mkdir(exist_ok=True)
47
+ with open(results_file, "a+") as writer:
48
+ for key in sorted(metrics):
49
+ if key in ["log", "progress_bar", "preds"]:
50
+ continue
51
+ val = metrics[key]
52
+ if isinstance(val, torch.Tensor):
53
+ val = val.item()
54
+ msg = f"{key}: {val:.6f}\n"
55
+ writer.write(msg)
56
+
57
+ if not save_generations:
58
+ return
59
+
60
+ if "preds" in metrics:
61
+ content = "\n".join(metrics["preds"])
62
+ generations_file.open("w+").write(content)
63
+
64
+ @rank_zero_only
65
+ def on_train_start(self, trainer, pl_module):
66
+ try:
67
+ npars = pl_module.model.model.num_parameters()
68
+ except AttributeError:
69
+ npars = pl_module.model.num_parameters()
70
+
71
+ n_trainable_pars = count_trainable_parameters(pl_module)
72
+ # mp stands for million parameters
73
+ trainer.logger.log_metrics({"n_params": npars, "mp": npars / 1e6, "grad_mp": n_trainable_pars / 1e6})
74
+
75
+ @rank_zero_only
76
+ def on_test_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
77
+ save_json(pl_module.metrics, pl_module.metrics_save_path)
78
+ return self._write_logs(trainer, pl_module, "test")
79
+
80
+ @rank_zero_only
81
+ def on_validation_end(self, trainer: pl.Trainer, pl_module):
82
+ save_json(pl_module.metrics, pl_module.metrics_save_path)
83
+ # Uncommenting this will save val generations
84
+ # return self._write_logs(trainer, pl_module, "valid")
85
+
86
+
87
+ def get_checkpoint_callback(output_dir, metric, save_top_k=1, lower_is_better=False):
88
+ """Saves the best model by validation ROUGE2 score."""
89
+ if metric == "rouge2":
90
+ exp = "{val_avg_rouge2:.4f}-{step_count}"
91
+ elif metric == "bleu":
92
+ exp = "{val_avg_bleu:.4f}-{step_count}"
93
+ elif metric == "loss":
94
+ exp = "{val_avg_loss:.4f}-{step_count}"
95
+ else:
96
+ raise NotImplementedError(
97
+ f"seq2seq callbacks only support rouge2, bleu and loss, got {metric}, You can make your own by adding to this function."
98
+ )
99
+
100
+ checkpoint_callback = ModelCheckpoint(
101
+ filepath=os.path.join(output_dir, exp),
102
+ monitor=f"val_{metric}",
103
+ mode="min" if "loss" in metric else "max",
104
+ save_top_k=save_top_k,
105
+ )
106
+ return checkpoint_callback
107
+
108
+
109
+ def get_early_stopping_callback(metric, patience):
110
+ return EarlyStopping(
111
+ monitor=f"val_{metric}", # does this need avg?
112
+ mode="min" if "loss" in metric else "max",
113
+ patience=patience,
114
+ verbose=True,
115
+ )
src/mrc_client/seq2seq/cjjpy.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ '''
4
+ @Author : Jiangjie Chen
5
+ @Time : 2018/11/15 17:08
6
+ @Contact: jjchen19@fudan.edu.cn
7
+ '''
8
+
9
+ import re
10
+ import datetime
11
+ import os
12
+ import argparse
13
+ import logging
14
+ import traceback
15
+
16
+ try:
17
+ import ujson as json
18
+ except:
19
+ import json
20
+
21
+ HADOOP_BIN = 'PATH=/usr/bin/:$PATH hdfs'
22
+ FOR_PUBLIC = True
23
+
24
+
25
+ def LengthStats(filename):
26
+ len_list = []
27
+ thresholds = [0.8, 0.9, 0.95, 0.99, 0.999]
28
+ with open(filename) as f:
29
+ for line in f:
30
+ len_list.append(len(line.strip().split()))
31
+ stats = {
32
+ 'Max': max(len_list),
33
+ 'Min': min(len_list),
34
+ 'Avg': round(sum(len_list) / len(len_list), 4),
35
+ }
36
+ len_list.sort()
37
+ for t in thresholds:
38
+ stats[f"Top-{t}"] = len_list[int(len(len_list) * t)]
39
+
40
+ for k in stats:
41
+ print(f"- {k}: {stats[k]}")
42
+ return stats
43
+
44
+
45
+ class AttrDict(dict):
46
+ def __init__(self, *args, **kwargs):
47
+ super(AttrDict, self).__init__(*args, **kwargs)
48
+ self.__dict__ = self
49
+
50
+
51
+ def TraceBack(error_msg):
52
+ exc = traceback.format_exc()
53
+ msg = f'[Error]: {error_msg}.\n[Traceback]: {exc}'
54
+ return msg
55
+
56
+
57
+ def Now():
58
+ return datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
59
+
60
+
61
+ def AbsParentDir(file, parent='..', postfix=None):
62
+ ppath = os.path.abspath(file)
63
+ parent_level = parent.count('.')
64
+ while parent_level > 0:
65
+ ppath = os.path.dirname(ppath)
66
+ parent_level -= 1
67
+ if postfix is not None:
68
+ return os.path.join(ppath, postfix)
69
+ else:
70
+ return ppath
71
+
72
+
73
+ def init_logger(log_file=None, log_file_level=logging.NOTSET, from_scratch=False):
74
+ from coloredlogs import ColoredFormatter
75
+ import tensorflow as tf
76
+
77
+ fmt = "[%(asctime)s %(levelname)s] %(message)s"
78
+ log_format = ColoredFormatter(fmt=fmt)
79
+ # log_format = logging.Formatter()
80
+ logger = logging.getLogger()
81
+ logger.setLevel(log_file_level)
82
+
83
+ console_handler = logging.StreamHandler()
84
+ console_handler.setFormatter(log_format)
85
+ logger.handlers = [console_handler]
86
+
87
+ if log_file and log_file != '':
88
+ if from_scratch and tf.io.gfile.exists(log_file):
89
+ logger.warning('Removing previous log file: %s' % log_file)
90
+ tf.io.gfile.remove(log_file)
91
+ path = os.path.dirname(log_file)
92
+ os.makedirs(path, exist_ok=True)
93
+ file_handler = logging.FileHandler(log_file)
94
+ file_handler.setLevel(log_file_level)
95
+ file_handler.setFormatter(log_format)
96
+ logger.addHandler(file_handler)
97
+
98
+ return logger
99
+
100
+
101
+ def OverWriteCjjPy(root='.'):
102
+ # import difflib
103
+ # diff = difflib.HtmlDiff()
104
+ cnt = 0
105
+ golden_cjjpy = os.path.join(root, 'cjjpy.py')
106
+ # golden_content = open(golden_cjjpy).readlines()
107
+ for dir, folder, file in os.walk(root):
108
+ for f in file:
109
+ if f == 'cjjpy.py':
110
+ cjjpy = '%s/%s' % (dir, f)
111
+ # content = open(cjjpy).readlines()
112
+ # d = diff.make_file(golden_content, content)
113
+ cnt += 1
114
+ print('[%d]: %s' % (cnt, cjjpy))
115
+ os.system('cp %s %s' % (golden_cjjpy, cjjpy))
116
+
117
+
118
+ def ChangeFileFormat(filename, new_fmt):
119
+ assert type(filename) is str and type(new_fmt) is str
120
+ spt = filename.split('.')
121
+ if len(spt) == 0:
122
+ return filename
123
+ else:
124
+ return filename.replace('.' + spt[-1], new_fmt)
125
+
126
+
127
+ def CountLines(fname):
128
+ with open(fname, 'rb') as f:
129
+ count = 0
130
+ last_data = '\n'
131
+ while True:
132
+ data = f.read(0x400000)
133
+ if not data:
134
+ break
135
+ count += data.count(b'\n')
136
+ last_data = data
137
+ if last_data[-1:] != b'\n':
138
+ count += 1 # Remove this if a wc-like count is needed
139
+ return count
140
+
141
+
142
+ def GetDate():
143
+ return str(datetime.datetime.now())[5:10].replace('-', '')
144
+
145
+
146
+ def TimeClock(seconds):
147
+ sec = int(seconds)
148
+ hour = int(sec / 3600)
149
+ min = int((sec - hour * 3600) / 60)
150
+ ssec = float(seconds) - hour * 3600 - min * 60
151
+ # return '%dh %dm %.2fs' % (hour, min, ssec)
152
+ return '{0:>2d}h{1:>3d}m{2:>6.2f}s'.format(hour, min, ssec)
153
+
154
+
155
+ def StripAll(text):
156
+ return text.strip().replace('\t', '').replace('\n', '').replace(' ', '')
157
+
158
+
159
+ def GetBracket(text, bracket, en_br=False):
160
+ # input should be aa(bb)cc, True for bracket, False for text
161
+ if bracket:
162
+ try:
163
+ return re.findall('\((.*?)\)', text.strip())[-1]
164
+ except:
165
+ return ''
166
+ else:
167
+ if en_br:
168
+ text = re.sub('\(.*?\)', '', text.strip())
169
+ return re.sub('(.*?)', '', text.strip())
170
+
171
+
172
+ def CharLang(uchar, lang):
173
+ assert lang.lower() in ['en', 'cn', 'zh']
174
+ if lang.lower() in ['cn', 'zh']:
175
+ if uchar >= '\u4e00' and uchar <= '\u9fa5':
176
+ return True
177
+ else:
178
+ return False
179
+ elif lang.lower() == 'en':
180
+ if (uchar <= 'Z' and uchar >= 'A') or (uchar <= 'z' and uchar >= 'a'):
181
+ return True
182
+ else:
183
+ return False
184
+ else:
185
+ raise NotImplementedError
186
+
187
+
188
+ def WordLang(word, lang):
189
+ for i in word.strip():
190
+ if i.isspace(): continue
191
+ if not CharLang(i, lang):
192
+ return False
193
+ return True
194
+
195
+
196
+ def SortDict(_dict, reverse=True):
197
+ assert type(_dict) is dict
198
+ return sorted(_dict.items(), key=lambda d: d[1], reverse=reverse)
199
+
200
+
201
+ def lark(content='test'):
202
+ print(content)
203
+
204
+
205
+ if __name__ == '__main__':
206
+ parser = argparse.ArgumentParser()
207
+
208
+ parser.add_argument('--diff', nargs=2,
209
+ help='show difference between two files, shown in downloads/diff.html')
210
+ parser.add_argument('--de_unicode', action='store_true', default=False,
211
+ help='remove unicode characters')
212
+ parser.add_argument('--link_entity', action='store_true', default=False,
213
+ help='')
214
+ parser.add_argument('--max_comm_len', action='store_true', default=False,
215
+ help='')
216
+ parser.add_argument('--search', nargs=2,
217
+ help='search key from file, 2 args: file name & key')
218
+ parser.add_argument('--email', nargs=2,
219
+ help='sending emails, 2 args: subject & content')
220
+ parser.add_argument('--overwrite', action='store_true', default=None,
221
+ help='overwrite all cjjpy under given *dir* based on *dir*/cjjpy.py')
222
+ parser.add_argument('--replace', nargs=3,
223
+ help='replace char, 3 args: file name & replaced char & replacer char')
224
+ parser.add_argument('--lark', nargs=1)
225
+ parser.add_argument('--get_hdfs', nargs=2,
226
+ help='easy copy from hdfs to local fs, 2 args: remote_file/dir & local_dir')
227
+ parser.add_argument('--put_hdfs', nargs=2,
228
+ help='easy put from local fs to hdfs, 2 args: local_file/dir & remote_dir')
229
+ parser.add_argument('--length_stats', nargs=1,
230
+ help='simple token lengths distribution of a line-by-line file')
231
+
232
+ args = parser.parse_args()
233
+
234
+ if args.overwrite:
235
+ print('* Overwriting cjjpy...')
236
+ OverWriteCjjPy()
237
+
238
+ if args.lark:
239
+ try:
240
+ content = args.lark[0]
241
+ except:
242
+ content = 'running complete'
243
+ print(f'* Larking "{content}"...')
244
+ lark(content)
245
+
246
+ if args.length_stats:
247
+ file = args.length_stats[0]
248
+ print(f'* Working on {file} lengths statistics...')
249
+ LengthStats(file)
src/mrc_client/seq2seq/convert_pl_checkpoint_to_hf.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ import os
4
+ from pathlib import Path
5
+ from typing import Dict, List
6
+
7
+ import fire
8
+ import torch
9
+
10
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
11
+ from transformers.utils.logging import get_logger
12
+
13
+
14
+ logger = get_logger(__name__)
15
+
16
+
17
+ def remove_prefix(text: str, prefix: str):
18
+ if text.startswith(prefix):
19
+ return text[len(prefix) :]
20
+ return text # or whatever
21
+
22
+
23
+ def sanitize(sd):
24
+ return {remove_prefix(k, "model."): v for k, v in sd.items()}
25
+
26
+
27
+ def average_state_dicts(state_dicts: List[Dict[str, torch.Tensor]]):
28
+ new_sd = {}
29
+ for k in state_dicts[0].keys():
30
+ tensors = [sd[k] for sd in state_dicts]
31
+ new_t = sum(tensors) / len(tensors)
32
+ assert isinstance(new_t, torch.Tensor)
33
+ new_sd[k] = new_t
34
+ return new_sd
35
+
36
+
37
+ def convert_pl_to_hf(pl_ckpt_path: str, hf_src_model_dir: str, save_path: str) -> None:
38
+ """Cleanup a pytorch-lightning .ckpt file or experiment dir and save a huggingface model with that state dict.
39
+ Silently allows extra pl keys (like teacher.) Puts all ckpt models into CPU RAM at once!
40
+
41
+ Args:
42
+ pl_ckpt_path (:obj:`str`): Path to a .ckpt file saved by pytorch_lightning or dir containing ckpt files.
43
+ If a directory is passed, all .ckpt files inside it will be averaged!
44
+ hf_src_model_dir (:obj:`str`): Path to a directory containing a correctly shaped checkpoint
45
+ save_path (:obj:`str`): Directory to save the new model
46
+
47
+ """
48
+ hf_model = AutoModelForSeq2SeqLM.from_pretrained(hf_src_model_dir)
49
+ if os.path.isfile(pl_ckpt_path):
50
+ ckpt_files = [pl_ckpt_path]
51
+ else:
52
+ assert os.path.isdir(pl_ckpt_path)
53
+ ckpt_files = list(Path(pl_ckpt_path).glob("*.ckpt"))
54
+ assert ckpt_files, f"could not find any ckpt files inside the {pl_ckpt_path} directory"
55
+
56
+ if len(ckpt_files) > 1:
57
+ logger.info(f"averaging the weights of {ckpt_files}")
58
+
59
+ state_dicts = [sanitize(torch.load(x, map_location="cpu")["state_dict"]) for x in ckpt_files]
60
+ state_dict = average_state_dicts(state_dicts)
61
+
62
+ missing, unexpected = hf_model.load_state_dict(state_dict, strict=False)
63
+ assert not missing, f"missing keys: {missing}"
64
+ hf_model.save_pretrained(save_path)
65
+ try:
66
+ tok = AutoTokenizer.from_pretrained(hf_src_model_dir)
67
+ tok.save_pretrained(save_path)
68
+ except Exception:
69
+ pass
70
+ # dont copy tokenizer if cant
71
+
72
+
73
+ if __name__ == "__main__":
74
+ fire.Fire(convert_pl_to_hf)
src/mrc_client/seq2seq/finetune.py ADDED
@@ -0,0 +1,465 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ import argparse
4
+ import glob
5
+ import logging
6
+ import os
7
+ import sys
8
+ import time
9
+ from collections import defaultdict
10
+ from pathlib import Path
11
+ from typing import Dict, List, Tuple
12
+
13
+ import numpy as np
14
+ import pytorch_lightning as pl
15
+ import torch
16
+ from torch.utils.data import DataLoader
17
+
18
+ from callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback, get_early_stopping_callback
19
+ from transformers import MBartTokenizer, T5ForConditionalGeneration
20
+ try:
21
+ from transformers.modeling_bart import shift_tokens_right
22
+ except:
23
+ from transformers.models.bart.modeling_bart import shift_tokens_right
24
+ from seq2seq_utils import (
25
+ ROUGE_KEYS,
26
+ LegacySeq2SeqDataset,
27
+ Seq2SeqDataset,
28
+ UniQASeq2SeqDataset,
29
+ assert_all_frozen,
30
+ calculate_bleu,
31
+ calculate_rouge,
32
+ check_output_dir,
33
+ flatten_list,
34
+ freeze_embeds,
35
+ freeze_params,
36
+ get_git_info,
37
+ label_smoothed_nll_loss,
38
+ lmap,
39
+ pickle_save,
40
+ save_git_info,
41
+ save_json,
42
+ use_task_specific_params,
43
+ )
44
+
45
+
46
+ # need the parent dir module
47
+ sys.path.insert(2, str(Path(__file__).resolve().parents[1]))
48
+ from lightning_base import BaseTransformer, add_generic_args, generic_train # noqa
49
+
50
+
51
+ logger = logging.getLogger(__name__)
52
+
53
+
54
+ class SummarizationModule(BaseTransformer):
55
+ mode = "summarization"
56
+ loss_names = ["loss"]
57
+ metric_names = ROUGE_KEYS
58
+ default_val_metric = "rouge2"
59
+
60
+ def __init__(self, hparams, **kwargs):
61
+ if hparams.sortish_sampler and hparams.gpus > 1:
62
+ hparams.replace_sampler_ddp = False
63
+ elif hparams.max_tokens_per_batch is not None:
64
+ if hparams.gpus > 1:
65
+ raise NotImplementedError("Dynamic Batch size does not work for multi-gpu training")
66
+ if hparams.sortish_sampler:
67
+ raise ValueError("--sortish_sampler and --max_tokens_per_batch may not be used simultaneously")
68
+
69
+ super().__init__(hparams, num_labels=None, mode=self.mode, **kwargs)
70
+ use_task_specific_params(self.model, "summarization")
71
+ # TODO: hard-encoded length constraint
72
+ self.model.config.min_length = hparams.min_target_length
73
+ self.model.config.max_length = hparams.max_target_length
74
+ save_git_info(self.hparams.output_dir)
75
+ self.metrics_save_path = Path(self.output_dir) / "metrics.json"
76
+ self.hparams_save_path = Path(self.output_dir) / "hparams.pkl"
77
+ pickle_save(self.hparams, self.hparams_save_path)
78
+ self.step_count = 0
79
+ self.metrics = defaultdict(list)
80
+ self.model_type = self.config.model_type
81
+ self.vocab_size = self.config.tgt_vocab_size if self.model_type == "fsmt" else self.config.vocab_size
82
+
83
+ self.dataset_kwargs: dict = dict(
84
+ data_dir=self.hparams.data_dir,
85
+ max_source_length=self.hparams.max_source_length,
86
+ prefix=self.model.config.prefix or "",
87
+ )
88
+ n_observations_per_split = {
89
+ "train": self.hparams.n_train,
90
+ "val": self.hparams.n_val,
91
+ "test": self.hparams.n_test,
92
+ }
93
+ self.n_obs = {k: v if v >= 0 else None for k, v in n_observations_per_split.items()}
94
+
95
+ self.target_lens = {
96
+ "train": self.hparams.max_target_length,
97
+ "val": self.hparams.val_max_target_length,
98
+ "test": self.hparams.test_max_target_length,
99
+ }
100
+ assert self.target_lens["train"] <= self.target_lens["val"], f"target_lens: {self.target_lens}"
101
+ assert self.target_lens["train"] <= self.target_lens["test"], f"target_lens: {self.target_lens}"
102
+ if self.hparams.freeze_embeds:
103
+ freeze_embeds(self.model)
104
+ if self.hparams.freeze_encoder:
105
+ freeze_params(self.model.get_encoder())
106
+ assert_all_frozen(self.model.get_encoder())
107
+
108
+ self.hparams.git_sha = get_git_info()["repo_sha"]
109
+ self.num_workers = hparams.num_workers
110
+ self.decoder_start_token_id = None # default to config
111
+ if self.model.config.decoder_start_token_id is None and isinstance(self.tokenizer, MBartTokenizer):
112
+ self.decoder_start_token_id = self.tokenizer.lang_code_to_id[hparams.tgt_lang]
113
+ self.model.config.decoder_start_token_id = self.decoder_start_token_id
114
+
115
+ if 'unifiedqa' in self.hparams.model_name_or_path:
116
+ self.dataset_class = (UniQASeq2SeqDataset)
117
+ else:
118
+ self.dataset_class = (
119
+ Seq2SeqDataset if hasattr(self.tokenizer, "prepare_seq2seq_batch") else LegacySeq2SeqDataset
120
+ )
121
+ self.already_saved_batch = False
122
+ self.eval_beams = self.model.config.num_beams if self.hparams.eval_beams is None else self.hparams.eval_beams
123
+ if self.hparams.eval_max_gen_length is not None:
124
+ self.eval_max_length = self.hparams.eval_max_gen_length
125
+ else:
126
+ self.eval_max_length = self.model.config.max_length
127
+ if self.hparams.min_target_length is not None:
128
+ self.min_length = self.hparams.min_target_length
129
+ else:
130
+ self.min_length = self.model.config.min_length
131
+
132
+ self.val_metric = self.default_val_metric if self.hparams.val_metric is None else self.hparams.val_metric
133
+
134
+ def save_readable_batch(self, batch: Dict[str, torch.Tensor]) -> Dict[str, List[str]]:
135
+ """A debugging utility"""
136
+ readable_batch = {
137
+ k: self.tokenizer.batch_decode(v.tolist()) if "mask" not in k else v.shape for k, v in batch.items()
138
+ }
139
+ save_json(readable_batch, Path(self.output_dir) / "text_batch.json")
140
+ save_json({k: v.tolist() for k, v in batch.items()}, Path(self.output_dir) / "tok_batch.json")
141
+
142
+ self.already_saved_batch = True
143
+ return readable_batch
144
+
145
+ def forward(self, input_ids, **kwargs):
146
+ return self.model(input_ids, **kwargs)
147
+
148
+ def ids_to_clean_text(self, generated_ids: List[int]):
149
+ gen_text = self.tokenizer.batch_decode(
150
+ generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
151
+ )
152
+ return lmap(str.strip, gen_text)
153
+
154
+ def _step(self, batch: dict) -> Tuple:
155
+ pad_token_id = self.tokenizer.pad_token_id
156
+ src_ids, src_mask = batch["input_ids"], batch["attention_mask"]
157
+ tgt_ids = batch["labels"]
158
+ if isinstance(self.model, T5ForConditionalGeneration):
159
+ decoder_input_ids = self.model._shift_right(tgt_ids)
160
+ else:
161
+ decoder_input_ids = shift_tokens_right(tgt_ids, pad_token_id)
162
+ if not self.already_saved_batch: # This would be slightly better if it only happened on rank zero
163
+ batch["decoder_input_ids"] = decoder_input_ids
164
+ self.save_readable_batch(batch)
165
+
166
+ outputs = self(src_ids, attention_mask=src_mask, decoder_input_ids=decoder_input_ids, use_cache=False)
167
+ lm_logits = outputs[0]
168
+ if self.hparams.label_smoothing == 0:
169
+ # Same behavior as modeling_bart.py, besides ignoring pad_token_id
170
+ ce_loss_fct = torch.nn.CrossEntropyLoss(ignore_index=pad_token_id)
171
+
172
+ assert lm_logits.shape[-1] == self.vocab_size
173
+ loss = ce_loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), tgt_ids.view(-1))
174
+ else:
175
+ lprobs = torch.nn.functional.log_softmax(lm_logits, dim=-1)
176
+ loss, nll_loss = label_smoothed_nll_loss(
177
+ lprobs, tgt_ids, self.hparams.label_smoothing, ignore_index=pad_token_id
178
+ )
179
+ return (loss,)
180
+
181
+ @property
182
+ def pad(self) -> int:
183
+ return self.tokenizer.pad_token_id
184
+
185
+ def training_step(self, batch, batch_idx) -> Dict:
186
+ loss_tensors = self._step(batch)
187
+
188
+ logs = {name: loss for name, loss in zip(self.loss_names, loss_tensors)}
189
+ # tokens per batch
190
+ logs["tpb"] = batch["input_ids"].ne(self.pad).sum() + batch["labels"].ne(self.pad).sum()
191
+ logs["bs"] = batch["input_ids"].shape[0]
192
+ logs["src_pad_tok"] = batch["input_ids"].eq(self.pad).sum()
193
+ logs["src_pad_frac"] = batch["input_ids"].eq(self.pad).float().mean()
194
+ # TODO(SS): make a wandb summary metric for this
195
+ return {"loss": loss_tensors[0], "log": logs}
196
+
197
+ def validation_step(self, batch, batch_idx) -> Dict:
198
+ return self._generative_step(batch)
199
+
200
+ def validation_epoch_end(self, outputs, prefix="val") -> Dict:
201
+ self.step_count += 1
202
+ losses = {k: torch.stack([x[k] for x in outputs]).mean() for k in self.loss_names}
203
+ loss = losses["loss"]
204
+ generative_metrics = {
205
+ k: np.array([x[k] for x in outputs]).mean() for k in self.metric_names + ["gen_time", "gen_len"]
206
+ }
207
+ metric_val = (
208
+ generative_metrics[self.val_metric] if self.val_metric in generative_metrics else losses[self.val_metric]
209
+ )
210
+ metric_tensor: torch.FloatTensor = torch.tensor(metric_val).type_as(loss)
211
+ generative_metrics.update({k: v.item() for k, v in losses.items()})
212
+ losses.update(generative_metrics)
213
+ all_metrics = {f"{prefix}_avg_{k}": x for k, x in losses.items()}
214
+ all_metrics["step_count"] = self.step_count
215
+ self.metrics[prefix].append(all_metrics) # callback writes this to self.metrics_save_path
216
+ preds = flatten_list([x["preds"] for x in outputs])
217
+ return {
218
+ "log": all_metrics,
219
+ "preds": preds,
220
+ f"{prefix}_loss": loss,
221
+ f"{prefix}_{self.val_metric}": metric_tensor,
222
+ }
223
+
224
+ def calc_generative_metrics(self, preds, target) -> Dict:
225
+ return calculate_rouge(preds, target)
226
+
227
+ def _generative_step(self, batch: dict) -> dict:
228
+ t0 = time.time()
229
+
230
+ # parser.add_argument('--eval_max_gen_length', type=int, default=None, help='never generate more than n tokens')
231
+ generated_ids = self.model.generate(
232
+ batch["input_ids"],
233
+ attention_mask=batch["attention_mask"],
234
+ use_cache=True,
235
+ decoder_start_token_id=self.decoder_start_token_id,
236
+ num_beams=self.eval_beams,
237
+ max_length=self.eval_max_length,
238
+ min_length=self.min_length
239
+ )
240
+ gen_time = (time.time() - t0) / batch["input_ids"].shape[0]
241
+ preds: List[str] = self.ids_to_clean_text(generated_ids)
242
+ target: List[str] = self.ids_to_clean_text(batch["labels"])
243
+ loss_tensors = self._step(batch)
244
+ base_metrics = {name: loss for name, loss in zip(self.loss_names, loss_tensors)}
245
+ rouge: Dict = self.calc_generative_metrics(preds, target)
246
+ summ_len = np.mean(lmap(len, generated_ids))
247
+ base_metrics.update(gen_time=gen_time, gen_len=summ_len, preds=preds, target=target, **rouge)
248
+ return base_metrics
249
+
250
+ def test_step(self, batch, batch_idx):
251
+ return self._generative_step(batch)
252
+
253
+ def test_epoch_end(self, outputs):
254
+ return self.validation_epoch_end(outputs, prefix="test")
255
+
256
+ def get_dataset(self, type_path) -> Seq2SeqDataset:
257
+ n_obs = self.n_obs[type_path]
258
+ max_target_length = self.target_lens[type_path]
259
+ dataset = self.dataset_class(
260
+ self.tokenizer,
261
+ type_path=type_path,
262
+ n_obs=n_obs,
263
+ max_target_length=max_target_length,
264
+ **self.dataset_kwargs,
265
+ )
266
+ return dataset
267
+
268
+ def get_dataloader(self, type_path: str, batch_size: int, shuffle: bool = False) -> DataLoader:
269
+ dataset = self.get_dataset(type_path)
270
+
271
+ if self.hparams.sortish_sampler and type_path != "test" and type_path != "val":
272
+ sampler = dataset.make_sortish_sampler(batch_size, distributed=self.hparams.gpus > 1)
273
+ return DataLoader(
274
+ dataset,
275
+ batch_size=batch_size,
276
+ collate_fn=dataset.collate_fn,
277
+ shuffle=False,
278
+ num_workers=self.num_workers,
279
+ sampler=sampler,
280
+ )
281
+
282
+ elif self.hparams.max_tokens_per_batch is not None and type_path != "test" and type_path != "val":
283
+ batch_sampler = dataset.make_dynamic_sampler(
284
+ self.hparams.max_tokens_per_batch, distributed=self.hparams.gpus > 1
285
+ )
286
+ return DataLoader(
287
+ dataset,
288
+ batch_sampler=batch_sampler,
289
+ collate_fn=dataset.collate_fn,
290
+ # shuffle=False,
291
+ num_workers=self.num_workers,
292
+ # batch_size=None,
293
+ )
294
+ else:
295
+ return DataLoader(
296
+ dataset,
297
+ batch_size=batch_size,
298
+ collate_fn=dataset.collate_fn,
299
+ shuffle=shuffle,
300
+ num_workers=self.num_workers,
301
+ sampler=None,
302
+ )
303
+
304
+ def train_dataloader(self) -> DataLoader:
305
+ dataloader = self.get_dataloader("train", batch_size=self.hparams.train_batch_size, shuffle=True)
306
+ return dataloader
307
+
308
+ def val_dataloader(self) -> DataLoader:
309
+ return self.get_dataloader("val", batch_size=self.hparams.eval_batch_size)
310
+
311
+ def test_dataloader(self) -> DataLoader:
312
+ return self.get_dataloader("test", batch_size=self.hparams.eval_batch_size)
313
+
314
+ @staticmethod
315
+ def add_model_specific_args(parser, root_dir):
316
+ BaseTransformer.add_model_specific_args(parser, root_dir)
317
+ add_generic_args(parser, root_dir)
318
+ parser.add_argument(
319
+ "--min_target_length",
320
+ default=1,
321
+ type=int,
322
+ help="The minimum total target sequence length after tokenization.",
323
+ )
324
+ parser.add_argument(
325
+ "--max_source_length",
326
+ default=1024,
327
+ type=int,
328
+ help="The maximum total input sequence length after tokenization. Sequences longer "
329
+ "than this will be truncated, sequences shorter will be padded.",
330
+ )
331
+ parser.add_argument(
332
+ "--max_target_length",
333
+ default=56,
334
+ type=int,
335
+ help="The maximum total input sequence length after tokenization. Sequences longer "
336
+ "than this will be truncated, sequences shorter will be padded.",
337
+ )
338
+ parser.add_argument(
339
+ "--val_max_target_length",
340
+ default=142, # these defaults are optimized for CNNDM. For xsum, see README.md.
341
+ type=int,
342
+ help="The maximum total input sequence length after tokenization. Sequences longer "
343
+ "than this will be truncated, sequences shorter will be padded.",
344
+ )
345
+ parser.add_argument(
346
+ "--test_max_target_length",
347
+ default=142,
348
+ type=int,
349
+ help="The maximum total input sequence length after tokenization. Sequences longer "
350
+ "than this will be truncated, sequences shorter will be padded.",
351
+ )
352
+ parser.add_argument("--freeze_encoder", action="store_true")
353
+ parser.add_argument("--freeze_embeds", action="store_true")
354
+ parser.add_argument("--sortish_sampler", action="store_true", default=False)
355
+ parser.add_argument("--overwrite_output_dir", action="store_true", default=False)
356
+ parser.add_argument("--max_tokens_per_batch", type=int, default=None)
357
+ parser.add_argument("--logger_name", type=str, choices=["default", "wandb", "wandb_shared"], default="default")
358
+ parser.add_argument("--n_train", type=int, default=-1, required=False, help="# examples. -1 means use all.")
359
+ parser.add_argument("--n_val", type=int, default=500, required=False, help="# examples. -1 means use all.")
360
+ parser.add_argument("--n_test", type=int, default=-1, required=False, help="# examples. -1 means use all.")
361
+ parser.add_argument(
362
+ "--task", type=str, default="summarization", required=False, help="# examples. -1 means use all."
363
+ )
364
+ parser.add_argument("--label_smoothing", type=float, default=0.0, required=False)
365
+ parser.add_argument("--src_lang", type=str, default="", required=False)
366
+ parser.add_argument("--tgt_lang", type=str, default="", required=False)
367
+ parser.add_argument("--eval_beams", type=int, default=None, required=False)
368
+ parser.add_argument(
369
+ "--val_metric", type=str, default=None, required=False, choices=["bleu", "rouge2", "loss", None]
370
+ )
371
+ parser.add_argument("--eval_max_gen_length", type=int, default=None, help="never generate more than n tokens")
372
+ parser.add_argument("--save_top_k", type=int, default=1, required=False, help="How many checkpoints to save")
373
+ parser.add_argument(
374
+ "--early_stopping_patience",
375
+ type=int,
376
+ default=-1,
377
+ required=False,
378
+ help="-1 means never early stop. early_stopping_patience is measured in validation checks, not epochs. So val_check_interval will effect it.",
379
+ )
380
+ return parser
381
+
382
+
383
+ class TranslationModule(SummarizationModule):
384
+ mode = "translation"
385
+ loss_names = ["loss"]
386
+ metric_names = ["bleu"]
387
+ default_val_metric = "bleu"
388
+
389
+ def __init__(self, hparams, **kwargs):
390
+ super().__init__(hparams, **kwargs)
391
+ self.dataset_kwargs["src_lang"] = hparams.src_lang
392
+ self.dataset_kwargs["tgt_lang"] = hparams.tgt_lang
393
+
394
+ def calc_generative_metrics(self, preds, target) -> dict:
395
+ return calculate_bleu(preds, target)
396
+
397
+
398
+ def main(args, model=None) -> SummarizationModule:
399
+ Path(args.output_dir).mkdir(exist_ok=True)
400
+ check_output_dir(args, expected_items=3)
401
+
402
+ if model is None:
403
+ if "summarization" in args.task:
404
+ model: SummarizationModule = SummarizationModule(args)
405
+ else:
406
+ model: SummarizationModule = TranslationModule(args)
407
+ dataset = Path(args.data_dir).name
408
+ if (
409
+ args.logger_name == "default"
410
+ or args.fast_dev_run
411
+ or str(args.output_dir).startswith("/tmp")
412
+ or str(args.output_dir).startswith("/var")
413
+ ):
414
+ logger = True # don't pollute wandb logs unnecessarily
415
+ elif args.logger_name == "wandb":
416
+ from pytorch_lightning.loggers import WandbLogger
417
+
418
+ project = os.environ.get("WANDB_PROJECT", dataset)
419
+ logger = WandbLogger(name=model.output_dir.name, project=project)
420
+
421
+ elif args.logger_name == "wandb_shared":
422
+ from pytorch_lightning.loggers import WandbLogger
423
+
424
+ logger = WandbLogger(name=model.output_dir.name, project=f"hf_{dataset}")
425
+
426
+ if args.early_stopping_patience >= 0:
427
+ es_callback = get_early_stopping_callback(model.val_metric, args.early_stopping_patience)
428
+ else:
429
+ es_callback = False
430
+
431
+ lower_is_better = args.val_metric == "loss"
432
+ trainer: pl.Trainer = generic_train(
433
+ model,
434
+ args,
435
+ logging_callback=Seq2SeqLoggingCallback(),
436
+ checkpoint_callback=get_checkpoint_callback(
437
+ args.output_dir, model.val_metric, args.save_top_k, lower_is_better
438
+ ),
439
+ early_stopping_callback=es_callback,
440
+ logger=logger,
441
+ )
442
+ pickle_save(model.hparams, model.output_dir / "hparams.pkl")
443
+ if not args.do_predict:
444
+ return model
445
+
446
+ model.hparams.test_checkpoint = ""
447
+ checkpoints = list(sorted(glob.glob(os.path.join(args.output_dir, "*.ckpt"), recursive=True)))
448
+ if checkpoints:
449
+ model.hparams.test_checkpoint = checkpoints[-1]
450
+ trainer.resume_from_checkpoint = checkpoints[-1]
451
+ trainer.logger.log_hyperparams(model.hparams)
452
+
453
+ # test() without a model tests using the best checkpoint automatically
454
+ trainer.test()
455
+ return model
456
+
457
+
458
+ if __name__ == "__main__":
459
+ parser = argparse.ArgumentParser()
460
+ parser = pl.Trainer.add_argparse_args(parser)
461
+ parser = SummarizationModule.add_model_specific_args(parser, os.getcwd())
462
+
463
+ args = parser.parse_args()
464
+
465
+ main(args)
src/mrc_client/seq2seq/finetune_t5.sh ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Add parent directory to python path to access lightning_base.py
2
+ export PYTHONPATH="../":"${PYTHONPATH}"
3
+
4
+ python finetune.py \
5
+ --data_dir=$CNN_DIR \
6
+ --learning_rate=3e-5 \
7
+ --train_batch_size=$BS \
8
+ --eval_batch_size=$BS \
9
+ --output_dir=$OUTPUT_DIR \
10
+ --max_source_length=512 \
11
+ --max_target_length=56 \
12
+ --val_check_interval=0.1 --n_val=200 \
13
+ --do_train --do_predict \
14
+ "$@"
src/mrc_client/seq2seq/finetune_trainer.py ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import sys
4
+ from dataclasses import dataclass, field
5
+ from typing import Optional
6
+
7
+ from seq2seq_trainer import Seq2SeqTrainer
8
+ from seq2seq_training_args import Seq2SeqTrainingArguments
9
+ from transformers import AutoConfig, AutoModelForSeq2SeqLM, AutoTokenizer, HfArgumentParser, MBartTokenizer, set_seed
10
+ from transformers.trainer_utils import EvaluationStrategy
11
+ from seq2seq_utils import (
12
+ Seq2SeqDataCollator,
13
+ Seq2SeqDataset,
14
+ assert_all_frozen,
15
+ build_compute_metrics_fn,
16
+ check_output_dir,
17
+ freeze_embeds,
18
+ freeze_params,
19
+ lmap,
20
+ save_json,
21
+ use_task_specific_params,
22
+ write_txt_file,
23
+ )
24
+
25
+
26
+ logger = logging.getLogger(__name__)
27
+
28
+
29
+ @dataclass
30
+ class ModelArguments:
31
+ """
32
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
33
+ """
34
+
35
+ model_name_or_path: str = field(
36
+ metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
37
+ )
38
+ config_name: Optional[str] = field(
39
+ default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
40
+ )
41
+ tokenizer_name: Optional[str] = field(
42
+ default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
43
+ )
44
+ cache_dir: Optional[str] = field(
45
+ default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
46
+ )
47
+ freeze_encoder: bool = field(default=False, metadata={"help": "Whether tp freeze the encoder."})
48
+ freeze_embeds: bool = field(default=False, metadata={"help": "Whether to freeze the embeddings."})
49
+
50
+
51
+ @dataclass
52
+ class DataTrainingArguments:
53
+ """
54
+ Arguments pertaining to what data we are going to input our model for training and eval.
55
+ """
56
+
57
+ data_dir: str = field(
58
+ metadata={"help": "The input data dir. Should contain the .tsv files (or other data files) for the task."}
59
+ )
60
+ task: Optional[str] = field(
61
+ default="summarization",
62
+ metadata={"help": "Task name, summarization (or summarization_{dataset} for pegasus) or translation"},
63
+ )
64
+ max_source_length: Optional[int] = field(
65
+ default=1024,
66
+ metadata={
67
+ "help": "The maximum total input sequence length after tokenization. Sequences longer "
68
+ "than this will be truncated, sequences shorter will be padded."
69
+ },
70
+ )
71
+ max_target_length: Optional[int] = field(
72
+ default=128,
73
+ metadata={
74
+ "help": "The maximum total sequence length for target text after tokenization. Sequences longer "
75
+ "than this will be truncated, sequences shorter will be padded."
76
+ },
77
+ )
78
+ val_max_target_length: Optional[int] = field(
79
+ default=142,
80
+ metadata={
81
+ "help": "The maximum total sequence length for validation target text after tokenization. Sequences longer "
82
+ "than this will be truncated, sequences shorter will be padded."
83
+ },
84
+ )
85
+ test_max_target_length: Optional[int] = field(
86
+ default=142,
87
+ metadata={
88
+ "help": "The maximum total sequence length for test target text after tokenization. Sequences longer "
89
+ "than this will be truncated, sequences shorter will be padded."
90
+ },
91
+ )
92
+ n_train: Optional[int] = field(default=-1, metadata={"help": "# training examples. -1 means use all."})
93
+ n_val: Optional[int] = field(default=-1, metadata={"help": "# validation examples. -1 means use all."})
94
+ n_test: Optional[int] = field(default=-1, metadata={"help": "# test examples. -1 means use all."})
95
+ src_lang: Optional[str] = field(default=None, metadata={"help": "Source language id for translation."})
96
+ tgt_lang: Optional[str] = field(default=None, metadata={"help": "Target language id for translation."})
97
+ eval_beams: Optional[int] = field(default=None, metadata={"help": "# num_beams to use for evaluation."})
98
+ ignore_pad_token_for_loss: bool = field(
99
+ default=True,
100
+ metadata={"help": "If only pad tokens should be ignored. This assumes that `config.pad_token_id` is defined."},
101
+ )
102
+
103
+
104
+ def main():
105
+ # See all possible arguments in src/transformers/training_args.py
106
+ # or by passing the --help flag to this script.
107
+ # We now keep distinct sets of args, for a cleaner separation of concerns.
108
+
109
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments))
110
+
111
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
112
+ # If we pass only one argument to the script and it's the path to a json file,
113
+ # let's parse it to get our arguments.
114
+ model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
115
+ else:
116
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
117
+
118
+ check_output_dir(training_args)
119
+
120
+ # Setup logging
121
+ logging.basicConfig(
122
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
123
+ datefmt="%m/%d/%Y %H:%M:%S",
124
+ level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN,
125
+ )
126
+ logger.warning(
127
+ "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
128
+ training_args.local_rank,
129
+ training_args.device,
130
+ training_args.n_gpu,
131
+ bool(training_args.local_rank != -1),
132
+ training_args.fp16,
133
+ )
134
+ logger.info("Training/evaluation parameters %s", training_args)
135
+
136
+ # Set seed
137
+ set_seed(training_args.seed)
138
+
139
+ # Load pretrained model and tokenizer
140
+ #
141
+ # Distributed training:
142
+ # The .from_pretrained methods guarantee that only one local process can concurrently
143
+ # download model & vocab.
144
+
145
+ config = AutoConfig.from_pretrained(
146
+ model_args.config_name if model_args.config_name else model_args.model_name_or_path,
147
+ cache_dir=model_args.cache_dir,
148
+ )
149
+
150
+ extra_model_params = ("encoder_layerdrop", "decoder_layerdrop", "dropout", "attention_dropout")
151
+ for p in extra_model_params:
152
+ if getattr(training_args, p, None):
153
+ assert hasattr(config, p), f"({config.__class__.__name__}) doesn't have a `{p}` attribute"
154
+ setattr(config, p, getattr(training_args, p))
155
+
156
+ tokenizer = AutoTokenizer.from_pretrained(
157
+ model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
158
+ cache_dir=model_args.cache_dir,
159
+ )
160
+ model = AutoModelForSeq2SeqLM.from_pretrained(
161
+ model_args.model_name_or_path,
162
+ from_tf=".ckpt" in model_args.model_name_or_path,
163
+ config=config,
164
+ cache_dir=model_args.cache_dir,
165
+ )
166
+
167
+ # use task specific params
168
+ use_task_specific_params(model, data_args.task)
169
+
170
+ # set num_beams for evaluation
171
+ if data_args.eval_beams is None:
172
+ data_args.eval_beams = model.config.num_beams
173
+
174
+ # set decoder_start_token_id for MBart
175
+ if model.config.decoder_start_token_id is None and isinstance(tokenizer, MBartTokenizer):
176
+ assert (
177
+ data_args.tgt_lang is not None and data_args.src_lang is not None
178
+ ), "mBart requires --tgt_lang and --src_lang"
179
+ model.config.decoder_start_token_id = tokenizer.lang_code_to_id[data_args.tgt_lang]
180
+
181
+ if model_args.freeze_embeds:
182
+ freeze_embeds(model)
183
+ if model_args.freeze_encoder:
184
+ freeze_params(model.get_encoder())
185
+ assert_all_frozen(model.get_encoder())
186
+
187
+ dataset_class = Seq2SeqDataset
188
+
189
+ # Get datasets
190
+ train_dataset = (
191
+ dataset_class(
192
+ tokenizer,
193
+ type_path="train",
194
+ data_dir=data_args.data_dir,
195
+ n_obs=data_args.n_train,
196
+ max_target_length=data_args.max_target_length,
197
+ max_source_length=data_args.max_source_length,
198
+ prefix=model.config.prefix or "",
199
+ )
200
+ if training_args.do_train
201
+ else None
202
+ )
203
+ eval_dataset = (
204
+ dataset_class(
205
+ tokenizer,
206
+ type_path="val",
207
+ data_dir=data_args.data_dir,
208
+ n_obs=data_args.n_val,
209
+ max_target_length=data_args.val_max_target_length,
210
+ max_source_length=data_args.max_source_length,
211
+ prefix=model.config.prefix or "",
212
+ )
213
+ if training_args.do_eval or training_args.evaluation_strategy != EvaluationStrategy.NO
214
+ else None
215
+ )
216
+ test_dataset = (
217
+ dataset_class(
218
+ tokenizer,
219
+ type_path="test",
220
+ data_dir=data_args.data_dir,
221
+ n_obs=data_args.n_test,
222
+ max_target_length=data_args.test_max_target_length,
223
+ max_source_length=data_args.max_source_length,
224
+ prefix=model.config.prefix or "",
225
+ )
226
+ if training_args.do_predict
227
+ else None
228
+ )
229
+
230
+ # Initialize our Trainer
231
+ compute_metrics_fn = (
232
+ build_compute_metrics_fn(data_args.task, tokenizer) if training_args.predict_with_generate else None
233
+ )
234
+ trainer = Seq2SeqTrainer(
235
+ model=model,
236
+ config=config,
237
+ args=training_args,
238
+ train_dataset=train_dataset,
239
+ eval_dataset=eval_dataset,
240
+ data_collator=Seq2SeqDataCollator(tokenizer, data_args, training_args.tpu_num_cores),
241
+ compute_metrics=compute_metrics_fn,
242
+ data_args=data_args,
243
+ )
244
+
245
+ # Training
246
+ if training_args.do_train:
247
+ trainer.train(
248
+ model_path=model_args.model_name_or_path if os.path.isdir(model_args.model_name_or_path) else None
249
+ )
250
+ trainer.save_model()
251
+ # For convenience, we also re-save the tokenizer to the same directory,
252
+ # so that you can share your model easily on huggingface.co/models =)
253
+ if trainer.is_world_process_zero():
254
+ trainer.state.save_to_json(os.path.join(training_args.output_dir, "trainer_state.json"))
255
+ tokenizer.save_pretrained(training_args.output_dir)
256
+
257
+ # Evaluation
258
+ eval_results = {}
259
+ if training_args.do_eval:
260
+ logger.info("*** Evaluate ***")
261
+
262
+ result = trainer.evaluate()
263
+
264
+ if trainer.is_world_process_zero():
265
+ logger.info("***** Eval results *****")
266
+ for key, value in result.items():
267
+ logger.info(" %s = %s", key, value)
268
+ save_json(result, os.path.join(training_args.output_dir, "eval_results.json"))
269
+ eval_results.update(result)
270
+
271
+ if training_args.do_predict:
272
+ logging.info("*** Test ***")
273
+
274
+ test_output = trainer.predict(test_dataset=test_dataset)
275
+ test_metrics = {k.replace("eval", "test"): v for k, v in test_output.metrics.items()}
276
+
277
+ if trainer.is_world_process_zero():
278
+ logger.info("***** Test results *****")
279
+ for key, value in test_metrics.items():
280
+ logger.info(" %s = %s", key, value)
281
+
282
+ save_json(test_metrics, os.path.join(training_args.output_dir, "test_results.json"))
283
+ eval_results.update(test_metrics)
284
+
285
+ if training_args.predict_with_generate:
286
+ test_preds = tokenizer.batch_decode(
287
+ test_output.predictions, skip_special_tokens=True, clean_up_tokenization_spaces=True
288
+ )
289
+ test_preds = lmap(str.strip, test_preds)
290
+ write_txt_file(test_preds, os.path.join(training_args.output_dir, "test_generations.txt"))
291
+
292
+ if trainer.is_world_process_zero():
293
+ save_json(eval_results, "all_results.json")
294
+ return eval_results
295
+
296
+
297
+ def _mp_fn(index):
298
+ # For xla_spawn (TPUs)
299
+ main()
300
+
301
+
302
+ if __name__ == "__main__":
303
+ main()