|
|
|
import torch,math |
|
import numpy as np |
|
from torch import nn, Tensor |
|
from pretrain.layers import CNN |
|
from pretrain.transformer import Transformer |
|
|
|
class GroupWiseLinear(nn.Module): |
|
def __init__(self, num_class, hidden_dim, bias=True): |
|
super().__init__() |
|
self.num_class = num_class |
|
self.hidden_dim = hidden_dim |
|
self.bias = bias |
|
self.W = nn.Parameter(torch.Tensor(1, num_class, hidden_dim)) |
|
if bias: |
|
self.b = nn.Parameter(torch.Tensor(1, num_class)) |
|
self.reset_parameters() |
|
def reset_parameters(self): |
|
stdv = 1. / math.sqrt(self.W.size(2)) |
|
for i in range(self.num_class): |
|
self.W[0][i].data.uniform_(-stdv, stdv) |
|
if self.bias: |
|
for i in range(self.num_class): |
|
self.b[0][i].data.uniform_(-stdv, stdv) |
|
def forward(self, x): |
|
x = (self.W * x).sum(-1) |
|
if self.bias: |
|
x = x + self.b |
|
return x |
|
|
|
class Tranmodel(nn.Module): |
|
def __init__(self, backbone, transfomer, num_class): |
|
super().__init__() |
|
self.backbone = backbone |
|
self.transformer = transfomer |
|
self.num_class = num_class |
|
hidden_dim = transfomer.d_model |
|
self.label_input = torch.Tensor(np.arange(num_class)).view(1, -1).long() |
|
self.input_proj = nn.Conv1d(backbone.num_channels, hidden_dim, kernel_size=1) |
|
self.query_embed = nn.Embedding(num_class, hidden_dim) |
|
self.fc = GroupWiseLinear(num_class, hidden_dim, bias=True) |
|
|
|
def forward(self, input): |
|
src = self.backbone(input) |
|
label_inputs=self.label_input.repeat(src.size(0),1).to(input.device) |
|
label_embed=self.query_embed(label_inputs) |
|
src=self.input_proj(src) |
|
hs = self.transformer(src, label_embed) |
|
out = self.fc(hs) |
|
return out |
|
|
|
def build_backbone(): |
|
model = CNN() |
|
return model |
|
def build_transformer(args): |
|
return Transformer( |
|
d_model=args.hidden_dim, |
|
dropout=args.dropout, |
|
nhead=args.nheads, |
|
dim_feedforward=args.dim_feedforward, |
|
num_encoder_layers=args.enc_layers, |
|
num_decoder_layers=args.dec_layers |
|
) |
|
def build_epd_model(args): |
|
backbone = build_backbone() |
|
transformer = build_transformer(args) |
|
model = Tranmodel( |
|
backbone=backbone, |
|
transfomer=transformer, |
|
num_class=args.num_class, |
|
) |
|
return model |
|
|