EPCOT / pretrain /model.py
drjieliu's picture
Upload 31 files
2e9cf56
#most of the codes below are copied from Query2label
import torch,math
import numpy as np
from torch import nn, Tensor
from pretrain.layers import CNN
from pretrain.transformer import Transformer
class GroupWiseLinear(nn.Module):
def __init__(self, num_class, hidden_dim, bias=True):
super().__init__()
self.num_class = num_class
self.hidden_dim = hidden_dim
self.bias = bias
self.W = nn.Parameter(torch.Tensor(1, num_class, hidden_dim))
if bias:
self.b = nn.Parameter(torch.Tensor(1, num_class))
self.reset_parameters()
def reset_parameters(self):
stdv = 1. / math.sqrt(self.W.size(2))
for i in range(self.num_class):
self.W[0][i].data.uniform_(-stdv, stdv)
if self.bias:
for i in range(self.num_class):
self.b[0][i].data.uniform_(-stdv, stdv)
def forward(self, x):
x = (self.W * x).sum(-1)
if self.bias:
x = x + self.b
return x
class Tranmodel(nn.Module):
def __init__(self, backbone, transfomer, num_class):
super().__init__()
self.backbone = backbone
self.transformer = transfomer
self.num_class = num_class
hidden_dim = transfomer.d_model
self.label_input = torch.Tensor(np.arange(num_class)).view(1, -1).long()
self.input_proj = nn.Conv1d(backbone.num_channels, hidden_dim, kernel_size=1)
self.query_embed = nn.Embedding(num_class, hidden_dim)
self.fc = GroupWiseLinear(num_class, hidden_dim, bias=True)
def forward(self, input):
src = self.backbone(input)
label_inputs=self.label_input.repeat(src.size(0),1).to(input.device)
label_embed=self.query_embed(label_inputs)
src=self.input_proj(src)
hs = self.transformer(src, label_embed)
out = self.fc(hs)
return out
def build_backbone():
model = CNN()
return model
def build_transformer(args):
return Transformer(
d_model=args.hidden_dim,
dropout=args.dropout,
nhead=args.nheads,
dim_feedforward=args.dim_feedforward,
num_encoder_layers=args.enc_layers,
num_decoder_layers=args.dec_layers
)
def build_epd_model(args):
backbone = build_backbone()
transformer = build_transformer(args)
model = Tranmodel(
backbone=backbone,
transfomer=transformer,
num_class=args.num_class,
)
return model