С Чичерин commited on
Commit
8e0f2c0
1 Parent(s): 6498dfa

added config

Browse files
Files changed (2) hide show
  1. base.yaml +61 -0
  2. test.py +2 -4
base.yaml ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ world_size: 1
2
+ experiment_name: "test"
3
+ datasets:
4
+ synthetic_fg: "/home/jovyan/datasets/synthetic_psi/"
5
+ synthetic_animals: "/home/jovyan/datasets/synthetic_psiny/"
6
+ bg: "/home/jovyan/datasets/matting/background/testval/"
7
+ ppm100: "/home/jovyan/kvanchiani/stylegan3/PPM-100/image"
8
+ aim500: "/home/jovyan/datasets/AIM-500"
9
+ am2k:
10
+ train_original: "/home/jovyan/datasets/matting/am-2k/train/original"
11
+ train_mask: "/home/jovyan/datasets/matting/am-2k/train/mask"
12
+ background: "/home/jovyan/datasets/matting/am-2k/background/train"
13
+ validation_original: "/home/jovyan/datasets/matting/am-2k/validation/original/"
14
+ validation_mask: "/home/jovyan/datasets/matting/am-2k/validation/mask/"
15
+ validation_trimap: "/home/jovyan/datasets/matting/am-2k/validation/trimap/"
16
+ tiktok: "/home/jovyan/datasets/tiktokdataset/dataset"
17
+ p3m10k: "/home/jovyan/datasets/matting/P3M-10k"
18
+ p3m10k_test:
19
+ VAL500P:
20
+ ROOT_PATH: "P3M-500-P/"
21
+ ORIGINAL_PATH: "P3M-500-P/blurred_image/"
22
+ MASK_PATH: "P3M-500-P/mask/"
23
+ TRIMAP_PATH: "P3M-500-P/trimap/"
24
+ SAMPLE_NUMBER: 500
25
+ VAL500NP:
26
+ ROOT_PATH: "P3M-500-NP/"
27
+ ORIGINAL_PATH: "P3M-500-NP/original_image/"
28
+ MASK_PATH: "P3M-500-NP/mask/"
29
+ TRIMAP_PATH: "P3M-500-NP/trimap/"
30
+ SAMPLE_NUMBER: 500
31
+ MAX_SIZE_H: 1600
32
+ MAX_SIZE_W: 1600
33
+ image_crop: 800
34
+ max_image_count: 10000
35
+ dataset_to_use: MixedDataset
36
+ pretrained_model: "microsoft/swinv2-tiny-patch4-window8-256" #"nielsr/mask2former-swin-base-youtubevis-2021" #"nvidia/mit-b2"
37
+ batch_size: 4
38
+ num_workers: 4
39
+ log_dir: "log"
40
+ checkpoint_dir: "checkpoints"
41
+ checkpoint: "best-89.pth"
42
+ distributed_addr: "localhost"
43
+ distributed_port: "12357"
44
+ image_size: 800
45
+ lr: 1e-7
46
+ epochs: 200
47
+ disable_validation: False
48
+ warmup_steps: 2
49
+ validate_each_epoch: 5
50
+ max_images_for_validation: 500
51
+ disable_mixed_precision: True
52
+ log_image_interval: 500
53
+ log_image_number: 8
54
+ save_model_interval: 10000
55
+ switch: 3
56
+ lambda_losses:
57
+ default: 1.
58
+ Laplassian: 1.
59
+ Grad: 3.
60
+ L1: 1.
61
+ switch: 1e-6
test.py CHANGED
@@ -25,8 +25,7 @@ from typing import Type, Any, Callable, Union, List, Optional
25
  import logging
26
  import time
27
  from omegaconf import OmegaConf
28
- config = OmegaConf.load(os.path.join(os.path.dirname(
29
- os.path.abspath(__file__)), "config/base.yaml"))
30
  device = "cuda"
31
 
32
  def conv3x3(in_planes, out_planes, stride=1):
@@ -981,8 +980,7 @@ def log(str):
981
 
982
  if __name__ == '__main__':
983
  print('*********************************')
984
- config = OmegaConf.load(os.path.join(os.path.dirname(
985
- os.path.abspath(__file__)), "config/base.yaml"))
986
  config=OmegaConf.merge(config,OmegaConf.from_cli())
987
  print(config)
988
  model = MaskForm()
 
25
  import logging
26
  import time
27
  from omegaconf import OmegaConf
28
+ config = OmegaConf.load("base.yaml")
 
29
  device = "cuda"
30
 
31
  def conv3x3(in_planes, out_planes, stride=1):
 
980
 
981
  if __name__ == '__main__':
982
  print('*********************************')
983
+ config = OmegaConf.load("base.yaml"))
 
984
  config=OmegaConf.merge(config,OmegaConf.from_cli())
985
  print(config)
986
  model = MaskForm()