gauneg commited on
Commit
c490f2d
·
1 Parent(s): 96e5a22

commit files to HF hub

Browse files
bert_gts_pretrained.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoTokenizer, AutoModel, PreTrainedModel, PretrainedConfig
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class GTSBertBaseABSATripleConfig(PretrainedConfig):
7
+ def __init__(self, feat_dim = 768, max_len=64, class_num=6, **kwargs):
8
+ super().__init__(**kwargs)
9
+ self.feat_dim = feat_dim
10
+ self.max_len = max_len
11
+ self.class_num = class_num
12
+
13
+ class GTSBertBaseABSATriple(PreTrainedModel):
14
+ config_class = GTSBertBaseABSATripleConfig
15
+ def __init__(self, config):
16
+ model_id = 'google-bert/bert-base-uncased'
17
+ super().__init__(config)
18
+ self.model = AutoModel.from_pretrained(model_id)
19
+ self.max_seq_len = config.max_len
20
+ self.bert_feat_dim = config.feat_dim#768
21
+ self.class_num = config.class_num#6
22
+ self.cls_linear = torch.nn.Linear(self.bert_feat_dim*2, self.class_num)
23
+ self.feature_linear = torch.nn.Linear(self.bert_feat_dim*2+self.class_num*3, self.bert_feat_dim*2)
24
+ self.dropout_output = torch.nn.Dropout(0.1)
25
+ self.post_init()
26
+
27
+
28
+ def multi_hops(self, features, mask, k):
29
+ max_length = features.shape[1]
30
+ mask = mask[:, :max_length]
31
+ mask_a = mask.unsqueeze(1).expand([-1, max_length, -1])
32
+ mask_b = mask.unsqueeze(2).expand([-1, -1, max_length])
33
+ mask = mask_a * mask_b
34
+ mask = torch.triu(mask).unsqueeze(3).expand([-1, -1, -1, self.class_num])
35
+
36
+ '''save all logits'''
37
+ logits_list = []
38
+ logits = self.cls_linear(features)
39
+ logits_list.append(logits)
40
+ for i in range(k):
41
+ #probs = torch.softmax(logits, dim=3)
42
+ probs = logits
43
+ logits = probs * mask
44
+ logits_a = torch.max(logits, dim=1)[0]
45
+ logits_b = torch.max(logits, dim=2)[0]
46
+ logits = torch.cat([logits_a.unsqueeze(3), logits_b.unsqueeze(3)], dim=3)
47
+ logits = torch.max(logits, dim=3)[0]
48
+
49
+ logits = logits.unsqueeze(2).expand([-1,-1, max_length, -1])
50
+ logits_T = logits.transpose(1, 2)
51
+ logits = torch.cat([logits, logits_T], dim=3)
52
+
53
+ new_features = torch.cat([features, logits, probs], dim=3)
54
+ features = self.feature_linear(new_features)
55
+ logits = self.cls_linear(features)
56
+ logits_list.append(logits)
57
+ return logits_list
58
+
59
+ def forward(self, input_ids, attention_masks, labels=None): # rename if required
60
+ model_feature = self.model(input_ids, attention_masks)
61
+ model_feature = model_feature.last_hidden_state.detach()
62
+ bert_feature = self.dropout_output(model_feature)
63
+ bert_feature = bert_feature.unsqueeze(2).expand([-1, -1, self.max_seq_len, -1])
64
+ bert_feature_T = bert_feature.transpose(1, 2)
65
+ features = torch.cat([bert_feature, bert_feature_T], dim=3)
66
+ logits = self.multi_hops(features, attention_masks, 1)
67
+ fin_logits = logits[-1]
68
+ loss = None
69
+ if labels is not None:
70
+ ## preforming the loss operation, crosscheck with the previous impl
71
+ gold_floss = labels.reshape([-1])
72
+ pred_floss = fin_logits.reshape([-1, fin_logits.shape[3]])
73
+ loss = F.cross_entropy(pred_floss, gold_floss, ignore_index=-1)
74
+ return {'logits': fin_logits, 'loss': loss}
75
+
config.json ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "MultiInferBertUncased"
4
+ ],
5
+ "class_num": 6,
6
+ "feat_dim": 768,
7
+ "max_len": 64,
8
+ "model_type": "gts_opinion_triple",
9
+ "torch_dtype": "float32",
10
+ "transformers_version": "4.42.3"
11
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cf4d4d02fb6b5d28e9c8008034866fad48cc684579d46c14b8521bdc0e98b736
3
+ size 447543680
post.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+
4
+ class DecodeAndEvaluate:
5
+ def __init__(self, tokenizer):
6
+ self.tokenizer = tokenizer
7
+ self.sentiment2id = {'negative': 3, 'neutral': 4, 'positive': 5}
8
+ self.id2sentiment = {v:k for k, v in self.sentiment2id.items()}
9
+
10
+ def get_span_from_tags(self, tags, token_range, tok_type): ## tok_type 1=aspect, 2 for opinions
11
+ sel_spans = []
12
+ end_ind = -1
13
+ has_prev = False
14
+ start_ind = -1
15
+ for i in range(len(token_range)):
16
+ l,r = token_range[i]
17
+ if tags[l][l]!= tok_type:
18
+ if has_prev:
19
+ sel_spans.append([start_ind, end_ind])
20
+ start_ind = -1
21
+ end_ind= -1
22
+ has_prev = False
23
+ if tags[l][l] == tok_type and not has_prev:
24
+ start_ind = l
25
+ end_ind = r
26
+ has_prev = True
27
+ if tags[l][l] == tok_type and has_prev:
28
+ end_ind = r
29
+ has_prev = True
30
+ if has_prev:
31
+ sel_spans.append([start_ind, end_ind])
32
+
33
+ return sel_spans
34
+
35
+ ## Corner cases where one sentiment span expresses over multiple sentiments
36
+ # and one aspect has multiple sentiments expressed on it
37
+ def find_triplet(self, tags, aspect_spans, opinion_spans):
38
+ triplets = []
39
+ for al, ar in aspect_spans:
40
+ for pl, pr in opinion_spans:
41
+ ## get the overlapping indices
42
+ # we select such that tag[aspect_l :aspect_r+1, opi_l: opi_r]
43
+ # if opi>asp then lower triangular matrix starts being selected that is not annotated
44
+ # print(al, ar, pl, pr)
45
+ if al<=pl:
46
+ sent_tags = tags[al:ar+1, pl:pr+1]
47
+ flat_tags = sent_tags.reshape([-1])
48
+ flat_tags = torch.tensor([v.item() for v in flat_tags if v.item()>=0])
49
+ val = torch.mode(flat_tags).values.item()
50
+ if val > 0:
51
+ triplets.append([al, ar, pl, pr, val])
52
+ else: # In this case the aspect becomes column and sentiment becomes the row
53
+ # print(al, pl)
54
+ sent_tags = tags[pl:pr+1, al: ar+1]
55
+ # print(sent_tags)
56
+ flat_tags = sent_tags.reshape([-1])
57
+ flat_tags = torch.tensor([v.item() for v in flat_tags if v.item()>=0])
58
+ val = torch.mode(flat_tags).values.item()
59
+ if val>0:
60
+ triplets.append([al, ar, pl, pr, val])
61
+ return triplets
62
+
63
+ def decode_triplets(self, triplets, sent_tokens):
64
+ triplet_list = []
65
+ for alt, art, olt, ort, pol in triplets:
66
+ asp_toks = sent_tokens[alt:art+1]
67
+ op_toks = sent_tokens[olt: ort+1]
68
+ asp_string = self.tokenizer.decode(asp_toks)
69
+ op_string = self.tokenizer.decode(op_toks)
70
+ if pol in [3, 4, 5]:
71
+ sentiment_pol = self.id2sentiment[pol] #.get(pol, "inconsistent")
72
+ triplet_list.append([asp_string, op_string, sentiment_pol])
73
+ return triplet_list
74
+
75
+ def decode_predict_one(self, tags, token_range, sent_tokens):
76
+ aspect_spans = self.get_span_from_tags(tags, token_range, 1)
77
+ opinion_spans = self.get_span_from_tags(tags, token_range, 2)
78
+ triplets = self.find_triplet(tags, aspect_spans, opinion_spans)
79
+ return self.decode_triplets(triplets, sent_tokens)
80
+
81
+
82
+ def decode_pred_batch(self, tags_batch, token_range_batch, sent_tokens):
83
+ decoded_batch_results = []
84
+ for i in range(tags_batch.shape[0]):
85
+ res = self.decode_predict_one(tags_batch[i], token_range_batch[i], sent_tokens[i])
86
+ decoded_batch_results.append(res)
87
+ return decoded_batch_results
88
+
89
+ def decode_predict_string_one(self, text_sent, model, max_len=64):
90
+ token_range = []
91
+ words = text_sent.strip().split()
92
+ bert_tokens_padding = torch.zeros(max_len).long()
93
+ bert_tokens = self.tokenizer.encode(text_sent) # tokenization (in sub-words)
94
+
95
+ tok_length = len(bert_tokens)
96
+ if tok_length>max_len:
97
+ raise Exception(f'Sub word length exceeded `maxlen` (>{max_len})')
98
+ # this maps (token_start, token_end)
99
+ #
100
+ token_start=1
101
+ for i, w, in enumerate(words):
102
+ token_end = token_start + len(self.tokenizer.encode(w, add_special_tokens=False))
103
+ token_range.append([token_start, token_end-1])
104
+ token_start = token_end
105
+
106
+ bert_tokens_padding[:tok_length] = torch.tensor(bert_tokens).long()
107
+ attention_mask = torch.zeros(max_len).long()
108
+ attention_mask[:tok_length]=1
109
+
110
+ tags_pred = model(bert_tokens_padding.unsqueeze(0),
111
+ attention_masks=attention_mask.unsqueeze(0))
112
+
113
+ tags = tags_pred['logits'][0].argmax(dim=-1)
114
+ return self.decode_predict_one(tags, token_range, bert_tokens)
115
+
116
+
117
+
118
+ def get_batch_tp_fp_tn(self, tags_batch, token_range_batch, sent_tokens, gold_labels):
119
+
120
+ batch_results = self.decode_pred_batch(tags_batch, token_range_batch, sent_tokens)
121
+ flat_gold, flat_pred = [], []
122
+
123
+ for preds, golds in list(zip(batch_results, gold_labels)):
124
+ for pred in preds:
125
+ flat_pred.append("-".join(pred))
126
+ for gold in golds:
127
+ flat_gold.append("-".join(gold))
128
+ gold_set = set(flat_gold)
129
+ pred_set = set(flat_pred)
130
+ tp = len(gold_set & pred_set)
131
+ fp = len(pred_set - gold_set)
132
+ fn = len(gold_set - pred_set)
133
+
134
+ return tp, fp, fn
135
+
special_tokens_map.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "cls_token": "[CLS]",
3
+ "mask_token": "[MASK]",
4
+ "pad_token": "[PAD]",
5
+ "sep_token": "[SEP]",
6
+ "unk_token": "[UNK]"
7
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "0": {
4
+ "content": "[PAD]",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false,
9
+ "special": true
10
+ },
11
+ "100": {
12
+ "content": "[UNK]",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false,
17
+ "special": true
18
+ },
19
+ "101": {
20
+ "content": "[CLS]",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false,
25
+ "special": true
26
+ },
27
+ "102": {
28
+ "content": "[SEP]",
29
+ "lstrip": false,
30
+ "normalized": false,
31
+ "rstrip": false,
32
+ "single_word": false,
33
+ "special": true
34
+ },
35
+ "103": {
36
+ "content": "[MASK]",
37
+ "lstrip": false,
38
+ "normalized": false,
39
+ "rstrip": false,
40
+ "single_word": false,
41
+ "special": true
42
+ }
43
+ },
44
+ "clean_up_tokenization_spaces": true,
45
+ "cls_token": "[CLS]",
46
+ "do_lower_case": true,
47
+ "mask_token": "[MASK]",
48
+ "model_max_length": 512,
49
+ "pad_token": "[PAD]",
50
+ "sep_token": "[SEP]",
51
+ "strip_accents": null,
52
+ "tokenize_chinese_chars": true,
53
+ "tokenizer_class": "BertTokenizer",
54
+ "unk_token": "[UNK]"
55
+ }
vocab.txt ADDED
The diff for this file is too large to render. See raw diff