lalala125 commited on
Commit
4101738
1 Parent(s): 0be014d

[Release] cpu support

Browse files
Files changed (1) hide show
  1. demo_img.py +1 -1
demo_img.py CHANGED
@@ -20,7 +20,7 @@ def img2vid(model_type, img0, img1, frame_ratio, iters):
20
  ckpt_path = hf_hub_download(repo_id='lalala125/AMT', filename=f'{model_type.lower()}.pth')
21
  print(model_type)
22
  ckpt = torch.load(ckpt_path)
23
- model.load_state_dict(ckpt['state_dict'])
24
  model.eval()
25
  img0_t = img2tensor(img0).to(device)
26
  img1_t = img2tensor(img1).to(device)
 
20
  ckpt_path = hf_hub_download(repo_id='lalala125/AMT', filename=f'{model_type.lower()}.pth')
21
  print(model_type)
22
  ckpt = torch.load(ckpt_path)
23
+ model.load_state_dict(ckpt['state_dict'], map_location=torch.device('cpu'))
24
  model.eval()
25
  img0_t = img2tensor(img0).to(device)
26
  img1_t = img2tensor(img1).to(device)