import argparse def args_parser(): parser = argparse.ArgumentParser(description="General top layer trainer") parser.add_argument("--opt", type=str, default="config/train.yaml", help="Path to optional configuration file") parser.add_argument('--model', type=str, default='model', help='Model block name, in the `model` directory') parser.add_argument('--name', type=str, default='FGT_train', help='Experiment name') parser.add_argument('--outputdir', type=str, default='/myData/ret/experiments', help='Output dir to save results') parser.add_argument('--datadir', type=str, default='/myData/', metavar='PATH') parser.add_argument('--datasetName_train', type=str, default='train_dataset_frames_diffusedFlows', help='The file name of the train dataset, in `data` directory') parser.add_argument('--network', type=str, default='network', help='The network file which defines the training process, in the `network` directory') parser.add_argument('--finetune', type=int, default=0, help='Whether to fine tune trained models') # parser.add_argument('--checkPoint', type=str, default='', help='checkpoint path for continue training') parser.add_argument('--gen_state', type=str, default='', help='Checkpoint of the generator') parser.add_argument('--dis_state', type=str, default='', help='Checkpoint of the discriminator') parser.add_argument('--opt_state', type=str, default='', help='Checkpoint of the options') parser.add_argument('--record_iter', type=int, default=16, help='How many iters to print an item of log') parser.add_argument('--flow_checkPoint', type=str, default='flowCheckPoint/', help='The path for flow model filling') parser.add_argument('--dataMode', type=str, default='resize', choices=['resize', 'crop']) # data related parameters parser.add_argument('--flow2rgb', type=int, default=1, help='Whether to transform flows from raw data to rgb') parser.add_argument('--flow_direction', type=str, default='for', choices=['for', 'back', 'bi'], help='Which GT flow should be chosen for guidance') parser.add_argument('--num_frames', type=int, default=5, help='How many frames are chosen for frame completion') parser.add_argument('--sample', type=str, default='random', choices=['random', 'seq'], help='Choose the sample method for training in each iterations') parser.add_argument('--max_val', type=float, default=0.01, help='The maximal value to quantize the optical flows') # model related parameters parser.add_argument('--res_h', type=int, default=240, help='The height of the frame resolution') parser.add_argument('--res_w', type=int, default=432, help='The width of the frame resolution') parser.add_argument('--in_channel', type=int, default=4, help='The input channel of the frame branch') parser.add_argument('--cnum', type=int, default=64, help='The initial channel number of the frame branch') parser.add_argument('--flow_inChannel', type=int, default=2, help='The input channel of the flow branch') parser.add_argument('--flow_cnum', type=int, default=64, help='The initial channel dimension of the flow branch') parser.add_argument('--dist_cnum', type=int, default=32, help='The initial channel num in the discriminator') parser.add_argument('--frame_hidden', type=int, default=512, help='The channel / patch dimension in the frame branch') parser.add_argument('--flow_hidden', type=int, default=256, help='The channel / patch dimension in the flow branch') parser.add_argument('--PASSMASK', type=int, default=1, help='1 -> concat the mask with the corrupted optical flows to fill the flow') parser.add_argument('--numBlocks', type=int, default=8, help='How many transformer blocks do we need to stack') parser.add_argument('--kernel_size_w', type=int, default=7, help='The width of the kernel for extracting patches') parser.add_argument('--kernel_size_h', type=int, default=7, help='The height of the kernel for extracting patches') parser.add_argument('--stride_h', type=int, default=3, help='The height of the stride') parser.add_argument('--stride_w', type=int, default=3, help='The width of the stride') parser.add_argument('--pad_h', type=int, default=3, help='The height of the padding') parser.add_argument('--pad_w', type=int, default=3, help='The width of the padding') parser.add_argument('--num_head', type=int, default=4, help='The head number for the multihead attention') parser.add_argument('--conv_type', type=str, choices=['vanilla', 'gated', 'partial'], default='vanilla', help='Which kind of conv to use') parser.add_argument('--norm', type=str, default='None', choices=['None', 'BN', 'SN', 'IN'], help='The normalization method for the conv blocks') parser.add_argument('--use_bias', type=int, default=1, help='If 1, use bias in the convolution blocks') parser.add_argument('--ape', type=int, default=1, help='If ape = 1, use absolute positional embedding') parser.add_argument('--pos_mode', type=str, default='single', choices=['single', 'dual'], help='If pos_mode = dual, add positional embedding to flow patches') parser.add_argument('--mlp_ratio', type=int, default=40, help='The mlp dilation rate for the feed forward layers') parser.add_argument('--drop', type=int, default=0, help='The dropout rate, 0 by default') parser.add_argument('--init_weights', type=int, default=1, help='If 1, initialize the network, 1 by default') # loss related parameters parser.add_argument('--L1M', type=float, default=1, help='The weight of L1 loss in the masked area') parser.add_argument('--L1V', type=float, default=1, help='The weight of L1 loss in the valid area') parser.add_argument('--adv', type=float, default=0.01, help='The weight of adversarial loss') # spatial and temporal related parameters parser.add_argument('--tw', type=int, default=2, help='The number of temporal group in the temporal transformer') parser.add_argument('--sw', type=int, default=8, help='The number of spatial window size in the spatial transformer') parser.add_argument('--gd', type=int, default=4, help='Global downsample rate for spatial transformer') parser.add_argument('--ref_length', type=int, default=10, help='The sample interval during inference') parser.add_argument('--use_valid', action='store_true') args = parser.parse_args() return args