webshop commited on
Commit
1a0f94f
1 Parent(s): 6fdba25
Files changed (4) hide show
  1. bert.py +118 -0
  2. config.json +14 -0
  3. modules.py +156 -0
  4. pytorch_model.bin +3 -0
bert.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from transformers import (
5
+ BertModel,
6
+ BertConfig,
7
+ PretrainedConfig,
8
+ PreTrainedModel,
9
+ )
10
+ from transformers.modeling_outputs import SequenceClassifierOutput
11
+ from .modules import EncoderRNN, BiAttention, get_aggregated
12
+
13
+
14
+ class BertConfigForWebshop(PretrainedConfig):
15
+ model_type = "bert"
16
+
17
+ def __init__(
18
+ self,
19
+ pretrained_bert=True,
20
+ image=False,
21
+
22
+ **kwargs
23
+ ):
24
+ self.pretrained_bert = pretrained_bert
25
+ self.image = image
26
+ super().__init__(**kwargs)
27
+
28
+
29
+
30
+ class BertModelForWebshop(PreTrainedModel):
31
+
32
+ config_class = BertConfigForWebshop
33
+
34
+ def __init__(self, config):
35
+ super().__init__(config)
36
+ bert_config = BertConfig.from_pretrained('bert-base-uncased')
37
+ if config.pretrained_bert:
38
+ self.bert = BertModel.from_pretrained('bert-base-uncased')
39
+ else:
40
+ self.bert = BertModel(config)
41
+ self.bert.resize_token_embeddings(30526)
42
+ self.attn = BiAttention(768, 0.0)
43
+ self.linear_1 = nn.Linear(768 * 4, 768)
44
+ self.relu = nn.ReLU()
45
+ self.linear_2 = nn.Linear(768, 1)
46
+ if config.image:
47
+ self.image_linear = nn.Linear(512, 768)
48
+ else:
49
+ self.image_linear = None
50
+
51
+ # for state value prediction, used in RL
52
+ self.linear_3 = nn.Sequential(
53
+ nn.Linear(768, 128),
54
+ nn.LeakyReLU(),
55
+ nn.Linear(128, 1),
56
+ )
57
+
58
+ def forward(self, state_input_ids, state_attention_mask, action_input_ids, action_attention_mask, sizes, images=None, labels=None):
59
+ sizes = sizes.tolist()
60
+ # print(state_input_ids.shape, action_input_ids.shape)
61
+ state_rep = self.bert(state_input_ids, attention_mask=state_attention_mask)[0]
62
+ if images is not None and self.image_linear is not None:
63
+ images = self.image_linear(images)
64
+ state_rep = torch.cat([images.unsqueeze(1), state_rep], dim=1)
65
+ state_attention_mask = torch.cat([state_attention_mask[:, :1], state_attention_mask], dim=1)
66
+ action_rep = self.bert(action_input_ids, attention_mask=action_attention_mask)[0]
67
+ state_rep = torch.cat([state_rep[i:i+1].repeat(j, 1, 1) for i, j in enumerate(sizes)], dim=0)
68
+ state_attention_mask = torch.cat([state_attention_mask[i:i+1].repeat(j, 1) for i, j in enumerate(sizes)], dim=0)
69
+ act_lens = action_attention_mask.sum(1).tolist()
70
+ state_action_rep = self.attn(action_rep, state_rep, state_attention_mask)
71
+ state_action_rep = self.relu(self.linear_1(state_action_rep))
72
+ act_values = get_aggregated(state_action_rep, act_lens, 'mean')
73
+ act_values = self.linear_2(act_values).squeeze(1)
74
+
75
+ logits = [F.log_softmax(_, dim=0) for _ in act_values.split(sizes)]
76
+
77
+ loss = None
78
+ if labels is not None:
79
+ loss = - sum([logit[label] for logit, label in zip(logits, labels)]) / len(logits)
80
+
81
+ return SequenceClassifierOutput(
82
+ loss=loss,
83
+ logits=logits,
84
+ )
85
+
86
+ def rl_forward(self, state_batch, act_batch, value=False, q=False, act=False):
87
+ act_values = []
88
+ act_sizes = []
89
+ values = []
90
+ for state, valid_acts in zip(state_batch, act_batch):
91
+ with torch.set_grad_enabled(not act):
92
+ state_ids = torch.tensor([state.obs]).cuda()
93
+ state_mask = (state_ids > 0).int()
94
+ act_lens = [len(_) for _ in valid_acts]
95
+ act_ids = [torch.tensor(_) for _ in valid_acts]
96
+ act_ids = nn.utils.rnn.pad_sequence(act_ids, batch_first=True).cuda()
97
+ act_mask = (act_ids > 0).int()
98
+ act_size = torch.tensor([len(valid_acts)]).cuda()
99
+ if self.image_linear is not None:
100
+ images = [state.image_feat]
101
+ images = [torch.zeros(512) if _ is None else _ for _ in images]
102
+ images = torch.stack(images).cuda() # BS x 512
103
+ else:
104
+ images = None
105
+ logits = self.forward(state_ids, state_mask, act_ids, act_mask, act_size, images=images).logits[0]
106
+ act_values.append(logits)
107
+ act_sizes.append(len(valid_acts))
108
+ if value:
109
+ v = self.bert(state_ids, state_mask)[0]
110
+ values.append(self.linear_3(v[0][0]))
111
+ act_values = torch.cat(act_values, dim=0)
112
+ act_values = torch.cat([F.log_softmax(_, dim=0) for _ in act_values.split(act_sizes)], dim=0)
113
+ # Optionally, output state value prediction
114
+ if value:
115
+ values = torch.cat(values, dim=0)
116
+ return act_values, act_sizes, values
117
+ else:
118
+ return act_values, act_sizes
config.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "BertModelForWebshop"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "bert.BertConfigForWebshop",
7
+ "AutoModel": "bert.BertModelForWebshop"
8
+ },
9
+ "image": true,
10
+ "model_type": "bert",
11
+ "pretrained_bert": true,
12
+ "torch_dtype": "float32",
13
+ "transformers_version": "4.17.0"
14
+ }
modules.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import itertools
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from torch.nn.utils import rnn
6
+
7
+
8
+ def duplicate(output, mask, lens, act_sizes):
9
+ """
10
+ Duplicate the output based on the action sizes.
11
+ """
12
+ output = torch.cat([output[i:i+1].repeat(j, 1, 1) for i, j in enumerate(act_sizes)], dim=0)
13
+ mask = torch.cat([mask[i:i+1].repeat(j, 1) for i, j in enumerate(act_sizes)], dim=0)
14
+ lens = list(itertools.chain.from_iterable([lens[i:i+1] * j for i, j in enumerate(act_sizes)]))
15
+ return output, mask, lens
16
+
17
+
18
+ def get_aggregated(output, lens, method):
19
+ """
20
+ Get the aggregated hidden state of the encoder.
21
+ B x D
22
+ """
23
+ if method == 'mean':
24
+ return torch.stack([output[i, :j, :].mean(0) for i, j in enumerate(lens)], dim=0)
25
+ elif method == 'last':
26
+ return torch.stack([output[i, j-1, :] for i, j in enumerate(lens)], dim=0)
27
+ elif method == 'first':
28
+ return output[:, 0, :]
29
+
30
+
31
+ class EncoderRNN(nn.Module):
32
+ def __init__(self, input_size, num_units, nlayers, concat,
33
+ bidir, layernorm, return_last):
34
+ super().__init__()
35
+ self.layernorm = (layernorm == 'layer')
36
+ if layernorm:
37
+ self.norm = nn.LayerNorm(input_size)
38
+
39
+ self.rnns = []
40
+ for i in range(nlayers):
41
+ if i == 0:
42
+ input_size_ = input_size
43
+ output_size_ = num_units
44
+ else:
45
+ input_size_ = num_units if not bidir else num_units * 2
46
+ output_size_ = num_units
47
+ self.rnns.append(
48
+ nn.GRU(input_size_, output_size_, 1,
49
+ bidirectional=bidir, batch_first=True))
50
+
51
+ self.rnns = nn.ModuleList(self.rnns)
52
+ self.init_hidden = nn.ParameterList(
53
+ [nn.Parameter(
54
+ torch.zeros(size=(2 if bidir else 1, 1, num_units)),
55
+ requires_grad=True) for _ in range(nlayers)])
56
+ self.concat = concat
57
+ self.nlayers = nlayers
58
+ self.return_last = return_last
59
+
60
+ self.reset_parameters()
61
+
62
+ def reset_parameters(self):
63
+ with torch.no_grad():
64
+ for rnn_layer in self.rnns:
65
+ for name, p in rnn_layer.named_parameters():
66
+ if 'weight_ih' in name:
67
+ torch.nn.init.xavier_uniform_(p.data)
68
+ elif 'weight_hh' in name:
69
+ torch.nn.init.orthogonal_(p.data)
70
+ elif 'bias' in name:
71
+ p.data.fill_(0.0)
72
+ else:
73
+ p.data.normal_(std=0.1)
74
+
75
+ def get_init(self, bsz, i):
76
+ return self.init_hidden[i].expand(-1, bsz, -1).contiguous()
77
+
78
+ def forward(self, inputs, input_lengths=None):
79
+ bsz, slen = inputs.size(0), inputs.size(1)
80
+ if self.layernorm:
81
+ inputs = self.norm(inputs)
82
+ output = inputs
83
+ outputs = []
84
+ lens = 0
85
+ if input_lengths is not None:
86
+ lens = input_lengths # .data.cpu().numpy()
87
+ for i in range(self.nlayers):
88
+ hidden = self.get_init(bsz, i)
89
+ # output = self.dropout(output)
90
+ if input_lengths is not None:
91
+ output = rnn.pack_padded_sequence(output, lens,
92
+ batch_first=True,
93
+ enforce_sorted=False)
94
+ output, hidden = self.rnns[i](output, hidden)
95
+ if input_lengths is not None:
96
+ output, _ = rnn.pad_packed_sequence(output, batch_first=True)
97
+ if output.size(1) < slen:
98
+ # used for parallel
99
+ # padding = Variable(output.data.new(1, 1, 1).zero_())
100
+ padding = torch.zeros(
101
+ size=(1, 1, 1), dtype=output.type(),
102
+ device=output.device())
103
+ output = torch.cat(
104
+ [output,
105
+ padding.expand(
106
+ output.size(0),
107
+ slen - output.size(1),
108
+ output.size(2))
109
+ ], dim=1)
110
+ if self.return_last:
111
+ outputs.append(
112
+ hidden.permute(1, 0, 2).contiguous().view(bsz, -1))
113
+ else:
114
+ outputs.append(output)
115
+ if self.concat:
116
+ return torch.cat(outputs, dim=2)
117
+ return outputs[-1]
118
+
119
+
120
+ class BiAttention(nn.Module):
121
+ def __init__(self, input_size, dropout):
122
+ super().__init__()
123
+ self.dropout = nn.Dropout(dropout)
124
+ self.input_linear = nn.Linear(input_size, 1, bias=False)
125
+ self.memory_linear = nn.Linear(input_size, 1, bias=False)
126
+ self.dot_scale = nn.Parameter(
127
+ torch.zeros(size=(input_size,)).uniform_(1. / (input_size ** 0.5)),
128
+ requires_grad=True)
129
+ self.init_parameters()
130
+
131
+ def init_parameters(self):
132
+ return
133
+
134
+ def forward(self, context, memory, mask):
135
+ bsz, input_len = context.size(0), context.size(1)
136
+ memory_len = memory.size(1)
137
+ context = self.dropout(context)
138
+ memory = self.dropout(memory)
139
+
140
+ input_dot = self.input_linear(context)
141
+ memory_dot = self.memory_linear(memory).view(bsz, 1, memory_len)
142
+ cross_dot = torch.bmm(
143
+ context * self.dot_scale,
144
+ memory.permute(0, 2, 1).contiguous())
145
+ att = input_dot + memory_dot + cross_dot
146
+ att = att - 1e30 * (1 - mask[:, None])
147
+
148
+ weight_one = F.softmax(att, dim=-1)
149
+ output_one = torch.bmm(weight_one, memory)
150
+ weight_two = (F.softmax(att.max(dim=-1)[0], dim=-1)
151
+ .view(bsz, 1, input_len))
152
+ output_two = torch.bmm(weight_two, context)
153
+ return torch.cat(
154
+ [context, output_one, context * output_one,
155
+ output_two * output_one],
156
+ dim=-1)
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:32c00421372d607ffbd15f317f0040569dbe3cc7843f8885d2b54ffd2db9d0a8
3
+ size 449449751