|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from transformers import ( |
|
BertModel, |
|
BertConfig, |
|
PretrainedConfig, |
|
PreTrainedModel, |
|
) |
|
from transformers.modeling_outputs import SequenceClassifierOutput |
|
from .modules import EncoderRNN, BiAttention, get_aggregated |
|
|
|
|
|
class BertConfigForWebshop(PretrainedConfig): |
|
model_type = "bert" |
|
|
|
def __init__( |
|
self, |
|
pretrained_bert=True, |
|
image=False, |
|
|
|
**kwargs |
|
): |
|
self.pretrained_bert = pretrained_bert |
|
self.image = image |
|
super().__init__(**kwargs) |
|
|
|
|
|
|
|
class BertModelForWebshop(PreTrainedModel): |
|
|
|
config_class = BertConfigForWebshop |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
bert_config = BertConfig.from_pretrained('bert-base-uncased') |
|
if config.pretrained_bert: |
|
self.bert = BertModel.from_pretrained('bert-base-uncased') |
|
else: |
|
self.bert = BertModel(config) |
|
self.bert.resize_token_embeddings(30526) |
|
self.attn = BiAttention(768, 0.0) |
|
self.linear_1 = nn.Linear(768 * 4, 768) |
|
self.relu = nn.ReLU() |
|
self.linear_2 = nn.Linear(768, 1) |
|
if config.image: |
|
self.image_linear = nn.Linear(512, 768) |
|
else: |
|
self.image_linear = None |
|
|
|
|
|
self.linear_3 = nn.Sequential( |
|
nn.Linear(768, 128), |
|
nn.LeakyReLU(), |
|
nn.Linear(128, 1), |
|
) |
|
|
|
def forward(self, state_input_ids, state_attention_mask, action_input_ids, action_attention_mask, sizes, images=None, labels=None): |
|
sizes = sizes.tolist() |
|
|
|
state_rep = self.bert(state_input_ids, attention_mask=state_attention_mask)[0] |
|
if images is not None and self.image_linear is not None: |
|
images = self.image_linear(images) |
|
state_rep = torch.cat([images.unsqueeze(1), state_rep], dim=1) |
|
state_attention_mask = torch.cat([state_attention_mask[:, :1], state_attention_mask], dim=1) |
|
action_rep = self.bert(action_input_ids, attention_mask=action_attention_mask)[0] |
|
state_rep = torch.cat([state_rep[i:i+1].repeat(j, 1, 1) for i, j in enumerate(sizes)], dim=0) |
|
state_attention_mask = torch.cat([state_attention_mask[i:i+1].repeat(j, 1) for i, j in enumerate(sizes)], dim=0) |
|
act_lens = action_attention_mask.sum(1).tolist() |
|
state_action_rep = self.attn(action_rep, state_rep, state_attention_mask) |
|
state_action_rep = self.relu(self.linear_1(state_action_rep)) |
|
act_values = get_aggregated(state_action_rep, act_lens, 'mean') |
|
act_values = self.linear_2(act_values).squeeze(1) |
|
|
|
logits = [F.log_softmax(_, dim=0) for _ in act_values.split(sizes)] |
|
|
|
loss = None |
|
if labels is not None: |
|
loss = - sum([logit[label] for logit, label in zip(logits, labels)]) / len(logits) |
|
|
|
return SequenceClassifierOutput( |
|
loss=loss, |
|
logits=logits, |
|
) |
|
|
|
def rl_forward(self, state_batch, act_batch, value=False, q=False, act=False): |
|
act_values = [] |
|
act_sizes = [] |
|
values = [] |
|
for state, valid_acts in zip(state_batch, act_batch): |
|
with torch.set_grad_enabled(not act): |
|
state_ids = torch.tensor([state.obs]).cuda() |
|
state_mask = (state_ids > 0).int() |
|
act_lens = [len(_) for _ in valid_acts] |
|
act_ids = [torch.tensor(_) for _ in valid_acts] |
|
act_ids = nn.utils.rnn.pad_sequence(act_ids, batch_first=True).cuda() |
|
act_mask = (act_ids > 0).int() |
|
act_size = torch.tensor([len(valid_acts)]).cuda() |
|
if self.image_linear is not None: |
|
images = [state.image_feat] |
|
images = [torch.zeros(512) if _ is None else _ for _ in images] |
|
images = torch.stack(images).cuda() |
|
else: |
|
images = None |
|
logits = self.forward(state_ids, state_mask, act_ids, act_mask, act_size, images=images).logits[0] |
|
act_values.append(logits) |
|
act_sizes.append(len(valid_acts)) |
|
if value: |
|
v = self.bert(state_ids, state_mask)[0] |
|
values.append(self.linear_3(v[0][0])) |
|
act_values = torch.cat(act_values, dim=0) |
|
act_values = torch.cat([F.log_softmax(_, dim=0) for _ in act_values.split(act_sizes)], dim=0) |
|
|
|
if value: |
|
values = torch.cat(values, dim=0) |
|
return act_values, act_sizes, values |
|
else: |
|
return act_values, act_sizes |