Spaces:
Runtime error
Runtime error
birdortyedi
commited on
Commit
•
2ee4de9
1
Parent(s):
74de975
hf hub added
Browse files
app.py
CHANGED
@@ -18,14 +18,15 @@ cfg.MODEL.CKPT = model_path
|
|
18 |
net, _ = build_model(cfg)
|
19 |
net = net.eval()
|
20 |
vgg16 = models.vgg16(pretrained=True).features.eval()
|
|
|
21 |
|
22 |
|
23 |
-
def load_checkpoints_from_ckpt(ckpt_path):
|
24 |
-
checkpoints = torch.load(ckpt_path, map_location=
|
25 |
net.load_state_dict(checkpoints["ifr"])
|
26 |
|
27 |
|
28 |
-
load_checkpoints_from_ckpt(cfg.MODEL.CKPT)
|
29 |
|
30 |
|
31 |
def filter_removal(img):
|
|
|
18 |
net, _ = build_model(cfg)
|
19 |
net = net.eval()
|
20 |
vgg16 = models.vgg16(pretrained=True).features.eval()
|
21 |
+
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
22 |
|
23 |
|
24 |
+
def load_checkpoints_from_ckpt(ckpt_path, device):
|
25 |
+
checkpoints = torch.load(ckpt_path, map_location=device)
|
26 |
net.load_state_dict(checkpoints["ifr"])
|
27 |
|
28 |
|
29 |
+
load_checkpoints_from_ckpt(cfg.MODEL.CKPT, device)
|
30 |
|
31 |
|
32 |
def filter_removal(img):
|