File size: 553 Bytes
82b70d0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 |
from .deepunfolding import *
def get_model(name, net=None):
if name == 'rpcanet':
net = RPCANet9(stage_num=6, slayers=6, llayers=3, mlayers=3, channel=32)
elif name == 'rpcanet_pp':
net = RPCANet_LSTM(stage_num=6, slayers=6, mlayers=3, channel=32)
elif name == 'rpcanet_pp_s3':
net = RPCANet_LSTM(stage_num=3, slayers=6, mlayers=3, channel=32)
elif name == 'rpcanet_pp_s9':
net = RPCANet_LSTM(stage_num=9, slayers=6, mlayers=3, channel=32)
else:
raise NotImplementedError
return net
|