Spaces:
Running
Running
File size: 1,470 Bytes
6e75a05 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 |
import torch.nn as nn
from model.SUNet_detail import SUNet
class SUNet_model(nn.Module):
def __init__(self, config):
super(SUNet_model, self).__init__()
self.config = config
self.swin_unet = SUNet(img_size=config['SWINUNET']['IMG_SIZE'],
patch_size=config['SWINUNET']['PATCH_SIZE'],
in_chans=3,
out_chans=3,
embed_dim=config['SWINUNET']['EMB_DIM'],
depths=config['SWINUNET']['DEPTH_EN'],
num_heads=config['SWINUNET']['HEAD_NUM'],
window_size=config['SWINUNET']['WIN_SIZE'],
mlp_ratio=config['SWINUNET']['MLP_RATIO'],
qkv_bias=config['SWINUNET']['QKV_BIAS'],
qk_scale=config['SWINUNET']['QK_SCALE'],
drop_rate=config['SWINUNET']['DROP_RATE'],
drop_path_rate=config['SWINUNET']['DROP_PATH_RATE'],
ape=config['SWINUNET']['APE'],
patch_norm=config['SWINUNET']['PATCH_NORM'],
use_checkpoint=config['SWINUNET']['USE_CHECKPOINTS'])
def forward(self, x):
if x.size()[1] == 1:
x = x.repeat(1, 3, 1, 1)
logits = self.swin_unet(x)
return logits
|