DINO-HuVITS / src /vae_memory_bank.py
SazerLife's picture
feat: added model
36a67ca
raw
history blame contribute delete
No virus
1.09 kB
import torch
import torch.nn as nn
from .attentions import MultiHeadAttention
class VAEMemoryBank(nn.Module):
def __init__(
self,
bank_size=1000,
n_hidden_dims=512,
n_attn_heads=2,
init_values=None,
output_channels=192,
):
super().__init__()
self.bank_size = bank_size
self.n_hidden_dims = n_hidden_dims
self.n_attn_heads = n_attn_heads
self.encoder = MultiHeadAttention(
channels=n_hidden_dims,
out_channels=n_hidden_dims,
n_heads=n_attn_heads,
)
self.memory_bank = nn.Parameter(torch.randn(n_hidden_dims, bank_size))
self.proj = nn.Conv1d(n_hidden_dims, output_channels, 1)
if init_values is not None:
with torch.no_grad():
self.memory_bank.copy_(init_values)
def forward(self, z: torch.Tensor):
b, _, _ = z.shape
ret = self.encoder(
z, self.memory_bank.unsqueeze(0).repeat(b, 1, 1), attn_mask=None
)
ret = self.proj(ret)
return ret