import torch,math from torch import nn, Tensor from pretrain.track.transformers import Transformer from einops import rearrange,repeat from pretrain.track.layers import CNN,Enformer,AttentionPool from einops.layers.torch import Rearrange import numpy as np import torch.nn.functional as F class Tranmodel(nn.Module): def __init__(self, backbone, transfomer): super().__init__() self.backbone = backbone self.transformer = transfomer hidden_dim = transfomer.d_model self.input_proj = nn.Conv1d(backbone.num_channels, hidden_dim, kernel_size=1) def forward(self, input): input=rearrange(input,'b n c l -> (b n) c l') src = self.backbone(input) src=self.input_proj(src) src = self.transformer(src) return src class PositionalEncoding(nn.Module): def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000): super().__init__() self.dropout = nn.Dropout(p=dropout) position = torch.arange(max_len).unsqueeze(1) div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) pe = torch.zeros(max_len, 1, d_model) pe[:, 0, 0::2] = torch.sin(position * div_term) pe[:, 0, 1::2] = torch.cos(position * div_term) self.register_buffer('pe', pe) def forward(self, x: Tensor) -> Tensor: x = x + self.pe[:x.size(0)] return self.dropout(x) class finetunemodel(nn.Module): def __init__(self,pretrain_model,hidden_dim,embed_dim,bins,crop=50,num_class=245,return_embed=True): super().__init__() self.pretrain_model = pretrain_model self.bins=bins self.crop=crop self.return_embed = return_embed self.attention_pool = AttentionPool(hidden_dim) self.project=nn.Sequential( Rearrange('(b n) c -> b c n', n=bins), nn.Conv1d(hidden_dim, hidden_dim, kernel_size=15, padding=7,groups=hidden_dim), nn.InstanceNorm1d(hidden_dim, affine=True), nn.Conv1d(hidden_dim, embed_dim, kernel_size=1), nn.ReLU(inplace=True), nn.Dropout(0.1), nn.Conv1d(embed_dim, embed_dim, kernel_size=9, padding=4), nn.InstanceNorm1d(embed_dim, affine=True), nn.ReLU(inplace=True), nn.Dropout(0.1), ) self.transformer = Enformer(dim=embed_dim, depth=4, heads=8) self.prediction_head=nn.Linear(embed_dim,num_class) def forward(self, x): x=self.pretrain_model(x) x = self.attention_pool(x) x = self.project(x) x= rearrange(x,'b c n -> b n c') x = self.transformer(x) out = self.prediction_head(x[:, self.crop:-self.crop, :]) if self.return_embed: return x else: return out def build_backbone(): model = CNN() return model def build_transformer(args): return Transformer( d_model=args.hidden_dim, dropout=args.dropout, nhead=args.nheads, dim_feedforward=args.dim_feedforward, num_encoder_layers=args.enc_layers, num_decoder_layers=args.dec_layers ) def build_track_model(args): backbone = build_backbone() transformer = build_transformer(args) pretrain_model = Tranmodel( backbone=backbone, transfomer=transformer, ) model=finetunemodel(pretrain_model,hidden_dim=args.hidden_dim,embed_dim=args.embed_dim, bins=args.bins,crop=args.crop,return_embed=args.return_embed) return model