Nigus commited on
Commit
0d74fb6
1 Parent(s): 23b6452

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -0
app.py CHANGED
@@ -9,6 +9,15 @@ torch.hub.download_url_to_file("https://github.com/pytorch/hub/raw/master/images
9
  model = torch.hub.load('pytorch/vision:v0.10.0', 'shufflenet_v2_x1_0', pretrained=True)
10
  model.load_state_dict(torch.load('shufflenetv2_x1-5666bf0f80.pth'))
11
 
 
 
 
 
 
 
 
 
 
12
  os.system("wget https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt")
13
 
14
  def inference(input_image):
 
9
  model = torch.hub.load('pytorch/vision:v0.10.0', 'shufflenet_v2_x1_0', pretrained=True)
10
  model.load_state_dict(torch.load('shufflenetv2_x1-5666bf0f80.pth'))
11
 
12
+ # To export
13
+ model = torch.hub.load('pytorch/vision:v0.10.0', 'shufflenet_v2_x1_0', pretrained=True).eval()
14
+ traced_graph = torch.jit.trace(model, torch.randn(1, 3, H, W))
15
+ traced_graph.save('shufflenet.pth')
16
+
17
+ # To load
18
+ model = torch.jit.load('shufflenet.pth').eval().to(device)
19
+
20
+
21
  os.system("wget https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt")
22
 
23
  def inference(input_image):