tidus2102 commited on
Commit
c9a1cd9
·
verified ·
1 Parent(s): 702b70a

Create Real-ESRGAN_x2plus__convert_pth_to_onnx.py

Browse files
Real-ESRGAN_x2plus__convert_pth_to_onnx.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.onnx as onnx
3
+ from basicsr.archs.rrdbnet_arch import RRDBNet
4
+
5
+ # Load the PyTorch model
6
+ device = torch.device('cpu')
7
+ model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
8
+
9
+ # Load the state dictionary
10
+ state_dict = torch.load('Real-ESRGAN_x2plus.pth', map_location=device)
11
+
12
+ # Load the state dictionary
13
+ model.load_state_dict(state_dict['params_ema'])
14
+ model.train(False)
15
+
16
+ # Set the model to evaluation mode
17
+ model.eval()
18
+
19
+ # Define the input shape
20
+ input_shape = (1, 3, 64, 64) # batch_size, channels, height, width
21
+
22
+ # Create a dummy input tensor
23
+ dummy_input = torch.randn(input_shape)
24
+
25
+ # Convert the model to ONNX
26
+ onnx.export(model,
27
+ dummy_input,
28
+ 'Real-ESRGAN_x2plus.onnx',
29
+ opset_version=11,
30
+ input_names=['input'],
31
+ output_names=['output'],
32
+ dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}})