Punyajoy commited on
Commit
37bf7e5
1 Parent(s): 88a58ac

commit the model file.

Browse files
Files changed (1) hide show
  1. model.py +90 -0
model.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoModelForTokenClassification, AutoModelForSequenceClassification, AdamW, get_linear_schedule_with_warmup, AutoModel
3
+ from transformers import BertForTokenClassification, BertForSequenceClassification,BertPreTrainedModel, BertModel
4
+ import torch.nn as nn
5
+ from .utils import *
6
+ import torch.nn.functional as F
7
+
8
+ from ekphrasis.classes.preprocessor import TextPreProcessor
9
+ from ekphrasis.classes.tokenizer import SocialTokenizer
10
+ from ekphrasis.dicts.emoticons import emoticons
11
+ import re
12
+ from transformers import AutoTokenizer, AutoModelForTokenClassification, AdamW, get_linear_schedule_with_warmup
13
+ from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
14
+
15
+
16
+ class Model_Rational_Label(BertPreTrainedModel):
17
+ def __init__(self,config,params):
18
+ super().__init__(config)
19
+ self.num_labels=params['num_classes']
20
+ self.num_targets=params['targets_num']
21
+ self.impact_factor=params['rationale_impact']
22
+ self.target_factor=params['target_impact']
23
+ self.bert = BertModel(config,add_pooling_layer=False)
24
+ self.pooler=BertPooler(config)
25
+ self.token_dropout = nn.Dropout(0.2)
26
+ self.token_classifier = nn.Linear(config.hidden_size, 2)
27
+ self.dropout = nn.Dropout(0.2)
28
+ self.classifier = nn.Linear(config.hidden_size, self.num_labels)
29
+ self.target_dropout = nn.Dropout(0.2)
30
+ self.target_classifier = nn.Linear(config.hidden_size, self.num_targets)
31
+ self.init_weights()
32
+ # self.embeddings = AutoModelForTokenClassification.from_pretrained(params['model_path'], cache_dir=params['cache_path'])
33
+
34
+ def forward(self, input_ids=None, mask=None, attn=None, labels=None, targets=None):
35
+ outputs = self.bert(input_ids, mask)
36
+ # out = outputs.last_hidden_state
37
+ out=outputs[0]
38
+ logits = self.token_classifier(self.token_dropout(out))
39
+
40
+
41
+ # mean_pooling = torch.mean(out, 1)
42
+ # max_pooling, _ = torch.max(out, 1)
43
+ # embed = torch.cat((mean_pooling, max_pooling), 1)
44
+ embed=self.pooler(outputs[0])
45
+ y_pred = self.classifier(self.dropout(embed))
46
+ y_pred_target = torch.sigmoid(self.target_classifier(self.target_dropout(embed)))
47
+
48
+ loss_token = None
49
+ loss_target= None
50
+ loss_label = None
51
+ loss_total = None
52
+
53
+ if attn is not None:
54
+ loss_fct = nn.CrossEntropyLoss()
55
+ ### Adding weighted
56
+
57
+ # Only keep active parts of the loss
58
+ if mask is not None:
59
+ class_weights=torch.tensor([1.0,1.0],dtype=torch.float).to(input_ids.device)
60
+ loss_funct = nn.CrossEntropyLoss(class_weights)
61
+ active_loss = mask.view(-1) == 1
62
+ active_logits = logits.view(-1, 2)
63
+ active_labels = torch.where(
64
+ active_loss, attn.view(-1), torch.tensor(loss_fct.ignore_index).type_as(attn)
65
+ )
66
+ loss_token = loss_funct(active_logits, active_labels)
67
+ else:
68
+ loss_token = loss_funct(logits.view(-1, 2), attn.view(-1))
69
+
70
+ loss_total=self.impact_factor*loss_token
71
+
72
+ if targets is not None:
73
+ loss_funct = nn.BCELoss()
74
+ loss_logits = loss_funct(y_pred_target.view(-1, self.num_targets), targets.view(-1, self.num_targets))
75
+ loss_targets= loss_logits
76
+ loss_total+=self.target_factor*loss_targets
77
+
78
+
79
+ if labels is not None:
80
+ loss_funct = nn.CrossEntropyLoss()
81
+ loss_logits = loss_funct(y_pred.view(-1, self.num_labels), labels.view(-1))
82
+ loss_label= loss_logits
83
+ if(loss_total is not None):
84
+ loss_total+=loss_label
85
+ else:
86
+ loss_total=loss_label
87
+ if(loss_total is not None):
88
+ return y_pred,y_pred_target,logits, loss_total
89
+ else:
90
+ return y_pred,y_pred_target,logits