Yuanhao Zhai commited on
Commit
2f7bb6a
1 Parent(s): 83e38fa

append probablities to the output

Browse files
Files changed (3) hide show
  1. demo.py +7 -1
  2. opt.py +8 -5
  3. requirements.txt +1 -0
demo.py CHANGED
@@ -31,6 +31,11 @@ def demo(folder_path, output_path=Path("tmp")):
31
  image = image.to(opt.device).unsqueeze(0)
32
  outputs = model(image, seg_size=image_size)
33
  out_map = outputs["ensemble"]["out_map"][0, ...].detach().cpu()
 
 
 
 
 
34
 
35
  overlay = draw_segmentation_masks(
36
  dsm_image, masks=out_map[0, ...] > opt.mask_threshold
@@ -43,7 +48,8 @@ def demo(folder_path, output_path=Path("tmp")):
43
  ],
44
  padding=5,
45
  )
46
- save_image(grid_image, (output_path / image_path.name).as_posix())
 
47
 
48
 
49
  if __name__ == "__main__":
 
31
  image = image.to(opt.device).unsqueeze(0)
32
  outputs = model(image, seg_size=image_size)
33
  out_map = outputs["ensemble"]["out_map"][0, ...].detach().cpu()
34
+ pred = outputs["ensemble"]["out_map"].max().item()
35
+ if pred > opt.mask_threshold:
36
+ print(f"Found manipulation in {image_path.name}")
37
+ else:
38
+ print(f"No manipulation found in {image_path.name}")
39
 
40
  overlay = draw_segmentation_masks(
41
  dsm_image, masks=out_map[0, ...] > opt.mask_threshold
 
48
  ],
49
  padding=5,
50
  )
51
+ image_name = image_path.stem + f"-{pred:.2f}" + image_path.suffix
52
+ save_image(grid_image, (output_path / image_name).as_posix())
53
 
54
 
55
  if __name__ == "__main__":
opt.py CHANGED
@@ -10,8 +10,8 @@ import yaml
10
  from termcolor import cprint
11
 
12
 
13
- def load_dataset_arguments(opt):
14
- if opt.load is None:
15
  return
16
 
17
  # exclude parameters assigned in the command
@@ -24,7 +24,10 @@ def load_dataset_arguments(opt):
24
  arguments = []
25
 
26
  # load parameters in the yaml file
27
- assert os.path.exists(opt.load)
 
 
 
28
  with open(opt.load, "r") as f:
29
  yaml_arguments = yaml.safe_load(f)
30
  # TODO this should be verified
@@ -33,7 +36,7 @@ def load_dataset_arguments(opt):
33
  setattr(opt, k, v)
34
 
35
 
36
- def get_opt(additional_parsers: Optional[List] = None):
37
  parents = [get_arguments_parser()]
38
  if additional_parsers:
39
  parents.extend(additional_parsers)
@@ -43,7 +46,7 @@ def get_opt(additional_parsers: Optional[List] = None):
43
  opt = parser.parse_known_args()[0]
44
 
45
  # load dataset argument file
46
- load_dataset_arguments(opt)
47
 
48
  # user-defined warnings and assertions
49
  if opt.decoder.lower() not in ["c1"]:
 
10
  from termcolor import cprint
11
 
12
 
13
+ def load_dataset_arguments(cfg_path, opt):
14
+ if opt.load is None and cfg_path is None:
15
  return
16
 
17
  # exclude parameters assigned in the command
 
24
  arguments = []
25
 
26
  # load parameters in the yaml file
27
+ if cfg_path is not None:
28
+ opt.load = cfg_path
29
+ else:
30
+ assert os.path.exists(opt.load)
31
  with open(opt.load, "r") as f:
32
  yaml_arguments = yaml.safe_load(f)
33
  # TODO this should be verified
 
36
  setattr(opt, k, v)
37
 
38
 
39
+ def get_opt(cfg_path: Optional[str] = None, additional_parsers: Optional[List] = None):
40
  parents = [get_arguments_parser()]
41
  if additional_parsers:
42
  parents.extend(additional_parsers)
 
46
  opt = parser.parse_known_args()[0]
47
 
48
  # load dataset argument file
49
+ load_dataset_arguments(cfg_path, opt)
50
 
51
  # user-defined warnings and assertions
52
  if opt.decoder.lower() not in ["c1"]:
requirements.txt CHANGED
@@ -26,3 +26,4 @@ timm==0.9.12
26
  torch==1.12.1+cu116
27
  torchvision==0.13.1+cu116
28
  tqdm==4.64.1
 
 
26
  torch==1.12.1+cu116
27
  torchvision==0.13.1+cu116
28
  tqdm==4.64.1
29
+ markupsafe==2.0.1