File size: 3,589 Bytes
2e9cf56 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 |
import torch,math
from torch import nn, Tensor
from pretrain.track.transformers import Transformer
from einops import rearrange,repeat
from pretrain.track.layers import CNN,Enformer,AttentionPool
from einops.layers.torch import Rearrange
import numpy as np
import torch.nn.functional as F
class Tranmodel(nn.Module):
def __init__(self, backbone, transfomer):
super().__init__()
self.backbone = backbone
self.transformer = transfomer
hidden_dim = transfomer.d_model
self.input_proj = nn.Conv1d(backbone.num_channels, hidden_dim, kernel_size=1)
def forward(self, input):
input=rearrange(input,'b n c l -> (b n) c l')
src = self.backbone(input)
src=self.input_proj(src)
src = self.transformer(src)
return src
class PositionalEncoding(nn.Module):
def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
super().__init__()
self.dropout = nn.Dropout(p=dropout)
position = torch.arange(max_len).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
pe = torch.zeros(max_len, 1, d_model)
pe[:, 0, 0::2] = torch.sin(position * div_term)
pe[:, 0, 1::2] = torch.cos(position * div_term)
self.register_buffer('pe', pe)
def forward(self, x: Tensor) -> Tensor:
x = x + self.pe[:x.size(0)]
return self.dropout(x)
class finetunemodel(nn.Module):
def __init__(self,pretrain_model,hidden_dim,embed_dim,bins,crop=50,num_class=245,return_embed=True):
super().__init__()
self.pretrain_model = pretrain_model
self.bins=bins
self.crop=crop
self.return_embed = return_embed
self.attention_pool = AttentionPool(hidden_dim)
self.project=nn.Sequential(
Rearrange('(b n) c -> b c n', n=bins),
nn.Conv1d(hidden_dim, hidden_dim, kernel_size=15, padding=7,groups=hidden_dim),
nn.InstanceNorm1d(hidden_dim, affine=True),
nn.Conv1d(hidden_dim, embed_dim, kernel_size=1),
nn.ReLU(inplace=True),
nn.Dropout(0.1),
nn.Conv1d(embed_dim, embed_dim, kernel_size=9, padding=4),
nn.InstanceNorm1d(embed_dim, affine=True),
nn.ReLU(inplace=True),
nn.Dropout(0.1),
)
self.transformer = Enformer(dim=embed_dim, depth=4, heads=8)
self.prediction_head=nn.Linear(embed_dim,num_class)
def forward(self, x):
x=self.pretrain_model(x)
x = self.attention_pool(x)
x = self.project(x)
x= rearrange(x,'b c n -> b n c')
x = self.transformer(x)
out = self.prediction_head(x[:, self.crop:-self.crop, :])
if self.return_embed:
return x
else:
return out
def build_backbone():
model = CNN()
return model
def build_transformer(args):
return Transformer(
d_model=args.hidden_dim,
dropout=args.dropout,
nhead=args.nheads,
dim_feedforward=args.dim_feedforward,
num_encoder_layers=args.enc_layers,
num_decoder_layers=args.dec_layers
)
def build_track_model(args):
backbone = build_backbone()
transformer = build_transformer(args)
pretrain_model = Tranmodel(
backbone=backbone,
transfomer=transformer,
)
model=finetunemodel(pretrain_model,hidden_dim=args.hidden_dim,embed_dim=args.embed_dim,
bins=args.bins,crop=args.crop,return_embed=args.return_embed)
return model
|