import difflib import os import json from tqdm import tqdm from glob import glob # # if not os.path.exists('./evttgr2type.json'): # for file_name in glob('data/RAMS_1.0/data/test.jsonlines'): # dic = {} # with open(file_name,'r',encoding='utf-8') as f: # lines = f.readlines() # for line in tqdm(lines): # linej = json.loads(line.strip()) # evt_triggers = linej['evt_triggers'] # # print(evt_triggers) # sentences = linej['sentences'] # # print(sentences) # sentences_uni = [] # for s in sentences: # sentences_uni += s # print(' '.join(sentences_uni)) # triggers = ' '.join(sentences_uni[evt_triggers[0][0]:evt_triggers[0][1]+1]) # evt_type = evt_triggers[0][2][0][0] # if triggers in dic: # if dic[triggers] != evt_type: # print('一个触发词有不同的事件类型: {} {} {}'.format(triggers,evt_type,dic[triggers])) # dic[triggers] = evt_type # print(evt_type, triggers) # exit() import argparse import jsonlines import torch from src.genie.data import my_collate from src.genie.data_module_w import RAMSDataModule from src.genie.model import GenIEModel import gradio as gr import re from transformers import BartTokenizer MAX_LENGTH = 424 MAX_TGT_LENGTH = 72 DOC_STRIDE = 256 class DataModule4(): def __init__(self, ontology_file): super().__init__() self.ontology_file = ontology_file self.tokenizer = BartTokenizer.from_pretrained('facebook/bart-large') self.tokenizer.add_tokens([' ', ' ']) self.ontology_dict = self.load_ontology() def create_gold_gen(self, context_words, evt_type, trigger): # 设置三个总列表、存放输入模板、输出模板 INPUT = [] CONTEXT = [] input_template = self.ontology_dict[evt_type.replace('n/a', 'unspecified')]['template'] i = len(input_template) input_list = [] for x in range(i): str = re.sub(r'', '', input_template[x]) input_list.append(str) # 其中input_list种存放的是 原始数据中 全部替换为 之后的模板 下一步应该进行分词 temp = [] for x in range(i): space_tokenized_template = input_list[x].split(' ') temp.append(space_tokenized_template) # 其中temp中存放的都是分词后的模板 下一步对temp中的所有元素进行tokenize tokenized_input_template = [] for x in range(len(temp)): for w in temp[x]: tokenized_input_template.extend(self.tokenizer.tokenize(w, add_prefix_space=True)) INPUT.append(tokenized_input_template) tokenized_input_template = [] context_words = context_words.split(' ') trigger_words = trigger.split(' ') trigger_span_start = context_words.index(trigger_words[0]) trigger_span_end = context_words.index(trigger_words[-1]) # 触发词之前的单词 prefix = self.tokenizer.tokenize(' '.join(context_words[:trigger_span_start]), add_prefix_space=True) # 触发词短语 tgt = self.tokenizer.tokenize(trigger, add_prefix_space=True) # 触发词之后的单词 suffix = self.tokenizer.tokenize(' '.join(context_words[trigger_span_end+1:]), add_prefix_space=True) context = prefix + [' ', ] + tgt + [' ', ] + suffix # context = self.tokenizer.tokenize(' '.join(context_words), add_prefix_space=True) # 将context放入CONTEXT中 for w in range(i): CONTEXT.append(context) return INPUT, CONTEXT def load_ontology(self): ontology_dict = {} with open(self.ontology_file, 'r') as f: for lidx, line in enumerate(f): if lidx == 0: # header continue fields = line.strip().split(',') if len(fields) < 2: break evt_type = fields[0] if evt_type in ontology_dict.keys(): args = fields[2:] ontology_dict[evt_type]['template'].append(fields[1]) for i, arg in enumerate(args): if arg != '': ontology_dict[evt_type]['arg{}'.format(i + 1)] = arg ontology_dict[evt_type][arg] = 'arg{}'.format(i + 1) else: ontology_dict[evt_type] = {} args = fields[2:] ontology_dict[evt_type]['template'] = [] ontology_dict[evt_type]['template'].append(fields[1]) for i, arg in enumerate(args): if arg != '': ontology_dict[evt_type]['arg{}'.format(i + 1)] = arg ontology_dict[evt_type][arg] = 'arg{}'.format(i + 1) return ontology_dict def prepare_data(self, sentences, evt_type, trigger): input_template, context = self.create_gold_gen(sentences, evt_type, trigger) length = len(input_template) # print(input_template) # print(context) results = [] for i in range(length): input_tokens = self.tokenizer.encode_plus(input_template[i], context[i], add_special_tokens=True, add_prefix_space=True, max_length=MAX_LENGTH, truncation='only_second', padding='max_length') # input_ids 单词在词典中的编码 results.append(input_tokens['input_ids']) temp = self.ontology_dict[evt_type.replace('n/a', 'unspecified')] return results, temp class DataModuleW(): def __init__(self, ontology_file): super().__init__() self.ontology_file = ontology_file self.tokenizer = BartTokenizer.from_pretrained('facebook/bart-large') self.tokenizer.add_tokens([' ', ' ']) self.ontology_dict = self.load_ontology() def create_gold_gen(self, context_words, evt_type, trigger): # 设置三个总列表、存放输入模板、输出模板 INPUT = [] CONTEXT = [] input_template = self.ontology_dict[evt_type.replace('n/a', 'unspecified')]['template'] i = len(input_template) input_list = [] for x in range(i): str = re.sub('', trigger, input_template[x]) str = re.sub('', trigger, str) input_list.append(str) # 其中input_list种存放的是 原始数据中 全部替换为 之后的模板 下一步应该进行分词 temp = [] for x in range(i): space_tokenized_template = input_list[x].split(' ') temp.append(space_tokenized_template) # 其中temp中存放的都是分词后的模板 下一步对temp中的所有元素进行tokenize tokenized_input_template = [] for x in range(len(temp)): for w in temp[x]: tokenized_input_template.extend(self.tokenizer.tokenize(w, add_prefix_space=True)) INPUT.append(tokenized_input_template) tokenized_input_template = [] template = self.ontology_dict[evt_type.replace('n/a', 'unspecified')]['template'] for y in range(len(template)): template[y] = re.sub('', trigger, template[y]) context = self.tokenizer.tokenize(context_words, add_prefix_space=True) # 将context放入CONTEXT中 for w in range(i): CONTEXT.append(context) return INPUT, CONTEXT def load_ontology(self): ontology_dict = {} with open(self.ontology_file, 'r') as f: for lidx, line in tqdm(enumerate(f)): if lidx == 0: # header continue fields = line.strip().split(',') if len(fields) < 2: break evt_type = fields[0] if evt_type in ontology_dict.keys(): args = fields[2:] ontology_dict[evt_type]['template'].append(fields[1]) for i, arg in enumerate(args): if arg != '': ontology_dict[evt_type]['arg{}'.format(i + 1)] = arg ontology_dict[evt_type][arg] = 'arg{}'.format(i + 1) else: ontology_dict[evt_type] = {} args = fields[2:] ontology_dict[evt_type]['template'] = [] ontology_dict[evt_type]['template'].append(fields[1]) for i, arg in enumerate(args): if arg != '': ontology_dict[evt_type]['arg{}'.format(i + 1)] = arg ontology_dict[evt_type][arg] = 'arg{}'.format(i + 1) return ontology_dict def prepare_data(self, sentences, evt_type, trigger): input_template, context = self.create_gold_gen(sentences, evt_type, trigger) length = len(input_template) # print(input_template) # print(output_template) # print(context) results = [] for i in range(length): input_tokens = self.tokenizer.encode_plus(input_template[i], context[i], add_special_tokens=True, add_prefix_space=True, max_length=MAX_LENGTH, truncation='only_second', padding='max_length') # input_ids 单词在词典中的编码 results.append(input_tokens['input_ids']) temp = self.ontology_dict[evt_type.replace('n/a', 'unspecified')] return results, temp class Runner(): def __init__(self, load_ckpt = 'checkpoints/gen-RAMS-what-new-span/epoch=2-v0.ckpt'): model = 'gen' self.ckpt_name = 'gen-RAMS-pred' self.load_ckpt = load_ckpt self.dataset = 'RAMS' self.eval_only = True self.train_file = 'data/RAMS_1.0/data/train.jsonlines' self.val_file = 'data/RAMS_1.0/data/dev.jsonlines' self.test_file = 'data/RAMS_1.0/data/test.jsonlines' self.train_batch_size = 2 self.eval_batch_size = 4 self.learning_rate = 3e-5 self.accumulate_grad_batches = 4 self.num_train_epochs = 3 parser = argparse.ArgumentParser() # Required parameters parser.add_argument( "--model", type=str, default=model ) parser.add_argument( "--dataset", type=str, default=self.dataset ) parser.add_argument('--tmp_dir', type=str) parser.add_argument( "--ckpt_name", default=self.ckpt_name, type=str, help="The output directory where the model checkpoints and predictions will be written.", ) parser.add_argument( "--load_ckpt", default=self.load_ckpt, type=str, ) parser.add_argument( "--train_file", default=self.train_file, type=str, help="The input training file. If a data dir is specified, will look for the file there" + "If no data dir or train/predict files are specified, will run with tensorflow_datasets.", ) parser.add_argument( "--val_file", default=self.val_file, type=str, help="The input evaluation file. If a data dir is specified, will look for the file there" + "If no data dir or train/predict files are specified, will run with tensorflow_datasets.", ) parser.add_argument( '--test_file', type=str, default=self.test_file, ) parser.add_argument('--input_dir', type=str, default=None) parser.add_argument('--coref_dir', type=str, default='data/kairos/coref_outputs') parser.add_argument('--use_info', action='store_true', default=False, help='use informative mentions instead of the nearest mention.') parser.add_argument('--mark_trigger', action='store_true') parser.add_argument('--sample-gen', action='store_true', help='Do sampling when generation.') parser.add_argument("--train_batch_size", default=self.train_batch_size, type=int, help="Batch size per GPU/CPU for training.") parser.add_argument( "--eval_batch_size", default=self.eval_batch_size, type=int, help="Batch size per GPU/CPU for evaluation." ) parser.add_argument("--learning_rate", default=self.learning_rate, type=float, help="The initial learning rate for Adam.") parser.add_argument( "--accumulate_grad_batches", type=int, default=self.accumulate_grad_batches, help="Number of updates steps to accumulate before performing a backward/update pass.", ) parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.") parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.") parser.add_argument("--gradient_clip_val", default=1.0, type=float, help="Max gradient norm.") parser.add_argument( "--num_train_epochs", default=self.num_train_epochs, type=int, help="Total number of training epochs to perform." ) parser.add_argument( "--max_steps", default=-1, type=int, help="If > 0: set total number of training steps to perform. Override num_train_epochs.", ) parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.") parser.add_argument("--gpus", default=None, help='-1 means train on all gpus') parser.add_argument("--seed", type=int, default=42, help="random seed for initialization") parser.add_argument( "--fp16", action="store_true", help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit", ) parser.add_argument("--threads", type=int, default=1, help="multiple threads for converting example to features") self.args = parser.parse_args() self.model = GenIEModel(self.args) self.model.load_state_dict(torch.load(self.args.load_ckpt, map_location=self.model.device)['state_dict']) def pred(self,input): x = torch.stack([torch.LongTensor(u) for u in input]) return self.model.pred(x) print('Loading data...') dm1 = DataModule4('aida_ontology_cleaned.csv') dm2 = DataModuleW('aida_ontology_fj-w-2.csv') dm3 = DataModuleW('aida_ontology_fj-w-3.csv') dm4 = DataModule4('aida_ontology_fj-5.csv') print('Loading Model 1...') runner1 = Runner('checkpoints/gen-RAMS-1-span/epoch=2-v1.ckpt') print('Loading Model 2...') runner2 = Runner('checkpoints/gen-RAMS-2-span/epoch=2-v0.ckpt') print('Loading Model 3...') runner3 = Runner('checkpoints/gen-RAMS-3-span/epoch=2-v0.ckpt') print('Loading Model 4...') runner4 = Runner('checkpoints/gen-RAMS-4-span/epoch=2-v0.ckpt') def handle(sentences,trigger, temp=3, evt_type='contact.prevarication.broadcast'): x, argnames = eval('dm{}.prepare_data(sentences,evt_type,trigger)'.format(temp+1)) ys = eval('runner{}.pred(x)'.format(temp+1)) print(ys) results = [] for y in ys: while ' ' in y: y = y.replace(' ', ' ') result = y.strip(' ').split(' ') results.append(result) print(results) argss = [] for n,template in enumerate(argnames['template']): template = template.split(' ') # print(template) args = [] for i, w in enumerate(template): if '', w).group(1)]) if m: label = m.group(1) if results[n][i] == '': args.append(label+': None') else: args.append(label+': '+results[n][i]) argss.append(', '.join(args)) return '\n'.join(argss) if __name__ == "__main__": # trigger = 'deceive' # sentences = """We are ashamed of them . " However , Mutko stopped short of admitting the doping scandal was state sponsored . " We are very sorry that athletes who tried to deceive us , and the world , were not caught sooner . We are very sorry because Russia is committed to upholding the highest standards in sport and is opposed to anything that threatens the Olympic values , " he said . English former heptathlete and Athens 2004 bronze medallist Kelly Sotherton was unhappy with Mutko 's plea for Russia 's ban to be lifted for Rio""" # print(handle(sentences, trigger)) dm_key = list(dm1.ontology_dict.keys()) print(len(dm_key)) def get_tmp(index,evt_type): if index is None or evt_type is None: return '' input_template = eval("dm{}.ontology_dict[evt_type.replace('n/a', 'unspecified')]['template']".format(index+1)) return '\n'.join(input_template) with gr.Blocks() as demo: with gr.Row().style(equal_height=False): with gr.Column(variant="panel"): stens = gr.Text(label='文档') evt_type = gr.Dropdown(choices=dm_key, label='事件类型') trigger = gr.Text(label='触发词') temp = gr.Dropdown(choices=['基础模板', '简单子模板', '融入语义信息的子模板', '融入论元信息的子模板'], type='index', value='基础模板', label='模板') output_tmp = gr.Text(label='模板内容') btn = gr.Button("Run") with gr.Column(variant="panel"): result = gr.Text(label='输出') evt_type.change(get_tmp,inputs=[temp,evt_type],outputs=[output_tmp]) temp.change(get_tmp,inputs=[temp,evt_type],outputs=[output_tmp]) btn.click(fn=handle, inputs=[stens,trigger,temp,evt_type], outputs=[result]) demo.launch(server_name='0.0.0.0',server_port=6006,share=True)