DD0101 commited on
Commit
ddb3043
1 Parent(s): 02d206c

Upload 4 files

Browse files
model/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .modeling_jointphobert import JointPhoBERT
2
+ from .modeling_jointxlmr import JointXLMR
model/modeling_jointphobert.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torchcrf import CRF
4
+ from transformers.models.roberta.modeling_roberta import RobertaModel, RobertaPreTrainedModel
5
+
6
+ from .module import IntentClassifier, SlotClassifier
7
+
8
+
9
+ class JointPhoBERT(RobertaPreTrainedModel):
10
+ def __init__(self, config, args, intent_label_lst, slot_label_lst):
11
+ super(JointPhoBERT, self).__init__(config)
12
+ self.args = args
13
+ self.num_intent_labels = len(intent_label_lst)
14
+ self.num_slot_labels = len(slot_label_lst)
15
+ self.roberta = RobertaModel(config) # Load pretrained phobert
16
+
17
+ self.intent_classifier = IntentClassifier(config.hidden_size, self.num_intent_labels, args.dropout_rate)
18
+
19
+ self.slot_classifier = SlotClassifier(
20
+ config.hidden_size,
21
+ self.num_intent_labels,
22
+ self.num_slot_labels,
23
+ self.args.use_intent_context_concat,
24
+ self.args.use_intent_context_attention,
25
+ self.args.max_seq_len,
26
+ self.args.attention_embedding_size,
27
+ args.dropout_rate,
28
+ )
29
+
30
+ if args.use_crf:
31
+ self.crf = CRF(num_tags=self.num_slot_labels, batch_first=True)
32
+
33
+ def forward(self, input_ids, attention_mask, token_type_ids, intent_label_ids, slot_labels_ids):
34
+ outputs = self.roberta(
35
+ input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids
36
+ ) # sequence_output, pooled_output, (hidden_states), (attentions)
37
+ sequence_output = outputs[0]
38
+ pooled_output = outputs[1] # [CLS]
39
+
40
+ intent_logits = self.intent_classifier(pooled_output)
41
+ if not self.args.use_attention_mask:
42
+ tmp_attention_mask = None
43
+ else:
44
+ tmp_attention_mask = attention_mask
45
+
46
+ if self.args.embedding_type == "hard":
47
+ hard_intent_logits = torch.zeros(intent_logits.shape)
48
+ for i, sample in enumerate(intent_logits):
49
+ max_idx = torch.argmax(sample)
50
+ hard_intent_logits[i][max_idx] = 1
51
+ slot_logits = self.slot_classifier(sequence_output, hard_intent_logits, tmp_attention_mask)
52
+ else:
53
+ slot_logits = self.slot_classifier(sequence_output, intent_logits, tmp_attention_mask)
54
+
55
+ total_loss = 0
56
+ # 1. Intent Softmax
57
+ if intent_label_ids is not None:
58
+ if self.num_intent_labels == 1:
59
+ intent_loss_fct = nn.MSELoss()
60
+ intent_loss = intent_loss_fct(intent_logits.view(-1), intent_label_ids.view(-1))
61
+ else:
62
+ intent_loss_fct = nn.CrossEntropyLoss()
63
+ intent_loss = intent_loss_fct(
64
+ intent_logits.view(-1, self.num_intent_labels), intent_label_ids.view(-1)
65
+ )
66
+ total_loss += self.args.intent_loss_coef * intent_loss
67
+
68
+ # 2. Slot Softmax
69
+ if slot_labels_ids is not None:
70
+ if self.args.use_crf:
71
+ slot_loss = self.crf(slot_logits, slot_labels_ids, mask=attention_mask.byte(), reduction="mean")
72
+ slot_loss = -1 * slot_loss # negative log-likelihood
73
+ else:
74
+ slot_loss_fct = nn.CrossEntropyLoss(ignore_index=self.args.ignore_index)
75
+ # Only keep active parts of the loss
76
+ if attention_mask is not None:
77
+ active_loss = attention_mask.view(-1) == 1
78
+ active_logits = slot_logits.view(-1, self.num_slot_labels)[active_loss]
79
+ active_labels = slot_labels_ids.view(-1)[active_loss]
80
+ slot_loss = slot_loss_fct(active_logits, active_labels)
81
+ else:
82
+ slot_loss = slot_loss_fct(slot_logits.view(-1, self.num_slot_labels), slot_labels_ids.view(-1))
83
+ total_loss += (1 - self.args.intent_loss_coef) * slot_loss
84
+
85
+ outputs = ((intent_logits, slot_logits),) + outputs[2:] # add hidden states and attention if they are here
86
+
87
+ outputs = (total_loss,) + outputs
88
+
89
+ return outputs # (loss), logits, (hidden_states), (attentions) # Logits is a tuple of intent and slot logits
model/modeling_jointxlmr.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torchcrf import CRF
4
+ from transformers.models.roberta.modeling_roberta import RobertaPreTrainedModel
5
+ from transformers.models.xlm_roberta.modeling_xlm_roberta import XLMRobertaModel
6
+
7
+ from .module import IntentClassifier, SlotClassifier
8
+
9
+
10
+ class JointXLMR(RobertaPreTrainedModel):
11
+ def __init__(self, config, args, intent_label_lst, slot_label_lst):
12
+ super(JointXLMR, self).__init__(config)
13
+ self.args = args
14
+ self.num_intent_labels = len(intent_label_lst)
15
+ self.num_slot_labels = len(slot_label_lst)
16
+ self.roberta = XLMRobertaModel(config) # Load pretrained bert
17
+ self.intent_classifier = IntentClassifier(config.hidden_size, self.num_intent_labels, args.dropout_rate)
18
+ self.slot_classifier = SlotClassifier(
19
+ config.hidden_size,
20
+ self.num_intent_labels,
21
+ self.num_slot_labels,
22
+ self.args.use_intent_context_concat,
23
+ self.args.use_intent_context_attention,
24
+ self.args.max_seq_len,
25
+ self.args.attention_embedding_size,
26
+ args.dropout_rate,
27
+ )
28
+
29
+ if args.use_crf:
30
+ self.crf = CRF(num_tags=self.num_slot_labels, batch_first=True)
31
+
32
+ def forward(self, input_ids, attention_mask, token_type_ids, intent_label_ids, slot_labels_ids):
33
+ outputs = self.roberta(
34
+ input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids
35
+ ) # sequence_output, pooled_output, (hidden_states), (attentions)
36
+ sequence_output = outputs[0]
37
+ pooled_output = outputs[1] # [CLS]
38
+
39
+ intent_logits = self.intent_classifier(pooled_output)
40
+ if not self.args.use_attention_mask:
41
+ tmp_attention_mask = None
42
+ else:
43
+ tmp_attention_mask = attention_mask
44
+
45
+ if self.args.embedding_type == "hard":
46
+ hard_intent_logits = torch.zeros(intent_logits.shape)
47
+ for i, sample in enumerate(intent_logits):
48
+ max_idx = torch.argmax(sample)
49
+ hard_intent_logits[i][max_idx] = 1
50
+ slot_logits = self.slot_classifier(sequence_output, hard_intent_logits, tmp_attention_mask)
51
+ else:
52
+ slot_logits = self.slot_classifier(sequence_output, intent_logits, tmp_attention_mask)
53
+
54
+ total_loss = 0
55
+ # 1. Intent Softmax
56
+ if intent_label_ids is not None:
57
+ if self.num_intent_labels == 1:
58
+ intent_loss_fct = nn.MSELoss()
59
+ intent_loss = intent_loss_fct(intent_logits.view(-1), intent_label_ids.view(-1))
60
+ else:
61
+ intent_loss_fct = nn.CrossEntropyLoss()
62
+ intent_loss = intent_loss_fct(
63
+ intent_logits.view(-1, self.num_intent_labels), intent_label_ids.view(-1)
64
+ )
65
+ total_loss += self.args.intent_loss_coef * intent_loss
66
+
67
+ # 2. Slot Softmax
68
+ if slot_labels_ids is not None:
69
+ if self.args.use_crf:
70
+ slot_loss = self.crf(slot_logits, slot_labels_ids, mask=attention_mask.byte(), reduction="mean")
71
+ slot_loss = -1 * slot_loss # negative log-likelihood
72
+ else:
73
+ slot_loss_fct = nn.CrossEntropyLoss(ignore_index=self.args.ignore_index)
74
+ # Only keep active parts of the loss
75
+ if attention_mask is not None:
76
+ active_loss = attention_mask.view(-1) == 1
77
+ active_logits = slot_logits.view(-1, self.num_slot_labels)[active_loss]
78
+ active_labels = slot_labels_ids.view(-1)[active_loss]
79
+ slot_loss = slot_loss_fct(active_logits, active_labels)
80
+ else:
81
+ slot_loss = slot_loss_fct(slot_logits.view(-1, self.num_slot_labels), slot_labels_ids.view(-1))
82
+ total_loss += (1 - self.args.intent_loss_coef) * slot_loss
83
+
84
+ outputs = ((intent_logits, slot_logits),) + outputs[2:] # add hidden states and attention if they are here
85
+
86
+ outputs = (total_loss,) + outputs
87
+
88
+ return outputs # (loss), logits, (hidden_states), (attentions) # Logits is a tuple of intent and slot logits
model/module.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+
6
+ class Attention(nn.Module):
7
+ """Applies attention mechanism on the `context` using the `query`.
8
+ Args:
9
+ dimensions (int): Dimensionality of the query and context.
10
+ attention_type (str, optional): How to compute the attention score:
11
+
12
+ * dot: :math:`score(H_j,q) = H_j^T q`
13
+ * general: :math:`score(H_j, q) = H_j^T W_a q`
14
+
15
+ Example:
16
+
17
+ >>> attention = Attention(256)
18
+ >>> query = torch.randn(32, 50, 256)
19
+ >>> context = torch.randn(32, 1, 256)
20
+ >>> output, weights = attention(query, context)
21
+ >>> output.size()
22
+ torch.Size([32, 50, 256])
23
+ >>> weights.size()
24
+ torch.Size([32, 50, 1])
25
+ """
26
+
27
+ def __init__(self, dimensions):
28
+ super(Attention, self).__init__()
29
+
30
+ self.dimensions = dimensions
31
+ self.linear_out = nn.Linear(dimensions * 2, dimensions, bias=False)
32
+ self.softmax = nn.Softmax(dim=1)
33
+ self.tanh = nn.Tanh()
34
+
35
+ def forward(self, query, context, attention_mask):
36
+ """
37
+ Args:
38
+ query (:class:`torch.FloatTensor` [batch size, output length, dimensions]): Sequence of
39
+ queries to query the context.
40
+ context (:class:`torch.FloatTensor` [batch size, query length, dimensions]): Data
41
+ overwhich to apply the attention mechanism.
42
+ output length: length of utterance
43
+ query length: length of each token (1)
44
+ Returns:
45
+ :class:`tuple` with `output` and `weights`:
46
+ * **output** (:class:`torch.LongTensor` [batch size, output length, dimensions]):
47
+ Tensor containing the attended features.
48
+ * **weights** (:class:`torch.FloatTensor` [batch size, output length, query length]):
49
+ Tensor containing attention weights.
50
+ """
51
+ # query = self.linear_query(query)
52
+
53
+ batch_size, output_len, hidden_size = query.size()
54
+ # query_len = context.size(1)
55
+
56
+ # (batch_size, output_len, dimensions) * (batch_size, query_len, dimensions) ->
57
+ # (batch_size, output_len, query_len)
58
+ attention_scores = torch.bmm(query, context.transpose(1, 2).contiguous())
59
+ # Compute weights across every context sequence
60
+ # attention_scores = attention_scores.view(batch_size * output_len, query_len)
61
+ if attention_mask is not None:
62
+ # Create attention mask, apply attention mask before softmax
63
+ attention_mask = torch.unsqueeze(attention_mask, 2)
64
+ # attention_mask = attention_mask.view(batch_size * output_len, query_len)
65
+ attention_scores.masked_fill_(attention_mask == 0, -np.inf)
66
+ # attention_scores = torch.squeeze(attention_scores,1)
67
+ attention_weights = self.softmax(attention_scores)
68
+ # attention_weights = attention_weights.view(batch_size, output_len, query_len)
69
+
70
+ # (batch_size, output_len, query_len) * (batch_size, query_len, dimensions) ->
71
+ # (batch_size, output_len, dimensions)
72
+ mix = torch.bmm(attention_weights, context)
73
+ # from IPython import embed; embed()
74
+ # concat -> (batch_size * output_len, 2*dimensions)
75
+ combined = torch.cat((mix, query), dim=2)
76
+ # combined = combined.view(batch_size * output_len, 2 * self.dimensions)
77
+
78
+ # Apply linear_out on every 2nd dimension of concat
79
+ # output -> (batch_size, output_len, dimensions)
80
+ # output = self.linear_out(combined).view(batch_size, output_len, self.dimensions)
81
+ output = self.linear_out(combined)
82
+
83
+ output = self.tanh(output)
84
+ # output = combined
85
+ return output, attention_weights
86
+
87
+
88
+ class IntentClassifier(nn.Module):
89
+ def __init__(self, input_dim, num_intent_labels, dropout_rate=0.0):
90
+ super(IntentClassifier, self).__init__()
91
+ self.dropout = nn.Dropout(dropout_rate)
92
+ self.linear = nn.Linear(input_dim, num_intent_labels)
93
+
94
+ def forward(self, x):
95
+ x = self.dropout(x)
96
+ return self.linear(x)
97
+
98
+
99
+ class SlotClassifier(nn.Module):
100
+ def __init__(
101
+ self,
102
+ input_dim,
103
+ num_intent_labels,
104
+ num_slot_labels,
105
+ use_intent_context_concat=False,
106
+ use_intent_context_attn=False,
107
+ max_seq_len=50,
108
+ attention_embedding_size=200,
109
+ dropout_rate=0.0,
110
+ ):
111
+ super(SlotClassifier, self).__init__()
112
+ self.use_intent_context_attn = use_intent_context_attn
113
+ self.use_intent_context_concat = use_intent_context_concat
114
+ self.max_seq_len = max_seq_len
115
+ self.num_intent_labels = num_intent_labels
116
+ self.num_slot_labels = num_slot_labels
117
+ self.attention_embedding_size = attention_embedding_size
118
+
119
+ output_dim = self.attention_embedding_size # base model
120
+ if self.use_intent_context_concat:
121
+ output_dim = self.attention_embedding_size
122
+ self.linear_out = nn.Linear(2 * attention_embedding_size, attention_embedding_size)
123
+
124
+ elif self.use_intent_context_attn:
125
+ output_dim = self.attention_embedding_size
126
+ self.attention = Attention(attention_embedding_size)
127
+
128
+ self.linear_slot = nn.Linear(input_dim, self.attention_embedding_size, bias=False)
129
+
130
+ if self.use_intent_context_attn or self.use_intent_context_concat:
131
+ # project intent vector and slot vector to have the same dimensions
132
+ self.linear_intent_context = nn.Linear(self.num_intent_labels, self.attention_embedding_size, bias=False)
133
+ self.softmax = nn.Softmax(dim=-1) # softmax layer for intent logits
134
+
135
+ # self.linear_out = nn.Linear(2 * intent_embedding_size, intent_embedding_size)
136
+ # output
137
+ self.dropout = nn.Dropout(dropout_rate)
138
+ self.linear = nn.Linear(output_dim, num_slot_labels)
139
+
140
+ def forward(self, x, intent_context, attention_mask):
141
+ x = self.linear_slot(x)
142
+ if self.use_intent_context_concat:
143
+ intent_context = self.softmax(intent_context)
144
+ intent_context = self.linear_intent_context(intent_context)
145
+ intent_context = torch.unsqueeze(intent_context, 1)
146
+ intent_context = intent_context.expand(-1, self.max_seq_len, -1)
147
+ x = torch.cat((x, intent_context), dim=2)
148
+ x = self.linear_out(x)
149
+
150
+ elif self.use_intent_context_attn:
151
+ intent_context = self.softmax(intent_context)
152
+ intent_context = self.linear_intent_context(intent_context)
153
+ intent_context = torch.unsqueeze(intent_context, 1) # 1: query length (each token)
154
+ output, weights = self.attention(x, intent_context, attention_mask)
155
+ x = output
156
+ x = self.dropout(x)
157
+ return self.linear(x)