cocktailpeanut commited on
Commit
ce91763
1 Parent(s): 7639b35
src/utils/frame_interpolation.py CHANGED
@@ -5,6 +5,13 @@ import torch
5
  import bisect
6
  import shutil
7
 
 
 
 
 
 
 
 
8
  def init_frame_interpolation_model():
9
  print("Initializing frame interpolation model")
10
  checkpoint_name = os.path.join("./pretrained_model/film_net_fp16.pt")
@@ -12,7 +19,7 @@ def init_frame_interpolation_model():
12
  model = torch.load(checkpoint_name, map_location='cpu')
13
  model.eval()
14
  model = model.half()
15
- model = model.to(device="cuda")
16
  return model
17
 
18
 
@@ -54,8 +61,8 @@ def batch_images_interpolation_tool(input_file, model, fps, inter_frames=1):
54
 
55
  x0 = x0.half()
56
  x1 = x1.half()
57
- x0 = x0.cuda()
58
- x1 = x1.cuda()
59
 
60
  dt = x0.new_full((1, 1), (splits[remains[step]] - splits[idxes[start_i]])) / (splits[idxes[end_i]] - splits[idxes[start_i]])
61
 
@@ -87,4 +94,4 @@ def batch_images_interpolation_tool(input_file, model, fps, inter_frames=1):
87
 
88
  shutil.rmtree(image_save_dir)
89
 
90
- return video_save_dir
 
5
  import bisect
6
  import shutil
7
 
8
+ if torch.backends.mps.is_available():
9
+ device = "mps"
10
+ elif torch.cuda.is_available():
11
+ device = "cuda"
12
+ else:
13
+ device = "cpu"
14
+
15
  def init_frame_interpolation_model():
16
  print("Initializing frame interpolation model")
17
  checkpoint_name = os.path.join("./pretrained_model/film_net_fp16.pt")
 
19
  model = torch.load(checkpoint_name, map_location='cpu')
20
  model.eval()
21
  model = model.half()
22
+ model = model.to(device=device)
23
  return model
24
 
25
 
 
61
 
62
  x0 = x0.half()
63
  x1 = x1.half()
64
+ x0 = x0.to(device)
65
+ x1 = x1.to(device)
66
 
67
  dt = x0.new_full((1, 1), (splits[remains[step]] - splits[idxes[start_i]])) / (splits[idxes[end_i]] - splits[idxes[start_i]])
68
 
 
94
 
95
  shutil.rmtree(image_save_dir)
96
 
97
+ return video_save_dir
src/utils/util.py CHANGED
@@ -12,6 +12,13 @@ import torchvision
12
  from einops import rearrange
13
  from PIL import Image
14
 
 
 
 
 
 
 
 
15
 
16
  def seed_everything(seed):
17
  import random
@@ -19,7 +26,8 @@ def seed_everything(seed):
19
  import numpy as np
20
 
21
  torch.manual_seed(seed)
22
- torch.cuda.manual_seed_all(seed)
 
23
  np.random.seed(seed % (2**32))
24
  random.seed(seed)
25
 
 
12
  from einops import rearrange
13
  from PIL import Image
14
 
15
+ if torch.backends.mps.is_available():
16
+ device = "mps"
17
+ elif torch.cuda.is_available():
18
+ device = "cuda"
19
+ else:
20
+ device = "cpu"
21
+
22
 
23
  def seed_everything(seed):
24
  import random
 
26
  import numpy as np
27
 
28
  torch.manual_seed(seed)
29
+ if device == "cuda":
30
+ torch.cuda.manual_seed_all(seed)
31
  np.random.seed(seed % (2**32))
32
  random.seed(seed)
33