Hann99 commited on
Commit
cf9bbdf
1 Parent(s): 45ff3f2

Upload data_utils.py

Browse files
Files changed (1) hide show
  1. data_utils.py +319 -0
data_utils.py ADDED
@@ -0,0 +1,319 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import re
4
+ import six
5
+ import unicodedata
6
+ import torch
7
+ import rouge
8
+ import numpy as np
9
+ import random
10
+ # from fengshen.examples.pegasus.pegasus_utils import text_segmentate
11
+ import sys
12
+
13
+ sys.path.append('../../../')
14
+
15
+ rouge = rouge.Rouge()
16
+
17
+
18
+ is_py2 = six.PY2
19
+
20
+ if not is_py2:
21
+ basestring = str
22
+
23
+
24
+ def _is_chinese_char(cp):
25
+ """Checks whether CP is the codepoint of a CJK character."""
26
+ # This defines a "chinese character" as anything in the CJK Unicode block:
27
+ # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
28
+ #
29
+ # Note that the CJK Unicode block is NOT all Japanese and Korean characters,
30
+ # despite its name. The modern Korean Hangul alphabet is a different block,
31
+ # as is Japanese Hiragana and Katakana. Those alphabets are used to write
32
+ # space-separated words, so they are not treated specially and handled
33
+ # like the all of the other languages.
34
+ if ((cp >= 0x4E00 and cp <= 0x9FFF) or (cp >= 0x3400 and cp <= 0x4DBF)
35
+ or (cp >= 0x20000 and cp <= 0x2A6DF)
36
+ or (cp >= 0x2A700 and cp <= 0x2B73F)
37
+ or (cp >= 0x2B740 and cp <= 0x2B81F)
38
+ or (cp >= 0x2B820 and cp <= 0x2CEAF)
39
+ or (cp >= 0xF900 and cp <= 0xFAFF)
40
+ or (cp >= 0x2F800 and cp <= 0x2FA1F)):
41
+ return True
42
+
43
+ return False
44
+
45
+
46
+ def _is_whitespace(char):
47
+ """Checks whether `char` is a whitespace character."""
48
+ # \t, \n, and \r are technically control characters but we treat them
49
+ # as whitespace since they are generally considered as such.
50
+ if char == " " or char == "\t" or char == "\n" or char == "\r":
51
+ return True
52
+ cat = unicodedata.category(char)
53
+ if cat == "Zs":
54
+ return True
55
+ return False
56
+
57
+
58
+ def _is_control(char):
59
+ """Checks whether `char` is a control character."""
60
+ # These are technically control characters but we count them as whitespace
61
+ # characters.
62
+ if char == "\t" or char == "\n" or char == "\r":
63
+ return False
64
+ cat = unicodedata.category(char)
65
+ if cat.startswith("C"):
66
+ return True
67
+ return False
68
+
69
+
70
+ def _is_punctuation(char):
71
+ """Checks whether `char` is a punctuation character."""
72
+ cp = ord(char)
73
+ # We treat all non-letter/number ASCII as punctuation.
74
+ # Characters such as "^", "$", and "`" are not in the Unicode
75
+ # Punctuation class but we treat them as punctuation anyways, for
76
+ # consistency.
77
+ if (cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or (
78
+ cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126):
79
+ return True
80
+ cat = unicodedata.category(char)
81
+ if cat.startswith("P"):
82
+ return True
83
+ return False
84
+
85
+
86
+ def is_string(s):
87
+ """判断是否是字符串
88
+ """
89
+ return isinstance(s, basestring)
90
+
91
+
92
+ def is_stopwords(word, stopwords):
93
+ if word in stopwords:
94
+ return True
95
+ else:
96
+ return False
97
+
98
+
99
+ def text_segmentate(text):
100
+ en_seg_pattern = '((?:\\!|\\?|\\.|\\n)+(?:\\s)+)'
101
+ ch_seg_pattern = '((?:?|!|。|\\n)+)'
102
+ try:
103
+ text = re.sub(en_seg_pattern, r'\1[SEP]', text)
104
+ # print("sub text: ", text)
105
+ except Exception as e:
106
+ print("input: ", text)
107
+ raise e
108
+ text = re.sub(ch_seg_pattern, r'\1[SEP]', text)
109
+ # print("sub ch text: ", text)
110
+ text_list = text.split("[SEP]")
111
+ text_list = list(filter(lambda x: len(x) != 0, text_list))
112
+ return text_list
113
+
114
+
115
+ def load_stopwords(stopwords_path):
116
+ stopwords_dict = {}
117
+ with open(stopwords_path, "r") as rf:
118
+ for line in rf:
119
+ line = line.strip()
120
+ if line not in stopwords_dict:
121
+ stopwords_dict[line] = 0
122
+ else:
123
+ pass
124
+ return stopwords_dict
125
+
126
+
127
+ def text_process(text, max_length):
128
+ """分割文本
129
+ """
130
+ texts = text_segmentate(text)
131
+
132
+ result, length = [], 0
133
+ for text in texts:
134
+ if length + len(text) > max_length * 1.3 and len(result) >= 3:
135
+ yield result
136
+ result, length = [], 0
137
+ result.append(text)
138
+ length += len(text)
139
+ if result and len(result) >= 3:
140
+ yield result
141
+
142
+
143
+ def text_process_split_long_content(text, max_length):
144
+ """分割长文本
145
+ """
146
+ texts = text_segmentate(text)
147
+
148
+ result, sentence_num = "", 0
149
+ for text in texts:
150
+ if len(text) > 500:
151
+ if len(result) > 300 and sentence_num >= 3:
152
+ yield result
153
+ result, sentence_num = "", 0
154
+ else:
155
+ result, sentence_num = "", 0
156
+ continue
157
+ else:
158
+ if len(result) + len(text) > max_length * 1.1 and sentence_num >= 3:
159
+ yield result
160
+ result, sentence_num = "", 0
161
+ result += text
162
+ sentence_num += 1
163
+
164
+ if result and sentence_num >= 3:
165
+ yield result
166
+
167
+
168
+ def gather_join(texts, idxs):
169
+ """取出对应的text,然后拼接起来
170
+ """
171
+ return ''.join([texts[i] for i in idxs])
172
+
173
+
174
+ def gather_join_f1(texts_token, idsx):
175
+ join_texts = []
176
+ for id in idsx:
177
+ join_texts.extend(texts_token[id])
178
+ return join_texts
179
+
180
+
181
+ def compute_rouge(source, target):
182
+ """计算rouge-1、rouge-2、rouge-l
183
+ """
184
+ source, target = ' '.join(source), ' '.join(target)
185
+ try:
186
+ scores = rouge.get_scores(hyps=source, refs=target)
187
+ return {
188
+ 'rouge-1': scores[0]['rouge-1']['f'],
189
+ 'rouge-2': scores[0]['rouge-2']['f'],
190
+ 'rouge-l': scores[0]['rouge-l']['f'],
191
+ }
192
+ except ValueError:
193
+ return {
194
+ 'rouge-1': 0.0,
195
+ 'rouge-2': 0.0,
196
+ 'rouge-l': 0.0,
197
+ }
198
+
199
+
200
+ def remove_stopwords(texts, stopwords_dict):
201
+ for i, text in enumerate(texts):
202
+ texts[i] = list(filter(lambda x: x not in stopwords_dict, text))
203
+ return texts
204
+
205
+
206
+ def pseudo_summary_f1(texts,
207
+ stopwords,
208
+ tokenizer,
209
+ max_length,
210
+ rouge_strategy="rouge-l"):
211
+ """构建伪标签摘要数据集
212
+ """
213
+ summary_rate = 0.25
214
+ max_length = max_length - 1
215
+ texts_tokens = []
216
+ sentece_idxs_vec = []
217
+ for text in texts:
218
+ if len(texts) == 0:
219
+ continue
220
+ try:
221
+ ids = tokenizer.encode(text.strip())[:-1]
222
+ except ValueError:
223
+ print("error, input : ", text)
224
+ raise ValueError
225
+ sentece_idxs_vec.append(ids)
226
+ tokens = [tokenizer._convert_id_to_token(token) for token in ids]
227
+ texts_tokens.append(tokens)
228
+
229
+ texts_tokens_rm = remove_stopwords(texts_tokens, stopwords)
230
+ source_idxs, target_idxs = list(range(len(texts))), []
231
+
232
+ assert len(texts_tokens) == len(texts)
233
+ # truncate_index = 0
234
+ while True:
235
+ sims = []
236
+ for i in source_idxs:
237
+ new_source_idxs = [j for j in source_idxs if j != i]
238
+ new_target_idxs = sorted(target_idxs + [i])
239
+ new_source = gather_join_f1(texts_tokens_rm, new_source_idxs)
240
+ new_target = gather_join_f1(texts_tokens_rm, new_target_idxs)
241
+ sim = compute_rouge(new_source, new_target)[rouge_strategy]
242
+ sims.append(sim)
243
+ new_idx = source_idxs[np.argmax(sims)]
244
+ del sims
245
+ source_idxs.remove(new_idx)
246
+ target_idxs = sorted(target_idxs + [new_idx])
247
+ source = gather_join(texts, source_idxs)
248
+ target = gather_join(texts, target_idxs)
249
+ try:
250
+ if (len(source_idxs) == 1
251
+ or 1.0 * len(target) / len(source) > summary_rate):
252
+ break
253
+ except ZeroDivisionError as e:
254
+ print(e.meesage)
255
+ print(texts)
256
+ print("source: ", source)
257
+ print("target: ", target)
258
+
259
+ if len(source) < len(target):
260
+ source, target = target, source
261
+ source_idxs, target_idxs = target_idxs, source_idxs
262
+
263
+ return sentece_idxs_vec, source, target, source_idxs, target_idxs
264
+
265
+
266
+ def get_input_mask(sentence_id_vec, indexs):
267
+ target_idxs = []
268
+ input_idxs = []
269
+ kMaskSentenceTokenId = 2
270
+ kEosTokenId = 1
271
+ mask_sentence_options_cumulative_prob = [0.9, 0.9, 1, 1]
272
+ for index in indexs:
273
+ target_idxs.extend(sentence_id_vec[index])
274
+ choice = random.uniform(0, 1)
275
+ if choice < mask_sentence_options_cumulative_prob[0]:
276
+ # print("mask index: ", index)
277
+ sentence_id_vec[index] = [kMaskSentenceTokenId]
278
+ elif choice < mask_sentence_options_cumulative_prob[1]:
279
+ # print("replace index: ", index)
280
+ replace_id = random.randint(0, len(sentence_id_vec))
281
+ sentence_id_vec[index] = sentence_id_vec[replace_id]
282
+ elif choice < mask_sentence_options_cumulative_prob[2]:
283
+ pass
284
+ else:
285
+ sentence_id_vec[index] = []
286
+
287
+ target_idxs.append(kEosTokenId)
288
+ # print(sentence_id_vec)
289
+ for index, sentence_id in enumerate(sentence_id_vec):
290
+ # print(index, sentence_id)
291
+ if len(sentence_id) == 0:
292
+ continue
293
+ input_idxs.extend(sentence_id_vec[index])
294
+
295
+ input_idxs.append(kEosTokenId)
296
+ return input_idxs, target_idxs
297
+
298
+
299
+ def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int,
300
+ decoder_start_token_id: int):
301
+ """
302
+ Shift input ids one token to the right.
303
+ """
304
+ shifted_input_ids = input_ids.new_zeros(input_ids.shape)
305
+ shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
306
+ shifted_input_ids[:, 0] = decoder_start_token_id
307
+
308
+ if pad_token_id is None:
309
+ raise ValueError("self.model.config.pad_token_id has to be defined.")
310
+ # replace possible -100 values in labels by `pad_token_id`
311
+ shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
312
+
313
+ return shifted_input_ids
314
+
315
+
316
+ def padding_to_maxlength(ids, max_length, pad_id):
317
+ cur_len = len(ids)
318
+ len_diff = max_length - cur_len
319
+ return ids + [pad_id] * len_diff, [1] * cur_len + [0] * len_diff