SHSH0819 commited on
Commit
6ef99d1
1 Parent(s): 78e069f

Upload summarization_dataset.py

Browse files
Files changed (1) hide show
  1. summarization_dataset.py +131 -0
summarization_dataset.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import re
3
+ import torch
4
+ import numpy as np
5
+ from collections import defaultdict
6
+
7
+ class summarization_data:
8
+ def __init__(self, raw_data, tokenizer, domain_adaption=False, wwm_prob=0.1):
9
+ data = self.select_data(raw_data)
10
+ self.data = self.data_clean(data)
11
+ self.len = len(self.data)
12
+ self.tokenizer = tokenizer
13
+ self.padding_max_len = "max_length"
14
+ self.domain_adaption = domain_adaption
15
+ self.wwm_prob = wwm_prob
16
+
17
+ def select_data(self, raw_data):
18
+ data = list()
19
+ for item in raw_data:
20
+ del item['pub_time']
21
+ del item['labels']
22
+ data.append(item)
23
+ return data
24
+
25
+ def data_clean(self, data):
26
+ for item in data:
27
+ item["text"] = re.sub(r"[?|!+?|:|(|)]|\\|-|/.*?/|http\S+", "", item["text"].lower())
28
+ item["title"] = re.sub(r"[?|!+?|:|(|)]|\\|-|/.*?/|http\S+", "", item["title"].lower())
29
+ return data
30
+
31
+
32
+ def __getitem__(self, index):
33
+ if self.domain_adaption:
34
+ tokenized_text = self.tokenizer(
35
+ self.data[index]["text"],
36
+ add_special_tokens=True,
37
+ padding="max_length",
38
+ return_token_type_ids=True,
39
+ truncation=True,
40
+ )
41
+ text_mask = tokenized_text['attention_mask']
42
+
43
+ input_ids, labels = self._word_masking(self.tokenizer, tokenized_text, self.wwm_prob)
44
+ return {
45
+ 'input_ids': torch.tensor(input_ids),
46
+ 'attention_mask': torch.tensor(text_mask),
47
+ 'labels': torch.tensor(labels)
48
+ }
49
+
50
+ else:
51
+ tokenized_text = self.tokenizer(
52
+ "summarize:"+self.data[index]["text"],
53
+ add_special_tokens=True,
54
+ padding="max_length",
55
+ return_token_type_ids=True,
56
+ truncation=True,
57
+ )
58
+ text_ids = tokenized_text['input_ids']
59
+ text_mask = tokenized_text['attention_mask']
60
+
61
+ tokenized_title = self.tokenizer(
62
+ self.data[index]["title"],
63
+ padding="max_length",
64
+ return_token_type_ids=True,
65
+ truncation=True,
66
+ )
67
+ title_ids = tokenized_title['input_ids']
68
+ return {
69
+ 'input_ids': torch.tensor(text_ids),
70
+ 'attention_mask': torch.tensor(text_mask),
71
+ 'labels': torch.tensor(title_ids)
72
+ }
73
+
74
+ def _word_masking(self, tokenizer, tokenized_inputs, wwm_prob):
75
+ # randomly mask words
76
+ input_ids = tokenized_inputs["input_ids"]
77
+ mask = np.random.binomial(1, wwm_prob, (len(input_ids),))
78
+
79
+ labels = list()
80
+
81
+ for idx in np.where(mask == 1)[0]:
82
+ #add special sentinel tokens
83
+ sentinel_token = tokenizer.additional_special_tokens[input_ids[idx] % 100]
84
+
85
+ labels.append(tokenizer(sentinel_token).input_ids[0])
86
+ labels.append(input_ids[idx])
87
+ input_ids[idx] = tokenizer(sentinel_token).input_ids[0]
88
+
89
+ return input_ids, labels
90
+
91
+ def __len__(self):
92
+ return self.len
93
+
94
+
95
+ # from tqdm.auto import tqdm
96
+ # from transformers import T5TokenizerFast, T5ForConditionalGeneration
97
+ # from transformers import DataCollatorForSeq2Seq
98
+ # from torch.utils.data import DataLoader, random_split
99
+ # from FTE_NLP.model.summarization_dataset_v1 import *
100
+ # from torch.optim import AdamW
101
+ # from transformers import get_scheduler
102
+ # from FTE_NLP.utils.post_process import *
103
+ #
104
+ # json_filename = '../data/raw_EDT/Trading_benchmark/evaluate_news_test.json'
105
+ # with open(json_filename) as data_file:
106
+ # test_data = json.loads(data_file.read())
107
+ #
108
+ # model_checkpoint = "t5-small"
109
+ # tokenizer = T5TokenizerFast.from_pretrained(model_checkpoint, model_max_length=512)
110
+ # model = T5ForConditionalGeneration.from_pretrained(model_checkpoint)
111
+ #
112
+ # all_dataset = summarization_data(test_data, tokenizer, domain_adaption=True, wwm_prob=0.1)
113
+ # train_dataset, eval_dataset = random_split(all_dataset, [7, 3], generator=torch.Generator().manual_seed(42))
114
+ #
115
+ # # data_collator = DataCollatorForSeq2Seq(tokenizer, model,label_pad_token_id=0)
116
+ # data_collator = DataCollatorForSeq2Seq(tokenizer,model)
117
+ # # pass data to dataloader
118
+ # train_params = {'batch_size': 2, 'shuffle': False, 'num_workers': 0}
119
+ # train_loader = DataLoader(train_dataset, collate_fn=data_collator, **train_params)
120
+ #
121
+ # eval_params = {'batch_size': 2, 'shuffle': False, 'num_workers': 0}
122
+ # eval_loader = DataLoader(eval_dataset, collate_fn=data_collator, **train_params)
123
+ #
124
+ #
125
+ # for item in train_loader:
126
+ # print(item)
127
+ # print("input id numbers:",item["input_ids"][0])
128
+ # print("input id: ",tokenizer.decode(item["input_ids"][0]))
129
+ # print("labels id number:",item["labels"][0])
130
+ # print("labels:",tokenizer.decode(item["labels"][0],skip_special_tokens=False))
131
+ # break