Yuanhao Zhai
commited on
Commit
•
2f7bb6a
1
Parent(s):
83e38fa
append probablities to the output
Browse files- demo.py +7 -1
- opt.py +8 -5
- requirements.txt +1 -0
demo.py
CHANGED
@@ -31,6 +31,11 @@ def demo(folder_path, output_path=Path("tmp")):
|
|
31 |
image = image.to(opt.device).unsqueeze(0)
|
32 |
outputs = model(image, seg_size=image_size)
|
33 |
out_map = outputs["ensemble"]["out_map"][0, ...].detach().cpu()
|
|
|
|
|
|
|
|
|
|
|
34 |
|
35 |
overlay = draw_segmentation_masks(
|
36 |
dsm_image, masks=out_map[0, ...] > opt.mask_threshold
|
@@ -43,7 +48,8 @@ def demo(folder_path, output_path=Path("tmp")):
|
|
43 |
],
|
44 |
padding=5,
|
45 |
)
|
46 |
-
|
|
|
47 |
|
48 |
|
49 |
if __name__ == "__main__":
|
|
|
31 |
image = image.to(opt.device).unsqueeze(0)
|
32 |
outputs = model(image, seg_size=image_size)
|
33 |
out_map = outputs["ensemble"]["out_map"][0, ...].detach().cpu()
|
34 |
+
pred = outputs["ensemble"]["out_map"].max().item()
|
35 |
+
if pred > opt.mask_threshold:
|
36 |
+
print(f"Found manipulation in {image_path.name}")
|
37 |
+
else:
|
38 |
+
print(f"No manipulation found in {image_path.name}")
|
39 |
|
40 |
overlay = draw_segmentation_masks(
|
41 |
dsm_image, masks=out_map[0, ...] > opt.mask_threshold
|
|
|
48 |
],
|
49 |
padding=5,
|
50 |
)
|
51 |
+
image_name = image_path.stem + f"-{pred:.2f}" + image_path.suffix
|
52 |
+
save_image(grid_image, (output_path / image_name).as_posix())
|
53 |
|
54 |
|
55 |
if __name__ == "__main__":
|
opt.py
CHANGED
@@ -10,8 +10,8 @@ import yaml
|
|
10 |
from termcolor import cprint
|
11 |
|
12 |
|
13 |
-
def load_dataset_arguments(opt):
|
14 |
-
if opt.load is None:
|
15 |
return
|
16 |
|
17 |
# exclude parameters assigned in the command
|
@@ -24,7 +24,10 @@ def load_dataset_arguments(opt):
|
|
24 |
arguments = []
|
25 |
|
26 |
# load parameters in the yaml file
|
27 |
-
|
|
|
|
|
|
|
28 |
with open(opt.load, "r") as f:
|
29 |
yaml_arguments = yaml.safe_load(f)
|
30 |
# TODO this should be verified
|
@@ -33,7 +36,7 @@ def load_dataset_arguments(opt):
|
|
33 |
setattr(opt, k, v)
|
34 |
|
35 |
|
36 |
-
def get_opt(additional_parsers: Optional[List] = None):
|
37 |
parents = [get_arguments_parser()]
|
38 |
if additional_parsers:
|
39 |
parents.extend(additional_parsers)
|
@@ -43,7 +46,7 @@ def get_opt(additional_parsers: Optional[List] = None):
|
|
43 |
opt = parser.parse_known_args()[0]
|
44 |
|
45 |
# load dataset argument file
|
46 |
-
load_dataset_arguments(opt)
|
47 |
|
48 |
# user-defined warnings and assertions
|
49 |
if opt.decoder.lower() not in ["c1"]:
|
|
|
10 |
from termcolor import cprint
|
11 |
|
12 |
|
13 |
+
def load_dataset_arguments(cfg_path, opt):
|
14 |
+
if opt.load is None and cfg_path is None:
|
15 |
return
|
16 |
|
17 |
# exclude parameters assigned in the command
|
|
|
24 |
arguments = []
|
25 |
|
26 |
# load parameters in the yaml file
|
27 |
+
if cfg_path is not None:
|
28 |
+
opt.load = cfg_path
|
29 |
+
else:
|
30 |
+
assert os.path.exists(opt.load)
|
31 |
with open(opt.load, "r") as f:
|
32 |
yaml_arguments = yaml.safe_load(f)
|
33 |
# TODO this should be verified
|
|
|
36 |
setattr(opt, k, v)
|
37 |
|
38 |
|
39 |
+
def get_opt(cfg_path: Optional[str] = None, additional_parsers: Optional[List] = None):
|
40 |
parents = [get_arguments_parser()]
|
41 |
if additional_parsers:
|
42 |
parents.extend(additional_parsers)
|
|
|
46 |
opt = parser.parse_known_args()[0]
|
47 |
|
48 |
# load dataset argument file
|
49 |
+
load_dataset_arguments(cfg_path, opt)
|
50 |
|
51 |
# user-defined warnings and assertions
|
52 |
if opt.decoder.lower() not in ["c1"]:
|
requirements.txt
CHANGED
@@ -26,3 +26,4 @@ timm==0.9.12
|
|
26 |
torch==1.12.1+cu116
|
27 |
torchvision==0.13.1+cu116
|
28 |
tqdm==4.64.1
|
|
|
|
26 |
torch==1.12.1+cu116
|
27 |
torchvision==0.13.1+cu116
|
28 |
tqdm==4.64.1
|
29 |
+
markupsafe==2.0.1
|