Spaces:
Running
on
Zero
Running
on
Zero
from einops import rearrange | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
class Generator(nn.Module): | |
def __init__(self, | |
in_features, | |
ffd_hidden_size, | |
num_classes, | |
attn_layer_num, | |
): | |
super(Generator, self).__init__() | |
self.attn = nn.ModuleList( | |
[ | |
nn.MultiheadAttention( | |
embed_dim=in_features, | |
num_heads=8, | |
dropout=0.2, | |
batch_first=True, | |
) | |
for _ in range(attn_layer_num) | |
] | |
) | |
self.ffd = nn.Sequential( | |
nn.Linear(in_features, ffd_hidden_size), | |
nn.ReLU(), | |
nn.Linear(ffd_hidden_size, in_features) | |
) | |
self.dropout = nn.Dropout(0.2) | |
self.fc = nn.Linear(in_features * 2, num_classes) | |
self.proj = nn.Tanh() | |
def forward(self, ssl_feature, judge_id=None): | |
''' | |
ssl_feature: [B, T, D] | |
output: [B, num_classes] | |
''' | |
B, T, D = ssl_feature.shape | |
ssl_feature = self.ffd(ssl_feature) | |
tmp_ssl_feature = ssl_feature | |
for attn in self.attn: | |
tmp_ssl_feature, _ = attn(tmp_ssl_feature, tmp_ssl_feature, tmp_ssl_feature) | |
ssl_feature = self.dropout(torch.concat([torch.mean(tmp_ssl_feature, dim=1), torch.max(ssl_feature, dim=1)[0]], dim=1)) # B, 2D | |
x = self.fc(ssl_feature) # B, num_classes | |
x = self.proj(x) * 2.0 + 3 | |
return x | |