EPCOT / cage /model.py
drjieliu's picture
Upload 4 files
890b6a3
import os,sys
import math
from pretrain.track.model import build_track_model
import torch.nn as nn
class Downstream_cage_model(nn.Module):
def __init__(self,pretrain_model,embed_dim,crop):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(embed_dim, 128),
nn.ReLU(),
nn.Linear(128,1)
)
self.pretrain_model=pretrain_model
self.crop=crop
def forward(self,x):
x=self.pretrain_model(x)
out=self.mlp(x[:,self.crop:-self.crop,:])
return out
def build_cage_model(args):
pretrain_model=build_track_model(args)
model=Downstream_cage_model(
pretrain_model=pretrain_model,
embed_dim=args.embed_dim,
crop=args.crop
)
return model
# import os,sys
# # import inspect
# # currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
# # parentdir = os.path.dirname(currentdir)
# # sys.path.insert(0, parentdir)
# from pretrain.track.layers import AttentionPool,Enformer,CNN
# from pretrain.track.transformers import Transformer
# from einops.layers.torch import Rearrange
# from einops import rearrange
# import torch
# import torch.nn as nn
# import torch.nn.functional as F
#
#
# class Convblock(nn.Module):
# def __init__(self,in_channel,kernel_size,dilate_size,dropout=0.1):
# super().__init__()
# self.conv=nn.Sequential(
# nn.Conv2d(
# in_channel, in_channel,
# kernel_size, padding=self.pad(kernel_size, dilate_size),
# dilation=dilate_size),
# nn.GroupNorm(16, in_channel),
# nn.Dropout(dropout)
# )
# def pad(self,kernelsize, dialte_size):
# return (kernelsize - 1) * dialte_size // 2
# def symmetric(self,x):
# return (x + x.permute(0,1,3,2)) / 2
# def forward(self,x):
# identity=x
# out=self.conv(x)
# x=out+identity
# x=self.symmetric(x)
# return F.relu(x)
#
# class dilated_tower(nn.Module):
# def __init__(self,embed_dim,in_channel=48,kernel_size=9,dilate_rate=4):
# super().__init__()
# dilate_convs=[]
# for i in range(dilate_rate+1):
# dilate_convs.append(
# Convblock(in_channel,kernel_size=kernel_size,dilate_size=2**i))
#
# self.cnn=nn.Sequential(
# Rearrange('b l n d -> b d l n'),
# nn.Conv2d(embed_dim, in_channel, kernel_size=1),
# *dilate_convs,
# nn.Conv2d(in_channel, in_channel, kernel_size=1),
# Rearrange('b d l n -> b l n d'),
# )
# def forward(self,x,crop):
# x=self.cnn(x)
# x=x[:,crop:-crop,crop:-crop,:]
# return x
#
#
# class Tranmodel(nn.Module):
# def __init__(self, backbone, transfomer):
# super().__init__()
# self.backbone = backbone
# self.transformer = transfomer
# hidden_dim = transfomer.d_model
# self.input_proj = nn.Conv1d(backbone.num_channels, hidden_dim, kernel_size=1)
# def forward(self, input):
# input=rearrange(input,'b n c l -> (b n) c l')
# src = self.backbone(input)
# src=self.input_proj(src)
# src = self.transformer(src)
# return src
#
# class finetunemodel(nn.Module):
# def __init__(self, pretrain_model, hidden_dim, embed_dim, bins, crop=25):
# super().__init__()
# self.pretrain_model = pretrain_model
# self.bins = bins
# self.crop = crop
# self.attention_pool = AttentionPool(hidden_dim)
# self.project = nn.Sequential(
# Rearrange('(b n) c -> b c n', n=bins),
# nn.Conv1d(hidden_dim, hidden_dim, kernel_size=9, padding=4, groups=hidden_dim),
# nn.InstanceNorm1d(hidden_dim, affine=True),
# nn.Conv1d(hidden_dim, embed_dim, kernel_size=1),
# nn.ReLU(inplace=True),
# nn.Dropout(0.2)
# )
# self.transformer = Enformer(dim=embed_dim, depth=4, heads=6)
# self.prediction_head = nn.Sequential(
# nn.Linear(embed_dim, 1)
# )
#
#
# def forward(self, x):
# # x = rearrange(x, 'b n c l -> (b n) c l')
# x = self.pretrain_model(x)
# x = self.attention_pool(x)
# x = self.project(x)
# x = rearrange(x, 'b c n -> b n c')
# x = self.transformer(x)
# x = self.prediction_head(x[:, self.crop:-self.crop, :])
# return x
#
# def build_backbone():
# model = CNN()
# return model
# def build_transformer(args):
# return Transformer(
# d_model=args.hidden_dim,
# dropout=args.dropout,
# nhead=args.nheads,
# dim_feedforward=args.dim_feedforward,
# num_encoder_layers=args.enc_layers,
# num_decoder_layers=args.dec_layers
# )
# def build_cage_model(args):
# backbone = build_backbone()
# transformer = build_transformer(args)
# pretrain_model = Tranmodel(
# backbone=backbone,
# transfomer=transformer,
# )
#
# model=finetunemodel(pretrain_model,hidden_dim=args.hidden_dim,embed_dim=args.embed_dim,
# bins=args.bins,crop=args.crop)
# return model