oguzakif's picture
init repo
d4b77ac
import argparse
def args_parser():
parser = argparse.ArgumentParser(description="General top layer trainer")
parser.add_argument("--opt", type=str, default="config/train.yaml", help="Path to optional configuration file")
parser.add_argument('--model', type=str, default='model',
help='Model block name, in the `model` directory')
parser.add_argument('--name', type=str, default='FGT_train', help='Experiment name')
parser.add_argument('--outputdir', type=str, default='/myData/ret/experiments', help='Output dir to save results')
parser.add_argument('--datadir', type=str, default='/myData/', metavar='PATH')
parser.add_argument('--datasetName_train', type=str, default='train_dataset_frames_diffusedFlows',
help='The file name of the train dataset, in `data` directory')
parser.add_argument('--network', type=str, default='network',
help='The network file which defines the training process, in the `network` directory')
parser.add_argument('--finetune', type=int, default=0, help='Whether to fine tune trained models')
# parser.add_argument('--checkPoint', type=str, default='', help='checkpoint path for continue training')
parser.add_argument('--gen_state', type=str, default='', help='Checkpoint of the generator')
parser.add_argument('--dis_state', type=str, default='', help='Checkpoint of the discriminator')
parser.add_argument('--opt_state', type=str, default='', help='Checkpoint of the options')
parser.add_argument('--record_iter', type=int, default=16, help='How many iters to print an item of log')
parser.add_argument('--flow_checkPoint', type=str, default='flowCheckPoint/',
help='The path for flow model filling')
parser.add_argument('--dataMode', type=str, default='resize', choices=['resize', 'crop'])
# data related parameters
parser.add_argument('--flow2rgb', type=int, default=1, help='Whether to transform flows from raw data to rgb')
parser.add_argument('--flow_direction', type=str, default='for', choices=['for', 'back', 'bi'],
help='Which GT flow should be chosen for guidance')
parser.add_argument('--num_frames', type=int, default=5, help='How many frames are chosen for frame completion')
parser.add_argument('--sample', type=str, default='random', choices=['random', 'seq'],
help='Choose the sample method for training in each iterations')
parser.add_argument('--max_val', type=float, default=0.01, help='The maximal value to quantize the optical flows')
# model related parameters
parser.add_argument('--res_h', type=int, default=240, help='The height of the frame resolution')
parser.add_argument('--res_w', type=int, default=432, help='The width of the frame resolution')
parser.add_argument('--in_channel', type=int, default=4, help='The input channel of the frame branch')
parser.add_argument('--cnum', type=int, default=64, help='The initial channel number of the frame branch')
parser.add_argument('--flow_inChannel', type=int, default=2, help='The input channel of the flow branch')
parser.add_argument('--flow_cnum', type=int, default=64, help='The initial channel dimension of the flow branch')
parser.add_argument('--dist_cnum', type=int, default=32, help='The initial channel num in the discriminator')
parser.add_argument('--frame_hidden', type=int, default=512,
help='The channel / patch dimension in the frame branch')
parser.add_argument('--flow_hidden', type=int, default=256, help='The channel / patch dimension in the flow branch')
parser.add_argument('--PASSMASK', type=int, default=1,
help='1 -> concat the mask with the corrupted optical flows to fill the flow')
parser.add_argument('--numBlocks', type=int, default=8, help='How many transformer blocks do we need to stack')
parser.add_argument('--kernel_size_w', type=int, default=7, help='The width of the kernel for extracting patches')
parser.add_argument('--kernel_size_h', type=int, default=7, help='The height of the kernel for extracting patches')
parser.add_argument('--stride_h', type=int, default=3, help='The height of the stride')
parser.add_argument('--stride_w', type=int, default=3, help='The width of the stride')
parser.add_argument('--pad_h', type=int, default=3, help='The height of the padding')
parser.add_argument('--pad_w', type=int, default=3, help='The width of the padding')
parser.add_argument('--num_head', type=int, default=4, help='The head number for the multihead attention')
parser.add_argument('--conv_type', type=str, choices=['vanilla', 'gated', 'partial'], default='vanilla',
help='Which kind of conv to use')
parser.add_argument('--norm', type=str, default='None', choices=['None', 'BN', 'SN', 'IN'],
help='The normalization method for the conv blocks')
parser.add_argument('--use_bias', type=int, default=1, help='If 1, use bias in the convolution blocks')
parser.add_argument('--ape', type=int, default=1, help='If ape = 1, use absolute positional embedding')
parser.add_argument('--pos_mode', type=str, default='single', choices=['single', 'dual'],
help='If pos_mode = dual, add positional embedding to flow patches')
parser.add_argument('--mlp_ratio', type=int, default=40, help='The mlp dilation rate for the feed forward layers')
parser.add_argument('--drop', type=int, default=0, help='The dropout rate, 0 by default')
parser.add_argument('--init_weights', type=int, default=1, help='If 1, initialize the network, 1 by default')
# loss related parameters
parser.add_argument('--L1M', type=float, default=1, help='The weight of L1 loss in the masked area')
parser.add_argument('--L1V', type=float, default=1, help='The weight of L1 loss in the valid area')
parser.add_argument('--adv', type=float, default=0.01, help='The weight of adversarial loss')
# spatial and temporal related parameters
parser.add_argument('--tw', type=int, default=2, help='The number of temporal group in the temporal transformer')
parser.add_argument('--sw', type=int, default=8,
help='The number of spatial window size in the spatial transformer')
parser.add_argument('--gd', type=int, default=4, help='Global downsample rate for spatial transformer')
parser.add_argument('--ref_length', type=int, default=10, help='The sample interval during inference')
parser.add_argument('--use_valid', action='store_true')
args = parser.parse_args()
return args