File size: 2,484 Bytes
2e9cf56 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
#most of the codes below are copied from Query2label
import torch,math
import numpy as np
from torch import nn, Tensor
from pretrain.layers import CNN
from pretrain.transformer import Transformer
class GroupWiseLinear(nn.Module):
def __init__(self, num_class, hidden_dim, bias=True):
super().__init__()
self.num_class = num_class
self.hidden_dim = hidden_dim
self.bias = bias
self.W = nn.Parameter(torch.Tensor(1, num_class, hidden_dim))
if bias:
self.b = nn.Parameter(torch.Tensor(1, num_class))
self.reset_parameters()
def reset_parameters(self):
stdv = 1. / math.sqrt(self.W.size(2))
for i in range(self.num_class):
self.W[0][i].data.uniform_(-stdv, stdv)
if self.bias:
for i in range(self.num_class):
self.b[0][i].data.uniform_(-stdv, stdv)
def forward(self, x):
x = (self.W * x).sum(-1)
if self.bias:
x = x + self.b
return x
class Tranmodel(nn.Module):
def __init__(self, backbone, transfomer, num_class):
super().__init__()
self.backbone = backbone
self.transformer = transfomer
self.num_class = num_class
hidden_dim = transfomer.d_model
self.label_input = torch.Tensor(np.arange(num_class)).view(1, -1).long()
self.input_proj = nn.Conv1d(backbone.num_channels, hidden_dim, kernel_size=1)
self.query_embed = nn.Embedding(num_class, hidden_dim)
self.fc = GroupWiseLinear(num_class, hidden_dim, bias=True)
def forward(self, input):
src = self.backbone(input)
label_inputs=self.label_input.repeat(src.size(0),1).to(input.device)
label_embed=self.query_embed(label_inputs)
src=self.input_proj(src)
hs = self.transformer(src, label_embed)
out = self.fc(hs)
return out
def build_backbone():
model = CNN()
return model
def build_transformer(args):
return Transformer(
d_model=args.hidden_dim,
dropout=args.dropout,
nhead=args.nheads,
dim_feedforward=args.dim_feedforward,
num_encoder_layers=args.enc_layers,
num_decoder_layers=args.dec_layers
)
def build_epd_model(args):
backbone = build_backbone()
transformer = build_transformer(args)
model = Tranmodel(
backbone=backbone,
transfomer=transformer,
num_class=args.num_class,
)
return model
|