Spaces:
Runtime error
Runtime error
import json | |
import re | |
import torch | |
import numpy as np | |
from collections import defaultdict | |
class summarization_data: | |
def __init__(self, raw_data, tokenizer, domain_adaption=False, wwm_prob=0.1): | |
data = self.select_data(raw_data) | |
self.data = self.data_clean(data) | |
self.len = len(self.data) | |
self.tokenizer = tokenizer | |
self.padding_max_len = "max_length" | |
self.domain_adaption = domain_adaption | |
self.wwm_prob = wwm_prob | |
def select_data(self, raw_data): | |
data = list() | |
for item in raw_data: | |
del item['pub_time'] | |
del item['labels'] | |
data.append(item) | |
return data | |
def data_clean(self, data): | |
for item in data: | |
item["text"] = re.sub(r"[?|!+?|:|(|)]|\\|-|/.*?/|http\S+", "", item["text"].lower()) | |
item["title"] = re.sub(r"[?|!+?|:|(|)]|\\|-|/.*?/|http\S+", "", item["title"].lower()) | |
return data | |
def __getitem__(self, index): | |
if self.domain_adaption: | |
tokenized_text = self.tokenizer( | |
self.data[index]["text"], | |
add_special_tokens=True, | |
padding="max_length", | |
return_token_type_ids=True, | |
truncation=True, | |
) | |
text_mask = tokenized_text['attention_mask'] | |
input_ids, labels = self._word_masking(self.tokenizer, tokenized_text, self.wwm_prob) | |
return { | |
'input_ids': torch.tensor(input_ids), | |
'attention_mask': torch.tensor(text_mask), | |
'labels': torch.tensor(labels) | |
} | |
else: | |
tokenized_text = self.tokenizer( | |
"summarize:"+self.data[index]["text"], | |
add_special_tokens=True, | |
padding="max_length", | |
return_token_type_ids=True, | |
truncation=True, | |
) | |
text_ids = tokenized_text['input_ids'] | |
text_mask = tokenized_text['attention_mask'] | |
tokenized_title = self.tokenizer( | |
self.data[index]["title"], | |
padding="max_length", | |
return_token_type_ids=True, | |
truncation=True, | |
) | |
title_ids = tokenized_title['input_ids'] | |
return { | |
'input_ids': torch.tensor(text_ids), | |
'attention_mask': torch.tensor(text_mask), | |
'labels': torch.tensor(title_ids) | |
} | |
def _word_masking(self, tokenizer, tokenized_inputs, wwm_prob): | |
# randomly mask words | |
input_ids = tokenized_inputs["input_ids"] | |
mask = np.random.binomial(1, wwm_prob, (len(input_ids),)) | |
labels = list() | |
for idx in np.where(mask == 1)[0]: | |
#add special sentinel tokens | |
sentinel_token = tokenizer.additional_special_tokens[input_ids[idx] % 100] | |
labels.append(tokenizer(sentinel_token).input_ids[0]) | |
labels.append(input_ids[idx]) | |
input_ids[idx] = tokenizer(sentinel_token).input_ids[0] | |
return input_ids, labels | |
def __len__(self): | |
return self.len | |
# from tqdm.auto import tqdm | |
# from transformers import T5TokenizerFast, T5ForConditionalGeneration | |
# from transformers import DataCollatorForSeq2Seq | |
# from torch.utils.data import DataLoader, random_split | |
# from FTE_NLP.model.summarization_dataset_v1 import * | |
# from torch.optim import AdamW | |
# from transformers import get_scheduler | |
# from FTE_NLP.utils.post_process import * | |
# | |
# json_filename = '../data/raw_EDT/Trading_benchmark/evaluate_news_test.json' | |
# with open(json_filename) as data_file: | |
# test_data = json.loads(data_file.read()) | |
# | |
# model_checkpoint = "t5-small" | |
# tokenizer = T5TokenizerFast.from_pretrained(model_checkpoint, model_max_length=512) | |
# model = T5ForConditionalGeneration.from_pretrained(model_checkpoint) | |
# | |
# all_dataset = summarization_data(test_data, tokenizer, domain_adaption=True, wwm_prob=0.1) | |
# train_dataset, eval_dataset = random_split(all_dataset, [7, 3], generator=torch.Generator().manual_seed(42)) | |
# | |
# # data_collator = DataCollatorForSeq2Seq(tokenizer, model,label_pad_token_id=0) | |
# data_collator = DataCollatorForSeq2Seq(tokenizer,model) | |
# # pass data to dataloader | |
# train_params = {'batch_size': 2, 'shuffle': False, 'num_workers': 0} | |
# train_loader = DataLoader(train_dataset, collate_fn=data_collator, **train_params) | |
# | |
# eval_params = {'batch_size': 2, 'shuffle': False, 'num_workers': 0} | |
# eval_loader = DataLoader(eval_dataset, collate_fn=data_collator, **train_params) | |
# | |
# | |
# for item in train_loader: | |
# print(item) | |
# print("input id numbers:",item["input_ids"][0]) | |
# print("input id: ",tokenizer.decode(item["input_ids"][0])) | |
# print("labels id number:",item["labels"][0]) | |
# print("labels:",tokenizer.decode(item["labels"][0],skip_special_tokens=False)) | |
# break |