52Hz's picture
Create SUNet.py
6e75a05
raw
history blame
1.47 kB
import torch.nn as nn
from model.SUNet_detail import SUNet
class SUNet_model(nn.Module):
def __init__(self, config):
super(SUNet_model, self).__init__()
self.config = config
self.swin_unet = SUNet(img_size=config['SWINUNET']['IMG_SIZE'],
patch_size=config['SWINUNET']['PATCH_SIZE'],
in_chans=3,
out_chans=3,
embed_dim=config['SWINUNET']['EMB_DIM'],
depths=config['SWINUNET']['DEPTH_EN'],
num_heads=config['SWINUNET']['HEAD_NUM'],
window_size=config['SWINUNET']['WIN_SIZE'],
mlp_ratio=config['SWINUNET']['MLP_RATIO'],
qkv_bias=config['SWINUNET']['QKV_BIAS'],
qk_scale=config['SWINUNET']['QK_SCALE'],
drop_rate=config['SWINUNET']['DROP_RATE'],
drop_path_rate=config['SWINUNET']['DROP_PATH_RATE'],
ape=config['SWINUNET']['APE'],
patch_norm=config['SWINUNET']['PATCH_NORM'],
use_checkpoint=config['SWINUNET']['USE_CHECKPOINTS'])
def forward(self, x):
if x.size()[1] == 1:
x = x.repeat(1, 3, 1, 1)
logits = self.swin_unet(x)
return logits