Spaces:
Sleeping
Sleeping
Upload 4 files
Browse files- model/__init__.py +2 -0
- model/modeling_jointphobert.py +89 -0
- model/modeling_jointxlmr.py +88 -0
- model/module.py +157 -0
model/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .modeling_jointphobert import JointPhoBERT
|
2 |
+
from .modeling_jointxlmr import JointXLMR
|
model/modeling_jointphobert.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from torchcrf import CRF
|
4 |
+
from transformers.models.roberta.modeling_roberta import RobertaModel, RobertaPreTrainedModel
|
5 |
+
|
6 |
+
from .module import IntentClassifier, SlotClassifier
|
7 |
+
|
8 |
+
|
9 |
+
class JointPhoBERT(RobertaPreTrainedModel):
|
10 |
+
def __init__(self, config, args, intent_label_lst, slot_label_lst):
|
11 |
+
super(JointPhoBERT, self).__init__(config)
|
12 |
+
self.args = args
|
13 |
+
self.num_intent_labels = len(intent_label_lst)
|
14 |
+
self.num_slot_labels = len(slot_label_lst)
|
15 |
+
self.roberta = RobertaModel(config) # Load pretrained phobert
|
16 |
+
|
17 |
+
self.intent_classifier = IntentClassifier(config.hidden_size, self.num_intent_labels, args.dropout_rate)
|
18 |
+
|
19 |
+
self.slot_classifier = SlotClassifier(
|
20 |
+
config.hidden_size,
|
21 |
+
self.num_intent_labels,
|
22 |
+
self.num_slot_labels,
|
23 |
+
self.args.use_intent_context_concat,
|
24 |
+
self.args.use_intent_context_attention,
|
25 |
+
self.args.max_seq_len,
|
26 |
+
self.args.attention_embedding_size,
|
27 |
+
args.dropout_rate,
|
28 |
+
)
|
29 |
+
|
30 |
+
if args.use_crf:
|
31 |
+
self.crf = CRF(num_tags=self.num_slot_labels, batch_first=True)
|
32 |
+
|
33 |
+
def forward(self, input_ids, attention_mask, token_type_ids, intent_label_ids, slot_labels_ids):
|
34 |
+
outputs = self.roberta(
|
35 |
+
input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids
|
36 |
+
) # sequence_output, pooled_output, (hidden_states), (attentions)
|
37 |
+
sequence_output = outputs[0]
|
38 |
+
pooled_output = outputs[1] # [CLS]
|
39 |
+
|
40 |
+
intent_logits = self.intent_classifier(pooled_output)
|
41 |
+
if not self.args.use_attention_mask:
|
42 |
+
tmp_attention_mask = None
|
43 |
+
else:
|
44 |
+
tmp_attention_mask = attention_mask
|
45 |
+
|
46 |
+
if self.args.embedding_type == "hard":
|
47 |
+
hard_intent_logits = torch.zeros(intent_logits.shape)
|
48 |
+
for i, sample in enumerate(intent_logits):
|
49 |
+
max_idx = torch.argmax(sample)
|
50 |
+
hard_intent_logits[i][max_idx] = 1
|
51 |
+
slot_logits = self.slot_classifier(sequence_output, hard_intent_logits, tmp_attention_mask)
|
52 |
+
else:
|
53 |
+
slot_logits = self.slot_classifier(sequence_output, intent_logits, tmp_attention_mask)
|
54 |
+
|
55 |
+
total_loss = 0
|
56 |
+
# 1. Intent Softmax
|
57 |
+
if intent_label_ids is not None:
|
58 |
+
if self.num_intent_labels == 1:
|
59 |
+
intent_loss_fct = nn.MSELoss()
|
60 |
+
intent_loss = intent_loss_fct(intent_logits.view(-1), intent_label_ids.view(-1))
|
61 |
+
else:
|
62 |
+
intent_loss_fct = nn.CrossEntropyLoss()
|
63 |
+
intent_loss = intent_loss_fct(
|
64 |
+
intent_logits.view(-1, self.num_intent_labels), intent_label_ids.view(-1)
|
65 |
+
)
|
66 |
+
total_loss += self.args.intent_loss_coef * intent_loss
|
67 |
+
|
68 |
+
# 2. Slot Softmax
|
69 |
+
if slot_labels_ids is not None:
|
70 |
+
if self.args.use_crf:
|
71 |
+
slot_loss = self.crf(slot_logits, slot_labels_ids, mask=attention_mask.byte(), reduction="mean")
|
72 |
+
slot_loss = -1 * slot_loss # negative log-likelihood
|
73 |
+
else:
|
74 |
+
slot_loss_fct = nn.CrossEntropyLoss(ignore_index=self.args.ignore_index)
|
75 |
+
# Only keep active parts of the loss
|
76 |
+
if attention_mask is not None:
|
77 |
+
active_loss = attention_mask.view(-1) == 1
|
78 |
+
active_logits = slot_logits.view(-1, self.num_slot_labels)[active_loss]
|
79 |
+
active_labels = slot_labels_ids.view(-1)[active_loss]
|
80 |
+
slot_loss = slot_loss_fct(active_logits, active_labels)
|
81 |
+
else:
|
82 |
+
slot_loss = slot_loss_fct(slot_logits.view(-1, self.num_slot_labels), slot_labels_ids.view(-1))
|
83 |
+
total_loss += (1 - self.args.intent_loss_coef) * slot_loss
|
84 |
+
|
85 |
+
outputs = ((intent_logits, slot_logits),) + outputs[2:] # add hidden states and attention if they are here
|
86 |
+
|
87 |
+
outputs = (total_loss,) + outputs
|
88 |
+
|
89 |
+
return outputs # (loss), logits, (hidden_states), (attentions) # Logits is a tuple of intent and slot logits
|
model/modeling_jointxlmr.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from torchcrf import CRF
|
4 |
+
from transformers.models.roberta.modeling_roberta import RobertaPreTrainedModel
|
5 |
+
from transformers.models.xlm_roberta.modeling_xlm_roberta import XLMRobertaModel
|
6 |
+
|
7 |
+
from .module import IntentClassifier, SlotClassifier
|
8 |
+
|
9 |
+
|
10 |
+
class JointXLMR(RobertaPreTrainedModel):
|
11 |
+
def __init__(self, config, args, intent_label_lst, slot_label_lst):
|
12 |
+
super(JointXLMR, self).__init__(config)
|
13 |
+
self.args = args
|
14 |
+
self.num_intent_labels = len(intent_label_lst)
|
15 |
+
self.num_slot_labels = len(slot_label_lst)
|
16 |
+
self.roberta = XLMRobertaModel(config) # Load pretrained bert
|
17 |
+
self.intent_classifier = IntentClassifier(config.hidden_size, self.num_intent_labels, args.dropout_rate)
|
18 |
+
self.slot_classifier = SlotClassifier(
|
19 |
+
config.hidden_size,
|
20 |
+
self.num_intent_labels,
|
21 |
+
self.num_slot_labels,
|
22 |
+
self.args.use_intent_context_concat,
|
23 |
+
self.args.use_intent_context_attention,
|
24 |
+
self.args.max_seq_len,
|
25 |
+
self.args.attention_embedding_size,
|
26 |
+
args.dropout_rate,
|
27 |
+
)
|
28 |
+
|
29 |
+
if args.use_crf:
|
30 |
+
self.crf = CRF(num_tags=self.num_slot_labels, batch_first=True)
|
31 |
+
|
32 |
+
def forward(self, input_ids, attention_mask, token_type_ids, intent_label_ids, slot_labels_ids):
|
33 |
+
outputs = self.roberta(
|
34 |
+
input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids
|
35 |
+
) # sequence_output, pooled_output, (hidden_states), (attentions)
|
36 |
+
sequence_output = outputs[0]
|
37 |
+
pooled_output = outputs[1] # [CLS]
|
38 |
+
|
39 |
+
intent_logits = self.intent_classifier(pooled_output)
|
40 |
+
if not self.args.use_attention_mask:
|
41 |
+
tmp_attention_mask = None
|
42 |
+
else:
|
43 |
+
tmp_attention_mask = attention_mask
|
44 |
+
|
45 |
+
if self.args.embedding_type == "hard":
|
46 |
+
hard_intent_logits = torch.zeros(intent_logits.shape)
|
47 |
+
for i, sample in enumerate(intent_logits):
|
48 |
+
max_idx = torch.argmax(sample)
|
49 |
+
hard_intent_logits[i][max_idx] = 1
|
50 |
+
slot_logits = self.slot_classifier(sequence_output, hard_intent_logits, tmp_attention_mask)
|
51 |
+
else:
|
52 |
+
slot_logits = self.slot_classifier(sequence_output, intent_logits, tmp_attention_mask)
|
53 |
+
|
54 |
+
total_loss = 0
|
55 |
+
# 1. Intent Softmax
|
56 |
+
if intent_label_ids is not None:
|
57 |
+
if self.num_intent_labels == 1:
|
58 |
+
intent_loss_fct = nn.MSELoss()
|
59 |
+
intent_loss = intent_loss_fct(intent_logits.view(-1), intent_label_ids.view(-1))
|
60 |
+
else:
|
61 |
+
intent_loss_fct = nn.CrossEntropyLoss()
|
62 |
+
intent_loss = intent_loss_fct(
|
63 |
+
intent_logits.view(-1, self.num_intent_labels), intent_label_ids.view(-1)
|
64 |
+
)
|
65 |
+
total_loss += self.args.intent_loss_coef * intent_loss
|
66 |
+
|
67 |
+
# 2. Slot Softmax
|
68 |
+
if slot_labels_ids is not None:
|
69 |
+
if self.args.use_crf:
|
70 |
+
slot_loss = self.crf(slot_logits, slot_labels_ids, mask=attention_mask.byte(), reduction="mean")
|
71 |
+
slot_loss = -1 * slot_loss # negative log-likelihood
|
72 |
+
else:
|
73 |
+
slot_loss_fct = nn.CrossEntropyLoss(ignore_index=self.args.ignore_index)
|
74 |
+
# Only keep active parts of the loss
|
75 |
+
if attention_mask is not None:
|
76 |
+
active_loss = attention_mask.view(-1) == 1
|
77 |
+
active_logits = slot_logits.view(-1, self.num_slot_labels)[active_loss]
|
78 |
+
active_labels = slot_labels_ids.view(-1)[active_loss]
|
79 |
+
slot_loss = slot_loss_fct(active_logits, active_labels)
|
80 |
+
else:
|
81 |
+
slot_loss = slot_loss_fct(slot_logits.view(-1, self.num_slot_labels), slot_labels_ids.view(-1))
|
82 |
+
total_loss += (1 - self.args.intent_loss_coef) * slot_loss
|
83 |
+
|
84 |
+
outputs = ((intent_logits, slot_logits),) + outputs[2:] # add hidden states and attention if they are here
|
85 |
+
|
86 |
+
outputs = (total_loss,) + outputs
|
87 |
+
|
88 |
+
return outputs # (loss), logits, (hidden_states), (attentions) # Logits is a tuple of intent and slot logits
|
model/module.py
ADDED
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
|
5 |
+
|
6 |
+
class Attention(nn.Module):
|
7 |
+
"""Applies attention mechanism on the `context` using the `query`.
|
8 |
+
Args:
|
9 |
+
dimensions (int): Dimensionality of the query and context.
|
10 |
+
attention_type (str, optional): How to compute the attention score:
|
11 |
+
|
12 |
+
* dot: :math:`score(H_j,q) = H_j^T q`
|
13 |
+
* general: :math:`score(H_j, q) = H_j^T W_a q`
|
14 |
+
|
15 |
+
Example:
|
16 |
+
|
17 |
+
>>> attention = Attention(256)
|
18 |
+
>>> query = torch.randn(32, 50, 256)
|
19 |
+
>>> context = torch.randn(32, 1, 256)
|
20 |
+
>>> output, weights = attention(query, context)
|
21 |
+
>>> output.size()
|
22 |
+
torch.Size([32, 50, 256])
|
23 |
+
>>> weights.size()
|
24 |
+
torch.Size([32, 50, 1])
|
25 |
+
"""
|
26 |
+
|
27 |
+
def __init__(self, dimensions):
|
28 |
+
super(Attention, self).__init__()
|
29 |
+
|
30 |
+
self.dimensions = dimensions
|
31 |
+
self.linear_out = nn.Linear(dimensions * 2, dimensions, bias=False)
|
32 |
+
self.softmax = nn.Softmax(dim=1)
|
33 |
+
self.tanh = nn.Tanh()
|
34 |
+
|
35 |
+
def forward(self, query, context, attention_mask):
|
36 |
+
"""
|
37 |
+
Args:
|
38 |
+
query (:class:`torch.FloatTensor` [batch size, output length, dimensions]): Sequence of
|
39 |
+
queries to query the context.
|
40 |
+
context (:class:`torch.FloatTensor` [batch size, query length, dimensions]): Data
|
41 |
+
overwhich to apply the attention mechanism.
|
42 |
+
output length: length of utterance
|
43 |
+
query length: length of each token (1)
|
44 |
+
Returns:
|
45 |
+
:class:`tuple` with `output` and `weights`:
|
46 |
+
* **output** (:class:`torch.LongTensor` [batch size, output length, dimensions]):
|
47 |
+
Tensor containing the attended features.
|
48 |
+
* **weights** (:class:`torch.FloatTensor` [batch size, output length, query length]):
|
49 |
+
Tensor containing attention weights.
|
50 |
+
"""
|
51 |
+
# query = self.linear_query(query)
|
52 |
+
|
53 |
+
batch_size, output_len, hidden_size = query.size()
|
54 |
+
# query_len = context.size(1)
|
55 |
+
|
56 |
+
# (batch_size, output_len, dimensions) * (batch_size, query_len, dimensions) ->
|
57 |
+
# (batch_size, output_len, query_len)
|
58 |
+
attention_scores = torch.bmm(query, context.transpose(1, 2).contiguous())
|
59 |
+
# Compute weights across every context sequence
|
60 |
+
# attention_scores = attention_scores.view(batch_size * output_len, query_len)
|
61 |
+
if attention_mask is not None:
|
62 |
+
# Create attention mask, apply attention mask before softmax
|
63 |
+
attention_mask = torch.unsqueeze(attention_mask, 2)
|
64 |
+
# attention_mask = attention_mask.view(batch_size * output_len, query_len)
|
65 |
+
attention_scores.masked_fill_(attention_mask == 0, -np.inf)
|
66 |
+
# attention_scores = torch.squeeze(attention_scores,1)
|
67 |
+
attention_weights = self.softmax(attention_scores)
|
68 |
+
# attention_weights = attention_weights.view(batch_size, output_len, query_len)
|
69 |
+
|
70 |
+
# (batch_size, output_len, query_len) * (batch_size, query_len, dimensions) ->
|
71 |
+
# (batch_size, output_len, dimensions)
|
72 |
+
mix = torch.bmm(attention_weights, context)
|
73 |
+
# from IPython import embed; embed()
|
74 |
+
# concat -> (batch_size * output_len, 2*dimensions)
|
75 |
+
combined = torch.cat((mix, query), dim=2)
|
76 |
+
# combined = combined.view(batch_size * output_len, 2 * self.dimensions)
|
77 |
+
|
78 |
+
# Apply linear_out on every 2nd dimension of concat
|
79 |
+
# output -> (batch_size, output_len, dimensions)
|
80 |
+
# output = self.linear_out(combined).view(batch_size, output_len, self.dimensions)
|
81 |
+
output = self.linear_out(combined)
|
82 |
+
|
83 |
+
output = self.tanh(output)
|
84 |
+
# output = combined
|
85 |
+
return output, attention_weights
|
86 |
+
|
87 |
+
|
88 |
+
class IntentClassifier(nn.Module):
|
89 |
+
def __init__(self, input_dim, num_intent_labels, dropout_rate=0.0):
|
90 |
+
super(IntentClassifier, self).__init__()
|
91 |
+
self.dropout = nn.Dropout(dropout_rate)
|
92 |
+
self.linear = nn.Linear(input_dim, num_intent_labels)
|
93 |
+
|
94 |
+
def forward(self, x):
|
95 |
+
x = self.dropout(x)
|
96 |
+
return self.linear(x)
|
97 |
+
|
98 |
+
|
99 |
+
class SlotClassifier(nn.Module):
|
100 |
+
def __init__(
|
101 |
+
self,
|
102 |
+
input_dim,
|
103 |
+
num_intent_labels,
|
104 |
+
num_slot_labels,
|
105 |
+
use_intent_context_concat=False,
|
106 |
+
use_intent_context_attn=False,
|
107 |
+
max_seq_len=50,
|
108 |
+
attention_embedding_size=200,
|
109 |
+
dropout_rate=0.0,
|
110 |
+
):
|
111 |
+
super(SlotClassifier, self).__init__()
|
112 |
+
self.use_intent_context_attn = use_intent_context_attn
|
113 |
+
self.use_intent_context_concat = use_intent_context_concat
|
114 |
+
self.max_seq_len = max_seq_len
|
115 |
+
self.num_intent_labels = num_intent_labels
|
116 |
+
self.num_slot_labels = num_slot_labels
|
117 |
+
self.attention_embedding_size = attention_embedding_size
|
118 |
+
|
119 |
+
output_dim = self.attention_embedding_size # base model
|
120 |
+
if self.use_intent_context_concat:
|
121 |
+
output_dim = self.attention_embedding_size
|
122 |
+
self.linear_out = nn.Linear(2 * attention_embedding_size, attention_embedding_size)
|
123 |
+
|
124 |
+
elif self.use_intent_context_attn:
|
125 |
+
output_dim = self.attention_embedding_size
|
126 |
+
self.attention = Attention(attention_embedding_size)
|
127 |
+
|
128 |
+
self.linear_slot = nn.Linear(input_dim, self.attention_embedding_size, bias=False)
|
129 |
+
|
130 |
+
if self.use_intent_context_attn or self.use_intent_context_concat:
|
131 |
+
# project intent vector and slot vector to have the same dimensions
|
132 |
+
self.linear_intent_context = nn.Linear(self.num_intent_labels, self.attention_embedding_size, bias=False)
|
133 |
+
self.softmax = nn.Softmax(dim=-1) # softmax layer for intent logits
|
134 |
+
|
135 |
+
# self.linear_out = nn.Linear(2 * intent_embedding_size, intent_embedding_size)
|
136 |
+
# output
|
137 |
+
self.dropout = nn.Dropout(dropout_rate)
|
138 |
+
self.linear = nn.Linear(output_dim, num_slot_labels)
|
139 |
+
|
140 |
+
def forward(self, x, intent_context, attention_mask):
|
141 |
+
x = self.linear_slot(x)
|
142 |
+
if self.use_intent_context_concat:
|
143 |
+
intent_context = self.softmax(intent_context)
|
144 |
+
intent_context = self.linear_intent_context(intent_context)
|
145 |
+
intent_context = torch.unsqueeze(intent_context, 1)
|
146 |
+
intent_context = intent_context.expand(-1, self.max_seq_len, -1)
|
147 |
+
x = torch.cat((x, intent_context), dim=2)
|
148 |
+
x = self.linear_out(x)
|
149 |
+
|
150 |
+
elif self.use_intent_context_attn:
|
151 |
+
intent_context = self.softmax(intent_context)
|
152 |
+
intent_context = self.linear_intent_context(intent_context)
|
153 |
+
intent_context = torch.unsqueeze(intent_context, 1) # 1: query length (each token)
|
154 |
+
output, weights = self.attention(x, intent_context, attention_mask)
|
155 |
+
x = output
|
156 |
+
x = self.dropout(x)
|
157 |
+
return self.linear(x)
|