swzamir commited on
Commit
5670d74
1 Parent(s): c2af834

Update demo_gradio.py

Browse files
Files changed (1) hide show
  1. demo_gradio.py +21 -28
demo_gradio.py CHANGED
@@ -6,10 +6,8 @@
6
  import torch
7
  import torch.nn.functional as F
8
  import os
9
- from runpy import run_path
10
  from skimage import img_as_ubyte
11
  import cv2
12
- from tqdm import tqdm
13
  import argparse
14
 
15
  parser = argparse.ArgumentParser(description='Test Restormer on your own images')
@@ -25,45 +23,36 @@ parser.add_argument('--task', required=True, type=str, help='Task to run', choic
25
  args = parser.parse_args()
26
 
27
 
28
- def get_weights_and_parameters(task, parameters):
29
- if task == 'Motion_Deblurring':
30
- weights = os.path.join('Motion_Deblurring', 'pretrained_models', 'motion_deblurring.pth')
31
- elif task == 'Single_Image_Defocus_Deblurring':
32
- weights = os.path.join('Defocus_Deblurring', 'pretrained_models', 'single_image_defocus_deblurring.pth')
33
- elif task == 'Deraining':
34
- weights = os.path.join('Deraining', 'pretrained_models', 'deraining.pth')
35
- elif task == 'Real_Denoising':
36
- weights = os.path.join('Denoising', 'pretrained_models', 'real_denoising.pth')
37
- parameters['LayerNorm_type'] = 'BiasFree'
38
- return weights, parameters
39
-
40
  task = args.task
41
  out_dir = os.path.join(args.result_dir, task)
42
 
43
  os.makedirs(out_dir, exist_ok=True)
44
 
45
- # Get model weights and parameters
46
- parameters = {'inp_channels':3, 'out_channels':3, 'dim':48, 'num_blocks':[4,6,6,8], 'num_refinement_blocks':4, 'heads':[1,2,4,8], 'ffn_expansion_factor':2.66, 'bias':False, 'LayerNorm_type':'WithBias', 'dual_pixel_task':False}
47
- weights, parameters = get_weights_and_parameters(task, parameters)
48
-
49
- load_arch = run_path(os.path.join('basicsr', 'models', 'archs', 'restormer_arch.py'))
50
- model = load_arch['Restormer'](**parameters)
51
 
 
 
 
 
 
 
 
 
 
52
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
53
  # device = torch.device('cpu')
 
54
 
55
  model = model.to(device)
56
- checkpoint = torch.load(weights)
57
- model.load_state_dict(checkpoint['params'])
58
-
59
- model.eval()
60
 
61
 
62
  img_multiple_of = 8
63
 
64
 
65
  with torch.inference_mode():
66
-
 
 
 
67
  img = cv2.cvtColor(cv2.imread(args.input_path), cv2.COLOR_BGR2RGB)
68
 
69
  input_ = torch.from_numpy(img).float().div(255.).permute(2,0,1).unsqueeze(0).to(device)
@@ -74,10 +63,14 @@ with torch.inference_mode():
74
  padh = H-h if h%img_multiple_of!=0 else 0
75
  padw = W-w if w%img_multiple_of!=0 else 0
76
  input_ = F.pad(input_, (0,padw,0,padh), 'reflect')
77
-
 
78
  restored = torch.clamp(model(input_),0,1)
79
-
80
  # Unpad the output
81
  restored = img_as_ubyte(restored[:,:,:h,:w].permute(0, 2, 3, 1).cpu().detach().numpy()[0])
82
 
83
- cv2.imwrite(os.path.join(out_dir, os.path.split(args.input_path)[-1]),cv2.cvtColor(restored, cv2.COLOR_RGB2BGR))
 
 
 
 
6
  import torch
7
  import torch.nn.functional as F
8
  import os
 
9
  from skimage import img_as_ubyte
10
  import cv2
 
11
  import argparse
12
 
13
  parser = argparse.ArgumentParser(description='Test Restormer on your own images')
 
23
  args = parser.parse_args()
24
 
25
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  task = args.task
27
  out_dir = os.path.join(args.result_dir, task)
28
 
29
  os.makedirs(out_dir, exist_ok=True)
30
 
 
 
 
 
 
 
31
 
32
+ if task == 'Motion_Deblurring':
33
+ model = torch.jit.load('motion_deblurring.pt')
34
+ elif task == 'Single_Image_Defocus_Deblurring':
35
+ model = torch.jit.load('single_image_defocus_deblurring.pt')
36
+ elif task == 'Deraining':
37
+ model = torch.jit.load('deraining.pt')
38
+ elif task == 'Real_Denoising':
39
+ model = torch.jit.load('real_denoising.pt')
40
+
41
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
42
  # device = torch.device('cpu')
43
+ # stx()
44
 
45
  model = model.to(device)
 
 
 
 
46
 
47
 
48
  img_multiple_of = 8
49
 
50
 
51
  with torch.inference_mode():
52
+ if torch.cuda.is_available():
53
+ torch.cuda.ipc_collect()
54
+ torch.cuda.empty_cache()
55
+
56
  img = cv2.cvtColor(cv2.imread(args.input_path), cv2.COLOR_BGR2RGB)
57
 
58
  input_ = torch.from_numpy(img).float().div(255.).permute(2,0,1).unsqueeze(0).to(device)
 
63
  padh = H-h if h%img_multiple_of!=0 else 0
64
  padw = W-w if w%img_multiple_of!=0 else 0
65
  input_ = F.pad(input_, (0,padw,0,padh), 'reflect')
66
+
67
+ # print(h,w)
68
  restored = torch.clamp(model(input_),0,1)
69
+
70
  # Unpad the output
71
  restored = img_as_ubyte(restored[:,:,:h,:w].permute(0, 2, 3, 1).cpu().detach().numpy()[0])
72
 
73
+ out_path = os.path.join(out_dir, os.path.split(args.input_path)[-1])
74
+ cv2.imwrite(out_path,cv2.cvtColor(restored, cv2.COLOR_RGB2BGR))
75
+
76
+ # print(f"\nRestored images are saved at {out_dir}")