LVM-Med / train_classification.py
duynhm's picture
Initial commit
be2715b
from classification_R50.train_R50_classification import train_R50
from utils.func import (
parse_config,
load_config
)
if __name__=="__main__":
yml_args = parse_config()
cfg = load_config(yml_args.config)
assert cfg.base.is_R50 + cfg.base.is_SAMVIT == 1
if cfg.base.is_R50:
train_R50(yml_args, cfg)
else:
print("Wrong")