File size: 6,716 Bytes
d4b77ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import argparse


def args_parser():
    parser = argparse.ArgumentParser(description="General top layer trainer")
    parser.add_argument("--opt", type=str, default="config/train.yaml", help="Path to optional configuration file")
    parser.add_argument('--model', type=str, default='model',
                        help='Model block name, in the `model` directory')
    parser.add_argument('--name', type=str, default='FGT_train', help='Experiment name')
    parser.add_argument('--outputdir', type=str, default='/myData/ret/experiments', help='Output dir to save results')
    parser.add_argument('--datadir', type=str, default='/myData/', metavar='PATH')
    parser.add_argument('--datasetName_train', type=str, default='train_dataset_frames_diffusedFlows',
                        help='The file name of the train dataset, in `data` directory')
    parser.add_argument('--network', type=str, default='network',
                        help='The network file which defines the training process, in the `network` directory')
    parser.add_argument('--finetune', type=int, default=0, help='Whether to fine tune trained models')
    # parser.add_argument('--checkPoint', type=str, default='', help='checkpoint path for continue training')
    parser.add_argument('--gen_state', type=str, default='', help='Checkpoint of the generator')
    parser.add_argument('--dis_state', type=str, default='', help='Checkpoint of the discriminator')
    parser.add_argument('--opt_state', type=str, default='', help='Checkpoint of the options')
    parser.add_argument('--record_iter', type=int, default=16, help='How many iters to print an item of log')
    parser.add_argument('--flow_checkPoint', type=str, default='flowCheckPoint/',
                        help='The path for flow model filling')
    parser.add_argument('--dataMode', type=str, default='resize', choices=['resize', 'crop'])

    # data related parameters
    parser.add_argument('--flow2rgb', type=int, default=1, help='Whether to transform flows from raw data to rgb')
    parser.add_argument('--flow_direction', type=str, default='for', choices=['for', 'back', 'bi'],
                        help='Which GT flow should be chosen for guidance')
    parser.add_argument('--num_frames', type=int, default=5, help='How many frames are chosen for frame completion')
    parser.add_argument('--sample', type=str, default='random', choices=['random', 'seq'],
                        help='Choose the sample method for training in each iterations')
    parser.add_argument('--max_val', type=float, default=0.01, help='The maximal value to quantize the optical flows')

    # model related parameters
    parser.add_argument('--res_h', type=int, default=240, help='The height of the frame resolution')
    parser.add_argument('--res_w', type=int, default=432, help='The width of the frame resolution')
    parser.add_argument('--in_channel', type=int, default=4, help='The input channel of the frame branch')
    parser.add_argument('--cnum', type=int, default=64, help='The initial channel number of the frame branch')
    parser.add_argument('--flow_inChannel', type=int, default=2, help='The input channel of the flow branch')
    parser.add_argument('--flow_cnum', type=int, default=64, help='The initial channel dimension of the flow branch')
    parser.add_argument('--dist_cnum', type=int, default=32, help='The initial channel num in the discriminator')
    parser.add_argument('--frame_hidden', type=int, default=512,
                        help='The channel / patch dimension in the frame branch')
    parser.add_argument('--flow_hidden', type=int, default=256, help='The channel / patch dimension in the flow branch')
    parser.add_argument('--PASSMASK', type=int, default=1,
                        help='1 -> concat the mask with the corrupted optical flows to fill the flow')
    parser.add_argument('--numBlocks', type=int, default=8, help='How many transformer blocks do we need to stack')
    parser.add_argument('--kernel_size_w', type=int, default=7, help='The width of the kernel for extracting patches')
    parser.add_argument('--kernel_size_h', type=int, default=7, help='The height of the kernel for extracting patches')
    parser.add_argument('--stride_h', type=int, default=3, help='The height of the stride')
    parser.add_argument('--stride_w', type=int, default=3, help='The width of the stride')
    parser.add_argument('--pad_h', type=int, default=3, help='The height of the padding')
    parser.add_argument('--pad_w', type=int, default=3, help='The width of the padding')
    parser.add_argument('--num_head', type=int, default=4, help='The head number for the multihead attention')
    parser.add_argument('--conv_type', type=str, choices=['vanilla', 'gated', 'partial'], default='vanilla',
                        help='Which kind of conv to use')
    parser.add_argument('--norm', type=str, default='None', choices=['None', 'BN', 'SN', 'IN'],
                        help='The normalization method for the conv blocks')
    parser.add_argument('--use_bias', type=int, default=1, help='If 1, use bias in the convolution blocks')
    parser.add_argument('--ape', type=int, default=1, help='If ape = 1, use absolute positional embedding')
    parser.add_argument('--pos_mode', type=str, default='single', choices=['single', 'dual'],
                        help='If pos_mode = dual, add positional embedding to flow patches')
    parser.add_argument('--mlp_ratio', type=int, default=40, help='The mlp dilation rate for the feed forward layers')
    parser.add_argument('--drop', type=int, default=0, help='The dropout rate, 0 by default')
    parser.add_argument('--init_weights', type=int, default=1, help='If 1, initialize the network, 1 by default')

    # loss related parameters
    parser.add_argument('--L1M', type=float, default=1, help='The weight of L1 loss in the masked area')
    parser.add_argument('--L1V', type=float, default=1, help='The weight of L1 loss in the valid area')
    parser.add_argument('--adv', type=float, default=0.01, help='The weight of adversarial loss')

    # spatial and temporal related parameters
    parser.add_argument('--tw', type=int, default=2, help='The number of temporal group in the temporal transformer')
    parser.add_argument('--sw', type=int, default=8,
                        help='The number of spatial window size in the spatial transformer')
    parser.add_argument('--gd', type=int, default=4, help='Global downsample rate for spatial transformer')

    parser.add_argument('--ref_length', type=int, default=10, help='The sample interval during inference')
    parser.add_argument('--use_valid', action='store_true')

    args = parser.parse_args()
    return args