webshop commited on
Commit
7051b8b
1 Parent(s): 7eb4ec1
Files changed (3) hide show
  1. config.json +14 -0
  2. pytorch_model.bin +3 -0
  3. webshop_bert.py +126 -0
config.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "BertModelForWebshop"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "webshop_bert.BertConfigForWebshop",
7
+ "AutoModel": "webshop_bert.BertModelForWebshop"
8
+ },
9
+ "image": true,
10
+ "model_type": "bert",
11
+ "pretrained_bert": true,
12
+ "torch_dtype": "float32",
13
+ "transformers_version": "4.17.0"
14
+ }
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d507b59c465e300dd8e7bcc2ae541131c272f10af996267015e66291f397a397
3
+ size 449054239
webshop_bert.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from transformers import (
5
+ BertModel,
6
+ BertConfig,
7
+ PretrainedConfig,
8
+ PreTrainedModel,
9
+ )
10
+ from transformers.modeling_outputs import SequenceClassifierOutput
11
+
12
+
13
+ class BertConfigForWebshop(PretrainedConfig):
14
+ model_type = "bert"
15
+
16
+ def __init__(
17
+ self,
18
+ pretrained_bert=True,
19
+ image=False,
20
+ **kwargs
21
+ ):
22
+ self.pretrained_bert = pretrained_bert
23
+ self.image = image
24
+ super().__init__(**kwargs)
25
+
26
+
27
+ class BiAttention(nn.Module):
28
+ def __init__(self, input_size, dropout):
29
+ super().__init__()
30
+ self.dropout = nn.Dropout(dropout)
31
+ self.input_linear = nn.Linear(input_size, 1, bias=False)
32
+ self.memory_linear = nn.Linear(input_size, 1, bias=False)
33
+ self.dot_scale = nn.Parameter(
34
+ torch.zeros(size=(input_size,)).uniform_(1. / (input_size ** 0.5)),
35
+ requires_grad=True)
36
+ self.init_parameters()
37
+
38
+ def init_parameters(self):
39
+ return
40
+
41
+ def forward(self, context, memory, mask):
42
+ bsz, input_len = context.size(0), context.size(1)
43
+ memory_len = memory.size(1)
44
+ context = self.dropout(context)
45
+ memory = self.dropout(memory)
46
+
47
+ input_dot = self.input_linear(context)
48
+ memory_dot = self.memory_linear(memory).view(bsz, 1, memory_len)
49
+ cross_dot = torch.bmm(
50
+ context * self.dot_scale,
51
+ memory.permute(0, 2, 1).contiguous())
52
+ att = input_dot + memory_dot + cross_dot
53
+ att = att - 1e30 * (1 - mask[:, None])
54
+
55
+ weight_one = F.softmax(att, dim=-1)
56
+ output_one = torch.bmm(weight_one, memory)
57
+ weight_two = (F.softmax(att.max(dim=-1)[0], dim=-1)
58
+ .view(bsz, 1, input_len))
59
+ output_two = torch.bmm(weight_two, context)
60
+ return torch.cat(
61
+ [context, output_one, context * output_one,
62
+ output_two * output_one],
63
+ dim=-1)
64
+
65
+
66
+ class BertModelForWebshop(PreTrainedModel):
67
+
68
+ config_class = BertConfigForWebshop
69
+
70
+ def __init__(self, config):
71
+ super().__init__(config)
72
+ bert_config = BertConfig.from_pretrained('bert-base-uncased')
73
+ if config.pretrained_bert:
74
+ self.bert = BertModel.from_pretrained('bert-base-uncased')
75
+ else:
76
+ self.bert = BertModel(config)
77
+ self.bert.resize_token_embeddings(30526)
78
+ self.attn = BiAttention(768, 0.0)
79
+ self.linear_1 = nn.Linear(768 * 4, 768)
80
+ self.relu = nn.ReLU()
81
+ self.linear_2 = nn.Linear(768, 1)
82
+ if config.image:
83
+ self.image_linear = nn.Linear(512, 768)
84
+ else:
85
+ self.image_linear = None
86
+
87
+ @staticmethod
88
+ def get_aggregated(output, lens, method):
89
+ """
90
+ Get the aggregated hidden state of the encoder.
91
+ B x D
92
+ """
93
+ if method == 'mean':
94
+ return torch.stack([output[i, :j, :].mean(0) for i, j in enumerate(lens)], dim=0)
95
+ elif method == 'last':
96
+ return torch.stack([output[i, j-1, :] for i, j in enumerate(lens)], dim=0)
97
+ elif method == 'first':
98
+ return output[:, 0, :]
99
+
100
+ def forward(self, state_input_ids, state_attention_mask, action_input_ids, action_attention_mask, sizes, images=None, labels=None):
101
+ sizes = sizes.tolist()
102
+ # print(state_input_ids.shape, action_input_ids.shape)
103
+ state_rep = self.bert(state_input_ids, attention_mask=state_attention_mask)[0]
104
+ if images is not None and self.image_linear is not None:
105
+ images = self.image_linear(images)
106
+ state_rep = torch.cat([images.unsqueeze(1), state_rep], dim=1)
107
+ state_attention_mask = torch.cat([state_attention_mask[:, :1], state_attention_mask], dim=1)
108
+ action_rep = self.bert(action_input_ids, attention_mask=action_attention_mask)[0]
109
+ state_rep = torch.cat([state_rep[i:i+1].repeat(j, 1, 1) for i, j in enumerate(sizes)], dim=0)
110
+ state_attention_mask = torch.cat([state_attention_mask[i:i+1].repeat(j, 1) for i, j in enumerate(sizes)], dim=0)
111
+ act_lens = action_attention_mask.sum(1).tolist()
112
+ state_action_rep = self.attn(action_rep, state_rep, state_attention_mask)
113
+ state_action_rep = self.relu(self.linear_1(state_action_rep))
114
+ act_values = self.get_aggregated(state_action_rep, act_lens, 'mean')
115
+ act_values = self.linear_2(act_values).squeeze(1)
116
+
117
+ logits = [F.log_softmax(_, dim=0) for _ in act_values.split(sizes)]
118
+
119
+ loss = None
120
+ if labels is not None:
121
+ loss = - sum([logit[label] for logit, label in zip(logits, labels)]) / len(logits)
122
+
123
+ return SequenceClassifierOutput(
124
+ loss=loss,
125
+ logits=logits,
126
+ )