# import sonnet as snt # from tensor2tensor.layers import common_attention # from tensor2tensor.layers import common_layers # import tensorflow.compat.v1 as tf # from tensorflow.python.framework import function # import tensorflow_probability as tfp import numpy as np import torch.nn as nn # import layer_utils import torch # import data_utils_torch as data_utils import math ## # from options.options import opt ### class TransformerEncoder(nn.Module): def __init__(self, hidden_size=256, fc_size=1024, num_heads=4, layer_norm=True, num_layers=8, dropout_rate=0.2, re_zero=True, memory_efficient=False, ): super(TransformerEncoder, self).__init__() ## hidden size, fc size, ## ## hidden size, fc size, num heads, layer_norm, num_layers, dropout_rate, self.hidden_size = hidden_size self.fc_size = fc_size ## fc_size ## self.num_heads = num_heads ## num_heads ## # self.num_heads = 1 self.layer_norm = layer_norm self.num_layers = num_layers ## num_layers ## self.dropout_rate = dropout_rate self.re_zero = re_zero self.memory_efficient = memory_efficient ### Attention layer and related modules ### self.attention_layers = nn.ModuleList() if self.layer_norm: self.layer_norm_layers = nn.ModuleList() if self.re_zero: self.re_zero_vars = nn.ParameterList() if self.dropout_rate: # dropout rate self.dropout_layers = nn.ModuleList() for i in range(self.num_layers): ## dropout rate, kdim, vdim, cur_atten_layer = nn.MultiheadAttention( ## hidden_size, hidden_size ## self.hidden_size, self.num_heads, dropout=0.0, bias=True, kdim=self.hidden_size, vdim=self.hidden_size, batch_first=True) self.attention_layers.append(cur_atten_layer) if self.layer_norm: ## layernorm ## cur_layer_norm = nn.LayerNorm(self.hidden_size) self.layer_norm_layers.append(cur_layer_norm) if self.re_zero: cur_re_zero_var = torch.nn.Parameter(torch.zeros(size=(1,), dtype=torch.float32, requires_grad=True), requires_grad=True) self.re_zero_vars.append(cur_re_zero_var) if self.dropout_rate: cur_dropout_layer = nn.Dropout(p=self.dropout_rate) self.dropout_layers.append(cur_dropout_layer) ### Attention layer and related modules ### self.fc_layers = nn.ModuleList() if self.layer_norm: self.fc_layer_norm_layers = nn.ModuleList() if self.re_zero: self.fc_re_zero_vars = nn.ParameterList() if self.dropout_rate: self.fc_dropout_layers = nn.ModuleList() # dropout layers for i in range(self.num_layers): cur_fc_layer = nn.Linear(in_features=self.hidden_size, out_features=self.fc_size, bias=True) cur_fc_layer_2 = nn.Linear(in_features=self.fc_size, out_features=self.hidden_size, bias=True) self.fc_layers.append(nn.Sequential(*[cur_fc_layer, cur_fc_layer_2])) if self.layer_norm: # layer norm cur_layer_norm = nn.LayerNorm(self.hidden_size) self.fc_layer_norm_layers.append(cur_layer_norm) if self.re_zero: # re_zero_var cur_re_zero_var = torch.nn.Parameter( torch.zeros(size=(1,), dtype=torch.float32, requires_grad=True), requires_grad=True) self.fc_re_zero_vars.append(cur_re_zero_var) if self.dropout_rate: cur_dropout_layer = nn.Dropout(p=self.dropout_rate) self.fc_dropout_layers.append(cur_dropout_layer) if self.layer_norm: self.out_layer_norm = nn.LayerNorm(self.hidden_size) def forward(self, inputs, set_attn_to_none=False): ### padding # bsz x seq_length x embedding_dim # bsz, seq_length = inputs.size(0), inputs.size(1) if set_attn_to_none: atten_mask = None else: atten_mask = np.tri(seq_length, seq_length, -1.0, dtype=np.float32).T # tri ### elements in the upper triangle are set to 1.0 ### atten_mask = torch.from_numpy(atten_mask).float() # .cuda() atten_mask = atten_mask.to(inputs.device) atten_mask = atten_mask > 0.5 ## the bool format # if inputs_mask is None: # encoder_padding = layer_utils.embedding_to_padding(inputs) # bsz x n_vertices # else: # encoder_padding = inputs_mask # inputs_mask: bsz x n_vertices # bsz = inputs.size(0) # seq_length = inputs.size(1) # ## attention masksingle direction ## need # # encoder_self_attention_bias = layer_utils.attention_bias_ignore_padding(encoder_padding) # # encoder_self_attention_mask = layer_utils.attention_mask(encoder_padding) # encoder_self_attention_mask = layer_utils.attention_mask_single_direction(encoder_padding) # # print(f"in vertex model forwarding function, encoder_self_attention_mask: {encoder_self_attention_mask.size()}, inputs: {inputs.size()}") # encoder_self_attention_mask = encoder_self_attention_mask.unsqueeze(1).repeat(1, self.num_heads, 1, 1) # encoder_self_attention_mask = encoder_self_attention_mask.contiguous().view(bsz * self.num_heads, seq_length, seq_length).contiguous() # seq_length = inputs.size(1) x = inputs ## bsz x seq_length x # bsz x seq x seq for the mask # # zeros padding layer ## remember to add that! # zero padding layer # # atten_mask = np.tri(seq_length, seq_length, -1.0, dtype=np.float32).T # mask; # atten_mask = torch.from_numpy(atten_mask).float().cuda() ## mask single direction ## encode for each for i in range(self.num_layers): res = x.clone() if self.layer_norm: res = self.layer_norm_layers[i](res) ## res, res ## ## layernorm layers ## # print(f"before attention {i}/{self.num_layers}, res: {res.size()}") # res, _ = self.attention_layers[i](res, res, res, attn_mask=atten_mask) ## attentiion layers ### ## self-attention ## ## memory, q, k, v ## bsz x seq x latnetdim --> bsz x seq x seq for frame-frame weights ### ### bsz x seq x seq --> weights ### initialize something to zero for modeling controls ### res, _ = self.attention_layers[i](res, res, res, attn_mask=atten_mask) # print(f"after attention {i}/{self.num_layers}, res: {res.size()}") if self.re_zero: res = res * self.re_zero_vars[i] if self.dropout_rate: res = self.dropout_layers[i](res) x = x + res res = x.clone() if self.layer_norm: res = self.fc_layer_norm_layers[i](res) # fc norm # res = self.fc_layers[i](res) if self.re_zero: res = res * self.fc_re_zero_vars[i] if self.dropout_rate: res = self.fc_dropout_layers[i](res) x = x + res if self.layer_norm: x = self.out_layer_norm(x) return x class TransformerDecoder(nn.Module): def __init__(self, hidden_size=256, fc_size=1024, num_heads=4, layer_norm=True, num_layers=8, dropout_rate=0.2, re_zero=True, with_seq_context=False ): super(TransformerDecoder, self).__init__() self.hidden_size = hidden_size self.fc_size = fc_size self.num_heads = num_heads self.layer_norm = layer_norm self.num_layers = num_layers self.dropout_rate = dropout_rate self.re_zero = re_zero self.with_seq_context = with_seq_context # self.context_window = opt.model.context_window self.atten_mask = None self.context_atten_mask = None # self.prefix_key_len = opt.model.prefix_key_len ## can add prefix key values for prefix queries ## # self.prefix_value_len = opt.model.prefix_value_len ## can add prefix key values for prefix queries ## # self.prefix_value_len = value length # ### Attention layer and related modules ### self.attention_layers = nn.ModuleList() if self.layer_norm: self.layer_norm_layers = nn.ModuleList() if self.re_zero: self.re_zero_vars = nn.ParameterList() if self.dropout_rate: self.dropout_layers = nn.ModuleList() for i in range(self.num_layers): cur_atten_layer = nn.MultiheadAttention( self.hidden_size, self.num_heads, dropout=0.0, bias=True, kdim=self.hidden_size, vdim=self.hidden_size, batch_first=True) self.attention_layers.append(cur_atten_layer) if self.layer_norm: cur_layer_norm = nn.LayerNorm(self.hidden_size) self.layer_norm_layers.append(cur_layer_norm) if self.re_zero: cur_re_zero_var = torch.nn.Parameter(torch.zeros(size=(1,), dtype=torch.float32, requires_grad=True), requires_grad=True) self.re_zero_vars.append(cur_re_zero_var) if self.dropout_rate: ## dropout cur_dropout_layer = nn.Dropout(p=self.dropout_rate) self.dropout_layers.append(cur_dropout_layer) if self.with_seq_context: ##### attention, re_zero, dropout layers for the context attention layers ##### self.context_attention_layers = nn.ModuleList() if self.layer_norm: self.context_norm_layers = nn.ModuleList() if self.re_zero: self.context_re_zero_vars = nn.ParameterList() if self.dropout_rate: self.context_dropout_layers = nn.ModuleList() for i in range(self.num_layers): cur_atten_layer = nn.MultiheadAttention( self.hidden_size, self.num_heads, dropout=0.0, bias=True, kdim=self.hidden_size, vdim=self.hidden_size, batch_first=True) self.context_attention_layers.append(cur_atten_layer) if self.layer_norm: cur_layer_norm = nn.LayerNorm(self.hidden_size) self.context_norm_layers.append(cur_layer_norm) if self.re_zero: cur_re_zero_var = torch.nn.Parameter(torch.zeros(size=(1,), dtype=torch.float32, requires_grad=True), requires_grad=True) self.context_re_zero_vars.append(cur_re_zero_var) if self.dropout_rate: cur_dropout_layer = nn.Dropout(p=self.dropout_rate) # dropout layers self.context_dropout_layers.append(cur_dropout_layer) ### Attention layer and related modules ### self.fc_layers = nn.ModuleList() if self.layer_norm: self.fc_layer_norm_layers = nn.ModuleList() if self.re_zero: self.fc_re_zero_vars = nn.ParameterList() # self.fc_re_zero_vars = nn.ModuleList() if self.dropout_rate: self.fc_dropout_layers = nn.ModuleList() for i in range(self.num_layers): cur_fc_layer = nn.Linear(in_features=self.hidden_size, out_features=self.fc_size, bias=True) cur_fc_layer_2 = nn.Linear(in_features=self.fc_size, out_features=self.hidden_size, bias=True) self.fc_layers.append(nn.Sequential(*[cur_fc_layer, cur_fc_layer_2])) if self.layer_norm: cur_layer_norm = nn.LayerNorm(self.hidden_size) self.fc_layer_norm_layers.append(cur_layer_norm) if self.re_zero: cur_re_zero_var = torch.nn.Parameter( torch.zeros(size=(1,), dtype=torch.float32, requires_grad=True), requires_grad=True) self.fc_re_zero_vars.append(cur_re_zero_var) if self.dropout_rate: ## dropout rate ## cur_dropout_layer = nn.Dropout(p=self.dropout_rate) self.fc_dropout_layers.append(cur_dropout_layer) if self.layer_norm: self.out_layer_norm = nn.LayerNorm(self.hidden_size) def forward(self, inputs, sequential_context_embeddings=None): seq_length = inputs.size(1) bsz = inputs.size(0) # #### ## sequential context embeddings for the embedding --> bsz x seq_length x feat_dim #### # TODO: mask for inputs can be set to 1) None, then a fully-attention setting, 2) self-mask setting # # ### sequential context mask should be set to a self-mask setting -> each self element can attend to self and before information ### # atten_mask = None ## mask for inputs if sequential_context_embeddings is not None: # sequential_context_mask = np.tri(seq_length, seq_length, -1.0, dtype=np.float32).T # tri ## triangle mask ## ## # sequential_context_mask = np.tri(inputs.size(1), sequential_context_embeddings.size(1), -1.0, dtype=np.float32).T # tri # 1 x 30 --> no mask ! sequential_context_mask = np.tri(sequential_context_embeddings.size(1), inputs.size(1), -1.0, dtype=np.float32).T # tri sequential_context_mask = torch.from_numpy(sequential_context_mask).float() # .cuda() sequential_context_mask = sequential_context_mask.to(inputs.device) sequential_context_mask = sequential_context_mask > 0.5 # # print(f"inputs: {inputs.size()}") #### # if self.training: # if self.atten_mask is None: ## seq length ## # atten_mask = np.tri(seq_length, seq_length, -1.0, dtype=np.float32).T # tri # # atten_mask = np.tri(seq_length, seq_length, 0.0, dtype=np.float32) # atten_mask = torch.from_numpy(atten_mask).float().cuda() # self.atten_mask = atten_mask # else: # atten_mask = self.atten_mask # else: ### atten_mask # atten_mask = np.tri(seq_length, seq_length, -1.0, dtype=np.float32).T # tri # # atten_mask = np.tri(seq_length, seq_length, 0.0, dtype=np.float32) # atten_mask = torch.from_numpy(atten_mask).float().cuda() # context_window # if self.context_window > 0 and sequential_context_embeddings is None: # # ##### add global context embeddings to embedding vectors ##### # # # inputs = inputs[:, 0:1] + inputs # add the contextual information to inputs # not add... # # if opt.model.debug: # # print(f"Using context window {self.context_window} for decoding...") # if self.training: # if self.context_atten_mask is None: # context_atten_mask = np.tri(seq_length, seq_length, -1.0 * float(self.context_window), dtype=np.float32) # context_atten_mask = torch.from_numpy(context_atten_mask).float().cuda() # self.context_atten_mask = context_atten_mask # else: # context_atten_mask = self.context_atten_mask # else: # context_atten_mask = np.tri(seq_length, seq_length, -1.0 * float(self.context_window), dtype=np.float32) # context_atten_mask = torch.from_numpy(context_atten_mask).float().cuda() # atten_mask = context_atten_mask + atten_mask # # context attention mask # atten_mask = (atten_mask > 0.5) # if len(atten_mask.size()) == 2: # atten_mask[: self.prefix_key_len, ] = False # else: # atten_mask[:, : self.prefix_key_len] = False # print(atten_mask) # if sequential_context_embeddings is not None: # context_length = sequential_context_embeddings.size(1) # # sequential_context_padding = layer_utils.embedding_to_padding(sequential_context_embeddings) # if sequential_context_mask is None: # sequential_context_padding = layer_utils.embedding_to_padding(sequential_context_embeddings) # else: # # sequential_context_padding = 1. - sequential_context_mask.float() # sequential context mask? # # sequential_context_padding = layer_utils.embedding_to_padding(sequential_context_embeddings) # # j # sequential_context_atten_mask = layer_utils.attention_mask_single_direction(sequential_context_padding, other_len=seq_length) # # print(f"in decoder's forward function, sequential_context_padding: {sequential_context_padding.size()}, sequential_context_atten_mask: {sequential_context_atten_mask.size()}") # sequential_context_atten_mask = sequential_context_atten_mask.unsqueeze(1).repeat(1, self.num_heads, 1, 1) # sequential_context_atten_mask = sequential_context_atten_mask.contiguous().view(bsz * self.num_heads, seq_length, context_length).contiguous() x = inputs for i in range(self.num_layers): res = x.clone() if self.layer_norm: res = self.layer_norm_layers[i](res) # # self attention; all self attention; sequential res, _ = self.attention_layers[i](res, res, res, attn_mask=atten_mask) if self.re_zero: res = res * self.re_zero_vars[i].unsqueeze(0).unsqueeze(0) if self.dropout_rate: res = self.dropout_layers[i](res) x = x + res # if we use sequential context embeddings if sequential_context_embeddings is not None: # for sequential context embedding res = x.clone() # then layer_norm, attention layer, re_zero layer and the dropout layer if self.layer_norm: res = self.context_norm_layers[i](res) ## need sequential masks! res can only attent to former sequential contexts ## ## res, _ = self.context_attention_layers[i](res, sequential_context_embeddings, sequential_context_embeddings, attn_mask=sequential_context_mask) if self.re_zero: res = res * self.context_re_zero_vars[i].unsqueeze(0).unsqueeze(0) if self.dropout_rate: res = self.context_dropout_layers[i](res) x = x + res res = x.clone() if self.layer_norm: res = self.fc_layer_norm_layers[i](res) res = self.fc_layers[i](res) if self.re_zero: res = res * self.fc_re_zero_vars[i] if self.dropout_rate: # dropout layers # fc_dropout_layers res = self.fc_dropout_layers[i](res) x = x + res if self.layer_norm: x = self.out_layer_norm(x) # x = x[:, self.prefix_key_len - 1: ] return x