add model
Browse files- bert.py +118 -0
- config.json +14 -0
- modules.py +156 -0
- pytorch_model.bin +3 -0
bert.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from transformers import (
|
5 |
+
BertModel,
|
6 |
+
BertConfig,
|
7 |
+
PretrainedConfig,
|
8 |
+
PreTrainedModel,
|
9 |
+
)
|
10 |
+
from transformers.modeling_outputs import SequenceClassifierOutput
|
11 |
+
from .modules import EncoderRNN, BiAttention, get_aggregated
|
12 |
+
|
13 |
+
|
14 |
+
class BertConfigForWebshop(PretrainedConfig):
|
15 |
+
model_type = "bert"
|
16 |
+
|
17 |
+
def __init__(
|
18 |
+
self,
|
19 |
+
pretrained_bert=True,
|
20 |
+
image=False,
|
21 |
+
|
22 |
+
**kwargs
|
23 |
+
):
|
24 |
+
self.pretrained_bert = pretrained_bert
|
25 |
+
self.image = image
|
26 |
+
super().__init__(**kwargs)
|
27 |
+
|
28 |
+
|
29 |
+
|
30 |
+
class BertModelForWebshop(PreTrainedModel):
|
31 |
+
|
32 |
+
config_class = BertConfigForWebshop
|
33 |
+
|
34 |
+
def __init__(self, config):
|
35 |
+
super().__init__(config)
|
36 |
+
bert_config = BertConfig.from_pretrained('bert-base-uncased')
|
37 |
+
if config.pretrained_bert:
|
38 |
+
self.bert = BertModel.from_pretrained('bert-base-uncased')
|
39 |
+
else:
|
40 |
+
self.bert = BertModel(config)
|
41 |
+
self.bert.resize_token_embeddings(30526)
|
42 |
+
self.attn = BiAttention(768, 0.0)
|
43 |
+
self.linear_1 = nn.Linear(768 * 4, 768)
|
44 |
+
self.relu = nn.ReLU()
|
45 |
+
self.linear_2 = nn.Linear(768, 1)
|
46 |
+
if config.image:
|
47 |
+
self.image_linear = nn.Linear(512, 768)
|
48 |
+
else:
|
49 |
+
self.image_linear = None
|
50 |
+
|
51 |
+
# for state value prediction, used in RL
|
52 |
+
self.linear_3 = nn.Sequential(
|
53 |
+
nn.Linear(768, 128),
|
54 |
+
nn.LeakyReLU(),
|
55 |
+
nn.Linear(128, 1),
|
56 |
+
)
|
57 |
+
|
58 |
+
def forward(self, state_input_ids, state_attention_mask, action_input_ids, action_attention_mask, sizes, images=None, labels=None):
|
59 |
+
sizes = sizes.tolist()
|
60 |
+
# print(state_input_ids.shape, action_input_ids.shape)
|
61 |
+
state_rep = self.bert(state_input_ids, attention_mask=state_attention_mask)[0]
|
62 |
+
if images is not None and self.image_linear is not None:
|
63 |
+
images = self.image_linear(images)
|
64 |
+
state_rep = torch.cat([images.unsqueeze(1), state_rep], dim=1)
|
65 |
+
state_attention_mask = torch.cat([state_attention_mask[:, :1], state_attention_mask], dim=1)
|
66 |
+
action_rep = self.bert(action_input_ids, attention_mask=action_attention_mask)[0]
|
67 |
+
state_rep = torch.cat([state_rep[i:i+1].repeat(j, 1, 1) for i, j in enumerate(sizes)], dim=0)
|
68 |
+
state_attention_mask = torch.cat([state_attention_mask[i:i+1].repeat(j, 1) for i, j in enumerate(sizes)], dim=0)
|
69 |
+
act_lens = action_attention_mask.sum(1).tolist()
|
70 |
+
state_action_rep = self.attn(action_rep, state_rep, state_attention_mask)
|
71 |
+
state_action_rep = self.relu(self.linear_1(state_action_rep))
|
72 |
+
act_values = get_aggregated(state_action_rep, act_lens, 'mean')
|
73 |
+
act_values = self.linear_2(act_values).squeeze(1)
|
74 |
+
|
75 |
+
logits = [F.log_softmax(_, dim=0) for _ in act_values.split(sizes)]
|
76 |
+
|
77 |
+
loss = None
|
78 |
+
if labels is not None:
|
79 |
+
loss = - sum([logit[label] for logit, label in zip(logits, labels)]) / len(logits)
|
80 |
+
|
81 |
+
return SequenceClassifierOutput(
|
82 |
+
loss=loss,
|
83 |
+
logits=logits,
|
84 |
+
)
|
85 |
+
|
86 |
+
def rl_forward(self, state_batch, act_batch, value=False, q=False, act=False):
|
87 |
+
act_values = []
|
88 |
+
act_sizes = []
|
89 |
+
values = []
|
90 |
+
for state, valid_acts in zip(state_batch, act_batch):
|
91 |
+
with torch.set_grad_enabled(not act):
|
92 |
+
state_ids = torch.tensor([state.obs]).cuda()
|
93 |
+
state_mask = (state_ids > 0).int()
|
94 |
+
act_lens = [len(_) for _ in valid_acts]
|
95 |
+
act_ids = [torch.tensor(_) for _ in valid_acts]
|
96 |
+
act_ids = nn.utils.rnn.pad_sequence(act_ids, batch_first=True).cuda()
|
97 |
+
act_mask = (act_ids > 0).int()
|
98 |
+
act_size = torch.tensor([len(valid_acts)]).cuda()
|
99 |
+
if self.image_linear is not None:
|
100 |
+
images = [state.image_feat]
|
101 |
+
images = [torch.zeros(512) if _ is None else _ for _ in images]
|
102 |
+
images = torch.stack(images).cuda() # BS x 512
|
103 |
+
else:
|
104 |
+
images = None
|
105 |
+
logits = self.forward(state_ids, state_mask, act_ids, act_mask, act_size, images=images).logits[0]
|
106 |
+
act_values.append(logits)
|
107 |
+
act_sizes.append(len(valid_acts))
|
108 |
+
if value:
|
109 |
+
v = self.bert(state_ids, state_mask)[0]
|
110 |
+
values.append(self.linear_3(v[0][0]))
|
111 |
+
act_values = torch.cat(act_values, dim=0)
|
112 |
+
act_values = torch.cat([F.log_softmax(_, dim=0) for _ in act_values.split(act_sizes)], dim=0)
|
113 |
+
# Optionally, output state value prediction
|
114 |
+
if value:
|
115 |
+
values = torch.cat(values, dim=0)
|
116 |
+
return act_values, act_sizes, values
|
117 |
+
else:
|
118 |
+
return act_values, act_sizes
|
config.json
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"architectures": [
|
3 |
+
"BertModelForWebshop"
|
4 |
+
],
|
5 |
+
"auto_map": {
|
6 |
+
"AutoConfig": "bert.BertConfigForWebshop",
|
7 |
+
"AutoModel": "bert.BertModelForWebshop"
|
8 |
+
},
|
9 |
+
"image": true,
|
10 |
+
"model_type": "bert",
|
11 |
+
"pretrained_bert": true,
|
12 |
+
"torch_dtype": "float32",
|
13 |
+
"transformers_version": "4.17.0"
|
14 |
+
}
|
modules.py
ADDED
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import itertools
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from torch.nn.utils import rnn
|
6 |
+
|
7 |
+
|
8 |
+
def duplicate(output, mask, lens, act_sizes):
|
9 |
+
"""
|
10 |
+
Duplicate the output based on the action sizes.
|
11 |
+
"""
|
12 |
+
output = torch.cat([output[i:i+1].repeat(j, 1, 1) for i, j in enumerate(act_sizes)], dim=0)
|
13 |
+
mask = torch.cat([mask[i:i+1].repeat(j, 1) for i, j in enumerate(act_sizes)], dim=0)
|
14 |
+
lens = list(itertools.chain.from_iterable([lens[i:i+1] * j for i, j in enumerate(act_sizes)]))
|
15 |
+
return output, mask, lens
|
16 |
+
|
17 |
+
|
18 |
+
def get_aggregated(output, lens, method):
|
19 |
+
"""
|
20 |
+
Get the aggregated hidden state of the encoder.
|
21 |
+
B x D
|
22 |
+
"""
|
23 |
+
if method == 'mean':
|
24 |
+
return torch.stack([output[i, :j, :].mean(0) for i, j in enumerate(lens)], dim=0)
|
25 |
+
elif method == 'last':
|
26 |
+
return torch.stack([output[i, j-1, :] for i, j in enumerate(lens)], dim=0)
|
27 |
+
elif method == 'first':
|
28 |
+
return output[:, 0, :]
|
29 |
+
|
30 |
+
|
31 |
+
class EncoderRNN(nn.Module):
|
32 |
+
def __init__(self, input_size, num_units, nlayers, concat,
|
33 |
+
bidir, layernorm, return_last):
|
34 |
+
super().__init__()
|
35 |
+
self.layernorm = (layernorm == 'layer')
|
36 |
+
if layernorm:
|
37 |
+
self.norm = nn.LayerNorm(input_size)
|
38 |
+
|
39 |
+
self.rnns = []
|
40 |
+
for i in range(nlayers):
|
41 |
+
if i == 0:
|
42 |
+
input_size_ = input_size
|
43 |
+
output_size_ = num_units
|
44 |
+
else:
|
45 |
+
input_size_ = num_units if not bidir else num_units * 2
|
46 |
+
output_size_ = num_units
|
47 |
+
self.rnns.append(
|
48 |
+
nn.GRU(input_size_, output_size_, 1,
|
49 |
+
bidirectional=bidir, batch_first=True))
|
50 |
+
|
51 |
+
self.rnns = nn.ModuleList(self.rnns)
|
52 |
+
self.init_hidden = nn.ParameterList(
|
53 |
+
[nn.Parameter(
|
54 |
+
torch.zeros(size=(2 if bidir else 1, 1, num_units)),
|
55 |
+
requires_grad=True) for _ in range(nlayers)])
|
56 |
+
self.concat = concat
|
57 |
+
self.nlayers = nlayers
|
58 |
+
self.return_last = return_last
|
59 |
+
|
60 |
+
self.reset_parameters()
|
61 |
+
|
62 |
+
def reset_parameters(self):
|
63 |
+
with torch.no_grad():
|
64 |
+
for rnn_layer in self.rnns:
|
65 |
+
for name, p in rnn_layer.named_parameters():
|
66 |
+
if 'weight_ih' in name:
|
67 |
+
torch.nn.init.xavier_uniform_(p.data)
|
68 |
+
elif 'weight_hh' in name:
|
69 |
+
torch.nn.init.orthogonal_(p.data)
|
70 |
+
elif 'bias' in name:
|
71 |
+
p.data.fill_(0.0)
|
72 |
+
else:
|
73 |
+
p.data.normal_(std=0.1)
|
74 |
+
|
75 |
+
def get_init(self, bsz, i):
|
76 |
+
return self.init_hidden[i].expand(-1, bsz, -1).contiguous()
|
77 |
+
|
78 |
+
def forward(self, inputs, input_lengths=None):
|
79 |
+
bsz, slen = inputs.size(0), inputs.size(1)
|
80 |
+
if self.layernorm:
|
81 |
+
inputs = self.norm(inputs)
|
82 |
+
output = inputs
|
83 |
+
outputs = []
|
84 |
+
lens = 0
|
85 |
+
if input_lengths is not None:
|
86 |
+
lens = input_lengths # .data.cpu().numpy()
|
87 |
+
for i in range(self.nlayers):
|
88 |
+
hidden = self.get_init(bsz, i)
|
89 |
+
# output = self.dropout(output)
|
90 |
+
if input_lengths is not None:
|
91 |
+
output = rnn.pack_padded_sequence(output, lens,
|
92 |
+
batch_first=True,
|
93 |
+
enforce_sorted=False)
|
94 |
+
output, hidden = self.rnns[i](output, hidden)
|
95 |
+
if input_lengths is not None:
|
96 |
+
output, _ = rnn.pad_packed_sequence(output, batch_first=True)
|
97 |
+
if output.size(1) < slen:
|
98 |
+
# used for parallel
|
99 |
+
# padding = Variable(output.data.new(1, 1, 1).zero_())
|
100 |
+
padding = torch.zeros(
|
101 |
+
size=(1, 1, 1), dtype=output.type(),
|
102 |
+
device=output.device())
|
103 |
+
output = torch.cat(
|
104 |
+
[output,
|
105 |
+
padding.expand(
|
106 |
+
output.size(0),
|
107 |
+
slen - output.size(1),
|
108 |
+
output.size(2))
|
109 |
+
], dim=1)
|
110 |
+
if self.return_last:
|
111 |
+
outputs.append(
|
112 |
+
hidden.permute(1, 0, 2).contiguous().view(bsz, -1))
|
113 |
+
else:
|
114 |
+
outputs.append(output)
|
115 |
+
if self.concat:
|
116 |
+
return torch.cat(outputs, dim=2)
|
117 |
+
return outputs[-1]
|
118 |
+
|
119 |
+
|
120 |
+
class BiAttention(nn.Module):
|
121 |
+
def __init__(self, input_size, dropout):
|
122 |
+
super().__init__()
|
123 |
+
self.dropout = nn.Dropout(dropout)
|
124 |
+
self.input_linear = nn.Linear(input_size, 1, bias=False)
|
125 |
+
self.memory_linear = nn.Linear(input_size, 1, bias=False)
|
126 |
+
self.dot_scale = nn.Parameter(
|
127 |
+
torch.zeros(size=(input_size,)).uniform_(1. / (input_size ** 0.5)),
|
128 |
+
requires_grad=True)
|
129 |
+
self.init_parameters()
|
130 |
+
|
131 |
+
def init_parameters(self):
|
132 |
+
return
|
133 |
+
|
134 |
+
def forward(self, context, memory, mask):
|
135 |
+
bsz, input_len = context.size(0), context.size(1)
|
136 |
+
memory_len = memory.size(1)
|
137 |
+
context = self.dropout(context)
|
138 |
+
memory = self.dropout(memory)
|
139 |
+
|
140 |
+
input_dot = self.input_linear(context)
|
141 |
+
memory_dot = self.memory_linear(memory).view(bsz, 1, memory_len)
|
142 |
+
cross_dot = torch.bmm(
|
143 |
+
context * self.dot_scale,
|
144 |
+
memory.permute(0, 2, 1).contiguous())
|
145 |
+
att = input_dot + memory_dot + cross_dot
|
146 |
+
att = att - 1e30 * (1 - mask[:, None])
|
147 |
+
|
148 |
+
weight_one = F.softmax(att, dim=-1)
|
149 |
+
output_one = torch.bmm(weight_one, memory)
|
150 |
+
weight_two = (F.softmax(att.max(dim=-1)[0], dim=-1)
|
151 |
+
.view(bsz, 1, input_len))
|
152 |
+
output_two = torch.bmm(weight_two, context)
|
153 |
+
return torch.cat(
|
154 |
+
[context, output_one, context * output_one,
|
155 |
+
output_two * output_one],
|
156 |
+
dim=-1)
|
pytorch_model.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:32c00421372d607ffbd15f317f0040569dbe3cc7843f8885d2b54ffd2db9d0a8
|
3 |
+
size 449449751
|