zhengrongzhang commited on
Commit
30c4d88
1 Parent(s): da9195c

change onnx to NHWC (#1)

Browse files

- change onnx to NHWC (9f86314792bc63c7d4327617e9dd76b69ed0c1b6)

RCAN_int8.onnx → RCAN_int8_NHWC.onnx RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:1f80a5945e9d7bd9da2625aeec430dad3ba1123788edf36416f80ef59207c804
3
- size 445505
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9a1cad5da6396a4c812bb5b0d60c1470a1e2de1b9b4e8fe58ce4132972b164f7
3
+ size 445692
data/data_tiling.py CHANGED
@@ -18,7 +18,7 @@ def tiling_inference(session, lr, overlapping, patch_size):
18
  w_idx = w_idx if w_idx + patch_size[1] <= w else w - patch_size[1]
19
 
20
  tilling_lr = lr[..., h_idx: h_idx+patch_size[0], w_idx: w_idx+patch_size[1]]
21
- sr_tiling = session.run(None, {session.get_inputs()[0].name: tilling_lr})[0]
22
 
23
  left, right, top, bottom = 0, patch_size[1], 0, patch_size[0]
24
  left += overlapping//2
 
18
  w_idx = w_idx if w_idx + patch_size[1] <= w else w - patch_size[1]
19
 
20
  tilling_lr = lr[..., h_idx: h_idx+patch_size[0], w_idx: w_idx+patch_size[1]]
21
+ sr_tiling = session.run(None, {session.get_inputs()[0].name: tilling_lr.transpose(0,2,3,1)})[0].transpose(0,3,1,2)
22
 
23
  left, right, top, bottom = 0, patch_size[1], 0, patch_size[0]
24
  left += overlapping//2
eval_onnx.py CHANGED
@@ -19,7 +19,7 @@ class Configs():
19
  # ipu test or cpu, you need to provide onnx path
20
  parser.add_argument('--ipu', action='store_true',
21
  help='use ipu')
22
- parser.add_argument('--onnx_path', type=str, default='RCAN_int8.onnx',
23
  help='onnx path')
24
  parser.add_argument('--provider_config', type=str, default=None,
25
  help='provider config path')
 
19
  # ipu test or cpu, you need to provide onnx path
20
  parser.add_argument('--ipu', action='store_true',
21
  help='use ipu')
22
+ parser.add_argument('--onnx_path', type=str, default='RCAN_int8_NHWC.onnx',
23
  help='onnx path')
24
  parser.add_argument('--provider_config', type=str, default=None,
25
  help='provider config path')
infer_onnx.py CHANGED
@@ -31,7 +31,7 @@ def main(args):
31
 
32
  if __name__ == '__main__':
33
  parser = argparse.ArgumentParser(description='RCAN SISR')
34
- parser.add_argument('--onnx_path', type=str, default='RCAN_int8.onnx',
35
  help='onnx path')
36
  parser.add_argument('--image_path', default='test_data/test.png',
37
  help='path of your image')
 
31
 
32
  if __name__ == '__main__':
33
  parser = argparse.ArgumentParser(description='RCAN SISR')
34
+ parser.add_argument('--onnx_path', type=str, default='RCAN_int8_NHWC.onnx',
35
  help='onnx path')
36
  parser.add_argument('--image_path', default='test_data/test.png',
37
  help='path of your image')