jbetker commited on
Commit
f01c9a2
·
1 Parent(s): 24a5b84

AND OTHER DEPS

Browse files
Files changed (2) hide show
  1. models/clvp.py +1 -1
  2. models/xtransformers.py +0 -47
models/clvp.py CHANGED
@@ -2,10 +2,10 @@ import torch
2
  import torch.nn as nn
3
  import torch.nn.functional as F
4
  from torch import einsum
5
- from x_transformers import Encoder
6
 
7
  from models.arch_util import CheckpointedXTransformerEncoder
8
  from models.transformer import Transformer
 
9
 
10
 
11
  def exists(val):
 
2
  import torch.nn as nn
3
  import torch.nn.functional as F
4
  from torch import einsum
 
5
 
6
  from models.arch_util import CheckpointedXTransformerEncoder
7
  from models.transformer import Transformer
8
+ from models.xtransformers import Encoder
9
 
10
 
11
  def exists(val):
models/xtransformers.py CHANGED
@@ -1253,50 +1253,3 @@ class ContinuousTransformerWrapper(nn.Module):
1253
  return tuple(res)
1254
  return res[0]
1255
 
1256
-
1257
- class XTransformer(nn.Module):
1258
- def __init__(
1259
- self,
1260
- *,
1261
- dim,
1262
- tie_token_emb=False,
1263
- **kwargs
1264
- ):
1265
- super().__init__()
1266
- enc_kwargs, kwargs = groupby_prefix_and_trim('enc_', kwargs)
1267
- dec_kwargs, kwargs = groupby_prefix_and_trim('dec_', kwargs)
1268
-
1269
- assert 'dim' not in enc_kwargs and 'dim' not in dec_kwargs, 'dimension of either encoder or decoder must be set with `dim` keyword'
1270
- enc_transformer_kwargs = pick_and_pop(['num_tokens', 'max_seq_len'], enc_kwargs)
1271
- enc_transformer_kwargs['emb_dropout'] = enc_kwargs.pop('emb_dropout', 0)
1272
- enc_transformer_kwargs['num_memory_tokens'] = enc_kwargs.pop('num_memory_tokens', None)
1273
- enc_transformer_kwargs['use_pos_emb'] = enc_kwargs.pop('use_pos_emb', True)
1274
-
1275
- dec_transformer_kwargs = pick_and_pop(['num_tokens', 'max_seq_len'], dec_kwargs)
1276
- dec_transformer_kwargs['emb_dropout'] = dec_kwargs.pop('emb_dropout', 0)
1277
- dec_transformer_kwargs['use_pos_emb'] = dec_kwargs.pop('use_pos_emb', True)
1278
-
1279
- self.encoder = TransformerWrapper(
1280
- **enc_transformer_kwargs,
1281
- attn_layers=Encoder(dim=dim, **enc_kwargs)
1282
- )
1283
-
1284
- self.decoder = TransformerWrapper(
1285
- **dec_transformer_kwargs,
1286
- attn_layers=Decoder(dim=dim, cross_attend=True, **dec_kwargs)
1287
- )
1288
-
1289
- if tie_token_emb:
1290
- self.decoder.token_emb = self.encoder.token_emb
1291
-
1292
- self.decoder = AutoregressiveWrapper(self.decoder)
1293
-
1294
- @torch.no_grad()
1295
- def generate(self, seq_in, seq_out_start, seq_len, src_mask=None, src_attn_mask=None, **kwargs):
1296
- encodings = self.encoder(seq_in, mask=src_mask, attn_mask=src_attn_mask, return_embeddings=True)
1297
- return self.decoder.generate(seq_out_start, seq_len, context=encodings, context_mask=src_mask, **kwargs)
1298
-
1299
- def forward(self, src, tgt, src_mask=None, tgt_mask=None, src_attn_mask=None):
1300
- enc = self.encoder(src, mask=src_mask, attn_mask=src_attn_mask, return_embeddings=True)
1301
- out = self.decoder(tgt, context=enc, mask=tgt_mask, context_mask=src_mask)
1302
- return out
 
1253
  return tuple(res)
1254
  return res[0]
1255