SHSH0819 commited on
Commit
0a0161b
1 Parent(s): 10058cd

Upload event_detection_dataset.py

Browse files
Files changed (1) hide show
  1. event_detection_dataset.py +77 -0
event_detection_dataset.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from torch.utils.data import Dataset
4
+ from collections import defaultdict
5
+
6
+
7
+
8
+ class event_detection_data(Dataset):
9
+ def __init__(self, raw_data, tokenizer, max_len, domain_adaption=False, wwm_prob=0.1):
10
+ self.len = len(raw_data)
11
+ self.data = raw_data
12
+ self.tokenizer = tokenizer
13
+ self.max_len = max_len
14
+ self.domain_adaption = domain_adaption
15
+ self.wwm_prob = wwm_prob
16
+
17
+ def __getitem__(self, index):
18
+ tokenized_inputs = self.tokenizer(
19
+ self.data[index]["text"],
20
+ add_special_tokens=True,
21
+ max_length=self.max_len,
22
+ padding='max_length',
23
+ return_token_type_ids=True,
24
+ truncation=True,
25
+ is_split_into_words=True
26
+ )
27
+
28
+ ids = tokenized_inputs['input_ids']
29
+ mask = tokenized_inputs['attention_mask']
30
+
31
+ if self.domain_adaption:
32
+ if self.tokenizer.is_fast:
33
+ input_ids, labels = self._whole_word_masking(self.tokenizer, tokenized_inputs, self.wwm_prob)
34
+ return {
35
+ 'input_ids': torch.tensor(input_ids),
36
+ 'attention_mask': torch.tensor(mask),
37
+ 'labels': torch.tensor(labels, dtype=torch.long)
38
+ }
39
+ else:
40
+ print("requires fast tokenizer for word_ids")
41
+ else:
42
+ return {
43
+ 'input_ids': torch.tensor(ids),
44
+ 'attention_mask': torch.tensor(mask),
45
+ 'targets': torch.tensor(self.data[index]["text_tag_id"][0], dtype=torch.long)
46
+ }
47
+
48
+ def __len__(self):
49
+ return self.len
50
+
51
+ def _whole_word_masking(self, tokenizer, tokenized_inputs, wwm_prob):
52
+ word_ids = tokenized_inputs.word_ids(0)
53
+
54
+ # create a map between words_ids and natural id
55
+ mapping = defaultdict(list)
56
+ current_word_index = -1
57
+ current_word = None
58
+
59
+ for idx, word_id in enumerate(word_ids):
60
+ if word_id is not None:
61
+ if word_id != current_word:
62
+ current_word = word_id
63
+ current_word_index += 1
64
+ mapping[current_word_index].append(idx)
65
+
66
+ # randomly mask words
67
+ mask = np.random.binomial(1, wwm_prob, (len(mapping),))
68
+ input_ids = tokenized_inputs["input_ids"]
69
+
70
+ # labels only contains masked words as target
71
+ labels = [-100] * len(input_ids)
72
+
73
+ for word_id in np.where(mask == 1)[0]:
74
+ for idx in mapping[word_id]:
75
+ labels[idx] = tokenized_inputs["input_ids"][idx]
76
+ input_ids[idx] = tokenizer.mask_token_id
77
+ return input_ids, labels