ing0's picture
v1.2
f5b749d
from einops import rearrange
import numpy as np
import torch
import torch.nn as nn
class Generator(nn.Module):
def __init__(self,
in_features,
ffd_hidden_size,
num_classes,
attn_layer_num,
):
super(Generator, self).__init__()
self.attn = nn.ModuleList(
[
nn.MultiheadAttention(
embed_dim=in_features,
num_heads=8,
dropout=0.2,
batch_first=True,
)
for _ in range(attn_layer_num)
]
)
self.ffd = nn.Sequential(
nn.Linear(in_features, ffd_hidden_size),
nn.ReLU(),
nn.Linear(ffd_hidden_size, in_features)
)
self.dropout = nn.Dropout(0.2)
self.fc = nn.Linear(in_features * 2, num_classes)
self.proj = nn.Tanh()
def forward(self, ssl_feature, judge_id=None):
'''
ssl_feature: [B, T, D]
output: [B, num_classes]
'''
B, T, D = ssl_feature.shape
ssl_feature = self.ffd(ssl_feature)
tmp_ssl_feature = ssl_feature
for attn in self.attn:
tmp_ssl_feature, _ = attn(tmp_ssl_feature, tmp_ssl_feature, tmp_ssl_feature)
ssl_feature = self.dropout(torch.concat([torch.mean(tmp_ssl_feature, dim=1), torch.max(ssl_feature, dim=1)[0]], dim=1)) # B, 2D
x = self.fc(ssl_feature) # B, num_classes
x = self.proj(x) * 2.0 + 3
return x