Nicolò commited on
Commit
5105c39
1 Parent(s): 7d6d113

add more friendly printing

Browse files
Files changed (1) hide show
  1. gan_vs_real_detector.py +12 -3
gan_vs_real_detector.py CHANGED
@@ -17,6 +17,8 @@ from utils import architectures
17
  from utils.python_patch_extractor.PatchExtractor import PatchExtractor
18
  from PIL import Image
19
 
 
 
20
 
21
  class Detector:
22
  def __init__(self):
@@ -29,11 +31,11 @@ class Detector:
29
  self.device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
30
 
31
  self.nets = []
32
- for i in range(5):
33
  # Instantiate and load network
34
  network_class = getattr(architectures, 'EfficientNetB4')
35
  net = network_class(n_classes=2, pretrained=False).eval().to(self.device)
36
- print('Loading model...')
37
  state_tmp = torch.load(self.weights_path_list[i], map_location='cpu')
38
 
39
  if 'net' not in state_tmp.keys():
@@ -75,6 +77,7 @@ class Detector:
75
  print('Omitting alpha channel')
76
  img = img[:, :, :3]
77
 
 
78
  img_net_scores = []
79
  for net_idx, net in enumerate(self.nets):
80
 
@@ -113,8 +116,14 @@ class Detector:
113
 
114
 
115
  def main():
 
 
 
 
 
 
116
  # debug img_path on fermi:
117
- img_path = '/home/nbonettini/nvidia_temp/nvidia-alias-free-gan/faces/alias-free-r-afhqv2-512x512/seed40000.png'
118
 
119
  detector = Detector()
120
  score = detector.synth_real_detector(img_path)
 
17
  from utils.python_patch_extractor.PatchExtractor import PatchExtractor
18
  from PIL import Image
19
 
20
+ import argparse
21
+
22
 
23
  class Detector:
24
  def __init__(self):
 
31
  self.device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
32
 
33
  self.nets = []
34
+ for i, l in enumerate('ABCDE'):
35
  # Instantiate and load network
36
  network_class = getattr(architectures, 'EfficientNetB4')
37
  net = network_class(n_classes=2, pretrained=False).eval().to(self.device)
38
+ print(f'Loading model {l}...')
39
  state_tmp = torch.load(self.weights_path_list[i], map_location='cpu')
40
 
41
  if 'net' not in state_tmp.keys():
 
77
  print('Omitting alpha channel')
78
  img = img[:, :, :3]
79
 
80
+ print('Computing scores...')
81
  img_net_scores = []
82
  for net_idx, net in enumerate(self.nets):
83
 
 
116
 
117
 
118
  def main():
119
+
120
+ parser = argparse.ArgumentParser()
121
+ parser.add_argument('--img_path', help='Pat to the test image', required=True)
122
+ args = parser.parse_args()
123
+
124
+ img_path = args.img_path
125
  # debug img_path on fermi:
126
+ # img_path = '/home/nbonettini/nvidia_temp/nvidia-alias-free-gan/faces/alias-free-r-afhqv2-512x512/seed40000.png'
127
 
128
  detector = Detector()
129
  score = detector.synth_real_detector(img_path)