Tej3 commited on
Commit
d0f8eba
1 Parent(s): 71bd54f

Updating app file

Browse files
Files changed (1) hide show
  1. app.py +2 -2
app.py CHANGED
@@ -61,13 +61,13 @@ def infer(model,data, notes):
61
  data= torch.tensor(data)
62
  if model == "CNN":
63
  model = MMCNN_CAT()
64
- checkpoint = torch.load(MMCNN_CAT_ckpt_path)
65
  model.load_state_dict(checkpoint['model_state_dict'])
66
  data = data.transpose(1,2).float()
67
 
68
  elif model == "RNN":
69
  model = MMRNN(device='cpu')
70
- model.load_state_dict(torch.load(MMRNN_ckpt_path)['model_state_dict'])
71
  data = data.float()
72
  model.eval()
73
  outputs, predicted = predict(model, data, embed_notes, device='cpu')
 
61
  data= torch.tensor(data)
62
  if model == "CNN":
63
  model = MMCNN_CAT()
64
+ checkpoint = torch.load(MMCNN_CAT_ckpt_path, map_location="cpu")
65
  model.load_state_dict(checkpoint['model_state_dict'])
66
  data = data.transpose(1,2).float()
67
 
68
  elif model == "RNN":
69
  model = MMRNN(device='cpu')
70
+ model.load_state_dict(torch.load(MMRNN_ckpt_path, map_location="cpu")['model_state_dict'])
71
  data = data.float()
72
  model.eval()
73
  outputs, predicted = predict(model, data, embed_notes, device='cpu')