AND OTHER DEPS
Browse files- models/clvp.py +1 -1
- models/xtransformers.py +0 -47
models/clvp.py
CHANGED
@@ -2,10 +2,10 @@ import torch
|
|
2 |
import torch.nn as nn
|
3 |
import torch.nn.functional as F
|
4 |
from torch import einsum
|
5 |
-
from x_transformers import Encoder
|
6 |
|
7 |
from models.arch_util import CheckpointedXTransformerEncoder
|
8 |
from models.transformer import Transformer
|
|
|
9 |
|
10 |
|
11 |
def exists(val):
|
|
|
2 |
import torch.nn as nn
|
3 |
import torch.nn.functional as F
|
4 |
from torch import einsum
|
|
|
5 |
|
6 |
from models.arch_util import CheckpointedXTransformerEncoder
|
7 |
from models.transformer import Transformer
|
8 |
+
from models.xtransformers import Encoder
|
9 |
|
10 |
|
11 |
def exists(val):
|
models/xtransformers.py
CHANGED
@@ -1253,50 +1253,3 @@ class ContinuousTransformerWrapper(nn.Module):
|
|
1253 |
return tuple(res)
|
1254 |
return res[0]
|
1255 |
|
1256 |
-
|
1257 |
-
class XTransformer(nn.Module):
|
1258 |
-
def __init__(
|
1259 |
-
self,
|
1260 |
-
*,
|
1261 |
-
dim,
|
1262 |
-
tie_token_emb=False,
|
1263 |
-
**kwargs
|
1264 |
-
):
|
1265 |
-
super().__init__()
|
1266 |
-
enc_kwargs, kwargs = groupby_prefix_and_trim('enc_', kwargs)
|
1267 |
-
dec_kwargs, kwargs = groupby_prefix_and_trim('dec_', kwargs)
|
1268 |
-
|
1269 |
-
assert 'dim' not in enc_kwargs and 'dim' not in dec_kwargs, 'dimension of either encoder or decoder must be set with `dim` keyword'
|
1270 |
-
enc_transformer_kwargs = pick_and_pop(['num_tokens', 'max_seq_len'], enc_kwargs)
|
1271 |
-
enc_transformer_kwargs['emb_dropout'] = enc_kwargs.pop('emb_dropout', 0)
|
1272 |
-
enc_transformer_kwargs['num_memory_tokens'] = enc_kwargs.pop('num_memory_tokens', None)
|
1273 |
-
enc_transformer_kwargs['use_pos_emb'] = enc_kwargs.pop('use_pos_emb', True)
|
1274 |
-
|
1275 |
-
dec_transformer_kwargs = pick_and_pop(['num_tokens', 'max_seq_len'], dec_kwargs)
|
1276 |
-
dec_transformer_kwargs['emb_dropout'] = dec_kwargs.pop('emb_dropout', 0)
|
1277 |
-
dec_transformer_kwargs['use_pos_emb'] = dec_kwargs.pop('use_pos_emb', True)
|
1278 |
-
|
1279 |
-
self.encoder = TransformerWrapper(
|
1280 |
-
**enc_transformer_kwargs,
|
1281 |
-
attn_layers=Encoder(dim=dim, **enc_kwargs)
|
1282 |
-
)
|
1283 |
-
|
1284 |
-
self.decoder = TransformerWrapper(
|
1285 |
-
**dec_transformer_kwargs,
|
1286 |
-
attn_layers=Decoder(dim=dim, cross_attend=True, **dec_kwargs)
|
1287 |
-
)
|
1288 |
-
|
1289 |
-
if tie_token_emb:
|
1290 |
-
self.decoder.token_emb = self.encoder.token_emb
|
1291 |
-
|
1292 |
-
self.decoder = AutoregressiveWrapper(self.decoder)
|
1293 |
-
|
1294 |
-
@torch.no_grad()
|
1295 |
-
def generate(self, seq_in, seq_out_start, seq_len, src_mask=None, src_attn_mask=None, **kwargs):
|
1296 |
-
encodings = self.encoder(seq_in, mask=src_mask, attn_mask=src_attn_mask, return_embeddings=True)
|
1297 |
-
return self.decoder.generate(seq_out_start, seq_len, context=encodings, context_mask=src_mask, **kwargs)
|
1298 |
-
|
1299 |
-
def forward(self, src, tgt, src_mask=None, tgt_mask=None, src_attn_mask=None):
|
1300 |
-
enc = self.encoder(src, mask=src_mask, attn_mask=src_attn_mask, return_embeddings=True)
|
1301 |
-
out = self.decoder(tgt, context=enc, mask=tgt_mask, context_mask=src_mask)
|
1302 |
-
return out
|
|
|
1253 |
return tuple(res)
|
1254 |
return res[0]
|
1255 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|