il-choice-bert-image_1 / webshop_bert.py
webshop's picture
add model
7051b8b
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import (
BertModel,
BertConfig,
PretrainedConfig,
PreTrainedModel,
)
from transformers.modeling_outputs import SequenceClassifierOutput
class BertConfigForWebshop(PretrainedConfig):
model_type = "bert"
def __init__(
self,
pretrained_bert=True,
image=False,
**kwargs
):
self.pretrained_bert = pretrained_bert
self.image = image
super().__init__(**kwargs)
class BiAttention(nn.Module):
def __init__(self, input_size, dropout):
super().__init__()
self.dropout = nn.Dropout(dropout)
self.input_linear = nn.Linear(input_size, 1, bias=False)
self.memory_linear = nn.Linear(input_size, 1, bias=False)
self.dot_scale = nn.Parameter(
torch.zeros(size=(input_size,)).uniform_(1. / (input_size ** 0.5)),
requires_grad=True)
self.init_parameters()
def init_parameters(self):
return
def forward(self, context, memory, mask):
bsz, input_len = context.size(0), context.size(1)
memory_len = memory.size(1)
context = self.dropout(context)
memory = self.dropout(memory)
input_dot = self.input_linear(context)
memory_dot = self.memory_linear(memory).view(bsz, 1, memory_len)
cross_dot = torch.bmm(
context * self.dot_scale,
memory.permute(0, 2, 1).contiguous())
att = input_dot + memory_dot + cross_dot
att = att - 1e30 * (1 - mask[:, None])
weight_one = F.softmax(att, dim=-1)
output_one = torch.bmm(weight_one, memory)
weight_two = (F.softmax(att.max(dim=-1)[0], dim=-1)
.view(bsz, 1, input_len))
output_two = torch.bmm(weight_two, context)
return torch.cat(
[context, output_one, context * output_one,
output_two * output_one],
dim=-1)
class BertModelForWebshop(PreTrainedModel):
config_class = BertConfigForWebshop
def __init__(self, config):
super().__init__(config)
bert_config = BertConfig.from_pretrained('bert-base-uncased')
if config.pretrained_bert:
self.bert = BertModel.from_pretrained('bert-base-uncased')
else:
self.bert = BertModel(config)
self.bert.resize_token_embeddings(30526)
self.attn = BiAttention(768, 0.0)
self.linear_1 = nn.Linear(768 * 4, 768)
self.relu = nn.ReLU()
self.linear_2 = nn.Linear(768, 1)
if config.image:
self.image_linear = nn.Linear(512, 768)
else:
self.image_linear = None
@staticmethod
def get_aggregated(output, lens, method):
"""
Get the aggregated hidden state of the encoder.
B x D
"""
if method == 'mean':
return torch.stack([output[i, :j, :].mean(0) for i, j in enumerate(lens)], dim=0)
elif method == 'last':
return torch.stack([output[i, j-1, :] for i, j in enumerate(lens)], dim=0)
elif method == 'first':
return output[:, 0, :]
def forward(self, state_input_ids, state_attention_mask, action_input_ids, action_attention_mask, sizes, images=None, labels=None):
sizes = sizes.tolist()
# print(state_input_ids.shape, action_input_ids.shape)
state_rep = self.bert(state_input_ids, attention_mask=state_attention_mask)[0]
if images is not None and self.image_linear is not None:
images = self.image_linear(images)
state_rep = torch.cat([images.unsqueeze(1), state_rep], dim=1)
state_attention_mask = torch.cat([state_attention_mask[:, :1], state_attention_mask], dim=1)
action_rep = self.bert(action_input_ids, attention_mask=action_attention_mask)[0]
state_rep = torch.cat([state_rep[i:i+1].repeat(j, 1, 1) for i, j in enumerate(sizes)], dim=0)
state_attention_mask = torch.cat([state_attention_mask[i:i+1].repeat(j, 1) for i, j in enumerate(sizes)], dim=0)
act_lens = action_attention_mask.sum(1).tolist()
state_action_rep = self.attn(action_rep, state_rep, state_attention_mask)
state_action_rep = self.relu(self.linear_1(state_action_rep))
act_values = self.get_aggregated(state_action_rep, act_lens, 'mean')
act_values = self.linear_2(act_values).squeeze(1)
logits = [F.log_softmax(_, dim=0) for _ in act_values.split(sizes)]
loss = None
if labels is not None:
loss = - sum([logit[label] for logit, label in zip(logits, labels)]) / len(logits)
return SequenceClassifierOutput(
loss=loss,
logits=logits,
)