cycool29 commited on
Commit
bc9b6cb
·
1 Parent(s): 20b1517
Files changed (1) hide show
  1. predict.py +1 -2
predict.py CHANGED
@@ -25,8 +25,7 @@ from configs import *
25
  model = MODEL.to(DEVICE)
26
  # model.load_state_dict(torch.load(MODEL_SAVE_PATH, map_location=DEVICE))
27
  model.load_state_dict(
28
- torch.load("output/checkpoints/EfficientNetB3WithDropout.pth", map_location=DEVICE)
29
- )
30
  model.eval()
31
 
32
  torch.set_grad_enabled(False)
 
25
  model = MODEL.to(DEVICE)
26
  # model.load_state_dict(torch.load(MODEL_SAVE_PATH, map_location=DEVICE))
27
  model.load_state_dict(
28
+ torch.load(f"output/checkpoints/{MODEL.__class__.__name__}.pth", map_location=DEVICE))
 
29
  model.eval()
30
 
31
  torch.set_grad_enabled(False)