Spaces:
Running
on
Zero
Running
on
Zero
File size: 1,787 Bytes
f5b749d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 |
from einops import rearrange
import numpy as np
import torch
import torch.nn as nn
class Generator(nn.Module):
def __init__(self,
in_features,
ffd_hidden_size,
num_classes,
attn_layer_num,
):
super(Generator, self).__init__()
self.attn = nn.ModuleList(
[
nn.MultiheadAttention(
embed_dim=in_features,
num_heads=8,
dropout=0.2,
batch_first=True,
)
for _ in range(attn_layer_num)
]
)
self.ffd = nn.Sequential(
nn.Linear(in_features, ffd_hidden_size),
nn.ReLU(),
nn.Linear(ffd_hidden_size, in_features)
)
self.dropout = nn.Dropout(0.2)
self.fc = nn.Linear(in_features * 2, num_classes)
self.proj = nn.Tanh()
def forward(self, ssl_feature, judge_id=None):
'''
ssl_feature: [B, T, D]
output: [B, num_classes]
'''
B, T, D = ssl_feature.shape
ssl_feature = self.ffd(ssl_feature)
tmp_ssl_feature = ssl_feature
for attn in self.attn:
tmp_ssl_feature, _ = attn(tmp_ssl_feature, tmp_ssl_feature, tmp_ssl_feature)
ssl_feature = self.dropout(torch.concat([torch.mean(tmp_ssl_feature, dim=1), torch.max(ssl_feature, dim=1)[0]], dim=1)) # B, 2D
x = self.fc(ssl_feature) # B, num_classes
x = self.proj(x) * 2.0 + 3
return x
|