|
import json |
|
import re |
|
import torch |
|
import numpy as np |
|
from collections import defaultdict |
|
|
|
class summarization_data: |
|
def __init__(self, raw_data, tokenizer, domain_adaption=False, wwm_prob=0.1): |
|
data = self.select_data(raw_data) |
|
self.data = self.data_clean(data) |
|
self.len = len(self.data) |
|
self.tokenizer = tokenizer |
|
self.padding_max_len = "max_length" |
|
self.domain_adaption = domain_adaption |
|
self.wwm_prob = wwm_prob |
|
|
|
def select_data(self, raw_data): |
|
data = list() |
|
for item in raw_data: |
|
del item['pub_time'] |
|
del item['labels'] |
|
data.append(item) |
|
return data |
|
|
|
def data_clean(self, data): |
|
for item in data: |
|
item["text"] = re.sub(r"[?|!+?|:|(|)]|\\|-|/.*?/|http\S+", "", item["text"].lower()) |
|
item["title"] = re.sub(r"[?|!+?|:|(|)]|\\|-|/.*?/|http\S+", "", item["title"].lower()) |
|
return data |
|
|
|
|
|
def __getitem__(self, index): |
|
if self.domain_adaption: |
|
tokenized_text = self.tokenizer( |
|
self.data[index]["text"], |
|
add_special_tokens=True, |
|
padding="max_length", |
|
return_token_type_ids=True, |
|
truncation=True, |
|
) |
|
text_mask = tokenized_text['attention_mask'] |
|
|
|
input_ids, labels = self._word_masking(self.tokenizer, tokenized_text, self.wwm_prob) |
|
return { |
|
'input_ids': torch.tensor(input_ids), |
|
'attention_mask': torch.tensor(text_mask), |
|
'labels': torch.tensor(labels) |
|
} |
|
|
|
else: |
|
tokenized_text = self.tokenizer( |
|
"summarize:"+self.data[index]["text"], |
|
add_special_tokens=True, |
|
padding="max_length", |
|
return_token_type_ids=True, |
|
truncation=True, |
|
) |
|
text_ids = tokenized_text['input_ids'] |
|
text_mask = tokenized_text['attention_mask'] |
|
|
|
tokenized_title = self.tokenizer( |
|
self.data[index]["title"], |
|
padding="max_length", |
|
return_token_type_ids=True, |
|
truncation=True, |
|
) |
|
title_ids = tokenized_title['input_ids'] |
|
return { |
|
'input_ids': torch.tensor(text_ids), |
|
'attention_mask': torch.tensor(text_mask), |
|
'labels': torch.tensor(title_ids) |
|
} |
|
|
|
def _word_masking(self, tokenizer, tokenized_inputs, wwm_prob): |
|
|
|
input_ids = tokenized_inputs["input_ids"] |
|
mask = np.random.binomial(1, wwm_prob, (len(input_ids),)) |
|
|
|
labels = list() |
|
|
|
for idx in np.where(mask == 1)[0]: |
|
|
|
sentinel_token = tokenizer.additional_special_tokens[input_ids[idx] % 100] |
|
|
|
labels.append(tokenizer(sentinel_token).input_ids[0]) |
|
labels.append(input_ids[idx]) |
|
input_ids[idx] = tokenizer(sentinel_token).input_ids[0] |
|
|
|
return input_ids, labels |
|
|
|
def __len__(self): |
|
return self.len |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|