import torch import torch.nn as nn from .attentions import MultiHeadAttention class VAEMemoryBank(nn.Module): def __init__( self, bank_size=1000, n_hidden_dims=512, n_attn_heads=2, init_values=None, output_channels=192, ): super().__init__() self.bank_size = bank_size self.n_hidden_dims = n_hidden_dims self.n_attn_heads = n_attn_heads self.encoder = MultiHeadAttention( channels=n_hidden_dims, out_channels=n_hidden_dims, n_heads=n_attn_heads, ) self.memory_bank = nn.Parameter(torch.randn(n_hidden_dims, bank_size)) self.proj = nn.Conv1d(n_hidden_dims, output_channels, 1) if init_values is not None: with torch.no_grad(): self.memory_bank.copy_(init_values) def forward(self, z: torch.Tensor): b, _, _ = z.shape ret = self.encoder( z, self.memory_bank.unsqueeze(0).repeat(b, 1, 1), attn_mask=None ) ret = self.proj(ret) return ret