add model
Browse files- config.json +14 -0
- pytorch_model.bin +3 -0
- webshop_bert.py +126 -0
config.json
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"architectures": [
|
3 |
+
"BertModelForWebshop"
|
4 |
+
],
|
5 |
+
"auto_map": {
|
6 |
+
"AutoConfig": "webshop_bert.BertConfigForWebshop",
|
7 |
+
"AutoModel": "webshop_bert.BertModelForWebshop"
|
8 |
+
},
|
9 |
+
"image": true,
|
10 |
+
"model_type": "bert",
|
11 |
+
"pretrained_bert": true,
|
12 |
+
"torch_dtype": "float32",
|
13 |
+
"transformers_version": "4.17.0"
|
14 |
+
}
|
pytorch_model.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d507b59c465e300dd8e7bcc2ae541131c272f10af996267015e66291f397a397
|
3 |
+
size 449054239
|
webshop_bert.py
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from transformers import (
|
5 |
+
BertModel,
|
6 |
+
BertConfig,
|
7 |
+
PretrainedConfig,
|
8 |
+
PreTrainedModel,
|
9 |
+
)
|
10 |
+
from transformers.modeling_outputs import SequenceClassifierOutput
|
11 |
+
|
12 |
+
|
13 |
+
class BertConfigForWebshop(PretrainedConfig):
|
14 |
+
model_type = "bert"
|
15 |
+
|
16 |
+
def __init__(
|
17 |
+
self,
|
18 |
+
pretrained_bert=True,
|
19 |
+
image=False,
|
20 |
+
**kwargs
|
21 |
+
):
|
22 |
+
self.pretrained_bert = pretrained_bert
|
23 |
+
self.image = image
|
24 |
+
super().__init__(**kwargs)
|
25 |
+
|
26 |
+
|
27 |
+
class BiAttention(nn.Module):
|
28 |
+
def __init__(self, input_size, dropout):
|
29 |
+
super().__init__()
|
30 |
+
self.dropout = nn.Dropout(dropout)
|
31 |
+
self.input_linear = nn.Linear(input_size, 1, bias=False)
|
32 |
+
self.memory_linear = nn.Linear(input_size, 1, bias=False)
|
33 |
+
self.dot_scale = nn.Parameter(
|
34 |
+
torch.zeros(size=(input_size,)).uniform_(1. / (input_size ** 0.5)),
|
35 |
+
requires_grad=True)
|
36 |
+
self.init_parameters()
|
37 |
+
|
38 |
+
def init_parameters(self):
|
39 |
+
return
|
40 |
+
|
41 |
+
def forward(self, context, memory, mask):
|
42 |
+
bsz, input_len = context.size(0), context.size(1)
|
43 |
+
memory_len = memory.size(1)
|
44 |
+
context = self.dropout(context)
|
45 |
+
memory = self.dropout(memory)
|
46 |
+
|
47 |
+
input_dot = self.input_linear(context)
|
48 |
+
memory_dot = self.memory_linear(memory).view(bsz, 1, memory_len)
|
49 |
+
cross_dot = torch.bmm(
|
50 |
+
context * self.dot_scale,
|
51 |
+
memory.permute(0, 2, 1).contiguous())
|
52 |
+
att = input_dot + memory_dot + cross_dot
|
53 |
+
att = att - 1e30 * (1 - mask[:, None])
|
54 |
+
|
55 |
+
weight_one = F.softmax(att, dim=-1)
|
56 |
+
output_one = torch.bmm(weight_one, memory)
|
57 |
+
weight_two = (F.softmax(att.max(dim=-1)[0], dim=-1)
|
58 |
+
.view(bsz, 1, input_len))
|
59 |
+
output_two = torch.bmm(weight_two, context)
|
60 |
+
return torch.cat(
|
61 |
+
[context, output_one, context * output_one,
|
62 |
+
output_two * output_one],
|
63 |
+
dim=-1)
|
64 |
+
|
65 |
+
|
66 |
+
class BertModelForWebshop(PreTrainedModel):
|
67 |
+
|
68 |
+
config_class = BertConfigForWebshop
|
69 |
+
|
70 |
+
def __init__(self, config):
|
71 |
+
super().__init__(config)
|
72 |
+
bert_config = BertConfig.from_pretrained('bert-base-uncased')
|
73 |
+
if config.pretrained_bert:
|
74 |
+
self.bert = BertModel.from_pretrained('bert-base-uncased')
|
75 |
+
else:
|
76 |
+
self.bert = BertModel(config)
|
77 |
+
self.bert.resize_token_embeddings(30526)
|
78 |
+
self.attn = BiAttention(768, 0.0)
|
79 |
+
self.linear_1 = nn.Linear(768 * 4, 768)
|
80 |
+
self.relu = nn.ReLU()
|
81 |
+
self.linear_2 = nn.Linear(768, 1)
|
82 |
+
if config.image:
|
83 |
+
self.image_linear = nn.Linear(512, 768)
|
84 |
+
else:
|
85 |
+
self.image_linear = None
|
86 |
+
|
87 |
+
@staticmethod
|
88 |
+
def get_aggregated(output, lens, method):
|
89 |
+
"""
|
90 |
+
Get the aggregated hidden state of the encoder.
|
91 |
+
B x D
|
92 |
+
"""
|
93 |
+
if method == 'mean':
|
94 |
+
return torch.stack([output[i, :j, :].mean(0) for i, j in enumerate(lens)], dim=0)
|
95 |
+
elif method == 'last':
|
96 |
+
return torch.stack([output[i, j-1, :] for i, j in enumerate(lens)], dim=0)
|
97 |
+
elif method == 'first':
|
98 |
+
return output[:, 0, :]
|
99 |
+
|
100 |
+
def forward(self, state_input_ids, state_attention_mask, action_input_ids, action_attention_mask, sizes, images=None, labels=None):
|
101 |
+
sizes = sizes.tolist()
|
102 |
+
# print(state_input_ids.shape, action_input_ids.shape)
|
103 |
+
state_rep = self.bert(state_input_ids, attention_mask=state_attention_mask)[0]
|
104 |
+
if images is not None and self.image_linear is not None:
|
105 |
+
images = self.image_linear(images)
|
106 |
+
state_rep = torch.cat([images.unsqueeze(1), state_rep], dim=1)
|
107 |
+
state_attention_mask = torch.cat([state_attention_mask[:, :1], state_attention_mask], dim=1)
|
108 |
+
action_rep = self.bert(action_input_ids, attention_mask=action_attention_mask)[0]
|
109 |
+
state_rep = torch.cat([state_rep[i:i+1].repeat(j, 1, 1) for i, j in enumerate(sizes)], dim=0)
|
110 |
+
state_attention_mask = torch.cat([state_attention_mask[i:i+1].repeat(j, 1) for i, j in enumerate(sizes)], dim=0)
|
111 |
+
act_lens = action_attention_mask.sum(1).tolist()
|
112 |
+
state_action_rep = self.attn(action_rep, state_rep, state_attention_mask)
|
113 |
+
state_action_rep = self.relu(self.linear_1(state_action_rep))
|
114 |
+
act_values = self.get_aggregated(state_action_rep, act_lens, 'mean')
|
115 |
+
act_values = self.linear_2(act_values).squeeze(1)
|
116 |
+
|
117 |
+
logits = [F.log_softmax(_, dim=0) for _ in act_values.split(sizes)]
|
118 |
+
|
119 |
+
loss = None
|
120 |
+
if labels is not None:
|
121 |
+
loss = - sum([logit[label] for logit, label in zip(logits, labels)]) / len(logits)
|
122 |
+
|
123 |
+
return SequenceClassifierOutput(
|
124 |
+
loss=loss,
|
125 |
+
logits=logits,
|
126 |
+
)
|