import os import json import jsonlines import re import random from collections import defaultdict import argparse import transformers from transformers import BartTokenizer import torch from torch.utils.data import DataLoader import pytorch_lightning as pl from .data import IEDataset, my_collate MAX_LENGTH = 424 MAX_TGT_LENGTH = 72 DOC_STRIDE = 256 print("data_module1.py") class RAMSDataModule(pl.LightningDataModule): def __init__(self, args): super().__init__() self.hparams = args self.tokenizer = BartTokenizer.from_pretrained('facebook/bart-large') self.tokenizer.add_tokens([' ', ' ']) def get_event_type(self, ex): evt_type = [] for evt in ex['evt_triggers']: for t in evt[2]: evt_type.append(t[0]) return evt_type # 获取标签数据 def create_gold_gen(self, ex, ontology_dict, mark_trigger=True): '''assumes that each line only contains 1 event. Input: Template with special placeholders Passage Output: Template with arguments and when no argument is found. ''' # 目前的模板: what is the in # Input: Template with special placeholders Passage # Output: Template with arguments and when no argument is found. # 得到每条数据的事件类型 evt_type = self.get_event_type(ex)[0] # 将文档中的每个单词取出放入context_words这个新建列表里 context_words = [w for sent in ex['sentences'] for w in sent] # 从事件本体中取出事件模板 有的事件类型模板做特殊处理 template = ontology_dict[evt_type.replace('n/a', 'unspecified')]['template'] # 将占位符 用 trigger进行替换 trigger_index = ex['evt_triggers'][0][0] trg = context_words[trigger_index] template = re.sub(r'', trg, template) # 将输入模板中的arg1 arg2等编号论元全部替换为统一的 input_template = re.sub(r'', '', template) # 转换之后 what is in trg what is the in trg ... # 将模板进行分词 space_tokenized_input_template = input_template.split(' ') # 分词后存储的列表 tokenized_input_template = [] # 将每个单词进行分词后添加到上面这个列表中 for w in space_tokenized_input_template: tokenized_input_template.extend(self.tokenizer.tokenize(w, add_prefix_space=True)) # 获取三元组 构建输出模板 即标签 for triple in ex['gold_evt_links']: # 触发词 论元 论元 # 例子: "gold_evt_links": # [[[40, 40], [33, 33], "evt089arg01victim"], # [[40, 40], [28, 28], "evt089arg02place"]] trigger_span, argument_span, arg_name = triple # 第几个论元 arg_num = ontology_dict[evt_type.replace('n/a', 'unspecified')][arg_name] # 具体论元内容 短语 arg_text = ' '.join(context_words[argument_span[0]:argument_span[1] + 1]) # 通过正则表达式的方式将模板中的每个 替换为具体的论元内容 template = re.sub('<{}>'.format(arg_num), arg_text, template) # 获取触发词 trigger = ex['evt_triggers'][0] if mark_trigger: trigger_span_start = trigger[0] trigger_span_end = trigger[1] + 2 # one for inclusion, one for extra start marker # 触发词之前的单词 prefix = self.tokenizer.tokenize(' '.join(context_words[:trigger[0]]), add_prefix_space=True) # 触发词短语 tgt = self.tokenizer.tokenize(' '.join(context_words[trigger[0]: trigger[1] + 1]), add_prefix_space=True) # 触发词之后的单词 suffix = self.tokenizer.tokenize(' '.join(context_words[trigger[1] + 1:]), add_prefix_space=True) context = prefix + [' ', ] + tgt + [' ', ] + suffix else: context = self.tokenizer.tokenize(' '.join(context_words), add_prefix_space=True) # 输出模板中的等都替换为统一的 output_template = re.sub(r'', '', template) space_tokenized_template = output_template.split(' ') print(output_template) tokenized_template = [] for w in space_tokenized_template: tokenized_template.extend(self.tokenizer.tokenize(w, add_prefix_space=True)) return tokenized_input_template, tokenized_template, context def load_ontology(self): # read ontology ontology_dict = {} with open('aida_ontology_new.csv', 'r') as f: for lidx, line in enumerate(f): if lidx == 0: # header continue fields = line.strip().split(',') if len(fields) < 2: break # 获取事件类型 evt_type = fields[0] # 得到该事件类型下的所有论元类型 args = fields[2:] # 获取该事件类型下带带抽取的论元数量 args_len = 0 for i, arg in enumerate(args): if arg != '': args_len = args_len + 1 # 将事件本体字典中添加事件类型的key,该key下对应的value为模板 # 利用args_len将template中的子模板数量进行循环增加, 将后续的子模板通过字符串拼接的方式进行增加 # 最终的模板样式变为 what is the in what is the in # 先利用一个临时的字符串变量来存储模板 ----------> temp_template temp_template = "" for i in range(args_len): temp_template = temp_template + " what is the in ".format(i + 1) # 将事件本体字典中添加事件类型的key,该key下对应的value为模板 ontology_dict[evt_type] = { 'template': temp_template } # 对每个论元类型进行遍历 for i, arg in enumerate(args): if arg != '': # 事件类型下添加字典一项 arg1的值为arg ontology_dict[evt_type]['arg{}'.format(i + 1)] = arg ontology_dict[evt_type][arg] = 'arg{}'.format(i + 1) return ontology_dict def prepare_data(self): if not os.path.exists('head_templates_preprocessed_data'): os.makedirs('head_templates_preprocessed_data') ontology_dict = self.load_ontology() #print("aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa") for split, f in [('train', self.hparams.train_file), ('val', self.hparams.val_file), ('test', self.hparams.test_file)]: with open(f, 'r') as reader, open('head_templates_preprocessed_data/{}.jsonl'.format(split), 'w') as writer: for lidx, line in enumerate(reader): # 读取jsonlines中的每一行 ex = json.loads(line.strip()) # 输入模板 应该输出的模板 文本 input_template, output_template, context = self.create_gold_gen(ex, ontology_dict, self.hparams.mark_trigger) # 返回所有的编码信息 input_tokens = self.tokenizer.encode_plus(input_template, context, add_special_tokens=True, add_prefix_space=True, max_length=MAX_LENGTH, truncation='only_second', padding='max_length') # target_tokens tgt_tokens = self.tokenizer.encode_plus(output_template, add_special_tokens=True, add_prefix_space=True, max_length=MAX_TGT_LENGTH, truncation=True, padding='max_length') # input_ids 单词在词典中的编码 # tgt_tokens 指定对哪些词进行self_attention操作 processed_ex = { # 'idx': lidx, 'doc_key': ex['doc_key'], 'input_token_ids': input_tokens['input_ids'], 'input_attn_mask': input_tokens['attention_mask'], 'tgt_token_ids': tgt_tokens['input_ids'], 'tgt_attn_mask': tgt_tokens['attention_mask'], } #print(processed_ex) writer.write(json.dumps(processed_ex) + "\n") def train_dataloader(self): dataset = IEDataset('head_templates_preprocessed_data/train.jsonl') dataloader = DataLoader(dataset, pin_memory=True, num_workers=2, collate_fn=my_collate, batch_size=self.hparams.train_batch_size, shuffle=True) return dataloader def val_dataloader(self): dataset = IEDataset('head_templates_preprocessed_data/val.jsonl') dataloader = DataLoader(dataset, pin_memory=True, num_workers=2, collate_fn=my_collate, batch_size=self.hparams.eval_batch_size, shuffle=False) return dataloader def test_dataloader(self): dataset = IEDataset('head_templates_preprocessed_data/test.jsonl') dataloader = DataLoader(dataset, pin_memory=True, num_workers=2, collate_fn=my_collate, batch_size=self.hparams.eval_batch_size, shuffle=False) return dataloader if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--train-file', type=str, default='data/RAMS_1.0/data/train.jsonlines') parser.add_argument('--val-file', type=str, default='data/RAMS_1.0/data/dev.jsonlines') parser.add_argument('--test-file', type=str, default='data/RAMS_1.0/data/test.jsonlines') parser.add_argument('--train_batch_size', type=int, default=2) parser.add_argument('--eval_batch_size', type=int, default=4) parser.add_argument('--mark-trigger', action='store_true', default=True) args = parser.parse_args() print("data_module1.pyaaaaaaaaaaaaaaa") dm = RAMSDataModule(args=args) dm.prepare_data() # training dataloader dataloader = dm.train_dataloader() for idx, batch in enumerate(dataloader): print(batch) break # val dataloader