|
import torch |
|
import torch.nn as nn |
|
|
|
from .attentions import MultiHeadAttention |
|
|
|
|
|
class VAEMemoryBank(nn.Module): |
|
def __init__( |
|
self, |
|
bank_size=1000, |
|
n_hidden_dims=512, |
|
n_attn_heads=2, |
|
init_values=None, |
|
output_channels=192, |
|
): |
|
super().__init__() |
|
|
|
self.bank_size = bank_size |
|
self.n_hidden_dims = n_hidden_dims |
|
self.n_attn_heads = n_attn_heads |
|
|
|
self.encoder = MultiHeadAttention( |
|
channels=n_hidden_dims, |
|
out_channels=n_hidden_dims, |
|
n_heads=n_attn_heads, |
|
) |
|
|
|
self.memory_bank = nn.Parameter(torch.randn(n_hidden_dims, bank_size)) |
|
self.proj = nn.Conv1d(n_hidden_dims, output_channels, 1) |
|
if init_values is not None: |
|
with torch.no_grad(): |
|
self.memory_bank.copy_(init_values) |
|
|
|
def forward(self, z: torch.Tensor): |
|
b, _, _ = z.shape |
|
ret = self.encoder( |
|
z, self.memory_bank.unsqueeze(0).repeat(b, 1, 1), attn_mask=None |
|
) |
|
ret = self.proj(ret) |
|
return ret |
|
|