RPCANet / models /__init__.py
fengyiwu's picture
Upload 93 files
82b70d0 verified
from .deepunfolding import *
def get_model(name, net=None):
if name == 'rpcanet':
net = RPCANet9(stage_num=6, slayers=6, llayers=3, mlayers=3, channel=32)
elif name == 'rpcanet_pp':
net = RPCANet_LSTM(stage_num=6, slayers=6, mlayers=3, channel=32)
elif name == 'rpcanet_pp_s3':
net = RPCANet_LSTM(stage_num=3, slayers=6, mlayers=3, channel=32)
elif name == 'rpcanet_pp_s9':
net = RPCANet_LSTM(stage_num=9, slayers=6, mlayers=3, channel=32)
else:
raise NotImplementedError
return net