diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..374a18cb2ba91abffab7d5a69a78c2b104f57129 --- /dev/null +++ b/.gitignore @@ -0,0 +1,225 @@ +# Created by .ignore support plugin (hsz.mobi) +### macOS template +# General +.DS_Store +.AppleDouble +.LSOverride + +# Icon must end with two \r +Icon + +# Thumbnails +._* + +# Files that might appear in the root of a volume +.DocumentRevisions-V100 +.fseventsd +.Spotlight-V100 +.TemporaryItems +.Trashes +.VolumeIcon.icns +.com.apple.timemachine.donotpresent + +# Directories potentially created on remote AFP share +.AppleDB +.AppleDesktop +Network Trash Folder +Temporary Items +.apdisk +### Python template +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# pyenv +.python-version + +# celery beat schedule file +celerybeat-schedule + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +### JetBrains template +# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and WebStorm +# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 + +# User-specific stuff +.idea/**/workspace.xml +.idea/**/tasks.xml +.idea/**/usage.statistics.xml +.idea/**/dictionaries +.idea/**/shelf + +# Sensitive or high-churn files +.idea/**/dataSources/ +.idea/**/dataSources.ids +.idea/**/dataSources.local.xml +.idea/**/sqlDataSources.xml +.idea/**/dynamic.xml +.idea/**/uiDesigner.xml +.idea/**/dbnavigator.xml + +# Gradle +.idea/**/gradle.xml +.idea/**/libraries + +# Gradle and Maven with auto-import +# When using Gradle or Maven with auto-import, you should exclude module files, +# since they will be recreated, and may cause churn. Uncomment if using +# auto-import. +# .idea/modules.xml +# .idea/*.iml +# .idea/modules + +# CMake +cmake-build-*/ + +# Mongo Explorer plugin +.idea/**/mongoSettings.xml + +# File-based project format +*.iws + +# IntelliJ +out/ + +# mpeltonen/sbt-idea plugin +.idea_modules/ + +# JIRA plugin +atlassian-ide-plugin.xml + +# Cursive Clojure plugin +.idea/replstate.xml + +# Crashlytics plugin (for Android Studio and IntelliJ) +com_crashlytics_export_strings.xml +crashlytics.properties +crashlytics-build.properties +fabric.properties + +# Editor-based Rest Client +.idea/httpRequests +### VirtualEnv template +# Virtualenv +# http://iamzed.com/2009/05/07/a-primer-on-virtualenv/ +.Python +[Bb]in +[Ii]nclude +[Ll]ib +[Ll]ib64 +[Ll]ocal +pyvenv.cfg +.venv +pip-selfcheck.json + +.idea/ +eden.py +/_tmp/ +runs +*nohup* +*.pt +*.out +*.pkl +*.db +/cache/ +output/ +*.csv +*_resources/ +*_proc +lightning_logs/ +wandb/ +.lock +*gradio* \ No newline at end of file diff --git a/app.py b/app.py index 3703e2db0009fea1686d779101b431c47248e5e9..6ca155096d0ec67c6532f1755cac90bb6f513972 100644 --- a/app.py +++ b/app.py @@ -1,7 +1,105 @@ +# -*- coding: utf-8 -*- + +""" +@Author : Jiangjie Chen +@Time : 2021/12/13 17:17 +@Contact : jjchen19@fudan.edu.cn +@Description: +""" + +import os import gradio as gr +from src.loren import Loren +from huggingface_hub import snapshot_download +from prettytable import PrettyTable +import pandas as pd + +config = { + "input": "demo", + "model_type": "roberta", + "model_name_or_path": "roberta-large", + "logic_lambda": 0.5, + "prior": "random", + "mask_rate": 0.0, + "cand_k": 3, + "max_seq2_length": 256, + "max_seq1_length": 128, + "max_num_questions": 8 +} + +model_dir = snapshot_download('Jiangjie/loren') + +config['fc_dir'] = os.path.join(model_dir, 'fact_checking/roberta-large/') +config['mrc_dir'] = os.path.join(model_dir, 'mrc_seq2seq/bart-base/') +config['er_dir'] = os.path.join(model_dir, 'evidence_retrieval/') + +loren = Loren(config) +try: + # js = { + # 'id': 0, + # 'evidence': ['EVIDENCE1', 'EVIDENCE2'], + # 'question': ['QUESTION1', 'QUESTION2'], + # 'claim_phrase': ['CLAIMPHRASE1', 'CLAIMPHRASE2'], + # 'local_premise': [['E1 ' * 100, 'E1' * 100, 'E1' * 10], ['E2', 'E2', 'E2']], + # 'phrase_veracity': [[0.1, 0.5, 0.4], [0.1, 0.7, 0.2]], + # 'claim_veracity': 'SUPPORT' + # } + js = loren.check('Donald Trump won the 2020 U.S. presidential election.') +except Exception as e: + raise ValueError(e) + + +def gradio_formatter(js, output_type): + if output_type == 'e': + data = {'Evidence': js['evidence']} + elif output_type == 'z': + data = { + 'Claim Phrase': js['claim_phrase'], + 'Local Premise': [x[0] for x in js['local_premise']], + 'p_SUP': [round(x[2], 4) for x in js['phrase_veracity']], + 'p_REF': [round(x[0], 4) for x in js['phrase_veracity']], + 'p_NEI': [round(x[1], 4) for x in js['phrase_veracity']], + } + else: + raise NotImplementedError + data = pd.DataFrame(data) + pt = PrettyTable(field_names=list(data.columns)) + for v in data.values: + pt.add_row(v) + + html = pt.get_html_string(attributes={ + 'style': 'border-width: 1px; border-collapse: collapse', + }, format=True) + return html + + +def run(claim): + js = loren.check(claim) + ev_html = gradio_formatter(js, 'e') + z_html = gradio_formatter(js, 'z') + return ev_html, z_html, js['claim_veracity'], js -def greet(name): - return "Hello " + name + "!!" -iface = gr.Interface(fn=greet, inputs="text", outputs="text") +iface = gr.Interface( + fn=run, + inputs="text", + outputs=[ + 'html', + 'html', + 'label', + 'json' + ], + examples=['Donald Trump won the U.S. 2020 presidential election.', + 'The first inauguration of Bill Clinton was in the United States.'], + title="LOREN", + layout='vertical', + description="LOREN is an interpretable Fact Verification model against Wikipedia. " + "This is a demo system for \"LOREN: Logic-Regularized Reasoning for Interpretable Fact Verification\". " + "See the paper for technical details. You can add FLAG on the bottom to record interesting or bad cases!", + flagging_dir='results/flagged/', + allow_flagging=True, + flagging_options=['Good Case!', 'Error: MRC', 'Error: Parsing', + 'Error: Commonsense', 'Error: Evidence', 'Error: Other'], + enable_queue=True +) iface.launch() diff --git a/cjjpy.py b/cjjpy.py new file mode 100755 index 0000000000000000000000000000000000000000..2cc70b5e553924123810ab198c143bf7ee28e5d6 --- /dev/null +++ b/cjjpy.py @@ -0,0 +1,249 @@ +# -*- coding: utf-8 -*- + +''' +@Author : Jiangjie Chen +@Time : 2018/11/15 17:08 +@Contact: jjchen19@fudan.edu.cn +''' + +import re +import datetime +import os +import argparse +import logging +import traceback + +try: + import ujson as json +except: + import json + +HADOOP_BIN = 'PATH=/usr/bin/:$PATH hdfs' +FOR_PUBLIC = True + + +def LengthStats(filename): + len_list = [] + thresholds = [0.8, 0.9, 0.95, 0.99, 0.999] + with open(filename) as f: + for line in f: + len_list.append(len(line.strip().split())) + stats = { + 'Max': max(len_list), + 'Min': min(len_list), + 'Avg': round(sum(len_list) / len(len_list), 4), + } + len_list.sort() + for t in thresholds: + stats[f"Top-{t}"] = len_list[int(len(len_list) * t)] + + for k in stats: + print(f"- {k}: {stats[k]}") + return stats + + +class AttrDict(dict): + def __init__(self, *args, **kwargs): + super(AttrDict, self).__init__(*args, **kwargs) + self.__dict__ = self + + +def TraceBack(error_msg): + exc = traceback.format_exc() + msg = f'[Error]: {error_msg}.\n[Traceback]: {exc}' + return msg + + +def Now(): + return datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") + + +def AbsParentDir(file, parent='..', postfix=None): + ppath = os.path.abspath(file) + parent_level = parent.count('.') + while parent_level > 0: + ppath = os.path.dirname(ppath) + parent_level -= 1 + if postfix is not None: + return os.path.join(ppath, postfix) + else: + return ppath + + +def init_logger(log_file=None, log_file_level=logging.NOTSET, from_scratch=False): + from coloredlogs import ColoredFormatter + import tensorflow as tf + + fmt = "[%(asctime)s %(levelname)s] %(message)s" + log_format = ColoredFormatter(fmt=fmt) + # log_format = logging.Formatter() + logger = logging.getLogger() + logger.setLevel(log_file_level) + + console_handler = logging.StreamHandler() + console_handler.setFormatter(log_format) + logger.handlers = [console_handler] + + if log_file and log_file != '': + if from_scratch and tf.io.gfile.exists(log_file): + logger.warning('Removing previous log file: %s' % log_file) + tf.io.gfile.remove(log_file) + path = os.path.dirname(log_file) + os.makedirs(path, exist_ok=True) + file_handler = logging.FileHandler(log_file) + file_handler.setLevel(log_file_level) + file_handler.setFormatter(log_format) + logger.addHandler(file_handler) + + return logger + + +def OverWriteCjjPy(root='.'): + # import difflib + # diff = difflib.HtmlDiff() + cnt = 0 + golden_cjjpy = os.path.join(root, 'cjjpy.py') + # golden_content = open(golden_cjjpy).readlines() + for dir, folder, file in os.walk(root): + for f in file: + if f == 'cjjpy.py': + cjjpy = '%s/%s' % (dir, f) + # content = open(cjjpy).readlines() + # d = diff.make_file(golden_content, content) + cnt += 1 + print('[%d]: %s' % (cnt, cjjpy)) + os.system('cp %s %s' % (golden_cjjpy, cjjpy)) + + +def ChangeFileFormat(filename, new_fmt): + assert type(filename) is str and type(new_fmt) is str + spt = filename.split('.') + if len(spt) == 0: + return filename + else: + return filename.replace('.' + spt[-1], new_fmt) + + +def CountLines(fname): + with open(fname, 'rb') as f: + count = 0 + last_data = '\n' + while True: + data = f.read(0x400000) + if not data: + break + count += data.count(b'\n') + last_data = data + if last_data[-1:] != b'\n': + count += 1 # Remove this if a wc-like count is needed + return count + + +def GetDate(): + return str(datetime.datetime.now())[5:10].replace('-', '') + + +def TimeClock(seconds): + sec = int(seconds) + hour = int(sec / 3600) + min = int((sec - hour * 3600) / 60) + ssec = float(seconds) - hour * 3600 - min * 60 + # return '%dh %dm %.2fs' % (hour, min, ssec) + return '{0:>2d}h{1:>3d}m{2:>6.2f}s'.format(hour, min, ssec) + + +def StripAll(text): + return text.strip().replace('\t', '').replace('\n', '').replace(' ', '') + + +def GetBracket(text, bracket, en_br=False): + # input should be aa(bb)cc, True for bracket, False for text + if bracket: + try: + return re.findall('\((.*?)\)', text.strip())[-1] + except: + return '' + else: + if en_br: + text = re.sub('\(.*?\)', '', text.strip()) + return re.sub('(.*?)', '', text.strip()) + + +def CharLang(uchar, lang): + assert lang.lower() in ['en', 'cn', 'zh'] + if lang.lower() in ['cn', 'zh']: + if uchar >= '\u4e00' and uchar <= '\u9fa5': + return True + else: + return False + elif lang.lower() == 'en': + if (uchar <= 'Z' and uchar >= 'A') or (uchar <= 'z' and uchar >= 'a'): + return True + else: + return False + else: + raise NotImplementedError + + +def WordLang(word, lang): + for i in word.strip(): + if i.isspace(): continue + if not CharLang(i, lang): + return False + return True + + +def SortDict(_dict, reverse=True): + assert type(_dict) is dict + return sorted(_dict.items(), key=lambda d: d[1], reverse=reverse) + + +def lark(content='test'): + print(content) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + + parser.add_argument('--diff', nargs=2, + help='show difference between two files, shown in downloads/diff.html') + parser.add_argument('--de_unicode', action='store_true', default=False, + help='remove unicode characters') + parser.add_argument('--link_entity', action='store_true', default=False, + help='') + parser.add_argument('--max_comm_len', action='store_true', default=False, + help='') + parser.add_argument('--search', nargs=2, + help='search key from file, 2 args: file name & key') + parser.add_argument('--email', nargs=2, + help='sending emails, 2 args: subject & content') + parser.add_argument('--overwrite', action='store_true', default=None, + help='overwrite all cjjpy under given *dir* based on *dir*/cjjpy.py') + parser.add_argument('--replace', nargs=3, + help='replace char, 3 args: file name & replaced char & replacer char') + parser.add_argument('--lark', nargs=1) + parser.add_argument('--get_hdfs', nargs=2, + help='easy copy from hdfs to local fs, 2 args: remote_file/dir & local_dir') + parser.add_argument('--put_hdfs', nargs=2, + help='easy put from local fs to hdfs, 2 args: local_file/dir & remote_dir') + parser.add_argument('--length_stats', nargs=1, + help='simple token lengths distribution of a line-by-line file') + + args = parser.parse_args() + + if args.overwrite: + print('* Overwriting cjjpy...') + OverWriteCjjPy() + + if args.lark: + try: + content = args.lark[0] + except: + content = 'running complete' + print(f'* Larking "{content}"...') + lark(content) + + if args.length_stats: + file = args.length_stats[0] + print(f'* Working on {file} lengths statistics...') + LengthStats(file) diff --git a/docs/front.png b/docs/front.png new file mode 100644 index 0000000000000000000000000000000000000000..ca42d1a4f6c1019d00babceee62d5aae44b55047 Binary files /dev/null and b/docs/front.png differ diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..65437a5f67f213099f867da72ea1ad4194e0782d --- /dev/null +++ b/requirements.txt @@ -0,0 +1,28 @@ +nltk +tqdm +six +scikit-learn +pathlib +configargparse +bottle +ujson +GPUtil +coloredlogs +inflect +unidecode +psutil +wandb +rouge_score +sacrebleu +tagme +wikipedia-api +gradio +tensorflow +pytorch-lightning==1.0.4 +allennlp==1.2.2 +allennlp-models==1.2.2 +transformers==3.5.1 +torch==1.7.1 +datasets +pandas +prettytable \ No newline at end of file diff --git a/src/available_models/aaai22_roberta.json b/src/available_models/aaai22_roberta.json new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/check_client/cjjpy.py b/src/check_client/cjjpy.py new file mode 100755 index 0000000000000000000000000000000000000000..2cc70b5e553924123810ab198c143bf7ee28e5d6 --- /dev/null +++ b/src/check_client/cjjpy.py @@ -0,0 +1,249 @@ +# -*- coding: utf-8 -*- + +''' +@Author : Jiangjie Chen +@Time : 2018/11/15 17:08 +@Contact: jjchen19@fudan.edu.cn +''' + +import re +import datetime +import os +import argparse +import logging +import traceback + +try: + import ujson as json +except: + import json + +HADOOP_BIN = 'PATH=/usr/bin/:$PATH hdfs' +FOR_PUBLIC = True + + +def LengthStats(filename): + len_list = [] + thresholds = [0.8, 0.9, 0.95, 0.99, 0.999] + with open(filename) as f: + for line in f: + len_list.append(len(line.strip().split())) + stats = { + 'Max': max(len_list), + 'Min': min(len_list), + 'Avg': round(sum(len_list) / len(len_list), 4), + } + len_list.sort() + for t in thresholds: + stats[f"Top-{t}"] = len_list[int(len(len_list) * t)] + + for k in stats: + print(f"- {k}: {stats[k]}") + return stats + + +class AttrDict(dict): + def __init__(self, *args, **kwargs): + super(AttrDict, self).__init__(*args, **kwargs) + self.__dict__ = self + + +def TraceBack(error_msg): + exc = traceback.format_exc() + msg = f'[Error]: {error_msg}.\n[Traceback]: {exc}' + return msg + + +def Now(): + return datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") + + +def AbsParentDir(file, parent='..', postfix=None): + ppath = os.path.abspath(file) + parent_level = parent.count('.') + while parent_level > 0: + ppath = os.path.dirname(ppath) + parent_level -= 1 + if postfix is not None: + return os.path.join(ppath, postfix) + else: + return ppath + + +def init_logger(log_file=None, log_file_level=logging.NOTSET, from_scratch=False): + from coloredlogs import ColoredFormatter + import tensorflow as tf + + fmt = "[%(asctime)s %(levelname)s] %(message)s" + log_format = ColoredFormatter(fmt=fmt) + # log_format = logging.Formatter() + logger = logging.getLogger() + logger.setLevel(log_file_level) + + console_handler = logging.StreamHandler() + console_handler.setFormatter(log_format) + logger.handlers = [console_handler] + + if log_file and log_file != '': + if from_scratch and tf.io.gfile.exists(log_file): + logger.warning('Removing previous log file: %s' % log_file) + tf.io.gfile.remove(log_file) + path = os.path.dirname(log_file) + os.makedirs(path, exist_ok=True) + file_handler = logging.FileHandler(log_file) + file_handler.setLevel(log_file_level) + file_handler.setFormatter(log_format) + logger.addHandler(file_handler) + + return logger + + +def OverWriteCjjPy(root='.'): + # import difflib + # diff = difflib.HtmlDiff() + cnt = 0 + golden_cjjpy = os.path.join(root, 'cjjpy.py') + # golden_content = open(golden_cjjpy).readlines() + for dir, folder, file in os.walk(root): + for f in file: + if f == 'cjjpy.py': + cjjpy = '%s/%s' % (dir, f) + # content = open(cjjpy).readlines() + # d = diff.make_file(golden_content, content) + cnt += 1 + print('[%d]: %s' % (cnt, cjjpy)) + os.system('cp %s %s' % (golden_cjjpy, cjjpy)) + + +def ChangeFileFormat(filename, new_fmt): + assert type(filename) is str and type(new_fmt) is str + spt = filename.split('.') + if len(spt) == 0: + return filename + else: + return filename.replace('.' + spt[-1], new_fmt) + + +def CountLines(fname): + with open(fname, 'rb') as f: + count = 0 + last_data = '\n' + while True: + data = f.read(0x400000) + if not data: + break + count += data.count(b'\n') + last_data = data + if last_data[-1:] != b'\n': + count += 1 # Remove this if a wc-like count is needed + return count + + +def GetDate(): + return str(datetime.datetime.now())[5:10].replace('-', '') + + +def TimeClock(seconds): + sec = int(seconds) + hour = int(sec / 3600) + min = int((sec - hour * 3600) / 60) + ssec = float(seconds) - hour * 3600 - min * 60 + # return '%dh %dm %.2fs' % (hour, min, ssec) + return '{0:>2d}h{1:>3d}m{2:>6.2f}s'.format(hour, min, ssec) + + +def StripAll(text): + return text.strip().replace('\t', '').replace('\n', '').replace(' ', '') + + +def GetBracket(text, bracket, en_br=False): + # input should be aa(bb)cc, True for bracket, False for text + if bracket: + try: + return re.findall('\((.*?)\)', text.strip())[-1] + except: + return '' + else: + if en_br: + text = re.sub('\(.*?\)', '', text.strip()) + return re.sub('(.*?)', '', text.strip()) + + +def CharLang(uchar, lang): + assert lang.lower() in ['en', 'cn', 'zh'] + if lang.lower() in ['cn', 'zh']: + if uchar >= '\u4e00' and uchar <= '\u9fa5': + return True + else: + return False + elif lang.lower() == 'en': + if (uchar <= 'Z' and uchar >= 'A') or (uchar <= 'z' and uchar >= 'a'): + return True + else: + return False + else: + raise NotImplementedError + + +def WordLang(word, lang): + for i in word.strip(): + if i.isspace(): continue + if not CharLang(i, lang): + return False + return True + + +def SortDict(_dict, reverse=True): + assert type(_dict) is dict + return sorted(_dict.items(), key=lambda d: d[1], reverse=reverse) + + +def lark(content='test'): + print(content) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + + parser.add_argument('--diff', nargs=2, + help='show difference between two files, shown in downloads/diff.html') + parser.add_argument('--de_unicode', action='store_true', default=False, + help='remove unicode characters') + parser.add_argument('--link_entity', action='store_true', default=False, + help='') + parser.add_argument('--max_comm_len', action='store_true', default=False, + help='') + parser.add_argument('--search', nargs=2, + help='search key from file, 2 args: file name & key') + parser.add_argument('--email', nargs=2, + help='sending emails, 2 args: subject & content') + parser.add_argument('--overwrite', action='store_true', default=None, + help='overwrite all cjjpy under given *dir* based on *dir*/cjjpy.py') + parser.add_argument('--replace', nargs=3, + help='replace char, 3 args: file name & replaced char & replacer char') + parser.add_argument('--lark', nargs=1) + parser.add_argument('--get_hdfs', nargs=2, + help='easy copy from hdfs to local fs, 2 args: remote_file/dir & local_dir') + parser.add_argument('--put_hdfs', nargs=2, + help='easy put from local fs to hdfs, 2 args: local_file/dir & remote_dir') + parser.add_argument('--length_stats', nargs=1, + help='simple token lengths distribution of a line-by-line file') + + args = parser.parse_args() + + if args.overwrite: + print('* Overwriting cjjpy...') + OverWriteCjjPy() + + if args.lark: + try: + content = args.lark[0] + except: + content = 'running complete' + print(f'* Larking "{content}"...') + lark(content) + + if args.length_stats: + file = args.length_stats[0] + print(f'* Working on {file} lengths statistics...') + LengthStats(file) diff --git a/src/check_client/fact_checker.py b/src/check_client/fact_checker.py new file mode 100644 index 0000000000000000000000000000000000000000..b45f56c5fa9f46b96d711bad47910f2c222b2139 --- /dev/null +++ b/src/check_client/fact_checker.py @@ -0,0 +1,209 @@ +# -*- coding: utf-8 -*- + +""" +@Author : Bao +@Date : 2020/8/12 +@Desc : +@Last modified by : Bao +@Last modified date : 2020/8/20 +""" + +import os +import sys +import logging +import torch +import numpy as np +from tqdm import tqdm +import tensorflow as tf +import ujson as json +import argparse +import cjjpy as cjj +from itertools import repeat +from torch.utils.data import DataLoader, SequentialSampler +from transformers import ( + BertConfig, BertTokenizer, AutoTokenizer, + RobertaConfig, RobertaTokenizer, +) + +try: + from .modules.data_processor import DataProcessor + from .plm_checkers import BertChecker, RobertaChecker + from .utils import read_json_lines, compute_metrics + from .train import do_evaluate, set_seed + from ..eval_client.fever_scorer import FeverScorer +except: + sys.path.append(cjj.AbsParentDir(__file__, '.')) + sys.path.append(cjj.AbsParentDir(__file__, '..')) + from eval_client.fever_scorer import FeverScorer + from modules.data_processor import DataProcessor + from plm_checkers import BertChecker, RobertaChecker + from utils import read_json_lines, compute_metrics + from train import do_evaluate, set_seed + +MODEL_MAPPING = { + 'bert': (BertConfig, BertTokenizer, BertChecker), + 'roberta': (RobertaConfig, RobertaTokenizer, RobertaChecker), +} + +logger = logging.getLogger(__name__) +label2id = {"SUPPORTS": 2, "REFUTES": 0, 'NOT ENOUGH INFO': 1} +id2label = {v: k for k, v in label2id.items()} + + +class FactChecker: + def __init__(self, args, fc_ckpt_dir=None, mask_rate=0.): + self.data_processor = None + self.tokenizer = None + self.model = None + self.args = args + self.ckpt = args.fc_dir if fc_ckpt_dir is None else fc_ckpt_dir + self.mask_rate = mask_rate + + logger.info('Initializing fact checker.') + self._prepare_ckpt(self.args.model_name_or_path, self.ckpt) + self.load_model() + + def _prepare_ckpt(self, model_name_or_path, ckpt_dir): + tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) + tokenizer.save_pretrained(ckpt_dir) + + def load_model(self): + if self.model is None: + self.data_processor = DataProcessor( + self.args.model_name_or_path, + self.args.max_seq1_length, + self.args.max_seq2_length, + self.args.max_num_questions, + self.args.cand_k, + mask_rate=self.mask_rate + ) + + _, tokenizer_class, model_class = MODEL_MAPPING[self.args.model_type] + self.tokenizer = tokenizer_class.from_pretrained( + self.ckpt, + do_lower_case=self.args.do_lower_case + ) + self.model = model_class.from_pretrained( + self.ckpt, + from_tf=bool(".ckpt" in self.ckpt), + logic_lambda=self.args.logic_lambda, + prior=self.args.prior, + ) + self.model = torch.nn.DataParallel(self.model) + + def _check(self, inputs: list, batch_size=32, verbose=True): + dataset = self.data_processor.convert_inputs_to_dataset(inputs, self.tokenizer, verbose=verbose) + sampler = SequentialSampler(dataset) + dataloader = DataLoader(dataset, sampler=sampler, batch_size=batch_size) + + with torch.no_grad(): + self.model.to(self.args.device) + self.model.eval() + iter = tqdm(dataloader, desc="Fact Checking") if verbose else dataloader + _, y_predicted, z_predicted, m_attn, mask = \ + do_evaluate(iter, self.model, self.args, during_training=False, with_label=False) + + return y_predicted, z_predicted, m_attn, mask + + def check_from_file(self, in_filename, out_filename, batch_size, verbose=False): + if 'test' in in_filename: + raw_inp = f'{os.environ["PJ_HOME"]}/data/fever/shared_task_test.jsonl' + else: + raw_inp = None + tf.io.gfile.makedirs(os.path.dirname(out_filename)) + inputs = list(read_json_lines(in_filename)) + y_predicted, z_predicted, m_attn, mask = self._check(inputs, batch_size) + + z_predicted = repeat(None) if z_predicted is None else z_predicted + m_attn = repeat(None) if m_attn is None else m_attn + ordered_results = {} + with_label = inputs[0].get('label') is not None + + if with_label: + label_truth = [label2id[x['label']] for x in inputs] + _, acc_results = compute_metrics(label_truth, y_predicted, z_predicted, mask) + else: + acc_results = {} + + for i, (inp, y, z, attn, _mask) in \ + enumerate(zip(inputs, y_predicted, z_predicted, m_attn, mask)): + result = {'id': inp['id'], + 'predicted_label': id2label[y], + 'predicted_evidence': inp.get('predicted_evidence', [])} + if verbose: + if i < 5: + print("{}\t{}\t{}".format(inp.get("id", i), inp["claim"], y)) + if z is not None and attn is not None: + result.update({ + 'z_prob': z[:torch.tensor(_mask).sum()], + 'm_attn': attn[:torch.tensor(_mask).sum()], + }) + ordered_results[inp['id']] = result + + with tf.io.gfile.GFile(out_filename, 'w') as fout: + if raw_inp: + with tf.io.gfile.GFile(raw_inp) as f: + for line in f: + raw_js = json.loads(line) + fout.write(json.dumps(ordered_results[raw_js['id']]) + '\n') + else: + for k in ordered_results: + fout.write(json.dumps(ordered_results[k]) + '\n') + + if ('dev' in in_filename or 'val' in in_filename) and with_label: + scorer = FeverScorer() + fever_results = scorer.get_scores(out_filename) + fever_results.update(acc_results) + + print(fever_results) + return fever_results + + def check_from_batch(self, inputs: list, verbose=False): + y_predicted, z_predicted, m_attn, mask = self._check(inputs, len(inputs), verbose) + return y_predicted, z_predicted, m_attn + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--input', '-i', required=True, type=str, + choices=['val', 'eval', 'test', 'demo']) + parser.add_argument('--output', '-o', default='none', type=str) + parser.add_argument('--ckpt', '-c', required=True, type=str) + parser.add_argument('--model_type', default='roberta', type=str, + choices=['roberta', 'bert']) + parser.add_argument('--model_name_or_path', default='roberta-large', type=str) + parser.add_argument('--verbose', '-v', action='store_true', default=False, + help='whether output phrasal veracity or not') + parser.add_argument('--logic_lambda', '-l', required=True, type=float) + parser.add_argument('--prior', default='random', type=str, choices=['nli', 'uniform', 'logic', 'random'], + help='type of prior distribution') + parser.add_argument('--mask_rate', '-m', default=0., type=float) + + parser.add_argument('--cand_k', '-k', default=3, type=int) + parser.add_argument('--max_seq1_length', default=256, type=int) + parser.add_argument('--max_seq2_length', default=128, type=int) + parser.add_argument('--max_num_questions', default=8, type=int) + parser.add_argument('--do_lower_case', action='store_true', default=False) + parser.add_argument('--batch_size', '-b', default=64, type=int) + parser.add_argument('--seed', default=42) + parser.add_argument('--n_gpu', default=4) + + args = parser.parse_args() + + set_seed(args) + + args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + if args.output == 'none': + args.ckpt = args.ckpt[:-1] if args.ckpt.endswith('/') else args.ckpt + base_name = os.path.basename(args.ckpt) + args.output = f'{os.environ["PJ_HOME"]}/results/fact_checking/AAAI22/{args.input}.{args.model_name_or_path}_m{args.mask_rate}_l{args.logic_lambda}_{base_name}_{args.prior}.predictions.jsonl' + + assert args.output.endswith('predictions.jsonl'), \ + f"{args.output} must end with predictions.jsonl" + + args.input = f'{os.environ["PJ_HOME"]}/data/fact_checking/v5/{args.input}.json' + + checker = FactChecker(args, args.ckpt, args.mask_rate) + fever_results = checker.check_from_file(args.input, args.output, args.batch_size, args.verbose) + cjj.lark(f"{args.output}: {fever_results}") diff --git a/src/check_client/modules/cjjpy.py b/src/check_client/modules/cjjpy.py new file mode 100755 index 0000000000000000000000000000000000000000..2cc70b5e553924123810ab198c143bf7ee28e5d6 --- /dev/null +++ b/src/check_client/modules/cjjpy.py @@ -0,0 +1,249 @@ +# -*- coding: utf-8 -*- + +''' +@Author : Jiangjie Chen +@Time : 2018/11/15 17:08 +@Contact: jjchen19@fudan.edu.cn +''' + +import re +import datetime +import os +import argparse +import logging +import traceback + +try: + import ujson as json +except: + import json + +HADOOP_BIN = 'PATH=/usr/bin/:$PATH hdfs' +FOR_PUBLIC = True + + +def LengthStats(filename): + len_list = [] + thresholds = [0.8, 0.9, 0.95, 0.99, 0.999] + with open(filename) as f: + for line in f: + len_list.append(len(line.strip().split())) + stats = { + 'Max': max(len_list), + 'Min': min(len_list), + 'Avg': round(sum(len_list) / len(len_list), 4), + } + len_list.sort() + for t in thresholds: + stats[f"Top-{t}"] = len_list[int(len(len_list) * t)] + + for k in stats: + print(f"- {k}: {stats[k]}") + return stats + + +class AttrDict(dict): + def __init__(self, *args, **kwargs): + super(AttrDict, self).__init__(*args, **kwargs) + self.__dict__ = self + + +def TraceBack(error_msg): + exc = traceback.format_exc() + msg = f'[Error]: {error_msg}.\n[Traceback]: {exc}' + return msg + + +def Now(): + return datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") + + +def AbsParentDir(file, parent='..', postfix=None): + ppath = os.path.abspath(file) + parent_level = parent.count('.') + while parent_level > 0: + ppath = os.path.dirname(ppath) + parent_level -= 1 + if postfix is not None: + return os.path.join(ppath, postfix) + else: + return ppath + + +def init_logger(log_file=None, log_file_level=logging.NOTSET, from_scratch=False): + from coloredlogs import ColoredFormatter + import tensorflow as tf + + fmt = "[%(asctime)s %(levelname)s] %(message)s" + log_format = ColoredFormatter(fmt=fmt) + # log_format = logging.Formatter() + logger = logging.getLogger() + logger.setLevel(log_file_level) + + console_handler = logging.StreamHandler() + console_handler.setFormatter(log_format) + logger.handlers = [console_handler] + + if log_file and log_file != '': + if from_scratch and tf.io.gfile.exists(log_file): + logger.warning('Removing previous log file: %s' % log_file) + tf.io.gfile.remove(log_file) + path = os.path.dirname(log_file) + os.makedirs(path, exist_ok=True) + file_handler = logging.FileHandler(log_file) + file_handler.setLevel(log_file_level) + file_handler.setFormatter(log_format) + logger.addHandler(file_handler) + + return logger + + +def OverWriteCjjPy(root='.'): + # import difflib + # diff = difflib.HtmlDiff() + cnt = 0 + golden_cjjpy = os.path.join(root, 'cjjpy.py') + # golden_content = open(golden_cjjpy).readlines() + for dir, folder, file in os.walk(root): + for f in file: + if f == 'cjjpy.py': + cjjpy = '%s/%s' % (dir, f) + # content = open(cjjpy).readlines() + # d = diff.make_file(golden_content, content) + cnt += 1 + print('[%d]: %s' % (cnt, cjjpy)) + os.system('cp %s %s' % (golden_cjjpy, cjjpy)) + + +def ChangeFileFormat(filename, new_fmt): + assert type(filename) is str and type(new_fmt) is str + spt = filename.split('.') + if len(spt) == 0: + return filename + else: + return filename.replace('.' + spt[-1], new_fmt) + + +def CountLines(fname): + with open(fname, 'rb') as f: + count = 0 + last_data = '\n' + while True: + data = f.read(0x400000) + if not data: + break + count += data.count(b'\n') + last_data = data + if last_data[-1:] != b'\n': + count += 1 # Remove this if a wc-like count is needed + return count + + +def GetDate(): + return str(datetime.datetime.now())[5:10].replace('-', '') + + +def TimeClock(seconds): + sec = int(seconds) + hour = int(sec / 3600) + min = int((sec - hour * 3600) / 60) + ssec = float(seconds) - hour * 3600 - min * 60 + # return '%dh %dm %.2fs' % (hour, min, ssec) + return '{0:>2d}h{1:>3d}m{2:>6.2f}s'.format(hour, min, ssec) + + +def StripAll(text): + return text.strip().replace('\t', '').replace('\n', '').replace(' ', '') + + +def GetBracket(text, bracket, en_br=False): + # input should be aa(bb)cc, True for bracket, False for text + if bracket: + try: + return re.findall('\((.*?)\)', text.strip())[-1] + except: + return '' + else: + if en_br: + text = re.sub('\(.*?\)', '', text.strip()) + return re.sub('(.*?)', '', text.strip()) + + +def CharLang(uchar, lang): + assert lang.lower() in ['en', 'cn', 'zh'] + if lang.lower() in ['cn', 'zh']: + if uchar >= '\u4e00' and uchar <= '\u9fa5': + return True + else: + return False + elif lang.lower() == 'en': + if (uchar <= 'Z' and uchar >= 'A') or (uchar <= 'z' and uchar >= 'a'): + return True + else: + return False + else: + raise NotImplementedError + + +def WordLang(word, lang): + for i in word.strip(): + if i.isspace(): continue + if not CharLang(i, lang): + return False + return True + + +def SortDict(_dict, reverse=True): + assert type(_dict) is dict + return sorted(_dict.items(), key=lambda d: d[1], reverse=reverse) + + +def lark(content='test'): + print(content) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + + parser.add_argument('--diff', nargs=2, + help='show difference between two files, shown in downloads/diff.html') + parser.add_argument('--de_unicode', action='store_true', default=False, + help='remove unicode characters') + parser.add_argument('--link_entity', action='store_true', default=False, + help='') + parser.add_argument('--max_comm_len', action='store_true', default=False, + help='') + parser.add_argument('--search', nargs=2, + help='search key from file, 2 args: file name & key') + parser.add_argument('--email', nargs=2, + help='sending emails, 2 args: subject & content') + parser.add_argument('--overwrite', action='store_true', default=None, + help='overwrite all cjjpy under given *dir* based on *dir*/cjjpy.py') + parser.add_argument('--replace', nargs=3, + help='replace char, 3 args: file name & replaced char & replacer char') + parser.add_argument('--lark', nargs=1) + parser.add_argument('--get_hdfs', nargs=2, + help='easy copy from hdfs to local fs, 2 args: remote_file/dir & local_dir') + parser.add_argument('--put_hdfs', nargs=2, + help='easy put from local fs to hdfs, 2 args: local_file/dir & remote_dir') + parser.add_argument('--length_stats', nargs=1, + help='simple token lengths distribution of a line-by-line file') + + args = parser.parse_args() + + if args.overwrite: + print('* Overwriting cjjpy...') + OverWriteCjjPy() + + if args.lark: + try: + content = args.lark[0] + except: + content = 'running complete' + print(f'* Larking "{content}"...') + lark(content) + + if args.length_stats: + file = args.length_stats[0] + print(f'* Working on {file} lengths statistics...') + LengthStats(file) diff --git a/src/check_client/modules/data_processor.py b/src/check_client/modules/data_processor.py new file mode 100644 index 0000000000000000000000000000000000000000..5958fa4032b508f4bde7b83e56a2ea78cecb2991 --- /dev/null +++ b/src/check_client/modules/data_processor.py @@ -0,0 +1,354 @@ +# -*- coding: utf-8 -*- + +""" +@Author : Bao +@Date : 2020/4/14 +@Desc : +@Last modified by : Bao +@Last modified date : 2020/8/12 +""" + +import os +import copy +import logging +import ujson as json +import torch +from tqdm import tqdm +from torch.utils.data import TensorDataset +import tensorflow as tf +import cjjpy as cjj +import sys + +try: + from ...mrc_client.answer_generator import assemble_answers_to_one +except: + sys.path.append(cjj.AbsParentDir(__file__, '...')) + from mrc_client.answer_generator import assemble_answers_to_one + +logger = logging.getLogger(__name__) + + +class InputExample(object): + def __init__(self, guid, claim, evidences, questions, answers, + evidential, label=None, nli_labels=None): + self.guid = guid + self.claim = claim + self.evidences = evidences + self.questions = questions + self.answers = answers + self.evidential = evidential + self.label = label + self.nli_labels = nli_labels + + def __repr__(self): + return str(self.to_json_string()) + + def to_dict(self): + """Serializes this instance to a Python dictionary.""" + output = copy.deepcopy(self.__dict__) + return output + + def to_json_string(self): + """Serializes this instance to a JSON string.""" + return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" + + +class InputFeatures(object): + def __init__( + self, + guid, + c_input_ids, + c_attention_mask, + c_token_type_ids, + q_input_ids_list, + q_attention_mask_list, + q_token_type_ids_list, + nli_labels=None, + label=None, + ): + self.guid = guid + self.c_input_ids = c_input_ids + self.c_attention_mask = c_attention_mask + self.c_token_type_ids = c_token_type_ids + self.q_input_ids_list = q_input_ids_list + self.q_attention_mask_list = q_attention_mask_list + self.q_token_type_ids_list = q_token_type_ids_list + self.nli_labels = nli_labels + self.label = label + + def __repr__(self): + return str(self.to_json_string()) + + def to_dict(self): + """Serializes this instance to a Python dictionary.""" + output = copy.deepcopy(self.__dict__) + return output + + def to_json_string(self): + """Serializes this instance to a JSON string.""" + return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" + + +def _create_input_ids_from_token_ids(token_ids_a, token_ids_b, tokenizer, max_seq_length): + pair = len(token_ids_b) != 0 + + # Truncate sequences. + num_special_tokens_to_add = tokenizer.num_special_tokens_to_add(pair=pair) + while len(token_ids_a) + len(token_ids_b) > max_seq_length - num_special_tokens_to_add: + if len(token_ids_b) > 0: + token_ids_b = token_ids_b[:-1] + else: + token_ids_a = token_ids_a[:-1] + + # Add special tokens to input_ids. + input_ids = tokenizer.build_inputs_with_special_tokens(token_ids_a, token_ids_b if pair else None) + + # The mask has 1 for real tokens and 0 for padding tokens. Only real tokens are attended to. + attention_mask = [1] * len(input_ids) + + # Create token_type_ids. + token_type_ids = tokenizer.create_token_type_ids_from_sequences(token_ids_a, token_ids_b if pair else None) + + # Pad up to the sequence length. + padding_length = max_seq_length - len(input_ids) + if tokenizer.padding_side == "right": + input_ids = input_ids + ([tokenizer.pad_token_id] * padding_length) + attention_mask = attention_mask + ([0] * padding_length) + token_type_ids = token_type_ids + ([tokenizer.pad_token_type_id] * padding_length) + else: + input_ids = ([tokenizer.pad_token_id] * padding_length) + input_ids + attention_mask = ([0] * padding_length) + attention_mask + token_type_ids = ([tokenizer.pad_token_type_id] * padding_length) + token_type_ids + + assert len(input_ids) == max_seq_length + assert len(attention_mask) == max_seq_length + assert len(token_type_ids) == max_seq_length + + return input_ids, attention_mask, token_type_ids + + +def convert_examples_to_features( + examples, + tokenizer, + max_seq1_length=256, + max_seq2_length=128, + verbose=True +): + features = [] + iter = tqdm(examples, desc="Converting Examples") if verbose else examples + for (ex_index, example) in enumerate(iter): + encoded_outputs = {"guid": example.guid, 'label': example.label, + 'nli_labels': example.nli_labels} + + # ****** for sequence 1 ******* # + token_ids_a, token_ids_b = [], [] + + # text a in sequence 1 + token_ids = tokenizer.encode(example.claim, add_special_tokens=False) # encode claim + token_ids_a.extend(token_ids) + + # text b in sequence 1 + for i, evidence in enumerate(example.evidences): + token_ids = tokenizer.encode(evidence, add_special_tokens=False) # encode evidence + token_ids_b.extend(token_ids + [tokenizer.sep_token_id]) + # Remove last sep token in token_ids_b. + token_ids_b = token_ids_b[:-1] + token_ids_b = token_ids_b[:max_seq1_length - len(token_ids_a) - 4] # magic number for special tokens + + # premise hypothesis + input_ids, attention_mask, token_type_ids = _create_input_ids_from_token_ids( + token_ids_b, + token_ids_a, + tokenizer, + max_seq1_length, + ) + + encoded_outputs["c_input_ids"] = input_ids + encoded_outputs["c_attention_mask"] = attention_mask + encoded_outputs["c_token_type_ids"] = token_type_ids + + # ****** for sequence 2 ******* # + encoded_outputs["q_input_ids_list"] = [] # m x L + encoded_outputs["q_attention_mask_list"] = [] + encoded_outputs["q_token_type_ids_list"] = [] + + for candidate in example.evidential: + # text a in sequence 2 + token_ids_a = tokenizer.encode(example.claim, add_special_tokens=False) # encode claim + # text b in sequence 2 + token_ids_b = tokenizer.encode(candidate, add_special_tokens=False) # encode candidate answer + # premise hypothesis + input_ids, attention_mask, token_type_ids = _create_input_ids_from_token_ids( + token_ids_b, + token_ids_a, + tokenizer, + max_seq2_length, + ) + + encoded_outputs["q_input_ids_list"].append(input_ids) + encoded_outputs["q_attention_mask_list"].append(attention_mask) + encoded_outputs["q_token_type_ids_list"].append(token_type_ids) + + features.append(InputFeatures(**encoded_outputs)) + + if ex_index < 5 and verbose: + logger.info("*** Example ***") + logger.info("guid: {}".format(example.guid)) + logger.info("c_input_ids: {}".format(encoded_outputs["c_input_ids"])) + for input_ids in encoded_outputs['q_input_ids_list']: + logger.info('q_input_ids: {}'.format(input_ids)) + logger.info("label: {}".format(example.label)) + logger.info("nli_labels: {}".format(example.nli_labels)) + + return features + + +class DataProcessor: + def __init__( + self, + model_name_or_path, + max_seq1_length, + max_seq2_length, + max_num_questions, + cand_k, + data_dir='', + cache_dir_name='cache_check', + overwrite_cache=False, + mask_rate=0. + ): + self.model_name_or_path = model_name_or_path + self.max_seq1_length = max_seq1_length + self.max_seq2_length = max_seq2_length + self.max_num_questions = max_num_questions + self.k = cand_k + self.mask_rate = mask_rate + + self.data_dir = data_dir + self.cached_data_dir = os.path.join(data_dir, cache_dir_name) + self.overwrite_cache = overwrite_cache + + self.label2id = {"SUPPORTS": 2, "REFUTES": 0, 'NOT ENOUGH INFO': 1} + + def _format_file(self, role): + return os.path.join(self.data_dir, "{}.json".format(role)) + + def load_and_cache_data(self, role, tokenizer, data_tag): + tf.io.gfile.makedirs(self.cached_data_dir) + cached_file = os.path.join( + self.cached_data_dir, + "cached_features_{}_{}_{}_{}_{}_{}".format( + role, + list(filter(None, self.model_name_or_path.split("/"))).pop(), + str(self.max_seq1_length), + str(self.max_seq2_length), + str(self.k), + data_tag + ), + ) + if os.path.exists(cached_file) and not self.overwrite_cache: + logger.info("Loading features from cached file {}".format(cached_file)) + features = torch.load(cached_file) + else: + examples = [] + with tf.io.gfile.GFile(self._format_file(role)) as f: + data = f.readlines() + for line in tqdm(data): + sample = self._load_line(line) + examples.append(InputExample(**sample)) + features = convert_examples_to_features(examples, tokenizer, + self.max_seq1_length, self.max_seq2_length) + if 'train' in role or 'eval' in role: + logger.info("Saving features into cached file {}".format(cached_file)) + torch.save(features, cached_file) + + return self._create_tensor_dataset(features, tokenizer) + + def convert_inputs_to_dataset(self, inputs, tokenizer, verbose=True): + examples = [] + for line in inputs: + sample = self._load_line(line) + examples.append(InputExample(**sample)) + features = convert_examples_to_features(examples, tokenizer, + self.max_seq1_length, self.max_seq2_length, verbose) + + return self._create_tensor_dataset(features, tokenizer, do_predict=True) + + def _create_tensor_dataset(self, features, tokenizer, do_predict=False): + all_c_input_ids = torch.tensor([f.c_input_ids for f in features], dtype=torch.long) + all_c_attention_mask = torch.tensor([f.c_attention_mask for f in features], dtype=torch.long) + all_c_token_type_ids = torch.tensor([f.c_token_type_ids for f in features], dtype=torch.long) + + all_q_input_ids_list = [] + all_q_attention_mask_list = [] + all_q_token_type_ids_list = [] + + def _trunc_agg(self, feature, pad_token): + # feature: m x L + _input_list = [v for v in feature[:self.max_num_questions]] + while len(_input_list) < self.max_num_questions: + _input_list.append([pad_token] * self.max_seq2_length) + return _input_list + + for f in features: # N x m x L + all_q_input_ids_list.append(_trunc_agg(self, f.q_input_ids_list, tokenizer.pad_token_id)) + all_q_attention_mask_list.append(_trunc_agg(self, f.q_attention_mask_list, 0)) + all_q_token_type_ids_list.append(_trunc_agg(self, f.q_token_type_ids_list, tokenizer.pad_token_type_id)) + + all_q_input_ids_list = torch.tensor(all_q_input_ids_list, dtype=torch.long) + all_q_attention_mask_list = torch.tensor(all_q_attention_mask_list, dtype=torch.long) + all_q_token_type_ids_list = torch.tensor(all_q_token_type_ids_list, dtype=torch.long) + + all_nli_labels_list = [] + for f in features: + all_nli_labels_list.append(f.nli_labels[:self.max_num_questions] + + max(0, (self.max_num_questions - len(f.nli_labels))) * [[0., 0., 0.]]) + all_nli_labels = torch.tensor(all_nli_labels_list, dtype=torch.float) + + if not do_predict: + all_labels = torch.tensor([f.label for f in features], dtype=torch.long) + dataset = TensorDataset( + all_c_input_ids, all_c_attention_mask, all_c_token_type_ids, + all_q_input_ids_list, all_q_attention_mask_list, all_q_token_type_ids_list, + all_nli_labels, all_labels, + ) + else: + dataset = TensorDataset( + all_c_input_ids, all_c_attention_mask, all_c_token_type_ids, + all_q_input_ids_list, all_q_attention_mask_list, all_q_token_type_ids_list, + all_nli_labels, + ) + + return dataset + + def _load_line(self, line): + if isinstance(line, str): + line = json.loads(line) + guid = line["id"] + claim = line["claim"] + + # TODO: hack no evidence situation + evidences = line["evidence"] if len(line['evidence']) > 0 else ['no idea'] * 5 + questions = line["questions"] + answers = line["answers"] + evidential = assemble_answers_to_one(line, self.k, mask_rate=self.mask_rate)['evidential_assembled'] + label = line.get("label", None) + nli_labels = line.get('nli_labels', [[0., 0., 0.]] * len(questions)) + + for i, e in enumerate(evidential): + if '' in e: + nli_labels[i] = [0., 0., 0.] + + answers = [v[0] for v in answers] # k = 1 + label = self.label2id.get(label) + + sample = { + "guid": guid, + "claim": claim, + "evidences": evidences, + "questions": questions, + "answers": answers, + "evidential": evidential, # already assembled. + "label": label, + 'nli_labels': nli_labels + } + return sample diff --git a/src/check_client/modules/test_data_processor.py b/src/check_client/modules/test_data_processor.py new file mode 100644 index 0000000000000000000000000000000000000000..414f99edaec1158745f90537c4072f10793b3978 --- /dev/null +++ b/src/check_client/modules/test_data_processor.py @@ -0,0 +1,26 @@ +# -*- coding: utf-8 -*- + +""" +@Author : Jiangjie Chen +@Time : 2020/12/20 18:05 +@Contact : jjchen19@fudan.edu.cn +@Description: +""" + +import os +from data_processor import DataProcessor +from transformers import RobertaTokenizer + + +root = os.environ['PJ_HOME'] + +tokenizer = RobertaTokenizer.from_pretrained('roberta-large') +dp = DataProcessor('roberta-large', 256, 128, 8, cand_k=3, data_dir=f'{root}/data/fact_checking/v5', overwrite_cache=True) + +# dp.load_and_cache_data('val', tokenizer) + + +data = {"id":91198,"claim":"Colin Kaepernick became a starting quarterback during the 49ers 63rd season in the National Football League .","evidence":["Things about Colin Kaepernick: He remained the team 's starting quarterback for the rest of the season and went on to lead the 49ers to their first Super Bowl appearance since 1994 , losing to the Baltimore Ravens .","Things about Colin Kaepernick: In the following seasons , Kaepernick lost and won back his starting job , with the 49ers missing the playoffs for three years consecutively .","Things about Colin Kaepernick: During the 2013 season , his first full season as a starter , Kaepernick helped the 49ers reach the NFC Championship , losing to the Seattle Seahawks .","Things about Colin Kaepernick: Kaepernick began his professional career as a backup to Alex Smith , but became the 49ers ' starter in the middle of the 2012 season after Smith suffered a concussion .","Things about Colin Kaepernick: Colin Rand Kaepernick ( ; born November 3 , 1987 ) is an American football quarterback who is currently a free agent ."],"answers":[["Colin Kaepernick",0,16],["a starting quarterback",24,46],["49ers",58,63],["63rd season",64,75],["National Football League",83,107]],"questions":["noun","noun","noun","noun","noun"],"label":"NOT ENOUGH INFO","evidential_assembled":["Who was the starting quarterback for the 49ers in the 63rd season? or became a starting quarterback during the 49ers 63rd season in the National Football League .","What was Colin Kaepernick's first job title? or Colin Kaepernick became during the 49ers 63rd season in the National Football League .","What team was Colin Kaepernick a quarterback for? or Colin Kaepernick became a starting quarterback during the 63rd season in the National Football League .","In what season did Colin Kaepernick become a starting quarterback for the 49ers? or Colin Kaepernick became a starting quarterback during the 49ers in the National Football League .","What league was Colin Kaepernick a quarterback in? or Colin Kaepernick became a starting quarterback during the 49ers 63rd season in the ."],"evidential":[["Colin Kaepernick became a starting quarterback during the 49ers 63rd season in the National Football League .","Colin Kapit became a starting quarterback during the 49ers 63rd season in the National Football League .","Colin Kapra became a starting quarterback during the 49ers 63rd season in the National Football League .","Colin Kaepernick became a starting quarterback during the 49ers 63rd season in the National Football League .","Colin Colin Kaepernick became a starting quarterback during the 49ers 63rd season in the National Football League ."],["Colin Kaepernick became a quarterback during the 49ers 63rd season in the National Football League .","Colin Kaepernick became a starter during the 49ers 63rd season in the National Football League .","Colin Kaepernick became a starting quarterback during the 49ers 63rd season in the National Football League .","Colin Kaepernick became a backup quarterback during the 49ers 63rd season in the National Football League .","Colin Kaepernick became a starting quarterback during the 49ers 63rd season in the National Football League ,"],["Colin Kaepernick became a starting quarterback during the 49ers 63rd season in the National Football League .","Colin Kaepernick became a starting quarterback during the 49ers ' 63rd season in the National Football League .","Colin Kaepernick became a starting quarterback during the 49ers' 63rd season in the National Football League .","Colin Kaepernick became a starting quarterback during the Niners 63rd season in the National Football League .","Colin Kaepernick became a starting quarterback during the 49ers 63rd season in the National Football League ."],["Colin Kaepernick became a starting quarterback during the 49ers season in the National Football League .","Colin Kaepernick became a starting quarterback during the 49ers ' season in the National Football League .","Colin Kaepernick became a starting quarterback during the 49ers season in the National Football League .","Colin Kaepernick became a starting quarterback during the 49ers first season in the National Football League .","Colin Kaepernick became a starting quarterback during the 49ers second season in the National Football League ."],["Colin Kaepernick became a starting quarterback during the 49ers 63rd season in the Super Bowl .","Colin Kaepernick became a starting quarterback during the 49ers 63rd season in the NFC .","Colin Kaepernick became a starting quarterback during the 49ers 63rd season in the professional sports .","Colin Kaepernick became a starting quarterback during the 49ers 63rd season in the NFL .","Colin Kaepernick became a starting quarterback during the 49ers 63rd season in the league ."]]} + +s = dp.convert_inputs_to_dataset([data], tokenizer, True) +print(s) \ No newline at end of file diff --git a/src/check_client/plm_checkers/__init__.py b/src/check_client/plm_checkers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..43596de2c2c4de791f0beccd188fae797ec8796d --- /dev/null +++ b/src/check_client/plm_checkers/__init__.py @@ -0,0 +1,12 @@ +# -*- coding: utf-8 -*- + +""" +@Author : Jiangjie Chen +@Time : 2020/12/27 15:41 +@Contact : jjchen19@fudan.edu.cn +@Description: +""" + + +from .bert_checker import BertChecker +from .roberta_checker import RobertaChecker diff --git a/src/check_client/plm_checkers/bert_checker.py b/src/check_client/plm_checkers/bert_checker.py new file mode 100644 index 0000000000000000000000000000000000000000..0063f7c627031de493c52b1f6e6c5ee13fac32f4 --- /dev/null +++ b/src/check_client/plm_checkers/bert_checker.py @@ -0,0 +1,203 @@ +# -*- coding: utf-8 -*- + +''' +@Author : Jiangjie Chen +@Time : 2020/8/18 14:40 +@Contact : jjchen19@fudan.edu.cn +@Description: +''' + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from transformers import BertModel, BertPreTrainedModel +from .checker_utils import attention_mask_to_mask, ClassificationHead, soft_logic, build_pseudo_labels, \ + get_label_embeddings, temperature_annealing + + +class BertChecker(BertPreTrainedModel): + def __init__(self, config, logic_lambda=0.0, prior='nli', m=8, temperature=1): + super().__init__(config) + self.num_labels = config.num_labels + self.hidden_size = config.hidden_size + self.bert = BertModel(config) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self._lambda = logic_lambda + self.prior = prior + self.temperature = temperature + self._step = 0 + + # general attention + self.linear_self_attn = nn.Linear(self.hidden_size, 1, bias=False) + self.linear_m_attn = nn.Linear(self.hidden_size * 2, 1, bias=False) + + self.var_hidden_size = self.hidden_size // 4 + + z_hid_size = self.num_labels * m + self.linear_P_theta = nn.Linear(self.hidden_size * 2 + z_hid_size, self.var_hidden_size) + y_hid_size = self.var_hidden_size + self.linear_Q_phi = nn.Linear(self.hidden_size * 2 + y_hid_size, self.var_hidden_size) + + self.classifier = ClassificationHead(self.var_hidden_size, self.num_labels, config.hidden_dropout_prob) # label embedding for y + self.z_clf = self.classifier + self.init_weights() + + def forward(self, claim_input_ids, claim_attention_mask, claim_token_type_ids, + qa_input_ids_list, qa_attention_mask_list, qa_token_type_ids_list, + nli_labels=None, labels=None): + ''' + m: num of questions; n: num of evidence; k: num of candidate answers + :param claim_input_ids: b x L1 + :param claim_attention_mask: b x L1 + :param claim_token_type_ids: b x L1 + :param qa_input_ids_list: b x m x L2 + :param qa_attention_mask_list: b x m x L2 + :param qa_token_type_ids_list: b x m x L2 + :param labels: (b,) + :return: + ''' + self._step += 1 + _zero = torch.tensor(0.).to(claim_input_ids.device) + + global_output = self.bert( + claim_input_ids, + attention_mask=claim_attention_mask, + token_type_ids=claim_token_type_ids + )[0] # b x L1 x h + + global_output = self.self_select(global_output) # b x h + + _qa_input_ids_list = qa_input_ids_list.transpose(1, 0) # m x b x L2 + _qa_attention_mask_list = qa_attention_mask_list.transpose(1, 0) + _qa_token_type_ids_list = qa_token_type_ids_list.transpose(1, 0) + + local_output_list = [] + for _inp, _attn, _token_ids in zip(_qa_input_ids_list, _qa_attention_mask_list, _qa_token_type_ids_list): + _local_output = self.bert(_inp, attention_mask=_attn, + token_type_ids=_token_ids)[0] + _local_output = self.self_select(_local_output) + local_output_list.append(_local_output) + + local_outputs = torch.stack(local_output_list, 0) # m x b x h + local_outputs = local_outputs.transpose(1, 0).contiguous() # b x m x h + + neg_elbo, loss, logic_loss = _zero, _zero, _zero + mask = attention_mask_to_mask(qa_attention_mask_list) + # b x h, b x m x h -> b x h + local_outputs_w, m_attn = self.local_attn(global_output, local_outputs, mask) + local_outputs = torch.cat([local_outputs, global_output.unsqueeze(1).repeat(1, local_outputs.size(1), 1)], -1) + + if labels is not None: + # Training + # ======================== Q_phi ================================ + + labels_onehot = F.one_hot(labels, num_classes=self.num_labels).to(torch.float) + y_star_emb = get_label_embeddings(labels_onehot, self.classifier.out_proj.weight) # b x h + z = self.Q_phi(local_outputs, y_star_emb) + z_softmax = z.softmax(-1) + + # ======================== P_theta ============================== + + z_gumbel = F.gumbel_softmax(z, tau=temperature_annealing(self.temperature, self._step), + dim=-1, hard=True) # b x m x 3 + y = self.P_theta(global_output, local_outputs_w, z_gumbel) + + # ======================== soft logic =========================== + mask = mask.to(torch.int) + y_z = soft_logic(z_softmax, mask) # b x 3 + logic_loss = F.kl_div(y.log_softmax(-1), y_z) + + # ======================== ELBO ================================= + elbo_neg_p_log = F.cross_entropy(y.view(-1, self.num_labels), labels.view(-1)) + if self.prior == 'nli': + prior = nli_labels.softmax(dim=-1) + elif self.prior == 'uniform': + prior = torch.tensor([1 / self.num_labels] * self.num_labels).to(y) + prior = prior.unsqueeze(0).unsqueeze(0).repeat(mask.size(0), mask.size(1), 1) + elif self.prior == 'logic': + prior = build_pseudo_labels(labels, m_attn) + else: + raise NotImplementedError(self.prior) + + elbo_kl = F.kl_div(z_softmax.log(), prior) + neg_elbo = elbo_kl + elbo_neg_p_log + + loss = (1 - abs(self._lambda)) * neg_elbo + abs(self._lambda) * logic_loss + else: + # Inference + if self.prior == 'nli': + z = nli_labels + elif self.prior == 'uniform': + prior = torch.tensor([1 / self.num_labels] * self.num_labels).to(y) + z = prior.unsqueeze(0).unsqueeze(0).repeat(mask.size(0), mask.size(1), 1) + else: + z = torch.rand([local_outputs.size(0), local_outputs.size(1), self.num_labels]).to(local_outputs) + z_softmax = z.softmax(-1) + + for i in range(3): # N = 3 + z = z_softmax.argmax(-1) + z = F.one_hot(z, num_classes=3).to(torch.float) + y = self.P_theta(global_output, local_outputs_w, z) + y = y.softmax(-1) + y_emb = get_label_embeddings(y, self.classifier.out_proj.weight) + z = self.Q_phi(local_outputs, y_emb) + z_softmax = z.softmax(-1) + + return (loss, (neg_elbo, logic_loss), y, m_attn, (z_softmax, mask)) # batch first + + def Q_phi(self, X, y): + ''' + X, y => z + :param X: b x m x h + :param y_emb: b x 3 / b x h' + :return: b x m x 3 (ref, nei, sup) + ''' + y_expand = y.unsqueeze(1).repeat(1, X.size(1), 1) # b x m x 3/h' + z_hidden = self.linear_Q_phi(torch.cat([y_expand, X], dim=-1)) # b x m x h' + z_hidden = F.tanh(z_hidden) + z = self.z_clf(z_hidden) + return z + + def P_theta(self, X_global, X_local, z): + ''' + X, z => y* + :param X_global: b x h + :param X_local: b x m x h + :param z: b x m x 3 + :param mask: b x m + :return: b x 3, b x m + ''' + b = z.size(0) + # global classification + _logits = torch.cat([X_local, X_global, z.reshape(b, -1)], dim=-1) + _logits = self.dropout(_logits) + _logits = self.linear_P_theta(_logits) + _logits = torch.tanh(_logits) + + y = self.classifier(_logits) + return y + + def self_select(self, h_x): + ''' + self attention on a vector + :param h_x: b x L x h + :return: b x h + ''' + w = self.dropout(self.linear_self_attn(h_x).squeeze(-1)).softmax(-1) + return torch.einsum('blh,bl->bh', h_x, w) + + def local_attn(self, global_output, local_outputs, mask): + ''' + :param global_output: b x h + :param qa_outputs: b x m x h + :param mask: b x m + :return: b x h, b x m + ''' + m = local_outputs.size(1) + scores = self.linear_m_attn(torch.cat([global_output.unsqueeze(1).repeat(1, m, 1), + local_outputs], dim=-1)).squeeze(-1) # b x m + mask = 1 - mask + scores = scores.masked_fill(mask.to(torch.bool), -1e16) + attn = F.softmax(scores, -1) + return torch.einsum('bm,bmh->bh', attn, local_outputs), attn diff --git a/src/check_client/plm_checkers/checker_utils.py b/src/check_client/plm_checkers/checker_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..77812da1d39bde00ad34494b73514a9d94a81221 --- /dev/null +++ b/src/check_client/plm_checkers/checker_utils.py @@ -0,0 +1,223 @@ +# -*- coding: utf-8 -*- + +''' +@Author : Jiangjie Chen +@Time : 2020/10/15 16:10 +@Contact : jjchen19@fudan.edu.cn +@Description: +''' + +import torch +import random +import torch.nn.functional as F +import torch.nn as nn + + +class ClassificationHead(nn.Module): + """Head for sentence-level classification tasks.""" + + def __init__(self, hidden_size, num_labels, hidden_dropout_prob=0.2): + super().__init__() + self.dropout = nn.Dropout(hidden_dropout_prob) + self.out_proj = nn.Linear(hidden_size, num_labels, bias=False) + + def forward(self, features, **kwargs): + x = features + x = self.dropout(x) + x = self.out_proj(x) + return x + + +def temperature_annealing(tau, step): + if tau == 0.: + tau = 10. if step % 5 == 0 else 1. + return tau + + +def get_label_embeddings(labels, label_embedding): + ''' + :param labels: b x 3 + :param label_embedding: 3 x h' + :return: b x h' + ''' + emb = torch.einsum('oi,bo->bi', label_embedding, labels) + return emb + + +def soft_logic(y_i, mask, tnorm='product'): + ''' + a^b = ab + avb = 1 - ((1-a)(1-b)) + :param y_i: b x m x 3 + :param mask: b x m + :param tnorm: product or godel or lukasiewicz + :return: [b x 3] + ''' + _sup = y_i[:, :, 2] # b x m + _ref = y_i[:, :, 0] # b x m + _sup = _sup * mask + (1 - mask) # pppp1111 + _ref = _ref * mask # pppp0000 + + if tnorm == 'product': + p_sup = torch.exp(torch.log(_sup).sum(1)) + p_ref = 1 - torch.exp(torch.log(1 - _ref).sum(1)) + elif tnorm == 'godel': + p_sup = _sup.min(-1).values + p_ref = _ref.max(-1).values + elif tnorm == 'lukas': + raise NotImplementedError(tnorm) + else: + raise NotImplementedError(tnorm) + + p_nei = 1 - p_sup - p_ref + p_sup = torch.max(p_sup, torch.zeros_like(p_sup)) + p_ref = torch.max(p_ref, torch.zeros_like(p_ref)) + p_nei = torch.max(p_nei, torch.zeros_like(p_nei)) + logical_prob = torch.stack([p_ref, p_nei, p_sup], dim=-1) + assert torch.lt(logical_prob, 0).to(torch.int).sum().tolist() == 0, \ + (logical_prob, _sup, _ref) + return logical_prob # b x 3 + + +def build_pseudo_labels(labels, m_attn): + ''' + :param labels: (b,) + :param m_attn: b x m + :return: b x m x 3 + ''' + mask = torch.gt(m_attn, 1e-16).to(torch.int) + sup_label = torch.tensor(2).to(labels) + nei_label = torch.tensor(1).to(labels) + ref_label = torch.tensor(0).to(labels) + pseudo_labels = [] + for idx, label in enumerate(labels): + mm = mask[idx].sum(0) + if label == 2: # SUPPORTS + pseudo_label = F.one_hot(sup_label.repeat(mask.size(1)), num_classes=3).to(torch.float) # TODO: hyperparam + + elif label == 0: # REFUTES + num_samples = magic_proportion(mm) + ids = torch.topk(m_attn[idx], k=num_samples).indices + pseudo_label = [] + for i in range(mask.size(1)): + if i >= mm: + _label = torch.tensor([1/3, 1/3, 1/3]).to(labels) + elif i in ids: + _label = F.one_hot(ref_label, num_classes=3).to(torch.float) + else: + if random.random() > 0.5: + _label = torch.tensor([0., 0., 1.]).to(labels) + else: + _label = torch.tensor([0., 1., 0.]).to(labels) + pseudo_label.append(_label) + pseudo_label = torch.stack(pseudo_label) + + else: # NEI + num_samples = magic_proportion(mm) + ids = torch.topk(m_attn[idx], k=num_samples).indices + pseudo_label = sup_label.repeat(mask.size(1)) + pseudo_label[ids] = nei_label + pseudo_label = F.one_hot(pseudo_label, num_classes=3).to(torch.float) # TODO: hyperparam + + pseudo_labels.append(pseudo_label) + return torch.stack(pseudo_labels) + + +def magic_proportion(m, magic_n=5): + # 1~4: 1, 5~m: 2 + return m // magic_n + 1 + + +def sequence_mask(lengths, max_len=None): + """ + Creates a boolean mask from sequence lengths. + """ + batch_size = lengths.numel() + max_len = max_len or lengths.max() + return (torch.arange(0, max_len, device=lengths.device) + .type_as(lengths) + .repeat(batch_size, 1) + .lt(lengths.unsqueeze(1))) + + +def collapse_w_mask(inputs, mask): + ''' + :param inputs: b x L x h + :param mask: b x L + :return: b x h + ''' + hidden = inputs.size(-1) + output = inputs * mask.unsqueeze(-1).repeat((1, 1, hidden)) # b x L x h + output = output.sum(-2) + output /= (mask.sum(-1) + 1e-6).unsqueeze(-1).repeat((1, hidden)) # b x h + return output + + +def parse_ce_outputs(ce_seq_output, ce_lengths): + ''' + :param qa_seq_output: b x L1 x h + :param qa_lengths: e.g. [0,1,1,0,2,2,0,0] (b x L2) + :return: + c_output: b x h + e_output: b x h + ''' + if ce_lengths.max() == 0: + b, L1, h = ce_seq_output.size() + return torch.zeros([b, h]).cuda(), torch.zeros([b, h]).cuda() + masks = [] + for mask_id in range(1, ce_lengths.max() + 1): + _m = torch.ones_like(ce_lengths) * mask_id + mask = _m.eq(ce_lengths).to(torch.int) + masks.append(mask) + c_output = collapse_w_mask(ce_seq_output, masks[0]) + e_output = torch.stack([collapse_w_mask(ce_seq_output, m) + for m in masks[1:]]).mean(0) + return c_output, e_output + + +def parse_qa_outputs(qa_seq_output, qa_lengths, k): + ''' + :param qa_seq_output: b x L2 x h + :param qa_lengths: e.g. [0,1,1,0,2,2,0,3,0,4,0,5,0,0,0,0] (b x L2) + :return: + q_output: b x h + a_output: b x h + k_cand_output: k x b x h + ''' + b, L2, h = qa_seq_output.size() + if qa_lengths.max() == 0: + return torch.zeros([b, h]).cuda(), torch.zeros([b, h]).cuda(), \ + torch.zeros([k, b, h]).cuda() + + masks = [] + for mask_id in range(1, qa_lengths.max() + 1): + _m = torch.ones_like(qa_lengths) * mask_id + mask = _m.eq(qa_lengths).to(torch.int) + masks.append(mask) + + q_output = collapse_w_mask(qa_seq_output, masks[0]) + a_output = collapse_w_mask(qa_seq_output, masks[1]) + k_cand_output = [collapse_w_mask(qa_seq_output, m) + for m in masks[2:2 + k]] + for i in range(k - len(k_cand_output)): + k_cand_output.append(torch.zeros([b, h]).cuda()) + k_cand_output = torch.stack(k_cand_output, dim=0) + + return q_output, a_output, k_cand_output + + +def attention_mask_to_mask(attention_mask): + ''' + :param attention_mask: b x m x L + :return: b x m + ''' + mask = torch.gt(attention_mask.sum(-1), 0).to(torch.int).sum(-1) # (b,) + mask = sequence_mask(mask, max_len=attention_mask.size(1)).to(torch.int) # (b, m) + return mask + + +if __name__ == "__main__": + y = torch.tensor([[[0.3,0.5,0.2],[0.1,0.4,0.5]]]) + mask = torch.tensor([1,1]) + s = soft_logic(y, mask) + print(s) \ No newline at end of file diff --git a/src/check_client/plm_checkers/roberta_checker.py b/src/check_client/plm_checkers/roberta_checker.py new file mode 100644 index 0000000000000000000000000000000000000000..51d0ab6f0219961072ef5f730872dde433614e4a --- /dev/null +++ b/src/check_client/plm_checkers/roberta_checker.py @@ -0,0 +1,203 @@ +# -*- coding: utf-8 -*- + +''' +@Author : Jiangjie Chen +@Time : 2020/8/18 14:40 +@Contact : jjchen19@fudan.edu.cn +@Description: +''' + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from transformers import RobertaModel, BertPreTrainedModel, RobertaConfig +from .checker_utils import attention_mask_to_mask, ClassificationHead, soft_logic, build_pseudo_labels, \ + get_label_embeddings, temperature_annealing + + +class RobertaChecker(BertPreTrainedModel): + config_class = RobertaConfig + base_model_prefix = "roberta" + + def __init__(self, config, logic_lambda=0.0, prior='nli', m=8, temperature=1): + super().__init__(config) + self.num_labels = config.num_labels + self.hidden_size = config.hidden_size + self.roberta = RobertaModel(config) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self._lambda = logic_lambda + self.prior = prior + self.temperature = temperature + self._step = 0 + + # general attention + self.linear_self_attn = nn.Linear(self.hidden_size, 1, bias=False) + self.linear_m_attn = nn.Linear(self.hidden_size * 2, 1, bias=False) + + self.var_hidden_size = self.hidden_size // 4 + + z_hid_size = self.num_labels * m + self.linear_P_theta = nn.Linear(self.hidden_size * 2 + z_hid_size, self.var_hidden_size) + y_hid_size = self.var_hidden_size + self.linear_Q_phi = nn.Linear(self.hidden_size * 2 + y_hid_size, self.var_hidden_size) + + # TODO: y_clf => classifier. compromise for mnli + self.classifier = ClassificationHead(self.var_hidden_size, self.num_labels, + config.hidden_dropout_prob) # label embedding for y + self.z_clf = self.classifier + self.init_weights() + + def forward(self, claim_input_ids, claim_attention_mask, claim_token_type_ids, + qa_input_ids_list, qa_attention_mask_list, qa_token_type_ids_list, + nli_labels=None, labels=None): + ''' + m: num of questions; n: num of evidence; k: num of candidate answers + :param claim_input_ids: b x L1 + :param claim_attention_mask: b x L1 + :param claim_token_type_ids: b x L1 + :param qa_input_ids_list: b x m x L2 + :param qa_attention_mask_list: b x m x L2 + :param qa_token_type_ids_list: b x m x L2 + :param nli_labels: b x m x 3 + :param labels: (b,) + :return: (loss, (neg_elbo, logic_loss), y, m_attn, (z_softmax, mask)) + ''' + self._step += 1 + _zero = torch.tensor(0.).to(claim_input_ids.device) + + # ====================== Representation learning ======================= + global_output = self.roberta(claim_input_ids, attention_mask=claim_attention_mask)[0] # b x L1 x h + global_output = self.self_select(global_output) # b x h + + _qa_input_ids_list = qa_input_ids_list.transpose(1, 0) # m x b x L2 + _qa_attention_mask_list = qa_attention_mask_list.transpose(1, 0) + + local_output_list = [] + for _inp, _attn in zip(_qa_input_ids_list, _qa_attention_mask_list): + _local_output = self.roberta(_inp, attention_mask=_attn)[0] + _local_output = self.self_select(_local_output) + local_output_list.append(_local_output) + + _local_outputs = torch.stack(local_output_list, 0) # m x b x h + local_outputs = _local_outputs.transpose(1, 0).contiguous() # b x m x h + + neg_elbo, loss, logic_loss = _zero, _zero, _zero + mask = attention_mask_to_mask(qa_attention_mask_list) + # b x h, b x m x h -> b x h + local_outputs_w, m_attn = self.local_attn(global_output, local_outputs, mask) + local_outputs = torch.cat([local_outputs, global_output.unsqueeze(1).repeat(1, local_outputs.size(1), 1)], -1) + + if labels is not None: + # Training + # ======================== Q_phi ================================ + + labels_onehot = F.one_hot(labels, num_classes=self.num_labels).to(torch.float) + y_star_emb = get_label_embeddings(labels_onehot, self.classifier.out_proj.weight) # b x h + z = self.Q_phi(local_outputs, y_star_emb) + z_softmax = z.softmax(-1) + + # ======================== P_theta ============================== + + z_gumbel = F.gumbel_softmax(z, tau=temperature_annealing(self.temperature, self._step), + dim=-1, hard=True) # b x m x 3 + y = self.P_theta(global_output, local_outputs_w, z_gumbel) + + # ======================== soft logic =========================== + mask = mask.to(torch.int) + y_z = soft_logic(z_softmax, mask) # b x 3 + logic_loss = F.kl_div(y.log_softmax(-1), y_z) + + # ======================== ELBO ================================= + elbo_neg_p_log = F.cross_entropy(y.view(-1, self.num_labels), labels.view(-1)) + if self.prior == 'nli': + prior = nli_labels.softmax(dim=-1) + elif self.prior == 'uniform': + prior = torch.tensor([1 / self.num_labels] * self.num_labels).to(mask.device) + prior = prior.unsqueeze(0).unsqueeze(0).repeat(mask.size(0), mask.size(1), 1) + elif self.prior == 'logic': + prior = build_pseudo_labels(labels, m_attn) + else: + raise NotImplementedError(self.prior) + + elbo_kl = F.kl_div(z_softmax.log(), prior) + neg_elbo = elbo_kl + elbo_neg_p_log + + loss = (1 - abs(self._lambda)) * neg_elbo + abs(self._lambda) * logic_loss + else: + # Inference + if self.prior == 'nli': + z = nli_labels + elif self.prior == 'uniform': + prior = torch.tensor([1 / self.num_labels] * self.num_labels).to(mask.device) + z = prior.unsqueeze(0).unsqueeze(0).repeat(mask.size(0), mask.size(1), 1) + else: + z = torch.rand([local_outputs.size(0), local_outputs.size(1), self.num_labels]).to(local_outputs) + z_softmax = z.softmax(-1) + + for i in range(3): # N = 3 + z = z_softmax.argmax(-1) + z = F.one_hot(z, num_classes=3).to(torch.float) + y = self.P_theta(global_output, local_outputs_w, z) + y = y.softmax(-1) + y_emb = get_label_embeddings(y, self.classifier.out_proj.weight) + z = self.Q_phi(local_outputs, y_emb) + z_softmax = z.softmax(-1) + + return (loss, (neg_elbo, logic_loss), y, m_attn, (z_softmax, mask)) # batch first + + def Q_phi(self, X, y): + ''' + X, y => z + :param X: b x m x h + :param y_emb: b x 3 / b x h' + :return: b x m x 3 (ref, nei, sup) + ''' + y_expand = y.unsqueeze(1).repeat(1, X.size(1), 1) # b x m x 3/h' + z_hidden = self.linear_Q_phi(torch.cat([y_expand, X], dim=-1)) # b x m x h' + z_hidden = F.tanh(z_hidden) + z = self.z_clf(z_hidden) + return z + + def P_theta(self, X_global, X_local, z): + ''' + X, z => y* + :param X_global: b x h + :param X_local: b x m x h + :param z: b x m x 3 + :param mask: b x m + :return: b x 3, b x m + ''' + b = z.size(0) + # global classification + _logits = torch.cat([X_local, X_global, z.reshape(b, -1)], dim=-1) + _logits = self.dropout(_logits) + _logits = self.linear_P_theta(_logits) + _logits = torch.tanh(_logits) + + y = self.classifier(_logits) + return y + + def self_select(self, h_x): + ''' + self attention on a vector + :param h_x: b x L x h + :return: b x h + ''' + w = self.dropout(self.linear_self_attn(h_x).squeeze(-1)).softmax(-1) + return torch.einsum('blh,bl->bh', h_x, w) + + def local_attn(self, global_output, local_outputs, mask): + ''' + :param global_output: b x h + :param qa_outputs: b x m x h + :param mask: b x m + :return: b x h, b x m + ''' + m = local_outputs.size(1) + scores = self.linear_m_attn(torch.cat([global_output.unsqueeze(1).repeat(1, m, 1), + local_outputs], dim=-1)).squeeze(-1) # b x m + mask = 1 - mask + scores = scores.masked_fill(mask.to(torch.bool), -1e16) + attn = F.softmax(scores, -1) + return torch.einsum('bm,bmh->bh', attn, local_outputs), attn diff --git a/src/check_client/scripts/train_bert-large.sh b/src/check_client/scripts/train_bert-large.sh new file mode 100644 index 0000000000000000000000000000000000000000..4b2b7cabf902ea8bdbd753d612f12248882c9e61 --- /dev/null +++ b/src/check_client/scripts/train_bert-large.sh @@ -0,0 +1,51 @@ +#!/usr/bin/env bash + +MODEL_TYPE=bert +MODEL_NAME_OR_PATH=bert-large-cased +VERSION=v5 +MAX_NUM_QUESTIONS=8 + +MAX_SEQ1_LENGTH=256 +MAX_SEQ2_LENGTH=128 +CAND_K=3 +LAMBDA=${1:-0.5} +PRIOR=${2:-nli} +MASK=${3:-0.0} +echo "lambda = $LAMBDA, prior = $PRIOR, mask = $MASK" + +DATA_DIR=$PJ_HOME/data/fact_checking/${VERSION} +OUTPUT_DIR=$PJ_HOME/models/fact_checking/${VERSION}_${MODEL_NAME_OR_PATH}/${VERSION}_${MODEL_NAME_OR_PATH}_AAAI_K${CAND_K}_${PRIOR}_m${MASK}_l${LAMBDA} +NUM_TRAIN_EPOCH=7 +GRADIENT_ACCUMULATION_STEPS=2 +PER_GPU_TRAIN_BATCH_SIZE=8 # 4546 +PER_GPU_EVAL_BATCH_SIZE=16 +LOGGING_STEPS=200 +SAVE_STEPS=200 + + +python3 train.py \ + --data_dir ${DATA_DIR} \ + --output_dir ${OUTPUT_DIR} \ + --model_type ${MODEL_TYPE} \ + --model_name_or_path ${MODEL_NAME_OR_PATH} \ + --max_seq1_length ${MAX_SEQ1_LENGTH} \ + --max_seq2_length ${MAX_SEQ2_LENGTH} \ + --max_num_questions ${MAX_NUM_QUESTIONS} \ + --do_train \ + --do_eval \ + --evaluate_during_training \ + --learning_rate 1e-5 \ + --num_train_epochs ${NUM_TRAIN_EPOCH} \ + --gradient_accumulation_steps ${GRADIENT_ACCUMULATION_STEPS} \ + --per_gpu_train_batch_size ${PER_GPU_TRAIN_BATCH_SIZE} \ + --per_gpu_eval_batch_size ${PER_GPU_EVAL_BATCH_SIZE} \ + --logging_steps ${LOGGING_STEPS} \ + --save_steps ${SAVE_STEPS} \ + --cand_k ${CAND_K} \ + --logic_lambda ${LAMBDA} \ + --prior ${PRIOR} \ + --overwrite_output_dir \ + --temperature 1.0 \ + --mask_rate ${MASK} + +python3 cjjpy.py --lark "$OUTPUT_DIR fact checking training completed" diff --git a/src/check_client/scripts/train_roberta.sh b/src/check_client/scripts/train_roberta.sh new file mode 100644 index 0000000000000000000000000000000000000000..5b00d8ee1967eadf5406db00f44a79a18bf2de59 --- /dev/null +++ b/src/check_client/scripts/train_roberta.sh @@ -0,0 +1,51 @@ +#!/usr/bin/env bash + +MODEL_TYPE=roberta +MODEL_NAME_OR_PATH=roberta-large +VERSION=v5 +MAX_NUM_QUESTIONS=8 + +MAX_SEQ1_LENGTH=256 +MAX_SEQ2_LENGTH=128 +CAND_K=3 +LAMBDA=${1:-0.5} +PRIOR=${2:-nli} +MASK=${3:-0.0} +echo "lambda = $LAMBDA, prior = $PRIOR, mask = $MASK" + +DATA_DIR=$PJ_HOME/data/fact_checking/${VERSION} +OUTPUT_DIR=$PJ_HOME/models/fact_checking/${VERSION}_${MODEL_NAME_OR_PATH}/${VERSION}_${MODEL_NAME_OR_PATH}_AAAI_K${CAND_K}_${PRIOR}_m${MASK}_l${LAMBDA} +NUM_TRAIN_EPOCH=7 +GRADIENT_ACCUMULATION_STEPS=2 +PER_GPU_TRAIN_BATCH_SIZE=8 # 4546 +PER_GPU_EVAL_BATCH_SIZE=16 +LOGGING_STEPS=200 +SAVE_STEPS=200 + + +python3 train.py \ + --data_dir ${DATA_DIR} \ + --output_dir ${OUTPUT_DIR} \ + --model_type ${MODEL_TYPE} \ + --model_name_or_path ${MODEL_NAME_OR_PATH} \ + --max_seq1_length ${MAX_SEQ1_LENGTH} \ + --max_seq2_length ${MAX_SEQ2_LENGTH} \ + --max_num_questions ${MAX_NUM_QUESTIONS} \ + --do_train \ + --do_eval \ + --evaluate_during_training \ + --learning_rate 1e-5 \ + --num_train_epochs ${NUM_TRAIN_EPOCH} \ + --gradient_accumulation_steps ${GRADIENT_ACCUMULATION_STEPS} \ + --per_gpu_train_batch_size ${PER_GPU_TRAIN_BATCH_SIZE} \ + --per_gpu_eval_batch_size ${PER_GPU_EVAL_BATCH_SIZE} \ + --logging_steps ${LOGGING_STEPS} \ + --save_steps ${SAVE_STEPS} \ + --cand_k ${CAND_K} \ + --logic_lambda ${LAMBDA} \ + --prior ${PRIOR} \ + --overwrite_output_dir \ + --temperature 1.0 \ + --mask_rate ${MASK} + +python3 cjjpy.py --lark "$OUTPUT_DIR fact checking training completed" diff --git a/src/check_client/train.py b/src/check_client/train.py new file mode 100644 index 0000000000000000000000000000000000000000..e0877fc794122789558eb2f66ddc172937aa42ef --- /dev/null +++ b/src/check_client/train.py @@ -0,0 +1,647 @@ +# coding=utf-8 +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import glob +import argparse +import logging +import random +import torch +import numpy as np +from tqdm import tqdm +from torch.utils.data import DataLoader, RandomSampler, SequentialSampler +from torch.utils.data.distributed import DistributedSampler +from transformers import ( + AutoConfig, + AutoTokenizer +) +from transformers import WEIGHTS_NAME, AdamW, get_linear_schedule_with_warmup +import tensorflow as tf +from pytorch_lightning.loggers import WandbLogger + +try: + from .modules.data_processor import DataProcessor + from .plm_checkers import BertChecker, RobertaChecker, XLNetChecker, DebertaChecker + from .utils import init_logger, compute_metrics +except: + from modules.data_processor import DataProcessor + from plm_checkers import BertChecker, RobertaChecker, XLNetChecker, DebertaChecker + from utils import init_logger, compute_metrics + +try: + from torch.utils.tensorboard import SummaryWriter +except ImportError: + from tensorboardX import SummaryWriter + +mAutoModel = { + 'bert': BertChecker, + 'roberta': RobertaChecker, + 'xlnet': XLNetChecker, + 'deberta': DebertaChecker, +} + +logger = logging.getLogger(__name__) + + +def set_seed(args): + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + if args.n_gpu > 0: + torch.cuda.manual_seed_all(args.seed) + + +def train(args, data_processor, model, tokenizer): + """ Train the model """ + global wdblogger + if args.local_rank in [-1, 0]: + tb_writer = SummaryWriter() + + tf.io.gfile.makedirs(os.path.dirname(args.output_dir)) + args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu) + train_dataset = data_processor.load_and_cache_data("train", tokenizer, args.data_tag) + train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset) + train_dataloader = DataLoader(train_dataset, sampler=train_sampler, + drop_last=True, + batch_size=args.train_batch_size) + + if args.max_steps > 0: + t_total = args.max_steps + args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1 + else: + t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs + + # Prepare optimizer and schedule (linear warmup and decay) + no_decay = ["bias", "LayerNorm.weight"] + optimizer_grouped_parameters = [ + { + "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], + "weight_decay": args.weight_decay, + }, + { + "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], + "weight_decay": 0.0 + }, + ] + optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) + scheduler = get_linear_schedule_with_warmup( + optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total + ) + if args.fp16: + try: + from apex import amp + except ImportError: + raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") + model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level) + + # multi-gpu training (should be after apex fp16 initialization) + if args.n_gpu > 1: + model = torch.nn.DataParallel(model) + + # Distributed training (should be after apex fp16 initialization) + if args.local_rank != -1: + model = torch.nn.parallel.DistributedDataParallel( + model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True + ) + + # Train! + logger.info("***** Running training *****") + logger.info("Num examples = %d", len(train_dataset)) + logger.info("Num Epochs = %d", args.num_train_epochs) + logger.info("Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size) + logger.info( + "Total train batch size (w. parallel, distributed & accumulation) = %d", + args.train_batch_size + * args.gradient_accumulation_steps + * (torch.distributed.get_world_size() if args.local_rank != -1 else 1), + ) + logger.info("Gradient Accumulation steps = %d", args.gradient_accumulation_steps) + logger.info("Total optimization steps = %d", t_total) + + global_step = 0 + tr_loss, logging_loss = 0.0, 0.0 + tr_loss2, logging_loss2 = 0.0, 0.0 + tr_loss3, logging_loss3 = 0.0, 0.0 + set_seed(args) # Added here for reproductibility + model.zero_grad() + for _ in range(int(args.num_train_epochs)): + all_loss = 0.0 + all_accuracy = 0.0 + epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0]) + for step, batch in enumerate(epoch_iterator): + model.train() + batch = tuple(t.to(args.device) for t in batch) + inputs = { + "claim_input_ids": batch[0], + "claim_attention_mask": batch[1], + "qa_input_ids_list": batch[3], + "qa_attention_mask_list": batch[4], + "nli_labels": batch[-2], + "labels": batch[-1], + } + if args.model_type != "distilbert": + # XLM, DistilBERT, RoBERTa, and XLM-RoBERTa don't use segment_ids + inputs["claim_token_type_ids"] = batch[2] if args.model_type in ["bert", "xlnet", "albert"] else None + inputs["qa_token_type_ids_list"] = batch[5] if args.model_type in ["bert", "xlnet", "albert"] else None + + outputs = model(**inputs) + loss, _loss2, logits = outputs[0], outputs[1], outputs[2] + loss2, loss3 = _loss2 + + if args.n_gpu > 1: + loss = loss.mean() # mean() to average on multi-gpu parallel (not distributed) training + loss2 = loss2.mean() + loss3 = loss3.mean() + if args.gradient_accumulation_steps > 1: + loss = loss / args.gradient_accumulation_steps + loss2 = loss2 / args.gradient_accumulation_steps + loss3 = loss3 / args.gradient_accumulation_steps + + if args.fp16: + with amp.scale_loss(loss, optimizer) as scaled_loss: + scaled_loss.backward() + else: + loss.backward() + + tr_loss += loss.item() + tr_loss2 += loss2.item() + tr_loss3 += loss3.item() + + all_loss += loss.detach().cpu().numpy() * args.gradient_accumulation_steps + all_accuracy += np.mean( + inputs["labels"].detach().cpu().numpy() == logits.detach().cpu().numpy().argmax(axis=-1) + ) + description = "Global step: {:>6}, Loss: {:>.6f}, Accuracy: {:>.6f}".format( + global_step, + all_loss / (step + 1), + all_accuracy / (step + 1), + ) + epoch_iterator.set_description(description) + if (step + 1) % args.gradient_accumulation_steps == 0: + if args.fp16: + torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm) + else: + torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) + + optimizer.step() + scheduler.step() # Update learning rate schedule + model.zero_grad() + global_step += 1 + + # Log metrics + if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0: + # Only evaluate when single GPU otherwise metrics may not average well + if args.local_rank == -1 and args.evaluate_during_training: + results = evaluate(args, data_processor, model, tokenizer) + for key, value in results.items(): + logger.warning(f"Step: {global_step}, eval_{key}: {value}") + wdblogger.log_metrics({"eval_{}".format(key): value}, global_step) + tb_writer.add_scalar("eval_{}".format(key), value, global_step) + tb_writer.add_scalar("lr", scheduler.get_lr()[0], global_step) + wdblogger.log_metrics({"lr": scheduler.get_lr()[0]}, global_step) + tb_writer.add_scalar("loss", (tr_loss - logging_loss) / args.logging_steps, global_step) + wdblogger.log_metrics({"loss": (tr_loss - logging_loss) / args.logging_steps}, global_step) + wdblogger.log_metrics({"loss2": (tr_loss2 - logging_loss2) / args.logging_steps}, global_step) + wdblogger.log_metrics({"loss3": (tr_loss3 - logging_loss3) / args.logging_steps}, global_step) + + logging_loss = tr_loss + logging_loss2 = tr_loss2 + logging_loss3 = tr_loss3 + wdblogger.save() + + # Save model checkpoint + if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0: + output_dir = os.path.join(args.output_dir, "checkpoint-{}".format(global_step)) + if not os.path.exists(output_dir): + os.makedirs(output_dir) + # Take care of distributed/parallel training + model_to_save = model.module if hasattr(model, "module") else model + model_to_save.save_pretrained(output_dir) + torch.save(args, os.path.join(output_dir, "training_args.bin")) + logger.info("Saving model checkpoint to %s", output_dir) + + if 0 < args.max_steps < global_step: + epoch_iterator.close() + break + if 0 < args.max_steps < global_step: + break + + if args.local_rank in [-1, 0]: + tb_writer.close() + + return global_step, tr_loss / global_step + + +def evaluate(args, data_processor, model, tokenizer, prefix=""): + if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]: + os.makedirs(args.output_dir) + + args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu) + dataset = data_processor.load_and_cache_data("eval", tokenizer, args.data_tag) + eval_sampler = SequentialSampler(dataset) if args.local_rank == -1 else DistributedSampler(dataset) + eval_dataloader = DataLoader(dataset, sampler=eval_sampler, + drop_last=True, + batch_size=args.eval_batch_size) + + # Eval! + logger.info("***** Running evaluation {} *****".format(prefix)) + logger.info("Num examples = %d", len(dataset)) + logger.info("Batch size = %d", args.eval_batch_size) + + label_truth, y_predicted, z_predicted, m_attn, mask = \ + do_evaluate(tqdm(eval_dataloader, desc="Evaluating"), model, args, during_training=True, with_label=True) + + outputs, results = compute_metrics(label_truth, y_predicted, z_predicted, mask) + + return results + + +def do_evaluate(dataloader, model, args, during_training=False, with_label=True): + label_truth = [] + y_predicted = [] + z_predicted = [] + m_attn = [] + mask = [] + for i, batch in enumerate(dataloader): + model.eval() + batch = tuple(t.to(args.device) for t in batch) + with torch.no_grad(): + inputs = { + "claim_input_ids": batch[0], + "claim_attention_mask": batch[1], + "qa_input_ids_list": batch[3], + "qa_attention_mask_list": batch[4], + "nli_labels": batch[6], + } + + if args.model_type != "distilbert": + # XLM, DistilBERT, RoBERTa, and XLM-RoBERTa don't use segment_ids + inputs["claim_token_type_ids"] = batch[2] if args.model_type in ["bert", "xlnet", "albert"] else None + inputs["qa_token_type_ids_list"] = batch[5] if args.model_type in ["bert", "xlnet", "albert"] else None + + outputs = model(**inputs) + + if during_training and (i < 3 and (args.logic_lambda != 0)): + logger.warning(f'* m_attn:\n {outputs[-2][:5]}\n') + logger.warning(f'* Logic outputs:\n {outputs[-1][0][:5]}.\n Labels: {batch[-1][:5]}\n') + + if with_label: + label_truth += batch[-1].tolist() + y_predicted += outputs[2].tolist() + mask += outputs[-1][1].tolist() + z_predicted += outputs[-1][0].tolist() + m_attn += outputs[-2].tolist() + + y_predicted = np.argmax(y_predicted, axis=-1).tolist() + + return label_truth, y_predicted, z_predicted, m_attn, mask + + +def main(): + parser = argparse.ArgumentParser() + + # Required parameters + parser.add_argument( + "--data_dir", + default=None, + type=str, + required=True, + help="The input data dir. Should contain the .tsv files (or other data files) for the task.", + ) + parser.add_argument( + "--model_type", + default=None, + type=str, + required=True, + help="Model type selected in the list: " + ", ".join(mAutoModel.keys()), + ) + parser.add_argument( + "--model_name_or_path", + default=None, + type=str, + required=True, + help="Path to pre-trained model or shortcut name", + ) + parser.add_argument( + "--data_tag", + default='default', + type=str, + help='Tag to cached data' + ) + parser.add_argument( + "--max_seq1_length", + default=None, + type=int, + required=True, + help="The maximum total input claim sequence length after tokenization. " + "Sequences longer than this will be truncated, sequences shorter will be padded.", + ) + parser.add_argument( + "--max_seq2_length", + default=None, + type=int, + required=True, + help="The maximum total input claim sequence length after tokenization. " + "Sequences longer than this will be truncated, sequences shorter will be padded.", + ) + parser.add_argument( + "--max_num_questions", + default=None, + type=int, + required=True, + help='The maximum number of evidences.', + ) + parser.add_argument( + "--cand_k", + default=1, + type=int, + help='The number of evidential answers out of beam size' + ) + parser.add_argument( + '--mask_rate', + default=0., + type=float, + help="Mask rate of QA" + ) + parser.add_argument( + "--output_dir", + default=None, + type=str, + required=True, + help="The output directory where the model predictions and checkpoints will be written.", + ) + + # Other parameters + parser.add_argument( + "--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name", + ) + parser.add_argument( + "--tokenizer_name", + default="", + type=str, + help="Pretrained tokenizer name or path if not the same as model_name", + ) + parser.add_argument( + "--cache_dir", + default="", + type=str, + help="Where do you want to store the pre-trained models downloaded from s3", + ) + parser.add_argument( + "--max_seq_length", + default=128, + type=int, + help="The maximum total input sequence length after tokenization. Sequences longer " + "than this will be truncated, sequences shorter will be padded.", + ) + parser.add_argument('--logic_lambda', required=True, type=float, + help='Regularization term for logic loss, also an indicator for using only logic.') + parser.add_argument('--prior', default='nli', type=str, choices=['nli', 'uniform', 'logic', 'random'], + help='type of prior distribution') + parser.add_argument('--temperature', required=True, type=float, help='Temperature for gumbel softmax.') + + parser.add_argument("--do_train", action="store_true", help="Whether to run training.") + parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the dev set.") + parser.add_argument( + "--evaluate_during_training", action="store_true", help="Run evaluation during training at each logging step.", + ) + parser.add_argument( + "--do_lower_case", action="store_true", help="Set this flag if you are using an uncased model.", + ) + parser.add_argument( + "--per_gpu_train_batch_size", default=8, type=int, help="Batch size per GPU/CPU for training.", + ) + parser.add_argument( + "--per_gpu_eval_batch_size", default=8, type=int, help="Batch size per GPU/CPU for evaluation.", + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.") + parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.") + parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.") + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument( + "--num_train_epochs", default=3.0, type=float, help="Total number of training epochs to perform.", + ) + parser.add_argument( + "--max_steps", + default=-1, + type=int, + help="If > 0: set total number of training steps to perform. Override num_train_epochs.", + ) + parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.") + parser.add_argument("--logging_steps", type=int, default=500, help="Log every X updates steps.") + parser.add_argument("--save_steps", type=int, default=500, help="Save checkpoint every X updates steps.") + parser.add_argument( + "--eval_all_checkpoints", + action="store_true", + help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number", + ) + parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available") + parser.add_argument( + "--overwrite_output_dir", action="store_true", help="Overwrite the content of the output directory", + ) + parser.add_argument( + "--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets", + ) + parser.add_argument("--seed", type=int, default=42, help="random seed for initialization") + parser.add_argument( + "--fp16", + action="store_true", + help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit", + ) + parser.add_argument( + "--fp16_opt_level", + type=str, + default="O1", + help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']." + "See details at https://nvidia.github.io/apex/amp.html", + ) + parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + parser.add_argument("--server_ip", type=str, default="", help="For distant debugging.") + parser.add_argument("--server_port", type=str, default="", help="For distant debugging.") + args = parser.parse_args() + + if ( + os.path.exists(args.output_dir) + and os.listdir(args.output_dir) + and args.do_train + and not args.overwrite_output_dir + ): + raise ValueError( + "Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format( + args.output_dir + ) + ) + + # Setup distant debugging if needed + if args.server_ip and args.server_port: + # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script + import ptvsd + + print("Waiting for debugger attach") + ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True) + ptvsd.wait_for_attach() + + # Setup CUDA, GPU & distributed training + if args.local_rank == -1 or args.no_cuda: + device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") + args.n_gpu = 0 if args.no_cuda else torch.cuda.device_count() + else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs + torch.cuda.set_device(args.local_rank) + device = torch.device("cuda", args.local_rank) + torch.distributed.init_process_group(backend="nccl") + args.n_gpu = 1 + args.device = device + + # Setup logging + if args.do_train: + global wdblogger + tf.io.gfile.makedirs(args.output_dir) + wdblogger = WandbLogger(name=os.path.basename(args.output_dir)) + wdblogger.log_hyperparams(args) + wdblogger.save() + log_file = os.path.join(args.output_dir, 'train.log') + init_logger(logging.INFO if args.local_rank in [-1, 0] else logging.WARN, log_file) + + logger.warning( + "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s", + args.local_rank, + device, + args.n_gpu, + bool(args.local_rank != -1), + args.fp16, + ) + + # Set seed + set_seed(args) + + # Prepare task + data_processor = DataProcessor( + args.model_name_or_path, + args.max_seq1_length, + args.max_seq2_length, + args.max_num_questions, + args.cand_k, + data_dir=args.data_dir, + cache_dir_name=os.path.basename(args.output_dir), + overwrite_cache=args.overwrite_cache, + mask_rate=args.mask_rate + ) + + # Make sure only the first process in distributed training will download model & vocab + if args.local_rank not in [-1, 0]: + torch.distributed.barrier() + + # Load pretrained model and tokenizer + args.model_type = args.model_type.lower() + + config = AutoConfig.from_pretrained( + args.config_name if args.config_name else args.model_name_or_path, + num_labels=3, + cache_dir=args.cache_dir if args.cache_dir else None, + ) + tokenizer = AutoTokenizer.from_pretrained( + args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, + do_lower_case=args.do_lower_case, + cache_dir=args.cache_dir if args.cache_dir else None, + ) + model = mAutoModel[args.model_type].from_pretrained( + args.model_name_or_path, + from_tf=bool(".ckpt" in args.model_name_or_path), + config=config, + cache_dir=args.cache_dir if args.cache_dir else None, + logic_lambda=args.logic_lambda, + m=args.max_num_questions, + prior=args.prior, + temperature=args.temperature + ) + + # Make sure only the first process in distributed training will download model & vocab + if args.local_rank == 0: + torch.distributed.barrier() + + if args.do_train: + model.to(args.device) + wdblogger.watch(model) + + logger.info("Training/evaluation parameters %s", args) + + # Before we do anything with models, we want to ensure that we get fp16 execution of torch.einsum + # if args.fp16 is set. Otherwise it'll default to "promote" mode, and we'll get fp32 operations. + # Note that running `--fp16_opt_level="O2"` will remove the need for this code, but it is still valid. + if args.fp16: + try: + import apex + apex.amp.register_half_function(torch, "einsum") + except ImportError: + raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") + + # Training + if args.do_train: + global_step, tr_loss = train(args, data_processor, model, tokenizer) + logger.info("global_step = %s, average loss = %s", global_step, tr_loss) + + # Save the trained model and the tokenizer + if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0): + logger.info("Saving model checkpoint to %s", args.output_dir) + # Save a trained model, configuration and tokenizer using `save_pretrained()`. + # They can then be reloaded using `from_pretrained()` + # Take care of distributed/parallel training + model_to_save = model.module if hasattr(model, "module") else model + model_to_save.save_pretrained(args.output_dir) + tokenizer.save_pretrained(args.output_dir) + + # Good practice: save your training arguments together with the trained model + torch.save(args, os.path.join(args.output_dir, "training_args.bin")) + + # Evaluation - we can ask to evaluate all the checkpoints (sub-directories) in a directory + results = {} + if args.do_eval and args.local_rank in [-1, 0]: + checkpoints = [args.output_dir] + if args.eval_all_checkpoints: + checkpoints = list( + os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + "/**/" + WEIGHTS_NAME, recursive=True)) + ) + logging.getLogger("transformers.modeling_utils").setLevel(logging.WARN) # Reduce model loading logs + + logger.info("Evaluate the following checkpoints: %s", checkpoints) + for checkpoint in checkpoints: + global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else "" + model = mAutoModel[args.model_type].from_pretrained( + checkpoint, + logic_lambda=args.logic_lambda, + m=args.max_num_questions, + prior=args.prior, + temperature=args.temperature + ) + model.to(args.device) + + # Evaluate + result = evaluate(args, data_processor, model, tokenizer, prefix=global_step) + result = dict((k + ("_{}".format(global_step) if global_step else ""), v) for k, v in result.items()) + results.update(result) + + print(results) + return results + + +if __name__ == "__main__": + main() diff --git a/src/check_client/utils.py b/src/check_client/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..7a3340a99a76f51f8ca5ff893e861d1f9e45d11d --- /dev/null +++ b/src/check_client/utils.py @@ -0,0 +1,131 @@ +# -*- coding: utf-8 -*- + +""" +@Author : Bao +@Date : 2020/8/12 +@Desc : +@Last modified by : Bao +@Last modified date : 2020/8/12 +""" + +import logging +from numpy.core.fromnumeric import argmax +import ujson as json +import torch +from plm_checkers.checker_utils import soft_logic + + +def init_logger(level, filename=None, mode='a', encoding='utf-8'): + logging_config = { + 'format': '%(asctime)s - %(levelname)s - %(name)s:\t%(message)s', + 'datefmt': '%Y-%m-%d %H:%M:%S', + 'level': level, + 'handlers': [logging.StreamHandler()] + } + if filename: + logging_config['handlers'].append(logging.FileHandler(filename, mode, encoding)) + logging.basicConfig(**logging_config) + + +def read_json(filename, mode='r', encoding='utf-8'): + with open(filename, mode, encoding=encoding) as fin: + return json.load(fin) + + +def save_json(data, filename, mode='w', encoding='utf-8'): + with open(filename, mode, encoding=encoding) as fout: + json.dump(data, fout, ensure_ascii=False, indent=4) + + +def read_json_lines(filename, mode='r', encoding='utf-8', skip=0): + with open(filename, mode, encoding=encoding) as fin: + for line in fin: + if skip > 0: + skip -= 1 + continue + yield json.loads(line) + + +def save_json_lines(data, filename, mode='w', encoding='utf-8', skip=0): + with open(filename, mode, encoding=encoding) as fout: + for line in data: + if skip > 0: + skip -= 1 + continue + print(json.dumps(line, ensure_ascii=False), file=fout) + + +def read_json_dict(filename, mode='r', encoding='utf-8'): + with open(filename, mode, encoding=encoding) as fin: + key_2_id = json.load(fin) + id_2_key = dict(zip(key_2_id.values(), key_2_id.keys())) + + return key_2_id, id_2_key + + +def save_json_dict(data, filename, mode='w', encoding='utf-8'): + with open(filename, mode, encoding=encoding) as fout: + json.dump(data, fout, ensure_ascii=False, indent=4) + + +# Calculate precision, recall and f1 value +# According to https://github.com/dice-group/gerbil/wiki/Precision,-Recall-and-F1-measure +def get_prf(res): + if res['TP'] == 0: + if res['FP'] == 0 and res['FN'] == 0: + precision = 1.0 + recall = 1.0 + f1 = 1.0 + else: + precision = 0.0 + recall = 0.0 + f1 = 0.0 + else: + precision = 1.0 * res['TP'] / (res['TP'] + res['FP']) + recall = 1.0 * res['TP'] / (res['TP'] + res['FN']) + f1 = 2 * precision * recall / (precision + recall) + + return precision, recall, f1 + + +def compute_metrics(truth, predicted, z_predicted, mask): + assert len(truth) == len(predicted) + + outputs = [] + results = {} + cnt = 0 + z_cnt_h, z_cnt_s = 0, 0 + agree_h, agree_s = 0, 0 + for x, y, z, m in zip(truth, predicted, z_predicted, mask): + res = {'label': x, 'prediction': y} + if x == y: + cnt += 1 + + res['pred_z'] = z + + y_ = soft_logic(torch.tensor([z]), torch.tensor([m]))[0] + if y_.argmax(-1).item() == x: + z_cnt_s += 1 + if y_.argmax(-1).item() == y: + agree_s += 1 + + z_h = torch.tensor(z[:torch.tensor(m).sum()]).argmax(-1).tolist() # m' x 3 + if 0 in z_h: # REFUTES + y__ = 0 + elif 1 in z_h: # NEI + y__ = 1 + else: # SUPPPORTS + y__ = 2 + if y__ == x: + z_cnt_h += 1 + if y__ == y: + agree_h += 1 + + outputs.append(res) + + results['Accuracy'] = cnt / len(truth) + results['z_Acc_hard'] = z_cnt_h / len(truth) + results['z_Acc_soft'] = z_cnt_s / len(truth) + results['Agreement_hard'] = agree_h / len(truth) + results['Agreement_soft'] = agree_s / len(truth) + return outputs, results diff --git a/src/cjjpy.py b/src/cjjpy.py new file mode 100755 index 0000000000000000000000000000000000000000..2cc70b5e553924123810ab198c143bf7ee28e5d6 --- /dev/null +++ b/src/cjjpy.py @@ -0,0 +1,249 @@ +# -*- coding: utf-8 -*- + +''' +@Author : Jiangjie Chen +@Time : 2018/11/15 17:08 +@Contact: jjchen19@fudan.edu.cn +''' + +import re +import datetime +import os +import argparse +import logging +import traceback + +try: + import ujson as json +except: + import json + +HADOOP_BIN = 'PATH=/usr/bin/:$PATH hdfs' +FOR_PUBLIC = True + + +def LengthStats(filename): + len_list = [] + thresholds = [0.8, 0.9, 0.95, 0.99, 0.999] + with open(filename) as f: + for line in f: + len_list.append(len(line.strip().split())) + stats = { + 'Max': max(len_list), + 'Min': min(len_list), + 'Avg': round(sum(len_list) / len(len_list), 4), + } + len_list.sort() + for t in thresholds: + stats[f"Top-{t}"] = len_list[int(len(len_list) * t)] + + for k in stats: + print(f"- {k}: {stats[k]}") + return stats + + +class AttrDict(dict): + def __init__(self, *args, **kwargs): + super(AttrDict, self).__init__(*args, **kwargs) + self.__dict__ = self + + +def TraceBack(error_msg): + exc = traceback.format_exc() + msg = f'[Error]: {error_msg}.\n[Traceback]: {exc}' + return msg + + +def Now(): + return datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") + + +def AbsParentDir(file, parent='..', postfix=None): + ppath = os.path.abspath(file) + parent_level = parent.count('.') + while parent_level > 0: + ppath = os.path.dirname(ppath) + parent_level -= 1 + if postfix is not None: + return os.path.join(ppath, postfix) + else: + return ppath + + +def init_logger(log_file=None, log_file_level=logging.NOTSET, from_scratch=False): + from coloredlogs import ColoredFormatter + import tensorflow as tf + + fmt = "[%(asctime)s %(levelname)s] %(message)s" + log_format = ColoredFormatter(fmt=fmt) + # log_format = logging.Formatter() + logger = logging.getLogger() + logger.setLevel(log_file_level) + + console_handler = logging.StreamHandler() + console_handler.setFormatter(log_format) + logger.handlers = [console_handler] + + if log_file and log_file != '': + if from_scratch and tf.io.gfile.exists(log_file): + logger.warning('Removing previous log file: %s' % log_file) + tf.io.gfile.remove(log_file) + path = os.path.dirname(log_file) + os.makedirs(path, exist_ok=True) + file_handler = logging.FileHandler(log_file) + file_handler.setLevel(log_file_level) + file_handler.setFormatter(log_format) + logger.addHandler(file_handler) + + return logger + + +def OverWriteCjjPy(root='.'): + # import difflib + # diff = difflib.HtmlDiff() + cnt = 0 + golden_cjjpy = os.path.join(root, 'cjjpy.py') + # golden_content = open(golden_cjjpy).readlines() + for dir, folder, file in os.walk(root): + for f in file: + if f == 'cjjpy.py': + cjjpy = '%s/%s' % (dir, f) + # content = open(cjjpy).readlines() + # d = diff.make_file(golden_content, content) + cnt += 1 + print('[%d]: %s' % (cnt, cjjpy)) + os.system('cp %s %s' % (golden_cjjpy, cjjpy)) + + +def ChangeFileFormat(filename, new_fmt): + assert type(filename) is str and type(new_fmt) is str + spt = filename.split('.') + if len(spt) == 0: + return filename + else: + return filename.replace('.' + spt[-1], new_fmt) + + +def CountLines(fname): + with open(fname, 'rb') as f: + count = 0 + last_data = '\n' + while True: + data = f.read(0x400000) + if not data: + break + count += data.count(b'\n') + last_data = data + if last_data[-1:] != b'\n': + count += 1 # Remove this if a wc-like count is needed + return count + + +def GetDate(): + return str(datetime.datetime.now())[5:10].replace('-', '') + + +def TimeClock(seconds): + sec = int(seconds) + hour = int(sec / 3600) + min = int((sec - hour * 3600) / 60) + ssec = float(seconds) - hour * 3600 - min * 60 + # return '%dh %dm %.2fs' % (hour, min, ssec) + return '{0:>2d}h{1:>3d}m{2:>6.2f}s'.format(hour, min, ssec) + + +def StripAll(text): + return text.strip().replace('\t', '').replace('\n', '').replace(' ', '') + + +def GetBracket(text, bracket, en_br=False): + # input should be aa(bb)cc, True for bracket, False for text + if bracket: + try: + return re.findall('\((.*?)\)', text.strip())[-1] + except: + return '' + else: + if en_br: + text = re.sub('\(.*?\)', '', text.strip()) + return re.sub('(.*?)', '', text.strip()) + + +def CharLang(uchar, lang): + assert lang.lower() in ['en', 'cn', 'zh'] + if lang.lower() in ['cn', 'zh']: + if uchar >= '\u4e00' and uchar <= '\u9fa5': + return True + else: + return False + elif lang.lower() == 'en': + if (uchar <= 'Z' and uchar >= 'A') or (uchar <= 'z' and uchar >= 'a'): + return True + else: + return False + else: + raise NotImplementedError + + +def WordLang(word, lang): + for i in word.strip(): + if i.isspace(): continue + if not CharLang(i, lang): + return False + return True + + +def SortDict(_dict, reverse=True): + assert type(_dict) is dict + return sorted(_dict.items(), key=lambda d: d[1], reverse=reverse) + + +def lark(content='test'): + print(content) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + + parser.add_argument('--diff', nargs=2, + help='show difference between two files, shown in downloads/diff.html') + parser.add_argument('--de_unicode', action='store_true', default=False, + help='remove unicode characters') + parser.add_argument('--link_entity', action='store_true', default=False, + help='') + parser.add_argument('--max_comm_len', action='store_true', default=False, + help='') + parser.add_argument('--search', nargs=2, + help='search key from file, 2 args: file name & key') + parser.add_argument('--email', nargs=2, + help='sending emails, 2 args: subject & content') + parser.add_argument('--overwrite', action='store_true', default=None, + help='overwrite all cjjpy under given *dir* based on *dir*/cjjpy.py') + parser.add_argument('--replace', nargs=3, + help='replace char, 3 args: file name & replaced char & replacer char') + parser.add_argument('--lark', nargs=1) + parser.add_argument('--get_hdfs', nargs=2, + help='easy copy from hdfs to local fs, 2 args: remote_file/dir & local_dir') + parser.add_argument('--put_hdfs', nargs=2, + help='easy put from local fs to hdfs, 2 args: local_file/dir & remote_dir') + parser.add_argument('--length_stats', nargs=1, + help='simple token lengths distribution of a line-by-line file') + + args = parser.parse_args() + + if args.overwrite: + print('* Overwriting cjjpy...') + OverWriteCjjPy() + + if args.lark: + try: + content = args.lark[0] + except: + content = 'running complete' + print(f'* Larking "{content}"...') + lark(content) + + if args.length_stats: + file = args.length_stats[0] + print(f'* Working on {file} lengths statistics...') + LengthStats(file) diff --git a/src/dataloaders.py b/src/dataloaders.py new file mode 100644 index 0000000000000000000000000000000000000000..fc80da1c26a0a17ec881f408da8efe1c92aa3dbb --- /dev/null +++ b/src/dataloaders.py @@ -0,0 +1,134 @@ +# -*- coding: utf-8 -*- + +''' +@Author : Jiangjie Chen +@Time : 2020/7/20 17:34 +@Contact : jjchen19@fudan.edu.cn +@Description: +''' + +import tensorflow as tf +import cjjpy as cjj +import os +import re +import ujson as json +from collections import defaultdict + +pj_prefix = cjj.AbsParentDir(__file__, '..') + + +class FEVERLoader: + def __init__(self, role): + role = 'dev' if role == 'val' else role + assert role in ['train', 'dev', 'test', 'eval'] + self.role = role + self.fever_data = defaultdict(dict) + self.SUPPORTS = 'SUPPORTS' + self.REFUTES = 'REFUTES' + self.NEI = 'NOT ENOUGH INFO' + + def __iter__(self): + for k in self.fever_data: + yield k + + def __len__(self): + return len(self.fever_data) + + def __getitem__(self, item): + return self.fever_data[item] + + def load_fever(self, retrieve_type='bert', clean_load=True): + self._load_fever_golden() + self._load_fever_all() + self._load_fever_retrieved(retrieve_type, clean_load) + + def _load_json(self, fname): + with tf.io.gfile.GFile(fname) as f: + return [json.loads(x) for x in f.readlines()] + + def _new_role(self): + role = self.role if self.role != 'eval' else 'dev' + return role + + def _load_fever_golden(self): + if self.role == 'test': + postfix = f'data/fever/shared_task_test.jsonl' + for js in self._load_json(f'{pj_prefix}/{postfix}'): + self.fever_data[js['id']].update({ + 'id': js['id'], + 'claim': js['claim'] + }) + else: + role = self._new_role() + postfix = f'data/fever/baked_data/golden_{role}.json' + for js in self._load_json(f'{pj_prefix}/{postfix}'): + self.fever_data[js['id']].update({ + 'id': js['id'], + 'claim': js['claim'], + 'label': js['label'], + 'golden_evidence': self._clean_evidence(js['evidence']) + }) + print('* FEVER golden loaded.') + + def _load_fever_all(self): + role = self._new_role() + postfix = f'data/fever/baked_data/all_{role}.json' + for js in self._load_json(f'{pj_prefix}/{postfix}'): + self.fever_data[js['id']].update({ + 'all_evidence': self._clean_evidence(js['evidence']) + }) + print('* FEVER all loaded.') + + def _load_fever_retrieved(self, retrieve_type, clean_load): + assert retrieve_type in ['bert'] + postfix = f'data/fever/baked_data/{retrieve_type}_{self.role}.json' + for js in self._load_json(f'{pj_prefix}/{postfix}'): + self.fever_data[js['id']].update({ + f'{retrieve_type}_evidence': self._clean_evidence(js['evidence']) if clean_load else js['evidence'] + }) + print(f'* FEVER {retrieve_type} loaded.') + + def clean_text(self, sentence): + sentence = re.sub(" \-LSB\-.*?\-RSB\-", "", sentence) + sentence = re.sub("\-LRB\- \-RRB\- ", "", sentence) + sentence = re.sub(" -LRB-", " ( ", sentence) + sentence = re.sub("-RRB-", " )", sentence) + + sentence = re.sub(" LSB.*?RSB", "", sentence) + sentence = re.sub("LRB RRB ", "", sentence) + sentence = re.sub("LRB", " ( ", sentence) + sentence = re.sub("RRB", " )", sentence) + sentence = re.sub("--", "-", sentence) + sentence = re.sub("``", '"', sentence) + sentence = re.sub("''", '"', sentence) + sentence = re.sub(' ', ' ', sentence) + return sentence + + def clean_title(self, title): + title = re.sub("_", " ", title) + title = re.sub(" -LRB-", " ( ", title) + title = re.sub("-RRB-", " )", title) + title = re.sub("-COLON-", ":", title) + title = re.sub(' ', ' ', title) + return title + + def _clean_evidence(self, evidence): + cev = [] + for ev in evidence: + if len(ev) == 4: + cev.append([self.clean_title(ev[0]), ev[1], self.clean_text(ev[2]), ev[3]]) + elif len(ev) == 3: + cev.append([self.clean_title(ev[0]), ev[1], self.clean_text(ev[2])]) + elif len(ev) == 0: + cev.append(ev) + else: + raise ValueError(ev) + return cev + + +if __name__ == '__main__': + floader = FEVERLoader('test') + floader.load_fever('bert', clean_load=False) + for k in floader: + print(floader[k]) + input() diff --git a/src/er_client/__init__.py b/src/er_client/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1456c18a25f206b09431019c04c3871260f047fa --- /dev/null +++ b/src/er_client/__init__.py @@ -0,0 +1,63 @@ +# -*- coding: utf-8 -*- + +''' +@Author : Jiangjie Chen +@Time : 2020/9/21 16:13 +@Contact : jjchen19@fudan.edu.cn +@Description: +''' + +import cjjpy as cjj +import os +# from .document_retrieval import DocRetrieval +from .doc_retrieval_by_api import DocRetrieval +from .sentence_selection import SentSelector + + +arg_values = { + 'batch_size': 32, + 'dropout': 0.6, + 'use_cuda': True, + 'bert_hidden_dim': 768, + 'layer': 1, + 'num_labels': 3, + 'evi_num': 5, + 'threshold': 0.0, + 'max_len': 120, +} + +args = cjj.AttrDict(arg_values) + +class EvidenceRetrieval: + def __init__(self, er_model_dir=cjj.AbsParentDir(__file__, '...', 'models/evidence_retrieval/')): + # self.doc_retriever = DocRetrieval(cjj.AbsParentDir(__file__, '...', 'data/fever.db'), + # add_claim=True, k_wiki_results=7) + self.doc_retrieval = DocRetrieval(link_type='tagme') + self.sent_selector = SentSelector(os.path.join(er_model_dir, 'bert_base/'), + os.path.join(er_model_dir, 'retrieval_model/model.best.pt'), + args) + + def retrieve(self, claim): + # noun_phrases, wiki_results, predicted_pages = self.doc_retriever.exact_match(claim) + # evidence = [] + # for page in predicted_pages: + # evidence.extend(self.doc_retriever.db.get_doc_lines(page)) + evidence = self.doc_retrieval.retrieve_docs(claim) + evidence = self.rank_sentences(claim, evidence) + return evidence + + def rank_sentences(self, claim, sentences, id=None): + ''' + :param claim: str + :param sentences: [(ent, num, sent) * N] + :param id: + :return: [(ent, num, sent) * k] + ''' + if id is None: + id = len(claim) + + result = self.sent_selector.rank_sentences([{'claim': claim, + 'evidence': sentences, + 'id': id}]) + evidence = result.get(id, []) + return evidence \ No newline at end of file diff --git a/src/er_client/cjjpy.py b/src/er_client/cjjpy.py new file mode 100755 index 0000000000000000000000000000000000000000..2cc70b5e553924123810ab198c143bf7ee28e5d6 --- /dev/null +++ b/src/er_client/cjjpy.py @@ -0,0 +1,249 @@ +# -*- coding: utf-8 -*- + +''' +@Author : Jiangjie Chen +@Time : 2018/11/15 17:08 +@Contact: jjchen19@fudan.edu.cn +''' + +import re +import datetime +import os +import argparse +import logging +import traceback + +try: + import ujson as json +except: + import json + +HADOOP_BIN = 'PATH=/usr/bin/:$PATH hdfs' +FOR_PUBLIC = True + + +def LengthStats(filename): + len_list = [] + thresholds = [0.8, 0.9, 0.95, 0.99, 0.999] + with open(filename) as f: + for line in f: + len_list.append(len(line.strip().split())) + stats = { + 'Max': max(len_list), + 'Min': min(len_list), + 'Avg': round(sum(len_list) / len(len_list), 4), + } + len_list.sort() + for t in thresholds: + stats[f"Top-{t}"] = len_list[int(len(len_list) * t)] + + for k in stats: + print(f"- {k}: {stats[k]}") + return stats + + +class AttrDict(dict): + def __init__(self, *args, **kwargs): + super(AttrDict, self).__init__(*args, **kwargs) + self.__dict__ = self + + +def TraceBack(error_msg): + exc = traceback.format_exc() + msg = f'[Error]: {error_msg}.\n[Traceback]: {exc}' + return msg + + +def Now(): + return datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") + + +def AbsParentDir(file, parent='..', postfix=None): + ppath = os.path.abspath(file) + parent_level = parent.count('.') + while parent_level > 0: + ppath = os.path.dirname(ppath) + parent_level -= 1 + if postfix is not None: + return os.path.join(ppath, postfix) + else: + return ppath + + +def init_logger(log_file=None, log_file_level=logging.NOTSET, from_scratch=False): + from coloredlogs import ColoredFormatter + import tensorflow as tf + + fmt = "[%(asctime)s %(levelname)s] %(message)s" + log_format = ColoredFormatter(fmt=fmt) + # log_format = logging.Formatter() + logger = logging.getLogger() + logger.setLevel(log_file_level) + + console_handler = logging.StreamHandler() + console_handler.setFormatter(log_format) + logger.handlers = [console_handler] + + if log_file and log_file != '': + if from_scratch and tf.io.gfile.exists(log_file): + logger.warning('Removing previous log file: %s' % log_file) + tf.io.gfile.remove(log_file) + path = os.path.dirname(log_file) + os.makedirs(path, exist_ok=True) + file_handler = logging.FileHandler(log_file) + file_handler.setLevel(log_file_level) + file_handler.setFormatter(log_format) + logger.addHandler(file_handler) + + return logger + + +def OverWriteCjjPy(root='.'): + # import difflib + # diff = difflib.HtmlDiff() + cnt = 0 + golden_cjjpy = os.path.join(root, 'cjjpy.py') + # golden_content = open(golden_cjjpy).readlines() + for dir, folder, file in os.walk(root): + for f in file: + if f == 'cjjpy.py': + cjjpy = '%s/%s' % (dir, f) + # content = open(cjjpy).readlines() + # d = diff.make_file(golden_content, content) + cnt += 1 + print('[%d]: %s' % (cnt, cjjpy)) + os.system('cp %s %s' % (golden_cjjpy, cjjpy)) + + +def ChangeFileFormat(filename, new_fmt): + assert type(filename) is str and type(new_fmt) is str + spt = filename.split('.') + if len(spt) == 0: + return filename + else: + return filename.replace('.' + spt[-1], new_fmt) + + +def CountLines(fname): + with open(fname, 'rb') as f: + count = 0 + last_data = '\n' + while True: + data = f.read(0x400000) + if not data: + break + count += data.count(b'\n') + last_data = data + if last_data[-1:] != b'\n': + count += 1 # Remove this if a wc-like count is needed + return count + + +def GetDate(): + return str(datetime.datetime.now())[5:10].replace('-', '') + + +def TimeClock(seconds): + sec = int(seconds) + hour = int(sec / 3600) + min = int((sec - hour * 3600) / 60) + ssec = float(seconds) - hour * 3600 - min * 60 + # return '%dh %dm %.2fs' % (hour, min, ssec) + return '{0:>2d}h{1:>3d}m{2:>6.2f}s'.format(hour, min, ssec) + + +def StripAll(text): + return text.strip().replace('\t', '').replace('\n', '').replace(' ', '') + + +def GetBracket(text, bracket, en_br=False): + # input should be aa(bb)cc, True for bracket, False for text + if bracket: + try: + return re.findall('\((.*?)\)', text.strip())[-1] + except: + return '' + else: + if en_br: + text = re.sub('\(.*?\)', '', text.strip()) + return re.sub('(.*?)', '', text.strip()) + + +def CharLang(uchar, lang): + assert lang.lower() in ['en', 'cn', 'zh'] + if lang.lower() in ['cn', 'zh']: + if uchar >= '\u4e00' and uchar <= '\u9fa5': + return True + else: + return False + elif lang.lower() == 'en': + if (uchar <= 'Z' and uchar >= 'A') or (uchar <= 'z' and uchar >= 'a'): + return True + else: + return False + else: + raise NotImplementedError + + +def WordLang(word, lang): + for i in word.strip(): + if i.isspace(): continue + if not CharLang(i, lang): + return False + return True + + +def SortDict(_dict, reverse=True): + assert type(_dict) is dict + return sorted(_dict.items(), key=lambda d: d[1], reverse=reverse) + + +def lark(content='test'): + print(content) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + + parser.add_argument('--diff', nargs=2, + help='show difference between two files, shown in downloads/diff.html') + parser.add_argument('--de_unicode', action='store_true', default=False, + help='remove unicode characters') + parser.add_argument('--link_entity', action='store_true', default=False, + help='') + parser.add_argument('--max_comm_len', action='store_true', default=False, + help='') + parser.add_argument('--search', nargs=2, + help='search key from file, 2 args: file name & key') + parser.add_argument('--email', nargs=2, + help='sending emails, 2 args: subject & content') + parser.add_argument('--overwrite', action='store_true', default=None, + help='overwrite all cjjpy under given *dir* based on *dir*/cjjpy.py') + parser.add_argument('--replace', nargs=3, + help='replace char, 3 args: file name & replaced char & replacer char') + parser.add_argument('--lark', nargs=1) + parser.add_argument('--get_hdfs', nargs=2, + help='easy copy from hdfs to local fs, 2 args: remote_file/dir & local_dir') + parser.add_argument('--put_hdfs', nargs=2, + help='easy put from local fs to hdfs, 2 args: local_file/dir & remote_dir') + parser.add_argument('--length_stats', nargs=1, + help='simple token lengths distribution of a line-by-line file') + + args = parser.parse_args() + + if args.overwrite: + print('* Overwriting cjjpy...') + OverWriteCjjPy() + + if args.lark: + try: + content = args.lark[0] + except: + content = 'running complete' + print(f'* Larking "{content}"...') + lark(content) + + if args.length_stats: + file = args.length_stats[0] + print(f'* Working on {file} lengths statistics...') + LengthStats(file) diff --git a/src/er_client/doc_retrieval_by_api.py b/src/er_client/doc_retrieval_by_api.py new file mode 100644 index 0000000000000000000000000000000000000000..fa4055ef0f3b8273d9b4eae0e57364961658756c --- /dev/null +++ b/src/er_client/doc_retrieval_by_api.py @@ -0,0 +1,44 @@ +# -*- coding: utf-8 -*- + +''' +@Author : Jiangjie Chen +@Time : 2020/11/12 21:19 +@Contact : jjchen19@fudan.edu.cn +@Description: +''' + +import wikipediaapi +import nltk +from nltk.tokenize import sent_tokenize +nltk.download('punkt') +try: + from entitylinker import ELClient +except: + from .entitylinker import ELClient + + +class DocRetrieval: + def __init__(self, link_type): + self.wiki = wikipediaapi.Wikipedia('en') + self.er_client = ELClient(link_type, verbose=True) + + def _get_page(self, title): + summary = self.wiki.page(title).summary + sents = [] + for i, sent in enumerate(sent_tokenize(summary)): + sents.append((title, i, sent, 0)) + return sents + + def retrieve_docs(self, claim): + el_results = self.er_client.link(claim) + sents = [] + for text, label, kb_id, title in el_results: + if title == '': continue + sents += self._get_page(title) + return sents + + +if __name__ == '__main__': + doc = DocRetrieval('tagme') + print(doc.retrieve_docs('joe biden won the U.S. president.')) + print(doc.retrieve_docs('Joe Biden won the U.S. president.')) \ No newline at end of file diff --git a/src/er_client/document_retrieval.py b/src/er_client/document_retrieval.py new file mode 100644 index 0000000000000000000000000000000000000000..3b83fb17cb4f72a72c13151d3320eb5a5ff1da4f --- /dev/null +++ b/src/er_client/document_retrieval.py @@ -0,0 +1,225 @@ +# -*- coding:utf-8 -*- + +""" +@Author : Bao +@Date : 2020/9/17 +@Desc : Document selection and sentence ranking code from KGAT. Not used in LOREN. +@Last modified by : Bao +@Last modified date : 2020/9/17 +""" + +import re +import time +import json +import nltk +from tqdm import tqdm +from allennlp.predictors import Predictor +from drqa.retriever import DocDB, utils +from drqa.retriever.utils import normalize +import wikipedia + + +class FeverDocDB(DocDB): + def __init__(self, path=None): + super().__init__(path) + + def get_doc_lines(self, doc_id): + """Fetch the raw text of the doc for 'doc_id'.""" + cursor = self.connection.cursor() + cursor.execute( + "SELECT lines FROM documents WHERE id = ?", + (utils.normalize(doc_id),) + ) + result = cursor.fetchone() + cursor.close() + + result = result[0] if result is not None else '' + doc_lines = [] + for line in result.split('\n'): + if len(line) == 0: continue + line = line.split('\t')[1] + if len(line) == 0: continue + doc_lines.append((doc_id, len(doc_lines), line, 0)) + + return doc_lines + + def get_non_empty_doc_ids(self): + """Fetch all ids of docs stored in the db.""" + cursor = self.connection.cursor() + cursor.execute("SELECT id FROM documents WHERE length(trim(text)) > 0") + results = [r[0] for r in cursor.fetchall()] + cursor.close() + return results + + +class DocRetrieval: + def __init__(self, database_path, add_claim=False, k_wiki_results=None): + self.db = FeverDocDB(database_path) + self.add_claim = add_claim + self.k_wiki_results = k_wiki_results + self.porter_stemmer = nltk.PorterStemmer() + self.tokenizer = nltk.word_tokenize + self.predictor = Predictor.from_path( + "https://storage.googleapis.com/allennlp-public-models/elmo-constituency-parser-2020.02.10.tar.gz" + ) + + def get_NP(self, tree, nps): + if isinstance(tree, dict): + if "children" not in tree: + if tree['nodeType'] == "NP": + # print(tree['word']) + # print(tree) + nps.append(tree['word']) + elif "children" in tree: + if tree['nodeType'] == "NP": + # print(tree['word']) + nps.append(tree['word']) + self.get_NP(tree['children'], nps) + else: + self.get_NP(tree['children'], nps) + elif isinstance(tree, list): + for sub_tree in tree: + self.get_NP(sub_tree, nps) + + return nps + + def get_subjects(self, tree): + subject_words = [] + subjects = [] + for subtree in tree['children']: + if subtree['nodeType'] == "VP" or subtree['nodeType'] == 'S' or subtree['nodeType'] == 'VBZ': + subjects.append(' '.join(subject_words)) + subject_words.append(subtree['word']) + else: + subject_words.append(subtree['word']) + return subjects + + def get_noun_phrases(self, claim): + tokens = self.predictor.predict(claim) + nps = [] + tree = tokens['hierplane_tree']['root'] + noun_phrases = self.get_NP(tree, nps) + subjects = self.get_subjects(tree) + for subject in subjects: + if len(subject) > 0: + noun_phrases.append(subject) + if self.add_claim: + noun_phrases.append(claim) + return list(set(noun_phrases)) + + def get_doc_for_claim(self, noun_phrases): + predicted_pages = [] + for np in noun_phrases: + if len(np) > 300: + continue + i = 1 + while i < 12: + try: + # print(np) + # res = server.lookup(np, keep_all=True) + # docs = [y for _, y in res] if res is not None else [] + docs = wikipedia.search(np) + if self.k_wiki_results is not None: + predicted_pages.extend(docs[:self.k_wiki_results]) + else: + predicted_pages.extend(docs) + except (ConnectionResetError, ConnectionError, ConnectionAbortedError, ConnectionRefusedError): + print("Connection reset error received! Trial #" + str(i)) + time.sleep(600 * i) + i += 1 + else: + break + + # sleep_num = random.uniform(0.1,0.7) + # time.sleep(sleep_num) + predicted_pages = set(predicted_pages) + processed_pages = [] + for page in predicted_pages: + page = page.replace(" ", "_") + page = page.replace("(", "-LRB-") + page = page.replace(")", "-RRB-") + page = page.replace(":", "-COLON-") + processed_pages.append(page) + + return processed_pages + + def np_conc(self, noun_phrases): + noun_phrases = set(noun_phrases) + predicted_pages = [] + for np in noun_phrases: + page = np.replace('( ', '-LRB-') + page = page.replace(' )', '-RRB-') + page = page.replace(' - ', '-') + page = page.replace(' :', '-COLON-') + page = page.replace(' ,', ',') + page = page.replace(" 's", "'s") + page = page.replace(' ', '_') + + if len(page) < 1: + continue + doc_lines = self.db.get_doc_lines(page) + if len(doc_lines) > 0: + predicted_pages.append(page) + return predicted_pages + + def exact_match(self, claim): + noun_phrases = self.get_noun_phrases(claim) + wiki_results = self.get_doc_for_claim(noun_phrases) + wiki_results = list(set(wiki_results)) + + claim = claim.replace(".", "") + claim = claim.replace("-", " ") + words = [self.porter_stemmer.stem(word.lower()) for word in self.tokenizer(claim)] + words = set(words) + predicted_pages = self.np_conc(noun_phrases) + + for page in wiki_results: + page = normalize(page) + processed_page = re.sub("-LRB-.*?-RRB-", "", page) + processed_page = re.sub("_", " ", processed_page) + processed_page = re.sub("-COLON-", ":", processed_page) + processed_page = processed_page.replace("-", " ") + processed_page = processed_page.replace("–", " ") + processed_page = processed_page.replace(".", "") + page_words = [self.porter_stemmer.stem(word.lower()) for word in self.tokenizer(processed_page) if + len(word) > 0] + + if all([item in words for item in page_words]): + if ':' in page: + page = page.replace(":", "-COLON-") + predicted_pages.append(page) + predicted_pages = list(set(predicted_pages)) + + return noun_phrases, wiki_results, predicted_pages + + +def save_to_file(results, client, filename): + with open(filename, 'w', encoding='utf-8') as fout: + for _id, line in enumerate(results): + claim = line['claim'] + evidence = [] + for page in line['predicted_pages']: + evidence.extend(client.db.get_doc_lines(page)) + print(json.dumps({'claim': claim, 'evidence': evidence}, ensure_ascii=False), file=fout) + + +if __name__ == '__main__': + database_path = 'data/fever.db' + add_claim = True + k_wiki_results = 7 + client = DocRetrieval(database_path, add_claim, k_wiki_results) + + results = [] + with open('data/claims.json', 'r', encoding='utf-8') as fin: + for line in tqdm(fin): + line = json.loads(line) + _, _, predicted_pages = client.exact_match(line['claim']) + evidence = [] + for page in predicted_pages: + evidence.extend(client.db.get_doc_lines(page)) + line['evidence'] = evidence + results.append(line) + + with open('data/pages.json', 'w', encoding='utf-8') as fout: + for line in results: + print(json.dumps(line, ensure_ascii=False), file=fout) diff --git a/src/er_client/entitylinker.py b/src/er_client/entitylinker.py new file mode 100644 index 0000000000000000000000000000000000000000..4ce2ca9289471fb7e62277046c42efe6eb1f2a03 --- /dev/null +++ b/src/er_client/entitylinker.py @@ -0,0 +1,84 @@ +# -*- coding: utf-8 -*- + +''' +@Author : Jiangjie Chen +@Time : 2020/5/11 19:08 +@Contact : jjchen19@fudan.edu.cn +@Description: +''' + +import os +import tagme + + +def read_title_id(entity_def_path): + id_to_title = {} + with open(entity_def_path, 'r', encoding='UTF-8') as f: + lines = f.readlines() + for i, line in enumerate(lines): + if i > 0: + entity, id = line.strip().split('|') + id_to_title[id] = entity + + return id_to_title + + +class ELClient: + def __init__(self, link_type, min_rho=0.1, prefix=None, verbose=False): + self.verbose = verbose + self.link_type = link_type + if link_type == 'tagme': + self.min_rho = min_rho + tagme.GCUBE_TOKEN = os.environ['TAGME_APIKEY'] + elif link_type == 'spacy': + assert prefix is not None + self.init_spacy_linker(prefix) + else: + raise NotImplementedError(link_type) + + def init_spacy_linker(self, prefix): + entity_def_path = f"{prefix}/entity_defs.csv" + self._print('* Loading entity linker...') + self.nlp = spacy.load(prefix) + self.id2title = read_title_id(entity_def_path) + self._print('* Entity linker loaded.') + + def _tagme_link(self, text): + result = [] + for ann in tagme.annotate(text, long_text=1).get_annotations(min_rho=self.min_rho): + result.append((text[ann.begin:ann.end], ann.score, ann.entity_id, ann.entity_title)) + # result.append({'begin': ann.begin, + # 'end': ann.end, + # 'id': ann.entity_id, + # 'title': ann.entity_title, + # 'score': ann.score}) + result.sort(key=lambda x: x[1], reverse=True) + return result + + def link(self, text): + if self.link_type == 'tagme': + return self._tagme_link(text) + else: + return self._spacy_link(text) + + def _spacy_link(self, text): + text = self._preprocess_text(text) + doc = self.nlp(text) + ents = [(e.text, e.label_, e.kb_id_, self.id2title.get(e.kb_id_, '')) + for e in doc.ents if e.kb_id_ != 'NIL'] + return ents + + def _preprocess_text(self, text): + if isinstance(text, list): + text = ' '.join(text) + text = text.strip().replace('-lrb-', '(').replace('-rrb-', ')') + return text + + def _print(self, x): + if self.verbose: print(x) + + +if __name__ == '__main__': + elcl = ELClient(link_type='tagme', verbose=True) + res = elcl.link('Jeff Dean wants to meet Yoshua Bengio.') + print(res) diff --git a/src/er_client/retrieval_model/bert_model.py b/src/er_client/retrieval_model/bert_model.py new file mode 100755 index 0000000000000000000000000000000000000000..006173823070eee5fb9bc41814d6b274d1c13a50 --- /dev/null +++ b/src/er_client/retrieval_model/bert_model.py @@ -0,0 +1,775 @@ +# coding=utf-8 +# Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch BERT model.""" + +from __future__ import absolute_import, division, print_function, unicode_literals + +import copy +import json +import logging +import math +import os +import shutil +import tarfile +import tempfile +import sys +from io import open + +import torch +from torch import nn +from torch.nn import CrossEntropyLoss + +from .file_utils import cached_path + +logger = logging.getLogger(__name__) + +PRETRAINED_MODEL_ARCHIVE_MAP = { + 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz", + 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased.tar.gz", + 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased.tar.gz", + 'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased.tar.gz", + 'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased.tar.gz", + 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased.tar.gz", + 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese.tar.gz", +} +CONFIG_NAME = 'bert_config.json' +WEIGHTS_NAME = 'pytorch_model.bin' +TF_WEIGHTS_NAME = 'model.ckpt' + +def load_tf_weights_in_bert(model, tf_checkpoint_path): + """ Load tf checkpoints in a pytorch model + """ + try: + import re + import numpy as np + import tensorflow as tf + except ImportError: + print("Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see " + "https://www.tensorflow.org/install/ for installation instructions.") + raise + tf_path = os.path.abspath(tf_checkpoint_path) + print("Converting TensorFlow checkpoint from {}".format(tf_path)) + # Load weights from TF model + init_vars = tf.train.list_variables(tf_path) + names = [] + arrays = [] + for name, shape in init_vars: + print("Loading TF weight {} with shape {}".format(name, shape)) + array = tf.train.load_variable(tf_path, name) + names.append(name) + arrays.append(array) + + for name, array in zip(names, arrays): + name = name.split('/') + # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v + # which are not required for using pretrained model + if any(n in ["adam_v", "adam_m"] for n in name): + print("Skipping {}".format("/".join(name))) + continue + pointer = model + for m_name in name: + if re.fullmatch(r'[A-Za-z]+_\d+', m_name): + l = re.split(r'_(\d+)', m_name) + else: + l = [m_name] + if l[0] == 'kernel' or l[0] == 'gamma': + pointer = getattr(pointer, 'weight') + elif l[0] == 'output_bias' or l[0] == 'beta': + pointer = getattr(pointer, 'bias') + elif l[0] == 'output_weights': + pointer = getattr(pointer, 'weight') + else: + pointer = getattr(pointer, l[0]) + if len(l) >= 2: + num = int(l[1]) + pointer = pointer[num] + if m_name[-11:] == '_embeddings': + pointer = getattr(pointer, 'weight') + elif m_name == 'kernel': + array = np.transpose(array) + try: + assert pointer.shape == array.shape + except AssertionError as e: + e.args += (pointer.shape, array.shape) + raise + print("Initialize PyTorch weight {}".format(name)) + pointer.data = torch.from_numpy(array) + return model + + +def gelu(x): + """Implementation of the gelu activation function. + For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): + 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) + Also see https://arxiv.org/abs/1606.08415 + """ + return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) + + +def swish(x): + return x * torch.sigmoid(x) + + +ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish} + + +class BertConfig(object): + """Configuration class to store the configuration of a `BertModel`. + """ + def __init__(self, + vocab_size_or_config_json_file, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=2, + initializer_range=0.02): + """Constructs BertConfig. + + Args: + vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `BertModel`. + hidden_size: Size of the encoder layers and the pooler layer. + num_hidden_layers: Number of hidden layers in the Transformer encoder. + num_attention_heads: Number of attention heads for each attention layer in + the Transformer encoder. + intermediate_size: The size of the "intermediate" (i.e., feed-forward) + layer in the Transformer encoder. + hidden_act: The non-linear activation function (function or string) in the + encoder and pooler. If string, "gelu", "relu" and "swish" are supported. + hidden_dropout_prob: The dropout probabilitiy for all fully connected + layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob: The dropout ratio for the attention + probabilities. + max_position_embeddings: The maximum sequence length that this model might + ever be used with. Typically set this to something large just in case + (e.g., 512 or 1024 or 2048). + type_vocab_size: The vocabulary size of the `token_type_ids` passed into + `BertModel`. + initializer_range: The sttdev of the truncated_normal_initializer for + initializing all weight matrices. + """ + if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2 + and isinstance(vocab_size_or_config_json_file, unicode)): + with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader: + json_config = json.loads(reader.read()) + for key, value in json_config.items(): + self.__dict__[key] = value + elif isinstance(vocab_size_or_config_json_file, int): + self.vocab_size = vocab_size_or_config_json_file + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.hidden_act = hidden_act + self.intermediate_size = intermediate_size + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.initializer_range = initializer_range + else: + raise ValueError("First argument must be either a vocabulary size (int)" + "or the path to a pretrained model config file (str)") + + @classmethod + def from_dict(cls, json_object): + """Constructs a `BertConfig` from a Python dictionary of parameters.""" + config = BertConfig(vocab_size_or_config_json_file=-1) + for key, value in json_object.items(): + config.__dict__[key] = value + return config + + @classmethod + def from_json_file(cls, json_file): + """Constructs a `BertConfig` from a json file of parameters.""" + with open(json_file, "r", encoding='utf-8') as reader: + text = reader.read() + return cls.from_dict(json.loads(text)) + + def __repr__(self): + return str(self.to_json_string()) + + def to_dict(self): + """Serializes this instance to a Python dictionary.""" + output = copy.deepcopy(self.__dict__) + return output + + def to_json_string(self): + """Serializes this instance to a JSON string.""" + return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" + +try: + from apex.normalization.fused_layer_norm import FusedLayerNorm as BertLayerNorm +except ImportError: + print("Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.") + class BertLayerNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-12): + """Construct a layernorm module in the TF style (epsilon inside the square root). + """ + super(BertLayerNorm, self).__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.bias = nn.Parameter(torch.zeros(hidden_size)) + self.variance_epsilon = eps + + def forward(self, x): + u = x.mean(-1, keepdim=True) + s = (x - u).pow(2).mean(-1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.variance_epsilon) + return self.weight * x + self.bias + +class BertEmbeddings(nn.Module): + """Construct the embeddings from word, position and token_type embeddings. + """ + def __init__(self, config): + super(BertEmbeddings, self).__init__() + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, input_ids, token_type_ids=None): + seq_length = input_ids.size(1) + position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) + position_ids = position_ids.unsqueeze(0).expand_as(input_ids) + if token_type_ids is None: + token_type_ids = torch.zeros_like(input_ids) + + words_embeddings = self.word_embeddings(input_ids) + position_embeddings = self.position_embeddings(position_ids) + token_type_embeddings = self.token_type_embeddings(token_type_ids) + + embeddings = words_embeddings + position_embeddings + token_type_embeddings + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +class BertSelfAttention(nn.Module): + def __init__(self, config): + super(BertSelfAttention, self).__init__() + if config.hidden_size % config.num_attention_heads != 0: + raise ValueError( + "The hidden size (%d) is not a multiple of the number of attention " + "heads (%d)" % (config.hidden_size, config.num_attention_heads)) + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward(self, hidden_states, attention_mask): + mixed_query_layer = self.query(hidden_states) + mixed_key_layer = self.key(hidden_states) + mixed_value_layer = self.value(hidden_states) + + query_layer = self.transpose_for_scores(mixed_query_layer) + key_layer = self.transpose_for_scores(mixed_key_layer) + value_layer = self.transpose_for_scores(mixed_value_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + # Apply the attention mask is (precomputed for all layers in BertModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.Softmax(dim=-1)(attention_scores) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + context_layer = torch.matmul(attention_probs, value_layer) + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) + return context_layer + + +class BertSelfOutput(nn.Module): + def __init__(self, config): + super(BertSelfOutput, self).__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertAttention(nn.Module): + def __init__(self, config): + super(BertAttention, self).__init__() + self.self = BertSelfAttention(config) + self.output = BertSelfOutput(config) + + def forward(self, input_tensor, attention_mask): + self_output = self.self(input_tensor, attention_mask) + attention_output = self.output(self_output, input_tensor) + return attention_output + + +class BertIntermediate(nn.Module): + def __init__(self, config): + super(BertIntermediate, self).__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str) or (sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode)): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class BertOutput(nn.Module): + def __init__(self, config): + super(BertOutput, self).__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertLayer(nn.Module): + def __init__(self, config): + super(BertLayer, self).__init__() + self.attention = BertAttention(config) + self.intermediate = BertIntermediate(config) + self.output = BertOutput(config) + + def forward(self, hidden_states, attention_mask): + attention_output = self.attention(hidden_states, attention_mask) + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + +class BertEncoder(nn.Module): + def __init__(self, config): + super(BertEncoder, self).__init__() + layer = BertLayer(config) + self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)]) + + def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True): + all_encoder_layers = [] + for layer_module in self.layer: + hidden_states = layer_module(hidden_states, attention_mask) + if output_all_encoded_layers: + all_encoder_layers.append(hidden_states) + if not output_all_encoded_layers: + all_encoder_layers.append(hidden_states) + return all_encoder_layers + + +class BertPooler(nn.Module): + def __init__(self, config): + super(BertPooler, self).__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +class BertPredictionHeadTransform(nn.Module): + def __init__(self, config): + super(BertPredictionHeadTransform, self).__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + if isinstance(config.hidden_act, str) or (sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode)): + self.transform_act_fn = ACT2FN[config.hidden_act] + else: + self.transform_act_fn = config.hidden_act + self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + +class BertLMPredictionHead(nn.Module): + def __init__(self, config, bert_model_embedding_weights): + super(BertLMPredictionHead, self).__init__() + self.transform = BertPredictionHeadTransform(config) + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder = nn.Linear(bert_model_embedding_weights.size(1), + bert_model_embedding_weights.size(0), + bias=False) + self.decoder.weight = bert_model_embedding_weights + self.bias = nn.Parameter(torch.zeros(bert_model_embedding_weights.size(0))) + + def forward(self, hidden_states): + hidden_states = self.transform(hidden_states) + hidden_states = self.decoder(hidden_states) + self.bias + return hidden_states + + +class BertOnlyMLMHead(nn.Module): + def __init__(self, config, bert_model_embedding_weights): + super(BertOnlyMLMHead, self).__init__() + self.predictions = BertLMPredictionHead(config, bert_model_embedding_weights) + + def forward(self, sequence_output): + prediction_scores = self.predictions(sequence_output) + return prediction_scores + + +class BertOnlyNSPHead(nn.Module): + def __init__(self, config): + super(BertOnlyNSPHead, self).__init__() + self.seq_relationship = nn.Linear(config.hidden_size, 2) + + def forward(self, pooled_output): + seq_relationship_score = self.seq_relationship(pooled_output) + return seq_relationship_score + + +class BertPreTrainingHeads(nn.Module): + def __init__(self, config, bert_model_embedding_weights): + super(BertPreTrainingHeads, self).__init__() + self.predictions = BertLMPredictionHead(config, bert_model_embedding_weights) + self.seq_relationship = nn.Linear(config.hidden_size, 2) + + def forward(self, sequence_output, pooled_output): + prediction_scores = self.predictions(sequence_output) + seq_relationship_score = self.seq_relationship(pooled_output) + return prediction_scores, seq_relationship_score + + +class BertPreTrainedModel(nn.Module): + """ An abstract class to handle weights initialization and + a simple interface for dowloading and loading pretrained models. + """ + def __init__(self, config, *inputs, **kwargs): + super(BertPreTrainedModel, self).__init__() + if not isinstance(config, BertConfig): + raise ValueError( + "Parameter config in `{}(config)` should be an instance of class `BertConfig`. " + "To create a model from a Google pretrained model use " + "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format( + self.__class__.__name__, self.__class__.__name__ + )) + self.config = config + + def init_bert_weights(self, module): + """ Initialize the weights. + """ + if isinstance(module, (nn.Linear, nn.Embedding)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + elif isinstance(module, BertLayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, state_dict=None, cache_dir=None, + from_tf=False, *inputs, **kwargs): + """ + Instantiate a BertPreTrainedModel from a pre-trained model file or a pytorch state dict. + Download and cache the pre-trained model file if needed. + + Params: + pretrained_model_name_or_path: either: + - a str with the name of a pre-trained model to load selected in the list of: + . `bert-base-uncased` + . `bert-large-uncased` + . `bert-base-cased` + . `bert-large-cased` + . `bert-base-multilingual-uncased` + . `bert-base-multilingual-cased` + . `bert-base-chinese` + - a path or url to a pretrained model archive containing: + . `bert_config.json` a configuration file for the model + . `pytorch_model.bin` a PyTorch dump of a BertForPreTraining instance + - a path or url to a pretrained model archive containing: + . `bert_config.json` a configuration file for the model + . `model.chkpt` a TensorFlow checkpoint + from_tf: should we load the weights from a locally saved TensorFlow checkpoint + cache_dir: an optional path to a folder in which the pre-trained models will be cached. + state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of Google pre-trained models + *inputs, **kwargs: additional input for the specific Bert class + (ex: num_labels for BertForSequenceClassification) + """ + if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP: + archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name_or_path] + else: + archive_file = pretrained_model_name_or_path + # redirect to the cache, if necessary + try: + resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir) + except EnvironmentError: + logger.error( + "Model name '{}' was not found in model name list ({}). " + "We assumed '{}' was a path or url but couldn't find any file " + "associated to this path or url.".format( + pretrained_model_name_or_path, + ', '.join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()), + archive_file)) + return None + if resolved_archive_file == archive_file: + logger.info("loading archive file {}".format(archive_file)) + else: + logger.info("loading archive file {} from cache at {}".format( + archive_file, resolved_archive_file)) + tempdir = None + if os.path.isdir(resolved_archive_file) or from_tf: + serialization_dir = resolved_archive_file + else: + # Extract archive to temp dir + tempdir = tempfile.mkdtemp() + logger.info("extracting archive file {} to temp dir {}".format( + resolved_archive_file, tempdir)) + with tarfile.open(resolved_archive_file, 'r:gz') as archive: + archive.extractall(tempdir) + serialization_dir = tempdir + # Load config + config_file = os.path.join(serialization_dir, CONFIG_NAME) + config = BertConfig.from_json_file(config_file) + logger.info("Model config {}".format(config)) + # Instantiate model. + model = cls(config, *inputs, **kwargs) + if state_dict is None and not from_tf: + weights_path = os.path.join(serialization_dir, WEIGHTS_NAME) + state_dict = torch.load(weights_path, map_location='cpu' if not torch.cuda.is_available() else None) + if tempdir: + # Clean up temp dir + shutil.rmtree(tempdir) + if from_tf: + # Directly load from a TensorFlow checkpoint + weights_path = os.path.join(serialization_dir, TF_WEIGHTS_NAME) + return load_tf_weights_in_bert(model, weights_path) + # Load from a PyTorch state_dict + old_keys = [] + new_keys = [] + for key in state_dict.keys(): + new_key = None + if 'gamma' in key: + new_key = key.replace('gamma', 'weight') + if 'beta' in key: + new_key = key.replace('beta', 'bias') + if new_key: + old_keys.append(key) + new_keys.append(new_key) + for old_key, new_key in zip(old_keys, new_keys): + state_dict[new_key] = state_dict.pop(old_key) + + missing_keys = [] + unexpected_keys = [] + error_msgs = [] + # copy state_dict so _load_from_state_dict can modify it + metadata = getattr(state_dict, '_metadata', None) + state_dict = state_dict.copy() + if metadata is not None: + state_dict._metadata = metadata + + def load(module, prefix=''): + local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) + module._load_from_state_dict( + state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) + for name, child in module._modules.items(): + if child is not None: + load(child, prefix + name + '.') + start_prefix = '' + if not hasattr(model, 'bert') and any(s.startswith('bert.') for s in state_dict.keys()): + start_prefix = 'bert.' + load(model, prefix=start_prefix) + if len(missing_keys) > 0: + logger.info("Weights of {} not initialized from pretrained model: {}".format( + model.__class__.__name__, missing_keys)) + if len(unexpected_keys) > 0: + logger.info("Weights from pretrained model not used in {}: {}".format( + model.__class__.__name__, unexpected_keys)) + if len(error_msgs) > 0: + raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( + model.__class__.__name__, "\n\t".join(error_msgs))) + return model + + +class BertModel(BertPreTrainedModel): + """BERT model ("Bidirectional Embedding Representations from a Transformer"). + + Params: + config: a BertConfig class instance with the configuration to build a new model + + Inputs: + `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] + with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts + `extract_features.py`, `run_classifier.py` and `run_squad.py`) + `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token + types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to + a `sentence B` token (see BERT paper for more details). + `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices + selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max + input sequence length in the current batch. It's the mask that we typically use for attention when + a batch has varying length sentences. + `output_all_encoded_layers`: boolean which controls the content of the `encoded_layers` output as described below. Default: `True`. + + Outputs: Tuple of (encoded_layers, pooled_output) + `encoded_layers`: controled by `output_all_encoded_layers` argument: + - `output_all_encoded_layers=True`: output a list of the full sequences of encoded-hidden-states at the end + of each attention block (i.e. 12 full sequences for BERT-base, 24 for BERT-large), each + encoded-hidden-state is a torch.FloatTensor of size [batch_size, sequence_length, hidden_size], + - `output_all_encoded_layers=False`: output only the full sequence of hidden-states corresponding + to the last attention block of shape [batch_size, sequence_length, hidden_size], + `pooled_output`: a torch.FloatTensor of size [batch_size, hidden_size] which is the output of a + classifier pretrained on top of the hidden state associated to the first character of the + input (`CLS`) to train on the Next-Sentence task (see BERT's paper). + + Example usage: + ```python + # Already been converted into WordPiece token ids + input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) + input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) + token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) + + config = modeling.BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, + num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) + + model = modeling.BertModel(config=config) + all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask) + ``` + """ + def __init__(self, config): + super(BertModel, self).__init__(config) + self.embeddings = BertEmbeddings(config) + self.encoder = BertEncoder(config) + self.pooler = BertPooler(config) + self.apply(self.init_bert_weights) + + def forward(self, input_ids, token_type_ids=None, attention_mask=None, output_all_encoded_layers=True): + if attention_mask is None: + attention_mask = torch.ones_like(input_ids) + if token_type_ids is None: + token_type_ids = torch.zeros_like(input_ids) + + # We create a 3D attention mask from a 2D tensor mask. + # Sizes are [batch_size, 1, 1, to_seq_length] + # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] + # this attention mask is more simple than the triangular masking of causal attention + # used in OpenAI GPT, we just need to prepare the broadcast dimension here. + extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility + extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + + embedding_output = self.embeddings(input_ids, token_type_ids) + encoded_layers = self.encoder(embedding_output, + extended_attention_mask, + output_all_encoded_layers=output_all_encoded_layers) + sequence_output = encoded_layers[-1] + pooled_output = self.pooler(sequence_output) + if not output_all_encoded_layers: + encoded_layers = encoded_layers[-1] + return encoded_layers, pooled_output + + + + + +class BertForSequenceEncoder(BertPreTrainedModel): + """BERT model for classification. + This module is composed of the BERT model with a linear layer on top of + the pooled output. + Params: + `config`: a BertConfig class instance with the configuration to build a new model. + `num_labels`: the number of classes for the classifier. Default = 2. + Inputs: + `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] + with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts + `extract_features.py`, `run_classifier.py` and `run_squad.py`) + `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token + types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to + a `sentence B` token (see BERT paper for more details). + `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices + selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max + input sequence length in the current batch. It's the mask that we typically use for attention when + a batch has varying length sentences. + `labels`: labels for the classification output: torch.LongTensor of shape [batch_size] + with indices selected in [0, ..., num_labels]. + Outputs: + if `labels` is not `None`: + Outputs the CrossEntropy classification loss of the output with the labels. + if `labels` is `None`: + Outputs the classification logits of shape [batch_size, num_labels]. + Example usage: + ```python + # Already been converted into WordPiece token ids + input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) + input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) + token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) + config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, + num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) + num_labels = 2 + model = BertForSequenceClassification(config, num_labels) + logits = model(input_ids, token_type_ids, input_mask) + ``` + """ + def __init__(self, config): + super(BertForSequenceEncoder, self).__init__(config) + self.bert = BertModel(config) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.apply(self.init_bert_weights) + + def forward(self, input_ids, attention_mask, token_type_ids): + output, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) + output = self.dropout(output) + pooled_output = self.dropout(pooled_output) + return output, pooled_output + + diff --git a/src/er_client/retrieval_model/data_loader.py b/src/er_client/retrieval_model/data_loader.py new file mode 100755 index 0000000000000000000000000000000000000000..84189fb07fc70f8eba7c64c6fbb7e4d1fe91d0d8 --- /dev/null +++ b/src/er_client/retrieval_model/data_loader.py @@ -0,0 +1,276 @@ +import os +import torch +import numpy as np +import json +import re +from torch.autograd import Variable + + +def _truncate_seq_pair(tokens_a, tokens_b, max_length): + """Truncates a sequence pair in place to the maximum length.""" + + # This is a simple heuristic which will always truncate the longer sequence + # one token at a time. This makes more sense than truncating an equal percent + # of tokens from each, since if one sequence is very short then each token + # that's truncated likely contains more information than a longer sequence. + while True: + total_length = len(tokens_a) + len(tokens_b) + if total_length <= max_length: + break + if len(tokens_a) > len(tokens_b): + tokens_a.pop() + else: + tokens_b.pop() + + +def tok2int_sent(sentence, tokenizer, max_seq_length): + """Loads a data file into a list of `InputBatch`s.""" + sent_a, sent_b = sentence + tokens_a = tokenizer.tokenize(sent_a) + + tokens_b = None + if sent_b: + tokens_b = tokenizer.tokenize(sent_b) + _truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3) + else: + # Account for [CLS] and [SEP] with "- 2" + if len(tokens_a) > max_seq_length - 2: + tokens_a = tokens_a[:(max_seq_length - 2)] + + tokens = ["[CLS]"] + tokens_a + ["[SEP]"] + segment_ids = [0] * len(tokens) + if tokens_b: + tokens = tokens + tokens_b + ["[SEP]"] + segment_ids += [1] * (len(tokens_b) + 1) + input_ids = tokenizer.convert_tokens_to_ids(tokens) + input_mask = [1] * len(input_ids) + padding = [0] * (max_seq_length - len(input_ids)) + + input_ids += padding + input_mask += padding + segment_ids += padding + + assert len(input_ids) == max_seq_length + assert len(input_mask) == max_seq_length + assert len(segment_ids) == max_seq_length + + return input_ids, input_mask, segment_ids + + +def tok2int_list(src_list, tokenizer, max_seq_length, max_seq_size=-1): + inp_padding = list() + msk_padding = list() + seg_padding = list() + for step, sent in enumerate(src_list): + input_ids, input_mask, input_seg = tok2int_sent(sent, tokenizer, max_seq_length) + inp_padding.append(input_ids) + msk_padding.append(input_mask) + seg_padding.append(input_seg) + # if max_seq_size != -1: + # inp_padding = inp_padding[:max_seq_size] + # msk_padding = msk_padding[:max_seq_size] + # seg_padding = seg_padding[:max_seq_size] + # inp_padding += ([[0] * max_seq_length] * (max_seq_size - len(inp_padding))) + # msk_padding += ([[0] * max_seq_length] * (max_seq_size - len(msk_padding))) + # seg_padding += ([[0] * max_seq_length] * (max_seq_size - len(seg_padding))) + return inp_padding, msk_padding, seg_padding + + +class DataLoader(object): + ''' For data iteration ''' + + def __init__(self, data_path, tokenizer, args, test=False, cuda=True, batch_size=64): + self.cuda = cuda + + self.batch_size = batch_size + self.tokenizer = tokenizer + self.max_len = args.max_len + self.evi_num = args.evi_num + self.threshold = args.threshold + self.data_path = data_path + self.test = test + examples = self.read_file(data_path) + self.examples = examples + self.total_num = len(examples) + if self.test: + self.total_num = 100000 + self.total_step = np.ceil(self.total_num * 1.0 / batch_size) + self.shuffle() + else: + self.total_step = self.total_num / batch_size + self.shuffle() + self.step = 0 + + def process_sent(self, sentence): + sentence = re.sub(" \-LSB\-.*?\-RSB\-", "", sentence) + sentence = re.sub("\-LRB\- \-RRB\- ", "", sentence) + sentence = re.sub(" -LRB-", " ( ", sentence) + sentence = re.sub("-RRB-", " )", sentence) + sentence = re.sub("--", "-", sentence) + sentence = re.sub("``", '"', sentence) + sentence = re.sub("''", '"', sentence) + + return sentence + + def process_wiki_title(self, title): + title = re.sub("_", " ", title) + title = re.sub(" -LRB-", " ( ", title) + title = re.sub("-RRB-", " )", title) + title = re.sub("-COLON-", ":", title) + return title + + def read_file(self, data_path): + examples = list() + with open(data_path) as fin: + for step, line in enumerate(fin): + sublines = line.strip().split("\t") + examples.append( + [self.process_sent(sublines[0]), self.process_sent(sublines[2]), self.process_sent(sublines[4])]) + return examples + + def shuffle(self): + np.random.shuffle(self.examples) + + def __iter__(self): + return self + + def __next__(self): + return self.next() + + def __len__(self): + return self._n_batch + + def next(self): + ''' Get the next batch ''' + if self.step < self.total_step: + examples = self.examples[self.step * self.batch_size: (self.step + 1) * self.batch_size] + pos_inputs = list() + neg_inputs = list() + for example in examples: + pos_inputs.append([example[0], example[1]]) + neg_inputs.append([example[0], example[2]]) + inp_pos, msk_pos, seg_pos = tok2int_list(pos_inputs, self.tokenizer, self.max_len) + inp_neg, msk_neg, seg_neg = tok2int_list(neg_inputs, self.tokenizer, self.max_len) + + inp_tensor_pos = Variable( + torch.LongTensor(inp_pos)) + msk_tensor_pos = Variable( + torch.LongTensor(msk_pos)) + seg_tensor_pos = Variable( + torch.LongTensor(seg_pos)) + inp_tensor_neg = Variable( + torch.LongTensor(inp_neg)) + msk_tensor_neg = Variable( + torch.LongTensor(msk_neg)) + seg_tensor_neg = Variable( + torch.LongTensor(seg_neg)) + + if self.cuda: + inp_tensor_pos = inp_tensor_pos.cuda() + msk_tensor_pos = msk_tensor_pos.cuda() + seg_tensor_pos = seg_tensor_pos.cuda() + inp_tensor_neg = inp_tensor_neg.cuda() + msk_tensor_neg = msk_tensor_neg.cuda() + seg_tensor_neg = seg_tensor_neg.cuda() + self.step += 1 + return inp_tensor_pos, msk_tensor_pos, seg_tensor_pos, inp_tensor_neg, msk_tensor_neg, seg_tensor_neg + else: + self.step = 0 + if not self.test: + # examples = self.read_file(self.data_path) + # self.examples = examples + self.shuffle() + raise StopIteration() + + +class DataLoaderTest(object): + ''' For data iteration ''' + + def __init__(self, data_path, tokenizer, args, cuda=True, batch_size=64): + self.cuda = cuda + + self.batch_size = batch_size + self.tokenizer = tokenizer + self.max_len = args.max_len + self.evi_num = args.evi_num + self.threshold = args.threshold + self.data_path = data_path + inputs, ids, evi_list = self.read_all(data_path) + self.inputs = inputs + self.ids = ids + self.evi_list = evi_list + + self.total_num = len(inputs) + self.total_step = np.ceil(self.total_num * 1.0 / batch_size) + self.step = 0 + + def process_sent(self, sentence): + sentence = re.sub(" \-LSB\-.*?\-RSB\-", "", sentence) + sentence = re.sub("\-LRB\- \-RRB\- ", "", sentence) + sentence = re.sub(" -LRB-", " ( ", sentence) + sentence = re.sub("-RRB-", " )", sentence) + sentence = re.sub("--", "-", sentence) + sentence = re.sub("``", '"', sentence) + sentence = re.sub("''", '"', sentence) + + return sentence + + def process_wiki_title(self, title): + title = re.sub("_", " ", title) + title = re.sub(" -LRB-", " ( ", title) + title = re.sub("-RRB-", " )", title) + title = re.sub("-COLON-", ":", title) + return title + + def read_all(self, data): + if not isinstance(data, list): + with open(data) as f: + data_ = [json.loads(line) for line in f] + else: + data_ = data + inputs = list() + ids = list() + evi_list = list() + for instance in data_: + claim = instance['claim'] + id = instance['id'] + for evidence in instance['evidence']: + ids.append(id) + inputs.append([self.process_sent(claim), self.process_sent(evidence[2])]) + evi_list.append(evidence) + return inputs, ids, evi_list + + def shuffle(self): + np.random.shuffle(self.examples) + + def __iter__(self): + return self + + def __next__(self): + return self.next() + + def __len__(self): + return self._n_batch + + def next(self): + ''' Get the next batch ''' + if self.step < self.total_step: + inputs = self.inputs[self.step * self.batch_size: (self.step + 1) * self.batch_size] + ids = self.ids[self.step * self.batch_size: (self.step + 1) * self.batch_size] + evi_list = self.evi_list[self.step * self.batch_size: (self.step + 1) * self.batch_size] + inp, msk, seg = tok2int_list(inputs, self.tokenizer, self.max_len, -1) + inp_tensor_input = Variable( + torch.LongTensor(inp)) + msk_tensor_input = Variable( + torch.LongTensor(msk)) + seg_tensor_input = Variable( + torch.LongTensor(seg)) + if self.cuda: + inp_tensor_input = inp_tensor_input.cuda() + msk_tensor_input = msk_tensor_input.cuda() + seg_tensor_input = seg_tensor_input.cuda() + self.step += 1 + return inp_tensor_input, msk_tensor_input, seg_tensor_input, ids, evi_list + else: + self.step = 0 + raise StopIteration() diff --git a/src/er_client/retrieval_model/file_utils.py b/src/er_client/retrieval_model/file_utils.py new file mode 100755 index 0000000000000000000000000000000000000000..c323146a11b6c50e916b4221f0b973ae81c69d56 --- /dev/null +++ b/src/er_client/retrieval_model/file_utils.py @@ -0,0 +1,249 @@ +""" +Utilities for working with the local dataset cache. +This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp +Copyright by the AllenNLP authors. +""" +from __future__ import (absolute_import, division, print_function, unicode_literals) + +import json +import logging +import os +import shutil +import tempfile +from functools import wraps +from hashlib import sha256 +import sys +from io import open + +import boto3 +import requests +from botocore.exceptions import ClientError +from tqdm import tqdm + +try: + from urllib.parse import urlparse +except ImportError: + from urlparse import urlparse + +try: + from pathlib import Path + PYTORCH_PRETRAINED_BERT_CACHE = Path(os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', + Path.home() / '.pytorch_pretrained_bert')) +except AttributeError: + PYTORCH_PRETRAINED_BERT_CACHE = os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', + os.path.join(os.path.expanduser("~"), '.pytorch_pretrained_bert')) + +logger = logging.getLogger(__name__) # pylint: disable=invalid-name + + +def url_to_filename(url, etag=None): + """ + Convert `url` into a hashed filename in a repeatable way. + If `etag` is specified, append its hash to the url's, delimited + by a period. + """ + url_bytes = url.encode('utf-8') + url_hash = sha256(url_bytes) + filename = url_hash.hexdigest() + + if etag: + etag_bytes = etag.encode('utf-8') + etag_hash = sha256(etag_bytes) + filename += '.' + etag_hash.hexdigest() + + return filename + + +def filename_to_url(filename, cache_dir=None): + """ + Return the url and etag (which may be ``None``) stored for `filename`. + Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist. + """ + if cache_dir is None: + cache_dir = PYTORCH_PRETRAINED_BERT_CACHE + if sys.version_info[0] == 3 and isinstance(cache_dir, Path): + cache_dir = str(cache_dir) + + cache_path = os.path.join(cache_dir, filename) + if not os.path.exists(cache_path): + raise EnvironmentError("file {} not found".format(cache_path)) + + meta_path = cache_path + '.json' + if not os.path.exists(meta_path): + raise EnvironmentError("file {} not found".format(meta_path)) + + with open(meta_path, encoding="utf-8") as meta_file: + metadata = json.load(meta_file) + url = metadata['url'] + etag = metadata['etag'] + + return url, etag + + +def cached_path(url_or_filename, cache_dir=None): + """ + Given something that might be a URL (or might be a local path), + determine which. If it's a URL, download the file and cache it, and + return the path to the cached file. If it's already a local path, + make sure the file exists and then return the path. + """ + if cache_dir is None: + cache_dir = PYTORCH_PRETRAINED_BERT_CACHE + if sys.version_info[0] == 3 and isinstance(url_or_filename, Path): + url_or_filename = str(url_or_filename) + if sys.version_info[0] == 3 and isinstance(cache_dir, Path): + cache_dir = str(cache_dir) + + parsed = urlparse(url_or_filename) + + if parsed.scheme in ('http', 'https', 's3'): + # URL, so get it from the cache (downloading if necessary) + return get_from_cache(url_or_filename, cache_dir) + elif os.path.exists(url_or_filename): + # File, and it exists. + return url_or_filename + elif parsed.scheme == '': + # File, but it doesn't exist. + raise EnvironmentError("file {} not found".format(url_or_filename)) + else: + # Something unknown + raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename)) + + +def split_s3_path(url): + """Split a full s3 path into the bucket name and path.""" + parsed = urlparse(url) + if not parsed.netloc or not parsed.path: + raise ValueError("bad s3 path {}".format(url)) + bucket_name = parsed.netloc + s3_path = parsed.path + # Remove '/' at beginning of path. + if s3_path.startswith("/"): + s3_path = s3_path[1:] + return bucket_name, s3_path + + +def s3_request(func): + """ + Wrapper function for s3 requests in order to create more helpful error + messages. + """ + + @wraps(func) + def wrapper(url, *args, **kwargs): + try: + return func(url, *args, **kwargs) + except ClientError as exc: + if int(exc.response["Error"]["Code"]) == 404: + raise EnvironmentError("file {} not found".format(url)) + else: + raise + + return wrapper + + +@s3_request +def s3_etag(url): + """Check ETag on S3 object.""" + s3_resource = boto3.resource("s3") + bucket_name, s3_path = split_s3_path(url) + s3_object = s3_resource.Object(bucket_name, s3_path) + return s3_object.e_tag + + +@s3_request +def s3_get(url, temp_file): + """Pull a file directly from S3.""" + s3_resource = boto3.resource("s3") + bucket_name, s3_path = split_s3_path(url) + s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file) + + +def http_get(url, temp_file): + req = requests.get(url, stream=True) + content_length = req.headers.get('Content-Length') + total = int(content_length) if content_length is not None else None + progress = tqdm(unit="B", total=total) + for chunk in req.iter_content(chunk_size=1024): + if chunk: # filter out keep-alive new chunks + progress.update(len(chunk)) + temp_file.write(chunk) + progress.close() + + +def get_from_cache(url, cache_dir=None): + """ + Given a URL, look for the corresponding dataset in the local cache. + If it's not there, download it. Then return the path to the cached file. + """ + if cache_dir is None: + cache_dir = PYTORCH_PRETRAINED_BERT_CACHE + if sys.version_info[0] == 3 and isinstance(cache_dir, Path): + cache_dir = str(cache_dir) + + if not os.path.exists(cache_dir): + os.makedirs(cache_dir) + + # Get eTag to add to filename, if it exists. + if url.startswith("s3://"): + etag = s3_etag(url) + else: + response = requests.head(url, allow_redirects=True) + if response.status_code != 200: + raise IOError("HEAD request failed for url {} with status code {}" + .format(url, response.status_code)) + etag = response.headers.get("ETag") + + filename = url_to_filename(url, etag) + + # get cache path to put the file + cache_path = os.path.join(cache_dir, filename) + + if not os.path.exists(cache_path): + # Download to temporary file, then copy to cache dir once finished. + # Otherwise you get corrupt cache entries if the download gets interrupted. + with tempfile.NamedTemporaryFile() as temp_file: + logger.info("%s not found in cache, downloading to %s", url, temp_file.name) + + # GET file object + if url.startswith("s3://"): + s3_get(url, temp_file) + else: + http_get(url, temp_file) + + # we are copying the file before closing it, so flush to avoid truncation + temp_file.flush() + # shutil.copyfileobj() starts at the current position, so go to the start + temp_file.seek(0) + + logger.info("copying %s to cache at %s", temp_file.name, cache_path) + with open(cache_path, 'wb') as cache_file: + shutil.copyfileobj(temp_file, cache_file) + + logger.info("creating metadata file for %s", cache_path) + meta = {'url': url, 'etag': etag} + meta_path = cache_path + '.json' + with open(meta_path, 'w', encoding="utf-8") as meta_file: + json.dump(meta, meta_file) + + logger.info("removing temp file %s", temp_file.name) + + return cache_path + + +def read_set_from_file(filename): + ''' + Extract a de-duped collection (set) of text from a file. + Expected file format is one item per line. + ''' + collection = set() + with open(filename, 'r', encoding='utf-8') as file_: + for line in file_: + collection.add(line.rstrip()) + return collection + + +def get_file_extension(path, dot=True, lower=True): + ext = os.path.splitext(path)[1] + ext = ext if dot else ext[1:] + return ext.lower() if lower else ext \ No newline at end of file diff --git a/src/er_client/retrieval_model/models.py b/src/er_client/retrieval_model/models.py new file mode 100755 index 0000000000000000000000000000000000000000..ad8658a37011aa113a75967d38b2a4c6f9a295ca --- /dev/null +++ b/src/er_client/retrieval_model/models.py @@ -0,0 +1,66 @@ +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import BatchNorm1d, Linear, ReLU +from .bert_model import BertForSequenceEncoder + +from torch.nn import BatchNorm1d, Linear, ReLU +from .bert_model import BertForSequenceEncoder +from torch.autograd import Variable +import numpy as np + + + + +def kernal_mus(n_kernels): + """ + get the mu for each guassian kernel. Mu is the middle of each bin + :param n_kernels: number of kernels (including exact match). first one is exact match + :return: l_mu, a list of mu. + """ + l_mu = [1] + if n_kernels == 1: + return l_mu + + bin_size = 2.0 / (n_kernels - 1) # score range from [-1, 1] + l_mu.append(1 - bin_size / 2) # mu: middle of the bin + for i in range(1, n_kernels - 1): + l_mu.append(l_mu[i] - bin_size) + return l_mu + + +def kernel_sigmas(n_kernels): + """ + get sigmas for each guassian kernel. + :param n_kernels: number of kernels (including exactmath.) + :param lamb: + :param use_exact: + :return: l_sigma, a list of simga + """ + bin_size = 2.0 / (n_kernels - 1) + l_sigma = [0.001] # for exact match. small variance -> exact match + if n_kernels == 1: + return l_sigma + + l_sigma += [0.1] * (n_kernels - 1) + return l_sigma + +class inference_model(nn.Module): + def __init__(self, bert_model, args): + super(inference_model, self).__init__() + self.bert_hidden_dim = args.bert_hidden_dim + self.dropout = nn.Dropout(args.dropout) + self.max_len = args.max_len + self.num_labels = args.num_labels + self.pred_model = bert_model + #self.proj_hidden = nn.Linear(self.bert_hidden_dim, 128) + self.proj_match = nn.Linear(self.bert_hidden_dim, 1) + + + def forward(self, inp_tensor, msk_tensor, seg_tensor): + _, inputs = self.pred_model(inp_tensor, msk_tensor, seg_tensor) + inputs = self.dropout(inputs) + score = self.proj_match(inputs).squeeze(-1) + score = torch.tanh(score) + return score \ No newline at end of file diff --git a/src/er_client/retrieval_model/process_data.py b/src/er_client/retrieval_model/process_data.py new file mode 100755 index 0000000000000000000000000000000000000000..7115d7800645cc1f2832a455f8e97a77010af4eb --- /dev/null +++ b/src/er_client/retrieval_model/process_data.py @@ -0,0 +1,41 @@ +import json +import os +import argparse + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--gold_file') + parser.add_argument('--retrieval_file') + parser.add_argument('--output') + parser.add_argument('--test', action='store_true', default=False) + args = parser.parse_args() + filter_dict = dict() + data_dict = dict() + golden_dict = dict() + with open(args.gold_file) as f: + for line in f: + data = json.loads(line) + data_dict[data["id"]] = {"id": data["id"], "evidence":[], "claim": data["claim"]} + if "label" in data: + data_dict[data["id"]]["label"] = data["label"] + if not args.test: + for evidence in data["evidence"]: + data_dict[data["id"]]["evidence"].append([evidence[0], evidence[1], evidence[2], 1.0]) + string = str(data["id"]) + "_" + evidence[0] + "_" + str(evidence[1]) + golden_dict[string] = 1 + with open(args.retrieval_file) as f: + for line in f: + data = json.loads(line) + for step, evidence in enumerate(data["evidence"]): + string = str(data["id"]) + "_" + str(evidence[0]) + "_" + str(evidence[1]) + if string not in golden_dict and string not in filter_dict: + data_dict[data["id"]]["evidence"].append([evidence[0], evidence[1], evidence[2], evidence[4]]) + filter_dict[string] = 1 + with open(args.output, "w") as out: + for data in data_dict.values(): + evidence_tmp = data["evidence"] + evidence_tmp = sorted(evidence_tmp, key=lambda x:x[3], reverse=True) + data["evidence"] = evidence_tmp[:5] + out.write(json.dumps(data) + "\n") + + diff --git a/src/er_client/retrieval_model/test.py b/src/er_client/retrieval_model/test.py new file mode 100755 index 0000000000000000000000000000000000000000..b7c109457a7f1b4b4704bd2334d3e35aab5cacc7 --- /dev/null +++ b/src/er_client/retrieval_model/test.py @@ -0,0 +1,81 @@ +import logging +import argparse +import os +import json +import torch +from tqdm import tqdm +from transformers import BertTokenizer + +from .models import inference_model +from .data_loader import DataLoaderTest +from .bert_model import BertForSequenceEncoder + +logger = logging.getLogger(__name__) + + +def save_to_file(all_predict, outpath, evi_num): + with open(outpath, "w") as out: + for key, values in all_predict.items(): + sorted_values = sorted(values, key=lambda x:x[-1], reverse=True) + data = json.dumps({"id": key, "evidence": sorted_values[:evi_num]}) + out.write(data + "\n") + + +def eval_model(model, validset_reader): + model.eval() + all_predict = dict() + for inp_tensor, msk_tensor, seg_tensor, ids, evi_list in tqdm(validset_reader): + probs = model(inp_tensor, msk_tensor, seg_tensor) + probs = probs.tolist() + assert len(probs) == len(evi_list) + for i in range(len(probs)): + if ids[i] not in all_predict: + all_predict[ids[i]] = [] + #if probs[i][1] >= probs[i][0]: + all_predict[ids[i]].append(evi_list[i] + [probs[i]]) + return all_predict + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--test_path', help='train path') + parser.add_argument('--name', help='train path') + parser.add_argument("--batch_size", default=32, type=int, help="Total batch size for training.") + parser.add_argument('--outdir', required=True, help='path to output directory') + parser.add_argument('--bert_pretrain', required=True) + parser.add_argument('--checkpoint', required=True) + parser.add_argument('--dropout', type=float, default=0.6, help='Dropout.') + parser.add_argument('--no-cuda', action='store_true', default=False, help='Disables CUDA training.') + parser.add_argument("--bert_hidden_dim", default=768, type=int, help="Total batch size for training.") + parser.add_argument("--layer", type=int, default=1, help='Graph Layer.') + parser.add_argument("--num_labels", type=int, default=3) + parser.add_argument("--evi_num", type=int, default=5, help='Evidence num.') + parser.add_argument("--threshold", type=float, default=0.0, help='Evidence num.') + parser.add_argument("--max_len", default=120, type=int, + help="The maximum total input sequence length after WordPiece tokenization. Sequences " + "longer than this will be truncated, and sequences shorter than this will be padded.") + args = parser.parse_args() + + if not os.path.exists(args.outdir): + os.mkdir(args.outdir) + args.cuda = not args.no_cuda and torch.cuda.is_available() + handlers = [logging.FileHandler(os.path.abspath(args.outdir) + '/train_log.txt'), logging.StreamHandler()] + logging.basicConfig(format='[%(asctime)s] %(levelname)s: %(message)s', level=logging.DEBUG, + datefmt='%d-%m-%Y %H:%M:%S', handlers=handlers) + logger.info(args) + logger.info('Start training!') + + tokenizer = BertTokenizer.from_pretrained(args.bert_pretrain, do_lower_case=False) + logger.info("loading training set") + validset_reader = DataLoaderTest(args.test_path, tokenizer, args, batch_size=args.batch_size) + + logger.info('initializing estimator model') + bert_model = BertForSequenceEncoder.from_pretrained(args.bert_pretrain) + bert_model = bert_model.cuda() + model = inference_model(bert_model, args) + model.load_state_dict(torch.load(args.checkpoint)['model']) + model = model.cuda() + logger.info('Start eval!') + save_path = args.outdir + "/" + args.name + predict_dict = eval_model(model, validset_reader) + save_to_file(predict_dict, save_path, args.evi_num) \ No newline at end of file diff --git a/src/er_client/retrieval_model/test.sh b/src/er_client/retrieval_model/test.sh new file mode 100755 index 0000000000000000000000000000000000000000..b0129f74c72df97d14970f96e22a9152bd802aca --- /dev/null +++ b/src/er_client/retrieval_model/test.sh @@ -0,0 +1,7 @@ +python test.py \ +--test_path ../data/pages.json \ +--bert_pretrain ../evidence_retrieval/bert_base \ +--checkpoint ../evidence_retrieval/retrieval_model/model.best.pt \ +--evi_num 5 \ +--outdir ../data \ +--name evidence.json diff --git a/src/er_client/sentence_selection.py b/src/er_client/sentence_selection.py new file mode 100644 index 0000000000000000000000000000000000000000..9d4a71a69fee416fccff99735213770b2d697d92 --- /dev/null +++ b/src/er_client/sentence_selection.py @@ -0,0 +1,54 @@ +# -*- coding: utf-8 -*- + +''' +@Author : Jiangjie Chen +@Time : 2020/9/20 11:42 +@Contact : jjchen19@fudan.edu.cn +@Description: +''' + +import torch +from transformers import BertTokenizer +from .retrieval_model.bert_model import BertForSequenceEncoder +from .retrieval_model.models import inference_model +from .retrieval_model.data_loader import DataLoaderTest + + +class SentSelector: + def __init__(self, pretrained_bert_path, select_model_path, args): + self.args = args + self.use_cuda = self.args.use_cuda and torch.cuda.is_available() + + self.tokenizer = BertTokenizer.from_pretrained('bert-base-cased') + self.bert_model = BertForSequenceEncoder.from_pretrained(pretrained_bert_path) + + self.rank_model = inference_model(self.bert_model, self.args) + self.rank_model.load_state_dict(torch.load(select_model_path)['model']) + + if self.use_cuda: + self.bert_model = self.bert_model.cuda() + self.rank_model.cuda() + + def rank_sentences(self, js: list): + ''' + :param js: [{'claim': xxx, 'id': xx, 'evidence': xxx}] + :return: [(ent, num, sent, prob), (ent, num, sent, prob)] + ''' + data_reader = DataLoaderTest(js, self.tokenizer, self.args, self.use_cuda) + self.rank_model.eval() + all_predict = dict() + for inp_tensor, msk_tensor, seg_tensor, ids, evi_list in data_reader: + probs = self.rank_model(inp_tensor, msk_tensor, seg_tensor) + probs = probs.tolist() + assert len(probs) == len(evi_list) + for i in range(len(probs)): + if ids[i] not in all_predict: + all_predict[ids[i]] = [] + # if probs[i][1] >= probs[i][0]: + all_predict[ids[i]].append(tuple(evi_list[i]) + (probs[i],)) + + results = {} + for k, v in all_predict.items(): + sorted_v = sorted(v, key=lambda x: x[-1], reverse=True) + results[k] = sorted_v[:self.args.evi_num] + return results diff --git a/src/eval_client/cjjpy.py b/src/eval_client/cjjpy.py new file mode 100755 index 0000000000000000000000000000000000000000..2cc70b5e553924123810ab198c143bf7ee28e5d6 --- /dev/null +++ b/src/eval_client/cjjpy.py @@ -0,0 +1,249 @@ +# -*- coding: utf-8 -*- + +''' +@Author : Jiangjie Chen +@Time : 2018/11/15 17:08 +@Contact: jjchen19@fudan.edu.cn +''' + +import re +import datetime +import os +import argparse +import logging +import traceback + +try: + import ujson as json +except: + import json + +HADOOP_BIN = 'PATH=/usr/bin/:$PATH hdfs' +FOR_PUBLIC = True + + +def LengthStats(filename): + len_list = [] + thresholds = [0.8, 0.9, 0.95, 0.99, 0.999] + with open(filename) as f: + for line in f: + len_list.append(len(line.strip().split())) + stats = { + 'Max': max(len_list), + 'Min': min(len_list), + 'Avg': round(sum(len_list) / len(len_list), 4), + } + len_list.sort() + for t in thresholds: + stats[f"Top-{t}"] = len_list[int(len(len_list) * t)] + + for k in stats: + print(f"- {k}: {stats[k]}") + return stats + + +class AttrDict(dict): + def __init__(self, *args, **kwargs): + super(AttrDict, self).__init__(*args, **kwargs) + self.__dict__ = self + + +def TraceBack(error_msg): + exc = traceback.format_exc() + msg = f'[Error]: {error_msg}.\n[Traceback]: {exc}' + return msg + + +def Now(): + return datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") + + +def AbsParentDir(file, parent='..', postfix=None): + ppath = os.path.abspath(file) + parent_level = parent.count('.') + while parent_level > 0: + ppath = os.path.dirname(ppath) + parent_level -= 1 + if postfix is not None: + return os.path.join(ppath, postfix) + else: + return ppath + + +def init_logger(log_file=None, log_file_level=logging.NOTSET, from_scratch=False): + from coloredlogs import ColoredFormatter + import tensorflow as tf + + fmt = "[%(asctime)s %(levelname)s] %(message)s" + log_format = ColoredFormatter(fmt=fmt) + # log_format = logging.Formatter() + logger = logging.getLogger() + logger.setLevel(log_file_level) + + console_handler = logging.StreamHandler() + console_handler.setFormatter(log_format) + logger.handlers = [console_handler] + + if log_file and log_file != '': + if from_scratch and tf.io.gfile.exists(log_file): + logger.warning('Removing previous log file: %s' % log_file) + tf.io.gfile.remove(log_file) + path = os.path.dirname(log_file) + os.makedirs(path, exist_ok=True) + file_handler = logging.FileHandler(log_file) + file_handler.setLevel(log_file_level) + file_handler.setFormatter(log_format) + logger.addHandler(file_handler) + + return logger + + +def OverWriteCjjPy(root='.'): + # import difflib + # diff = difflib.HtmlDiff() + cnt = 0 + golden_cjjpy = os.path.join(root, 'cjjpy.py') + # golden_content = open(golden_cjjpy).readlines() + for dir, folder, file in os.walk(root): + for f in file: + if f == 'cjjpy.py': + cjjpy = '%s/%s' % (dir, f) + # content = open(cjjpy).readlines() + # d = diff.make_file(golden_content, content) + cnt += 1 + print('[%d]: %s' % (cnt, cjjpy)) + os.system('cp %s %s' % (golden_cjjpy, cjjpy)) + + +def ChangeFileFormat(filename, new_fmt): + assert type(filename) is str and type(new_fmt) is str + spt = filename.split('.') + if len(spt) == 0: + return filename + else: + return filename.replace('.' + spt[-1], new_fmt) + + +def CountLines(fname): + with open(fname, 'rb') as f: + count = 0 + last_data = '\n' + while True: + data = f.read(0x400000) + if not data: + break + count += data.count(b'\n') + last_data = data + if last_data[-1:] != b'\n': + count += 1 # Remove this if a wc-like count is needed + return count + + +def GetDate(): + return str(datetime.datetime.now())[5:10].replace('-', '') + + +def TimeClock(seconds): + sec = int(seconds) + hour = int(sec / 3600) + min = int((sec - hour * 3600) / 60) + ssec = float(seconds) - hour * 3600 - min * 60 + # return '%dh %dm %.2fs' % (hour, min, ssec) + return '{0:>2d}h{1:>3d}m{2:>6.2f}s'.format(hour, min, ssec) + + +def StripAll(text): + return text.strip().replace('\t', '').replace('\n', '').replace(' ', '') + + +def GetBracket(text, bracket, en_br=False): + # input should be aa(bb)cc, True for bracket, False for text + if bracket: + try: + return re.findall('\((.*?)\)', text.strip())[-1] + except: + return '' + else: + if en_br: + text = re.sub('\(.*?\)', '', text.strip()) + return re.sub('(.*?)', '', text.strip()) + + +def CharLang(uchar, lang): + assert lang.lower() in ['en', 'cn', 'zh'] + if lang.lower() in ['cn', 'zh']: + if uchar >= '\u4e00' and uchar <= '\u9fa5': + return True + else: + return False + elif lang.lower() == 'en': + if (uchar <= 'Z' and uchar >= 'A') or (uchar <= 'z' and uchar >= 'a'): + return True + else: + return False + else: + raise NotImplementedError + + +def WordLang(word, lang): + for i in word.strip(): + if i.isspace(): continue + if not CharLang(i, lang): + return False + return True + + +def SortDict(_dict, reverse=True): + assert type(_dict) is dict + return sorted(_dict.items(), key=lambda d: d[1], reverse=reverse) + + +def lark(content='test'): + print(content) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + + parser.add_argument('--diff', nargs=2, + help='show difference between two files, shown in downloads/diff.html') + parser.add_argument('--de_unicode', action='store_true', default=False, + help='remove unicode characters') + parser.add_argument('--link_entity', action='store_true', default=False, + help='') + parser.add_argument('--max_comm_len', action='store_true', default=False, + help='') + parser.add_argument('--search', nargs=2, + help='search key from file, 2 args: file name & key') + parser.add_argument('--email', nargs=2, + help='sending emails, 2 args: subject & content') + parser.add_argument('--overwrite', action='store_true', default=None, + help='overwrite all cjjpy under given *dir* based on *dir*/cjjpy.py') + parser.add_argument('--replace', nargs=3, + help='replace char, 3 args: file name & replaced char & replacer char') + parser.add_argument('--lark', nargs=1) + parser.add_argument('--get_hdfs', nargs=2, + help='easy copy from hdfs to local fs, 2 args: remote_file/dir & local_dir') + parser.add_argument('--put_hdfs', nargs=2, + help='easy put from local fs to hdfs, 2 args: local_file/dir & remote_dir') + parser.add_argument('--length_stats', nargs=1, + help='simple token lengths distribution of a line-by-line file') + + args = parser.parse_args() + + if args.overwrite: + print('* Overwriting cjjpy...') + OverWriteCjjPy() + + if args.lark: + try: + content = args.lark[0] + except: + content = 'running complete' + print(f'* Larking "{content}"...') + lark(content) + + if args.length_stats: + file = args.length_stats[0] + print(f'* Working on {file} lengths statistics...') + LengthStats(file) diff --git a/src/eval_client/culpa.py b/src/eval_client/culpa.py new file mode 100644 index 0000000000000000000000000000000000000000..41fa4d9b876a05061cd79d1ebe7a2aec194fac3f --- /dev/null +++ b/src/eval_client/culpa.py @@ -0,0 +1,61 @@ +# -*- coding:utf-8 -*- + +""" +@Author : Bao +@Date : 2021/9/7 +@Desc : +@Last modified by : Bao +@Last modified date : 2021/9/7 +""" + +import json +import numpy as np +import argparse +from collections import defaultdict +from sklearn.metrics import precision_recall_fscore_support + +# ref --> label 1, nei & sup --> label 0 +idx2label = {0: 1, 1: 0, 2: 0} + + +def read_json_lines(filename, mode='r', encoding='utf-8', skip=0): + with open(filename, mode, encoding=encoding) as fin: + for line in fin: + if skip > 0: + skip -= 1 + continue + yield json.loads(line) + + +def process(filein): + id2info = defaultdict(dict) + for line in read_json_lines('eval.human.ref.merged.json'): + labels = [0] * len(line['questions']) + for cul in line['culprit']: + labels[cul] = 1 + id2info[line['id']].update({'id': line['id'], 'labels': labels}) + + for line in read_json_lines(filein): + if line['id'] not in id2info: continue + predicted = [idx2label[_] for _ in np.argmax(line['z_prob'], axis=-1)] + id2info[line['id']]['predicted'] = predicted + + ps, rs, fs = [], [], [] + for info in id2info.values(): + p, r, f, _ = precision_recall_fscore_support(info['labels'], info['predicted'], average='binary') + ps.append(p) + rs.append(r) + fs.append(f) + print(filein) + print('Precision: {}'.format(sum(ps) / len(ps))) + print('Recall: {}'.format(sum(rs) / len(rs))) + print('F1: {}'.format(sum(fs) / len(fs))) + + return sum(ps) / len(ps), sum(rs) / len(rs), sum(fs) / len(fs) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('-i', type=str, help='predicted jsonl file with phrasal veracity predictions.') + args = parser.parse_args() + process(args.i) diff --git a/src/eval_client/culprit/eval.human.ref.json b/src/eval_client/culprit/eval.human.ref.json new file mode 100644 index 0000000000000000000000000000000000000000..a51c29d792668abacb74532e58a2ef3497206f1d --- /dev/null +++ b/src/eval_client/culprit/eval.human.ref.json @@ -0,0 +1,100 @@ +{"id": 102600, "claim": "Sausage Party was released in May of 2016 .", "questions": ["What was the name of the new album released in May 2016? or was released in May of 2016 .", "When was Sausage Party released? or Sausage Party was released in of 2016 .", "When was Sausage Party released? or Sausage Party was released in May of .", "What was Sausage Party's release date? or Sausage Party was in May of 2016 ."], "answers": [["Sausage Party", 0, 13], ["May", 30, 33], ["2016", 37, 41], ["released", 18, 26]], "evidential": [["Sausage Party", "The Sausage Party", "A Sausage Party", "Sausage party"], ["August", "the summer", "March", "the fall"], ["2016", "the year 2016", "March 2016", "2015"], ["released", "announced", "premiered", "released domestically"]], "culprit": [1]} +{"id": 92833, "claim": "Anne Boleyn did not live in England in 1522 .", "questions": ["Who did not live in England in 1522? or did not live in England in 1522 .", "Where did Anne Boleyn live in 1522? or Anne Boleyn did not live in in 1522 .", "When did Anne Boleyn not live in England? or Anne Boleyn did not live in England in .", "What did Anne Boleyn not do in England? or Anne Boleyn did not in England in 1522 ."], "answers": [["Anne Boleyn", 0, 11], ["England", 28, 35], ["1522", 39, 43], ["live", 20, 24]], "evidential": [["Anne Boleyn", "Ann Boleyn", "Anne Bolyn", "A woman"], ["England", "Europe", "the world", "the UK"], ["1532", "1536", "1533", "1534"], ["live", "stay", "marry", "reside"]], "culprit": [1, 2]} +{"id": 159707, "claim": "Edgar Wright is only a producer .", "questions": ["Who is the only producer? or is only a producer .", "What is Edgar Wright's job title? or Edgar Wright is ."], "answers": [["Edgar Wright", 0, 12], ["only a producer", 16, 31]], "evidential": [["Edgar Wright", "Edgar Wright Jr.", "Edgar W. Wright", "Edgar Wayne Wright"], ["a producer", "a director", "a screenwriter", "a film producer"]], "culprit": [1]} +{"id": 146055, "claim": "The Giver is a bill .", "questions": ["What is a bill called? or is a bill .", "What is the Giver? or The Giver is ."], "answers": [["The Giver", 0, 9], ["a bill", 13, 19]], "evidential": [["The Giver", "A The Giver", "The giver", "The Giver Act"], ["a film", "a work", "a motion picture", "a movie"]], "culprit": [1]} +{"id": 8443, "claim": "A Milli is by Justin Bieber .", "questions": ["What is the name of Justin Bieber's song? or is by Justin Bieber .", "Who is A Milli by? or A Milli is by ."], "answers": [["A Milli", 0, 7], ["Justin Bieber", 14, 27]], "evidential": [["A Milli", "A Milli song", "A Milli Song", "A Milli."], ["Justin Bieber", "a Justin Bieber", "an artist", "a musician"]], "culprit": [1]} +{"id": 67833, "claim": "Shane McMahon did not win the Hardcore Championship once .", "questions": ["Who won the Hardcore Championship once? or did not win the Hardcore Championship once .", "What did Shane McMahon not win once? or Shane McMahon did not win once .", "What did Shane McMahon not do once? or Shane McMahon did not the Hardcore Championship once ."], "answers": [["Shane McMahon", 0, 13], ["the Hardcore Championship", 26, 51], ["win", 22, 25]], "evidential": [["Shane McMahon", "Shane McMahon", "Shane McMah", "Shane McMahon ("], ["the European Championship", "the Hardcore Championship", "a wrestling championship", "a championship"], ["win", "won", "achieve", "earn"]], "culprit": [1]} +{"id": 116789, "claim": "Minor League Baseball is a hierarchy of only amateur baseball leagues .", "questions": ["What is the name of the only amateur baseball league? or is a hierarchy of only amateur baseball leagues .", "What is Minor League Baseball? or Minor League Baseball is of only amateur baseball leagues .", "What is Minor League Baseball a hierarchy of? or Minor League Baseball is a hierarchy of ."], "answers": [["Minor League Baseball", 0, 21], ["a hierarchy", 25, 36], ["only amateur baseball leagues", 40, 69]], "evidential": [["Minor League Baseball", "The Minor League Baseball", "Minor league Baseball", "Major League Baseball"], ["a hierarchy", "an organization", "a system", "a structure"], ["professional baseball leagues", "minor league baseball", "professional baseball teams", "baseball leagues"]], "culprit": [2]} +{"id": 12454, "claim": "Tangled is a silent film .", "questions": ["What is the name of the film that is a silent film? or is a silent film .", "What type of film is Tangled? or Tangled is ."], "answers": [["Tangled", 0, 7], ["a silent film", 11, 24]], "evidential": [["Tangled", "Tangles", "Tangled (", "Tangling"], ["an animated film", "a musical fantasy film", "a fantasy film", "a film"]], "culprit": [1]} +{"id": 149501, "claim": "Kung Fu Panda was number three at the box office .", "questions": ["What movie was number three at the box office? or was number three at the box office .", "What was Kung Fu Panda's box office number? or Kung Fu Panda was at the box office .", "Where was Kung Fu Panda number three? or Kung Fu Panda was number three at ."], "answers": [["Kung Fu Panda", 0, 13], ["number three", 18, 30], ["the box office", 34, 48]], "evidential": [["Kung Fu Panda", "Kung fu Panda", "Kung F Panda", "Kungfu Panda"], ["the number one", "number one", "the number one movie", "the number one film"], ["the box office", "the movie box office", "the US box office", "a box office"]], "culprit": [1]} +{"id": 51962, "claim": "Mandy Moore is a Canadian film actress .", "questions": ["Who is the name of the Canadian film actress? or is a Canadian film actress .", "What nationality is Mandy Moore? or Mandy Moore is film actress .", "What is Mandy Moore's career? or Mandy Moore is a Canadian ."], "answers": [["Mandy Moore", 0, 11], ["a Canadian", 15, 25], ["film actress", 26, 38]], "evidential": [["Mandy Moore", "Mandy Moore ( choreographer )", "Mandy Moore ( dancer )", "Mandy Moore( choreographer )"], ["an American", "an american", "a North American", "an North American"], ["actress", "film actress", "actor", "singer"]], "culprit": [1]} +{"id": 217102, "claim": "Innovation is viewed as the application of better solutions that negate market needs .", "questions": ["What is viewed as the application of better solutions that negate market needs? or is viewed as the application of better solutions that negate market needs .", "Innovation is viewed as what? or Innovation is viewed as of better solutions that negate market needs .", "Innovation is viewed as the application of what? or Innovation is viewed as the application of that negate market needs .", "Innovation is viewed as the application of better solutions that negate what? or Innovation is viewed as the application of better solutions that negate .", "What is innovation as? or Innovation is as the application of better solutions that negate market needs .", "Innovation is viewed as the application of better solutions that do what to market needs? or Innovation is viewed as the application of better solutions that market needs ."], "answers": [["Innovation", 0, 10], ["the application", 24, 39], ["better solutions", 43, 59], ["market needs", 72, 84], ["viewed", 14, 20], ["negate", 65, 71]], "evidential": [["Innovation", "Technology innovation", "Insulin", "In innovation"], ["the application", "an application", "a application", "the applications"], ["solutions", "new solutions", "better solutions", "products"], ["new requirements", "existing market needs", "existing market requirements", "existing requirements"], ["viewed", "perceived", "characterized", "described"], ["meet", "meet existing", "meet current", "met"]], "culprit": [5]} +{"id": 202314, "claim": "The New Jersey Turnpike has zero shoulders .", "questions": ["What has zero shoulders? or has zero shoulders .", "What is the total length of the New Jersey Turnpike? or The New Jersey Turnpike has ."], "answers": [["The New Jersey Turnpike", 0, 23], ["zero shoulders", 28, 42]], "evidential": [["The New Jersey Turnpike", "New Jersey Turnpike", "A New Jersey Turnpike", "the New Jersey Turnpike"], ["12 ft lanes", "a total length", "12 feet long", "12 feet"]], "culprit": [1]} +{"id": 226106, "claim": "Bongwater is set outside of Oregon .", "questions": ["What is the name of the town outside of Oregon? or is set outside of Oregon .", "What state is Bongwater located outside of? or Bongwater is set outside of .", "Where is Bongwater located outside of Oregon? or Bongwater is outside of Oregon .", "Where is Bongwater located? or Bongwater is set of Oregon ."], "answers": [["Bongwater", 0, 9], ["Oregon", 28, 34], ["set", 13, 16], ["outside", 17, 24]], "evidential": [["Bongwater", "The film Bongwater", "Bongwwater", "Bongswater"], ["Oregon", "a state", "Washington State", "the Oregon"], ["set", "located", "filmed", "based"], ["the state", "outside", "the city", "the coast"]], "culprit": [3]} +{"id": 182051, "claim": "The Fly was first released in 1999 .", "questions": ["What was the name of the first film released in 1999? or was first released in 1999 .", "When was The Fly first released? or The Fly was first released in .", "When was The Fly first ? or The Fly was first in 1999 .", "When was The Fly released? or The Fly was released in 1999 ."], "answers": [["The Fly", 0, 7], ["1999", 30, 34], ["released", 18, 26], ["first", 12, 17]], "evidential": [["The Fly", "The Fly 's", "A film The Fly", "The fly"], ["August 1986", "1986", "the 1980s", "the eighties"], ["released", "published", "distributed", "release"], ["first", "originally", "last", "only"]], "culprit": [1]} +{"id": 65598, "claim": "Uganda was not ruled by the British .", "questions": ["What country was not ruled by the British? or was not ruled by the British .", "Who ruled Uganda? or Uganda was not ruled by .", "What was Uganda not by the British? or Uganda was not by the British ."], "answers": [["Uganda", 0, 6], ["the British", 24, 35], ["ruled", 15, 20]], "evidential": [["Uganda", "Uganda", "Ugandan", "Uganda"], ["the British", "Britain", "a colony", "British"], ["ruled", "controlled", "governed", "owned"]], "culprit": [1]} +{"id": 117126, "claim": "Pocahontas was not the daughter of Powhatan .", "questions": ["Who was not the daughter of Powhatan? or was not the daughter of Powhatan .", "What was Pocahontas' mother's name? or Pocahontas was not of Powhatan .", "Who was Pocahontas' father? or Pocahontas was not the daughter of ."], "answers": [["Pocahontas", 0, 10], ["the daughter", 19, 31], ["Powhatan", 35, 43]], "evidential": [["Pocahontas", "Pocahonta", "Pocahontas n't", "Pocahontas Jr."], ["the daughter", "a daughter", "the granddaughter", "the child"], ["Powhatan", "a Native American", "a chief", "a person"]], "culprit": [1, 2]} +{"id": 164506, "claim": "The Nobel Prize in Chemistry was awarded to a person from anywhere except the Netherlands .", "questions": ["What award was given to a person from anywhere except the Netherlands? or was awarded to a person from anywhere except the Netherlands .", "Who was the Nobel Prize in Chemistry awarded to? or The Nobel Prize in Chemistry was awarded to from anywhere except the Netherlands .", "Where was the Nobel Prize in Chemistry awarded to? or The Nobel Prize in Chemistry was awarded to a person from except the Netherlands .", "Where was the Nobel Prize in Chemistry awarded to? or The Nobel Prize in Chemistry was awarded to a person from anywhere except .", "How is the Nobel Prize in Chemistry ? or The Nobel Prize in Chemistry was to a person from anywhere except the Netherlands ."], "answers": [["The Nobel Prize in Chemistry", 0, 28], ["a person", 44, 52], ["anywhere", 58, 66], ["the Netherlands", 74, 89], ["awarded", 33, 40]], "evidential": [["The Nobel Prize in Chemistry", "A Nobel Prize in Chemistry", "Nobel Prize in Chemistry", "The Nobel prize in Chemistry"], ["scientists", "a scientist", "people", "anyone"], ["every country", "every state", "every place", "all"], ["the Netherlands", "Sweden", "Europe", "Norway"], ["awarded", "given", "presented", "distributed"]], "culprit": [2, 3]} +{"id": 113010, "claim": "Duane Chapman is not a former bail bondsman .", "questions": ["Who is not a former bail bondsman? or is not a former bail bondsman .", "What is Duane Chapman's profession? or Duane Chapman is not ."], "answers": [["Duane Chapman", 0, 13], ["a former bail bondsman", 21, 43]], "evidential": [["Duane Chapman", "Duane ChapmanI.", "Duane Chapman I.", "Duane Chapman II."], ["a bail bondsman", "a bounty hunter", "a former bail bondsman", "a bail bondsman"]], "culprit": [1]} +{"id": 109582, "claim": "US Airways Flight 1549 did not have any people on board .", "questions": ["What flight did not have any people on board? or did not have any people on board .", "What was not on board the US Airways Flight 1549? or US Airways Flight 1549 did not have any on board .", "What was the name of the aircraft that did not have any people on? or US Airways Flight 1549 did not have any people on ."], "answers": [["US Airways Flight 1549", 0, 22], ["people", 40, 46], ["board", 50, 55]], "evidential": [["US Airways Flight 1549", "The US Airways Flight 1549", "American Airways Flight 1549", "United Airways Flight 1549"], ["people", "passengers", "humans", "birds"], ["an aircraft", "an Airbus A320", "the Airbus A320", "an airliner"]], "culprit": [2]} +{"id": 23766, "claim": "Charles de Gaulle was a Polish Resistance leader .", "questions": ["Who was the leader of the Polish Resistance? or was a Polish Resistance leader .", "What nationality was Charles de Gaulle? or Charles de Gaulle was Resistance leader .", "What political party was Charles de Gaulle a leader of? or Charles de Gaulle was a Polish leader .", "What was Charles de Gaulle's role in the Polish Resistance? or Charles de Gaulle was a Polish Resistance ."], "answers": [["Charles de Gaulle", 0, 17], ["a Polish", 22, 30], ["Resistance", 31, 41], ["leader", 42, 48]], "evidential": [["Charles de Gaulle", "Charles De Gaulle", "Charles de Gaulle", "Louis de Gaulle"], ["a French", "an American", "the French", "an English"], ["French", "Nationalist", "Communist", "National Socialist"], ["leader", "chief leader", "person", "chief strategist"]], "culprit": [1]} +{"id": 94556, "claim": "Pirates of the Caribbean has yet to be opened in Disneyland Paris .", "questions": ["What is the name of the movie that has yet to be opened at Disneyland Paris? or has yet to be opened in Disneyland Paris .", "Where is Pirates of the Caribbean currently located? or Pirates of the Caribbean has yet to be opened in .", "What is the name of the first attraction to open at Disneyland Paris? or Pirates of the Caribbean has yet to be in Disneyland Paris .", "How long has it been since the Pirates of the Caribbean opened? or Pirates of the Caribbean has to be opened in Disneyland Paris ."], "answers": [["Pirates of the Caribbean", 0, 24], ["Disneyland Paris", 49, 65], ["opened", 39, 45], ["yet", 29, 32]], "evidential": [["Pirates of the Caribbean", "The Pirates of the Caribbean", "Pirates of The Caribbean", "Pirates of Caribbean"], ["Disneyland Paris", "Disney Disneyland Paris", "Disney Paris", "Disney Park"], ["an attraction", "the first attraction", "a ride", "the first ride"], ["yet", "a decade", "a year", "the time"]], "culprit": [2, 3]} +{"id": 225871, "claim": "Revolver has only ever topped a single chart .", "questions": ["What has only ever topped a single chart? or has only ever topped a single chart .", "How many charts has Revolver ever topped? or Revolver has only ever topped .", "How many times has Revolver ever a single chart? or Revolver has only ever a single chart .", "How many times has Revolver topped a single chart? or Revolver has ever topped a single chart ."], "answers": [["Revolver", 0, 8], ["a single chart", 30, 44], ["topped", 23, 29], ["only", 13, 17]], "evidential": [["Revolver", "Revololver", "Revolver Record", "The Revolver"], ["four charts", "two charts", "three charts", "zero charts"], ["topped", "charted", "reached", "appeared"], ["never", "n't", "only", "rarely"]], "culprit": [1, 3]} +{"id": 164417, "claim": "Carey Hayes was born in 1897 .", "questions": ["Who was born in 1897? or was born in 1897 .", "When was Carey Hayes born? or Carey Hayes was born in .", "What was Carey Hayes' birth year? or Carey Hayes was in 1897 ."], "answers": [["Carey Hayes", 0, 11], ["1897", 24, 28], ["born", 16, 20]], "evidential": [["Carey Hayes", "Carey Hayes (", "Carey Hayes", "Carey Hayden"], ["1961", "the 1960s", "the 1960 's", "the 20th century"], ["born", "conceived", "created", "born in"]], "culprit": [1]} +{"id": 70311, "claim": "IMDb is a professional Dota 2 player .", "questions": ["What is the name of the professional Dota 2 player? or is a professional Dota 2 player .", "What is IMDb's professional name? or IMDb is a professional 2 player .", "How many players does IMDb have? or IMDb is a professional Dota .", "IMDb is what type of player? or IMDb is Dota 2 player ."], "answers": [["IMDb", 0, 4], ["Dota", 23, 27], ["2 player", 28, 36], ["a professional", 8, 22]], "evidential": [["IMDb", "The Internet Movie Database", "The internet movie database", "The internet Movie Database"], ["Game", "Web", "video game", "Webmaster"], ["users", "one player", "one user", "user"], ["an online database", "a fictional", "a popular", "a professional"]], "culprit": [1, 2, 3]} +{"id": 123479, "claim": "The Hundred Years ' War does not include the Lancastrian War .", "questions": ["What does not include the Lancastrian War? or does not include the Lancastrian War .", "What is not included in the Hundred Years' War? or The Hundred Years ' War does not include .", "What does the Hundred Years' War not ? or The Hundred Years ' War does not the Lancastrian War ."], "answers": [["The Hundred Years ' War", 0, 23], ["the Lancastrian War", 41, 60], ["include", 33, 40]], "evidential": [["The Hundred Years ' War", "The Hundred Years' War", "Hundred Years ' War", "A Hundred Years ' War"], ["a conflict", "local conflicts", "a war", "several conflicts"], ["mention", "name", "see", "include"]], "culprit": [1]} +{"id": 16811, "claim": "Efraim Diveroli is a Spaniard .", "questions": ["Who is a Spaniard? or is a Spaniard .", "What is Efraim Diveroli's nationality? or Efraim Diveroli is ."], "answers": [["Efraim Diveroli", 0, 15], ["a Spaniard", 19, 29]], "evidential": [["Efraim Diveroli", "Efranim Diveroli", "Efriim Diveroli", "Efrafri Diveroli"], ["an American", "American", "North American", "a North American"]], "culprit": [1]} +{"id": 183618, "claim": "Finding Dory was written by anyone except Andrew Stanton .", "questions": ["What was the name of the book that was written by anyone other than Andrew Stanton? or was written by anyone except Andrew Stanton .", "Who wrote Finding Dory? or Finding Dory was written by except Andrew Stanton .", "Who wrote Finding Dory? or Finding Dory was written by anyone except .", "Who else wrote Finding Dory? or Finding Dory was by anyone except Andrew Stanton ."], "answers": [["Finding Dory", 0, 12], ["anyone", 28, 34], ["Andrew Stanton", 42, 56], ["written", 17, 24]], "evidential": [["Finding Dory", "The Finding Dory", "Finding dory", "Finding Dory 2"], ["anyone", "every person", "almost anyone", "almost all"], ["Andrew Stanton", "Andrew Strouse", "Andrew Stanton", "Andy Stanton"], ["written", "penned", "directed", "authored"]], "culprit": [2]} +{"id": 125315, "claim": "Phoenix , Arizona is the most populous city in Massachusetts .", "questions": ["What is the most populous city in Massachusetts? or , Arizona is the most populous city in Massachusetts .", "What state is Phoenix located in? or Phoenix , is the most populous city in Massachusetts .", "What state is Phoenix located in? or Phoenix , Arizona is the most populous city in .", "What is the population of Phoenix in Massachusetts? or Phoenix , Arizona is populous city in Massachusetts .", "What is the population of Phoenix? or Phoenix , Arizona is the most city in Massachusetts ."], "answers": [["Phoenix", 0, 7], ["Arizona", 10, 17], ["Massachusetts", 47, 60], ["the most", 21, 29], ["populous", 30, 38]], "evidential": [["Phoenix", "The Phoenix", "Arizona Phoenix", "Tempe"], ["Arizona", "Arizona Republic", "Arizona State", "United States"], ["the United States", "the US", "Arizona", "a state"], ["the most", "the fifth most", "the 5th most", "the fourth most"], ["populous", "populous city", "populous US", "large"]], "culprit": [2]} +{"id": 216367, "claim": "All speakers of the Chagatai language lived in France .", "questions": ["Who lived in France? or All of the Chagatai language lived in France .", "What language did all French speakers speak? or All speakers of lived in France .", "Where did the Chagatai language live? or All speakers of the Chagatai language lived in .", "Where did all speakers of the Chagatai language live? or All speakers of the Chagatai language in France ."], "answers": [["speakers", 4, 12], ["the Chagatai language", 16, 37], ["France", 47, 53], ["lived", 38, 43]], "evidential": [["The authors", "An author", "People", "A person"], ["the Chagatai language", "Chagatai language", "The Chagatai language", "a Chagatai language"], ["Europe", "a place", "France", "Asia"], ["lived", "existed", "resided", "originated"]], "culprit": [2]} +{"id": 23428, "claim": "The Cincinnati Kid was directed by Norman Jewison in 1960 .", "questions": ["What movie was directed by Norman Jewison? or was directed by Norman Jewison in 1960 .", "Who directed The Cincinnati Kid? or The Cincinnati Kid was directed by in 1960 .", "When was The Cincinnati Kid directed? or The Cincinnati Kid was directed by Norman Jewison in .", "What was the name of the film that was produced by Norman Jewison? or The Cincinnati Kid was by Norman Jewison in 1960 ."], "answers": [["The Cincinnati Kid", 0, 18], ["Norman Jewison", 35, 49], ["1960", 53, 57], ["directed", 23, 31]], "evidential": [["The Cincinnati Kid", "the Cincinnati Kid", "The CincinnatiKid", "Cincinnati Kid"], ["Norman Jewison", "a man", "Norman JewISON", "Norman Jewisons"], ["1965", "the 1960s", "the 1960 's", "the late 1960s"], ["directed", "produced", "written", "filmed"]], "culprit": [2]} +{"id": 67903, "claim": "Murda Beatz 's real name is Donald Trump .", "questions": ["Who is Donald Trump's real name? or 's real name is Donald Trump .", "What is Beatz' real name? or Murda Beatz 's real name is .", "What is the name of Murda Beatz? or Murda Beatz 's name is Donald Trump ."], "answers": [["Murda Beatz", 0, 11], ["Donald Trump", 28, 40], ["real", 15, 19]], "evidential": [["Murda Beatz", "Murdas Beatz", "Murda beatz", "Murdac Beatz"], ["Donald Trump", "a Donald Trump", "Donald Donald Trump", "Donald John Trump"], ["middle", "real", "full", "legal"]], "culprit": [1]} +{"id": 45585, "claim": "Harris Jayaraj is from Idaho .", "questions": ["Who is from Idaho? or is from Idaho .", "Where is Harris Jayaraj from? or Harris Jayaraj is from ."], "answers": [["Harris Jayaraj", 0, 14], ["Idaho", 23, 28]], "evidential": [["Harris Jayaraj", "Harris Jayaram", "Harris Jayarbaj", "Harris Jayaraja"], ["a state", "Idaho", "a place", "America"]], "culprit": [1]} +{"id": 95601, "claim": "Ian Gillan is only a singer .", "questions": ["Who is the only singer? or is only a singer .", "What is Ian Gillan's job? or Ian Gillan is ."], "answers": [["Ian Gillan", 0, 10], ["only a singer", 14, 27]], "evidential": [["Ian Gillan", "Ian Gillan", "Ian Gillan", "Ian Gillans"], ["a singer", "a vocalist", "a singer and songwriter", "a performer"]], "culprit": [1]} +{"id": 122348, "claim": "Wolfgang Amadeus Mozart never married .", "questions": ["Who never married? or never married .", "What did Wolfgang Amadeus Mozart never do? or Wolfgang Amadeus Mozart never .", "How did Wolfgang Amadeus Mozart get married? or Wolfgang Amadeus Mozart married ."], "answers": [["Wolfgang Amadeus Mozart", 0, 23], ["married", 30, 37], ["never", 24, 29]], "evidential": [["Wolfgang Amadeus Mozart", "Amadeus Mozart", "Johannes Amadeus Mozart", "Wolfgang Amadeu Mozart"], ["married", "marry", "died", "live"], ["got", "eventually", "later", "was"]], "culprit": [2]} +{"id": 146157, "claim": "The New England Patriots lost five Super Bowls .", "questions": ["Who lost five Super Bowls? or lost five Super Bowls .", "What type of game did the New England Patriots lose? or The New England Patriots lost five .", "How many Super Bowls did the New England Patriots win? or The New England Patriots five Super Bowls .", "How many Super Bowls did the New England Patriots lose? or The New England Patriots lost Super Bowls ."], "answers": [["The New England Patriots", 0, 24], ["Super Bowls", 35, 46], ["lost", 25, 29], ["five", 30, 34]], "evidential": [["New England Patriots", "The Patriots", "The New Patriots", "Patriots"], ["Super Bowls", "a Super Bowl", "the Super Bowl", "a football game"], ["won", "played", "reached", "achieved"], ["five", "5", "least five", "seven"]], "culprit": [2]} +{"id": 107699, "claim": "Floyd Mayweather Jr. is incapable of boxing .", "questions": ["Who is incapable of boxing? or is incapable of boxing .", "Floyd Mayweather Jr. is incapable of what sport? or Floyd Mayweather Jr. is incapable of .", "Is Floyd Mayweather Jr. capable or of boxing? or Floyd Mayweather Jr. is of boxing ."], "answers": [["Floyd Mayweather Jr.", 0, 20], ["boxing", 37, 43], ["incapable", 24, 33]], "evidential": [["Floyd Mayweather Jr.", "Floyd Mayweather Jr .", "Floyd Mayweather Jr.?", "Floyd Mayweather Jr.:"], ["boxing", "professional boxing", "boxed", "a sport"], ["incapable", "capable", "a capable", "an athlete"]], "culprit": [2]} +{"id": 216594, "claim": "Calcaneal spurs are only detected by a dancing technique .", "questions": ["What is only detected by a dancing technique? or are only detected by a dancing technique .", "What is the only way to detect Calcaneal spurs? or Calcaneal spurs are only detected by .", "How are Calcaneal spurs ? or Calcaneal spurs are only by a dancing technique .", "How are Calcaneal spurs detected by a dancing technique? or Calcaneal spurs are detected by a dancing technique ."], "answers": [["Calcaneal spurs", 0, 15], ["a dancing technique", 37, 56], ["detected", 25, 33], ["only", 20, 24]], "evidential": [["Calcaneal spurs", "Calcaneal spur", "Calcaneals spurs", "Calcane al spurs"], ["a radiographic examination", "an x ray", "radiographic examination", "a radiographic exam"], ["detected", "observed", "seen", "indicated"], ["typically", "usually", "often", "frequently"]], "culprit": [1, 3]} +{"id": 118068, "claim": "Liverpool is unrelated to The Beatles .", "questions": ["What city is not related to The Beatles? or is unrelated to The Beatles .", "Liverpool is not related to what band? or Liverpool is unrelated to .", "Is Liverpool related to The Beatles? or Liverpool is to The Beatles ."], "answers": [["Liverpool", 0, 9], ["The Beatles", 26, 37], ["unrelated", 13, 22]], "evidential": [["Liverpool", "The Liverpool", "Liverpool City", "Liverpool"], ["The Beatles", "the Beatles", "a rock band", "a band"], ["related", "connected", "a home", "home"]], "culprit": [2]} +{"id": 110504, "claim": "The Mighty Ducks was only distributed by a subsidiary of 20th Century Fox .", "questions": ["What was the name of the show that was distributed by 20th Century Fox? or was only distributed by a subsidiary of 20th Century Fox .", "Who distributed the Mighty Ducks? or The Mighty Ducks was only distributed by of 20th Century Fox .", "Who distributed the Mighty Ducks? or The Mighty Ducks was only distributed by a subsidiary of .", "How was the Mighty Ducks ? or The Mighty Ducks was only by a subsidiary of 20th Century Fox .", "How many times was The Mighty Ducks distributed by 20th Century Fox? or The Mighty Ducks was distributed by a subsidiary of 20th Century Fox ."], "answers": [["The Mighty Ducks", 0, 16], ["a subsidiary", 41, 53], ["20th Century Fox", 57, 73], ["distributed", 26, 37], ["only", 21, 25]], "evidential": [["The Mighty Ducks", "The Mighty Ducks of Anaheim", "The Mighty Duck", "Mighty Ducks"], ["the parent company", "a division", "a subsidiary", "the company"], ["Walt Disney Pictures", "Disney Pictures", "a company", "Walt Disney Productions"], ["distributed", "produced", "released", "created"], ["only", "never", "twice", "previously"]], "culprit": [1, 2, 4]} +{"id": 161151, "claim": "No Strings Attached was released on May 21 .", "questions": ["What was released on May 21? or No was released on May 21 .", "When was No Strings Attached released? or No Strings Attached was released on .", "When was No Strings Attached ? or No Strings Attached was on May 21 ."], "answers": [["Strings Attached", 3, 19], ["May 21", 36, 42], ["released", 24, 32]], "evidential": [["Strings Attached", "strings Attached", "Strings Attached album", "Strings Attached film"], ["January 21 , 2011", "January 21st", "January 21st 2011", "January 21"], ["released", "published", "issued", "distributed"]], "culprit": [1]} +{"id": 150099, "claim": "Sherilyn Fenn is Japanese .", "questions": ["Who is the name of the Japanese woman who is a native of Japan? or is Japanese .", "What language is Sherilyn Fenn? or Sherilyn Fenn is ."], "answers": [["Sherilyn Fenn", 0, 13], ["Japanese", 17, 25]], "evidential": [["Sherilyn Fenn", "The Sherilyn Fenn", "Sherilyn Fenna", "Cherilyn Fenn"], ["American", "English", "North American", "French"]], "culprit": [1]} +{"id": 157652, "claim": "Touchscreens are only used in gaming computers .", "questions": ["What type of screen is used in gaming computers? or are only used in gaming computers .", "What type of computers are touch screens used for? or Touchscreens are only used in .", "What is the only way a touch screen can be in gaming computers? or Touchscreens are only in gaming computers .", "How are touchscreens used in gaming computers? or Touchscreens are used in gaming computers ."], "answers": [["Touchscreens", 0, 12], ["gaming computers", 30, 46], ["used", 22, 26], ["only", 17, 21]], "evidential": [["Touchscreens", "Touchscreen", "Touchscreen devices", "Touch screens"], ["personal computers", "electronic voting machines", "computer systems", "mobile computers"], ["common", "used", "found", "prevalent"], ["commonly", "frequently", "increasingly", "widely"]], "culprit": [3]} +{"id": 209863, "claim": "In a Lonely Place had nothing to do with any novel by Dorthy B. Hughes .", "questions": ["What was the name of the book that had nothing to do with any novel by Dorthy or had nothing to do with any novel by Dorthy B. Hughes .", "What did In a Lonely Place have to do with Dorthy B. Hughes or In a Lonely Place had to do with any novel by Dorthy B. Hughes .", "What type of work did In a Lonely Place have nothing to do with? or In a Lonely Place had nothing to do with any by Dorthy B. Hughes .", "Who wrote In a Lonely Place? or In a Lonely Place had nothing to do with any novel by ."], "answers": [["In a Lonely Place", 0, 17], ["nothing", 22, 29], ["novel", 45, 50], ["Dorthy B. Hughes", 54, 70]], "evidential": [["In a Lonely Place", "in a Lonely Place", "In a Lonely place", "In a Lonely Place ."], ["a lot", "a thing", "nothing", "a script"], ["novels", "mystery work", "written work", "written works"], ["Dorothy B. Hughes", "a mystery writer", "the mystery writer", "the author"]], "culprit": [1, 2, 3]} +{"id": 3305, "claim": "Julianne Moore was not in the television series As the World Turns .", "questions": ["Who was not in the television series As The World Turns? or was not in the television series As the World Turns .", "What was Julianne Moore not in? or Julianne Moore was not in As the World Turns .", "What television series did Julianne Moore not appear in? or Julianne Moore was not in the television series As ."], "answers": [["Julianne Moore", 0, 14], ["the television series", 26, 47], ["the World Turns", 51, 66]], "evidential": [["Julianne Moore", "Juliene Moore", "Juliann Moore", "Julianna Moore"], ["the soap opera", "the television show", "the television series", "the show"], ["the World Turns", "The World Turns", "the World turns", "a World Turns"]], "culprit": [1, 2]} +{"id": 83351, "claim": "In 2015 , among Mexicans , 70 % of adults had consumed alcoholic drink in the last year .", "questions": ["In what year did 70 % of Mexican adults drink alcohol? or In , among Mexicans , 70 % of adults had consumed alcoholic drink in the last year .", "What ethnicity had the highest percentage of alcoholic beverages in 2015? or In 2015 , among , 70 % of adults had consumed alcoholic drink in the last year .", "What percentage of Mexican adults had consumed alcohol in 2015? or In 2015 , among Mexicans , of adults had consumed alcoholic drink in the last year .", "What group of Mexicans consumed alcohol in 2015? or In 2015 , among Mexicans , 70 % of had consumed alcoholic drink in the last year .", "What type of drink did 70 % of Mexican adults consume in 2015? or In 2015 , among Mexicans , 70 % of adults had consumed in the last year .", "In what year did 70 % of Mexican adults drink alcohol? or In 2015 , among Mexicans , 70 % of adults had consumed alcoholic drink in .", "What did 70 % of adults in Mexico do with alcoholic beverages? or In 2015 , among Mexicans , 70 % of adults had alcoholic drink in the last year ."], "answers": [["2015", 3, 7], ["Mexicans", 16, 24], ["70 %", 27, 31], ["adults", 35, 41], ["alcoholic drink", 55, 70], ["the last year", 74, 87], ["consumed", 46, 54]], "evidential": [["2015", "2015 's", "the 2015 year", "the last year"], ["Americans", "Mexican", "the Mexican", "Mexicans"], ["89 %", "90 %", "70 %", "87 %"], ["adults", "people", "adult", "Americans"], ["alcohol", "alcoholic drink", "alcoholic drinks", "alcoholic beverages"], ["the last year", "the past year", "the year", "2015"], ["drank", "drunk", "consumed", "drinking"]], "culprit": [1]} +{"id": 97937, "claim": "Watchmen is a film set in the future .", "questions": ["What is the name of the film set in the future? or is a film set in the future .", "What type of film is Watchmen? or Watchmen is set in the future .", "What is the setting of Watchmen? or Watchmen is a film set in .", "Where is Watchmen ? or Watchmen is a film in the future ."], "answers": [["Watchmen", 0, 8], ["a film", 12, 18], ["the future", 26, 36], ["set", 19, 22]], "evidential": [["Watchmen", "Watchmen ( film )", "Watchmen( film )", "Watchmen(film )"], ["a superhero film", "a film", "a dystopian film", "a cinematic film"], ["an alternate history", "a dystopian history", "a dystopian future", "a past"], ["set", "located", "filmed", "based"]], "culprit": [2]} +{"id": 8298, "claim": "Simon Pegg is only a banker .", "questions": ["Who is a banker? or is only a banker .", "What is Simon Pegg's job title? or Simon Pegg is ."], "answers": [["Simon Pegg", 0, 10], ["only a banker", 14, 27]], "evidential": [["Simon Pegg", "Simon Pgg", "Simon pegg", "Simon Pegg"], ["a producer", "a screenwriter", "an entertainer", "an executive producer"]], "culprit": [1]} +{"id": 193862, "claim": "Barry Van Dyke is the first son of Dick Van Dyke .", "questions": ["Who is the first son of Dick Van Dyke? or is the first son of Dick Van Dyke .", "What is Barry Van Dyke's first name? or Barry Van Dyke is of Dick Van Dyke .", "Who is Barry Van Dyke's father? or Barry Van Dyke is the first son of ."], "answers": [["Barry Van Dyke", 0, 14], ["the first son", 18, 31], ["Dick Van Dyke", 35, 48]], "evidential": [["Barry Van Dyke", "Barry van Dyke", "Dick Van Dyke", "A man"], ["the second son", "the first son", "the second child", "the son"], ["Dick Van Dyke", "an entertainer", "an actor", "a comedian"]], "culprit": [1]} +{"id": 55279, "claim": "Helmand Province contains a city .", "questions": ["What province contains a city? or contains a city .", "What does Helmand Province contain? or Helmand Province contains .", "What is the name of the city in Helmand Province? or Helmand Province a city ."], "answers": [["Helmand Province", 0, 16], ["a city", 26, 32], ["contains", 17, 25]], "evidential": [["Helmand Province", "Helmand province", "Helmand Provincial", "Helmand District"], ["people", "a city", "a town", "a population"], ["contains", "includes", "possesses", "features"]], "culprit": [1]} +{"id": 69871, "claim": "Robert Zemeckis has rarely directed movies .", "questions": ["Who has rarely directed a movie? or has rarely directed movies .", "What type of film has Zemeckis rarely directed? or Robert Zemeckis has rarely directed .", "What type of movies has Zemeckis rarely made? or Robert Zemeckis has rarely movies .", "How often has Zemeckis directed movies? or Robert Zemeckis has directed movies ."], "answers": [["Robert Zemeckis", 0, 15], ["movies", 36, 42], ["directed", 27, 35], ["rarely", 20, 26]], "evidential": [["Robert Zemeckis", "Robert Zemeckis", "Robert Zemckis", "Robert Memeckis"], ["a film", "a drama film", "a comedy", "a comedy film"], ["directed", "direct", "produced", "directing"], ["never", "rarely", "always", "only"]], "culprit": [3]} +{"id": 48276, "claim": "Raees ( film ) stars an Indian film actor born in April 1965 .", "questions": ["What film stars an Indian actor? or stars an Indian film actor born in April 1965 .", "What nationality is Raees? or Raees ( film ) stars film actor born in April 1965 .", "What is Raees' career? or Raees ( film ) stars an Indian born in April 1965 .", "When was Raees born? or Raees ( film ) stars an Indian film actor born in .", "What is Raees' career? or Raees ( film ) an Indian film actor born in April 1965 .", "What is the birth year of Raees? or Raees ( film ) stars an Indian film actor in April 1965 ."], "answers": [["Raees ( film )", 0, 14], ["an Indian", 21, 30], ["film actor", 31, 41], ["April 1965", 50, 60], ["stars", 15, 20], ["born", 42, 46]], "evidential": [["Raees ( film )", "Raees", "Raees( film )", "Raes ( film )"], ["an Indian", "a Indian", "An Indian", "an India"], ["film actor", "film actor and television personality", "actor", "television personality"], ["1965", "the sixties", "the 1960s", "the year 1965"], ["stars", "features", "starred", "includes"], ["born", "birth year", "birth date", "founded"]], "culprit": [3]} +{"id": 101845, "claim": "Richard Kuklinski is a innocent man .", "questions": ["Who is an innocent man? or is a innocent man .", "What is Richard Kuklinski? or Richard Kuklinski is ."], "answers": [["Richard Kuklinski", 0, 17], ["a innocent man", 21, 35]], "evidential": [["Richard Kuklinski", "Richard Kuklinski", "Richard Kuklinsky", "Richard Kuplinski"], ["a person", "a killer", "a serial killer", "a criminal"]], "culprit": [1]} +{"id": 44240, "claim": "Amancio Ortega refuses to be a businessman .", "questions": ["Who refuses to be a businessman? or refuses to be a businessman .", "What does Amancio Ortega refuse to be? or Amancio Ortega refuses to be .", "What does Amancio Ortega do to be a businessman? or Amancio Ortega to be a businessman ."], "answers": [["Amancio Ortega", 0, 14], ["a businessman", 29, 42], ["refuses", 15, 22]], "evidential": [["Amancio Ortega", "Amancio Ortega Gaona", "Amancio Ortega Jr.", "Amancio Orlando Ortega"], ["a businessman", "a tycoon", "a person", "a businessperson"], ["wants", "used", "works", "acts"]], "culprit": [2]} +{"id": 142735, "claim": "Elizabeth I was the daughter of a salesman .", "questions": ["What was my mother's name? or I was the daughter of a salesman .", "What was Elizabeth I's mother's name? or Elizabeth I was of a salesman .", "What was Elizabeth I's father's occupation? or Elizabeth I was the daughter of ."], "answers": [["Elizabeth", 0, 9], ["the daughter", 16, 28], ["a salesman", 32, 42]], "evidential": [["Elizabeth", "Elizabeth I", "ElizabethI", "Elizabeth II"], ["the daughter", "the second daughter", "the first daughter", "the second wife"], ["a man", "a second wife", "Henry VIII", "a person"]], "culprit": [2]} +{"id": 167977, "claim": "Don Bradman was called the \" greatest living Australian \" by a President .", "questions": ["Who was called the \"greatest living Australian\" by a President? or was called the \" greatest living Australian \" by a President .", "What nationality was Don Bradman? or Don Bradman was called the \" greatest living \" by a President .", "What was Bradman called by a President? or Don Bradman was called the \" greatest living Australian by a President .", "Who called Don Bradman the \"greatest living Australian\"? or Don Bradman was called the \" greatest living Australian \" by .", "What was Don Bradman called by a President? or Don Bradman was called Australian \" by a President ."], "answers": [["Don Bradman", 0, 11], ["Australian", 45, 55], ["\"", 56, 57], ["a President", 61, 72], ["the \" greatest living", 23, 44]], "evidential": [["Don Bradman", "Donald Bradman", "Don Bradm", "An Australian"], ["Australian", "American", "an Australian", "Australia"], ["person", "\"", "honored", "icon"], ["Prime Minister John Howard", "John Howard", "a Prime Minister", "the Prime Minister"], ["the \" greatest living", "the \" great living", "the \" best living", "the \" Greatest living"]], "culprit": [3]} +{"id": 227084, "claim": "Roar ( song ) is a Katy Perry song from her fifth album .", "questions": ["What is the name of Katy Perry's fifth album? or is a Katy Perry song from her fifth album .", "What is the name of the song Roar? or Roar ( song ) is song from her fifth album .", "What is Roar? or Roar ( song ) is a Katy Perry from her fifth album .", "What album is Roar from? or Roar ( song ) is a Katy Perry song from her ."], "answers": [["Roar ( song )", 0, 13], ["a Katy Perry", 17, 29], ["song", 30, 34], ["fifth album", 44, 55]], "evidential": [["Roar", "Roars", "Roar .", "Rar"], ["a Katy Perry", "an Katy Perry", "an American", "an artist 's"], ["song", "title", "single", "track"], ["fourth studio album", "fourth album", "fourth record", "fourth studio record"]], "culprit": [3]} +{"id": 205646, "claim": "St. Anger is the second studio album by Metallica .", "questions": ["What is the name of Metallica's second album? or is the second studio album by Metallica .", "What is the name of the second album by Metallica? or St. Anger is by Metallica .", "What band released St. Anger? or St. Anger is the second studio album by ."], "answers": [["St. Anger", 0, 9], ["the second studio album", 13, 36], ["Metallica", 40, 49]], "evidential": [["St. Anger", "The St. Anger", "St . Anger", "St. Anger ."], ["the eighth studio album", "an album", "an eighth studio album", "the eighth album"], ["Metallica", "a heavy metal band", "the Metallica", "a heavy metal group"]], "culprit": [1]} +{"id": 209095, "claim": "Stadium Arcadium was released after 2009 .", "questions": ["What stadium was released after 2009? or was released after 2009 .", "In what year was Stadium Arcadium released? or Stadium Arcadium was released after .", "What happened to Stadium Arcadium after 2009? or Stadium Arcadium was after 2009 .", "When was Stadium Arcadium released? or Stadium Arcadium was released 2009 ."], "answers": [["Stadium Arcadium", 0, 16], ["2009", 36, 40], ["released", 21, 29], ["after", 30, 35]], "evidential": [["Stadium Arcadium", "Stadium Arcadia", "Stadium Arcadadium", "Stadium Arcadion"], ["2006", "the 2000s", "a different year", "a 2006 album"], ["released", "disbanded", "dropped", "cancelled"], ["before", "after", "around", "back"]], "culprit": [1, 3]} +{"id": 155657, "claim": "The Prowler was created by Stan Lee , John Buscema , and dust .", "questions": ["What was the name of the film created by Stan Lee, John Buscema and Dust or was created by Stan Lee , John Buscema , and dust .", "Who created The Prowler? or The Prowler was created by , John Buscema , and dust .", "Who created The Prowler? or The Prowler was created by Stan Lee , , and dust .", "What was the Prowler made of? or The Prowler was created by Stan Lee , John Buscema , and .", "How was The Prowler ? or The Prowler was by Stan Lee , John Buscema , and dust ."], "answers": [["The Prowler", 0, 11], ["Stan Lee", 27, 35], ["John Buscema", 38, 50], ["dust", 57, 61], ["created", 16, 23]], "evidential": [["The Prowler", "The Prowler ( 1981 film )", "The Prowler( 1981 film )", "Prowler ( 1981 film )"], ["Stan Lee", "Jim Mooney", "writer editor", "writer and editor"], ["comics editor", "comics writers", "people", "comics editors"], ["a writer", "a person", "characters", "comics"], ["created", "produced", "designed", "invented"]], "culprit": [3]} +{"id": 172095, "claim": "Selena Gomez & the Scene 's debut album was released in any month except September .", "questions": ["What group's debut album was released in any month except September? or 's debut album was released in any month except September .", "Selena Gomez & the Scene's debut album was released in what or Selena Gomez & the Scene 's debut album was released in any except September .", "Selena Gomez & the Scene's debut album was released in what month or Selena Gomez & the Scene 's debut album was released in any month except .", "When was Selena Gomez's debut album ? or Selena Gomez & the Scene 's debut album was in any month except September ."], "answers": [["Selena Gomez & the Scene", 0, 24], ["month", 60, 65], ["September", 73, 82], ["released", 44, 52]], "evidential": [["Selena Gomez & the Scene", "The Selena Gomez & the Scene", "Selena Gomez & The Scene", "Selena Gomez & the Scene"], ["September", "the month", "the US", "the summer"], ["September", "October", "July", "August"], ["released", "published", "issued", "launched"]], "culprit": [2]} +{"id": 191441, "claim": "Keith Urban was released by Sony Music Entertainment .", "questions": ["What artist was released by Sony Music Entertainment? or was released by Sony Music Entertainment .", "What company released Keith Urban? or Keith Urban was released by .", "When was Keith Urban ? or Keith Urban was by Sony Music Entertainment ."], "answers": [["Keith Urban", 0, 11], ["Sony Music Entertainment", 28, 52], ["released", 16, 24]], "evidential": [["Keith Urban", "Keith Urban II", "Keith U.", "The Keith Urban"], ["Capitol Nashville", "Capitol Records", "a company", "Capitol"], ["released", "created", "signed", "founded"]], "culprit": [1]} +{"id": 188640, "claim": "Foot Locker operates in only 11 countries .", "questions": ["What company operates in only 11 countries? or operates in only 11 countries .", "How many countries does Foot Locker operate in? or Foot Locker operates in countries .", "How does Foot Locker operate in 11 countries? or Foot Locker in only 11 countries ."], "answers": [["Foot Locker", 0, 11], ["only 11", 24, 31], ["operates", 12, 20]], "evidential": [["Foot Locker", "Foot Locker , Inc.", "Foot Locker ( Inc.", "Foot Locker Inc."], ["28", "least 28", "27", "29"], ["operates", "operate", "exists", "runs"]], "culprit": [1]} +{"id": 164407, "claim": "Carey Hayes is only a German lawyer .", "questions": ["Who is a German lawyer? or is only a German lawyer .", "What nationality is Hayes? or Carey Hayes is only lawyer .", "What is Hayes' profession? or Carey Hayes is only a German .", "How old is Carey Hayes? or Carey Hayes is German lawyer ."], "answers": [["Carey Hayes", 0, 11], ["a German", 20, 28], ["lawyer", 29, 35], ["only a", 15, 21]], "evidential": [["Carey Hayes", "Carey Hayes Jr.", "Carey Hayes", "Carey Hayden"], ["an American", "an american", "a North American", "an Oregon"], ["writer", "screenwriter", "a writer", "author"], ["a 21st century", "a 21 year old", "a 21-year old", "a young"]], "culprit": [1, 2, 3]} +{"id": 83545, "claim": "Volkswagen Group declines financing , leasing , and fleet management .", "questions": ["Which company declines financing, leasing and fleet management? or declines financing , leasing , and fleet management .", "What does Volkswagen Group decline? or Volkswagen Group declines financing , leasing , and .", "What does Volkswagen Group do with financing, leasing and fleet management? or Volkswagen Group financing , leasing , and fleet management ."], "answers": [["Volkswagen Group", 0, 16], ["fleet management", 52, 68], ["declines", 17, 25]], "evidential": [["Volkswagen Group", "The Volkswagen Group", "VW Group", "Volkswagen group"], ["fleet management", "fleet management services", "fleets management", "vehicles fleet management"], ["offers", "provides", "performs", "facilitates"]], "culprit": [2]} +{"id": 97837, "claim": "Caroline Kennedy is against diplomacy .", "questions": ["Who is against diplomacy? or is against diplomacy .", "Caroline Kennedy is against what? or Caroline Kennedy is against ."], "answers": [["Caroline Kennedy", 0, 16], ["diplomacy", 28, 37]], "evidential": [["Caroline Kennedy", "Caroline Flemming", "Caroline Klemming", "Caroline Kennedy"], ["politics", "the Democratic Party", "a presidential election", "a presidential campaign"]], "culprit": [1]} +{"id": 229309, "claim": "A working animal is released by humans .", "questions": ["What is released by humans? or is released by humans .", "Who releases a working animal? or A working animal is released by .", "What happens to a working animal when it is ? or A working animal is by humans ."], "answers": [["A working animal", 0, 16], ["humans", 32, 38], ["released", 20, 28]], "evidential": [["A working animal", "A Working animal", "Working animal", "An animal"], ["humans", "a human", "human beings", "people"], ["kept", "domesticated", "raised", "captured"]], "culprit": [2]} +{"id": 98672, "claim": "Balibo ( film ) starts in the year 1995 .", "questions": ["What film was released in 1995? or starts in the year 1995 .", "When does Balibo begin? or Balibo ( film ) starts in .", "When does Balibo begin? or Balibo ( film ) in the year 1995 ."], "answers": [["Balibo ( film )", 0, 15], ["the year 1995", 26, 39], ["starts", 16, 22]], "evidential": [["Balibo", "Balibo ( film )", "Balibo( film )", "Balibo ( films )"], ["1975", "the 1970s", "the 1980s", "the year 1975"], ["begins", "starts", "began", "begin"]], "culprit": [1]} +{"id": 55239, "claim": "Victor Frankenstein is only a romance film .", "questions": ["What is the name of the film that is a romance? or is only a romance film .", "What is the purpose of Victor Frankenstein? or Victor Frankenstein is ."], "answers": [["Victor Frankenstein", 0, 19], ["only a romance film", 23, 42]], "evidential": [["Victor Frankenstein ( film )", "Victor Frankenstein", "Victor Frankenstein( film )", "Victor Frankenstein ( films )"], ["a film", "a motion picture", "a recorded work", "directed"]], "culprit": [1]} +{"id": 7728, "claim": "Hinduism has zero textual resources .", "questions": ["What religion has zero textual resources? or has zero textual resources .", "How many textual resources does Hinduism have? or Hinduism has ."], "answers": [["Hinduism", 0, 8], ["zero textual resources", 13, 35]], "evidential": [["Hinduism", "Hindu religion", "Indianism", "Buddhism"], ["multiple textual resources", "many shared textual resources", "shared textual resources", "many textual resources"]], "culprit": [1]} +{"id": 202475, "claim": "Tinker Tailor Soldier Spy only stars Gary Oldman .", "questions": ["What movie stars Gary Oldman? or only stars Gary Oldman .", "Who stars in Tinker Tailor Soldier Spy? or Tinker Tailor Soldier Spy only stars .", "What is Gary Oldman's first name? or Tinker Tailor Soldier Spy only Gary Oldman .", "How many episodes does Tinker Tailor Soldier Spy have? or Tinker Tailor Soldier Spy stars Gary Oldman ."], "answers": [["Tinker Tailor Soldier Spy", 0, 25], ["Gary Oldman", 37, 48], ["stars", 31, 36], ["only", 26, 30]], "evidential": [["Tinker Tailor Soldier Spy", "The Tinker Tailor Soldier Spy", "Tinker Tailor Soldier Spy", "Tinker Tailor Soldier Spy movie"], ["Gary Oldman", "an actor", "George Smiley", "a man"], ["stars", "features", "includes", "contains"], ["only", "one episode", "2 episodes", "one series"]], "culprit": [2, 3]} +{"id": 159091, "claim": "Guatemala has lived without war for its entire existence .", "questions": ["What country has lived without war for its entire existence? or has lived without war for its entire existence .", "What has Guatemala lived without? or Guatemala has lived without for its entire existence .", "How long has Guatemala lived without war? or Guatemala has lived without war for its .", "How long has Guatemala been without war? or Guatemala has without war for its entire existence ."], "answers": [["Guatemala", 0, 9], ["war", 28, 31], ["entire existence", 40, 56], ["lived", 14, 19]], "evidential": [["Guatemala", "Central America Guatemala", "Guatemalan", "Central America"], ["a military", "a government", "war", "a war"], ["time", "existence", "decade", "decades"], ["existed", "gone", "lived", "been"]], "culprit": [1, 2]} +{"id": 24481, "claim": "David Spade was fired from being in Joe Dirt 2 : Beautiful Loser .", "questions": ["Who was fired from being in Joe Dirt 2? or was fired from being in Joe Dirt 2 : Beautiful Loser .", "What was David Spade's first role in? or David Spade was fired from being in 2 : Beautiful Loser .", "How many episodes of Joe Dirt did Spade have? or David Spade was fired from being in Joe Dirt : Beautiful Loser .", "What was the title of Joe Dirt 2? or David Spade was fired from being in Joe Dirt 2 : .", "How did David Spade react to being in Joe Dirt 2? or David Spade was from being in Joe Dirt 2 : Beautiful Loser ."], "answers": [["David Spade", 0, 11], ["Joe Dirt", 36, 44], ["2", 45, 46], ["Beautiful Loser", 49, 64], ["fired", 16, 21]], "evidential": [["David Spade", "David Spades", "David Spade", "David Spader"], ["Joe Dirt", "the comedy Joe Dirt", "the film Joe Dirt", "the movie Joe Dirt"], ["2", "two episodes", "two", "2 :"], ["Beautiful Loser", "Beautiful Ler", "BeautifulLoser", "Beautiful Losers"], ["distracted", "banned", "traumatized", "disheartened"]], "culprit": [4]} +{"id": 67876, "claim": "Britt Robertson was not in the television series Girlboss .", "questions": ["Who was not in the television series Girlboss? or was not in the television series Girlboss .", "What television series did Britt Robertson not appear in? or Britt Robertson was not in the television series .", "What was Britt Robertson not in? or Britt Robertson was not in Girlboss ."], "answers": [["Britt Robertson", 0, 15], ["Girlboss", 49, 57], ["the television series", 27, 48]], "evidential": [["Britt Robertson", "Brittany Robertson", "Britt Roberts", "Brit Robertson"], ["Girlboss", "The Secret Circle", "Girlsboss", "a Netflix comedy"], ["the comedy television series", "the show", "the comedy TV series", "the TV series"]], "culprit": [1, 2]} +{"id": 76324, "claim": "Richard Dawson is still alive .", "questions": ["Who is still alive? or is still alive .", "How old is Richard Dawson? or Richard Dawson is alive .", "What is Richard Dawson's age? or Richard Dawson is still ."], "answers": [["Richard Dawson", 0, 14], ["still", 18, 23], ["alive", 24, 29]], "evidential": [["Richard Dawson", "Richard Dwayne Dawson", "Richard Dawsons", "Richard D Dawson"], ["still", "alive", "barely", "currently"], ["dead", "deceased", "alive", "63"]], "culprit": [1, 2]} +{"id": 104710, "claim": "Miranda Otto is the son of Barry Otto .", "questions": ["Who is the son of Barry Otto? or is the son of Barry Otto .", "What is Miranda Otto's biological name? or Miranda Otto is of Barry Otto .", "Who is Miranda Otto's father? or Miranda Otto is the son of ."], "answers": [["Miranda Otto", 0, 12], ["the son", 16, 23], ["Barry Otto", 27, 37]], "evidential": [["Miranda Otto", "Miriam Otto", "Miranda Oster", "Miranda Oste"], ["the daughter", "the sister", "the biological daughter", "the granddaughter"], ["an actor", "Barry Otto", "an actress", "a man"]], "culprit": [1]} +{"id": 92988, "claim": "See You on the Other Side is a boat .", "questions": ["What side of the boat is See You on? or See You on the Other is a boat .", "What is See You on the Other Side? or See You on the Other Side is .", "What is the name of the boat that is on the other side? or You on the Other Side is a boat ."], "answers": [["Side", 21, 25], ["a boat", 29, 35], ["See", 0, 3]], "evidential": [["Side", "side", "side 2", "Side 2"], ["an album", "a recorded work", "a record", "a work"], ["See", "The album", "See '", "see"]], "culprit": [1]} +{"id": 150834, "claim": "Tool has not produced albums .", "questions": ["Which tool has not produced an album? or has not produced albums .", "Tool has not produced what? or Tool has not produced .", "Tool has not what type of albums? or Tool has not albums ."], "answers": [["Tool", 0, 4], ["albums", 22, 28], ["produced", 13, 21]], "evidential": [["Tool", "Tool ( band )", "Tool( band )", "Tool(band )"], ["albums", "an album", "music", "records"], ["produced", "released", "created", "published"]], "culprit": [1]} +{"id": 135684, "claim": "Elizabeth I was the son of Anne Boleyn .", "questions": ["Who was the son of Anne Boleyn? or I was the son of Anne Boleyn .", "What was Elizabeth I's father's name? or Elizabeth I was of Anne Boleyn .", "Who was Elizabeth I's mother? or Elizabeth I was the son of ."], "answers": [["Elizabeth", 0, 9], ["the son", 16, 23], ["Anne Boleyn", 27, 38]], "evidential": [["Elizabeth", "Elizabeth I", "Queen Elizabeth", "ElizabethI"], ["the daughter", "the child", "the son", "a daughter"], ["Anne Boleyn", "Ann Boleyn", "Anne Bolyn", "a woman"]], "culprit": [1]} +{"id": 124045, "claim": "Ron Weasley was denied membership to Gryffindor house .", "questions": ["Who was denied membership to Gryffindor house? or was denied membership to Gryffindor house .", "What was Ron Weasley denied? or Ron Weasley was denied to Gryffindor house .", "What house was Ron Weasley denied membership to? or Ron Weasley was denied membership to house .", "What was Ron Weasley denied membership to? or Ron Weasley was denied membership to Gryffindor .", "What was Ron Weasley's status as a member of Gryffindor or Ron Weasley was membership to Gryffindor house ."], "answers": [["Ron Weasley", 0, 11], ["membership", 23, 33], ["Gryffindor", 37, 47], ["house", 48, 53], ["denied", 16, 22]], "evidential": [["Ron Weasley", "The Ron Weasley", "A Ron Weasley", "Ronald Weasley"], ["access", "a visit", "membership", "a membership"], ["the Gryffindor", "a Gryffindor", "The Gryffindor", "the Gryfindor"], ["house", "family", "houses", "home"], ["given", "granted", "denied", "required"]], "culprit": [4]} +{"id": 56381, "claim": "Lorelai Gilmore 's uncle was played by Edward Herrmann .", "questions": ["Who was the uncle of Edward Herrmann? or 's uncle was played by Edward Herrmann .", "Who played Lorelai Gilmore's uncle? or Lorelai Gilmore 's uncle was played by .", "What role did Edward Herrmann play in Lorelai Gilmore's uncle? or Lorelai Gilmore 's uncle was by Edward Herrmann ."], "answers": [["Lorelai Gilmore", 0, 15], ["Edward Herrmann", 39, 54], ["played", 29, 35]], "evidential": [["Lorelai Gilmore", "Lorelai Gilmore", "Lorelai Gilpin", "Lorelai Glyn"], ["Edward Herrmann", "an actor", "Edward Herrman", "a man"], ["played", "portrayed", "performed", "voiced"]], "culprit": [1]} +{"id": 78742, "claim": "Tim Roth is not an English actor .", "questions": ["Who is an English actor? or is not an English actor .", "What is Tim Roth's nationality? or Tim Roth is not actor .", "What is Tim Roth's profession? or Tim Roth is not an English ."], "answers": [["Tim Roth", 0, 8], ["an English", 16, 26], ["actor", 27, 32]], "evidential": [["Tim Roth", "Timothy Roth", "Tim Roth", "Tim R Roth"], ["an English", "a European", "a British", "an European"], ["actor", "director", "film actor", "film director"]], "culprit": [1, 2]} +{"id": 180717, "claim": "Victoria ( Dance Exponents song ) was released in New Zealand in 1980 .", "questions": ["What song was released in New Zealand in 1980? or was released in New Zealand in 1980 .", "Where was Victoria released? or Victoria ( Dance Exponents song ) was released in in 1980 .", "When was Victoria released in New Zealand? or Victoria ( Dance Exponents song ) was released in New Zealand in .", "What was the name of Victoria's song? or Victoria ( Dance Exponents song ) was in New Zealand in 1980 ."], "answers": [["Victoria ( Dance Exponents song )", 0, 33], ["New Zealand", 50, 61], ["1980", 65, 69], ["released", 38, 46]], "evidential": [["Victoria / Dance Exponents song", "Victoria", "Victoria Song", "Victoria song"], ["New Zealand", "China", "Australia", "the world"], ["1982", "the 1980s", "the eighties", "the nineties"], ["released", "performed", "played", "recorded"]], "culprit": [2]} +{"id": 125491, "claim": "Hot Right Now is from the album Escape from Planet Monday .", "questions": ["What is the name of the song from the album Escape from Planet Monday? or is from the album Escape from Planet Monday .", "What is the name of the album that Hot Right Now is from? or Hot Right Now is from from Planet Monday .", "What album is Hot Right Now from? or Hot Right Now is from the album ."], "answers": [["Hot Right Now", 0, 13], ["the album Escape", 22, 38], ["Escape from Planet Monday", 32, 57]], "evidential": [["Hot Right Now", "Hot right Now", "Hit Right Now", "Hot Right now"], ["Escape", "the album Escape", "an album", "the single Escape"], ["Escape from Planet Monday", "Nextlevelism", "Escape From Planet Monday", "Next Levelism"]], "culprit": [1, 2]} +{"id": 100204, "claim": "Shadowhunters did not premiere in 2016 .", "questions": ["What movie did not premiere in 2016? or did not premiere in 2016 .", "When did Shadowhunters not premiere? or Shadowhunters did not premiere in .", "What did Shadowhunters not do in 2016? or Shadowhunters did not in 2016 ."], "answers": [["Shadowhunters", 0, 13], ["2016", 34, 38], ["premiere", 22, 30]], "evidential": [["Shadowhunters", "The Shadowhunters", "Shadowshunters", "Shadowhunterters"], ["2016", "2015", "January 2016", "the 2010s"], ["premiere", "air", "start", "launch"]], "culprit": [1]} +{"id": 73208, "claim": "Reign Over Me was written and directed by Spike Lee .", "questions": ["What movie was directed by Spike Lee? or was written and directed by Spike Lee .", "Who directed Reign Over Me? or Reign Over Me was written and directed by .", "What was the name of the film that directed it? or Reign Over Me was and directed by Spike Lee .", "What was the film by Spike Lee? or Reign Over Me was written and by Spike Lee ."], "answers": [["Reign Over Me", 0, 13], ["Spike Lee", 42, 51], ["written", 18, 25], ["directed", 30, 38]], "evidential": [["Reign Over Me", "Reign over Me", "Reign of Me", "Reign Over me"], ["a man", "an American", "a person", "an actor"], ["written", "penned", "authored", "wrote"], ["directed", "produced", "written", "created"]], "culprit": [1]} +{"id": 225871, "claim": "Revolver has only ever topped a single chart .", "questions": ["What has only ever topped a single chart? or has only ever topped a single chart .", "How many charts has Revolver ever topped? or Revolver has only ever topped .", "How many times has Revolver ever a single chart? or Revolver has only ever a single chart .", "How many times has Revolver topped a single chart? or Revolver has ever topped a single chart ."], "answers": [["Revolver", 0, 8], ["a single chart", 30, 44], ["topped", 23, 29], ["only", 13, 17]], "evidential": [["Revolver", "Revololver", "Revolver Record", "The Revolver"], ["four charts", "two charts", "three charts", "zero charts"], ["topped", "charted", "reached", "appeared"], ["never", "n't", "only", "rarely"]], "culprit": [1, 3]} +{"id": 125225, "claim": "Omar Khadr has always been free .", "questions": ["Who has always been free? or has always been free .", "How long has Omar Khadr been free? or Omar Khadr has been free .", "Omar Khadr has always been what? or Omar Khadr has always been ."], "answers": [["Omar Khadr", 0, 10], ["always", 15, 21], ["free", 27, 31]], "evidential": [["Omar Khadr", "Omar Khadri", "Omar Khadr", "Om Khadr"], ["yet", "never", "always", "since"], ["a prisoner", "a person", "a human", "a detainee"]], "culprit": [1, 2]} +{"id": 174514, "claim": "Red Bull Racing was eradicated in the United Kingdom .", "questions": ["What was the name of the race that was eradicated in the UK? or was eradicated in the United Kingdom .", "Where was Red Bull Racing eradicated? or Red Bull Racing was eradicated in .", "What happened to Red Bull Racing in the UK? or Red Bull Racing was in the United Kingdom ."], "answers": [["Red Bull Racing", 0, 15], ["the United Kingdom", 34, 52], ["eradicated", 20, 30]], "evidential": [["Red Bull Racing", "Red Bull R&B Racing", "Red Bull Racing", "Red Bull racing"], ["Austria", "Europe", "the UK", "England"], ["acquired", "founded", "established", "created"]], "culprit": [2]} +{"id": 67464, "claim": "Louie ( season 1 ) was created by David Benioff .", "questions": ["What was the name of the show created by David Benioff? or was created by David Benioff .", "Who created Louie? or Louie ( season 1 ) was created by .", "What was the name of Louie? or Louie ( season 1 ) was by David Benioff ."], "answers": [["Louie ( season 1 )", 0, 18], ["David Benioff", 34, 47], ["created", 23, 30]], "evidential": [["Louie", "Louie ( season 1 )", "Louis C.K.", "The show Louie"], ["Louis C.K", "a person", "a series creator", "a man"], ["created", "written", "penned", "produced"]], "culprit": [1, 2]} +{"id": 84710, "claim": "Buffy the Vampire Slayer is created by Joss Whedon in 1990 .", "questions": ["What movie was created by Joss Whedon? or is created by Joss Whedon in 1990 .", "Who created Buffy the Vampire Slayer? or Buffy the Vampire Slayer is created by in 1990 .", "When was Buffy the Vampire Slayer created? or Buffy the Vampire Slayer is created by Joss Whedon in .", "What was the name of the film that made Buffy the Vampire Slayer? or Buffy the Vampire Slayer is by Joss Whedon in 1990 ."], "answers": [["Buffy the Vampire Slayer", 0, 24], ["Joss Whedon", 39, 50], ["1990", 54, 58], ["created", 28, 35]], "evidential": [["Buffy the Vampire Slayer", "The Buffy the Vampire Slayer", "Buffy The Vampire Slayer", "Buffy of the Vampire Slayer"], ["Joss Whedon", "a person", "a man", "an American"], ["the 1990s", "the 2000s", "the nineties", "1992"], ["created", "produced", "directed", "a film"]], "culprit": [2]} +{"id": 198041, "claim": "The New York City Landmarks Preservation Commission includes zero architects .", "questions": ["What organization has zero architects? or includes zero architects .", "How many architects does the New York City Landmarks Preservation Commission have? or The New York City Landmarks Preservation Commission includes .", "How many architects does the New York City Landmarks Preservation Commission have? or The New York City Landmarks Preservation Commission zero architects ."], "answers": [["The New York City Landmarks Preservation Commission", 0, 51], ["zero architects", 61, 76], ["includes", 52, 60]], "evidential": [["The New York City Landmarks Preservation Commission", "New York City Landmarks Preservation Commission", "The New York City Landmarks Preservation commission", "A New York City Landmarks Preservation Commission"], ["11 architects", "three architects", "11 commissioners", "ten architects"], ["includes", "contains", "consists", "involves"]], "culprit": [1]} +{"id": 42390, "claim": "Jack Falahee is Mongolian .", "questions": ["Who is the Mongolian whose name is? or is Mongolian .", "What nationality is Jack Falahee? or Jack Falahee is ."], "answers": [["Jack Falahee", 0, 12], ["Mongolian", 16, 25]], "evidential": [["Jack Falahee", "Jack Falahe", "John Falahee", "Jack Falaefhee"], ["American", "an American", "North American", "European"]], "culprit": [1]} +{"id": 175736, "claim": "The Cry of the Owl is based on Patricia Highsmith 's eighth movie .", "questions": ["What is the name of the movie based on Patricia Highsmith's eighth film? or is based on Patricia Highsmith 's eighth movie .", "Who wrote the movie The Cry of the Owl? or The Cry of the Owl is based on 's eighth movie .", "What is the story of The Cry Of The Owl? or The Cry of the Owl is on Patricia Highsmith 's eighth movie .", "What was the first movie based on? or The Cry of the Owl is based on Patricia Highsmith 's movie ."], "answers": [["The Cry of the Owl", 0, 18], ["Patricia Highsmith", 31, 49], ["based", 22, 27], ["eighth", 53, 59]], "evidential": [["The Cry of the Owl", "The Cry of the Owl ( 2009 film )", "The Cry of the Owl( 2009 film )", "The Cry of the Owl(2009 film )"], ["Patricia Highsmith", "an author", "a writer", "a novelist"], ["based", "a story", "a novel", "loosely"], ["first", "book", "novel", "a novel"]], "culprit": [3]} +{"id": 152929, "claim": "Firefox is an operating system shell .", "questions": ["What is the name of the operating system shell? or is an operating system shell .", "What is Firefox? or Firefox is ."], "answers": [["Firefox", 0, 7], ["an operating system shell", 11, 36]], "evidential": [["Firefox", "Mozilla", "Mozilla Firefox", "The Firefox"], ["a web browser", "a free web browser", "open source", "a free software application"]], "culprit": [1]} +{"id": 183589, "claim": "Finding Dory was directed by Ingmar Bergman .", "questions": ["What movie was directed by Ingmar Bergman? or was directed by Ingmar Bergman .", "Who directed Finding Dory? or Finding Dory was directed by .", "What was the name of the film that starred in Finding Dory? or Finding Dory was by Ingmar Bergman ."], "answers": [["Finding Dory", 0, 12], ["Ingmar Bergman", 29, 43], ["directed", 17, 25]], "evidential": [["Finding Dory", "The Finding Dory", "Finding dory", "Finding Dory movie"], ["Andrew Stanton", "Angus MacLane", "a person", "Angus Maclane"], ["directed", "written", "produced", "penned"]], "culprit": [1]} +{"id": 108957, "claim": "Agent Raghav \u2013 Crime Branch is a phone .", "questions": ["What is the name of the agent that is on the phone? or is a phone .", "What is the name of the agent in the Crime Branch? or Agent Raghav \u2013 Crime Branch is ."], "answers": [["Agent Raghav \u2013 Crime Branch", 0, 27], ["a phone", 31, 38]], "evidential": [["Agent Raghav \u2013 Crime Branch", "Agent Raghav - Crime Branch", "Agent Raghav", "Agent Raghav \u2014 Crime Branch"], ["an anthology television series", "a serial", "a television serial", "a television series"]], "culprit": [1]} +{"id": 3160, "claim": "University of Chicago Law School is ranked first in the 2016 QS World University Rankings .", "questions": ["What is the name of the law school that is ranked first in the 2016 QS World or is ranked first in the 2016 QS World University Rankings .", "What is the name of the organization that ranks law schools in the world? or University of Chicago Law School is ranked first in the 2016 .", "What is the ranking of University of Chicago Law School in the 2016 QS World University Rankings or University of Chicago Law School is first in the 2016 QS World University Rankings .", "What is the ranking of University of Chicago Law School in the 2016 QS World University Rankings or University of Chicago Law School is ranked in the 2016 QS World University Rankings .", "In what year did the University of Chicago Law School rank first in the QS World University Ranking or University of Chicago Law School is ranked first in QS World University Rankings ."], "answers": [["University of Chicago Law School", 0, 32], ["QS World University Rankings", 61, 89], ["ranked", 36, 42], ["first", 43, 48], ["the 2016", 52, 60]], "evidential": [["University of Chicago Law School", "The University of Chicago Law School", "the University of Chicago Law School", "University of Chicago law School"], ["QS World University Rankings", "the QS World University Rankings", "S&S World University Rankings", "QS World University Rankings."], ["ranked", "listed", "placed", "Ranked"], ["12th", "11th", "twelveth", "ninth"], ["the 2016", "the 2015", "2016", "The 2016"]], "culprit": [3]} +{"id": 148309, "claim": "The Adventures of Pluto Nash failed to be a released film .", "questions": ["What failed to be a released film? or failed to be a released film .", "What did The Adventures of Pluto Nash fail to be? or The Adventures of Pluto Nash failed to be .", "What was the result of The Adventures of Pluto Nash? or The Adventures of Pluto Nash to be a released film ."], "answers": [["The Adventures of Pluto Nash", 0, 28], ["a released film", 42, 57], ["failed", 29, 35]], "evidential": [["The Adventures of Pluto Nash", "The adventures of Pluto Nash", "The Adventures of Pluto N", "An Adventures of Pluto Nash"], ["a release", "released", "an release", "release"], ["happened", "ceased", "turned", "failed"]], "culprit": [2]} +{"id": 227135, "claim": "The New Orleans Pelicans only play in the NHL .", "questions": ["Who plays in the NHL? or only play in the NHL .", "What league do the New Orleans Pelicans play in? or The New Orleans Pelicans only play in .", "What do the New Orleans Pelicans only do in the NHL? or The New Orleans Pelicans only in the NHL .", "How many of the New Orleans Pelicans play in the NHL? or The New Orleans Pelicans play in the NHL ."], "answers": [["The New Orleans Pelicans", 0, 24], ["the NHL", 38, 45], ["play", 30, 34], ["only", 25, 29]], "evidential": [["The New Orleans Pelicans", "New Orleans Pelicans", "the New Orleans Pelicans", "The New Orleans Saints"], ["the National Basketball Association", "the NBA", "a league", "the Western Conference"], ["play", "compete", "participate", "plays"], ["only", "two", "one", "currently"]], "culprit": [1, 3]} +{"id": 126678, "claim": "The Colosseum is a wrestler from Italy .", "questions": ["What is the name of the wrestler from Italy? or is a wrestler from Italy .", "Who is the Colosseum? or The Colosseum is from Italy .", "Where is The Colosseum? or The Colosseum is a wrestler from ."], "answers": [["The Colosseum", 0, 13], ["a wrestler", 17, 27], ["Italy", 33, 38]], "evidential": [["The Colosseum", "Colosseum", "The colosseum", "A Colosseum"], ["a tourist attraction", "an attraction", "an amphitheater", "a popular tourist attraction"], ["Rome", "Italy", "a city", "the city"]], "culprit": [1]} diff --git a/src/eval_client/fever_scorer.py b/src/eval_client/fever_scorer.py new file mode 100644 index 0000000000000000000000000000000000000000..9c4b498e7c6b6698a0a6a4dfd4fcab735fe7811d --- /dev/null +++ b/src/eval_client/fever_scorer.py @@ -0,0 +1,84 @@ +# -*- coding:utf-8 -*- + +""" +@Author : Bao +@Date : 2020/8/24 +@Desc : +@Last modified by : Bao +@Last modified date : 2020/9/1 +""" + +import os +import json +import numpy as np +from collections import defaultdict +import tensorflow as tf +from sklearn.metrics import precision_recall_fscore_support +try: + from .scorer import fever_score +except: + from scorer import fever_score + + +prefix = os.environ['PJ_HOME'] + + +class FeverScorer: + def __init__(self): + self.id2label = {2: 'SUPPORTS', 0: 'REFUTES', 1: 'NOT ENOUGH INFO'} + self.label2id = {value: key for key, value in self.id2label.items()} + + def get_scores(self, predicted_file, actual_file=f'{prefix}/data/fever/shared_task_dev.jsonl'): + id2results = defaultdict(dict) + + with tf.io.gfile.GFile(predicted_file) as f: + for line in f: + js = json.loads(line) + guid = js['id'] + id2results[guid] = js + + with tf.io.gfile.GFile(actual_file) as fin: + for line in fin: + line = json.loads(line) + guid = line['id'] + evidence = line['evidence'] + label = line['label'] + id2results[guid]['evidence'] = evidence + id2results[guid]['label'] = label + + results = self.label_score(list(id2results.values())) + score, accuracy, precision, recall, f1 = fever_score(list(id2results.values())) + results.update({ + 'Evidence Precision': precision, + 'Evidence Recall': recall, + 'Evidence F1': f1, + 'FEVER Score': score, + 'Label Accuracy': accuracy + }) + + return results + + def label_score(self, results): + truth = np.array([v['label'] for v in results]) + prediction = np.array([v['predicted_label'] for v in results]) + labels = list(self.label2id.keys()) + results = {} + p, r, f, _ = precision_recall_fscore_support(truth, prediction, labels=labels) + for i, label in enumerate(self.label2id.keys()): + results['{} Precision'.format(label)] = p[i] + results['{} Recall'.format(label)] = r[i] + results['{} F1'.format(label)] = f[i] + + return results + + +if __name__ == '__main__': + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("--predicted_file", '-i', type=str) + args = parser.parse_args() + + scorer = FeverScorer() + results = scorer.get_scores(args.predicted_file) + print(json.dumps(results, ensure_ascii=False, indent=4)) diff --git a/src/eval_client/scorer.py b/src/eval_client/scorer.py new file mode 100644 index 0000000000000000000000000000000000000000..ecae63dcd900f84c498ad0f2e5d7dd06e305e329 --- /dev/null +++ b/src/eval_client/scorer.py @@ -0,0 +1,153 @@ +import six + +def check_predicted_evidence_format(instance): + if 'predicted_evidence' in instance.keys() and len(instance['predicted_evidence']): + assert all(isinstance(prediction, list) + for prediction in instance["predicted_evidence"]), \ + "Predicted evidence must be a list of (page,line) lists" + + assert all(len(prediction) == 2 + for prediction in instance["predicted_evidence"]), \ + "Predicted evidence must be a list of (page,line) lists" + + assert all(isinstance(prediction[0], six.string_types) + for prediction in instance["predicted_evidence"]), \ + "Predicted evidence must be a list of (page,line) lists" + + assert all(isinstance(prediction[1], int) + for prediction in instance["predicted_evidence"]), \ + "Predicted evidence must be a list of (page,line) lists" + + +def is_correct_label(instance): + return instance["label"].upper() == instance["predicted_label"].upper() + + +def is_strictly_correct(instance, max_evidence=None): + #Strict evidence matching is only for NEI class + check_predicted_evidence_format(instance) + + if instance["label"].upper() != "NOT ENOUGH INFO" and is_correct_label(instance): + assert 'predicted_evidence' in instance, "Predicted evidence must be provided for strict scoring" + + if max_evidence is None: + max_evidence = len(instance["predicted_evidence"]) + + + for evience_group in instance["evidence"]: + #Filter out the annotation ids. We just want the evidence page and line number + actual_sentences = [[e[2], e[3]] for e in evience_group] + #Only return true if an entire group of actual sentences is in the predicted sentences + if all([actual_sent in instance["predicted_evidence"][:max_evidence] for actual_sent in actual_sentences]): + return True + + #If the class is NEI, we don't score the evidence retrieval component + elif instance["label"].upper() == "NOT ENOUGH INFO" and is_correct_label(instance): + return True + + return False + + +def evidence_macro_precision(instance, max_evidence=None): + this_precision = 0.0 + this_precision_hits = 0.0 + + if instance["label"].upper() != "NOT ENOUGH INFO": + all_evi = [[e[2], e[3]] for eg in instance["evidence"] for e in eg if e[3] is not None] + + predicted_evidence = instance["predicted_evidence"] if max_evidence is None else \ + instance["predicted_evidence"][:max_evidence] + + for prediction in predicted_evidence: + if prediction in all_evi: + this_precision += 1.0 + this_precision_hits += 1.0 + + return (this_precision / this_precision_hits) if this_precision_hits > 0 else 1.0, 1.0 + + return 0.0, 0.0 + +def evidence_macro_recall(instance, max_evidence=None): + # We only want to score F1/Precision/Recall of recalled evidence for NEI claims + if instance["label"].upper() != "NOT ENOUGH INFO": + # If there's no evidence to predict, return 1 + if len(instance["evidence"]) == 0 or all([len(eg) == 0 for eg in instance]): + return 1.0, 1.0 + + predicted_evidence = instance["predicted_evidence"] if max_evidence is None else \ + instance["predicted_evidence"][:max_evidence] + + for evidence_group in instance["evidence"]: + evidence = [[e[2], e[3]] for e in evidence_group] + if all([item in predicted_evidence for item in evidence]): + # We only want to score complete groups of evidence. Incomplete groups are worthless. + return 1.0, 1.0 + return 0.0, 1.0 + return 0.0, 0.0 + + +# Micro is not used. This code is just included to demostrate our model of macro/micro +def evidence_micro_precision(instance): + this_precision = 0 + this_precision_hits = 0 + + # We only want to score Macro F1/Precision/Recall of recalled evidence for NEI claims + if instance["label"].upper() != "NOT ENOUGH INFO": + all_evi = [[e[2], e[3]] for eg in instance["evidence"] for e in eg if e[3] is not None] + + for prediction in instance["predicted_evidence"]: + if prediction in all_evi: + this_precision += 1.0 + this_precision_hits += 1.0 + + return this_precision, this_precision_hits + + +def fever_score(predictions,actual=None, max_evidence=5): + correct = 0 + strict = 0 + + macro_precision = 0 + macro_precision_hits = 0 + + macro_recall = 0 + macro_recall_hits = 0 + + for idx,instance in enumerate(predictions): + assert 'predicted_evidence' in instance.keys(), 'evidence must be provided for the prediction' + + #If it's a blind test set, we need to copy in the values from the actual data + if 'evidence' not in instance or 'label' not in instance: + assert actual is not None, 'in blind evaluation mode, actual data must be provided' + assert len(actual) == len(predictions), 'actual data and predicted data length must match' + assert 'evidence' in actual[idx].keys(), 'evidence must be provided for the actual evidence' + instance['evidence'] = actual[idx]['evidence'] + instance['label'] = actual[idx]['label'] + + assert 'evidence' in instance.keys(), 'gold evidence must be provided' + + if is_correct_label(instance): + correct += 1.0 + + if is_strictly_correct(instance, max_evidence): + strict+=1.0 + + macro_prec = evidence_macro_precision(instance, max_evidence) + macro_precision += macro_prec[0] + macro_precision_hits += macro_prec[1] + + macro_rec = evidence_macro_recall(instance, max_evidence) + macro_recall += macro_rec[0] + macro_recall_hits += macro_rec[1] + + total = len(predictions) + + strict_score = strict / total + acc_score = correct / total + + pr = (macro_precision / macro_precision_hits) if macro_precision_hits > 0 else 1.0 + rec = (macro_recall / macro_recall_hits) if macro_recall_hits > 0 else 0.0 + + f1 = 2.0 * pr * rec / (pr + rec) + + return strict_score, acc_score, pr, rec, f1 \ No newline at end of file diff --git a/src/loren.py b/src/loren.py new file mode 100644 index 0000000000000000000000000000000000000000..163efff1046329b8461a2169aa41fa2e56ef151b --- /dev/null +++ b/src/loren.py @@ -0,0 +1,167 @@ +# -*- coding: utf-8 -*- + +''' +@Author : Jiangjie Chen +@Time : 2020/9/17 15:55 +@Contact : jjchen19@fudan.edu.cn +@Description: +''' + +import os +import sys +import json +import logging +import cjjpy as cjj + +try: + from .qg_client.question_generator import QuestionGenerator + from .mrc_client.answer_generator import AnswerGenerator, chunks, assemble_answers_to_one + from .parsing_client.sentence_parser import SentenceParser, deal_bracket + from .check_client.fact_checker import FactChecker, id2label + from .er_client import EvidenceRetrieval +except: + sys.path.append(cjj.AbsParentDir(__file__, '.')) + from qg_client.question_generator import QuestionGenerator + from mrc_client.answer_generator import AnswerGenerator, chunks, assemble_answers_to_one + from parsing_client.sentence_parser import SentenceParser, deal_bracket + from check_client.fact_checker import FactChecker, id2label + from er_client import EvidenceRetrieval + + +def load_config(config): + if isinstance(config, str): + with open(config) as f: + config = json.load(f) + cfg = cjj.AttrDict(config) + return cfg + + +class Loren: + def __init__(self, config_file, verbose=True): + self.verbose = verbose + self.args = load_config(config_file) + self.sent_client = SentenceParser() + self.qg_client = QuestionGenerator('t5', verbose=False) + self.ag_client = AnswerGenerator(self.args.mrc_dir) + self.fc_client = FactChecker(self.args, self.args.fc_dir) + self.er_client = EvidenceRetrieval(self.args.er_dir) + self.logger = cjj.init_logger(f'{os.environ["PJ_HOME"]}/results/loren_dev.log', + log_file_level=logging.INFO if self.verbose else logging.WARNING) + self.logger.info('*** Loren initialized. ***') + + def check(self, claim, evidence=None): + self.logger.info('*** Verifying "%s"... ***' % claim) + js = self.prep(claim, evidence) + js['id'] = 0 + y_predicted, z_predicted, m_attn = self.fc_client.check_from_batch([js], verbose=self.verbose) + label = id2label[y_predicted[0]] + + # Update js + js['local_premises'] = assemble_answers_to_one(js, k=3) + js['evidence'] = [self.fc_client.tokenizer.clean_up_tokenization(e[2]) for e in js['evidence']] + js['questions'] = [self.fc_client.tokenizer.clean_up_tokenization(q) for q in js['questions']] + js['claim_phrases'] = [self.fc_client.tokenizer.clean_up_tokenization(a[0]) for a in js['answers']] + js['local_premises'] = [[self.fc_client.tokenizer.clean_up_tokenization(a) for a in aa] + for aa in js['local_premises']] + # js['m_attn'] = m_attn[0][:len(js['claim_phrases'])] + js['phrase_veracity'] = z_predicted[0][:len(js['claim_phrases'])] + js['claim_veracity'] = label + + self.logger.info(" * Intermediary: %s *" % str(js)) + self.logger.info('*** Verification completed: "%s" ***' % label) + return js + + def prep(self, claim, evidence=None): + ''' + :param evidence: 'aaa||bbb||ccc' / [entity, num, evidence, (prob)] if not None + ''' + evidence = self._prep_evidence(claim, evidence) + self.logger.info(' * Evidence prepared. *') + assert isinstance(evidence, list) + + js = {'claim': claim, 'evidence': evidence} + js = self._prep_claim_phrases(js) + self.logger.info(' * Claim phrases prepared. *') + js = self._prep_questions(js) + self.logger.info(' * Probing questions prepared. *') + js = self._prep_evidential_phrases(js) + self.logger.info(' * Evidential phrases prepared. *') + return js + + def _prep_claim_phrases(self, js): + results = self.sent_client.identify_NPs(deal_bracket(js['claim'], True), + candidate_NPs=[x[0] for x in js['evidence']]) + NPs = results['NPs'] + claim = results['text'] + verbs = results['verbs'] + adjs = results['adjs'] + _cache = {'claim': claim, + 'evidence': js['evidence'], + 'answers': NPs + verbs + adjs, + 'answer_roles': ['noun'] * len(NPs) + ['verb'] * len(verbs) + ['adj'] * len(adjs)} + if len(_cache['answers']) == 0: + _cache['answers'] = js['claim'].split()[0] + _cache['answer_roles'] = ['noun'] + return _cache + + def _prep_questions(self, js): + _cache = [] + for answer in js['answers']: + _cache.append((js['claim'], [answer])) + qa_pairs = self.qg_client.generate([(x, y) for x, y in _cache]) + for q, clz_q, a in qa_pairs: + if 'questions' in js: + js['regular_qs'].append(q) + js['cloze_qs'].append(clz_q) + js['questions'].append(self.qg_client.assemble_question(q, clz_q)) + else: + js['regular_qs'] = [q] + js['cloze_qs'] = [clz_q] + js['questions'] = [self.qg_client.assemble_question(q, clz_q)] + return js + + def _prep_evidential_phrases(self, js): + examples = [] + for q in js['questions']: + ex = self.ag_client.assemble(q, " ".join([x[2] for x in js['evidence']])) + examples.append(ex) + predicted = self.ag_client.generate(examples, num_beams=self.args['cand_k'], + num_return_sequences=self.args['cand_k'], + batch_size=2, verbose=False) + for answers in predicted: + if 'evidential' in js: + js['evidential'].append(answers) + else: + js['evidential'] = [answers] + return js + + def _prep_evidence(self, claim, evidence=None): + ''' + :param evidence: 'aaa||bbb||ccc' / [entity, num, evidence, (prob)] + :return: [entity, num, evidence, (prob)] + ''' + if evidence in [None, '', 'null', 'NULL', 'Null']: + evidence = self.er_client.retrieve(claim) + evidence = [(ev[0], ev[1], deal_bracket(ev[2], True, ev[0])) for ev in evidence] + else: + if isinstance(evidence, str): + # TODO: magic sentence number + evidence = [("None", i, ev.strip()) for i, ev in enumerate(evidence.split('||')[:5])] + return evidence + + +if __name__ == '__main__': + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument('--config', '-c', type=str, required=True, + default='available_models/aaai22_roberta.json', + help='Config json file with hyper-parameters') + args = parser.parse_args() + + loren = Loren(args.config) + while True: + claim = input('> ') + label, js = loren.check(claim) + print(label) + print(js) diff --git a/src/mrc_client/answer_generator.py b/src/mrc_client/answer_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..5df0be03d1298fd5ec430fef656f361d2c0fd31f --- /dev/null +++ b/src/mrc_client/answer_generator.py @@ -0,0 +1,144 @@ +# -*- coding: utf-8 -*- + +''' +@Author : Jiangjie Chen +@Time : 2020/8/12 14:44 +@Contact : jjchen19@fudan.edu.cn +@Description: +''' + +import re +import time +from pathlib import Path +from typing import Dict, List +import torch +from logging import getLogger +from tqdm import tqdm +from transformers import AutoModelForSeq2SeqLM, AutoTokenizer +import ujson as json +import random + +try: + from .seq2seq.seq2seq_utils import ( + use_task_specific_params, + calculate_rouge, + chunks, + Seq2SeqDataset, + lmap, + load_json, + save_json, + ) +except ImportError: + import cjjpy as cjj + import sys + sys.path.append(cjj.AbsParentDir(__file__, '.')) + from seq2seq.seq2seq_utils import ( + use_task_specific_params, + calculate_rouge, + chunks, + Seq2SeqDataset, + lmap, + load_json, + save_json, + ) + +logger = getLogger(__name__) +DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu" +random.seed(1111) + + +def assemble_answers_to_one(js, k=5, mask_token='', mask_rate=0.): + if isinstance(js, str): + js = json.loads(js) + + should_keep = random.random() > mask_rate + js.pop('evidential_assembled') + for q, answers in zip(js['cloze_qs'], js['evidential']): + if mask_token in q: + s = q.find(mask_token) + e = s + len(mask_token) + nq_list = [] + if should_keep: + for i in range(k): + answer_span = answers[i] + nq = q[:s] + answer_span + q[e:] + nq_list.append(nq) + else: + for i in range(k): + answer_span = mask_token + nq = q[:s] + answer_span + q[e:] + nq_list.append(nq) + ev_nqs = ' '.join(nq_list) + if js.get('evidential_assembled') is None: + js['evidential_assembled'] = [ev_nqs] + else: + js['evidential_assembled'].append(ev_nqs) + assert len(js['evidential_assembled']) == len(js['answers']) + return js + + +class AnswerGenerator(): + def __init__(self, model_name, device=DEFAULT_DEVICE): + self.model_name = str(model_name) + self.device = device + self.model = None + self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) + + def init_model(self): + if self.model is None: + self.model = AutoModelForSeq2SeqLM.from_pretrained(self.model_name).to(self.device) + + def assemble(self, question, context): + sep = '\n' if 'unifiedqa' in self.tokenizer.name_or_path else self.tokenizer.sep_token + return f'{question} {sep} {context}' + + def generate(self, examples, out_file=None, batch_size=16, verbose=True, + max_length=20, min_length=1, num_beams=4, num_return_sequences=4, + prefix=None, fp16=False, task='summarization', **generate_kwargs): + ''' + :param examples: [N] + :return: [N x num_return_seq] + ''' + self.init_model() + if fp16: + self.model = self.model.half() + # update config with summarization specific params + use_task_specific_params(self.model, task) + + fout = None if out_file is None else Path(out_file).open("w", encoding="utf-8") + generated = [] + if verbose: + iter = tqdm(list(chunks(examples, batch_size)), desc="MRC") + else: + iter = list(chunks(examples, batch_size)) + if prefix is None: + prefix = prefix or getattr(self.model.config, "prefix", "") or "" + for examples_chunk in iter: + examples_chunk = [prefix + text for text in examples_chunk] + batch = self.tokenizer(examples_chunk, return_tensors="pt", truncation=True, + padding="longest").to(self.device) + summaries = self.model.generate( + input_ids=batch.input_ids, + attention_mask=batch.attention_mask, + max_length=max_length, + min_length=min_length, + num_beams=num_beams, + num_return_sequences=num_return_sequences, + length_penalty=1.2, + repetition_penalty=1.2, + **generate_kwargs, + ) + dec = self.tokenizer.batch_decode(summaries, skip_special_tokens=True, + clean_up_tokenization_spaces=False) + if fout is not None: + for hypothesis in dec: + fout.write(hypothesis.strip() + "\n") + fout.flush() + else: + generated += dec + if fout is not None: + fout.close() + generated = list(map(lambda x: x.strip(), generated)) + generated = list(chunks(generated, num_return_sequences)) + return generated + diff --git a/src/mrc_client/cjjpy.py b/src/mrc_client/cjjpy.py new file mode 100755 index 0000000000000000000000000000000000000000..2cc70b5e553924123810ab198c143bf7ee28e5d6 --- /dev/null +++ b/src/mrc_client/cjjpy.py @@ -0,0 +1,249 @@ +# -*- coding: utf-8 -*- + +''' +@Author : Jiangjie Chen +@Time : 2018/11/15 17:08 +@Contact: jjchen19@fudan.edu.cn +''' + +import re +import datetime +import os +import argparse +import logging +import traceback + +try: + import ujson as json +except: + import json + +HADOOP_BIN = 'PATH=/usr/bin/:$PATH hdfs' +FOR_PUBLIC = True + + +def LengthStats(filename): + len_list = [] + thresholds = [0.8, 0.9, 0.95, 0.99, 0.999] + with open(filename) as f: + for line in f: + len_list.append(len(line.strip().split())) + stats = { + 'Max': max(len_list), + 'Min': min(len_list), + 'Avg': round(sum(len_list) / len(len_list), 4), + } + len_list.sort() + for t in thresholds: + stats[f"Top-{t}"] = len_list[int(len(len_list) * t)] + + for k in stats: + print(f"- {k}: {stats[k]}") + return stats + + +class AttrDict(dict): + def __init__(self, *args, **kwargs): + super(AttrDict, self).__init__(*args, **kwargs) + self.__dict__ = self + + +def TraceBack(error_msg): + exc = traceback.format_exc() + msg = f'[Error]: {error_msg}.\n[Traceback]: {exc}' + return msg + + +def Now(): + return datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") + + +def AbsParentDir(file, parent='..', postfix=None): + ppath = os.path.abspath(file) + parent_level = parent.count('.') + while parent_level > 0: + ppath = os.path.dirname(ppath) + parent_level -= 1 + if postfix is not None: + return os.path.join(ppath, postfix) + else: + return ppath + + +def init_logger(log_file=None, log_file_level=logging.NOTSET, from_scratch=False): + from coloredlogs import ColoredFormatter + import tensorflow as tf + + fmt = "[%(asctime)s %(levelname)s] %(message)s" + log_format = ColoredFormatter(fmt=fmt) + # log_format = logging.Formatter() + logger = logging.getLogger() + logger.setLevel(log_file_level) + + console_handler = logging.StreamHandler() + console_handler.setFormatter(log_format) + logger.handlers = [console_handler] + + if log_file and log_file != '': + if from_scratch and tf.io.gfile.exists(log_file): + logger.warning('Removing previous log file: %s' % log_file) + tf.io.gfile.remove(log_file) + path = os.path.dirname(log_file) + os.makedirs(path, exist_ok=True) + file_handler = logging.FileHandler(log_file) + file_handler.setLevel(log_file_level) + file_handler.setFormatter(log_format) + logger.addHandler(file_handler) + + return logger + + +def OverWriteCjjPy(root='.'): + # import difflib + # diff = difflib.HtmlDiff() + cnt = 0 + golden_cjjpy = os.path.join(root, 'cjjpy.py') + # golden_content = open(golden_cjjpy).readlines() + for dir, folder, file in os.walk(root): + for f in file: + if f == 'cjjpy.py': + cjjpy = '%s/%s' % (dir, f) + # content = open(cjjpy).readlines() + # d = diff.make_file(golden_content, content) + cnt += 1 + print('[%d]: %s' % (cnt, cjjpy)) + os.system('cp %s %s' % (golden_cjjpy, cjjpy)) + + +def ChangeFileFormat(filename, new_fmt): + assert type(filename) is str and type(new_fmt) is str + spt = filename.split('.') + if len(spt) == 0: + return filename + else: + return filename.replace('.' + spt[-1], new_fmt) + + +def CountLines(fname): + with open(fname, 'rb') as f: + count = 0 + last_data = '\n' + while True: + data = f.read(0x400000) + if not data: + break + count += data.count(b'\n') + last_data = data + if last_data[-1:] != b'\n': + count += 1 # Remove this if a wc-like count is needed + return count + + +def GetDate(): + return str(datetime.datetime.now())[5:10].replace('-', '') + + +def TimeClock(seconds): + sec = int(seconds) + hour = int(sec / 3600) + min = int((sec - hour * 3600) / 60) + ssec = float(seconds) - hour * 3600 - min * 60 + # return '%dh %dm %.2fs' % (hour, min, ssec) + return '{0:>2d}h{1:>3d}m{2:>6.2f}s'.format(hour, min, ssec) + + +def StripAll(text): + return text.strip().replace('\t', '').replace('\n', '').replace(' ', '') + + +def GetBracket(text, bracket, en_br=False): + # input should be aa(bb)cc, True for bracket, False for text + if bracket: + try: + return re.findall('\((.*?)\)', text.strip())[-1] + except: + return '' + else: + if en_br: + text = re.sub('\(.*?\)', '', text.strip()) + return re.sub('(.*?)', '', text.strip()) + + +def CharLang(uchar, lang): + assert lang.lower() in ['en', 'cn', 'zh'] + if lang.lower() in ['cn', 'zh']: + if uchar >= '\u4e00' and uchar <= '\u9fa5': + return True + else: + return False + elif lang.lower() == 'en': + if (uchar <= 'Z' and uchar >= 'A') or (uchar <= 'z' and uchar >= 'a'): + return True + else: + return False + else: + raise NotImplementedError + + +def WordLang(word, lang): + for i in word.strip(): + if i.isspace(): continue + if not CharLang(i, lang): + return False + return True + + +def SortDict(_dict, reverse=True): + assert type(_dict) is dict + return sorted(_dict.items(), key=lambda d: d[1], reverse=reverse) + + +def lark(content='test'): + print(content) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + + parser.add_argument('--diff', nargs=2, + help='show difference between two files, shown in downloads/diff.html') + parser.add_argument('--de_unicode', action='store_true', default=False, + help='remove unicode characters') + parser.add_argument('--link_entity', action='store_true', default=False, + help='') + parser.add_argument('--max_comm_len', action='store_true', default=False, + help='') + parser.add_argument('--search', nargs=2, + help='search key from file, 2 args: file name & key') + parser.add_argument('--email', nargs=2, + help='sending emails, 2 args: subject & content') + parser.add_argument('--overwrite', action='store_true', default=None, + help='overwrite all cjjpy under given *dir* based on *dir*/cjjpy.py') + parser.add_argument('--replace', nargs=3, + help='replace char, 3 args: file name & replaced char & replacer char') + parser.add_argument('--lark', nargs=1) + parser.add_argument('--get_hdfs', nargs=2, + help='easy copy from hdfs to local fs, 2 args: remote_file/dir & local_dir') + parser.add_argument('--put_hdfs', nargs=2, + help='easy put from local fs to hdfs, 2 args: local_file/dir & remote_dir') + parser.add_argument('--length_stats', nargs=1, + help='simple token lengths distribution of a line-by-line file') + + args = parser.parse_args() + + if args.overwrite: + print('* Overwriting cjjpy...') + OverWriteCjjPy() + + if args.lark: + try: + content = args.lark[0] + except: + content = 'running complete' + print(f'* Larking "{content}"...') + lark(content) + + if args.length_stats: + file = args.length_stats[0] + print(f'* Working on {file} lengths statistics...') + LengthStats(file) diff --git a/src/mrc_client/seq2seq/README.md b/src/mrc_client/seq2seq/README.md new file mode 100644 index 0000000000000000000000000000000000000000..2802347e66d60169fbc62be8ca597db6dbc17ffe --- /dev/null +++ b/src/mrc_client/seq2seq/README.md @@ -0,0 +1,590 @@ +## Sequence to Sequence Training and Evaluation + +This directory contains examples for finetuning and evaluating transformers on summarization and translation tasks. +Please tag @patil-suraj with any issues/unexpected behaviors, or send a PR! +For deprecated `bertabs` instructions, see [`bertabs/README.md`](bertabs/README.md). + +### Supported Architectures + +- `BartForConditionalGeneration` (and anything that inherits from it) +- `MarianMTModel` +- `PegasusForConditionalGeneration` +- `MBartForConditionalGeneration` +- `FSMTForConditionalGeneration` +- `T5ForConditionalGeneration` + +## Datasets + +#### XSUM + +```bash +cd examples/seq2seq +wget https://cdn-datasets.huggingface.co/summarization/xsum.tar.gz +tar -xzvf xsum.tar.gz +export XSUM_DIR=${PWD}/xsum +``` +this should make a directory called `xsum/` with files like `test.source`. +To use your own data, copy that files format. Each article to be summarized is on its own line. + +#### CNN/DailyMail + +```bash +cd examples/seq2seq +wget https://cdn-datasets.huggingface.co/summarization/cnn_dm_v2.tgz +tar -xzvf cnn_dm_v2.tgz # empty lines removed +mv cnn_cln cnn_dm +export CNN_DIR=${PWD}/cnn_dm +``` +this should make a directory called `cnn_dm/` with 6 files. + +#### WMT16 English-Romanian Translation Data + +download with this command: +```bash +wget https://cdn-datasets.huggingface.co/translation/wmt_en_ro.tar.gz +tar -xzvf wmt_en_ro.tar.gz +export ENRO_DIR=${PWD}/wmt_en_ro +``` +this should make a directory called `wmt_en_ro/` with 6 files. + +#### WMT English-German + +```bash +wget https://cdn-datasets.huggingface.co/translation/wmt_en_de.tgz +tar -xzvf wmt_en_de.tgz +export DATA_DIR=${PWD}/wmt_en_de +``` + +#### FSMT datasets (wmt) + +Refer to the scripts starting with `eval_` under: +https://github.com/huggingface/transformers/tree/master/scripts/fsmt + +#### Pegasus (multiple datasets) + +Multiple eval datasets are available for download from: +https://github.com/stas00/porting/tree/master/datasets/pegasus + + +#### Your Data + +If you are using your own data, it must be formatted as one directory with 6 files: +``` +train.source +train.target +val.source +val.target +test.source +test.target +``` +The `.source` files are the input, the `.target` files are the desired output. + +### Tips and Tricks + +General Tips: +- since you need to run from `examples/seq2seq`, and likely need to modify code, the easiest workflow is fork transformers, clone your fork, and run `pip install -e .` before you get started. +- try `--freeze_encoder` or `--freeze_embeds` for faster training/larger batch size. (3hr per epoch with bs=8, see the "xsum_shared_task" command below) +- `fp16_opt_level=O1` (the default works best). +- In addition to the pytorch-lightning .ckpt checkpoint, a transformers checkpoint will be saved. +Load it with `BartForConditionalGeneration.from_pretrained(f'{output_dir}/best_tfmr)`. +- At the moment, `--do_predict` does not work in a multi-gpu setting. You need to use `evaluate_checkpoint` or the `run_eval.py` code. +- This warning can be safely ignored: + > "Some weights of BartForConditionalGeneration were not initialized from the model checkpoint at facebook/bart-large-xsum and are newly initialized: ['final_logits_bias']" +- Both finetuning and eval are 30% faster with `--fp16`. For that you need to [install apex](https://github.com/NVIDIA/apex#quick-start). +- Read scripts before you run them! + +Summarization Tips: +- (summ) 1 epoch at batch size 1 for bart-large takes 24 hours and requires 13GB GPU RAM with fp16 on an NVIDIA-V100. +- If you want to run experiments on improving the summarization finetuning process, try the XSUM Shared Task (below). It's faster to train than CNNDM because the summaries are shorter. +- For CNN/DailyMail, the default `val_max_target_length` and `test_max_target_length` will truncate the ground truth labels, resulting in slightly higher rouge scores. To get accurate rouge scores, you should rerun calculate_rouge on the `{output_dir}/test_generations.txt` file saved by `trainer.test()` +- `--max_target_length=60 --val_max_target_length=60 --test_max_target_length=100 ` is a reasonable setting for XSUM. +- `wandb` can be used by specifying `--logger_name wandb`. It is useful for reproducibility. Specify the environment variable `WANDB_PROJECT='hf_xsum'` to do the XSUM shared task. +- If you are finetuning on your own dataset, start from `distilbart-cnn-12-6` if you want long summaries and `distilbart-xsum-12-6` if you want short summaries. +(It rarely makes sense to start from `bart-large` unless you are a researching finetuning methods). + +**Update 2018-07-18** +Datasets: `LegacySeq2SeqDataset` will be used for all tokenizers without a `prepare_seq2seq_batch` method. Otherwise, `Seq2SeqDataset` will be used. +Future work/help wanted: A new dataset to support multilingual tasks. + + +### Finetuning Scripts +All finetuning bash scripts call finetune.py (or distillation.py) with reasonable command line arguments. They usually require extra command line arguments to work. + +To see all the possible command line options, run: + +```bash +./finetune.py --help +``` + +### Finetuning Training Params + +To override the pretrained model's training params, you can pass them to `./finetune.sh`: + +```bash +./finetune.sh \ + [...] + --encoder_layerdrop 0.1 \ + --decoder_layerdrop 0.1 \ + --dropout 0.1 \ + --attention_dropout 0.1 \ +``` + +### Summarization Finetuning +Run/modify `finetune.sh` + +The following command should work on a 16GB GPU: +```bash +./finetune.sh \ + --data_dir $XSUM_DIR \ + --train_batch_size=1 \ + --eval_batch_size=1 \ + --output_dir=xsum_results \ + --num_train_epochs 6 \ + --model_name_or_path facebook/bart-large +``` + +There is a starter finetuning script for pegasus at `finetune_pegasus_xsum.sh`. + +### Translation Finetuning + +First, follow the wmt_en_ro download instructions. +Then you can finetune mbart_cc25 on english-romanian with the following command. +**Recommendation:** Read and potentially modify the fairly opinionated defaults in `train_mbart_cc25_enro.sh` script before running it. + +Best performing command: +```bash +# optionally +export ENRO_DIR='wmt_en_ro' # Download instructions above +# export WANDB_PROJECT="MT" # optional +export MAX_LEN=128 +export BS=4 +./train_mbart_cc25_enro.sh --output_dir enro_finetune_baseline --label_smoothing 0.1 --fp16_opt_level=O1 --logger_name wandb --sortish_sampler +``` +This should take < 6h/epoch on a 16GB v100 and achieve test BLEU above 26 +To get results in line with fairseq, you need to do some postprocessing. (see `romanian_postprocessing.md`) + +MultiGPU command +(using 8 GPUS as an example) +```bash +export ENRO_DIR='wmt_en_ro' # Download instructions above + # export WANDB_PROJECT="MT" # optional +export MAX_LEN=128 +export BS=4 +./train_mbart_cc25_enro.sh --output_dir enro_finetune_baseline --gpus 8 --logger_name wandb +``` +### Finetuning Outputs +As you train, `output_dir` will be filled with files, that look kind of like this (comments are mine). +Some of them are metrics, some of them are checkpoints, some of them are metadata. Here is a quick tour: + +```bash +output_dir +├── best_tfmr # this is a huggingface checkpoint generated by save_pretrained. It is the same model as the PL .ckpt file below +│   ├── config.json +│   ├── merges.txt +│   ├── pytorch_model.bin +│   ├── special_tokens_map.json +│   ├── tokenizer_config.json +│   └── vocab.json +├── git_log.json # repo, branch, and commit hash +├── val_avg_rouge2=0.1984-step_count=11.ckpt # this is a pytorch lightning checkpoint associated with the best val score. (it will be called BLEU for MT) +├── metrics.json # new validation metrics will continually be appended to this +├── student # this is a huggingface checkpoint generated by SummarizationDistiller. It is the student before it gets finetuned. +│   ├── config.json +│   └── pytorch_model.bin +├── test_generations.txt +# ^^ are the summaries or translations produced by your best checkpoint on the test data. Populated when training is done +├── test_results.txt # a convenience file with the test set metrics. This data is also in metrics.json['test'] +├── hparams.pkl # the command line args passed after some light preprocessing. Should be saved fairly quickly. +``` +After training, you can recover the best checkpoint by running +```python +from transformers import AutoModelForSeq2SeqLM +model = AutoModelForSeq2SeqLM.from_pretrained(f'{output_dir}/best_tfmr') +``` + +### Fine-tuning using Seq2SeqTrainer +To use `Seq2SeqTrainer` for fine-tuning you should use the `finetune_trainer.py` script. It subclasses `Trainer` to extend it for seq2seq training. Except the `Trainer` releated `TrainingArguments`, it shares the same argument names as that of `finetune.py` file. One notable difference is that, calculating generative metrics (BLEU, ROUGE) is optional and is controlled using the `--predict_with_generate` argument, set this argument to calculate BLEU and ROUGE metrics. + +With PyTorch 1.6+ it'll automatically use `native AMP` when `--fp16` is set. + +To see all the possible command line options, run: + +```bash +./builtin_trainer/finetune.sh --help # This calls python finetune_trainer.py --help +``` + +**At the moment, `Seq2SeqTrainer` does not support *with teacher* distillation.** + +All `Seq2SeqTrainer` based fine-tuning scripts are included in the `builtin_trainer` directory. + +#### TPU Training +`Seq2SeqTrainer` supports TPU training with few caveats +1. As `generate` method does not work on TPU at the moment, `predict_with_generate` can not be used. You should use `--prediction_loss_only` to only calculate loss, and do not set `--do_predict` and `--predict_with_generate`. +2. All sequences should be padded to be of equal length otherwise it leads to extremely slow training. (`finetune_trainer.py` does this automatically when running on TPU.) + +We provide a very simple launcher script named `xla_spawn.py` that lets you run our example scripts on multiple TPU cores without any boilerplate. Just pass a --num_cores flag to this script, then your regular training script with its arguments (this is similar to the torch.distributed.launch helper for torch.distributed). + +`builtin_trainer/finetune_tpu.sh` script provides minimal arguments needed for TPU training. + +Following command fine-tunes `sshleifer/student_marian_en_ro_6_3` on TPU V3-8 and should complete one epoch in ~5-6 mins. + +```bash +./builtin_trainer/train_distil_marian_enro_tpu.sh +``` + +# DistilBART + +This section describes all code and artifacts from our [Paper](http://arxiv.org/abs/2010.13002) + +![DBART](https://huggingface.co/front/thumbnails/distilbart_large.png) + ++ For the CNN/DailyMail dataset, (relatively longer, more extractive summaries), we found a simple technique that works, which we call "Shrink and Fine-tune", or SFT. +you just copy alternating layers from `facebook/bart-large-cnn` and fine-tune more on the cnn/dm data. `sshleifer/distill-pegasus-cnn-16-4`, `sshleifer/distilbart-cnn-12-6` and all other checkpoints under `sshleifer` that start with `distilbart-cnn` were trained this way. ++ For the XSUM dataset, training on pseudo-labels worked best for Pegasus (`sshleifer/distill-pegasus-16-4`), while training with KD worked best for `distilbart-xsum-12-6` ++ For `sshleifer/dbart-xsum-12-3` ++ We ran 100s experiments, and didn't want to document 100s of commands. If you want a command to replicate a figure from the paper that is not documented below, feel free to ask on the [forums](https://discuss.huggingface.co/t/seq2seq-distillation-methodology-questions/1270) and tag `@sshleifer`. ++ You can see the performance tradeoffs of model sizes [here](https://docs.google.com/spreadsheets/d/1EkhDMwVO02m8jCD1cG3RoFPLicpcL1GQHTQjfvDYgIM/edit#gid=0). +and more granular timing results [here](https://docs.google.com/spreadsheets/d/1EkhDMwVO02m8jCD1cG3RoFPLicpcL1GQHTQjfvDYgIM/edit#gid=1753259047&range=B2:I23). + +### Evaluation + +use [run_distributed_eval](./run_distributed_eval.py), with the following convenient alias +```bash +deval () { + proc=$1 + m=$2 + dd=$3 + sd=$4 + shift + shift + shift + shift + python -m torch.distributed.launch --nproc_per_node=$proc run_distributed_eval.py \ + --model_name $m --save_dir $sd --data_dir $dd $@ +} +``` +On a 1 GPU system, here are four commands (that assume `xsum`, `cnn_dm` are downloaded, cmd-F for those links in this file). + +`distilBART`: +```bash +deval 1 sshleifer/distilbart-xsum-12-3 xsum dbart_12_3_xsum_eval --fp16 # --help for more choices. +deval 1 sshleifer/distilbart-cnn_dm-12-6 cnn_dm dbart_12_6_cnn_eval --fp16 +``` + +`distill-pegasus`: +```bash +deval 1 sshleifer/distill-pegasus-cnn-16-4 cnn_dm dpx_cnn_eval +deval 1 sshleifer/distill-pegasus-xsum-16-4 xsum dpx_xsum_eval +``` + +### Distillation ++ For all of the following commands, you can get roughly equivalent result and faster run times by passing `--num_beams=4`. That's not what we did for the paper. ++ Besides the KD section, you can also run commands with the built-in transformers trainer. See, for example, [builtin_trainer/train_distilbart_cnn.sh](./builtin_trainer/train_distilbart_cnn.sh). ++ Large performance deviations (> 5X slower or more than 0.5 Rouge-2 worse), should be reported. ++ Multi-gpu (controlled with `--gpus` should work, but might require more epochs). + +#### Recommended Workflow ++ Get your dataset in the right format. (see 6 files above). ++ Find a teacher model [Pegasus](https://huggingface.co/models?search=pegasus) (slower, better ROUGE) or `facebook/bart-large-xsum`/`facebook/bart-large-cnn` (faster, slightly lower.). +Choose the checkpoint where the corresponding dataset is most similar (or identical to) your dataset. ++ Follow the sections in order below. You can stop after SFT if you are satisfied, or move on to pseudo-labeling if you want more performance. ++ student size: If you want a close to free 50% speedup, cut the decoder in half. If you want a larger speedup, cut it in 4. ++ If your SFT run starts at a validation ROUGE-2 that is more than 10 pts below the teacher's validation ROUGE-2, you have a bug. Switching to a more expensive technique will not help. Try setting a breakpoint and looking at generation and truncation defaults/hyper-parameters, and share your experience on the forums! + + +#### Initialization +We use [make_student.py](./make_student.py) to copy alternating layers from the teacher, and save the resulting model to disk +```bash +python make_student.py facebook/bart-large-xsum --save_path dbart_xsum_12_3 -e 12 -d 3 +``` +or for `pegasus-xsum` +```bash +python make_student.py google/pegasus-xsum --save_path dpx_xsum_16_4 --e 16 --d 4 +``` +we now have an initialized student saved to `dbart_xsum_12_3`, which we will use for the following commands. ++ Extension: To replicate more complicated initialize experiments in section 6.1, or try your own. Use the `create_student_by_copying_alternating_layers` function. + +#### Pegasus ++ The following commands are written for BART and will require, at minimum, the following modifications ++ reduce batch size, and increase gradient accumulation steps so that the product `gpus * batch size * gradient_accumulation_steps = 256`. We used `--learning-rate` = 1e-4 * gradient accumulation steps. ++ don't use fp16 ++ `--tokenizer_name google/pegasus-large` + +### SFT (No Teacher Distillation) +You don't need `distillation.py`, you can just run: + +```bash +python finetune.py \ + --data_dir xsum \ + --freeze_encoder --freeze_embeds \ + --learning_rate=3e-4 \ + --do_train \ + --do_predict \ + --fp16 --fp16_opt_level=O1 \ + --val_check_interval 0.1 --n_val 1000 --eval_beams 2 --length_penalty=0.5 \ + --max_target_length=60 --val_max_target_length=60 --test_max_target_length=100 \ + --model_name_or_path dbart_xsum_12_3 \ + --train_batch_size=64 --eval_batch_size=64 \ + --sortish_sampler \ + --num_train_epochs=6 \ + --warmup_steps 500 \ + --output_dir distilbart_xsum_sft_12_3 --gpus 1 +``` + ++ Note: The command that produced `sshleifer/distilbart-cnn-12-6` is at [train_distilbart_cnn.sh](./[train_distilbart_cnn.sh) + +```bash +./train_distilbart_cnn.sh +``` + ++ Tip: You can get the same simple distillation logic by using `distillation.py --no_teacher ` followed by identical arguments as the ones in `train_distilbart_cnn.sh`. +If you are using `wandb` and comparing the two distillation methods, using this entry point will make your logs consistent, +because you will have the same hyper-parameters logged in every run. + +### Pseudo-Labeling ++ You don't need `distillation.py`. ++ Instructions to generate pseudo-labels and use pre-computed pseudo-labels can be found [here](./precomputed_pseudo_labels.md). +Simply run `finetune.py` with one of those pseudo-label datasets as `--data_dir` (`DATA`, below). + +```bash +python finetune.py \ + --teacher facebook/bart-large-xsum --data_dir DATA \ + --freeze_encoder --freeze_embeds \ + --learning_rate=3e-4 \ + --do_train \ + --do_predict \ + --fp16 --fp16_opt_level=O1 \ + --val_check_interval 0.1 --n_val 1000 --eval_beams 2 --length_penalty=0.5 \ + --max_target_length=60 --val_max_target_length=60 --test_max_target_length=100 \ + --model_name_or_path dbart_xsum_12_3 \ + --train_batch_size=32 --eval_batch_size=32 \ + --sortish_sampler \ + --num_train_epochs=5 \ + --warmup_steps 500 \ + --output_dir dbart_xsum_12_3_PL --gpus 1 --logger_name wandb +``` + + + +To combine datasets, as in Section 6.2, try something like: +```bash +curl -S https://cdn-datasets.huggingface.co/pseudo/xsum/bart_xsum_pl.tgz | tar -xvz -C . +curl -S https://cdn-datasets.huggingface.co/pseudo/xsum/pegasus_xsum.tgz | tar -xvz -C . +curl -S https://cdn-datasets.huggingface.co/summarization/xsum.tar.gz | tar -xvz -C . +mkdir all_pl +cat bart_xsum_pl/train.source pegasus_xsum/train.source xsum/train.source > all_pl/train.source +cat bart_xsum_pl/train.target pegasus_xsum/train.target xsum/train.target > all_pl/train.target +cp xsum/val* all_pl +cp xsum/test* all_pl +``` +then use `all_pl` as DATA in the command above. + +#### Direct Knowledge Distillation (KD) ++ In this method, we use try to enforce that the student and teacher produce similar encoder_outputs, logits, and hidden_states using `BartSummarizationDistiller`. ++ This method was used for `sshleifer/distilbart-xsum-12-6`, `6-6`, and `9-6` checkpoints were produced. ++ You must use [`distillation.py`](./distillation.py). Note that this command initializes the student for you. + +The command that produced `sshleifer/distilbart-xsum-12-6` is at [./train_distilbart_xsum.sh](train_distilbart_xsum.sh) +```bash +./train_distilbart_xsum.sh --logger_name wandb --gpus 1 +``` + ++ Expected ROUGE-2 between 21.3 and 21.6, run time ~13H. ++ direct KD + Pegasus is VERY slow and works best with `--supervise_forward --normalize_hidden`. + + + +### Citation + +```bibtex +@misc{shleifer2020pretrained, + title={Pre-trained Summarization Distillation}, + author={Sam Shleifer and Alexander M. Rush}, + year={2020}, + eprint={2010.13002}, + archivePrefix={arXiv}, + primaryClass={cs.CL} +} +@article{Wolf2019HuggingFacesTS, + title={HuggingFace's Transformers: State-of-the-art Natural Language Processing}, + author={Thomas Wolf and Lysandre Debut and Victor Sanh and Julien Chaumond and Clement Delangue and Anthony Moi and Pierric Cistac and Tim Rault and Rémi Louf and Morgan Funtowicz and Joe Davison and Sam Shleifer and Patrick von Platen and Clara Ma and Yacine Jernite and Julien Plu and Canwen Xu and Teven Le Scao and Sylvain Gugger and Mariama Drame and Quentin Lhoest and Alexander M. Rush}, + journal={ArXiv}, + year={2019}, + volume={abs/1910.03771} +} +``` + +This is the end of the distillation section, the rest of this doc pertains to general seq2seq commands. + +## Evaluation Commands + +To create summaries for each article in dataset, we use `run_eval.py`, here are a few commands that run eval for different tasks and models. +If 'translation' is in your task name, the computed metric will be BLEU. Otherwise, ROUGE will be used. + +For t5, you need to specify --task translation_{src}_to_{tgt} as follows: +```bash +export DATA_DIR=wmt_en_ro +./run_eval.py t5-base \ + $DATA_DIR/val.source t5_val_generations.txt \ + --reference_path $DATA_DIR/val.target \ + --score_path enro_bleu.json \ + --task translation_en_to_ro \ + --n_obs 100 \ + --device cuda \ + --fp16 \ + --bs 32 +``` + +This command works for MBART, although the BLEU score is suspiciously low. +```bash +export DATA_DIR=wmt_en_ro +./run_eval.py facebook/mbart-large-en-ro $DATA_DIR/val.source mbart_val_generations.txt \ + --reference_path $DATA_DIR/val.target \ + --score_path enro_bleu.json \ + --task translation \ + --n_obs 100 \ + --device cuda \ + --fp16 \ + --bs 32 +``` + +Summarization (xsum will be very similar): +```bash +export DATA_DIR=cnn_dm +./run_eval.py sshleifer/distilbart-cnn-12-6 $DATA_DIR/val.source dbart_val_generations.txt \ + --reference_path $DATA_DIR/val.target \ + --score_path cnn_rouge.json \ + --task summarization \ + --n_obs 100 \ + +th 56 \ + --fp16 \ + --bs 32 +``` + +### Multi-GPU Evaluation +here is a command to run xsum evaluation on 8 GPUS. It is more than linearly faster than run_eval.py in some cases +because it uses SortishSampler to minimize padding. You can also use it on 1 GPU. `data_dir` must have +`{type_path}.source` and `{type_path}.target`. Run `./run_distributed_eval.py --help` for all clargs. + +```bash +python -m torch.distributed.launch --nproc_per_node=8 run_distributed_eval.py \ + --model_name sshleifer/distilbart-large-xsum-12-3 \ + --save_dir xsum_generations \ + --data_dir xsum \ + --fp16 # you can pass generate kwargs like num_beams here, just like run_eval.py +``` + +Contributions that implement this command for other distributed hardware setups are welcome! + +#### Single-GPU Eval: Tips and Tricks + +When using `run_eval.py`, the following features can be useful: + +* if you running the script multiple times and want to make it easier to track what arguments produced that output, use `--dump-args`. Along with the results it will also dump any custom params that were passed to the script. For example if you used: `--num_beams 8 --early_stopping true`, the output will be: + ``` + {'bleu': 26.887, 'n_obs': 10, 'runtime': 1, 'seconds_per_sample': 0.1, 'num_beams': 8, 'early_stopping': True} + ``` + + `--info` is an additional argument available for the same purpose of tracking the conditions of the experiment. It's useful to pass things that weren't in the argument list, e.g. a language pair `--info "lang:en-ru"`. But also if you pass `--info` without a value it will fallback to the current date/time string, e.g. `2020-09-13 18:44:43`. + + If using `--dump-args --info`, the output will be: + + ``` + {'bleu': 26.887, 'n_obs': 10, 'runtime': 1, 'seconds_per_sample': 0.1, 'num_beams': 8, 'early_stopping': True, 'info': '2020-09-13 18:44:43'} + ``` + + If using `--dump-args --info "pair:en-ru chkpt=best`, the output will be: + + ``` + {'bleu': 26.887, 'n_obs': 10, 'runtime': 1, 'seconds_per_sample': 0.1, 'num_beams': 8, 'early_stopping': True, 'info': 'pair=en-ru chkpt=best'} + ``` + + +* if you need to perform a parametric search in order to find the best ones that lead to the highest BLEU score, let `run_eval_search.py` to do the searching for you. + + The script accepts the exact same arguments as `run_eval.py`, plus an additional argument `--search`. The value of `--search` is parsed, reformatted and fed to ``run_eval.py`` as additional args. + + The format for the `--search` value is a simple string with hparams and colon separated values to try, e.g.: + ``` + --search "num_beams=5:10 length_penalty=0.8:1.0:1.2 early_stopping=true:false" + ``` + which will generate `12` `(2*3*2)` searches for a product of each hparam. For example the example that was just used will invoke `run_eval.py` repeatedly with: + + ``` + --num_beams 5 --length_penalty 0.8 --early_stopping true + --num_beams 5 --length_penalty 0.8 --early_stopping false + [...] + --num_beams 10 --length_penalty 1.2 --early_stopping false + ``` + + On completion, this function prints a markdown table of the results sorted by the best BLEU score and the winning arguments. + +``` +bleu | num_beams | length_penalty | early_stopping +----- | --------- | -------------- | -------------- +26.71 | 5 | 1.1 | 1 +26.66 | 5 | 0.9 | 1 +26.66 | 5 | 0.9 | 0 +26.41 | 5 | 1.1 | 0 +21.94 | 1 | 0.9 | 1 +21.94 | 1 | 0.9 | 0 +21.94 | 1 | 1.1 | 1 +21.94 | 1 | 1.1 | 0 + +Best score args: +stas/wmt19-en-ru data/en-ru/val.source data/en-ru/test_translations.txt --reference_path data/en-ru/val.target --score_path data/en-ru/test_bleu.json --bs 8 --task translation --num_beams 5 --length_penalty 1.1 --early_stopping True +``` + +If you pass `--info "some experiment-specific info"` it will get printed before the results table - this is useful for scripting and multiple runs, so one can tell the different sets of results from each other. + + +### Contributing +- follow the standard contributing guidelines and code of conduct. +- add tests to `test_seq2seq_examples.py` +- To run only the seq2seq tests, you must be in the root of the repository and run: +```bash +pytest examples/seq2seq/ +``` + +### Converting pytorch-lightning checkpoints +pytorch lightning ``-do_predict`` often fails, after you are done training, the best way to evaluate your model is to convert it. + +This should be done for you, with a file called `{save_dir}/best_tfmr`. + +If that file doesn't exist but you have a lightning `.ckpt` file, you can run +```bash +python convert_pl_checkpoint_to_hf.py PATH_TO_CKPT randomly_initialized_hf_model_path save_dir/best_tfmr +``` +Then either `run_eval` or `run_distributed_eval` with `save_dir/best_tfmr` (see previous sections) + + +# Experimental Features +These features are harder to use and not always useful. + +### Dynamic Batch Size for MT +`finetune.py` has a command line arg `--max_tokens_per_batch` that allows batches to be dynamically sized. +This feature can only be used: +- with fairseq installed +- on 1 GPU +- without sortish sampler +- after calling `./save_len_file.py $tok $data_dir` + +For example, +```bash +./save_len_file.py Helsinki-NLP/opus-mt-en-ro wmt_en_ro +./dynamic_bs_example.sh --max_tokens_per_batch=2000 --output_dir benchmark_dynamic_bs +``` +splits `wmt_en_ro/train` into 11,197 uneven lengthed batches and can finish 1 epoch in 8 minutes on a v100. + +For comparison, +```bash +./dynamic_bs_example.sh --sortish_sampler --train_batch_size 48 +``` +uses 12,723 batches of length 48 and takes slightly more time 9.5 minutes. + +The feature is still experimental, because: ++ we can make it much more robust if we have memory mapped/preprocessed datasets. ++ The speedup over sortish sampler is not that large at the moment. + + diff --git a/src/mrc_client/seq2seq/__init__.py b/src/mrc_client/seq2seq/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3cee09bb7f51087e92d778c4c9e27d76085d1b30 --- /dev/null +++ b/src/mrc_client/seq2seq/__init__.py @@ -0,0 +1,5 @@ +import os +import sys + + +sys.path.insert(1, os.path.dirname(os.path.realpath(__file__))) diff --git a/src/mrc_client/seq2seq/callbacks.py b/src/mrc_client/seq2seq/callbacks.py new file mode 100644 index 0000000000000000000000000000000000000000..f4a5732d06085a501b37e2c40e47ba4ae84bdd09 --- /dev/null +++ b/src/mrc_client/seq2seq/callbacks.py @@ -0,0 +1,115 @@ +import logging +import os +from pathlib import Path + +import numpy as np +import pytorch_lightning as pl +import torch +from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint +from pytorch_lightning.utilities import rank_zero_only + +from seq2seq_utils import save_json + + +def count_trainable_parameters(model): + model_parameters = filter(lambda p: p.requires_grad, model.parameters()) + params = sum([np.prod(p.size()) for p in model_parameters]) + return params + + +logger = logging.getLogger(__name__) + + +class Seq2SeqLoggingCallback(pl.Callback): + def on_batch_end(self, trainer, pl_module): + lrs = {f"lr_group_{i}": param["lr"] for i, param in enumerate(pl_module.trainer.optimizers[0].param_groups)} + pl_module.logger.log_metrics(lrs) + + @rank_zero_only + def _write_logs( + self, trainer: pl.Trainer, pl_module: pl.LightningModule, type_path: str, save_generations=True + ) -> None: + logger.info(f"***** {type_path} results at step {trainer.global_step:05d} *****") + metrics = trainer.callback_metrics + trainer.logger.log_metrics({k: v for k, v in metrics.items() if k not in ["log", "progress_bar", "preds"]}) + # Log results + od = Path(pl_module.hparams.output_dir) + if type_path == "test": + results_file = od / "test_results.txt" + generations_file = od / "test_generations.txt" + else: + # this never gets hit. I prefer not to save intermediate generations, and results are in metrics.json + # If people want this it will be easy enough to add back. + results_file = od / f"{type_path}_results/{trainer.global_step:05d}.txt" + generations_file = od / f"{type_path}_generations/{trainer.global_step:05d}.txt" + results_file.parent.mkdir(exist_ok=True) + generations_file.parent.mkdir(exist_ok=True) + with open(results_file, "a+") as writer: + for key in sorted(metrics): + if key in ["log", "progress_bar", "preds"]: + continue + val = metrics[key] + if isinstance(val, torch.Tensor): + val = val.item() + msg = f"{key}: {val:.6f}\n" + writer.write(msg) + + if not save_generations: + return + + if "preds" in metrics: + content = "\n".join(metrics["preds"]) + generations_file.open("w+").write(content) + + @rank_zero_only + def on_train_start(self, trainer, pl_module): + try: + npars = pl_module.model.model.num_parameters() + except AttributeError: + npars = pl_module.model.num_parameters() + + n_trainable_pars = count_trainable_parameters(pl_module) + # mp stands for million parameters + trainer.logger.log_metrics({"n_params": npars, "mp": npars / 1e6, "grad_mp": n_trainable_pars / 1e6}) + + @rank_zero_only + def on_test_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule): + save_json(pl_module.metrics, pl_module.metrics_save_path) + return self._write_logs(trainer, pl_module, "test") + + @rank_zero_only + def on_validation_end(self, trainer: pl.Trainer, pl_module): + save_json(pl_module.metrics, pl_module.metrics_save_path) + # Uncommenting this will save val generations + # return self._write_logs(trainer, pl_module, "valid") + + +def get_checkpoint_callback(output_dir, metric, save_top_k=1, lower_is_better=False): + """Saves the best model by validation ROUGE2 score.""" + if metric == "rouge2": + exp = "{val_avg_rouge2:.4f}-{step_count}" + elif metric == "bleu": + exp = "{val_avg_bleu:.4f}-{step_count}" + elif metric == "loss": + exp = "{val_avg_loss:.4f}-{step_count}" + else: + raise NotImplementedError( + f"seq2seq callbacks only support rouge2, bleu and loss, got {metric}, You can make your own by adding to this function." + ) + + checkpoint_callback = ModelCheckpoint( + filepath=os.path.join(output_dir, exp), + monitor=f"val_{metric}", + mode="min" if "loss" in metric else "max", + save_top_k=save_top_k, + ) + return checkpoint_callback + + +def get_early_stopping_callback(metric, patience): + return EarlyStopping( + monitor=f"val_{metric}", # does this need avg? + mode="min" if "loss" in metric else "max", + patience=patience, + verbose=True, + ) diff --git a/src/mrc_client/seq2seq/cjjpy.py b/src/mrc_client/seq2seq/cjjpy.py new file mode 100755 index 0000000000000000000000000000000000000000..2cc70b5e553924123810ab198c143bf7ee28e5d6 --- /dev/null +++ b/src/mrc_client/seq2seq/cjjpy.py @@ -0,0 +1,249 @@ +# -*- coding: utf-8 -*- + +''' +@Author : Jiangjie Chen +@Time : 2018/11/15 17:08 +@Contact: jjchen19@fudan.edu.cn +''' + +import re +import datetime +import os +import argparse +import logging +import traceback + +try: + import ujson as json +except: + import json + +HADOOP_BIN = 'PATH=/usr/bin/:$PATH hdfs' +FOR_PUBLIC = True + + +def LengthStats(filename): + len_list = [] + thresholds = [0.8, 0.9, 0.95, 0.99, 0.999] + with open(filename) as f: + for line in f: + len_list.append(len(line.strip().split())) + stats = { + 'Max': max(len_list), + 'Min': min(len_list), + 'Avg': round(sum(len_list) / len(len_list), 4), + } + len_list.sort() + for t in thresholds: + stats[f"Top-{t}"] = len_list[int(len(len_list) * t)] + + for k in stats: + print(f"- {k}: {stats[k]}") + return stats + + +class AttrDict(dict): + def __init__(self, *args, **kwargs): + super(AttrDict, self).__init__(*args, **kwargs) + self.__dict__ = self + + +def TraceBack(error_msg): + exc = traceback.format_exc() + msg = f'[Error]: {error_msg}.\n[Traceback]: {exc}' + return msg + + +def Now(): + return datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") + + +def AbsParentDir(file, parent='..', postfix=None): + ppath = os.path.abspath(file) + parent_level = parent.count('.') + while parent_level > 0: + ppath = os.path.dirname(ppath) + parent_level -= 1 + if postfix is not None: + return os.path.join(ppath, postfix) + else: + return ppath + + +def init_logger(log_file=None, log_file_level=logging.NOTSET, from_scratch=False): + from coloredlogs import ColoredFormatter + import tensorflow as tf + + fmt = "[%(asctime)s %(levelname)s] %(message)s" + log_format = ColoredFormatter(fmt=fmt) + # log_format = logging.Formatter() + logger = logging.getLogger() + logger.setLevel(log_file_level) + + console_handler = logging.StreamHandler() + console_handler.setFormatter(log_format) + logger.handlers = [console_handler] + + if log_file and log_file != '': + if from_scratch and tf.io.gfile.exists(log_file): + logger.warning('Removing previous log file: %s' % log_file) + tf.io.gfile.remove(log_file) + path = os.path.dirname(log_file) + os.makedirs(path, exist_ok=True) + file_handler = logging.FileHandler(log_file) + file_handler.setLevel(log_file_level) + file_handler.setFormatter(log_format) + logger.addHandler(file_handler) + + return logger + + +def OverWriteCjjPy(root='.'): + # import difflib + # diff = difflib.HtmlDiff() + cnt = 0 + golden_cjjpy = os.path.join(root, 'cjjpy.py') + # golden_content = open(golden_cjjpy).readlines() + for dir, folder, file in os.walk(root): + for f in file: + if f == 'cjjpy.py': + cjjpy = '%s/%s' % (dir, f) + # content = open(cjjpy).readlines() + # d = diff.make_file(golden_content, content) + cnt += 1 + print('[%d]: %s' % (cnt, cjjpy)) + os.system('cp %s %s' % (golden_cjjpy, cjjpy)) + + +def ChangeFileFormat(filename, new_fmt): + assert type(filename) is str and type(new_fmt) is str + spt = filename.split('.') + if len(spt) == 0: + return filename + else: + return filename.replace('.' + spt[-1], new_fmt) + + +def CountLines(fname): + with open(fname, 'rb') as f: + count = 0 + last_data = '\n' + while True: + data = f.read(0x400000) + if not data: + break + count += data.count(b'\n') + last_data = data + if last_data[-1:] != b'\n': + count += 1 # Remove this if a wc-like count is needed + return count + + +def GetDate(): + return str(datetime.datetime.now())[5:10].replace('-', '') + + +def TimeClock(seconds): + sec = int(seconds) + hour = int(sec / 3600) + min = int((sec - hour * 3600) / 60) + ssec = float(seconds) - hour * 3600 - min * 60 + # return '%dh %dm %.2fs' % (hour, min, ssec) + return '{0:>2d}h{1:>3d}m{2:>6.2f}s'.format(hour, min, ssec) + + +def StripAll(text): + return text.strip().replace('\t', '').replace('\n', '').replace(' ', '') + + +def GetBracket(text, bracket, en_br=False): + # input should be aa(bb)cc, True for bracket, False for text + if bracket: + try: + return re.findall('\((.*?)\)', text.strip())[-1] + except: + return '' + else: + if en_br: + text = re.sub('\(.*?\)', '', text.strip()) + return re.sub('(.*?)', '', text.strip()) + + +def CharLang(uchar, lang): + assert lang.lower() in ['en', 'cn', 'zh'] + if lang.lower() in ['cn', 'zh']: + if uchar >= '\u4e00' and uchar <= '\u9fa5': + return True + else: + return False + elif lang.lower() == 'en': + if (uchar <= 'Z' and uchar >= 'A') or (uchar <= 'z' and uchar >= 'a'): + return True + else: + return False + else: + raise NotImplementedError + + +def WordLang(word, lang): + for i in word.strip(): + if i.isspace(): continue + if not CharLang(i, lang): + return False + return True + + +def SortDict(_dict, reverse=True): + assert type(_dict) is dict + return sorted(_dict.items(), key=lambda d: d[1], reverse=reverse) + + +def lark(content='test'): + print(content) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + + parser.add_argument('--diff', nargs=2, + help='show difference between two files, shown in downloads/diff.html') + parser.add_argument('--de_unicode', action='store_true', default=False, + help='remove unicode characters') + parser.add_argument('--link_entity', action='store_true', default=False, + help='') + parser.add_argument('--max_comm_len', action='store_true', default=False, + help='') + parser.add_argument('--search', nargs=2, + help='search key from file, 2 args: file name & key') + parser.add_argument('--email', nargs=2, + help='sending emails, 2 args: subject & content') + parser.add_argument('--overwrite', action='store_true', default=None, + help='overwrite all cjjpy under given *dir* based on *dir*/cjjpy.py') + parser.add_argument('--replace', nargs=3, + help='replace char, 3 args: file name & replaced char & replacer char') + parser.add_argument('--lark', nargs=1) + parser.add_argument('--get_hdfs', nargs=2, + help='easy copy from hdfs to local fs, 2 args: remote_file/dir & local_dir') + parser.add_argument('--put_hdfs', nargs=2, + help='easy put from local fs to hdfs, 2 args: local_file/dir & remote_dir') + parser.add_argument('--length_stats', nargs=1, + help='simple token lengths distribution of a line-by-line file') + + args = parser.parse_args() + + if args.overwrite: + print('* Overwriting cjjpy...') + OverWriteCjjPy() + + if args.lark: + try: + content = args.lark[0] + except: + content = 'running complete' + print(f'* Larking "{content}"...') + lark(content) + + if args.length_stats: + file = args.length_stats[0] + print(f'* Working on {file} lengths statistics...') + LengthStats(file) diff --git a/src/mrc_client/seq2seq/convert_pl_checkpoint_to_hf.py b/src/mrc_client/seq2seq/convert_pl_checkpoint_to_hf.py new file mode 100755 index 0000000000000000000000000000000000000000..5f3c984f3724c1cb46ffcdc9e57b20a391a423cf --- /dev/null +++ b/src/mrc_client/seq2seq/convert_pl_checkpoint_to_hf.py @@ -0,0 +1,74 @@ +#!/usr/bin/env python + +import os +from pathlib import Path +from typing import Dict, List + +import fire +import torch + +from transformers import AutoModelForSeq2SeqLM, AutoTokenizer +from transformers.utils.logging import get_logger + + +logger = get_logger(__name__) + + +def remove_prefix(text: str, prefix: str): + if text.startswith(prefix): + return text[len(prefix) :] + return text # or whatever + + +def sanitize(sd): + return {remove_prefix(k, "model."): v for k, v in sd.items()} + + +def average_state_dicts(state_dicts: List[Dict[str, torch.Tensor]]): + new_sd = {} + for k in state_dicts[0].keys(): + tensors = [sd[k] for sd in state_dicts] + new_t = sum(tensors) / len(tensors) + assert isinstance(new_t, torch.Tensor) + new_sd[k] = new_t + return new_sd + + +def convert_pl_to_hf(pl_ckpt_path: str, hf_src_model_dir: str, save_path: str) -> None: + """Cleanup a pytorch-lightning .ckpt file or experiment dir and save a huggingface model with that state dict. + Silently allows extra pl keys (like teacher.) Puts all ckpt models into CPU RAM at once! + + Args: + pl_ckpt_path (:obj:`str`): Path to a .ckpt file saved by pytorch_lightning or dir containing ckpt files. + If a directory is passed, all .ckpt files inside it will be averaged! + hf_src_model_dir (:obj:`str`): Path to a directory containing a correctly shaped checkpoint + save_path (:obj:`str`): Directory to save the new model + + """ + hf_model = AutoModelForSeq2SeqLM.from_pretrained(hf_src_model_dir) + if os.path.isfile(pl_ckpt_path): + ckpt_files = [pl_ckpt_path] + else: + assert os.path.isdir(pl_ckpt_path) + ckpt_files = list(Path(pl_ckpt_path).glob("*.ckpt")) + assert ckpt_files, f"could not find any ckpt files inside the {pl_ckpt_path} directory" + + if len(ckpt_files) > 1: + logger.info(f"averaging the weights of {ckpt_files}") + + state_dicts = [sanitize(torch.load(x, map_location="cpu")["state_dict"]) for x in ckpt_files] + state_dict = average_state_dicts(state_dicts) + + missing, unexpected = hf_model.load_state_dict(state_dict, strict=False) + assert not missing, f"missing keys: {missing}" + hf_model.save_pretrained(save_path) + try: + tok = AutoTokenizer.from_pretrained(hf_src_model_dir) + tok.save_pretrained(save_path) + except Exception: + pass + # dont copy tokenizer if cant + + +if __name__ == "__main__": + fire.Fire(convert_pl_to_hf) diff --git a/src/mrc_client/seq2seq/finetune.py b/src/mrc_client/seq2seq/finetune.py new file mode 100755 index 0000000000000000000000000000000000000000..10ff08a4817d857b27fe77119e82d52b5f220616 --- /dev/null +++ b/src/mrc_client/seq2seq/finetune.py @@ -0,0 +1,465 @@ +#!/usr/bin/env python + +import argparse +import glob +import logging +import os +import sys +import time +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Tuple + +import numpy as np +import pytorch_lightning as pl +import torch +from torch.utils.data import DataLoader + +from callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback, get_early_stopping_callback +from transformers import MBartTokenizer, T5ForConditionalGeneration +try: + from transformers.modeling_bart import shift_tokens_right +except: + from transformers.models.bart.modeling_bart import shift_tokens_right +from seq2seq_utils import ( + ROUGE_KEYS, + LegacySeq2SeqDataset, + Seq2SeqDataset, + UniQASeq2SeqDataset, + assert_all_frozen, + calculate_bleu, + calculate_rouge, + check_output_dir, + flatten_list, + freeze_embeds, + freeze_params, + get_git_info, + label_smoothed_nll_loss, + lmap, + pickle_save, + save_git_info, + save_json, + use_task_specific_params, +) + + +# need the parent dir module +sys.path.insert(2, str(Path(__file__).resolve().parents[1])) +from lightning_base import BaseTransformer, add_generic_args, generic_train # noqa + + +logger = logging.getLogger(__name__) + + +class SummarizationModule(BaseTransformer): + mode = "summarization" + loss_names = ["loss"] + metric_names = ROUGE_KEYS + default_val_metric = "rouge2" + + def __init__(self, hparams, **kwargs): + if hparams.sortish_sampler and hparams.gpus > 1: + hparams.replace_sampler_ddp = False + elif hparams.max_tokens_per_batch is not None: + if hparams.gpus > 1: + raise NotImplementedError("Dynamic Batch size does not work for multi-gpu training") + if hparams.sortish_sampler: + raise ValueError("--sortish_sampler and --max_tokens_per_batch may not be used simultaneously") + + super().__init__(hparams, num_labels=None, mode=self.mode, **kwargs) + use_task_specific_params(self.model, "summarization") + # TODO: hard-encoded length constraint + self.model.config.min_length = hparams.min_target_length + self.model.config.max_length = hparams.max_target_length + save_git_info(self.hparams.output_dir) + self.metrics_save_path = Path(self.output_dir) / "metrics.json" + self.hparams_save_path = Path(self.output_dir) / "hparams.pkl" + pickle_save(self.hparams, self.hparams_save_path) + self.step_count = 0 + self.metrics = defaultdict(list) + self.model_type = self.config.model_type + self.vocab_size = self.config.tgt_vocab_size if self.model_type == "fsmt" else self.config.vocab_size + + self.dataset_kwargs: dict = dict( + data_dir=self.hparams.data_dir, + max_source_length=self.hparams.max_source_length, + prefix=self.model.config.prefix or "", + ) + n_observations_per_split = { + "train": self.hparams.n_train, + "val": self.hparams.n_val, + "test": self.hparams.n_test, + } + self.n_obs = {k: v if v >= 0 else None for k, v in n_observations_per_split.items()} + + self.target_lens = { + "train": self.hparams.max_target_length, + "val": self.hparams.val_max_target_length, + "test": self.hparams.test_max_target_length, + } + assert self.target_lens["train"] <= self.target_lens["val"], f"target_lens: {self.target_lens}" + assert self.target_lens["train"] <= self.target_lens["test"], f"target_lens: {self.target_lens}" + if self.hparams.freeze_embeds: + freeze_embeds(self.model) + if self.hparams.freeze_encoder: + freeze_params(self.model.get_encoder()) + assert_all_frozen(self.model.get_encoder()) + + self.hparams.git_sha = get_git_info()["repo_sha"] + self.num_workers = hparams.num_workers + self.decoder_start_token_id = None # default to config + if self.model.config.decoder_start_token_id is None and isinstance(self.tokenizer, MBartTokenizer): + self.decoder_start_token_id = self.tokenizer.lang_code_to_id[hparams.tgt_lang] + self.model.config.decoder_start_token_id = self.decoder_start_token_id + + if 'unifiedqa' in self.hparams.model_name_or_path: + self.dataset_class = (UniQASeq2SeqDataset) + else: + self.dataset_class = ( + Seq2SeqDataset if hasattr(self.tokenizer, "prepare_seq2seq_batch") else LegacySeq2SeqDataset + ) + self.already_saved_batch = False + self.eval_beams = self.model.config.num_beams if self.hparams.eval_beams is None else self.hparams.eval_beams + if self.hparams.eval_max_gen_length is not None: + self.eval_max_length = self.hparams.eval_max_gen_length + else: + self.eval_max_length = self.model.config.max_length + if self.hparams.min_target_length is not None: + self.min_length = self.hparams.min_target_length + else: + self.min_length = self.model.config.min_length + + self.val_metric = self.default_val_metric if self.hparams.val_metric is None else self.hparams.val_metric + + def save_readable_batch(self, batch: Dict[str, torch.Tensor]) -> Dict[str, List[str]]: + """A debugging utility""" + readable_batch = { + k: self.tokenizer.batch_decode(v.tolist()) if "mask" not in k else v.shape for k, v in batch.items() + } + save_json(readable_batch, Path(self.output_dir) / "text_batch.json") + save_json({k: v.tolist() for k, v in batch.items()}, Path(self.output_dir) / "tok_batch.json") + + self.already_saved_batch = True + return readable_batch + + def forward(self, input_ids, **kwargs): + return self.model(input_ids, **kwargs) + + def ids_to_clean_text(self, generated_ids: List[int]): + gen_text = self.tokenizer.batch_decode( + generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True + ) + return lmap(str.strip, gen_text) + + def _step(self, batch: dict) -> Tuple: + pad_token_id = self.tokenizer.pad_token_id + src_ids, src_mask = batch["input_ids"], batch["attention_mask"] + tgt_ids = batch["labels"] + if isinstance(self.model, T5ForConditionalGeneration): + decoder_input_ids = self.model._shift_right(tgt_ids) + else: + decoder_input_ids = shift_tokens_right(tgt_ids, pad_token_id) + if not self.already_saved_batch: # This would be slightly better if it only happened on rank zero + batch["decoder_input_ids"] = decoder_input_ids + self.save_readable_batch(batch) + + outputs = self(src_ids, attention_mask=src_mask, decoder_input_ids=decoder_input_ids, use_cache=False) + lm_logits = outputs[0] + if self.hparams.label_smoothing == 0: + # Same behavior as modeling_bart.py, besides ignoring pad_token_id + ce_loss_fct = torch.nn.CrossEntropyLoss(ignore_index=pad_token_id) + + assert lm_logits.shape[-1] == self.vocab_size + loss = ce_loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), tgt_ids.view(-1)) + else: + lprobs = torch.nn.functional.log_softmax(lm_logits, dim=-1) + loss, nll_loss = label_smoothed_nll_loss( + lprobs, tgt_ids, self.hparams.label_smoothing, ignore_index=pad_token_id + ) + return (loss,) + + @property + def pad(self) -> int: + return self.tokenizer.pad_token_id + + def training_step(self, batch, batch_idx) -> Dict: + loss_tensors = self._step(batch) + + logs = {name: loss for name, loss in zip(self.loss_names, loss_tensors)} + # tokens per batch + logs["tpb"] = batch["input_ids"].ne(self.pad).sum() + batch["labels"].ne(self.pad).sum() + logs["bs"] = batch["input_ids"].shape[0] + logs["src_pad_tok"] = batch["input_ids"].eq(self.pad).sum() + logs["src_pad_frac"] = batch["input_ids"].eq(self.pad).float().mean() + # TODO(SS): make a wandb summary metric for this + return {"loss": loss_tensors[0], "log": logs} + + def validation_step(self, batch, batch_idx) -> Dict: + return self._generative_step(batch) + + def validation_epoch_end(self, outputs, prefix="val") -> Dict: + self.step_count += 1 + losses = {k: torch.stack([x[k] for x in outputs]).mean() for k in self.loss_names} + loss = losses["loss"] + generative_metrics = { + k: np.array([x[k] for x in outputs]).mean() for k in self.metric_names + ["gen_time", "gen_len"] + } + metric_val = ( + generative_metrics[self.val_metric] if self.val_metric in generative_metrics else losses[self.val_metric] + ) + metric_tensor: torch.FloatTensor = torch.tensor(metric_val).type_as(loss) + generative_metrics.update({k: v.item() for k, v in losses.items()}) + losses.update(generative_metrics) + all_metrics = {f"{prefix}_avg_{k}": x for k, x in losses.items()} + all_metrics["step_count"] = self.step_count + self.metrics[prefix].append(all_metrics) # callback writes this to self.metrics_save_path + preds = flatten_list([x["preds"] for x in outputs]) + return { + "log": all_metrics, + "preds": preds, + f"{prefix}_loss": loss, + f"{prefix}_{self.val_metric}": metric_tensor, + } + + def calc_generative_metrics(self, preds, target) -> Dict: + return calculate_rouge(preds, target) + + def _generative_step(self, batch: dict) -> dict: + t0 = time.time() + + # parser.add_argument('--eval_max_gen_length', type=int, default=None, help='never generate more than n tokens') + generated_ids = self.model.generate( + batch["input_ids"], + attention_mask=batch["attention_mask"], + use_cache=True, + decoder_start_token_id=self.decoder_start_token_id, + num_beams=self.eval_beams, + max_length=self.eval_max_length, + min_length=self.min_length + ) + gen_time = (time.time() - t0) / batch["input_ids"].shape[0] + preds: List[str] = self.ids_to_clean_text(generated_ids) + target: List[str] = self.ids_to_clean_text(batch["labels"]) + loss_tensors = self._step(batch) + base_metrics = {name: loss for name, loss in zip(self.loss_names, loss_tensors)} + rouge: Dict = self.calc_generative_metrics(preds, target) + summ_len = np.mean(lmap(len, generated_ids)) + base_metrics.update(gen_time=gen_time, gen_len=summ_len, preds=preds, target=target, **rouge) + return base_metrics + + def test_step(self, batch, batch_idx): + return self._generative_step(batch) + + def test_epoch_end(self, outputs): + return self.validation_epoch_end(outputs, prefix="test") + + def get_dataset(self, type_path) -> Seq2SeqDataset: + n_obs = self.n_obs[type_path] + max_target_length = self.target_lens[type_path] + dataset = self.dataset_class( + self.tokenizer, + type_path=type_path, + n_obs=n_obs, + max_target_length=max_target_length, + **self.dataset_kwargs, + ) + return dataset + + def get_dataloader(self, type_path: str, batch_size: int, shuffle: bool = False) -> DataLoader: + dataset = self.get_dataset(type_path) + + if self.hparams.sortish_sampler and type_path != "test" and type_path != "val": + sampler = dataset.make_sortish_sampler(batch_size, distributed=self.hparams.gpus > 1) + return DataLoader( + dataset, + batch_size=batch_size, + collate_fn=dataset.collate_fn, + shuffle=False, + num_workers=self.num_workers, + sampler=sampler, + ) + + elif self.hparams.max_tokens_per_batch is not None and type_path != "test" and type_path != "val": + batch_sampler = dataset.make_dynamic_sampler( + self.hparams.max_tokens_per_batch, distributed=self.hparams.gpus > 1 + ) + return DataLoader( + dataset, + batch_sampler=batch_sampler, + collate_fn=dataset.collate_fn, + # shuffle=False, + num_workers=self.num_workers, + # batch_size=None, + ) + else: + return DataLoader( + dataset, + batch_size=batch_size, + collate_fn=dataset.collate_fn, + shuffle=shuffle, + num_workers=self.num_workers, + sampler=None, + ) + + def train_dataloader(self) -> DataLoader: + dataloader = self.get_dataloader("train", batch_size=self.hparams.train_batch_size, shuffle=True) + return dataloader + + def val_dataloader(self) -> DataLoader: + return self.get_dataloader("val", batch_size=self.hparams.eval_batch_size) + + def test_dataloader(self) -> DataLoader: + return self.get_dataloader("test", batch_size=self.hparams.eval_batch_size) + + @staticmethod + def add_model_specific_args(parser, root_dir): + BaseTransformer.add_model_specific_args(parser, root_dir) + add_generic_args(parser, root_dir) + parser.add_argument( + "--min_target_length", + default=1, + type=int, + help="The minimum total target sequence length after tokenization.", + ) + parser.add_argument( + "--max_source_length", + default=1024, + type=int, + help="The maximum total input sequence length after tokenization. Sequences longer " + "than this will be truncated, sequences shorter will be padded.", + ) + parser.add_argument( + "--max_target_length", + default=56, + type=int, + help="The maximum total input sequence length after tokenization. Sequences longer " + "than this will be truncated, sequences shorter will be padded.", + ) + parser.add_argument( + "--val_max_target_length", + default=142, # these defaults are optimized for CNNDM. For xsum, see README.md. + type=int, + help="The maximum total input sequence length after tokenization. Sequences longer " + "than this will be truncated, sequences shorter will be padded.", + ) + parser.add_argument( + "--test_max_target_length", + default=142, + type=int, + help="The maximum total input sequence length after tokenization. Sequences longer " + "than this will be truncated, sequences shorter will be padded.", + ) + parser.add_argument("--freeze_encoder", action="store_true") + parser.add_argument("--freeze_embeds", action="store_true") + parser.add_argument("--sortish_sampler", action="store_true", default=False) + parser.add_argument("--overwrite_output_dir", action="store_true", default=False) + parser.add_argument("--max_tokens_per_batch", type=int, default=None) + parser.add_argument("--logger_name", type=str, choices=["default", "wandb", "wandb_shared"], default="default") + parser.add_argument("--n_train", type=int, default=-1, required=False, help="# examples. -1 means use all.") + parser.add_argument("--n_val", type=int, default=500, required=False, help="# examples. -1 means use all.") + parser.add_argument("--n_test", type=int, default=-1, required=False, help="# examples. -1 means use all.") + parser.add_argument( + "--task", type=str, default="summarization", required=False, help="# examples. -1 means use all." + ) + parser.add_argument("--label_smoothing", type=float, default=0.0, required=False) + parser.add_argument("--src_lang", type=str, default="", required=False) + parser.add_argument("--tgt_lang", type=str, default="", required=False) + parser.add_argument("--eval_beams", type=int, default=None, required=False) + parser.add_argument( + "--val_metric", type=str, default=None, required=False, choices=["bleu", "rouge2", "loss", None] + ) + parser.add_argument("--eval_max_gen_length", type=int, default=None, help="never generate more than n tokens") + parser.add_argument("--save_top_k", type=int, default=1, required=False, help="How many checkpoints to save") + parser.add_argument( + "--early_stopping_patience", + type=int, + default=-1, + required=False, + help="-1 means never early stop. early_stopping_patience is measured in validation checks, not epochs. So val_check_interval will effect it.", + ) + return parser + + +class TranslationModule(SummarizationModule): + mode = "translation" + loss_names = ["loss"] + metric_names = ["bleu"] + default_val_metric = "bleu" + + def __init__(self, hparams, **kwargs): + super().__init__(hparams, **kwargs) + self.dataset_kwargs["src_lang"] = hparams.src_lang + self.dataset_kwargs["tgt_lang"] = hparams.tgt_lang + + def calc_generative_metrics(self, preds, target) -> dict: + return calculate_bleu(preds, target) + + +def main(args, model=None) -> SummarizationModule: + Path(args.output_dir).mkdir(exist_ok=True) + check_output_dir(args, expected_items=3) + + if model is None: + if "summarization" in args.task: + model: SummarizationModule = SummarizationModule(args) + else: + model: SummarizationModule = TranslationModule(args) + dataset = Path(args.data_dir).name + if ( + args.logger_name == "default" + or args.fast_dev_run + or str(args.output_dir).startswith("/tmp") + or str(args.output_dir).startswith("/var") + ): + logger = True # don't pollute wandb logs unnecessarily + elif args.logger_name == "wandb": + from pytorch_lightning.loggers import WandbLogger + + project = os.environ.get("WANDB_PROJECT", dataset) + logger = WandbLogger(name=model.output_dir.name, project=project) + + elif args.logger_name == "wandb_shared": + from pytorch_lightning.loggers import WandbLogger + + logger = WandbLogger(name=model.output_dir.name, project=f"hf_{dataset}") + + if args.early_stopping_patience >= 0: + es_callback = get_early_stopping_callback(model.val_metric, args.early_stopping_patience) + else: + es_callback = False + + lower_is_better = args.val_metric == "loss" + trainer: pl.Trainer = generic_train( + model, + args, + logging_callback=Seq2SeqLoggingCallback(), + checkpoint_callback=get_checkpoint_callback( + args.output_dir, model.val_metric, args.save_top_k, lower_is_better + ), + early_stopping_callback=es_callback, + logger=logger, + ) + pickle_save(model.hparams, model.output_dir / "hparams.pkl") + if not args.do_predict: + return model + + model.hparams.test_checkpoint = "" + checkpoints = list(sorted(glob.glob(os.path.join(args.output_dir, "*.ckpt"), recursive=True))) + if checkpoints: + model.hparams.test_checkpoint = checkpoints[-1] + trainer.resume_from_checkpoint = checkpoints[-1] + trainer.logger.log_hyperparams(model.hparams) + + # test() without a model tests using the best checkpoint automatically + trainer.test() + return model + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser = pl.Trainer.add_argparse_args(parser) + parser = SummarizationModule.add_model_specific_args(parser, os.getcwd()) + + args = parser.parse_args() + + main(args) diff --git a/src/mrc_client/seq2seq/finetune_t5.sh b/src/mrc_client/seq2seq/finetune_t5.sh new file mode 100755 index 0000000000000000000000000000000000000000..504e9eb71e3596360bfb575ded4136689854e250 --- /dev/null +++ b/src/mrc_client/seq2seq/finetune_t5.sh @@ -0,0 +1,14 @@ +# Add parent directory to python path to access lightning_base.py +export PYTHONPATH="../":"${PYTHONPATH}" + +python finetune.py \ +--data_dir=$CNN_DIR \ +--learning_rate=3e-5 \ +--train_batch_size=$BS \ +--eval_batch_size=$BS \ +--output_dir=$OUTPUT_DIR \ +--max_source_length=512 \ +--max_target_length=56 \ +--val_check_interval=0.1 --n_val=200 \ +--do_train --do_predict \ + "$@" diff --git a/src/mrc_client/seq2seq/finetune_trainer.py b/src/mrc_client/seq2seq/finetune_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..ccd0b15f4b293d3d384249ea4d7ae7c14cd0acb7 --- /dev/null +++ b/src/mrc_client/seq2seq/finetune_trainer.py @@ -0,0 +1,303 @@ +import logging +import os +import sys +from dataclasses import dataclass, field +from typing import Optional + +from seq2seq_trainer import Seq2SeqTrainer +from seq2seq_training_args import Seq2SeqTrainingArguments +from transformers import AutoConfig, AutoModelForSeq2SeqLM, AutoTokenizer, HfArgumentParser, MBartTokenizer, set_seed +from transformers.trainer_utils import EvaluationStrategy +from seq2seq_utils import ( + Seq2SeqDataCollator, + Seq2SeqDataset, + assert_all_frozen, + build_compute_metrics_fn, + check_output_dir, + freeze_embeds, + freeze_params, + lmap, + save_json, + use_task_specific_params, + write_txt_file, +) + + +logger = logging.getLogger(__name__) + + +@dataclass +class ModelArguments: + """ + Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. + """ + + model_name_or_path: str = field( + metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} + ) + config_name: Optional[str] = field( + default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} + ) + tokenizer_name: Optional[str] = field( + default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} + ) + cache_dir: Optional[str] = field( + default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"} + ) + freeze_encoder: bool = field(default=False, metadata={"help": "Whether tp freeze the encoder."}) + freeze_embeds: bool = field(default=False, metadata={"help": "Whether to freeze the embeddings."}) + + +@dataclass +class DataTrainingArguments: + """ + Arguments pertaining to what data we are going to input our model for training and eval. + """ + + data_dir: str = field( + metadata={"help": "The input data dir. Should contain the .tsv files (or other data files) for the task."} + ) + task: Optional[str] = field( + default="summarization", + metadata={"help": "Task name, summarization (or summarization_{dataset} for pegasus) or translation"}, + ) + max_source_length: Optional[int] = field( + default=1024, + metadata={ + "help": "The maximum total input sequence length after tokenization. Sequences longer " + "than this will be truncated, sequences shorter will be padded." + }, + ) + max_target_length: Optional[int] = field( + default=128, + metadata={ + "help": "The maximum total sequence length for target text after tokenization. Sequences longer " + "than this will be truncated, sequences shorter will be padded." + }, + ) + val_max_target_length: Optional[int] = field( + default=142, + metadata={ + "help": "The maximum total sequence length for validation target text after tokenization. Sequences longer " + "than this will be truncated, sequences shorter will be padded." + }, + ) + test_max_target_length: Optional[int] = field( + default=142, + metadata={ + "help": "The maximum total sequence length for test target text after tokenization. Sequences longer " + "than this will be truncated, sequences shorter will be padded." + }, + ) + n_train: Optional[int] = field(default=-1, metadata={"help": "# training examples. -1 means use all."}) + n_val: Optional[int] = field(default=-1, metadata={"help": "# validation examples. -1 means use all."}) + n_test: Optional[int] = field(default=-1, metadata={"help": "# test examples. -1 means use all."}) + src_lang: Optional[str] = field(default=None, metadata={"help": "Source language id for translation."}) + tgt_lang: Optional[str] = field(default=None, metadata={"help": "Target language id for translation."}) + eval_beams: Optional[int] = field(default=None, metadata={"help": "# num_beams to use for evaluation."}) + ignore_pad_token_for_loss: bool = field( + default=True, + metadata={"help": "If only pad tokens should be ignored. This assumes that `config.pad_token_id` is defined."}, + ) + + +def main(): + # See all possible arguments in src/transformers/training_args.py + # or by passing the --help flag to this script. + # We now keep distinct sets of args, for a cleaner separation of concerns. + + parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments)) + + if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): + # If we pass only one argument to the script and it's the path to a json file, + # let's parse it to get our arguments. + model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) + else: + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + + check_output_dir(training_args) + + # Setup logging + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN, + ) + logger.warning( + "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s", + training_args.local_rank, + training_args.device, + training_args.n_gpu, + bool(training_args.local_rank != -1), + training_args.fp16, + ) + logger.info("Training/evaluation parameters %s", training_args) + + # Set seed + set_seed(training_args.seed) + + # Load pretrained model and tokenizer + # + # Distributed training: + # The .from_pretrained methods guarantee that only one local process can concurrently + # download model & vocab. + + config = AutoConfig.from_pretrained( + model_args.config_name if model_args.config_name else model_args.model_name_or_path, + cache_dir=model_args.cache_dir, + ) + + extra_model_params = ("encoder_layerdrop", "decoder_layerdrop", "dropout", "attention_dropout") + for p in extra_model_params: + if getattr(training_args, p, None): + assert hasattr(config, p), f"({config.__class__.__name__}) doesn't have a `{p}` attribute" + setattr(config, p, getattr(training_args, p)) + + tokenizer = AutoTokenizer.from_pretrained( + model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, + cache_dir=model_args.cache_dir, + ) + model = AutoModelForSeq2SeqLM.from_pretrained( + model_args.model_name_or_path, + from_tf=".ckpt" in model_args.model_name_or_path, + config=config, + cache_dir=model_args.cache_dir, + ) + + # use task specific params + use_task_specific_params(model, data_args.task) + + # set num_beams for evaluation + if data_args.eval_beams is None: + data_args.eval_beams = model.config.num_beams + + # set decoder_start_token_id for MBart + if model.config.decoder_start_token_id is None and isinstance(tokenizer, MBartTokenizer): + assert ( + data_args.tgt_lang is not None and data_args.src_lang is not None + ), "mBart requires --tgt_lang and --src_lang" + model.config.decoder_start_token_id = tokenizer.lang_code_to_id[data_args.tgt_lang] + + if model_args.freeze_embeds: + freeze_embeds(model) + if model_args.freeze_encoder: + freeze_params(model.get_encoder()) + assert_all_frozen(model.get_encoder()) + + dataset_class = Seq2SeqDataset + + # Get datasets + train_dataset = ( + dataset_class( + tokenizer, + type_path="train", + data_dir=data_args.data_dir, + n_obs=data_args.n_train, + max_target_length=data_args.max_target_length, + max_source_length=data_args.max_source_length, + prefix=model.config.prefix or "", + ) + if training_args.do_train + else None + ) + eval_dataset = ( + dataset_class( + tokenizer, + type_path="val", + data_dir=data_args.data_dir, + n_obs=data_args.n_val, + max_target_length=data_args.val_max_target_length, + max_source_length=data_args.max_source_length, + prefix=model.config.prefix or "", + ) + if training_args.do_eval or training_args.evaluation_strategy != EvaluationStrategy.NO + else None + ) + test_dataset = ( + dataset_class( + tokenizer, + type_path="test", + data_dir=data_args.data_dir, + n_obs=data_args.n_test, + max_target_length=data_args.test_max_target_length, + max_source_length=data_args.max_source_length, + prefix=model.config.prefix or "", + ) + if training_args.do_predict + else None + ) + + # Initialize our Trainer + compute_metrics_fn = ( + build_compute_metrics_fn(data_args.task, tokenizer) if training_args.predict_with_generate else None + ) + trainer = Seq2SeqTrainer( + model=model, + config=config, + args=training_args, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + data_collator=Seq2SeqDataCollator(tokenizer, data_args, training_args.tpu_num_cores), + compute_metrics=compute_metrics_fn, + data_args=data_args, + ) + + # Training + if training_args.do_train: + trainer.train( + model_path=model_args.model_name_or_path if os.path.isdir(model_args.model_name_or_path) else None + ) + trainer.save_model() + # For convenience, we also re-save the tokenizer to the same directory, + # so that you can share your model easily on huggingface.co/models =) + if trainer.is_world_process_zero(): + trainer.state.save_to_json(os.path.join(training_args.output_dir, "trainer_state.json")) + tokenizer.save_pretrained(training_args.output_dir) + + # Evaluation + eval_results = {} + if training_args.do_eval: + logger.info("*** Evaluate ***") + + result = trainer.evaluate() + + if trainer.is_world_process_zero(): + logger.info("***** Eval results *****") + for key, value in result.items(): + logger.info(" %s = %s", key, value) + save_json(result, os.path.join(training_args.output_dir, "eval_results.json")) + eval_results.update(result) + + if training_args.do_predict: + logging.info("*** Test ***") + + test_output = trainer.predict(test_dataset=test_dataset) + test_metrics = {k.replace("eval", "test"): v for k, v in test_output.metrics.items()} + + if trainer.is_world_process_zero(): + logger.info("***** Test results *****") + for key, value in test_metrics.items(): + logger.info(" %s = %s", key, value) + + save_json(test_metrics, os.path.join(training_args.output_dir, "test_results.json")) + eval_results.update(test_metrics) + + if training_args.predict_with_generate: + test_preds = tokenizer.batch_decode( + test_output.predictions, skip_special_tokens=True, clean_up_tokenization_spaces=True + ) + test_preds = lmap(str.strip, test_preds) + write_txt_file(test_preds, os.path.join(training_args.output_dir, "test_generations.txt")) + + if trainer.is_world_process_zero(): + save_json(eval_results, "all_results.json") + return eval_results + + +def _mp_fn(index): + # For xla_spawn (TPUs) + main() + + +if __name__ == "__main__": + main() diff --git a/src/mrc_client/seq2seq/lightning_base.py b/src/mrc_client/seq2seq/lightning_base.py new file mode 100644 index 0000000000000000000000000000000000000000..3f35ffe0f09ab1087dd1f014e2da04dc6e8a765c --- /dev/null +++ b/src/mrc_client/seq2seq/lightning_base.py @@ -0,0 +1,397 @@ +import argparse +import logging +import os +from pathlib import Path +from typing import Any, Dict + +import pytorch_lightning as pl +from pytorch_lightning.utilities import rank_zero_info + +import pkg_resources +from transformers import ( + AdamW, + AutoConfig, + AutoModel, + AutoModelForPreTraining, + AutoModelForQuestionAnswering, + AutoModelForSeq2SeqLM, + AutoModelForSequenceClassification, + AutoModelForTokenClassification, + AutoModelWithLMHead, + AutoTokenizer, + PretrainedConfig, + PreTrainedTokenizer, +) +from transformers.optimization import ( + Adafactor, + get_cosine_schedule_with_warmup, + get_cosine_with_hard_restarts_schedule_with_warmup, + get_linear_schedule_with_warmup, + get_polynomial_decay_schedule_with_warmup, +) + + +logger = logging.getLogger(__name__) + +try: + pkg = "pytorch_lightning" + min_ver = "1.0.4" + pkg_resources.require(f"{pkg}>={min_ver}") +except pkg_resources.VersionConflict: + logger.warning( + f"{pkg}>={min_ver} is required for a normal functioning of this module, but found {pkg}=={pkg_resources.get_distribution(pkg).version}. Try pip install -r examples/requirements.txt" + ) + + +MODEL_MODES = { + "base": AutoModel, + "sequence-classification": AutoModelForSequenceClassification, + "question-answering": AutoModelForQuestionAnswering, + "pretraining": AutoModelForPreTraining, + "token-classification": AutoModelForTokenClassification, + "language-modeling": AutoModelWithLMHead, + "summarization": AutoModelForSeq2SeqLM, + "translation": AutoModelForSeq2SeqLM, +} + + +# update this and the import above to support new schedulers from transformers.optimization +arg_to_scheduler = { + "linear": get_linear_schedule_with_warmup, + "cosine": get_cosine_schedule_with_warmup, + "cosine_w_restarts": get_cosine_with_hard_restarts_schedule_with_warmup, + "polynomial": get_polynomial_decay_schedule_with_warmup, + # '': get_constant_schedule, # not supported for now + # '': get_constant_schedule_with_warmup, # not supported for now +} +arg_to_scheduler_choices = sorted(arg_to_scheduler.keys()) +arg_to_scheduler_metavar = "{" + ", ".join(arg_to_scheduler_choices) + "}" + + +class BaseTransformer(pl.LightningModule): + def __init__( + self, + hparams: argparse.Namespace, + num_labels=None, + mode="base", + config=None, + tokenizer=None, + model=None, + **config_kwargs + ): + """Initialize a model, tokenizer and config.""" + super().__init__() + # TODO: move to self.save_hyperparameters() + # self.save_hyperparameters() + # can also expand arguments into trainer signature for easier reading + + self.save_hyperparameters(hparams) + self.step_count = 0 + self.output_dir = Path(self.hparams.output_dir) + cache_dir = self.hparams.cache_dir if self.hparams.cache_dir else None + if config is None: + self.config = AutoConfig.from_pretrained( + self.hparams.config_name if self.hparams.config_name else self.hparams.model_name_or_path, + **({"num_labels": num_labels} if num_labels is not None else {}), + cache_dir=cache_dir, + **config_kwargs, + ) + else: + self.config: PretrainedConfig = config + + extra_model_params = ("encoder_layerdrop", "decoder_layerdrop", "dropout", "attention_dropout") + for p in extra_model_params: + if getattr(self.hparams, p, None): + assert hasattr(self.config, p), f"model config doesn't have a `{p}` attribute" + setattr(self.config, p, getattr(self.hparams, p)) + + if tokenizer is None: + self.tokenizer = AutoTokenizer.from_pretrained( + self.hparams.tokenizer_name if self.hparams.tokenizer_name else self.hparams.model_name_or_path, + cache_dir=cache_dir, + ) + else: + self.tokenizer: PreTrainedTokenizer = tokenizer + self.model_type = MODEL_MODES[mode] + if model is None: + self.model = self.model_type.from_pretrained( + self.hparams.model_name_or_path, + from_tf=bool(".ckpt" in self.hparams.model_name_or_path), + config=self.config, + cache_dir=cache_dir, + ) + else: + self.model = model + + def load_hf_checkpoint(self, *args, **kwargs): + self.model = self.model_type.from_pretrained(*args, **kwargs) + + def get_lr_scheduler(self): + get_schedule_func = arg_to_scheduler[self.hparams.lr_scheduler] + scheduler = get_schedule_func( + self.opt, num_warmup_steps=self.hparams.warmup_steps, num_training_steps=self.total_steps() + ) + scheduler = {"scheduler": scheduler, "interval": "step", "frequency": 1} + return scheduler + + def configure_optimizers(self): + """Prepare optimizer and schedule (linear warmup and decay)""" + model = self.model + no_decay = ["bias", "LayerNorm.weight"] + optimizer_grouped_parameters = [ + { + "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], + "weight_decay": self.hparams.weight_decay, + }, + { + "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], + "weight_decay": 0.0, + }, + ] + if self.hparams.adafactor: + optimizer = Adafactor( + optimizer_grouped_parameters, lr=self.hparams.learning_rate, scale_parameter=False, relative_step=False + ) + + else: + optimizer = AdamW( + optimizer_grouped_parameters, lr=self.hparams.learning_rate, eps=self.hparams.adam_epsilon + ) + self.opt = optimizer + + scheduler = self.get_lr_scheduler() + + return [optimizer], [scheduler] + + def test_step(self, batch, batch_nb): + return self.validation_step(batch, batch_nb) + + def test_epoch_end(self, outputs): + return self.validation_end(outputs) + + def total_steps(self) -> int: + """The number of total training steps that will be run. Used for lr scheduler purposes.""" + num_devices = max(1, self.hparams.gpus) # TODO: consider num_tpu_cores + effective_batch_size = self.hparams.train_batch_size * self.hparams.accumulate_grad_batches * num_devices + return (self.dataset_size / effective_batch_size) * self.hparams.max_epochs + + def setup(self, mode): + if mode == "test": + self.dataset_size = len(self.test_dataloader().dataset) + else: + self.train_loader = self.get_dataloader("train", self.hparams.train_batch_size, shuffle=True) + self.dataset_size = len(self.train_dataloader().dataset) + + def get_dataloader(self, type_path: str, batch_size: int, shuffle: bool = False): + raise NotImplementedError("You must implement this for your task") + + def train_dataloader(self): + return self.train_loader + + def val_dataloader(self): + return self.get_dataloader("dev", self.hparams.eval_batch_size, shuffle=False) + + def test_dataloader(self): + return self.get_dataloader("test", self.hparams.eval_batch_size, shuffle=False) + + def _feature_file(self, mode): + return os.path.join( + self.hparams.data_dir, + "cached_{}_{}_{}".format( + mode, + list(filter(None, self.hparams.model_name_or_path.split("/"))).pop(), + str(self.hparams.max_seq_length), + ), + ) + + @pl.utilities.rank_zero_only + def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: + save_path = self.output_dir.joinpath("best_tfmr") + self.model.config.save_step = self.step_count + self.model.save_pretrained(save_path) + self.tokenizer.save_pretrained(save_path) + + @staticmethod + def add_model_specific_args(parser, root_dir): + parser.add_argument( + "--model_name_or_path", + default=None, + type=str, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models", + ) + parser.add_argument( + "--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name" + ) + parser.add_argument( + "--tokenizer_name", + default=None, + type=str, + help="Pretrained tokenizer name or path if not the same as model_name", + ) + parser.add_argument( + "--cache_dir", + default="", + type=str, + help="Where do you want to store the pre-trained models downloaded from s3", + ) + parser.add_argument( + "--encoder_layerdrop", + type=float, + help="Encoder layer dropout probability (Optional). Goes into model.config", + ) + parser.add_argument( + "--decoder_layerdrop", + type=float, + help="Decoder layer dropout probability (Optional). Goes into model.config", + ) + parser.add_argument( + "--dropout", + type=float, + help="Dropout probability (Optional). Goes into model.config", + ) + parser.add_argument( + "--attention_dropout", + type=float, + help="Attention dropout probability (Optional). Goes into model.config", + ) + parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.") + parser.add_argument( + "--lr_scheduler", + default="linear", + choices=arg_to_scheduler_choices, + metavar=arg_to_scheduler_metavar, + type=str, + help="Learning rate scheduler", + ) + parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.") + parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.") + parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.") + parser.add_argument("--num_workers", default=4, type=int, help="kwarg passed to DataLoader") + parser.add_argument("--num_train_epochs", dest="max_epochs", default=3, type=int) + parser.add_argument("--train_batch_size", default=32, type=int) + parser.add_argument("--eval_batch_size", default=32, type=int) + parser.add_argument("--adafactor", action="store_true") + + +class LoggingCallback(pl.Callback): + def on_batch_end(self, trainer, pl_module): + lr_scheduler = trainer.lr_schedulers[0]["scheduler"] + lrs = {f"lr_group_{i}": lr for i, lr in enumerate(lr_scheduler.get_lr())} + pl_module.logger.log_metrics(lrs) + + def on_validation_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule): + rank_zero_info("***** Validation results *****") + metrics = trainer.callback_metrics + # Log results + for key in sorted(metrics): + if key not in ["log", "progress_bar"]: + rank_zero_info("{} = {}\n".format(key, str(metrics[key]))) + + def on_test_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule): + rank_zero_info("***** Test results *****") + metrics = trainer.callback_metrics + # Log and save results to file + output_test_results_file = os.path.join(pl_module.hparams.output_dir, "test_results.txt") + with open(output_test_results_file, "w") as writer: + for key in sorted(metrics): + if key not in ["log", "progress_bar"]: + rank_zero_info("{} = {}\n".format(key, str(metrics[key]))) + writer.write("{} = {}\n".format(key, str(metrics[key]))) + + +def add_generic_args(parser, root_dir) -> None: + # To allow all pl args uncomment the following line + # parser = pl.Trainer.add_argparse_args(parser) + parser.add_argument( + "--output_dir", + default=None, + type=str, + required=True, + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument( + "--fp16", + action="store_true", + help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit", + ) + + parser.add_argument( + "--fp16_opt_level", + type=str, + default="O2", + help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']." + "See details at https://nvidia.github.io/apex/amp.html", + ) + parser.add_argument("--n_tpu_cores", dest="tpu_cores", type=int) + parser.add_argument("--max_grad_norm", dest="gradient_clip_val", default=1.0, type=float, help="Max gradient norm") + parser.add_argument("--do_train", action="store_true", help="Whether to run training.") + parser.add_argument("--do_predict", action="store_true", help="Whether to run predictions on the test set.") + parser.add_argument( + "--gradient_accumulation_steps", + dest="accumulate_grad_batches", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument("--seed", type=int, default=42, help="random seed for initialization") + parser.add_argument( + "--data_dir", + default=None, + type=str, + required=True, + help="The input data dir. Should contain the training files for the CoNLL-2003 NER task.", + ) + + +def generic_train( + model: BaseTransformer, + args: argparse.Namespace, + early_stopping_callback=None, + logger=True, # can pass WandbLogger() here + extra_callbacks=[], + checkpoint_callback=None, + logging_callback=None, + **extra_train_kwargs +): + pl.seed_everything(args.seed) + + # init model + odir = Path(model.hparams.output_dir) + odir.mkdir(exist_ok=True) + + # add custom checkpoints + if checkpoint_callback is None: + checkpoint_callback = pl.callbacks.ModelCheckpoint( + filepath=args.output_dir, prefix="checkpoint", monitor="val_loss", mode="min", save_top_k=1 + ) + if early_stopping_callback: + extra_callbacks.append(early_stopping_callback) + if logging_callback is None: + logging_callback = LoggingCallback() + + train_params = {} + + # TODO: remove with PyTorch 1.6 since pl uses native amp + if args.fp16: + train_params["precision"] = 16 + train_params["amp_level"] = args.fp16_opt_level + + if args.gpus > 1: + train_params["distributed_backend"] = "ddp" + + train_params["accumulate_grad_batches"] = args.accumulate_grad_batches + + trainer = pl.Trainer.from_argparse_args( + args, + weights_summary=None, + callbacks=[logging_callback] + extra_callbacks, + logger=logger, + checkpoint_callback=checkpoint_callback, + **train_params, + ) + + if args.do_train: + trainer.fit(model) + + return trainer diff --git a/src/mrc_client/seq2seq/rouge_cli.py b/src/mrc_client/seq2seq/rouge_cli.py new file mode 100644 index 0000000000000000000000000000000000000000..62236d06a0ff57dccdc677e7380b1530d6248c1f --- /dev/null +++ b/src/mrc_client/seq2seq/rouge_cli.py @@ -0,0 +1,17 @@ +import fire + +from seq2seq_utils import calculate_rouge, save_json + + +def calculate_rouge_path(pred_path, tgt_path, save_path=None, **kwargs): + """Kwargs will be passed to calculate_rouge""" + pred_lns = [x.strip() for x in open(pred_path).readlines()] + tgt_lns = [x.strip() for x in open(tgt_path).readlines()][: len(pred_lns)] + metrics = calculate_rouge(pred_lns, tgt_lns, **kwargs) + if save_path is not None: + save_json(metrics, save_path, indent=None) + return metrics # these print nicely + + +if __name__ == "__main__": + fire.Fire(calculate_rouge_path) diff --git a/src/mrc_client/seq2seq/run_distributed_eval.py b/src/mrc_client/seq2seq/run_distributed_eval.py new file mode 100755 index 0000000000000000000000000000000000000000..783af80ff81376ee9784af10d3d8b99d79f7921e --- /dev/null +++ b/src/mrc_client/seq2seq/run_distributed_eval.py @@ -0,0 +1,248 @@ +#!/usr/bin/env python + +import argparse +import shutil +import time +from json import JSONDecodeError +from logging import getLogger +from pathlib import Path +from typing import Dict, List + +import torch +from torch.utils.data import DataLoader +from tqdm import tqdm + +from transformers import AutoModelForSeq2SeqLM, AutoTokenizer +from seq2seq_utils import ( + Seq2SeqDataset, + calculate_bleu, + calculate_rouge, + chunks, + lmap, + load_json, + parse_numeric_n_bool_cl_kwargs, + save_json, + use_task_specific_params, + write_txt_file, +) + + +logger = getLogger(__name__) + + +def eval_data_dir( + data_dir, + save_dir: str, + model_name: str, + bs: int = 8, + max_source_length: int = 1024, + type_path="val", + n_obs=None, + fp16=False, + task="summarization", + local_rank=None, + num_return_sequences=1, + dataset_kwargs: Dict = None, + prefix="", + **generate_kwargs, +) -> Dict: + """Run evaluation on part of the data for one gpu and save to {save_dir}/rank_{rank}_output.json""" + model_name = str(model_name) + assert local_rank is not None + torch.distributed.init_process_group(backend="nccl", rank=local_rank) + + save_dir = Path(save_dir) + save_path = save_dir.joinpath(f"rank_{local_rank}_output.json") + torch.cuda.set_device(local_rank) + model = AutoModelForSeq2SeqLM.from_pretrained(model_name).cuda() + if fp16: + model = model.half() + # determine if we need to increase num_beams + use_task_specific_params(model, task) # update config with task specific params + num_beams = generate_kwargs.pop("num_beams", model.config.num_beams) # AttributeError risk? + if num_return_sequences > num_beams: + num_beams = num_return_sequences + + tokenizer = AutoTokenizer.from_pretrained(model_name) + logger.info(f"Inferred tokenizer type: {tokenizer.__class__}") # if this is wrong, check config.model_type. + + if max_source_length is None: + max_source_length = tokenizer.model_max_length + if prefix is None: + prefix = prefix or getattr(model.config, "prefix", "") or "" + ds = Seq2SeqDataset( + tokenizer, + data_dir, + max_source_length, + max_target_length=1024, + type_path=type_path, + n_obs=n_obs, + prefix=prefix, + **dataset_kwargs, + ) + # I set shuffle=True for a more accurate progress bar. + # If all the longest samples are first, the prog bar estimate is too high at the beginning. + sampler = ds.make_sortish_sampler(bs, distributed=True, add_extra_examples=False, shuffle=True) + data_loader = DataLoader(ds, sampler=sampler, batch_size=bs, collate_fn=ds.collate_fn) + results = [] + for batch in tqdm(data_loader): + summaries = model.generate( + input_ids=batch["input_ids"].to(model.device), + attention_mask=batch["attention_mask"].to(model.device), + num_return_sequences=num_return_sequences, + num_beams=num_beams, + **generate_kwargs, + ) + preds = tokenizer.batch_decode(summaries, skip_special_tokens=True, clean_up_tokenization_spaces=False) + ids = batch["ids"] + if num_return_sequences > 1: + preds = chunks(preds, num_return_sequences) # batch size chunks, each of size num_return_seq + for i, pred in enumerate(preds): + results.append(dict(pred=pred, id=ids[i].item())) + save_json(results, save_path) + return results, sampler.num_replicas + + +def run_generate(): + parser = argparse.ArgumentParser( + epilog="Unspecified args like --num_beams=2 --decoder_start_token_id=4 are passed to model.generate" + ) + parser.add_argument("--data_dir", type=str, help="like cnn_dm/test.source") + parser.add_argument( + "--model_name", + type=str, + help="like facebook/bart-large-cnn,t5-base, etc.", + default="sshleifer/distilbart-xsum-12-3", + ) + parser.add_argument("--save_dir", type=str, help="where to save", default="tmp_gen") + parser.add_argument("--max_source_length", type=int, default=None) + parser.add_argument( + "--type_path", type=str, default="test", help="which subset to evaluate typically train/val/test" + ) + parser.add_argument("--task", type=str, default="summarization", help="used for task_specific_params + metrics") + parser.add_argument("--bs", type=int, default=8, required=False, help="batch size") + parser.add_argument( + "--local_rank", type=int, default=-1, required=False, help="should be passed by distributed.launch" + ) + + parser.add_argument( + "--n_obs", type=int, default=None, required=False, help="How many observations. Defaults to all." + ) + parser.add_argument( + "--num_return_sequences", type=int, default=1, required=False, help="How many sequences to return" + ) + parser.add_argument( + "--sync_timeout", + type=int, + default=600, + required=False, + help="How long should master process wait for other processes to finish.", + ) + parser.add_argument("--src_lang", type=str, default=None, required=False) + parser.add_argument("--tgt_lang", type=str, default=None, required=False) + parser.add_argument( + "--prefix", type=str, required=False, default=None, help="will be added to the begininng of src examples" + ) + parser.add_argument("--fp16", action="store_true") + parser.add_argument("--debug", action="store_true") + start_time = time.time() + args, rest = parser.parse_known_args() + generate_kwargs = parse_numeric_n_bool_cl_kwargs(rest) + if generate_kwargs and args.local_rank <= 0: + print(f"parsed the following generate kwargs: {generate_kwargs}") + json_save_dir = Path(args.save_dir + "_tmp") + Path(json_save_dir).mkdir(exist_ok=True) # this handles locking. + intermediate_files = list(json_save_dir.glob("rank_*.json")) + if intermediate_files: + raise ValueError(f"Found files at {json_save_dir} please move or remove them.") + # In theory, a node could finish and save before another node hits this. If this happens, we can address later. + dataset_kwargs = {} + if args.src_lang is not None: + dataset_kwargs["src_lang"] = args.src_lang + if args.tgt_lang is not None: + dataset_kwargs["tgt_lang"] = args.tgt_lang + + Path(args.save_dir).mkdir(exist_ok=True) + results, num_replicas = eval_data_dir( + args.data_dir, + json_save_dir, + args.model_name, + type_path=args.type_path, + bs=args.bs, + fp16=args.fp16, + task=args.task, + local_rank=args.local_rank, + n_obs=args.n_obs, + max_source_length=args.max_source_length, + num_return_sequences=args.num_return_sequences, + prefix=args.prefix, + dataset_kwargs=dataset_kwargs, + **generate_kwargs, + ) + + if args.local_rank <= 0: + save_dir = Path(args.save_dir) + save_dir.mkdir(exist_ok=True) + partial_results = gather_results_from_each_node(num_replicas, json_save_dir, args.sync_timeout) + preds = combine_partial_results(partial_results) + if args.num_return_sequences > 1: + save_path = save_dir.joinpath("pseudolabel_results.json") + print(f"Saving aggregated results at {save_path}, intermediate in {json_save_dir}/") + save_json(preds, save_path) + return + tgt_file = Path(args.data_dir).joinpath(args.type_path + ".target") + labels = [x.rstrip() for x in open(tgt_file).readlines()][: len(preds)] + + # Calculate metrics, save metrics, and save _generations.txt + calc_bleu = "translation" in args.task + score_fn = calculate_bleu if calc_bleu else calculate_rouge + metric_name = "bleu" if calc_bleu else "rouge" + metrics: Dict = score_fn(preds, labels) + metrics["n_obs"] = len(preds) + runtime = time.time() - start_time + metrics["seconds_per_sample"] = round(runtime / metrics["n_obs"], 4) + metrics["n_gpus"] = num_replicas + # TODO(@stas00): add whatever metadata to metrics + metrics_save_path = save_dir.joinpath(f"{args.type_path}_{metric_name}.json") + save_json(metrics, metrics_save_path, indent=None) + print(metrics) + write_txt_file(preds, save_dir.joinpath(f"{args.type_path}_generations.txt")) + if args.debug: + write_txt_file(labels, save_dir.joinpath(f"{args.type_path}.target")) + else: + shutil.rmtree(json_save_dir) + + +def combine_partial_results(partial_results) -> List: + """Concatenate partial results into one file, then sort it by id.""" + records = [] + for partial_result in partial_results: + records.extend(partial_result) + records = list(sorted(records, key=lambda x: x["id"])) + preds = [x["pred"] for x in records] + return preds + + +def gather_results_from_each_node(num_replicas, save_dir, timeout) -> List[Dict[str, List]]: + # WAIT FOR lots of .json files + start_wait = time.time() + logger.info("waiting for all nodes to finish") + json_data = None + while (time.time() - start_wait) < timeout: + json_files = list(save_dir.glob("rank_*.json")) + if len(json_files) < num_replicas: + continue + try: + # make sure all json files are fully saved + json_data = lmap(load_json, json_files) + return json_data + except JSONDecodeError: + continue + else: + raise TimeoutError("Rank 0 gave up on waiting for other processes") + # Unreachable + + +if __name__ == "__main__": + # Usage for MT: + run_generate() diff --git a/src/mrc_client/seq2seq/run_eval.py b/src/mrc_client/seq2seq/run_eval.py new file mode 100755 index 0000000000000000000000000000000000000000..b10223194499ee9b45892db4c4cb1e3fa6f89691 --- /dev/null +++ b/src/mrc_client/seq2seq/run_eval.py @@ -0,0 +1,163 @@ +#!/usr/bin/env python + +import argparse +import datetime +import json +import time +import warnings +from logging import getLogger +from pathlib import Path +from typing import Dict, List + +import torch +from tqdm import tqdm + +from transformers import AutoModelForSeq2SeqLM, AutoTokenizer +from seq2seq_utils import calculate_bleu, calculate_rouge, chunks, parse_numeric_n_bool_cl_kwargs, use_task_specific_params + + +logger = getLogger(__name__) + + +DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu" + + +def generate_summaries_or_translations( + examples: List[str], + out_file: str, + model_name: str, + batch_size: int = 8, + device: str = DEFAULT_DEVICE, + fp16=False, + task="summarization", + prefix=None, + **generate_kwargs, +) -> Dict: + """Save model.generate results to , and return how long it took.""" + fout = Path(out_file).open("w", encoding="utf-8") + model_name = str(model_name) + model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device) + if fp16: + model = model.half() + + tokenizer = AutoTokenizer.from_pretrained(model_name) + logger.info(f"Inferred tokenizer type: {tokenizer.__class__}") # if this is wrong, check config.model_type. + + start_time = time.time() + # update config with task specific params + use_task_specific_params(model, task) + if prefix is None: + prefix = prefix or getattr(model.config, "prefix", "") or "" + for examples_chunk in tqdm(list(chunks(examples, batch_size))): + examples_chunk = [prefix + text for text in examples_chunk] + batch = tokenizer(examples_chunk, return_tensors="pt", truncation=True, padding="longest").to(device) + summaries = model.generate( + input_ids=batch.input_ids, + attention_mask=batch.attention_mask, + **generate_kwargs, + ) + dec = tokenizer.batch_decode(summaries, skip_special_tokens=True, clean_up_tokenization_spaces=False) + for hypothesis in dec: + fout.write(hypothesis + "\n") + fout.flush() + fout.close() + runtime = int(time.time() - start_time) # seconds + n_obs = len(examples) + return dict(n_obs=n_obs, runtime=runtime, seconds_per_sample=round(runtime / n_obs, 4)) + + +def datetime_now(): + return datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") + + +def run_generate(verbose=True): + """ + + Takes input text, generates output, and then using reference calculates the BLEU scores. + + The results are saved to a file and returned to the caller, and printed out unless ``verbose=False`` is passed. + + Args: + verbose (:obj:`bool`, `optional`, defaults to :obj:`True`): print results to stdout + + Returns: + a tuple: ``(scores, params}`` + - ``scores``: a dict of scores data ``{'bleu': 39.6501, 'n_obs': 2000, 'runtime': 186, 'seconds_per_sample': 0.093}`` + - ``params``: a dict of custom params, e.g. ``{'num_beams': 5, 'length_penalty': 0.8}`` + """ + + parser = argparse.ArgumentParser() + parser.add_argument("--model_name", type=str, help="like facebook/bart-large-cnn,t5-base, etc.") + parser.add_argument("--input_path", type=str, help="like cnn_dm/test.source") + parser.add_argument("--save_path", type=str, help="where to save summaries") + parser.add_argument("--reference_path", type=str, required=False, help="like cnn_dm/test.target") + parser.add_argument("--score_path", type=str, required=False, default="metrics.json", help="where to save metrics") + parser.add_argument("--device", type=str, required=False, default=DEFAULT_DEVICE, help="cuda, cuda:1, cpu etc.") + parser.add_argument( + "--prefix", type=str, required=False, default=None, help="will be added to the begininng of src examples" + ) + parser.add_argument("--task", type=str, default="summarization", help="used for task_specific_params + metrics") + parser.add_argument("--bs", type=int, default=8, required=False, help="batch size") + parser.add_argument( + "--n_obs", type=int, default=-1, required=False, help="How many observations. Defaults to all." + ) + parser.add_argument("--fp16", action="store_true") + parser.add_argument("--dump-args", action="store_true", help="print the custom hparams with the results") + parser.add_argument( + "--info", + nargs="?", + type=str, + const=datetime_now(), + help="use in conjunction w/ --dump-args to print with the results whatever other info you'd like, e.g. lang=en-ru. If no value is passed, the current datetime string will be used.", + ) + # Unspecified args like --num_beams=2 --decoder_start_token_id=4 are passed to model.generate + args, rest = parser.parse_known_args() + parsed_args = parse_numeric_n_bool_cl_kwargs(rest) + if parsed_args and verbose: + print(f"parsed the following generate kwargs: {parsed_args}") + examples = [" " + x.rstrip() if "t5" in args.model_name else x.rstrip() for x in open(args.input_path).readlines()] + if args.n_obs > 0: + examples = examples[: args.n_obs] + Path(args.save_path).parent.mkdir(exist_ok=True) + if args.reference_path is None and Path(args.score_path).exists(): + warnings.warn(f"score_path {args.score_path} will be overwritten unless you type ctrl-c.") + runtime_metrics = generate_summaries_or_translations( + examples, + args.save_path, + args.model_name, + batch_size=args.bs, + device=args.device, + fp16=args.fp16, + task=args.task, + prefix=args.prefix, + **parsed_args, + ) + + if args.reference_path is None: + return {} + + # Compute scores + score_fn = calculate_bleu if "translation" in args.task else calculate_rouge + output_lns = [x.rstrip() for x in open(args.save_path).readlines()] + reference_lns = [x.rstrip() for x in open(args.reference_path).readlines()][: len(output_lns)] + scores: dict = score_fn(output_lns, reference_lns) + scores.update(runtime_metrics) + + if args.dump_args: + scores.update(parsed_args) + if args.info: + scores["info"] = args.info + + if verbose: + print(scores) + + if args.score_path is not None: + json.dump(scores, open(args.score_path, "w")) + + return scores + + +if __name__ == "__main__": + # Usage for MT: + # python run_eval.py MODEL_NAME $DATA_DIR/test.source $save_dir/test_translations.txt --reference_path $DATA_DIR/test.target --score_path $save_dir/test_bleu.json --task translation $@ + run_generate(verbose=True) diff --git a/src/mrc_client/seq2seq/run_eval_search.py b/src/mrc_client/seq2seq/run_eval_search.py new file mode 100755 index 0000000000000000000000000000000000000000..71599a2b664a7f36e1391499de72d1bf88633cfd --- /dev/null +++ b/src/mrc_client/seq2seq/run_eval_search.py @@ -0,0 +1,138 @@ +#!/usr/bin/env python + +import argparse +import itertools +import operator +import sys +from collections import OrderedDict + +from run_eval import datetime_now, run_generate +from seq2seq_utils import ROUGE_KEYS + + +# A table of supported tasks and the list of scores in the order of importance to be sorted by. +# To add a new task, simply list the score names that `run_eval.run_generate()` returns +task_score_names = { + "translation": ["bleu"], + "summarization": ROUGE_KEYS, +} + + +def parse_search_arg(search): + groups = search.split() + entries = {k: vs for k, vs in (g.split("=") for g in groups)} + entry_names = list(entries.keys()) + sets = [list((f"--{k} {v}") for v in vs.split(":")) for k, vs in entries.items()] + matrix = [list(x) for x in itertools.product(*sets)] + return matrix, entry_names + + +def run_search(): + """ + Run parametric search over the desired hparam space with help of ``run_eval.py``. + + All the arguments except ``--search`` are passed to ``run_eval.py`` as is. The values inside of "--search" are parsed, reformatted and fed to ``run_eval.py`` as additional args. + + The format for the ``--search`` value is a simple string with hparams and colon separated values to try, e.g.: + ``` + --search "num_beams=5:10 length_penalty=0.8:1.0:1.2 early_stopping=true:false" + ``` + which will generate ``12`` ``(2*3*2)`` searches for a product of each hparam. For example the example that was just used will invoke ``run_eval.py`` repeatedly with: + + ``` + --num_beams 5 --length_penalty 0.8 --early_stopping true + --num_beams 5 --length_penalty 0.8 --early_stopping false + [...] + --num_beams 10 --length_penalty 1.2 --early_stopping false + ``` + + On completion, this function prints a markdown table of the results sorted by the best BLEU score and the winning arguments. + + + """ + prog = sys.argv[0] + + parser = argparse.ArgumentParser( + usage="\n\nImportant: this script accepts all arguments `run_eval.py` accepts and then a few extra, therefore refer to `run_eval.py -h` for the complete list." + ) + parser.add_argument( + "--search", + type=str, + required=False, + help='param space to search, e.g. "num_beams=5:10 length_penalty=0.8:1.0:1.2"', + ) + parser.add_argument( + "--bs", type=int, default=8, required=False, help="initial batch size (may get reduced if it's too big)" + ) + parser.add_argument("--task", type=str, help="used for task_specific_params + metrics") + parser.add_argument( + "--info", + nargs="?", + type=str, + const=datetime_now(), + help="add custom notes to be printed before the results table. If no value is passed, the current datetime string will be used.", + ) + args, args_main = parser.parse_known_args() + # we share some of the args + args_main.extend(["--task", args.task]) + args_normal = [prog] + args_main + + # to support variations like translation_en_to_de" + task = "translation" if "translation" in args.task else "summarization" + + matrix, col_names = parse_search_arg(args.search) + col_names[0:0] = task_score_names[task] # score cols first + col_widths = {col: len(str(col)) for col in col_names} + results = [] + for r in matrix: + hparams = {k: v for k, v in (x.replace("--", "").split() for x in r)} + args_exp = " ".join(r).split() + args_exp.extend(["--bs", str(args.bs)]) # in case we need to reduce its size due to CUDA OOM + sys.argv = args_normal + args_exp + + # XXX: need to trap CUDA OOM and lower args.bs if that happens and retry + + scores = run_generate(verbose=False) + # make sure scores are first in the table + result = OrderedDict() + for score in task_score_names[task]: + result[score] = scores[score] + result.update(hparams) + results.append(result) + + # find widest entries + for k, v in result.items(): + l = len(str(v)) + if l > col_widths[k]: + col_widths[k] = l + + results_sorted = sorted(results, key=operator.itemgetter(*task_score_names[task]), reverse=True) + print(" | ".join([f"{col:{col_widths[col]}}" for col in col_names])) + print(" | ".join([f"{'-'*col_widths[col]}" for col in col_names])) + for row in results_sorted: + print(" | ".join([f"{row[col]:{col_widths[col]}}" for col in col_names])) + + best = results_sorted[0] + for score in task_score_names[task]: + del best[score] + best_args = [f"--{k} {v}" for k, v in best.items()] + dyn_args = ["--bs", str(args.bs)] + if args.info: + print(f"\nInfo: {args.info}") + print("\nBest score args:") + print(" ".join(args_main + best_args + dyn_args)) + + return results_sorted + + +if __name__ == "__main__": + # Usage: + # [normal-run_eval_search.py cmd plus] \ + # --search="num_beams=1:5:10 length_penalty=0.8:1:1.2 early_stopping=true:false" + # + # Example: + # PYTHONPATH="src:examples/seq2seq" python examples/seq2seq/run_eval_search.py $MODEL_NAME \ + # $DATA_DIR/val.source $SAVE_DIR/test_translations.txt --reference_path $DATA_DIR/val.target \ + # --score_path $SAVE_DIR/test_bleu.json --bs $BS --task translation \ + # --search="num_beams=1:5:10 length_penalty=0.8:1:1.2 early_stopping=true:false" + run_search() diff --git a/src/mrc_client/seq2seq/scripts/test.sh b/src/mrc_client/seq2seq/scripts/test.sh new file mode 100644 index 0000000000000000000000000000000000000000..068117074b901a9b31d0f111f784b2f85f47936c --- /dev/null +++ b/src/mrc_client/seq2seq/scripts/test.sh @@ -0,0 +1,9 @@ +#!/usr/bin/env bash + +echo "--input_path --model_name --save_path --reference_path" + +python3 run_eval.py \ + --max_length 15 \ + --min_length 1 \ + "$@" + diff --git a/src/mrc_client/seq2seq/scripts/train.sh b/src/mrc_client/seq2seq/scripts/train.sh new file mode 100755 index 0000000000000000000000000000000000000000..eda1cb85f521c111c484cda17aa760d3756ac5f7 --- /dev/null +++ b/src/mrc_client/seq2seq/scripts/train.sh @@ -0,0 +1,24 @@ +#!/usr/bin/env bash + +echo "--data_dir --model_name_or_path=facebook/bart-base --output_dir" + +python3 finetune.py \ + --gpus 8 \ + --do_train \ + --train_batch_size 16 \ + --eval_batch_size 32 \ + --gradient_accumulation_steps 2 \ + --num_train_epochs 20 \ + --max_source_length 400 \ + --max_target_length 15 \ + --val_max_target_length 15 \ + --test_max_target_length 15 \ + --min_target_length 1 \ + --val_check_interval 0.5 \ + --n_val -1 \ + --save_top_k 5 \ + --logger_name wandb \ + --overwrite_output_dir \ + "$@" + +python3 cjjpy.py --lark "training mrc-seq2seq completed" diff --git a/src/mrc_client/seq2seq/sentence_splitter.py b/src/mrc_client/seq2seq/sentence_splitter.py new file mode 100644 index 0000000000000000000000000000000000000000..c5acec73928ccd00dcf049601ebdf37bcdf4cfea --- /dev/null +++ b/src/mrc_client/seq2seq/sentence_splitter.py @@ -0,0 +1,22 @@ +import re + +from filelock import FileLock + + +try: + import nltk + + NLTK_AVAILABLE = True +except (ImportError, ModuleNotFoundError): + NLTK_AVAILABLE = False + +if NLTK_AVAILABLE: + with FileLock(".lock") as lock: + nltk.download("punkt", quiet=True) + + +def add_newline_to_end_of_each_sentence(x: str) -> str: + """This was added to get rougeLsum scores matching published rougeL scores for BART and PEGASUS.""" + re.sub("", "", x) # remove pegasus newline char + assert NLTK_AVAILABLE, "nltk must be installed to separate newlines between sentences. (pip install nltk)" + return "\n".join(nltk.sent_tokenize(x)) diff --git a/src/mrc_client/seq2seq/seq2seq_trainer.py b/src/mrc_client/seq2seq/seq2seq_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..f643c279b9ec8dfc4b91c693b47439ec1e2e08bd --- /dev/null +++ b/src/mrc_client/seq2seq/seq2seq_trainer.py @@ -0,0 +1,226 @@ +from typing import Any, Dict, Optional, Tuple, Union + +import torch +from torch import nn +from torch.utils.data import DistributedSampler, RandomSampler + +from transformers import PreTrainedModel, Trainer, logging +from transformers.configuration_fsmt import FSMTConfig +from transformers.file_utils import is_torch_tpu_available +from transformers.optimization import ( + Adafactor, + AdamW, + get_constant_schedule, + get_constant_schedule_with_warmup, + get_cosine_schedule_with_warmup, + get_cosine_with_hard_restarts_schedule_with_warmup, + get_linear_schedule_with_warmup, + get_polynomial_decay_schedule_with_warmup, +) +from transformers.trainer_pt_utils import get_tpu_sampler + + +logger = logging.get_logger(__name__) + +arg_to_scheduler = { + "linear": get_linear_schedule_with_warmup, + "cosine": get_cosine_schedule_with_warmup, + "cosine_w_restarts": get_cosine_with_hard_restarts_schedule_with_warmup, + "polynomial": get_polynomial_decay_schedule_with_warmup, + "constant": get_constant_schedule, + "constant_w_warmup": get_constant_schedule_with_warmup, +} + + +class Seq2SeqTrainer(Trainer): + def __init__(self, config=None, data_args=None, *args, **kwargs): + super().__init__(*args, **kwargs) + + if config is None: + assert isinstance( + self.model, PreTrainedModel + ), f"If no `config` is passed the model to be trained has to be of type `PreTrainedModel`, but is {self.model.__class__}" + self.config = self._actual_model(self.model).config + else: + self.config = config + + self.data_args = data_args + self.vocab_size = self.config.tgt_vocab_size if isinstance(self.config, FSMTConfig) else self.config.vocab_size + + if self.args.label_smoothing != 0 or (self.data_args is not None and self.data_args.ignore_pad_token_for_loss): + assert ( + self.config.pad_token_id is not None + ), "Make sure that `config.pad_token_id` is correcly defined when ignoring `pad_token` for loss calculation or doing label smoothing." + + if self.config.pad_token_id is None and self.config.eos_token_id is not None: + logger.warn( + f"The `config.pad_token_id` is `None`. Using `config.eos_token_id` = {self.config.eos_token_id} for padding.." + ) + + if self.args.label_smoothing == 0: + self.loss_fn = torch.nn.CrossEntropyLoss(ignore_index=self.config.pad_token_id) + else: + # dynamically import label_smoothed_nll_loss + from seq2seq_utils import label_smoothed_nll_loss + + self.loss_fn = label_smoothed_nll_loss + + def create_optimizer_and_scheduler(self, num_training_steps: int): + """ + Setup the optimizer and the learning rate scheduler. + + We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the + Trainer's init through :obj:`optimizers`, or subclass and override this method in a subclass. + """ + if self.optimizer is None: + no_decay = ["bias", "LayerNorm.weight"] + optimizer_grouped_parameters = [ + { + "params": [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)], + "weight_decay": self.args.weight_decay, + }, + { + "params": [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)], + "weight_decay": 0.0, + }, + ] + if self.args.adafactor: + self.optimizer = Adafactor( + optimizer_grouped_parameters, + lr=self.args.learning_rate, + scale_parameter=False, + relative_step=False, + ) + + else: + self.optimizer = AdamW( + optimizer_grouped_parameters, lr=self.args.learning_rate, eps=self.args.adam_epsilon + ) + + if self.lr_scheduler is None: + self.lr_scheduler = self._get_lr_scheduler(num_training_steps) + else: # ignoring --lr_scheduler + logger.warn("scheduler is passed to `Seq2SeqTrainer`, `--lr_scheduler` arg is ignored.") + + def _get_lr_scheduler(self, num_training_steps): + schedule_func = arg_to_scheduler[self.args.lr_scheduler] + if self.args.lr_scheduler == "constant": + scheduler = schedule_func(self.optimizer) + elif self.args.lr_scheduler == "constant_w_warmup": + scheduler = schedule_func(self.optimizer, num_warmup_steps=self.args.warmup_steps) + else: + scheduler = schedule_func( + self.optimizer, num_warmup_steps=self.args.warmup_steps, num_training_steps=num_training_steps + ) + return scheduler + + def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]: + if isinstance(self.train_dataset, torch.utils.data.IterableDataset): + return None + elif is_torch_tpu_available(): + return get_tpu_sampler(self.train_dataset) + else: + if self.args.sortish_sampler: + self.train_dataset.make_sortish_sampler( + self.args.per_device_train_batch_size, distributed=self.args.n_gpu > 1 + ) + + return ( + RandomSampler(self.train_dataset) + if self.args.local_rank == -1 + else DistributedSampler(self.train_dataset) + ) + + def _compute_loss(self, model, inputs, labels): + if self.args.label_smoothing == 0: + if self.data_args is not None and self.data_args.ignore_pad_token_for_loss: + # force training to ignore pad token + logits = model(**inputs, use_cache=False)[0] + loss = self.loss_fn(logits.view(-1, logits.shape[-1]), labels.view(-1)) + else: + # compute usual loss via models + loss, logits = model(**inputs, labels=labels, use_cache=False)[:2] + else: + # compute label smoothed loss + logits = model(**inputs, use_cache=False)[0] + lprobs = torch.nn.functional.log_softmax(logits, dim=-1) + loss, _ = self.loss_fn(lprobs, labels, self.args.label_smoothing, ignore_index=self.config.pad_token_id) + return loss, logits + + def compute_loss(self, model, inputs): + labels = inputs.pop("labels") + loss, _ = self._compute_loss(model, inputs, labels) + return loss + + def prediction_step( + self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], prediction_loss_only: bool + ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: + """ + Perform an evaluation step on :obj:`model` using obj:`inputs`. + + Subclass and override to inject custom behavior. + + Args: + model (:obj:`nn.Module`): + The model to evaluate. + inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`): + The inputs and targets of the model. + + The dictionary will be unpacked before being fed to the model. Most models expect the targets under the + argument :obj:`labels`. Check your model's documentation for all accepted arguments. + prediction_loss_only (:obj:`bool`): + Whether or not to return the loss only. + + Return: + Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: + A tuple with the loss, logits and labels (each being optional). + """ + inputs = self._prepare_inputs(inputs) + + gen_kwargs = { + "max_length": self.data_args.val_max_target_length + if self.data_args is not None + else self.config.max_length, + "num_beams": self.data_args.eval_beams if self.data_args is not None else self.config.num_beams, + } + + if self.args.predict_with_generate and not self.args.prediction_loss_only: + generated_tokens = model.generate( + inputs["input_ids"], + attention_mask=inputs["attention_mask"], + **gen_kwargs, + ) + # in case the batch is shorter than max length, the output should be padded + if generated_tokens.shape[-1] < gen_kwargs["max_length"]: + generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_length"]) + + labels = inputs.pop("labels") + with torch.no_grad(): + # compute loss on predict data + loss, logits = self._compute_loss(model, inputs, labels) + + loss = loss.mean().detach() + if self.args.prediction_loss_only: + return (loss, None, None) + + logits = generated_tokens if self.args.predict_with_generate else logits + + if labels.shape[-1] < gen_kwargs["max_length"]: + labels = self._pad_tensors_to_max_len(labels, gen_kwargs["max_length"]) + + return (loss, logits, labels) + + def _pad_tensors_to_max_len(self, tensor, max_length): + # If PAD token is not defined at least EOS token has to be defined + pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else self.config.eos_token_id + + if pad_token_id is None: + raise ValueError( + f"Make sure that either `config.pad_token_id` or `config.eos_token_id` is defined if tensor has to be padded to `max_length`={max_length}" + ) + + padded_tensor = pad_token_id * torch.ones( + (tensor.shape[0], max_length), dtype=tensor.dtype, device=tensor.device + ) + padded_tensor[:, : tensor.shape[-1]] = tensor + return padded_tensor diff --git a/src/mrc_client/seq2seq/seq2seq_training_args.py b/src/mrc_client/seq2seq/seq2seq_training_args.py new file mode 100644 index 0000000000000000000000000000000000000000..0bd486026a2b4550ef2f5bd1086b0fe847400db3 --- /dev/null +++ b/src/mrc_client/seq2seq/seq2seq_training_args.py @@ -0,0 +1,45 @@ +import logging +from dataclasses import dataclass, field +from typing import Optional + +from seq2seq_trainer import arg_to_scheduler +from transformers import TrainingArguments + + +logger = logging.getLogger(__name__) + + +@dataclass +class Seq2SeqTrainingArguments(TrainingArguments): + """ + Parameters: + label_smoothing (:obj:`float`, `optional`, defaults to 0): + The label smoothing epsilon to apply (if not zero). + sortish_sampler (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether to SortishSamler or not. It sorts the inputs according to lenghts in-order to minimizing the padding size. + predict_with_generate (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether to use generate to calculate generative metrics (ROUGE, BLEU). + """ + + label_smoothing: Optional[float] = field( + default=0.0, metadata={"help": "The label smoothing epsilon to apply (if not zero)."} + ) + sortish_sampler: bool = field(default=False, metadata={"help": "Whether to SortishSamler or not."}) + predict_with_generate: bool = field( + default=False, metadata={"help": "Whether to use generate to calculate generative metrics (ROUGE, BLEU)."} + ) + adafactor: bool = field(default=False, metadata={"help": "whether to use adafactor"}) + encoder_layerdrop: Optional[float] = field( + default=None, metadata={"help": "Encoder layer dropout probability. Goes into model.config."} + ) + decoder_layerdrop: Optional[float] = field( + default=None, metadata={"help": "Decoder layer dropout probability. Goes into model.config."} + ) + dropout: Optional[float] = field(default=None, metadata={"help": "Dropout probability. Goes into model.config."}) + attention_dropout: Optional[float] = field( + default=None, metadata={"help": "Attention dropout probability. Goes into model.config."} + ) + lr_scheduler: Optional[str] = field( + default="linear", + metadata={"help": f"Which lr scheduler to use. Selected in {sorted(arg_to_scheduler.keys())}"}, + ) diff --git a/src/mrc_client/seq2seq/seq2seq_utils.py b/src/mrc_client/seq2seq/seq2seq_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..5fb875871d84ebd92adf70a5acab74c852cb0e4c --- /dev/null +++ b/src/mrc_client/seq2seq/seq2seq_utils.py @@ -0,0 +1,672 @@ +import itertools +import json +import linecache +import math +import os +import pickle +import socket +from logging import getLogger +from pathlib import Path +from typing import Callable, Dict, Iterable, List, Tuple, Union + +import git +import numpy as np +import torch +import torch.distributed as dist +from rouge_score import rouge_scorer, scoring +from sacrebleu import corpus_bleu +from torch import nn +from torch.utils.data import Dataset, Sampler + +from sentence_splitter import add_newline_to_end_of_each_sentence +from transformers import BartTokenizer, EvalPrediction, PreTrainedTokenizer, T5Tokenizer +from transformers.file_utils import cached_property +try: + from transformers.modeling_bart import shift_tokens_right +except: + from transformers.models.bart.modeling_bart import shift_tokens_right + +try: + from fairseq.data.data_utils import batch_by_size + + FAIRSEQ_AVAILABLE = True +except (ImportError, ModuleNotFoundError): + FAIRSEQ_AVAILABLE = False + + +def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=-100): + """From fairseq""" + if target.dim() == lprobs.dim() - 1: + target = target.unsqueeze(-1) + nll_loss = -lprobs.gather(dim=-1, index=target) + smooth_loss = -lprobs.sum(dim=-1, keepdim=True) + if ignore_index is not None: + pad_mask = target.eq(ignore_index) + nll_loss.masked_fill_(pad_mask, 0.0) + smooth_loss.masked_fill_(pad_mask, 0.0) + else: + nll_loss = nll_loss.squeeze(-1) + smooth_loss = smooth_loss.squeeze(-1) + + nll_loss = nll_loss.sum() # mean()? Scared to break other math. + smooth_loss = smooth_loss.sum() + eps_i = epsilon / lprobs.size(-1) + loss = (1.0 - epsilon) * nll_loss + eps_i * smooth_loss + return loss, nll_loss + + +def lmap(f: Callable, x: Iterable) -> List: + """list(map(f, x))""" + return list(map(f, x)) + + +def calculate_bleu(output_lns, refs_lns, **kwargs) -> dict: + """Uses sacrebleu's corpus_bleu implementation.""" + return {"bleu": round(corpus_bleu(output_lns, [refs_lns], **kwargs).score, 4)} + + +def build_compute_metrics_fn(task_name: str, tokenizer: PreTrainedTokenizer) -> Callable[[EvalPrediction], Dict]: + def non_pad_len(tokens: np.ndarray) -> int: + return np.count_nonzero(tokens != tokenizer.pad_token_id) + + def decode_pred(pred: EvalPrediction) -> Tuple[List[str], List[str]]: + pred_str = tokenizer.batch_decode(pred.predictions, skip_special_tokens=True) + label_str = tokenizer.batch_decode(pred.label_ids, skip_special_tokens=True) + pred_str = lmap(str.strip, pred_str) + label_str = lmap(str.strip, label_str) + return pred_str, label_str + + def summarization_metrics(pred: EvalPrediction) -> Dict: + pred_str, label_str = decode_pred(pred) + rouge: Dict = calculate_rouge(pred_str, label_str) + summ_len = np.round(np.mean(lmap(non_pad_len, pred.predictions)), 1) + rouge.update({"gen_len": summ_len}) + return rouge + + def translation_metrics(pred: EvalPrediction) -> Dict: + pred_str, label_str = decode_pred(pred) + bleu: Dict = calculate_bleu(pred_str, label_str) + gen_len = np.round(np.mean(lmap(non_pad_len, pred.predictions)), 1) + bleu.update({"gen_len": gen_len}) + return bleu + + compute_metrics_fn = summarization_metrics if "summarization" in task_name else translation_metrics + return compute_metrics_fn + + +def trim_batch( + input_ids, + pad_token_id, + attention_mask=None, +): + """Remove columns that are populated exclusively by pad_token_id""" + keep_column_mask = input_ids.ne(pad_token_id).any(dim=0) + if attention_mask is None: + return input_ids[:, keep_column_mask] + else: + return (input_ids[:, keep_column_mask], attention_mask[:, keep_column_mask]) + + +class AbstractSeq2SeqDataset(Dataset): + def __init__( + self, + tokenizer, + data_dir, + max_source_length, + max_target_length, + type_path="train", + n_obs=None, + prefix="", + **dataset_kwargs + ): + super().__init__() + self.src_file = Path(data_dir).joinpath(type_path + ".source") + self.tgt_file = Path(data_dir).joinpath(type_path + ".target") + self.len_file = Path(data_dir).joinpath(type_path + ".len") + if os.path.exists(self.len_file): + self.src_lens = pickle_load(self.len_file) + self.used_char_len = False + else: + self.src_lens = self.get_char_lens(self.src_file) + self.used_char_len = True + self.max_source_length = max_source_length + self.max_target_length = max_target_length + assert min(self.src_lens) > 0, f"found empty line in {self.src_file}" + self.tokenizer = tokenizer + self.prefix = prefix if prefix is not None else "" + + if n_obs is not None: + self.src_lens = self.src_lens[:n_obs] + self.pad_token_id = self.tokenizer.pad_token_id + self.dataset_kwargs = dataset_kwargs + dataset_kwargs.update({"add_prefix_space": True} if isinstance(self.tokenizer, BartTokenizer) else {}) + + def __len__(self): + return len(self.src_lens) + + @staticmethod + def get_char_lens(data_file): + return [len(x) for x in Path(data_file).open().readlines()] + + @cached_property + def tgt_lens(self): + """Length in characters of target documents""" + return self.get_char_lens(self.tgt_file) + + def make_sortish_sampler(self, batch_size, distributed=False, shuffle=True, **kwargs): + if distributed: + return DistributedSortishSampler(self, batch_size, shuffle=shuffle, **kwargs) + else: + return SortishSampler(self.src_lens, batch_size, shuffle=shuffle) + + def make_dynamic_sampler(self, max_tokens_per_batch=1024, **kwargs): + assert FAIRSEQ_AVAILABLE, "Dynamic batch size requires `pip install fairseq`" + assert not self.used_char_len, "You must call python make_len_file.py before calling make_dynamic_sampler" + sorted_indices = list(self.make_sortish_sampler(1024, shuffle=False)) + + def num_tokens_in_example(i): + return min(self.src_lens[i], self.max_target_length) + + # call fairseq cython function + batch_sampler: List[List[int]] = batch_by_size( + sorted_indices, + num_tokens_fn=num_tokens_in_example, + max_tokens=max_tokens_per_batch, + required_batch_size_multiple=64, + ) + shuffled_batches = [batch_sampler[i] for i in np.random.permutation(range(len(batch_sampler)))] + # move the largest batch to the front to OOM quickly (uses an approximation for padding) + approximate_toks_per_batch = [max(self.src_lens[i] for i in batch) * len(batch) for batch in shuffled_batches] + largest_batch_idx = np.argmax(approximate_toks_per_batch) + shuffled_batches[0], shuffled_batches[largest_batch_idx] = ( + shuffled_batches[largest_batch_idx], + shuffled_batches[0], + ) + return shuffled_batches + + def __getitem__(self, item): + raise NotImplementedError("You must implement this") + + def collate_fn(self, batch): + raise NotImplementedError("You must implement this") + + +class LegacySeq2SeqDataset(AbstractSeq2SeqDataset): + def __getitem__(self, index) -> Dict[str, torch.Tensor]: + """Call tokenizer on src and tgt_lines""" + index = index + 1 # linecache starts at 1 + source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip("\n") + tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n") + assert source_line, f"empty source line for index {index}" + assert tgt_line, f"empty tgt line for index {index}" + source_inputs = self.encode_line(self.tokenizer, source_line, self.max_source_length) + target_inputs = self.encode_line(self.tokenizer, tgt_line, self.max_target_length) + + source_ids = source_inputs["input_ids"].squeeze() + target_ids = target_inputs["input_ids"].squeeze() + src_mask = source_inputs["attention_mask"].squeeze() + return { + "input_ids": source_ids, + "attention_mask": src_mask, + "labels": target_ids, + } + + def encode_line(self, tokenizer, line, max_length, pad_to_max_length=True, return_tensors="pt"): + """Only used by LegacyDataset""" + return tokenizer( + [line], + max_length=max_length, + padding="max_length" if pad_to_max_length else None, + truncation=True, + return_tensors=return_tensors, + **self.dataset_kwargs, + ) + + def collate_fn(self, batch) -> Dict[str, torch.Tensor]: + input_ids = torch.stack([x["input_ids"] for x in batch]) + masks = torch.stack([x["attention_mask"] for x in batch]) + target_ids = torch.stack([x["labels"] for x in batch]) + pad_token_id = self.pad_token_id + y = trim_batch(target_ids, pad_token_id) + source_ids, source_mask = trim_batch(input_ids, pad_token_id, attention_mask=masks) + batch = { + "input_ids": source_ids, + "attention_mask": source_mask, + "labels": y, + } + return batch + + +class Seq2SeqDataset(AbstractSeq2SeqDataset): + """A dataset that calls prepare_seq2seq_batch.""" + + def __getitem__(self, index) -> Dict[str, str]: + index = index + 1 # linecache starts at 1 + source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip("\n") + tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n") + assert source_line, f"empty source line for index {index}" + assert tgt_line, f"empty tgt line for index {index}" + return {"tgt_texts": tgt_line, "src_texts": source_line, "id": index - 1} + + def collate_fn(self, batch) -> Dict[str, torch.Tensor]: + """Call prepare_seq2seq_batch.""" + batch_encoding: Dict[str, torch.Tensor] = self.tokenizer.prepare_seq2seq_batch( + [x["src_texts"] for x in batch], + tgt_texts=[x["tgt_texts"] for x in batch], + max_length=self.max_source_length, + max_target_length=self.max_target_length, + return_tensors="pt", + **self.dataset_kwargs, + ).data + batch_encoding["ids"] = torch.tensor([x["id"] for x in batch]) + return batch_encoding + + +class UniQASeq2SeqDataset(AbstractSeq2SeqDataset): + """A dataset that calls prepare_seq2seq_batch.""" + + def __getitem__(self, index) -> Dict[str, str]: + index = index + 1 # linecache starts at 1 + source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip("\n").replace('', '\n') + tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n") + assert source_line, f"empty source line for index {index}" + assert tgt_line, f"empty tgt line for index {index}" + return {"tgt_texts": tgt_line, "src_texts": source_line, "id": index - 1} + + def collate_fn(self, batch) -> Dict[str, torch.Tensor]: + """Call prepare_seq2seq_batch.""" + batch_encoding: Dict[str, torch.Tensor] = self.tokenizer.prepare_seq2seq_batch( + [x["src_texts"] for x in batch], + tgt_texts=[x["tgt_texts"] for x in batch], + max_length=self.max_source_length, + max_target_length=self.max_target_length, + return_tensors="pt", + **self.dataset_kwargs, + ).data + batch_encoding["ids"] = torch.tensor([x["id"] for x in batch]) + return batch_encoding + + +class Seq2SeqDataCollator: + def __init__(self, tokenizer, data_args, tpu_num_cores=None): + self.tokenizer = tokenizer + self.pad_token_id = tokenizer.pad_token_id + assert ( + self.pad_token_id is not None + ), f"pad_token_id is not defined for ({self.tokenizer.__class__.__name__}), it must be defined." + self.data_args = data_args + self.tpu_num_cores = tpu_num_cores + self.dataset_kwargs = {"add_prefix_space": True} if isinstance(tokenizer, BartTokenizer) else {} + if data_args.src_lang is not None: + self.dataset_kwargs["src_lang"] = data_args.src_lang + if data_args.tgt_lang is not None: + self.dataset_kwargs["tgt_lang"] = data_args.tgt_lang + + def __call__(self, batch) -> Dict[str, torch.Tensor]: + if hasattr(self.tokenizer, "prepare_seq2seq_batch"): + batch = self._encode(batch) + input_ids, attention_mask, labels = ( + batch["input_ids"], + batch["attention_mask"], + batch["labels"], + ) + else: + input_ids = torch.stack([x["input_ids"] for x in batch]) + attention_mask = torch.stack([x["attention_mask"] for x in batch]) + labels = torch.stack([x["labels"] for x in batch]) + + labels = trim_batch(labels, self.pad_token_id) + input_ids, attention_mask = trim_batch(input_ids, self.pad_token_id, attention_mask=attention_mask) + + if isinstance(self.tokenizer, T5Tokenizer): + decoder_input_ids = self._shift_right_t5(labels) + else: + decoder_input_ids = shift_tokens_right(labels, self.pad_token_id) + + batch = { + "input_ids": input_ids, + "attention_mask": attention_mask, + "decoder_input_ids": decoder_input_ids, + "labels": labels, + } + return batch + + def _shift_right_t5(self, input_ids): + # shift inputs to the right + shifted_input_ids = input_ids.new_zeros(input_ids.shape) + shifted_input_ids[..., 1:] = input_ids[..., :-1].clone() + shifted_input_ids[..., 0] = self.pad_token_id + return shifted_input_ids + + def _encode(self, batch) -> Dict[str, torch.Tensor]: + batch_encoding = self.tokenizer.prepare_seq2seq_batch( + [x["src_texts"] for x in batch], + tgt_texts=[x["tgt_texts"] for x in batch], + max_length=self.data_args.max_source_length, + max_target_length=self.data_args.max_target_length, + padding="max_length" if self.tpu_num_cores is not None else "longest", # TPU hack + return_tensors="pt", + **self.dataset_kwargs, + ) + return batch_encoding.data + + +class SortishSampler(Sampler): + "Go through the text data by order of src length with a bit of randomness. From fastai repo." + + def __init__(self, data, batch_size, shuffle=True): + self.data, self.bs, self.shuffle = data, batch_size, shuffle + + def __len__(self) -> int: + return len(self.data) + + def __iter__(self): + return iter(sortish_sampler_indices(self.data, self.bs, shuffle=self.shuffle)) + + +def sortish_sampler_indices(data: List, bs: int, shuffle=True) -> np.array: + "Go through the text data by order of src length with a bit of randomness. From fastai repo." + if not shuffle: + return np.argsort(np.array(data) * -1) + + def key_fn(i): + return data[i] + + idxs = np.random.permutation(len(data)) + sz = bs * 50 + ck_idx = [idxs[i : i + sz] for i in range(0, len(idxs), sz)] + sort_idx = np.concatenate([sorted(s, key=key_fn, reverse=True) for s in ck_idx]) + sz = bs + ck_idx = [sort_idx[i : i + sz] for i in range(0, len(sort_idx), sz)] + max_ck = np.argmax([key_fn(ck[0]) for ck in ck_idx]) # find the chunk with the largest key, + ck_idx[0], ck_idx[max_ck] = ck_idx[max_ck], ck_idx[0] # then make sure it goes first. + sort_idx = np.concatenate(np.random.permutation(ck_idx[1:])) if len(ck_idx) > 1 else np.array([], dtype=np.int) + sort_idx = np.concatenate((ck_idx[0], sort_idx)) + return sort_idx + + +class DistributedSortishSampler(Sampler): + """Copied from torch DistributedSampler""" + + def __init__(self, dataset, batch_size, num_replicas=None, rank=None, add_extra_examples=True, shuffle=True): + if num_replicas is None: + if not dist.is_available(): + raise RuntimeError("Requires distributed package to be available") + num_replicas = dist.get_world_size() + if rank is None: + if not dist.is_available(): + raise RuntimeError("Requires distributed package to be available") + rank = dist.get_rank() + self.dataset = dataset + self.num_replicas = num_replicas + self.rank = rank + self.epoch = 0 + if add_extra_examples: + self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas)) + self.total_size = self.num_samples * self.num_replicas + else: + self.total_size = len(dataset) + self.num_samples = len(self.available_indices) + self.batch_size = batch_size + self.add_extra_examples = add_extra_examples + self.shuffle = shuffle + + def __iter__(self) -> Iterable: + g = torch.Generator() + g.manual_seed(self.epoch) + + sortish_data = [self.dataset.src_lens[i] for i in self.available_indices] + sortish_indices = sortish_sampler_indices(sortish_data, self.batch_size, shuffle=self.shuffle) + indices = [self.available_indices[i] for i in sortish_indices] + assert len(indices) == self.num_samples + return iter(indices) + + @cached_property + def available_indices(self) -> np.array: + indices = list(range(len(self.dataset))) + # add extra samples to make it evenly divisible + indices += indices[: (self.total_size - len(indices))] + assert len(indices) == self.total_size + # subsample + available_indices = indices[self.rank : self.total_size : self.num_replicas] + return available_indices + + def __len__(self): + return self.num_samples + + def set_epoch(self, epoch): + self.epoch = epoch + + +logger = getLogger(__name__) + + +def use_task_specific_params(model, task): + """Update config with summarization specific params.""" + task_specific_params = model.config.task_specific_params + + if task_specific_params is not None: + pars = task_specific_params.get(task, {}) + logger.info(f"using task specific params for {task}: {pars}") + model.config.update(pars) + + +def pickle_load(path): + """pickle.load(path)""" + with open(path, "rb") as f: + return pickle.load(f) + + +def pickle_save(obj, path): + """pickle.dump(obj, path)""" + with open(path, "wb") as f: + return pickle.dump(obj, f) + + +def flatten_list(summary_ids: List[List]): + return [x for x in itertools.chain.from_iterable(summary_ids)] + + +def save_git_info(folder_path: str) -> None: + """Save git information to output_dir/git_log.json""" + repo_infos = get_git_info() + save_json(repo_infos, os.path.join(folder_path, "git_log.json")) + + +def save_json(content, path, indent=4, **json_dump_kwargs): + with open(path, "w") as f: + json.dump(content, f, indent=indent, **json_dump_kwargs) + + +def load_json(path): + with open(path) as f: + return json.load(f) + + +def get_git_info(): + try: + repo = git.Repo(search_parent_directories=True) + repo_infos = { + "repo_id": str(repo), + "repo_sha": str(repo.head.object.hexsha), + "repo_branch": str(repo.active_branch), + "hostname": str(socket.gethostname()), + } + return repo_infos + except TypeError: + return { + "repo_id": None, + "repo_sha": None, + "repo_branch": None, + "hostname": None, + } + + +ROUGE_KEYS = ["rouge1", "rouge2", "rougeL", "rougeLsum"] + + +def extract_rouge_mid_statistics(dct): + new_dict = {} + for k1, v1 in dct.items(): + mid = v1.mid + new_dict[k1] = {stat: round(getattr(mid, stat), 4) for stat in ["precision", "recall", "fmeasure"]} + return new_dict + + +def calculate_rouge( + pred_lns: List[str], + tgt_lns: List[str], + use_stemmer=True, + rouge_keys=ROUGE_KEYS, + return_precision_and_recall=False, + bootstrap_aggregation=True, + newline_sep=True, +) -> Dict: + """Calculate rouge using rouge_scorer package. + + Args: + pred_lns: list of summaries generated by model + tgt_lns: list of groundtruth summaries (e.g. contents of val.target) + use_stemmer: Bool indicating whether Porter stemmer should be used to + strip word suffixes to improve matching. + rouge_keys: which metrics to compute, defaults to rouge1, rouge2, rougeL, rougeLsum + return_precision_and_recall: (False) whether to also return precision and recall. + bootstrap_aggregation: whether to do the typical bootstrap resampling of scores. Defaults to True, if False + this function returns a collections.defaultdict[metric: list of values for each observation for each subscore]`` + newline_sep:(default=True) whether to add newline between sentences. This is essential for calculation rougeL + on multi sentence summaries (CNN/DM dataset). + + Returns: + Dict[score: value] if aggregate else defaultdict(list) keyed by rouge_keys + + """ + scorer = rouge_scorer.RougeScorer(rouge_keys, use_stemmer=use_stemmer) + aggregator = scoring.BootstrapAggregator() + for pred, tgt in zip(tgt_lns, pred_lns): + # rougeLsum expects "\n" separated sentences within a summary + if newline_sep: + pred = add_newline_to_end_of_each_sentence(pred) + tgt = add_newline_to_end_of_each_sentence(tgt) + scores = scorer.score(pred, tgt) + aggregator.add_scores(scores) + + if bootstrap_aggregation: + result = aggregator.aggregate() + if return_precision_and_recall: + return extract_rouge_mid_statistics(result) # here we return dict + else: + return {k: round(v.mid.fmeasure * 100, 4) for k, v in result.items()} + + else: + return aggregator._scores # here we return defaultdict(list) + + +# Utilities for freezing parameters and checking whether they are frozen + + +def freeze_params(model: nn.Module): + """Set requires_grad=False for each of model.parameters()""" + for par in model.parameters(): + par.requires_grad = False + + +def freeze_embeds(model): + """Freeze token embeddings and positional embeddings for bart, just token embeddings for t5.""" + model_type = model.config.model_type + + if model_type == "t5": + freeze_params(model.shared) + for d in [model.encoder, model.decoder]: + freeze_params(d.embed_tokens) + elif model_type == "fsmt": + for d in [model.model.encoder, model.model.decoder]: + freeze_params(d.embed_positions) + freeze_params(d.embed_tokens) + else: + freeze_params(model.model.shared) + for d in [model.model.encoder, model.model.decoder]: + freeze_params(d.embed_positions) + freeze_params(d.embed_tokens) + + +def grad_status(model: nn.Module) -> Iterable: + return (par.requires_grad for par in model.parameters()) + + +def any_requires_grad(model: nn.Module) -> bool: + return any(grad_status(model)) + + +def assert_all_frozen(model): + model_grads: List[bool] = list(grad_status(model)) + n_require_grad = sum(lmap(int, model_grads)) + npars = len(model_grads) + assert not any(model_grads), f"{n_require_grad/npars:.1%} of {npars} weights require grad" + + +def assert_not_all_frozen(model): + model_grads: List[bool] = list(grad_status(model)) + npars = len(model_grads) + assert any(model_grads), f"none of {npars} weights require grad" + + +def parse_numeric_n_bool_cl_kwargs(unparsed_args: List[str]) -> Dict[str, Union[int, float, bool]]: + """ + Parse an argv list of unspecified command line args to a dict. + Assumes all values are either numeric or boolean in the form of true/false. + """ + result = {} + assert len(unparsed_args) % 2 == 0, f"got odd number of unparsed args: {unparsed_args}" + num_pairs = len(unparsed_args) // 2 + for pair_num in range(num_pairs): + i = 2 * pair_num + assert unparsed_args[i].startswith("--") + if unparsed_args[i + 1].lower() == "true": + value = True + elif unparsed_args[i + 1].lower() == "false": + value = False + else: + try: + value = int(unparsed_args[i + 1]) + except ValueError: + value = float(unparsed_args[i + 1]) # this can raise another informative ValueError + + result[unparsed_args[i][2:]] = value + return result + + +def write_txt_file(ordered_tgt, path): + f = Path(path).open("w") + for ln in ordered_tgt: + f.write(ln + "\n") + f.flush() + + +def chunks(lst, n): + """Yield successive n-sized chunks from lst.""" + for i in range(0, len(lst), n): + yield lst[i : i + n] + + +def check_output_dir(args, expected_items=0): + """ + Checks whether to bail out if output_dir already exists and has more than expected_items in it + + `args`: needs to have the following attributes of `args`: + - output_dir + - do_train + - overwrite_output_dir + + `expected_items`: normally 0 (default) - i.e. empty dir, but in some cases a few files are expected (e.g. recovery from OOM) + """ + if ( + os.path.exists(args.output_dir) + and len(os.listdir(args.output_dir)) > expected_items + and args.do_train + and not args.overwrite_output_dir + ): + raise ValueError( + f"Output directory ({args.output_dir}) already exists and " + f"has {len(os.listdir(args.output_dir))} items in it (expected {expected_items} items). " + "Use --overwrite_output_dir to overcome." + ) diff --git a/src/parsing_client/sentence_parser.py b/src/parsing_client/sentence_parser.py new file mode 100644 index 0000000000000000000000000000000000000000..0e56fdadc7233f08a275aae888d2a07687b8e729 --- /dev/null +++ b/src/parsing_client/sentence_parser.py @@ -0,0 +1,266 @@ +# -*- coding:utf-8 -*- + +""" +@Last modified date : 2020/12/23 +""" + +import re +import nltk +from nltk.stem import WordNetLemmatizer +from allennlp.predictors.predictor import Predictor + +nltk.download('wordnet') +nltk.download('stopwords') + + +def deal_bracket(text, restore, leading_ent=None): + if leading_ent: + leading_ent = ' '.join(leading_ent.split('_')) + text = f'Things about {leading_ent}: ' + text + if restore: + text = text.replace('-LRB-', '(').replace('-RRB-', ')') + text = text.replace('LRB', '(').replace('RRB', ')') + return text + + +def refine_entity(entity): + entity = re.sub(r'-LRB- .+ -RRB-$', '', entity) + entity = re.sub(r'LRB .+ RRB$', '', entity) + entity = re.sub(r'_', ' ', entity) + entity = re.sub(r'\s+', ' ', entity) + return entity.strip() + + +def find_sub_seq(seq_a, seq_b, shift=0, uncased=False, lemmatizer=None): + if uncased: + seq_a = [token.lower() for token in seq_a] + seq_b = [token.lower() for token in seq_b] + if lemmatizer is not None: + seq_a = [lemmatizer.lemmatize(token) for token in seq_a] + seq_b = [lemmatizer.lemmatize(token) for token in seq_b] + for i in range(shift, len(seq_a)): + if seq_a[i:i+len(seq_b)] == seq_b: + return i, i + len(seq_b) + return -1, -1 + + +def is_sub_seq(seq_start, seq_end, all_seqs): + for start, end, is_candidate in all_seqs: + if start <= seq_start < seq_end <= end: + return start, end, is_candidate + return None + + +# extract named entity with B-I-L-U-O schema +def extract_named_entity(tags): + all_NEs = [] + ne_type, ne_start = '', -1 + for i, t in enumerate(tags): + if t == 'O': + ne_type, ne_start = '', -1 + continue + t1, t2 = t.split('-') + if t1 == 'B': + ne_type, ne_start = t2, i + elif t1 == 'I' and t2 != ne_type: + ne_type, ne_start = '', -1 + elif t1 == 'L' and t2 != ne_type: + ne_type, ne_start = '', -1 + elif t1 == 'L' and t2 == ne_type: + all_NEs.append((ne_start, i + 1, False)) + ne_type, ne_start = '', -1 + elif t1 == 'U': + all_NEs.append((i, i + 1, False)) + ne_type, ne_start = '', -1 + + return all_NEs + + +def refine_results(tokens, spans, stopwords): + all_spans = [] + for span_start, span_end, is_candidate in spans: + # remove stopwords + if not is_candidate: + while span_start < span_end and tokens[span_start].lower() in stopwords: + span_start += 1 + if span_start >= span_end: + continue + + # add prefix + if span_start > 0 and tokens[span_start - 1] in ['a', 'an', 'A', 'An', 'the', 'The']: + span_start -= 1 + + # convert token-level index into char-level index + span = ' '.join(tokens[span_start:span_end]) + span_start = len(' '.join(tokens[:span_start])) + 1 * min(1, span_start) # 1 for blank + span_end = span_start + len(span) + + all_spans.append((span, span_start, span_end)) + all_spans = sorted(all_spans, key=lambda x: (x[1], x[1] - x[2])) + + # remove overlap + refined_spans = [] + for span, span_start, span_end in all_spans: + flag = True + for _, start, end in refined_spans: + if start <= span_start < span_end <= end: + flag = False + break + if flag: + refined_spans.append((span, span_start, span_end)) + + return refined_spans + + +class SentenceParser: + def __init__(self, device='cuda:0', + ner_path="https://storage.googleapis.com/allennlp-public-models/ner-model-2020.02.10.tar.gz", + cp_path="https://storage.googleapis.com/allennlp-public-models/elmo-constituency-parser-2020.02.10.tar.gz"): + self.device = self.parse_device(device) + self.ner = Predictor.from_path(ner_path, cuda_device=self.device) + print('* ner loaded') + self.cp = Predictor.from_path(cp_path, cuda_device=self.device) + print('* constituency parser loaded') + self.lemmatizer = WordNetLemmatizer() + + # some heuristic rules can be added here + self.stopwords = set(nltk.corpus.stopwords.words('english')) + self.stopwords.update({'-', '\'s', 'try', 'tries', 'tried', 'trying', + 'become', 'becomes', 'became', 'becoming', + 'make', 'makes', 'made', 'making', 'call', 'called', 'calling', + 'put', 'ever', 'something', 'someone', 'sometime'}) + self.special_tokens = ['only', 'most', 'before', 'after', 'behind'] + for token in self.special_tokens: + if token in self.stopwords: self.stopwords.remove(token) + if 'won' in self.stopwords: self.stopwords.remove('won') + if 'own' in self.stopwords: self.stopwords.remove('own') + + def parse_device(self, device): + if 'cpu' in device: + return -1 + else: + dev = re.findall('\d+', device) + return 0 if len(dev) == 0 else int(dev[0]) + + def identify_NPs(self, text, candidate_NPs=None): + text = re.sub(r'\s+', ' ', text).strip() + if len(text) == 0: return {'text': '', 'NPs': [], 'verbs': [], 'adjs': []} + + cp_outputs = self.cp.predict(text) + ner_outputs = self.ner.predict(text) + tokens = cp_outputs['tokens'] + pos_tags = cp_outputs['pos_tags'] + ner_tags = ner_outputs['tags'] + tree = cp_outputs['hierplane_tree']['root'] + + # extract candidate noun phrases passed by user with token index + all_NPs = [] + candidate_NPs = [refine_entity(np).split() for np in candidate_NPs] if candidate_NPs else [] + for np in sorted(candidate_NPs, key=len, reverse=True): + np_start, np_end = find_sub_seq(tokens, np, 0, uncased=True, lemmatizer=self.lemmatizer) + while np_start != -1 and np_end != -1: + if not is_sub_seq(np_start, np_end, all_NPs): + all_NPs.append((np_start, np_end, True)) + np_start, np_end = find_sub_seq(tokens, np, np_end, uncased=True, lemmatizer=self.lemmatizer) + + # extract noun phrases from tree + def _get_bottom_NPs(children): + if 'children' not in children: + return None + if {'NP', 'OP', 'XP', 'QP'} & set(children['attributes']): + is_bottom = True + for child in children['children']: + if 'children' in child: + is_bottom = False + if is_bottom: + bottom_NPs.append(children['word'].split()) + else: + for child in children['children']: + _get_bottom_NPs(child) + else: + for child in children['children']: + _get_bottom_NPs(child) + bottom_NPs = [] + _get_bottom_NPs(tree) + + # find token indices of noun phrases + np_index = -1 + for np in bottom_NPs: + np_start, np_end = find_sub_seq(tokens, np, np_index + 1) + if not is_sub_seq(np_start, np_end, all_NPs): + all_NPs.append((np_start, np_end, False)) + np_index = np_end + + # extract named entities with token index + all_NEs = extract_named_entity(ner_tags) + + # extract verbs with token index + all_verbs = [] + for i, pos in enumerate(pos_tags): + if pos[0] == 'V': + if not is_sub_seq(i, i + 1, all_NPs) and not is_sub_seq(i, i + 1, all_NEs): + all_verbs.append((i, i + 1, False)) + + # extract modifiers with token index + all_modifiers = [] + for i, (token, pos) in enumerate(zip(tokens, pos_tags)): + if pos in ['JJ', 'RB']: # adj. and adv. + if not is_sub_seq(i, i + 1, all_NPs) and not is_sub_seq(i, i + 1, all_NEs): + all_modifiers.append((i, i + 1, False)) + elif token in self.special_tokens: + if not is_sub_seq(i, i + 1, all_NPs) and not is_sub_seq(i, i + 1, all_NEs): + all_modifiers.append((i, i + 1, False)) + + # split noun phrases with named entities + all_spans = [] + for np_start, np_end, np_is_candidate in all_NPs: + if np_is_candidate: # candidate noun phrases will be preserved + all_spans.append((np_start, np_end, np_is_candidate)) + else: + match = is_sub_seq(np_start, np_end, all_NEs) + if match: # if a noun phrase is a sub span of a named entity, the named entity will be preserved + all_spans.append(match) + else: # else if a named entity is a sub span of a noun phrase, the noun phrase will be split + index = np_start + for ne_start, ne_end, ne_is_candidate in all_NEs: + if np_start <= ne_start < ne_end <= np_end: + all_modifiers.append((index, ne_start, False)) + all_spans.append((ne_start, ne_end, ne_is_candidate)) + index = ne_end + all_spans.append((index, np_end, False)) + + # named entities without overlapping + for ne_start, ne_end, is_candidate in all_NEs: + if not is_sub_seq(ne_start, ne_end, all_spans): + all_spans.append((ne_start, ne_end, is_candidate)) + + all_spans = refine_results(tokens, all_spans, self.stopwords) + all_verbs = refine_results(tokens, all_verbs, self.stopwords) + all_modifiers = refine_results(tokens, all_modifiers, self.stopwords) + + return {'text': tree['word'], 'NPs': all_spans, 'verbs': all_verbs, 'adjs': all_modifiers} + + +if __name__ == '__main__': + import json + + print('Initializing sentence parser.') + client = SentenceParser(device='cpu') + + print('Parsing sentence.') + sentence = "The Africa Cup of Nations is held in odd - numbered years due to conflict with the World Cup . " + entities = ['Africa Cup of Nations', 'Africa_Cup_of_Nations', 'Africa Cup', 'Africa_Cup'] + results = client.identify_NPs(sentence, entities) + print(json.dumps(results, ensure_ascii=False, indent=4)) + + # import random + # from tqdm import tqdm + # from utils import read_json_lines, save_json + # + # print('Parsing file.') + # results = [] + # data = list(read_json_lines('data/train.jsonl')) + # random.shuffle(data) + # for entry in tqdm(data[:100]): + # results.append(client.identify_NPs(entry['claim'])) + # save_json(results, 'data/results.json') diff --git a/src/pproc_client/cjjpy.py b/src/pproc_client/cjjpy.py new file mode 100755 index 0000000000000000000000000000000000000000..2cc70b5e553924123810ab198c143bf7ee28e5d6 --- /dev/null +++ b/src/pproc_client/cjjpy.py @@ -0,0 +1,249 @@ +# -*- coding: utf-8 -*- + +''' +@Author : Jiangjie Chen +@Time : 2018/11/15 17:08 +@Contact: jjchen19@fudan.edu.cn +''' + +import re +import datetime +import os +import argparse +import logging +import traceback + +try: + import ujson as json +except: + import json + +HADOOP_BIN = 'PATH=/usr/bin/:$PATH hdfs' +FOR_PUBLIC = True + + +def LengthStats(filename): + len_list = [] + thresholds = [0.8, 0.9, 0.95, 0.99, 0.999] + with open(filename) as f: + for line in f: + len_list.append(len(line.strip().split())) + stats = { + 'Max': max(len_list), + 'Min': min(len_list), + 'Avg': round(sum(len_list) / len(len_list), 4), + } + len_list.sort() + for t in thresholds: + stats[f"Top-{t}"] = len_list[int(len(len_list) * t)] + + for k in stats: + print(f"- {k}: {stats[k]}") + return stats + + +class AttrDict(dict): + def __init__(self, *args, **kwargs): + super(AttrDict, self).__init__(*args, **kwargs) + self.__dict__ = self + + +def TraceBack(error_msg): + exc = traceback.format_exc() + msg = f'[Error]: {error_msg}.\n[Traceback]: {exc}' + return msg + + +def Now(): + return datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") + + +def AbsParentDir(file, parent='..', postfix=None): + ppath = os.path.abspath(file) + parent_level = parent.count('.') + while parent_level > 0: + ppath = os.path.dirname(ppath) + parent_level -= 1 + if postfix is not None: + return os.path.join(ppath, postfix) + else: + return ppath + + +def init_logger(log_file=None, log_file_level=logging.NOTSET, from_scratch=False): + from coloredlogs import ColoredFormatter + import tensorflow as tf + + fmt = "[%(asctime)s %(levelname)s] %(message)s" + log_format = ColoredFormatter(fmt=fmt) + # log_format = logging.Formatter() + logger = logging.getLogger() + logger.setLevel(log_file_level) + + console_handler = logging.StreamHandler() + console_handler.setFormatter(log_format) + logger.handlers = [console_handler] + + if log_file and log_file != '': + if from_scratch and tf.io.gfile.exists(log_file): + logger.warning('Removing previous log file: %s' % log_file) + tf.io.gfile.remove(log_file) + path = os.path.dirname(log_file) + os.makedirs(path, exist_ok=True) + file_handler = logging.FileHandler(log_file) + file_handler.setLevel(log_file_level) + file_handler.setFormatter(log_format) + logger.addHandler(file_handler) + + return logger + + +def OverWriteCjjPy(root='.'): + # import difflib + # diff = difflib.HtmlDiff() + cnt = 0 + golden_cjjpy = os.path.join(root, 'cjjpy.py') + # golden_content = open(golden_cjjpy).readlines() + for dir, folder, file in os.walk(root): + for f in file: + if f == 'cjjpy.py': + cjjpy = '%s/%s' % (dir, f) + # content = open(cjjpy).readlines() + # d = diff.make_file(golden_content, content) + cnt += 1 + print('[%d]: %s' % (cnt, cjjpy)) + os.system('cp %s %s' % (golden_cjjpy, cjjpy)) + + +def ChangeFileFormat(filename, new_fmt): + assert type(filename) is str and type(new_fmt) is str + spt = filename.split('.') + if len(spt) == 0: + return filename + else: + return filename.replace('.' + spt[-1], new_fmt) + + +def CountLines(fname): + with open(fname, 'rb') as f: + count = 0 + last_data = '\n' + while True: + data = f.read(0x400000) + if not data: + break + count += data.count(b'\n') + last_data = data + if last_data[-1:] != b'\n': + count += 1 # Remove this if a wc-like count is needed + return count + + +def GetDate(): + return str(datetime.datetime.now())[5:10].replace('-', '') + + +def TimeClock(seconds): + sec = int(seconds) + hour = int(sec / 3600) + min = int((sec - hour * 3600) / 60) + ssec = float(seconds) - hour * 3600 - min * 60 + # return '%dh %dm %.2fs' % (hour, min, ssec) + return '{0:>2d}h{1:>3d}m{2:>6.2f}s'.format(hour, min, ssec) + + +def StripAll(text): + return text.strip().replace('\t', '').replace('\n', '').replace(' ', '') + + +def GetBracket(text, bracket, en_br=False): + # input should be aa(bb)cc, True for bracket, False for text + if bracket: + try: + return re.findall('\((.*?)\)', text.strip())[-1] + except: + return '' + else: + if en_br: + text = re.sub('\(.*?\)', '', text.strip()) + return re.sub('(.*?)', '', text.strip()) + + +def CharLang(uchar, lang): + assert lang.lower() in ['en', 'cn', 'zh'] + if lang.lower() in ['cn', 'zh']: + if uchar >= '\u4e00' and uchar <= '\u9fa5': + return True + else: + return False + elif lang.lower() == 'en': + if (uchar <= 'Z' and uchar >= 'A') or (uchar <= 'z' and uchar >= 'a'): + return True + else: + return False + else: + raise NotImplementedError + + +def WordLang(word, lang): + for i in word.strip(): + if i.isspace(): continue + if not CharLang(i, lang): + return False + return True + + +def SortDict(_dict, reverse=True): + assert type(_dict) is dict + return sorted(_dict.items(), key=lambda d: d[1], reverse=reverse) + + +def lark(content='test'): + print(content) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + + parser.add_argument('--diff', nargs=2, + help='show difference between two files, shown in downloads/diff.html') + parser.add_argument('--de_unicode', action='store_true', default=False, + help='remove unicode characters') + parser.add_argument('--link_entity', action='store_true', default=False, + help='') + parser.add_argument('--max_comm_len', action='store_true', default=False, + help='') + parser.add_argument('--search', nargs=2, + help='search key from file, 2 args: file name & key') + parser.add_argument('--email', nargs=2, + help='sending emails, 2 args: subject & content') + parser.add_argument('--overwrite', action='store_true', default=None, + help='overwrite all cjjpy under given *dir* based on *dir*/cjjpy.py') + parser.add_argument('--replace', nargs=3, + help='replace char, 3 args: file name & replaced char & replacer char') + parser.add_argument('--lark', nargs=1) + parser.add_argument('--get_hdfs', nargs=2, + help='easy copy from hdfs to local fs, 2 args: remote_file/dir & local_dir') + parser.add_argument('--put_hdfs', nargs=2, + help='easy put from local fs to hdfs, 2 args: local_file/dir & remote_dir') + parser.add_argument('--length_stats', nargs=1, + help='simple token lengths distribution of a line-by-line file') + + args = parser.parse_args() + + if args.overwrite: + print('* Overwriting cjjpy...') + OverWriteCjjPy() + + if args.lark: + try: + content = args.lark[0] + except: + content = 'running complete' + print(f'* Larking "{content}"...') + lark(content) + + if args.length_stats: + file = args.length_stats[0] + print(f'* Working on {file} lengths statistics...') + LengthStats(file) diff --git a/src/pproc_client/fix_predicted_evidence.py b/src/pproc_client/fix_predicted_evidence.py new file mode 100644 index 0000000000000000000000000000000000000000..dfa880837882b87da27ec6934bb2811d4b67d96d --- /dev/null +++ b/src/pproc_client/fix_predicted_evidence.py @@ -0,0 +1,39 @@ +# -*- coding: utf-8 -*- + +""" +@Author : Jiangjie Chen +@Time : 2020/12/31 18:38 +@Contact : jjchen19@fudan.edu.cn +@Description: +""" + +import tensorflow as tf +import cjjpy as cjj +import ujson as json +from hparams import * +import sys, os +sys.path.append('..') +from dataloaders import FEVERLoader + + +def rewrite_file(filename, loader): + with tf.io.gfile.GFile(filename) as f: + data = f.readlines() + + with tf.io.gfile.GFile(filename, 'w') as fo: + for line in data: + js = json.loads(line) + if js.get('predicted_evidence') is None: + js['predicted_evidence'] = [[ev[0], ev[1]] for ev in loader[js['id']]['bert_evidence']] + fo.write(json.dumps(js) + '\n') + print(f'* {filename} rewritten') + + +for role in ['eval', 'test']: + floader = FEVERLoader(role) + floader.load_fever('bert', clean_load=False) + filename = os.path.join(AG_PREFIX.format(version='v5'), CACHED_EVIDENTIAL_FILE.format(role=role, k_cand=4)) + rewrite_file(filename, floader) + final_output = os.path.join(cjj.AbsParentDir(AG_PREFIX.format(version='v5'), '.'), + FINAL_FILE.format(role=role)) + rewrite_file(final_output, floader) \ No newline at end of file diff --git a/src/pproc_client/hparams.py b/src/pproc_client/hparams.py new file mode 100644 index 0000000000000000000000000000000000000000..3fc07c835da5dc13b46515e2e4f2b21502b48f26 --- /dev/null +++ b/src/pproc_client/hparams.py @@ -0,0 +1,19 @@ +# -*- coding: utf-8 -*- + +""" +@Author : Jiangjie Chen +@Time : 2020/12/10 14:07 +@Contact : jjchen19@fudan.edu.cn +@Description: +""" + +import os + + +QG_PREFIX = '%s/data/fact_checking/{version}/cache_qg/' % os.environ['PJ_HOME'] +AG_PREFIX = '%s/data/fact_checking/{version}/cache_ag/' % os.environ['PJ_HOME'] + +CACHED_QUESTION_FILE = 'question.{role}.cache' +CACHED_ANSEWR_FILE = 'answer.{role}.cache' +CACHED_EVIDENTIAL_FILE = 'evidential.k_{k_cand}.{role}.cache' +FINAL_FILE = '{role}.json' \ No newline at end of file diff --git a/src/pproc_client/pproc_evidential.py b/src/pproc_client/pproc_evidential.py new file mode 100644 index 0000000000000000000000000000000000000000..5d8c34cdd6aa54b72dca346bebc78877b11aaf09 --- /dev/null +++ b/src/pproc_client/pproc_evidential.py @@ -0,0 +1,122 @@ +# -*- coding: utf-8 -*- + +''' +@Author : Jiangjie Chen +@Time : 2020/8/16 16:51 +@Contact : jjchen19@fudan.edu.cn +@Description: +''' + +import cjjpy as cjj +import sys +import tensorflow as tf +import ujson as json +import argparse + +try: + from .hparams import * + from ..dataloaders import FEVERLoader + from ..mrc_client.answer_generator import AnswerGenerator, assemble_answers_to_one +except: + sys.path.append(cjj.AbsParentDir(__file__, '..')) + from hparams import * + from dataloaders import FEVERLoader + from mrc_client.answer_generator import AnswerGenerator, assemble_answers_to_one + + +def prepare_evidential(version, role, mrc_model_path, evi_key, + k_cand=4, batch_size=64): + ''' + After pproc_questions (prepare_answers, prepare_questions) + :return: + { + 'id': id, + 'label': x, + 'claim': c, + 'evidence': [e1, e2, ...], + 'answers': [a1, a2, ...], + 'questions': [q1, q2, ...], + 'cloze_qs': [q1, q2, ...], #m + 'regular_qs': [q1, q2, ...], #m + 'answer_roles': [noun, noun, adj, verb, ...], # m + 'evidential': [[b1, b2, ... bk]_1, [...]_2, ...], + 'evidential_assembled': [], # m + } + ''' + tf.io.gfile.makedirs(AG_PREFIX.format(version=version)) + cached_evidential = AG_PREFIX.format(version=version) \ + + CACHED_EVIDENTIAL_FILE.format(k_cand=k_cand, role=role) + cached_question = QG_PREFIX.format(version=version) + CACHED_QUESTION_FILE.format(role=role) + + ag = AnswerGenerator(mrc_model_path) + ag.init_model() + with tf.io.gfile.GFile(cached_question) as f, \ + tf.io.gfile.GFile(cached_evidential, 'w') as fo: + data = f.read().splitlines() + examples, ids = [], [] + data_dict = {} + for line in data: + js = json.loads(line) + data_dict[js['id']] = js + for q in js['questions']: + ids.append(js['id']) + ex = ag.assemble(q, " ".join(js["evidence"])) + examples.append(ex) + + predicted = ag.generate(examples, num_beams=k_cand, num_return_sequences=k_cand, + batch_size=batch_size, verbose=True) + assert len(predicted) == len(examples) + + # follow by strict order + for answers, id in zip(predicted, ids): + if 'evidential' in data_dict[id]: + # [b1, b2, ..., bk] + data_dict[id]['evidential'].append(answers) + else: + data_dict[id]['evidential'] = [answers] + + _ = [_sanity_check(data_dict[k]) for k in data_dict] + + if role in ['eval', 'test']: + floader = FEVERLoader(role) + print('Loading FEVER...') + floader.load_fever(evi_key.split('_')[0], clean_load=False) + + for k in data_dict: + js = data_dict[k] + if role in ['eval', 'test']: + if js.get('predicted_evidence') is None: + js['predicted_evidence'] = [[ev[0], ev[1]] for ev in floader[js['id']][evi_key]] + fo.write(json.dumps(js) + '\n') + + final_output = os.path.join(cjj.AbsParentDir(AG_PREFIX.format(version=version), '.'), + FINAL_FILE.format(role=role)) + + tf.io.gfile.copy(cached_evidential, final_output) + + cjj.lark(f'Final baked in {final_output}') + return final_output + + +def _sanity_check(js): + assert len(js['evidential']) == len(js['questions']) == len(js['answers']), js + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--k_cand', '-k', type=int, default=4, + help='number of candidate answer') + parser.add_argument('--batch_size', '-b', type=int, default=64) + parser.add_argument('--version', '-v', type=str, help='v1, v2, ...', default='v5') + parser.add_argument('--roles', nargs='+', required=True, + help='train val test eval') + parser.add_argument('--evi_key', '-e', type=str, choices=['bert_evidence'], default='bert_evidence') + parser.add_argument('--mrc_model_name', '-m', type=str, required=True, + help='Absolute path of the mrc model') + args = parser.parse_args() + + server = None + + for role in args.roles: + evidential_output = prepare_evidential(args.version, role, args.mrc_model_name, args.evi_key, + args.k_cand, args.batch_size) diff --git a/src/pproc_client/pproc_mrc.py b/src/pproc_client/pproc_mrc.py new file mode 100644 index 0000000000000000000000000000000000000000..a72d2f4134fb0177e6ea40228fd328f6fd2728f4 --- /dev/null +++ b/src/pproc_client/pproc_mrc.py @@ -0,0 +1,85 @@ +# -*- coding: utf-8 -*- + +''' +@Author : Jiangjie Chen +@Time : 2020/7/20 17:54 +@Contact : jjchen19@fudan.edu.cn +@Description: +''' + +import os +import ujson as json +import tensorflow as tf +import argparse +import random +from transformers import BartTokenizer + +try: + from .hparams import CACHED_QUESTION_FILE, QG_PREFIX +except: + from hparams import CACHED_QUESTION_FILE, QG_PREFIX + + +random.seed(1111) + + +def pproc_seq2seq(input_file, output_dir, role): + ''' + :param input_file: + { + 'id': id, + 'claim': c, + 'label': x, + 'evidence': [e1, e2, ...], # n + 'answers': [a1, a2, ...], # m + 'questions': [q1, q2, ...], # m + 'cloze_qs': [q1, q2, ...], #m + 'regular_qs': [q1, q2, ...], #m + 'answer_roles': [noun, noun, adj, verb, ...] # m + } + ''' + assert role in ['val', 'test', 'train'], role + + use_rag = 'v6' in input_file + if not use_rag: + tokenizer = BartTokenizer.from_pretrained('facebook/bart-base') + + tf.io.gfile.makedirs(output_dir) + src_fname = os.path.join(output_dir, f'{role}.source') + tgt_fname = os.path.join(output_dir, f'{role}.target') + + with tf.io.gfile.GFile(input_file) as fin, \ + tf.io.gfile.GFile(src_fname, 'w') as srcf, \ + tf.io.gfile.GFile(tgt_fname, 'w') as tgtf: + data = fin.readlines() + for line in data: + js = json.loads(line) + if js['label'] == 'SUPPORTS': + evidence = ' '.join(js['evidence']) + questions = js['questions'] + i = random.randint(0, len(questions) - 1) + if use_rag: + srcf.write(f'{questions[i]}\n') + else: + srcf.write(f'{questions[i]} {tokenizer.sep_token} {evidence}\n') + tgtf.write(js['answers'][i][0] + '\n') + + return src_fname, tgt_fname + + +def pproc_for_mrc(output_dir, version): + assert version in ['v5'] + for role in ['val', 'train', 'test']: + _role = 'val' if role == 'test' else role + input_file = os.path.join(QG_PREFIX.format(version=version), + CACHED_QUESTION_FILE.format(role=_role)) + pproc_seq2seq(input_file, output_dir, role) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--output_dir', '-o', required=True, default='data/mrc_seq2seq_v5', + help='data path, e.g. data/mrc_seq2seq_v5') + parser.add_argument('--version', '-v', type=str, default='v5') + args = parser.parse_args() + pproc_for_mrc(args.output_dir, args.version) diff --git a/src/pproc_client/pproc_nli_labels.py b/src/pproc_client/pproc_nli_labels.py new file mode 100644 index 0000000000000000000000000000000000000000..3fde9803da2e0fe28f35297f8a56eb70a8b528ab --- /dev/null +++ b/src/pproc_client/pproc_nli_labels.py @@ -0,0 +1,65 @@ +# -*- coding: utf-8 -*- + +""" +@Author : Jiangjie Chen +@Time : 2021/5/7 19:39 +@Contact : jjchen19@fudan.edu.cn +@Description: +""" + +import sys +import os +import cjjpy as cjj +from tqdm import tqdm +import ujson as json +import argparse +from transformers import AutoModelForSequenceClassification, AutoTokenizer + +try: + from ..mrc_client.answer_generator import assemble_answers_to_one, chunks +except: + sys.path.append('..') + from mrc_client.answer_generator import assemble_answers_to_one, chunks + + +def load_model(model_name_or_path, device='cuda'): + model = AutoModelForSequenceClassification.from_pretrained(model_name_or_path).to(device) + tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) + return model, tokenizer + + +def run_nli_line(line, model, tokenizer): + js = json.loads(line) if isinstance(line, str) else line + js = assemble_answers_to_one(js, 1) + premises, hypotheses = [], [] + for ev in js['evidential_assembled']: + premises.append(ev) + hypotheses.append(js['claim']) + nli_labels = [] + for p_chunk, h_chunk in zip(chunks(premises, 8), chunks(hypotheses, 8)): + inputs = tokenizer(p_chunk, h_chunk, return_tensors='pt', padding=True, truncation=True).to(model.device) + s = model(**inputs).logits.tolist() + nli_labels += s + assert len(nli_labels) == len(js['answers']) + js['nli_labels'] = nli_labels + return js + + +def run(filename, model, tokenizer): + with open(filename) as f: + data = f.readlines() + with open(filename, 'w') as fo: + for line in tqdm(data, desc=os.path.basename(filename)): + js = run_nli_line(line, model, tokenizer) + fo.write(json.dumps(js) + '\n') + cjj.lark(f'{filename} done.') + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--model_name_or_path', '-m', type=str, required=True) + parser.add_argument('--input', '-i', type=str, required=True) + args = parser.parse_args() + + model, tokenizer = load_model(args.model_name_or_path) + run(args.input, model, tokenizer) \ No newline at end of file diff --git a/src/pproc_client/pproc_questions.py b/src/pproc_client/pproc_questions.py new file mode 100644 index 0000000000000000000000000000000000000000..1a6d3c5fd1927e86c608b22f1d73891432b25a83 --- /dev/null +++ b/src/pproc_client/pproc_questions.py @@ -0,0 +1,170 @@ +# -*- coding: utf-8 -*- + +''' +@Author : Jiangjie Chen +@Time : 2020/7/25 18:23 +@Contact : jjchen19@fudan.edu.cn +@Description: +''' + +import os +import cjjpy as cjj +import sys +import tensorflow as tf +import ujson as json +from tqdm import tqdm +import argparse + +try: + sys.path.append(cjj.AbsParentDir(__file__, '..')) + from hparams import * + from pseudo_multiproc_toolkit import * + from dataloaders import FEVERLoader + from parsing_client.sentence_parser import SentenceParser, deal_bracket + from qg_client.question_generator import QuestionGenerator +except: + from .hparams import * + from .pseudo_multiproc_toolkit import * + from ..dataloaders import FEVERLoader + from ..parsing_client.sentence_parser import SentenceParser, deal_bracket + from ..qg_client.question_generator import QuestionGenerator + + +def prepare_answers(version, role, evi_key='bert_evidence', overwrite=False): + ''' + :return + { + 'id': id, + 'claim': c, + 'label': x, + 'evidence': [e1, e2, ...], # n + 'answers': [a1, a2, ...], # m + 'answer_roles': [noun, noun, adj, verb, ...] # m + } + ''' + assert role in ['val', 'test', 'train', 'eval'], role + + def _proc_one(js): + js.pop('all_evidence') + evidence = [deal_bracket(ev[2], True, ev[0]) for ev in js[evi_key]] + results = sent_client.identify_NPs(deal_bracket(js['claim'], True), + candidate_NPs=[x[0] for x in js[evi_key]]) + NPs = results['NPs'] + claim = results['text'] + verbs = results['verbs'] + adjs = results['adjs'] + _cache = {'id': js['id'], + 'claim': claim, + 'evidence': evidence, + 'answers': NPs + verbs + adjs, + 'answer_roles': ['noun'] * len(NPs) + ['verb'] * len(verbs) + ['adj'] * len(adjs)} + if js.get('label'): + _cache.update({'label': js['label']}) + return _cache + + cached_ = QG_PREFIX.format(version=version) + CACHED_ANSEWR_FILE.format(role=role) + tf.io.gfile.makedirs(QG_PREFIX.format(version=version)) + if tf.io.gfile.exists(cached_) and not overwrite: + print(f'* Skipped, exising {cached_}') + return cached_ + + sent_client = SentenceParser(device='cuda:0') + floader = FEVERLoader(role) + floader.load_fever(evi_key.split('_')[0]) + + with tf.io.gfile.GFile(cached_, 'w') as f: + for id in tqdm(floader, desc=f'{role} answer'): + res = _proc_one(floader[id]) + f.write(json.dumps(res) + '\n') + + cjj.lark(f'* NPs baked in {cached_}') + return cached_ + + +def prepare_questions(version, role, qg_model='t5', batch_size=64, overwrite=False): + ''' + After prepare_nps + :return + { + 'id': id, + 'claim': c, + 'label': x, + 'evidence': [e1, e2, ...], # n + 'answers': [a1, a2, ...], # m + 'questions': [q1, q2, ...], # m + 'cloze_qs': [q1, q2, ...], #m + 'regular_qs': [q1, q2, ...], #m + 'answer_roles': [noun, noun, adj, verb, ...] # m + } + ''' + cached_answer = QG_PREFIX.format(version=version) + CACHED_ANSEWR_FILE.format(role=role) + cached_question = QG_PREFIX.format(version=version) + CACHED_QUESTION_FILE.format(role=role) + + if tf.io.gfile.exists(cached_question) and not overwrite: + print(f'* Skipped, existing {cached_question}') + return cached_question + + qg_client = QuestionGenerator(qg_model) + with tf.io.gfile.GFile(cached_answer, 'r') as f, \ + tf.io.gfile.GFile(cached_question, 'w') as fo: + data = f.read().splitlines() + data_dict = {} + _cache = [] + for line in data: + js = json.loads(line) + data_dict[js['id']] = js + if len(js['answers']) == 0: + # TODO: hack empty answer + print('Empty answer:', js) + pseudo_answer = js['claim'].split()[0] + js['answers'] = [(pseudo_answer, 0, len(pseudo_answer))] + js['answer_roles'] = ['noun'] + for answer in js['answers']: + _cache.append((js['claim'], [answer], js['id'])) + print(_cache[:5]) + + qa_pairs = qg_client.generate([(x, y) for x, y, z in _cache], batch_size=batch_size) + print(qa_pairs[:5]) + + for (q, clz_q, a), (_, _, id) in zip(qa_pairs, _cache): + if 'questions' in data_dict[id]: + data_dict[id]['cloze_qs'].append(clz_q) + data_dict[id]['regular_qs'].append(q) + data_dict[id]['questions'].append(qg_client.assemble_question(q, clz_q)) + else: + data_dict[id]['cloze_qs'] = [clz_q] + data_dict[id]['regular_qs'] = [q] + data_dict[id]['questions'] = [qg_client.assemble_question(q, clz_q)] + + _ = [_sanity_check(data_dict[k]) for k in data_dict] + + for k in data_dict: + fo.write(json.dumps(data_dict[k]) + '\n') + + cjj.lark(f'* Questions baked in {cached_question}') + return cached_question + + +def _sanity_check(js): + try: + assert len(js['questions']) == len(js['answers']) + assert len(js['answers']) == len(js['answer_roles']) + except: + print(js) + raise Exception + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--overwrite', action='store_true') + parser.add_argument('--batch_size', '-b', type=int, default=64) + parser.add_argument('--evi_key', '-e', type=str, default='bert_evidence') + parser.add_argument('--version', '-v', type=str, help='v1, v2, ...', default='v5') + parser.add_argument('--roles', nargs='+', required=True, + help='train val test eval') + parser.add_argument('--qg_model', '-m', type=str, default='t5') + args = parser.parse_args() + + for role in args.roles: + prepare_answers(args.version, role, args.evi_key, args.overwrite) + prepare_questions(args.version, role, args.qg_model, args.batch_size, args.overwrite) diff --git a/src/pproc_client/pseudo_multiproc_toolkit.py b/src/pproc_client/pseudo_multiproc_toolkit.py new file mode 100644 index 0000000000000000000000000000000000000000..864ff4daf4456d239a9f7d4312939c38966752c3 --- /dev/null +++ b/src/pproc_client/pseudo_multiproc_toolkit.py @@ -0,0 +1,141 @@ +# -*- coding: utf-8 -*- + +""" +@Author : Jiangjie Chen +@Time : 2020/6/8 22:17 +@Contact : jjchen19@fudan.edu.cn +@Description: +""" + +import re +import os +import ujson as json +import tensorflow as tf + + +def args_to_shell(args): + args_dict = vars(args) + shell_args = '' + for k, v in args_dict.items(): + if isinstance(v, bool): + if v: shell_args += f'--{k} ' + else: + if isinstance(v, list): + v = ' '.join([str(x) for x in v]) + shell_args += f'--{k} {v} ' + return shell_args + + +def _is_proc_file(fname): + return re.search('._\d+_proc$', fname) is not None + + +def _restore_fname_from_proc(fname): + if _is_proc_file(fname): + return '.'.join(fname.split('.')[:-1]) + else: + return fname + + +def rename_fname_by_proc(fname: str, proc_num: int): + if not _is_proc_file(fname): + fname = fname + f'._{proc_num}_proc' + return fname + + +def chunks(lst, n): + """Yield successive n-sized chunks from lst.""" + for i in range(0, len(lst), n): + yield lst[i : i + n] + + +# def slice_dataset_json(in_fname, slice_num): +# with tf.io.gfile.GFile(in_fname) as fin: +# data = json.load(fin) +# sliced_data = chunks(data, slice_num) +# datasets = [] +# for i in range(len(list(sliced_data))): +# proc_fname = rename_fname_by_proc(in_fname, i) +# with tf.io.gfile.GFile(proc_fname, 'w') as f: +# js = [] +# for line in sliced_data[i]: +# js.append(line) +# json.dump(js, f) +# datasets.append(proc_fname) +# return datasets + + +def slice_filenames(in_fname, slice_num): + sliced_f = [] + for i in range(slice_num): + sliced_f.append(rename_fname_by_proc(in_fname, i)) + return sliced_f + + +def slice_dataset(in_fname, slice_num): + ''' + :param in_fname: + :param slice_num: + :return: sliced dataset filenames + ''' + with tf.io.gfile.GFile(in_fname) as fin: + data = fin.readlines() + _sliced_data = list(chunks(data, len(data) // slice_num)) + if len(_sliced_data) == slice_num + 1: # loose ends + sliced_data = _sliced_data[:slice_num] + sliced_data[-1] += _sliced_data[-1] + else: + sliced_data = _sliced_data + datasets = [] + for i in range(len(list(sliced_data))): + proc_fname = rename_fname_by_proc(in_fname, i) + with tf.io.gfile.GFile(proc_fname, 'w') as f: + for line in sliced_data[i]: + f.write(line) + datasets.append(proc_fname) + return datasets + + +def union_multiproc_files(files, overwrite=False): + real_out_fname = None + for i, file in enumerate(files): + if not _is_proc_file(file): + raise FileNotFoundError(file) + else: + _out_fname = _restore_fname_from_proc(file) + if i > 0 and _out_fname != real_out_fname: + raise ValueError(file, real_out_fname) + real_out_fname = _out_fname + + if real_out_fname is None: + raise FileNotFoundError(real_out_fname) + + if tf.io.gfile.exists(real_out_fname) and not overwrite: + print(f'Skip {real_out_fname}, as it already exists.') + else: + with tf.io.gfile.GFile(real_out_fname, 'w') as fo: + for file in files: + if _is_proc_file(file): + with tf.io.gfile.GFile(file) as f: + data = f.readlines() + for line in data: + fo.write(line) + print(f'{files} united into {real_out_fname}.') + return real_out_fname + + +def clean_multiproc_files(files): + for file in files: + if _is_proc_file(file): + if tf.io.gfile.exists(file): + print(f'Removing {file}...') + tf.io.gfile.remove(file) + else: + print(f'Removing {file}, but does not exists.') + + +if __name__ == '__main__': + test_file = 'cjjpy.py' + sliced_files = slice_dataset(test_file, 2) + file = union_multiproc_files(sliced_files) + clean_multiproc_files(sliced_files) \ No newline at end of file diff --git a/src/qg_client/cjjpy.py b/src/qg_client/cjjpy.py new file mode 100755 index 0000000000000000000000000000000000000000..2cc70b5e553924123810ab198c143bf7ee28e5d6 --- /dev/null +++ b/src/qg_client/cjjpy.py @@ -0,0 +1,249 @@ +# -*- coding: utf-8 -*- + +''' +@Author : Jiangjie Chen +@Time : 2018/11/15 17:08 +@Contact: jjchen19@fudan.edu.cn +''' + +import re +import datetime +import os +import argparse +import logging +import traceback + +try: + import ujson as json +except: + import json + +HADOOP_BIN = 'PATH=/usr/bin/:$PATH hdfs' +FOR_PUBLIC = True + + +def LengthStats(filename): + len_list = [] + thresholds = [0.8, 0.9, 0.95, 0.99, 0.999] + with open(filename) as f: + for line in f: + len_list.append(len(line.strip().split())) + stats = { + 'Max': max(len_list), + 'Min': min(len_list), + 'Avg': round(sum(len_list) / len(len_list), 4), + } + len_list.sort() + for t in thresholds: + stats[f"Top-{t}"] = len_list[int(len(len_list) * t)] + + for k in stats: + print(f"- {k}: {stats[k]}") + return stats + + +class AttrDict(dict): + def __init__(self, *args, **kwargs): + super(AttrDict, self).__init__(*args, **kwargs) + self.__dict__ = self + + +def TraceBack(error_msg): + exc = traceback.format_exc() + msg = f'[Error]: {error_msg}.\n[Traceback]: {exc}' + return msg + + +def Now(): + return datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") + + +def AbsParentDir(file, parent='..', postfix=None): + ppath = os.path.abspath(file) + parent_level = parent.count('.') + while parent_level > 0: + ppath = os.path.dirname(ppath) + parent_level -= 1 + if postfix is not None: + return os.path.join(ppath, postfix) + else: + return ppath + + +def init_logger(log_file=None, log_file_level=logging.NOTSET, from_scratch=False): + from coloredlogs import ColoredFormatter + import tensorflow as tf + + fmt = "[%(asctime)s %(levelname)s] %(message)s" + log_format = ColoredFormatter(fmt=fmt) + # log_format = logging.Formatter() + logger = logging.getLogger() + logger.setLevel(log_file_level) + + console_handler = logging.StreamHandler() + console_handler.setFormatter(log_format) + logger.handlers = [console_handler] + + if log_file and log_file != '': + if from_scratch and tf.io.gfile.exists(log_file): + logger.warning('Removing previous log file: %s' % log_file) + tf.io.gfile.remove(log_file) + path = os.path.dirname(log_file) + os.makedirs(path, exist_ok=True) + file_handler = logging.FileHandler(log_file) + file_handler.setLevel(log_file_level) + file_handler.setFormatter(log_format) + logger.addHandler(file_handler) + + return logger + + +def OverWriteCjjPy(root='.'): + # import difflib + # diff = difflib.HtmlDiff() + cnt = 0 + golden_cjjpy = os.path.join(root, 'cjjpy.py') + # golden_content = open(golden_cjjpy).readlines() + for dir, folder, file in os.walk(root): + for f in file: + if f == 'cjjpy.py': + cjjpy = '%s/%s' % (dir, f) + # content = open(cjjpy).readlines() + # d = diff.make_file(golden_content, content) + cnt += 1 + print('[%d]: %s' % (cnt, cjjpy)) + os.system('cp %s %s' % (golden_cjjpy, cjjpy)) + + +def ChangeFileFormat(filename, new_fmt): + assert type(filename) is str and type(new_fmt) is str + spt = filename.split('.') + if len(spt) == 0: + return filename + else: + return filename.replace('.' + spt[-1], new_fmt) + + +def CountLines(fname): + with open(fname, 'rb') as f: + count = 0 + last_data = '\n' + while True: + data = f.read(0x400000) + if not data: + break + count += data.count(b'\n') + last_data = data + if last_data[-1:] != b'\n': + count += 1 # Remove this if a wc-like count is needed + return count + + +def GetDate(): + return str(datetime.datetime.now())[5:10].replace('-', '') + + +def TimeClock(seconds): + sec = int(seconds) + hour = int(sec / 3600) + min = int((sec - hour * 3600) / 60) + ssec = float(seconds) - hour * 3600 - min * 60 + # return '%dh %dm %.2fs' % (hour, min, ssec) + return '{0:>2d}h{1:>3d}m{2:>6.2f}s'.format(hour, min, ssec) + + +def StripAll(text): + return text.strip().replace('\t', '').replace('\n', '').replace(' ', '') + + +def GetBracket(text, bracket, en_br=False): + # input should be aa(bb)cc, True for bracket, False for text + if bracket: + try: + return re.findall('\((.*?)\)', text.strip())[-1] + except: + return '' + else: + if en_br: + text = re.sub('\(.*?\)', '', text.strip()) + return re.sub('(.*?)', '', text.strip()) + + +def CharLang(uchar, lang): + assert lang.lower() in ['en', 'cn', 'zh'] + if lang.lower() in ['cn', 'zh']: + if uchar >= '\u4e00' and uchar <= '\u9fa5': + return True + else: + return False + elif lang.lower() == 'en': + if (uchar <= 'Z' and uchar >= 'A') or (uchar <= 'z' and uchar >= 'a'): + return True + else: + return False + else: + raise NotImplementedError + + +def WordLang(word, lang): + for i in word.strip(): + if i.isspace(): continue + if not CharLang(i, lang): + return False + return True + + +def SortDict(_dict, reverse=True): + assert type(_dict) is dict + return sorted(_dict.items(), key=lambda d: d[1], reverse=reverse) + + +def lark(content='test'): + print(content) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + + parser.add_argument('--diff', nargs=2, + help='show difference between two files, shown in downloads/diff.html') + parser.add_argument('--de_unicode', action='store_true', default=False, + help='remove unicode characters') + parser.add_argument('--link_entity', action='store_true', default=False, + help='') + parser.add_argument('--max_comm_len', action='store_true', default=False, + help='') + parser.add_argument('--search', nargs=2, + help='search key from file, 2 args: file name & key') + parser.add_argument('--email', nargs=2, + help='sending emails, 2 args: subject & content') + parser.add_argument('--overwrite', action='store_true', default=None, + help='overwrite all cjjpy under given *dir* based on *dir*/cjjpy.py') + parser.add_argument('--replace', nargs=3, + help='replace char, 3 args: file name & replaced char & replacer char') + parser.add_argument('--lark', nargs=1) + parser.add_argument('--get_hdfs', nargs=2, + help='easy copy from hdfs to local fs, 2 args: remote_file/dir & local_dir') + parser.add_argument('--put_hdfs', nargs=2, + help='easy put from local fs to hdfs, 2 args: local_file/dir & remote_dir') + parser.add_argument('--length_stats', nargs=1, + help='simple token lengths distribution of a line-by-line file') + + args = parser.parse_args() + + if args.overwrite: + print('* Overwriting cjjpy...') + OverWriteCjjPy() + + if args.lark: + try: + content = args.lark[0] + except: + content = 'running complete' + print(f'* Larking "{content}"...') + lark(content) + + if args.length_stats: + file = args.length_stats[0] + print(f'* Working on {file} lengths statistics...') + LengthStats(file) diff --git a/src/qg_client/question_generator.py b/src/qg_client/question_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..c133f38889335aa9fc63c45ed3498ac28060aeb8 --- /dev/null +++ b/src/qg_client/question_generator.py @@ -0,0 +1,104 @@ +# -*- coding: utf-8 -*- + +''' +@Author : Jiangjie Chen +@Time : 2020/7/29 21:50 +@Contact : jjchen19@fudan.edu.cn +@Description: +''' + +import random +import cjjpy as cjj +import sys, os +import torch +from tqdm import tqdm + +try: + from .t5_qg.generator import Generator +except: + sys.path.append(cjj.AbsParentDir(__file__, '.')) + from t5_qg.generator import Generator + + +def chunks(lst, n): + """Yield successive n-sized chunks from lst.""" + for i in range(0, len(lst), n): + yield lst[i: i + n] + + +class QuestionGenerator: + def __init__(self, model, prefix=None, verbose=True): + assert model in ['t5'] + self.verbose = verbose + prefix = f'{prefix}/models/question_generation/t5-base-qg-hl/' if prefix else None + self.qg = Generator('valhalla/t5-base-qg-hl', prefix, + device='cuda' if torch.cuda.is_available() else 'cpu', + verbose=self.verbose) + + def _clean_input_lines(self, input_lines): + # Only use the first option + if isinstance(input_lines[0][1], tuple) and len(input_lines[0][1]) == 3: + input_lines = list(map(lambda x: (x[0], [x[1]]), input_lines)) + return input_lines + + def generate(self, input_lines: list, sample_num=1, batch_size=128, mask_token=''): + ''' + :param input_lines: List([text, options=[('answer', 0, 1), (x, y, z), ...]]) + :param sample_num: default as 1, as usually only provide one option. + :return: List((regular_q, cloze_q, a)) + ''' + qa_pairs = [] + if len(input_lines) == 0: + return qa_pairs + input_lines = self._clean_input_lines(input_lines) + ques_chunk = [] + + for text, options in input_lines: + masked_qa = self.mask_text(text, options, sample_num=sample_num, mask_token=mask_token) + for q, a in masked_qa: + ques_chunk.append({'context': text, 'answer': a, 'cloze_q': q}) + + ques_pairs = self.qg(ques_chunk, batch_size=batch_size) + iter = tqdm(zip(ques_pairs, ques_chunk), desc='Replacing') \ + if self.verbose else zip(ques_pairs, ques_chunk) + for qa, mq in iter: + q = qa['questions'][0] + a = qa['answer'] + q = q.replace(a[0], mask_token) + qa_pairs.append((q, mq['cloze_q'], a)) + + return qa_pairs + + def _sample(self, options, sample_num=1): + if len(options) <= sample_num: + return options + else: + return random.sample(options, sample_num) + + def mask_text(self, text: str, options, sample_num=1, mask_token=''): + ''' + :param text: text + :param options: [('xx', 1, 2), (), ()] + :return: [text, ('xx', 1, 2)] * sample_num + ''' + masked_span = self._sample(options, sample_num) + masked = [] + for span in masked_span: + if isinstance(span, str): + ntext = text.replace(span, mask_token) + elif len(span) == 3: + assert text[span[1]:span[2]] == span[0], (text[span[1]:span[2]], span[0]) + ntext = text[:span[1]] + mask_token + text[span[2]:] + else: + raise ValueError(span) + masked.append((ntext, span)) + return masked + + def assemble_question(self, regular_q, cloze_q): + return f'{regular_q} or {cloze_q}' + + +if __name__ == '__main__': + qg = QuestionGenerator('t5') + qa_pairs = qg.generate([['I was born yesterday.', [('born', 6, 10), ('yesterday', 11, 20)]]], sample_num=1) + print(qa_pairs) diff --git a/src/qg_client/t5_qg/LICENSE b/src/qg_client/t5_qg/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..f0a13039bbf2faa76888eac5ef6ec8d3dca5e0de --- /dev/null +++ b/src/qg_client/t5_qg/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2020 Suraj Patil + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/src/qg_client/t5_qg/README.md b/src/qg_client/t5_qg/README.md new file mode 100644 index 0000000000000000000000000000000000000000..9b90eb75538e340a6a30b50409e43593fbc5682e --- /dev/null +++ b/src/qg_client/t5_qg/README.md @@ -0,0 +1,351 @@ +# Question Generation using 🤗transformers + +- [Question Generation using 🤗transformers](#question-generation-using-transformers) + - [Project Details](#project-details) + - [Initial experiments](#initial-experiments) + - [answer aware question generation](#answer-aware-question-generation) + - [answer extraction models](#answer-extraction-models) + - [Multitask QA-QG](#multitask-qa-qg) + - [End-to-End question generation (answer agnostic)](#end-to-end-question-generation-answer-agnostic) + - [Results](#results) + - [Requirements](#requirements) + - [Usage](#usage) + - [Question Generation](#question-generation) + - [Multitask QA-QG](#multitask-qa-qg-1) + - [End-to-end question generation (without answer supervision)](#end-to-end-question-generation-without-answer-supervision) + - [Fine-tuning](#fine-tuning) + - [Data processing](#data-processing) + - [training](#training) + - [Evaluation](#evaluation) + - [Applications 🚀](#applications-) + - [Relevant papers](#relevant-papers) + + +## Project Details +Question generation is the task of automatically generating questions from a text paragraph. The most straight-forward way for this is answer aware question generation. In answer aware question generation the model is presented with the answer and the passage and asked to generate a question for that answer by considering the passage context. While there are many papers available for QG task, it's still not as mainstream as QA. One of the reasons is most of the earlier papers use complicated models/processing pipelines and have no pre-trained models available. Few recent papers, specifically UniLM and ProphetNet have SOTA pre-trained weights availble for QG but the usage seems quite complicated. + +This project is aimed as an open source study on question generation with pre-trained transformers (specifically seq-2-seq models) using straight-forward end-to-end methods without much complicated pipelines. The goal is to provide simplified data processing and training scripts and easy to use pipelines for inference. + + +## Initial experiments +Initial experiments are conducted using the SQuADv1 dataset and T5 model with different input processing formats as described below. + +### answer aware question generation + +For answer aware models the input text can be processed in two ways. + +**1. prepend format:** + + Here the answer is simply added before the context and seperated by sep token. For example + + `42 [SEP] 42 is the answer to life, the universe and everything.` + + for T5 model the input is processed like this + + `answer: 42 context: 42 is the answer to life, the universe and everything.` + +**2. highlight format** + +Here the answer span is highlighted within the text with special highlight tokens. + +` 42 is the answer to life, the universe and everything.` + +This idea is proposed in the "A Recurrent BERT-based Model for Question Generation" [paper](https://www.aclweb.org/anthology/D19-5821.pdf). See section 4.3 + +### answer extraction models + +As the answer aware models need answers for generating question, we need something which can extract answer like spans from the text. This can be done using various methods like NER, noun-phrase extarction etc. But here a model is trained to extract answer like spans, to see how it'll work. With T5, answer extarction is done using the text-to-format. + +As the highlight format will need to know the position of extracted answer spans the input for answer extraction is processed as follows + + 1. split the text into senteces. + 2. for each sentence that has answers, highlight the sentence with `` tokens. + 3. for the target text join the answers in that sentence with `` tokens. + +For example for this text + +`Python is a programming language. Created by Guido van Rossum and first released in 1991.` + +following examples will be created + +Input text: +` Python is a programming language. Created by Guido van Rossum and first released in 1991.` + +target text: +`Python ` + +and + +Input text: +`Python is a programming language. Created by Guido van Rossum and first released in 1991 .` + +target text: +`Guido van Rossum 1991 ` + +At inference time the text is split into sentences and each sentence is highlighted. + +### Multitask QA-QG + +For answer aware question generation we usually need 3 models, first which will extract answer like spans, second model will generate question on that answer and third will be a QA model which will take the question and produce an answer, +then we can compare the two answers to see if the generated question is correct or not. + +Having 3 models for single task is lot of complexity, so goal is to create a multi-task model which can do all of these 3 tasks + +1. extract answer like spans +2. generate question based on the answer +3. QA + +T5 model is fine-tuned in multi-task way using task prefixes as described in the paper. + +

+ +

+ +### End-to-End question generation (answer agnostic) + +In end-to-end question generation the model is aksed to generate questions without providing the answers. [This](https://arxiv.org/pdf/2005.01107v1.pdf) paper discusses these ideas in more detail. Here the T5 model is trained to generate multiple questions simultaneously by just providing the context. The questions are seperated by the `` token. Here's how the examples are processed + +input text: `Python is a programming language. Created by Guido van Rossum and first released in 1991.` + +target text: `Who created Python ? When was python released ? ` + +**All the training details can be found in [this](https://app.wandb.ai/psuraj/question-generation) wandb project** + +## Results + +Results on the SQuAD1.0 dev set using above approaches. For decoding, beam search with num_beams 4 is used with max decoding length set to 32. + +For multitask qa-qg models the EM and F1 scores are privded as QA-EM and QA-F1. + +The [nlg-eval](https://github.com/Maluuba/nlg-eval) package is used for calculating the metrics. + + +| Name | BLEU-4 | METEOR | ROUGE-L | QA-EM | QA-F1 | QG-FORMAT | +|----------------------------------------------------------------------------|---------|---------|---------|--------|--------|-----------| +| [t5-base-qg-hl](https://huggingface.co/valhalla/t5-base-qg-hl) | 21.3226 | 27.0854 | 43.5962 | - | - | highlight | +| [t5-base-qa-qg-hl](https://huggingface.co/valhalla/t5-base-qa-qg-hl) | 21.0141 | 26.9113 | 43.2484 | 82.46 | 90.272 | highlight | +| [t5-small-qa-qg-hl](https://huggingface.co/valhalla/t5-small-qa-qg-hl) | 18.9872 | 25.2217 | 40.7893 | 76.121 | 84.904 | highlight | +| [t5-small-qg-hl](https://huggingface.co/valhalla/t5-small-qg-hl) | 18.5921 | 24.9915 | 40.1886 | - | - | highlight | +| [t5-small-qg-prepend](https://huggingface.co/valhalla/t5-small-qg-prepend) | 18.2791 | 24.6722 | 39.958 | - | - | prepend | + + +## Requirements +``` +transformers==3.0.0 +nltk +nlp==0.2.0 # only if you want to fine-tune. +``` + +after installing `nltk` do +```bash +python -m nltk.downloader punkt +``` + +## Usage +Use the pipeline whch mimics 🤗transformers pipeline for easy inference. + +The pipeline is divided into 3 tasks +1. `question-generation`: for single task question generation models. +2. `multitask-qa-qg`: for multi-task qa,qg models. +3. `e2e-qg`: for end-to-end question generation. + +[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/patil-suraj/question_generation/blob/master/question_generation.ipynb) + +#### Question Generation + +```python3 +from pipelines import pipeline + +nlp = pipeline("question-generation") +nlp("42 is the answer to life, the universe and everything.") +=> [{'answer': '42', 'question': 'What is the answer to life, the universe and everything?'}] +``` + +**prepend format** +```python3 +nlp = pipeline("question-generation", model="valhalla/t5-small-qg-prepend", qg_format="prepend") +nlp("42 is the answer to life, the universe and everything.") +=> [{'answer': '42 ', 'question': 'What is the answer to life, the universe, and everything?'}] +``` + +#### Multitask QA-QG +```python3 +nlp = pipeline("multitask-qa-qg") + +# to generate questions simply pass the text +nlp("42 is the answer to life, the universe and everything.") +=> [{'answer': '42', 'question': 'What is the answer to life, the universe and everything?'}] + +# for qa pass a dict with "question" and "context" +nlp({ + "question": "What is 42 ?", + "context": "42 is the answer to life, the universe and everything." +}) +=> 'the answer to life, the universe and everything' +``` + +#### End-to-end question generation (without answer supervision) +```python3 +nlp = pipeline("e2e-qg") +nlp("Python is a programming language. Created by Guido van Rossum and first released in 1991.") +=> [ + 'What is a programming language?', + 'Who created Python?', + 'When was Python first released?' +] +``` + +By default both pipelines will use the t5-small* models, to use the other models pass the path through `model` paramter. + +By default the `question-generation` pipeline will download the [valhalla/t5-small-qg-hl](https://huggingface.co/valhalla/t5-small-qg-hl) model with `highlight` qg format. If you want to use prepend format then provide the path to the prepend model and set `qg_format` to `"prepend"`. For extracting answer like spans it uses [valhalla/t5-small-qa-qg-hl](https://huggingface.co/valhalla/t5-small-qa-qg-hl) model, you can provide a different model through `ans_model` parameter. + +The `multitask-qa-qg` model is for multitask models which can extract answer like spans, do qg and qa, so it won't need seperate `ans_model`. By default [valhalla/t5-small-qa-qg-hl](https://huggingface.co/valhalla/t5-small-qa-qg-hl) model is used with `highlight` format. If you want to use prepend format then provide the path to the prepend model and set `qg_format` to `"prepend"` + +The `e2e-qg` pipeline is for end-to-end question generation. These models can generate multiple questions simultaneously without answer supervision. By default it uses [valhalla/t5-small-e2e-qg](https://huggingface.co/valhalla/t5-small-e2e-qg) + +## Fine-tuning + +### Data processing + +To support different data formats the trainer expects pre-processed cached dataset, so you can process the data the way you want. +The cached dataset should be saved using `torch.save` and it should return a `dict` with `source_ids`, `target_ids`, `attention_mask` keys from `__getitem__`. + +- `source_ids`: encoded source text +- `target_ids`: encoded target text +- `attention_mask`: attention mask for the `source_ids` + +The `T2TDataCollator` takes care of preparing right `input_ids` and `labels`. It also trims the batches dynamically to remove excessive padding tokens, to speed up the training. + +The `data/squad_multitask` containes the modifed SQuAD dataset for answer aware question generation (using both prepend and highlight formats), question answering (text-to-text), answer extraction and end-to-end question generation. This dataset can be loaded using the awesome 🤗`nlp` library, this makes processing very easy. + +To process and cache the dataset use `prepare_data.py` script. It will load the correct tokenizer depending on the `model_type` argument. It adds two new tokens `` and `` to the tokenizer and saves it at `{model_type}_qg_tokenizer` path. You should pass this tokenizer to the fine-tuning script. + +The datasets will be saved in `data/` directory. You should provide filenames using `train_file_name` and `valid_file_name` arguments. + +**process data for single task question generation with highlight_qg_format** +```bash +python prepare_data.py \ + --task qg \ + --model_type t5 \ + --dataset_path data/squad_multitask/ \ + --qg_format highlight_qg_format \ + --max_source_length 512 \ + --max_target_length 32 \ + --train_file_name train_data_qg_hl_t5.pt \ + --valid_file_name valid_data_qg_hl_t5.pt \ +``` + +**process data for multi-task qa-qg with highlight_qg_format** + +`valid_for_qg_only` argument is used to decide if the validation set should only contain data for qg task. For my multi-task experiments I used validation data with only qg task so that the eval loss curve can be easly compared with other single task models + +```bash +python prepare_data.py \ + --task multi \ + --valid_for_qg_only \ + --model_type t5 \ + --dataset_path data/squad_multitask/ \ + --qg_format highlight_qg_format \ + --max_source_length 512 \ + --max_target_length 32 \ + --train_file_name train_data_qa_qg_hl_t5.pt \ + --valid_file_name valid_data_qg_hl_t5.pt \ +``` + +**process dataset for end-to-end question generation** +```bash +python prepare_data.py \ + --task e2e_qg \ + --valid_for_qg_only \ + --model_type t5 \ + --dataset_path data/squad_multitask/ \ + --qg_format highlight_qg_format \ + --max_source_length 512 \ + --max_target_length 32 \ + --train_file_name train_data_e2e_qg_t5.pt \ + --valid_file_name valid_data_e2e_qg_t5.pt \ +``` + +### training +Use the `run_qg.py` script to start training. It uses transformers `Trainer` class for training the models. + + +```bash +python run_qg.py \ + --model_name_or_path t5-small \ + --model_type t5 \ + --tokenizer_name_or_path t5_qg_tokenizer \ + --output_dir t5-small-qg-hl \ + --train_file_path data/train_data_qg_hl_t5.pt \ + --valid_file_path data/valid_data_qg_hl_t5.pt \ + --per_device_train_batch_size 32 \ + --per_device_eval_batch_size 32 \ + --gradient_accumulation_steps 8 \ + --learning_rate 1e-4 \ + --num_train_epochs 10 \ + --seed 42 \ + --do_train \ + --do_eval \ + --evaluate_during_training \ + --logging_steps 100 +``` + +or if you want to train it from script or notebook then + +```python3 +from run_qg import run_qg + +args_dict = { + "model_name_or_path": "t5-small", + "model_type": "t5", + "tokenizer_name_or_path": "t5_qg_tokenizer", + "output_dir": "t5-small-qg-hl", + "train_file_path": "data/train_data_qg_hl_t5.pt", + "valid_file_path": "data/valid_data_qg_hl_t5.pt", + "per_device_train_batch_size": 32, + "per_device_eval_batch_size": 32, + "gradient_accumulation_steps": 8, + "learning_rate": 1e-4, + "num_train_epochs": 10, + "seed": 42, + "do_train": True, + "do_eval": True, + "evaluate_during_training": True, + "logging_steps": 100 +} + +# start training +run_qg(args_dict) +``` + +### Evaluation + +Use the `eval.py` script for evaluting the model. + +```bash +python eval.py \ + --model_name_or_path t5-base-qg-hl \ + --valid_file_path valid_data_qg_hl_t5.pt \ + --model_type t5 \ + --num_beams 4 \ + --max_decoding_length 32 \ + --output_path hypothesis_t5-base-qg-hl.txt +``` + +This will save the output at {output_path} file. + +To calculate the metrics install the [nlg-eval](https://github.com/Maluuba/nlg-eval) package and run + +```bash +nlg-eval --hypothesis=hypothesis_t5-base-qg-hl.txt --references=data/references.txt --no-skipthoughts --no-glove +``` + +## Applications 🚀 + +1. A simple Trivia Quiz on topics of your choice -
+ [Medium article](https://medium.com/@nvarshney97/using-the-latest-nlp-techniques-for-fun-98f31ce7b556) and its [Colab Notebook](https://colab.research.google.com/gist/nrjvarshney/39ed6c80e2fe293b9e7eca5bc3a45b7d/quiz.ipynb) + +## Relevant papers +- https://arxiv.org/abs/1906.05416 +- https://www.aclweb.org/anthology/D19-5821/ +- https://arxiv.org/abs/2005.01107v1 diff --git a/src/qg_client/t5_qg/generator.py b/src/qg_client/t5_qg/generator.py new file mode 100644 index 0000000000000000000000000000000000000000..c234600d521c51845a996aeee128b0bfcbf9fa54 --- /dev/null +++ b/src/qg_client/t5_qg/generator.py @@ -0,0 +1,123 @@ +# -*- coding:utf-8 -*- + +""" +@Author : Bao +@Date : 2020/10/13 +@Desc : +@Last modified by : Bao +@Last modified date : 2020/11/11 +""" + +import torch +from tqdm import tqdm +from transformers import AutoTokenizer, AutoModelForSeq2SeqLM + + +DEFAULT_DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' + + +class Generator: + """ + Examples: + import json + import torch + + input_data = [ + {'context': 'My name is Sarah.', 'answer': ('Sarah', 11, 16)}, + {'context': 'My name is Sarah and I live in London.', 'answer': ('London', 31, 37)}, + {'context': 'Sarah lived in London. Jone lived in Canada.', 'answer': ('Canada', 37, 43)}, + {'context': 'Sarah lived in London. Jone lived in Canada.', 'answer': ('lived', 28, 33)}, + ] + generator = Generator( + 'valhalla/t5-base-qg-hl', + 'your_cache_dir', + 'cuda' if torch.cuda.is_available() else 'cpu', + ) + + results = generator(input_data, beam_size=5) + print(json.dumps(results, ensure_ascii=False, indent=4)) + """ + + def __init__(self, model_name_or_path, cache_dir=None, device=DEFAULT_DEVICE, verbose=True): + self.seed = 1111 + self.device = device + self.verbose = verbose + self.tokenizer = AutoTokenizer.from_pretrained( + model_name_or_path, + cache_dir=cache_dir if cache_dir else None, + ) + self.model = AutoModelForSeq2SeqLM.from_pretrained( + model_name_or_path, + cache_dir=cache_dir if cache_dir else None, + ) + n_gpu = torch.cuda.device_count() + if n_gpu > 0: + torch.cuda.manual_seed_all(self.seed) + self.model.to(device) + # if n_gpu > 1: + # self.model = torch.nn.DataParallel(self.model) + self.model.eval() + + def __call__(self, input_data, beam_size=1, max_length=100, batch_size=8): + all_ids_with_beam = [] + num_batches = (len(input_data) + batch_size - 1) // batch_size + iter = tqdm(range(num_batches), desc='Generate questions') if self.verbose else range(num_batches) + for step in iter: + batch_start = step * batch_size + batch_end = min((step + 1) * batch_size, len(input_data)) + + batch_text = [] + for entry in input_data[batch_start:batch_end]: + context = entry['context'] + answer, answer_start, answer_end = entry['answer'] + context = 'generate question: ' + context[:answer_start] + \ + ' ' + answer + ' ' + context[answer_end:] + ' ' + batch_text.append(context) + + inputs = self.tokenizer.batch_encode_plus( + batch_text, + padding='max_length', + truncation='longest_first', + max_length=max_length, + return_tensors='pt', + ) + + for key, value in inputs.items(): + inputs[key] = value.to(self.device) + + ids_with_beam = self.model.generate(num_beams=beam_size, + num_return_sequences=beam_size, + no_repeat_ngram_size=3, + early_stopping=True, + length_penalty=1.5, + repetition_penalty=1.5, + min_length=3, + **inputs) + ids_with_beam = ids_with_beam.reshape([len(batch_text), beam_size, -1]) + all_ids_with_beam.extend(ids_with_beam.detach().cpu().tolist()) + + for i, ids_with_beam in enumerate(all_ids_with_beam): + input_data[i]['questions'] = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in ids_with_beam] + + return input_data + + +if __name__ == '__main__': + import json + + input_data = [ + {'context': 'My name is Sarah.', 'answer': ('Sarah', 11, 16)}, + {'context': 'My name is Sarah and I live in London.', 'answer': ('London', 31, 37)}, + {'context': 'Sarah lived in London. Jone lived in Canada.', 'answer': ('Canada', 37, 43)}, + {'context': 'Sarah lived in London. Jone lived in Canada.', 'answer': ('lived', 28, 33)}, + ] + + generator = Generator( + 'valhalla/t5-base-qg-hl', + 'cache/', + 'cuda' if torch.cuda.is_available() else 'cpu', + ) + + results = generator(input_data, beam_size=5) + + print(json.dumps(results, ensure_ascii=False, indent=4)) \ No newline at end of file