File size: 3,834 Bytes
d0e1f8b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import  torch
from  torch import nn
import math


from pytorch_transformers.modeling_bert import(
	BertEncoder,
	BertPreTrainedModel,
	BertConfig
)

class GeLU(nn.Module):
	"""Implementation of the gelu activation function.
		For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
		0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
		Also see https://arxiv.org/abs/1606.08415
	"""

	def __init__(self):
		super().__init__()

	def forward(self, x):
		return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))

class BertLayerNorm(nn.Module):
	def __init__(self, hidden_size, eps=1e-12):
		"""Construct a layernorm module in the TF style (epsilon inside the square root).
		"""
		super(BertLayerNorm, self).__init__()
		self.weight = nn.Parameter(torch.ones(hidden_size))
		self.bias = nn.Parameter(torch.zeros(hidden_size))
		self.variance_epsilon = eps

	def forward(self, x):
		u = x.mean(-1, keepdim=True)
		s = (x - u).pow(2).mean(-1, keepdim=True)
		x = (x - u) / torch.sqrt(s + self.variance_epsilon)
		return self.weight * x + self.bias

class mlp_meta(nn.Module):
	def __init__(self, config):
		super().__init__()
		self.mlp = nn.Sequential(
			nn.Linear(config.hid_dim, config.hid_dim),
			GeLU(),
			BertLayerNorm(config.hid_dim, eps=1e-12),
			nn.Dropout(config.dropout),
		)

	def forward(self, x):
		return self.mlp(x)
	
class Bert_Transformer_Layer(BertPreTrainedModel):
	def __init__(self,fusion_config):
		super().__init__(BertConfig(**fusion_config))
		bertconfig_fusion = BertConfig(**fusion_config)
		self.encoder = BertEncoder(bertconfig_fusion)
		self.init_weights()
	
	def forward(self,input, mask=None):
		"""
		input:(bs, 4, dim)
		"""
		batch, feats, dim = input.size()
		if mask is not None:
			mask_ = torch.ones(size=(batch,feats), device=mask.device)
			mask_[:,1:] = mask
			mask_ = torch.bmm(mask_.view(batch,1,-1).transpose(1,2), mask_.view(batch,1,-1))
			mask_ = mask_.unsqueeze(1)
		
		else:
			mask = torch.Tensor([1.0]).to(input.device)
			mask_ = mask.repeat(batch,1,feats, feats)
		
		extend_mask = (1- mask_) * -10000
		assert not extend_mask.requires_grad
		head_mask = [None] * self.config.num_hidden_layers

		enc_output = self.encoder(
			input,extend_mask,head_mask=head_mask
		)
		output = enc_output[0]
		all_attention = enc_output[1]

		return output,all_attention
	
class mmdPreModel(nn.Module):
	def __init__(self, config, num_mlp=0, transformer_flag=False, num_hidden_layers=1, mlp_flag=True):
		super(mmdPreModel, self).__init__()
		self.num_mlp = num_mlp
		self.transformer_flag = transformer_flag
		self.mlp_flag = mlp_flag
		token_num = config.token_num
		self.mlp = nn.Sequential(
			nn.Linear(config.in_dim, config.hid_dim),
			GeLU(),
			BertLayerNorm(config.hid_dim, eps=1e-12),
			nn.Dropout(config.dropout),
			# nn.Linear(config.hid_dim, config.out_dim),
		)
		self.fusion_config = {
			'hidden_size': config.in_dim,
			'num_hidden_layers':num_hidden_layers,
			'num_attention_heads':4,
			'output_attentions':True
			}
		if self.num_mlp>0:
			self.mlp2 = nn.ModuleList([mlp_meta(config) for _ in range(self.num_mlp)])
		if self.transformer_flag:
			self.transformer = Bert_Transformer_Layer(self.fusion_config)
		self.feature = nn.Linear(config.hid_dim * token_num, config.out_dim)

	def forward(self, features):
		"""
		input: [batch, token_num, hidden_size], output: [batch, token_num * config.out_dim]
		"""

		if self.transformer_flag:
			features,_ = self.transformer(features)
		if self.mlp_flag:
			features = self.mlp(features)
		
		if self.num_mlp>0:
			# features = self.mlp2(features)
			for _ in range(1):
				for mlp in self.mlp2:
					features = mlp(features)

		features = self.feature(features.view(features.shape[0], -1))
		return features #features.view(features.shape[0], -1)