Spaces:
Running
Running
import torch.nn as nn | |
from model.SUNet_detail import SUNet | |
class SUNet_model(nn.Module): | |
def __init__(self, config): | |
super(SUNet_model, self).__init__() | |
self.config = config | |
self.swin_unet = SUNet(img_size=config['SWINUNET']['IMG_SIZE'], | |
patch_size=config['SWINUNET']['PATCH_SIZE'], | |
in_chans=3, | |
out_chans=3, | |
embed_dim=config['SWINUNET']['EMB_DIM'], | |
depths=config['SWINUNET']['DEPTH_EN'], | |
num_heads=config['SWINUNET']['HEAD_NUM'], | |
window_size=config['SWINUNET']['WIN_SIZE'], | |
mlp_ratio=config['SWINUNET']['MLP_RATIO'], | |
qkv_bias=config['SWINUNET']['QKV_BIAS'], | |
qk_scale=config['SWINUNET']['QK_SCALE'], | |
drop_rate=config['SWINUNET']['DROP_RATE'], | |
drop_path_rate=config['SWINUNET']['DROP_PATH_RATE'], | |
ape=config['SWINUNET']['APE'], | |
patch_norm=config['SWINUNET']['PATCH_NORM'], | |
use_checkpoint=config['SWINUNET']['USE_CHECKPOINTS']) | |
def forward(self, x): | |
if x.size()[1] == 1: | |
x = x.repeat(1, 3, 1, 1) | |
logits = self.swin_unet(x) | |
return logits | |