Spaces:
Runtime error
Runtime error
File size: 4,831 Bytes
c3d0293 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
import torch
from torch import nn
import torch.nn.functional as F
import copy
from torch.nn import MultiheadAttention
from motion.model.layer_norm_fp16 import LayerNorm, RMSNorm
import numpy as np
import math
class SwiGLU(nn.Module):
'''
follow the structure of llama
'''
def __init__(self, dim, hidden_dim, multiple_of = 256):
super().__init__()
hidden_dim = int(2 * hidden_dim / 3)
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
self.w3 = nn.Linear(dim, hidden_dim, bias= False)
def forward(self, x):
return self.w2(F.silu(self.w1(x)) * self.w3(x))
def _get_activation_fn(activation: str):
if activation.lower() == "relu":
return F.relu
elif activation.lower() == "gelu":
return F.gelu
raise RuntimeError("activation should be relu/gelu, not {}".format(activation))
def _get_clones(module, N):
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
class RefinedLayer(nn.Module):
__constants__ = ['batch_first', 'norm_first']
def __init__(self, d_model, nhead, dim_feedforward = 2048, dropout = 0.1,
activation = F.relu, layer_norm_eps = 1e-5, device=None, dtype=None, max_seq_len=196, position_type="static", word_tokens=False, norm_type="rmsnorm", attention_type="torch"):
factory_kwargs = {'device': device, 'dtype': dtype, "bias":False}
super().__init__()
if norm_type.lower() == "rmsnorm":
Norm = RMSNorm
elif norm_type.lower() == "layer":
Norm = LayerNorm
self.attention_type = attention_type
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=False, **factory_kwargs)
if word_tokens:
self.cross_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=False, **factory_kwargs)
self.norm3 = Norm(d_model, layer_norm_eps)
self.dropout3 = nn.Dropout(dropout)
self.word_tokens = word_tokens
# Implementation of Feedforward model
self.norm1 = Norm(d_model, layer_norm_eps)
self.norm2 = Norm(d_model, layer_norm_eps)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
# Legacy string support for activation function.
if isinstance(activation, str) and activation.lower() != "swiglu":
activation = _get_activation_fn(activation)
self.linear1 = nn.Linear(d_model, dim_feedforward, **factory_kwargs)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model, **factory_kwargs)
self.ffn = self._ff_block
elif activation.lower() == "swiglu":
self.ffn = SwiGLU(d_model, dim_feedforward)
self.activation = activation
def forward(
self,
src,
word_tokens = None,
src_mask = None,
src_key_padding_mask = None):
x = src
x = x + self._sa_block(self.norm1(x), src_mask, src_key_padding_mask)
if self.word_tokens:
x = x + self._csa_block(self.norm3(x), word_tokens)
x = x + self.dropout2(self.ffn(self.norm2(x)))
return x
# encoder block
def _sa_block(self, x, attn_mask, key_padding_mask):
x = self.self_attn(x, x, x,
attn_mask=attn_mask,
key_padding_mask=key_padding_mask,
need_weights=False)[0]
return self.dropout1(x)
# multihead attention block
def _csa_block(self, x, mem, attn_mask=None, key_padding_mask=None):
x = self.cross_attn(x, mem, mem,
attn_mask=attn_mask,
key_padding_mask=key_padding_mask,
need_weights=False)[0]
return self.dropout3(x)
# feed forward block
def _ff_block(self, x):
x = self.linear2(self.dropout(self.activation(self.linear1(x))))
return x
class Refined_Transformer(nn.Module):
def __init__(self, refined_layer, num_layers):
super().__init__()
self.layers = _get_clones(refined_layer, num_layers)
self.num_layers = num_layers
def forward(
self,
src,
word_tokens=None,
src_mask=None,
src_key_padding_mask = None):
output = src
src_key_padding_mask_for_layers = src_key_padding_mask
for mod in self.layers:
output = mod(output, word_tokens=word_tokens, src_mask=src_mask, src_key_padding_mask=src_key_padding_mask_for_layers)
return output
|