|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from transformers import ( |
|
BertModel, |
|
BertConfig, |
|
PretrainedConfig, |
|
PreTrainedModel, |
|
) |
|
from transformers.modeling_outputs import SequenceClassifierOutput |
|
|
|
|
|
class BertConfigForWebshop(PretrainedConfig): |
|
model_type = "bert" |
|
|
|
def __init__( |
|
self, |
|
pretrained_bert=True, |
|
image=False, |
|
**kwargs |
|
): |
|
self.pretrained_bert = pretrained_bert |
|
self.image = image |
|
super().__init__(**kwargs) |
|
|
|
|
|
class BiAttention(nn.Module): |
|
def __init__(self, input_size, dropout): |
|
super().__init__() |
|
self.dropout = nn.Dropout(dropout) |
|
self.input_linear = nn.Linear(input_size, 1, bias=False) |
|
self.memory_linear = nn.Linear(input_size, 1, bias=False) |
|
self.dot_scale = nn.Parameter( |
|
torch.zeros(size=(input_size,)).uniform_(1. / (input_size ** 0.5)), |
|
requires_grad=True) |
|
self.init_parameters() |
|
|
|
def init_parameters(self): |
|
return |
|
|
|
def forward(self, context, memory, mask): |
|
bsz, input_len = context.size(0), context.size(1) |
|
memory_len = memory.size(1) |
|
context = self.dropout(context) |
|
memory = self.dropout(memory) |
|
|
|
input_dot = self.input_linear(context) |
|
memory_dot = self.memory_linear(memory).view(bsz, 1, memory_len) |
|
cross_dot = torch.bmm( |
|
context * self.dot_scale, |
|
memory.permute(0, 2, 1).contiguous()) |
|
att = input_dot + memory_dot + cross_dot |
|
att = att - 1e30 * (1 - mask[:, None]) |
|
|
|
weight_one = F.softmax(att, dim=-1) |
|
output_one = torch.bmm(weight_one, memory) |
|
weight_two = (F.softmax(att.max(dim=-1)[0], dim=-1) |
|
.view(bsz, 1, input_len)) |
|
output_two = torch.bmm(weight_two, context) |
|
return torch.cat( |
|
[context, output_one, context * output_one, |
|
output_two * output_one], |
|
dim=-1) |
|
|
|
|
|
class BertModelForWebshop(PreTrainedModel): |
|
|
|
config_class = BertConfigForWebshop |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
bert_config = BertConfig.from_pretrained('bert-base-uncased') |
|
if config.pretrained_bert: |
|
self.bert = BertModel.from_pretrained('bert-base-uncased') |
|
else: |
|
self.bert = BertModel(config) |
|
self.bert.resize_token_embeddings(30526) |
|
self.attn = BiAttention(768, 0.0) |
|
self.linear_1 = nn.Linear(768 * 4, 768) |
|
self.relu = nn.ReLU() |
|
self.linear_2 = nn.Linear(768, 1) |
|
if config.image: |
|
self.image_linear = nn.Linear(512, 768) |
|
else: |
|
self.image_linear = None |
|
|
|
@staticmethod |
|
def get_aggregated(output, lens, method): |
|
""" |
|
Get the aggregated hidden state of the encoder. |
|
B x D |
|
""" |
|
if method == 'mean': |
|
return torch.stack([output[i, :j, :].mean(0) for i, j in enumerate(lens)], dim=0) |
|
elif method == 'last': |
|
return torch.stack([output[i, j-1, :] for i, j in enumerate(lens)], dim=0) |
|
elif method == 'first': |
|
return output[:, 0, :] |
|
|
|
def forward(self, state_input_ids, state_attention_mask, action_input_ids, action_attention_mask, sizes, images=None, labels=None): |
|
sizes = sizes.tolist() |
|
|
|
state_rep = self.bert(state_input_ids, attention_mask=state_attention_mask)[0] |
|
if images is not None and self.image_linear is not None: |
|
images = self.image_linear(images) |
|
state_rep = torch.cat([images.unsqueeze(1), state_rep], dim=1) |
|
state_attention_mask = torch.cat([state_attention_mask[:, :1], state_attention_mask], dim=1) |
|
action_rep = self.bert(action_input_ids, attention_mask=action_attention_mask)[0] |
|
state_rep = torch.cat([state_rep[i:i+1].repeat(j, 1, 1) for i, j in enumerate(sizes)], dim=0) |
|
state_attention_mask = torch.cat([state_attention_mask[i:i+1].repeat(j, 1) for i, j in enumerate(sizes)], dim=0) |
|
act_lens = action_attention_mask.sum(1).tolist() |
|
state_action_rep = self.attn(action_rep, state_rep, state_attention_mask) |
|
state_action_rep = self.relu(self.linear_1(state_action_rep)) |
|
act_values = self.get_aggregated(state_action_rep, act_lens, 'mean') |
|
act_values = self.linear_2(act_values).squeeze(1) |
|
|
|
logits = [F.log_softmax(_, dim=0) for _ in act_values.split(sizes)] |
|
|
|
loss = None |
|
if labels is not None: |
|
loss = - sum([logit[label] for logit, label in zip(logits, labels)]) / len(logits) |
|
|
|
return SequenceClassifierOutput( |
|
loss=loss, |
|
logits=logits, |
|
) |
|
|