Spaces:
Runtime error
Runtime error
File size: 6,089 Bytes
9fdc3cc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 |
import torch
from torch.utils.data import Dataset
from args import args
class items_dataset(Dataset):
def __init__(self, tokenizer, data_set, label_dict, stride=0, max_length=args.max_length):
self.data_set = data_set
self.tokenizer = tokenizer
self.label_dict = label_dict
self.max_length = max_length
self.encode_max_length = max_length-2 #[CLS] [SEP]
self.batch_max_lenght = max_length
self.stride = stride
def __getitem__(self, index):
result = self.data_set[index]
return result
def __len__(self):
return len(self.data_set)
def create_label_list(self, span_label, max_len):
#ans = []
table = torch.zeros(max_len)
for start, end in span_label:
table[start:end] = 2 #"I"
table[start] = 1 #"B"
"""
for label in table.tolist():
if label == 0:
ans.append("O")
elif label == 1:
ans.append("B")
elif label == 2:
ans.append("I")
else:
print("error")
"""
return table
def encode_lable(self, encoded, batch_table):
batch_encode_seq_lens = []
sample_mapping = encoded["overflow_to_sample_mapping"]
offset_mapping = encoded["offset_mapping"]
encoded_label = torch.zeros(len(sample_mapping) ,self.encode_max_length, dtype=torch.long)
for id_in_batch in range(len(sample_mapping)):
encode_len=0
table = batch_table[sample_mapping[id_in_batch]]
for i in range(self.max_length):
char_start, char_end = offset_mapping[id_in_batch][i]
# ignore [CLS], [SEP] token
if char_start!=0 or char_end!=0:
encode_len+=1
#print(encoded_label.shape, table.shape)
encoded_label[id_in_batch][i-1] = table[char_start].long()
batch_encode_seq_lens.append(encode_len)
return encoded_label, batch_encode_seq_lens
def create_crf_mask(self, batch_encode_seq_lens):
mask = torch.zeros(len(batch_encode_seq_lens), self.encode_max_length, dtype=torch.bool)
#print(len(batch_table), len(batch_lens), seq_lens, batch_lens)
for i, batch_len in enumerate(batch_encode_seq_lens):
mask[i][:batch_len]=True
return mask
def boundary_encoded(self, encodings, batch_boundary):
batch_boundary_encoded = []
for batch_id, span_labels in enumerate(batch_boundary):
boundary_encoded = []
end = 0
for boundary in span_labels:
end += boundary
encoded_end = encodings[batch_id].char_to_token(end-1)
#
tmp_end = end
while encoded_end==None and tmp_end>0:
tmp_end-=1
encoded_end = encodings[batch_id].char_to_token(tmp_end-1)
if end!=None: encoded_end+=1
if encoded_end>self.encode_max_length:
boundary_encoded.append(self.encode_max_length)
break
else:
boundary_encoded.append(encoded_end)
for i in range(len(boundary_encoded)-1, 0, -1):
boundary_encoded[i]=boundary_encoded[i]-boundary_encoded[i-1]
batch_boundary_encoded.append(boundary_encoded)
return batch_boundary_encoded
def cal_agreement_span(self, agreement_table, min_agree=2, max_agree=3):
"""
find the spans from agreement table
"""
ans_span=[]
start, end =(0, 0)
pre_p = agreement_table[0]
for i, word_agreement in enumerate(agreement_table):
curr_p = word_agreement
if curr_p != pre_p:
if start != end: ans_span.append([start, end])
start=i
end=i
pre_p = curr_p
if word_agreement<min_agree:
start+=1
if word_agreement<=max_agree:
end+=1
#print([start, end])
pre_p = curr_p
if start != end: ans_span.append([start, end])
#print(ans_span)
if len(ans_span)<=1 or min_agree == max_agree:
return ans_span
#span 合併
span_concate = []
start, end = [ans_span[0][0], ans_span[0][1]]
for span_id in range(1, len(ans_span)):
if ans_span[span_id-1][1]==ans_span[span_id][0]:
ans_span[span_id]=[ans_span[span_id-1][0], ans_span[span_id][1]]
if span_id==len(ans_span)-1: span_concate.append(ans_span[span_id])
#span_concate.append()
elif span_id==len(ans_span)-1:
span_concate.extend([ans_span[span_id-1], ans_span[span_id]])
else:
span_concate.append(ans_span[span_id-1])
return span_concate
def collate_fn(self, batch_sample):
batch_text = []
batch_table = []
batch_span_label= []
seq_lens = []
for sample in batch_sample:
batch_text.append(sample['original_text'])
batch_table.append(self.create_label_list(sample["span_labels"], len(sample['original_text'])))
#batch_boundary = [sample['data_len_c'] for sample in batch_sample]
batch_span_label.append(sample["span_labels"])
seq_lens.append(len(sample['original_text']))
self.batch_max_lenght = max(seq_lens)
if self.batch_max_lenght > self.encode_max_length : self.batch_max_lenght = self.encode_max_length
encoded = self.tokenizer(batch_text, truncation=True, max_length=512, padding='max_length', stride=self.stride, return_overflowing_tokens=True, return_tensors="pt", return_offsets_mapping=True)
#encoded = self.tokenizer(batch_text, truncation=True, padding=True, return_tensors="pt", max_length=self.max_length)
encoded['labels'], batch_encode_seq_lens = self.encode_lable(encoded, batch_table)
encoded["crf_mask"] = self.create_crf_mask(batch_encode_seq_lens)
#encoded["boundary"] = batch_boundary
#encoded["boundary_encode"] = self.boundary_encoded(encoded, batch_boundary)
encoded["span_labels"] = batch_span_label
encoded["batch_text"] = batch_text
return encoded |