birdortyedi commited on
Commit
2ee4de9
1 Parent(s): 74de975

hf hub added

Browse files
Files changed (1) hide show
  1. app.py +4 -3
app.py CHANGED
@@ -18,14 +18,15 @@ cfg.MODEL.CKPT = model_path
18
  net, _ = build_model(cfg)
19
  net = net.eval()
20
  vgg16 = models.vgg16(pretrained=True).features.eval()
 
21
 
22
 
23
- def load_checkpoints_from_ckpt(ckpt_path):
24
- checkpoints = torch.load(ckpt_path, map_location=torch.device('cuda'))
25
  net.load_state_dict(checkpoints["ifr"])
26
 
27
 
28
- load_checkpoints_from_ckpt(cfg.MODEL.CKPT)
29
 
30
 
31
  def filter_removal(img):
 
18
  net, _ = build_model(cfg)
19
  net = net.eval()
20
  vgg16 = models.vgg16(pretrained=True).features.eval()
21
+ device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
22
 
23
 
24
+ def load_checkpoints_from_ckpt(ckpt_path, device):
25
+ checkpoints = torch.load(ckpt_path, map_location=device)
26
  net.load_state_dict(checkpoints["ifr"])
27
 
28
 
29
+ load_checkpoints_from_ckpt(cfg.MODEL.CKPT, device)
30
 
31
 
32
  def filter_removal(img):