ayush2003 commited on
Commit
f6b1175
·
1 Parent(s): 95fd861

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -1
app.py CHANGED
@@ -493,13 +493,15 @@ from torchvision import transforms, utils
493
  from matplotlib import pyplot as plt
494
  import numpy as np
495
 
 
 
 
496
 
497
  def predict_pose(test_image):
498
  img = cv2.resize(test_image, (32,32))
499
  convert_tensor = transforms.ToTensor()
500
  tensor_img = convert_tensor(img)
501
  tensor_img = tensor_img[None,:,:,:]
502
- model.eval()
503
 
504
  outputs = model(tensor_img)
505
 
 
493
  from matplotlib import pyplot as plt
494
  import numpy as np
495
 
496
+ model = SimpleCNN()
497
+ model.load_state_dict(torch.load("model.pth"))
498
+ model.eval()
499
 
500
  def predict_pose(test_image):
501
  img = cv2.resize(test_image, (32,32))
502
  convert_tensor = transforms.ToTensor()
503
  tensor_img = convert_tensor(img)
504
  tensor_img = tensor_img[None,:,:,:]
 
505
 
506
  outputs = model(tensor_img)
507