Spaces:
Runtime error
Runtime error
import sys | |
import os | |
THIS_DIR = os.path.dirname(os.path.abspath(__file__)) | |
ROOT_DIR = os.path.abspath(os.path.join(THIS_DIR, os.pardir)) | |
sys.path.append(ROOT_DIR) | |
sys.path.append(THIS_DIR) | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import uuid | |
import numpy as np | |
from functools import partial | |
from torch.nn import TransformerEncoder, TransformerEncoderLayer, Transformer | |
#from models.x_transformers import ContinuousTransformerWrapper, Decoder, Encoder, AutoregressiveWrapper | |
from x_transformers import ContinuousTransformerWrapper, Decoder, Encoder, AutoregressiveWrapper | |
def pick_and_pop(keys, d): | |
values = list(map(lambda key: d.pop(key), keys)) | |
return dict(zip(keys, values)) | |
def groupby_prefix_and_trim(prefix, d): | |
kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d) | |
kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items()))) | |
return kwargs_without_prefix, kwargs | |
def group_dict_by_key(cond, d): | |
return_val = [dict(),dict()] | |
for key in d.keys(): | |
match = bool(cond(key)) | |
ind = int(not match) | |
return_val[ind][key] = d[key] | |
return (*return_val,) | |
def string_begins_with(prefix, str): | |
return str.startswith(prefix) | |
class PositionalEncoding(nn.Module): | |
def __init__(self, d_model, dropout=0.1, max_len=5000, device=None): | |
super(PositionalEncoding, self).__init__() | |
self.device = device | |
self.dropout = nn.Dropout(p=dropout) | |
self.lpe = nn.Embedding(max_len+1, d_model) | |
# self.weight = None | |
self.indices = torch.arange(max_len).unsqueeze(1) + 1 | |
if device is not None: | |
self.indices = self.indices.to(self.device) | |
def init_weights(self): | |
initrange = 0.1 | |
self.lpe.weight.data.uniform_(-initrange, initrange) | |
def forward(self, x, indices = None): | |
np.save(str(uuid.uuid4())+".np",self.lpe.weight.data.cpu().numpy()) | |
if indices is None: | |
indices = self.indices[:x.size(0),:] | |
indices = self.dropout(indices) | |
x = x + self.lpe(indices) | |
return self.dropout(x) | |
class LearnedPositionalEncoding(nn.Module): # emm this isn't learned lol | |
def __init__(self, d_model, input_length, dropout=0.1, device=None): | |
super(LearnedPositionalEncoding, self).__init__() | |
self.dropout = nn.Dropout(p=dropout) | |
pe = nn.Parameter((torch.zeros(input_length, 1, d_model))) | |
self.register_buffer('pe', pe) | |
def forward(self, x): | |
# print(x.shape) | |
x = x + self.pe | |
return self.dropout(x) | |
class BasicTransformerModel(nn.Module): | |
def __init__(self, dout, dinp, nhead, dhid, nlayers, dropout=0.5, ntokens=0, device=None, use_pos_emb=False, use_rel_pos_emb=False, input_length=0,use_x_transformers=False, opt=None, discrete_inputs=False): | |
super(BasicTransformerModel, self).__init__() | |
self.device = device | |
self.model_type = 'Transformer' | |
self.use_x_transformers = use_x_transformers | |
self.discrete_inputs = discrete_inputs | |
if not use_x_transformers: | |
if discrete_inputs: | |
assert ntokens != 0 #ntoken needs to be set if we are to use an embedding layer (discrete inputs) | |
self.encoder = nn.Embedding(ntokens, dinp) | |
self.encoder1 = nn.Linear(dinp, dhid) | |
if use_pos_emb: | |
self.pos_encoder = LearnedPositionalEncoding(dhid, input_length, dropout) | |
encoder_layers = TransformerEncoderLayer(dhid, nhead, dhid, dropout) | |
self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers) | |
self.dinp = dinp | |
self.dhid = dhid | |
self.decoder = nn.Linear(dhid, dout) | |
self.use_pos_emb = use_pos_emb | |
self.use_rel_pos_emb = use_rel_pos_emb | |
self.input_length = input_length | |
if use_rel_pos_emb and input_length > 1: | |
#assert input_length > 0 | |
self.pos_emb = nn.Parameter((torch.zeros(input_length, input_length))) | |
# self.pos_emb = nn.Parameter((torch.eye(input_length, input_length))) | |
# self.pos_emb = nn.Parameter((torch.randn(input_length, input_length))/np.sqrt(dinp)) | |
self.init_weights() | |
#self.pos_encoder.init_weights() | |
else: | |
if discrete_inputs: | |
assert ntoken is not None #ntoken needs to be set if we are to use an embedding layer (discrete inputs) | |
self.encoder = nn.Embedding(ntoken, dinp) | |
self.model = ContinuousTransformerWrapper( | |
dim_in = dinp, | |
dim_out = dout, | |
max_seq_len = 1024, | |
use_pos_emb = use_pos_emb, | |
attn_layers = Encoder( | |
dim = dhid, | |
depth = nlayers, | |
heads = nhead, | |
rotary_pos_emb = opt.use_rotary_pos_emb, | |
#rel_pos_bias = True | |
) | |
) | |
def generate_square_subsequent_mask(self, sz, prefix_length = 1): | |
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1) | |
mask[:,:prefix_length] = 1 | |
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0)) | |
return mask | |
def init_weights(self): | |
initrange = 0.1 | |
# self.encoder.weight.data.uniform_(-initrange, initrange) | |
self.decoder.bias.data.zero_() | |
self.decoder.weight.data.uniform_(-initrange, initrange) | |
def forward(self, src, src_mask=None): | |
#print(src.shape) | |
if self.discrete_inputs: | |
src = self.encoder(src.squeeze(0)) | |
if not self.use_x_transformers: | |
src = self.encoder1(src) | |
# import pdb;pdb.set_trace() | |
#src *= math.sqrt(self.dhid) | |
if self.use_pos_emb: | |
src = self.pos_encoder(src) | |
#src /= math.sqrt(self.dhid) | |
# print(src) | |
# print(torch.mm(src[:,0,:],src[:,0,:].T)) | |
if self.use_rel_pos_emb and self.input_length > 1: | |
#print(self.pos_emb) | |
#src_mask += self.pos_emb | |
if src_mask is not None: | |
output = self.transformer_encoder(src, src_mask + self.pos_emb) | |
else: | |
output = self.transformer_encoder(src, self.pos_emb) | |
#output = self.transformer_encoder(src, self.pos_emb) | |
else: | |
if src_mask is not None: | |
output = self.transformer_encoder(src, src_mask) | |
else: | |
output = self.transformer_encoder(src) | |
#output = self.transformer_encoder(src) | |
output = self.decoder(output) | |
#print(output.shape) | |
if self.discrete_inputs: | |
#return output.unsqueeze(0) | |
return output.permute(1,0,2) | |
else: | |
return output | |
else: | |
assert src_mask == None | |
src = src.permute(1,0,2) | |
mask = torch.ones(src.shape[0], src.shape[1]).bool().cuda() | |
output = self.model(src, mask = mask) | |
# output = self.model(src.permute(1,0,2)) | |
return output.permute(1,0,2) | |
class BasicTransformerModelCausal(nn.Module): | |
def __init__(self, dout, dinp, nhead, dhid, nlayers, dropout=0.5, ntokens=0, device=None, use_pos_emb=False, use_rel_pos_emb=False, input_length=0,use_x_transformers=False, opt=None, discrete_inputs=False): | |
self.model = BasicTransformerModel(self, dout, dinp, nhead, dhid, nlayers, dropout=0.5, ntokens=0, device=None, use_pos_emb=False, use_rel_pos_emb=False, input_length=0,use_x_transformers=False, opt=None, discrete_inputs=False) | |
self.mask = self.model.generate_square_subsequent_mask(input_length) | |
def init_weights(self): | |
self.model.init_weights() | |
def generate_square_subsequent_mask(self, sz, prefix_length = 1): | |
self.model.generate_square_subsequent_mask(sz,prefix_length) | |
def forward(self, src, src_mask=None): | |
return self.model(src,src_mask=self.mask) | |
class EncDecTransformerModel(nn.Module): | |
def __init__(self, dout, src_d, tgt_d, nhead, dhid, nlayers, dropout=0.5,device=None,use_pos_emb=False,src_length=0,tgt_length=0,use_x_transformers=False,opt=None): | |
super(EncDecTransformerModel, self).__init__() | |
self.device = device | |
self.model_type = 'Transformer' | |
self.use_x_transformers = use_x_transformers | |
self.encoder1 = nn.Linear(src_d, dhid) | |
self.encoder2 = nn.Linear(tgt_d, dhid) | |
if not use_x_transformers: | |
self.transformer = Transformer(d_model=dhid, nhead=nhead, num_encoder_layers=nlayers, num_decoder_layers=nlayers, dropout=0, activation="relu") | |
#enc_layer = TransformerEncoderLayer(d_model=dhid, nhead=nhead, dropout=0, activation="relu") | |
#self.transformerEnc = TransformerEncoder(enc_layer, nlayers) | |
else: | |
self.transformer = EncDecXTransformer(enc_dim_in=src_d, enc_dim_out=tgt_d, dec_din_in=tgt_d, edec_dim_out=dout, enc_dim=dhid, dec_dim=dhid, nc_heads=nhead, dec_heads=nhead, enc_depth=nlayers, dec_depth=nlayers, enc_dropout=dropout, dec_dropout=dropout, enc_max_seq_len=1024, dec_max_seq_len=1024) | |
#xdecoder = Decoder(dim=dhid, depth=nlayers, heads=nhead, cross_attend=True) | |
#self.transformer = Transformer(d_model=dhid, nhead=nhead, num_encoder_layers=nlayers, num_decoder_layers=nlayers, dropout=dropout, activation="gelu", custom_decoder=xdecoder) | |
#self.transformer = Transformer(d_model=dhid, nhead=nhead, num_encoder_layers=nlayers, num_decoder_layers=nlayers) | |
# self.encoder = nn.Embedding(ntoken, dinp) | |
self.src_d = src_d | |
self.tgt_d = tgt_d | |
self.dhid = dhid | |
self.decoder = nn.Linear(dhid, dout) | |
self.use_pos_emb = use_pos_emb | |
if use_pos_emb: | |
assert src_length > 0 | |
assert tgt_length > 0 | |
self.src_pos_emb = nn.Parameter((torch.zeros(src_length, src_length))) | |
self.tgt_pos_emb = nn.Parameter((torch.zeros(tgt_length, tgt_length))) | |
if not use_x_transformers: | |
tgt_mask = self.generate_square_subsequent_mask(tgt_length) | |
else: | |
tgt_mask = self.generate_square_subsequent_mask_bool(tgt_length) | |
self.register_buffer("tgt_mask", tgt_mask) | |
#a = torch.randn(32,3,512) | |
#b = torch.randn(32,3,512) | |
#self.register_buffer('a', a) | |
#self.register_buffer('b', b) | |
self.init_weights() | |
def generate_square_subsequent_mask(self, sz): | |
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1) | |
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0)) | |
return mask | |
def generate_square_subsequent_mask_bool(self, sz): | |
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1).bool() | |
return mask | |
def init_weights(self): | |
initrange = 0.1 | |
# self.encoder.weight.data.uniform_(-initrange, initrange) | |
self.decoder.bias.data.zero_() | |
self.decoder.weight.data.uniform_(-initrange, initrange) | |
def forward(self, src, tgt): | |
if not self.use_x_transformers: | |
# import pdb;pdb.set_trace() | |
src = self.encoder1(src) | |
tgt = self.encoder2(tgt) | |
tgt_mask = self.tgt_mask[:tgt.shape[0], :tgt.shape[0]] | |
if self.use_pos_emb: | |
tgt_pos_emb = self.tgt_pos_emb[:tgt.shape[0], :tgt.shape[0]] | |
# import pdb;pdb.set_trace() | |
output = self.transformer(src=src, tgt=tgt, src_mask=self.src_pos_emb, tgt_mask=tgt_pos_emb+tgt_mask) | |
#output = self.transformer(src=src, tgt=tgt) | |
else: | |
output = self.transformer(src=src, tgt=tgt, tgt_mask=tgt_mask) | |
#output = self.transformer(src=self.a, tgt=self.b) | |
#output = self.transformerEnc(src=self.a) | |
output = self.decoder(output) | |
return output | |
else: | |
src = self.encoder1(src) | |
tgt = self.encoder2(tgt) | |
#tgt_mask = self.tgt_mask[:tgt.shape[0], :tgt.shape[0]] | |
# if self.use_pos_emb: | |
# tgt_pos_emb = self.tgt_pos_emb[:tgt.shape[0], :tgt.shape[0]] | |
# # import pdb;pdb.set_trace() | |
# output = self.transformer(src=src.permute(1,2,0), tgt=tgt.permute(1,2,0), src_mask=self.src_pos_emb, tgt_mask=tgt_pos_emb+tgt_mask) | |
# #output = self.transformer(src=src, tgt=tgt) | |
# else: | |
# output = self.transformer(src=src.permute(1,0,2), tgt=tgt.permute(1,0,2), tgt_mask=tgt_mask) | |
# output = self.transformer(src=src, tgt=tgt, tgt_mask=tgt_mask) | |
output = self.transformer(src=src, tgt=tgt) | |
#output = self.transformer(src=self.a, tgt=self.b) | |
#output = self.transformer(src=tgt, tgt=tgt) #hmm thats an interesting way of residual attention | |
output = self.decoder(output) | |
return output | |
#return | |
class EncDecXTransformer(nn.Module): | |
def __init__( | |
self, | |
*, | |
# dim, | |
tie_token_emb = False, | |
**kwargs | |
): | |
super().__init__() | |
enc_kwargs, kwargs = groupby_prefix_and_trim('enc_', kwargs) | |
dec_kwargs, kwargs = groupby_prefix_and_trim('dec_', kwargs) | |
# import pdb;pdb.set_trace() | |
# assert 'dim' not in enc_kwargs and 'dim' not in dec_kwargs, 'dimension of either encoder or decoder must be set with `dim` keyword' | |
# enc_transformer_kwargs = pick_and_pop(['num_tokens', 'max_seq_len'], enc_kwargs) | |
enc_transformer_kwargs = pick_and_pop(['max_seq_len'], enc_kwargs) | |
# enc_transformer_kwargs['num_memory_tokens'] = enc_kwargs.pop('num_memory_tokens', None) | |
# dec_transformer_kwargs = pick_and_pop(['num_tokens', 'max_seq_len'], dec_kwargs) | |
dec_transformer_kwargs = pick_and_pop(['max_seq_len'], dec_kwargs) | |
self.encoder = ContinuousTransformerWrapper( | |
**enc_transformer_kwargs, | |
attn_layers = Encoder(**enc_kwargs) | |
) | |
self.decoder = ContinuousTransformerWrapper( | |
**dec_transformer_kwargs, | |
attn_layers = Decoder(cross_attend = True, **dec_kwargs) | |
) | |
if tie_token_emb: | |
self.decoder.token_emb = self.encoder.token_emb | |
# self.decoder = AutoregressiveWrapper(self.decoder) | |
def generate(self, seq_in, seq_out_start, seq_len, src_mask = None): | |
encodings = self.encoder(seq_in, return_embeddings = True, mask = src_mask) | |
return self.decoder.generate(seq_out_start, seq_len, context = encodings, context_mask = src_mask) | |
def forward(self, src, tgt, src_mask = None, tgt_mask = None): | |
enc = self.encoder(src, mask = src_mask, return_embeddings = True) | |
#out = self.decoder(tgt, context = enc, mask = tgt_mask, context_mask = src_mask) | |
out = self.decoder(tgt, context = enc) | |
return out | |