Koi953215 commited on
Commit
79910d2
1 Parent(s): 32f7278
Files changed (1) hide show
  1. app.py +7 -6
app.py CHANGED
@@ -15,6 +15,7 @@ from NaRCan_model import Homography, Siren
15
  from util import get_mgrid, apply_homography, jacobian, VideoFitting, TestVideoFitting
16
 
17
 
 
18
 
19
  def get_example():
20
  case = [
@@ -116,9 +117,9 @@ def NaRCan_make_video(edit_canonical, pth_path, frames_path):
116
  # load NaRCan model
117
  checkpoint_g_old = torch.load(os.path.join(pth_path, "homography_g.pth"))
118
  checkpoint_g = torch.load(os.path.join(pth_path, "mlp_g.pth"))
119
- g_old = Homography(hidden_features=256, hidden_layers=2).cuda()
120
  g = Siren(in_features=3, out_features=2, hidden_features=256,
121
- hidden_layers=5, outermost_linear=True).cuda()
122
 
123
  g_old.load_state_dict(checkpoint_g_old)
124
  g.load_state_dict(checkpoint_g)
@@ -135,7 +136,7 @@ def NaRCan_make_video(edit_canonical, pth_path, frames_path):
135
  videoloader = DataLoader(v, batch_size=1, pin_memory=True, num_workers=0)
136
 
137
  model_input, ground_truth = next(iter(videoloader))
138
- model_input, ground_truth = model_input[0].cuda(), ground_truth[0].cuda()
139
 
140
  myoutput = None
141
  data_len = len(os.listdir(frames_path))
@@ -156,7 +157,7 @@ def NaRCan_make_video(edit_canonical, pth_path, frames_path):
156
  # use canonical to reconstruct
157
  w, h = v.W, v.H
158
  canonical_img = np.array(edit_canonical.convert('RGB'))
159
- canonical_img = torch.from_numpy(canonical_img).float().cuda()
160
  h_c, w_c = canonical_img.shape[:2]
161
  grid_new = xy_.clone()
162
  grid_new[..., 1] = xy_[..., 0] / 1.5
@@ -204,7 +205,7 @@ def edit_with_pnp(input_video, prompt, num_steps, guidance_scale, seed, n_prompt
204
  pipe = StableDiffusionControlNetPipeline.from_pretrained(
205
  "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16
206
  )
207
- pipe.to("cuda")
208
  # lineart
209
  processor = LineartDetector.from_pretrained("lllyasviel/Annotators")
210
  processor_partial = partial(processor, coarse=False)
@@ -231,7 +232,7 @@ def edit_with_pnp(input_video, prompt, num_steps, guidance_scale, seed, n_prompt
231
  pipe = StableDiffusionControlNetPipeline.from_pretrained(
232
  "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16
233
  )
234
- pipe.to("cuda")
235
  # canny
236
  canonical_image = cv2.imread(image_path)
237
  canonical_image = cv2.cvtColor(canonical_image, cv2.COLOR_BGR2RGB)
 
15
  from util import get_mgrid, apply_homography, jacobian, VideoFitting, TestVideoFitting
16
 
17
 
18
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
19
 
20
  def get_example():
21
  case = [
 
117
  # load NaRCan model
118
  checkpoint_g_old = torch.load(os.path.join(pth_path, "homography_g.pth"))
119
  checkpoint_g = torch.load(os.path.join(pth_path, "mlp_g.pth"))
120
+ g_old = Homography(hidden_features=256, hidden_layers=2).to(device)
121
  g = Siren(in_features=3, out_features=2, hidden_features=256,
122
+ hidden_layers=5, outermost_linear=True).to(device)
123
 
124
  g_old.load_state_dict(checkpoint_g_old)
125
  g.load_state_dict(checkpoint_g)
 
136
  videoloader = DataLoader(v, batch_size=1, pin_memory=True, num_workers=0)
137
 
138
  model_input, ground_truth = next(iter(videoloader))
139
+ model_input, ground_truth = model_input[0].to(device), ground_truth[0].to(device)
140
 
141
  myoutput = None
142
  data_len = len(os.listdir(frames_path))
 
157
  # use canonical to reconstruct
158
  w, h = v.W, v.H
159
  canonical_img = np.array(edit_canonical.convert('RGB'))
160
+ canonical_img = torch.from_numpy(canonical_img).float().to(device)
161
  h_c, w_c = canonical_img.shape[:2]
162
  grid_new = xy_.clone()
163
  grid_new[..., 1] = xy_[..., 0] / 1.5
 
205
  pipe = StableDiffusionControlNetPipeline.from_pretrained(
206
  "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16
207
  )
208
+ pipe.to(device)
209
  # lineart
210
  processor = LineartDetector.from_pretrained("lllyasviel/Annotators")
211
  processor_partial = partial(processor, coarse=False)
 
232
  pipe = StableDiffusionControlNetPipeline.from_pretrained(
233
  "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16
234
  )
235
+ pipe.to(device)
236
  # canny
237
  canonical_image = cv2.imread(image_path)
238
  canonical_image = cv2.cvtColor(canonical_image, cv2.COLOR_BGR2RGB)