Upload summarization_dataset.py
Browse files- summarization_dataset.py +131 -0
summarization_dataset.py
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import re
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
from collections import defaultdict
|
6 |
+
|
7 |
+
class summarization_data:
|
8 |
+
def __init__(self, raw_data, tokenizer, domain_adaption=False, wwm_prob=0.1):
|
9 |
+
data = self.select_data(raw_data)
|
10 |
+
self.data = self.data_clean(data)
|
11 |
+
self.len = len(self.data)
|
12 |
+
self.tokenizer = tokenizer
|
13 |
+
self.padding_max_len = "max_length"
|
14 |
+
self.domain_adaption = domain_adaption
|
15 |
+
self.wwm_prob = wwm_prob
|
16 |
+
|
17 |
+
def select_data(self, raw_data):
|
18 |
+
data = list()
|
19 |
+
for item in raw_data:
|
20 |
+
del item['pub_time']
|
21 |
+
del item['labels']
|
22 |
+
data.append(item)
|
23 |
+
return data
|
24 |
+
|
25 |
+
def data_clean(self, data):
|
26 |
+
for item in data:
|
27 |
+
item["text"] = re.sub(r"[?|!+?|:|(|)]|\\|-|/.*?/|http\S+", "", item["text"].lower())
|
28 |
+
item["title"] = re.sub(r"[?|!+?|:|(|)]|\\|-|/.*?/|http\S+", "", item["title"].lower())
|
29 |
+
return data
|
30 |
+
|
31 |
+
|
32 |
+
def __getitem__(self, index):
|
33 |
+
if self.domain_adaption:
|
34 |
+
tokenized_text = self.tokenizer(
|
35 |
+
self.data[index]["text"],
|
36 |
+
add_special_tokens=True,
|
37 |
+
padding="max_length",
|
38 |
+
return_token_type_ids=True,
|
39 |
+
truncation=True,
|
40 |
+
)
|
41 |
+
text_mask = tokenized_text['attention_mask']
|
42 |
+
|
43 |
+
input_ids, labels = self._word_masking(self.tokenizer, tokenized_text, self.wwm_prob)
|
44 |
+
return {
|
45 |
+
'input_ids': torch.tensor(input_ids),
|
46 |
+
'attention_mask': torch.tensor(text_mask),
|
47 |
+
'labels': torch.tensor(labels)
|
48 |
+
}
|
49 |
+
|
50 |
+
else:
|
51 |
+
tokenized_text = self.tokenizer(
|
52 |
+
"summarize:"+self.data[index]["text"],
|
53 |
+
add_special_tokens=True,
|
54 |
+
padding="max_length",
|
55 |
+
return_token_type_ids=True,
|
56 |
+
truncation=True,
|
57 |
+
)
|
58 |
+
text_ids = tokenized_text['input_ids']
|
59 |
+
text_mask = tokenized_text['attention_mask']
|
60 |
+
|
61 |
+
tokenized_title = self.tokenizer(
|
62 |
+
self.data[index]["title"],
|
63 |
+
padding="max_length",
|
64 |
+
return_token_type_ids=True,
|
65 |
+
truncation=True,
|
66 |
+
)
|
67 |
+
title_ids = tokenized_title['input_ids']
|
68 |
+
return {
|
69 |
+
'input_ids': torch.tensor(text_ids),
|
70 |
+
'attention_mask': torch.tensor(text_mask),
|
71 |
+
'labels': torch.tensor(title_ids)
|
72 |
+
}
|
73 |
+
|
74 |
+
def _word_masking(self, tokenizer, tokenized_inputs, wwm_prob):
|
75 |
+
# randomly mask words
|
76 |
+
input_ids = tokenized_inputs["input_ids"]
|
77 |
+
mask = np.random.binomial(1, wwm_prob, (len(input_ids),))
|
78 |
+
|
79 |
+
labels = list()
|
80 |
+
|
81 |
+
for idx in np.where(mask == 1)[0]:
|
82 |
+
#add special sentinel tokens
|
83 |
+
sentinel_token = tokenizer.additional_special_tokens[input_ids[idx] % 100]
|
84 |
+
|
85 |
+
labels.append(tokenizer(sentinel_token).input_ids[0])
|
86 |
+
labels.append(input_ids[idx])
|
87 |
+
input_ids[idx] = tokenizer(sentinel_token).input_ids[0]
|
88 |
+
|
89 |
+
return input_ids, labels
|
90 |
+
|
91 |
+
def __len__(self):
|
92 |
+
return self.len
|
93 |
+
|
94 |
+
|
95 |
+
# from tqdm.auto import tqdm
|
96 |
+
# from transformers import T5TokenizerFast, T5ForConditionalGeneration
|
97 |
+
# from transformers import DataCollatorForSeq2Seq
|
98 |
+
# from torch.utils.data import DataLoader, random_split
|
99 |
+
# from FTE_NLP.model.summarization_dataset_v1 import *
|
100 |
+
# from torch.optim import AdamW
|
101 |
+
# from transformers import get_scheduler
|
102 |
+
# from FTE_NLP.utils.post_process import *
|
103 |
+
#
|
104 |
+
# json_filename = '../data/raw_EDT/Trading_benchmark/evaluate_news_test.json'
|
105 |
+
# with open(json_filename) as data_file:
|
106 |
+
# test_data = json.loads(data_file.read())
|
107 |
+
#
|
108 |
+
# model_checkpoint = "t5-small"
|
109 |
+
# tokenizer = T5TokenizerFast.from_pretrained(model_checkpoint, model_max_length=512)
|
110 |
+
# model = T5ForConditionalGeneration.from_pretrained(model_checkpoint)
|
111 |
+
#
|
112 |
+
# all_dataset = summarization_data(test_data, tokenizer, domain_adaption=True, wwm_prob=0.1)
|
113 |
+
# train_dataset, eval_dataset = random_split(all_dataset, [7, 3], generator=torch.Generator().manual_seed(42))
|
114 |
+
#
|
115 |
+
# # data_collator = DataCollatorForSeq2Seq(tokenizer, model,label_pad_token_id=0)
|
116 |
+
# data_collator = DataCollatorForSeq2Seq(tokenizer,model)
|
117 |
+
# # pass data to dataloader
|
118 |
+
# train_params = {'batch_size': 2, 'shuffle': False, 'num_workers': 0}
|
119 |
+
# train_loader = DataLoader(train_dataset, collate_fn=data_collator, **train_params)
|
120 |
+
#
|
121 |
+
# eval_params = {'batch_size': 2, 'shuffle': False, 'num_workers': 0}
|
122 |
+
# eval_loader = DataLoader(eval_dataset, collate_fn=data_collator, **train_params)
|
123 |
+
#
|
124 |
+
#
|
125 |
+
# for item in train_loader:
|
126 |
+
# print(item)
|
127 |
+
# print("input id numbers:",item["input_ids"][0])
|
128 |
+
# print("input id: ",tokenizer.decode(item["input_ids"][0]))
|
129 |
+
# print("labels id number:",item["labels"][0])
|
130 |
+
# print("labels:",tokenizer.decode(item["labels"][0],skip_special_tokens=False))
|
131 |
+
# break
|