Spaces:
Runtime error
Runtime error
import os,sys | |
import importlib | |
import options | |
from utils import log | |
import warnings | |
warnings.filterwarnings("ignore") | |
os.environ['WANDB_IGNORE_GLOBS'] = '*.pth' | |
os.environ['WANDB_MODE'] = 'dryrun' | |
def main(): | |
log.process(os.getpid()) | |
log.title("[{}] (PyTorch code for testing Sat2Density and debug".format(sys.argv[0])) | |
opt_cmd = options.parse_arguments(sys.argv[1:]) | |
opt = options.set(opt_cmd=opt_cmd) | |
if opt.test_ckpt_path and opt.task not in ["test" , "val","vis_test",'test_speed','test_vid','test_sty','test_interpolation']: | |
opt.task = "test" | |
if opt.task in ["train" , "Train"]: | |
opt.isTrain = True | |
else: | |
opt.isTrain = False | |
opt.name = opt.yaml if opt.name is None else opt.name | |
mode = importlib.import_module("model.{}".format(opt.model)) | |
m = mode.Model(opt) | |
m.load_dataset(opt) | |
m.build_networks(opt) | |
# train | |
if opt.task in ["train" , "Train"]: | |
m.setup_optimizer(opt) | |
m.train(opt) | |
# test or visualization | |
elif opt.task in ["test" , "val","vis_test"]: | |
m.test(opt) | |
# test speed | |
elif opt.task == 'test_speed': | |
m.test_speed(opt) | |
# inference video results | |
elif opt.task == 'test_vid': | |
m.test_vid(opt) | |
# test one image with different styles | |
elif opt.task == 'test_sty': | |
m.test_sty(opt) | |
# test style interpolation | |
elif opt.task == 'test_interpolation': | |
m.test_interpolation(opt) | |
else: | |
raise Exception("Unknow task") | |
if __name__=="__main__": | |
main() | |