52Hz commited on
Commit
6185242
1 Parent(s): 03673d4

Update main_test_CMFNet.py

Browse files
Files changed (1) hide show
  1. main_test_CMFNet.py +13 -2
main_test_CMFNet.py CHANGED
@@ -13,6 +13,7 @@ import torch.nn.functional as F
13
  from natsort import natsorted
14
  from model.CMFNet import CMFNet
15
 
 
16
  def main():
17
  parser = argparse.ArgumentParser(description='Demo Image Deraindrop')
18
  parser.add_argument('--input_dir', default='test/', type=str, help='Input images')
@@ -64,7 +65,7 @@ def main():
64
 
65
  f = os.path.splitext(os.path.split(file_)[-1])[0]
66
  save_img((os.path.join(out_dir, f + '.png')), restored)
67
-
68
 
69
 
70
  def save_img(filepath, img):
@@ -82,7 +83,17 @@ def load_checkpoint(model, weights):
82
  name = k[7:] # remove `module.`
83
  new_state_dict[name] = v
84
  model.load_state_dict(new_state_dict)
85
-
 
 
 
 
 
 
 
 
 
 
86
 
87
  if __name__ == '__main__':
88
  main()
 
13
  from natsort import natsorted
14
  from model.CMFNet import CMFNet
15
 
16
+
17
  def main():
18
  parser = argparse.ArgumentParser(description='Demo Image Deraindrop')
19
  parser.add_argument('--input_dir', default='test/', type=str, help='Input images')
 
65
 
66
  f = os.path.splitext(os.path.split(file_)[-1])[0]
67
  save_img((os.path.join(out_dir, f + '.png')), restored)
68
+ clean_folder(inp_dir)
69
 
70
 
71
  def save_img(filepath, img):
 
83
  name = k[7:] # remove `module.`
84
  new_state_dict[name] = v
85
  model.load_state_dict(new_state_dict)
86
+
87
+ def clean_folder(folder):
88
+ for filename in os.listdir(folder):
89
+ file_path = os.path.join(folder, filename)
90
+ try:
91
+ if os.path.isfile(file_path) or os.path.islink(file_path):
92
+ os.unlink(file_path)
93
+ elif os.path.isdir(file_path):
94
+ shutil.rmtree(file_path)
95
+ except Exception as e:
96
+ print('Failed to delete %s. Reason: %s' % (file_path, e))
97
 
98
  if __name__ == '__main__':
99
  main()