|
import os |
|
import shutil |
|
import argparse |
|
import torch |
|
import torch.multiprocessing as mp |
|
from Anymate.utils.train_utils import train_model |
|
import yaml |
|
from Anymate.dataset import AnymateDataset |
|
|
|
if __name__ == '__main__': |
|
parser = argparse.ArgumentParser(description='PyG DGCNN') |
|
parser.add_argument('--config', type=str, default='joints', help='load decoder') |
|
parser.add_argument('--split', action='store_true', help='use split dataset') |
|
args = parser.parse_args() |
|
|
|
world_size = torch.cuda.device_count() |
|
print('world_size', world_size) |
|
|
|
|
|
config_folder = './Anymate/configs' |
|
assert os.path.exists(os.path.join(config_folder, args.config+'.yaml')), f"Config file {os.path.join(config_folder, args.config+'.yaml')} not found" |
|
with open(os.path.join(config_folder, args.config+'.yaml')) as f: |
|
config = yaml.load(f, Loader=yaml.FullLoader) |
|
|
|
for key, value in config['args'].items(): |
|
setattr(args, key, value) |
|
setattr(args, 'decoder', config['model']['decoder']) |
|
args.logdir = os.path.join(args.logdir, args.mode + '-' + config['model']['encoder']+ '-' + config['model']['decoder']) |
|
args.checkpoint = os.path.join(args.checkpoint, args.mode + '-' + config['model']['encoder']+ '-' + config['model']['decoder']) |
|
print(args) |
|
|
|
|
|
if not os.path.isdir(args.checkpoint): |
|
print("Create new checkpoint folder " + args.checkpoint) |
|
os.makedirs(args.checkpoint, exist_ok=True) |
|
if not args.resume: |
|
if os.path.isdir(args.logdir): |
|
shutil.rmtree(args.logdir) |
|
os.makedirs(args.logdir, exist_ok=True) |
|
else: |
|
os.makedirs(args.logdir, exist_ok=True) |
|
global train_dataset |
|
|
|
if not args.split: |
|
|
|
train_dataset = AnymateDataset(name=args.trainset, root=args.root) |
|
train_dataset.data_list = [data for data in train_dataset.data_list if data['vox'].shape[0] != 0] |
|
print('train_dataset', len(train_dataset.data_list)) |
|
import multiprocessing |
|
manager = multiprocessing.Manager() |
|
shared_dict = manager.dict() |
|
shared_dict['train_dataset'] = train_dataset |
|
else: |
|
shared_dict = None |
|
|
|
|
|
port = 12355 |
|
while port < 65535: |
|
try: |
|
mp.spawn(train_model, args=(world_size, config, args, shared_dict, port), nprocs=world_size) |
|
break |
|
except Exception as e: |
|
if "address already in use" in str(e).lower(): |
|
print(f"Port {port} is already in use, trying next port") |
|
port += 1 |
|
else: |
|
print(f"Error starting training on port {port}: {e}") |
|
raise e |
|
print(f"Successfully started training on port {port}") |