HungNP
New single commit message
cb80c28
raw
history blame contribute delete
No virus
2.05 kB
def get_model(model_name, args):
name = model_name.lower()
if name == "icarl":
from models.icarl import iCaRL
return iCaRL(args)
elif name == "bic":
from models.bic import BiC
return BiC(args)
elif name == "podnet":
from models.podnet import PODNet
return PODNet(args)
elif name == "lwf":
from models.lwf import LwF
return LwF(args)
elif name == "ewc":
from models.ewc import EWC
return EWC(args)
elif name == "wa":
from models.wa import WA
return WA(args)
elif name == "der":
from models.der import DER
return DER(args)
elif name == "finetune":
from models.finetune import Finetune
return Finetune(args)
elif name == "replay":
from models.replay import Replay
return Replay(args)
elif name == "gem":
from models.gem import GEM
return GEM(args)
elif name == "coil":
from models.coil import COIL
return COIL(args)
elif name == "foster":
from models.foster import FOSTER
return FOSTER(args)
elif name == "rmm-icarl":
from models.rmm import RMM_FOSTER, RMM_iCaRL
return RMM_iCaRL(args)
elif name == "rmm-foster":
from models.rmm import RMM_FOSTER, RMM_iCaRL
return RMM_FOSTER(args)
elif name == "fetril":
from models.fetril import FeTrIL
return FeTrIL(args)
elif name == "pass":
from models.pa2s import PASS
return PASS(args)
elif name == "il2a":
from models.il2a import IL2A
return IL2A(args)
elif name == "ssre":
from models.ssre import SSRE
return SSRE(args)
elif name == "memo":
from models.memo import MEMO
return MEMO(args)
elif name == "beefiso":
from models.beef_iso import BEEFISO
return BEEFISO(args)
elif name == "simplecil":
from models.simplecil import SimpleCIL
return SimpleCIL(args)
else:
assert 0