NikeZoldyck commited on
Commit
88bec8b
1 Parent(s): 7f88488

Update utils/shared_utils.py

Browse files
Files changed (1) hide show
  1. utils/shared_utils.py +4 -2
utils/shared_utils.py CHANGED
@@ -21,9 +21,9 @@ from utils.photo_smooth import Propagator
21
  root = Path.cwd()
22
  device = "cuda" if torch.cuda.is_available() else "cpu"
23
  # Load model
24
- p_wct = PhotoWCT()
25
  p_wct.load_state_dict(torch.load(root/"models/components/photo_wct.pth"))
26
- p_pro = Propagator()
27
  stylization_module=p_wct
28
  smoothing_module=p_pro
29
 
@@ -115,6 +115,8 @@ def style_transfer(cont_img,styl_img):
115
  return stylized_img
116
 
117
  def smoother(stylized_img, over_img):
 
 
118
  final_img = smoothing_module.process(stylized_img, over_img)
119
  #final_img = smooth_filter(stylized_img, over_img, f_radius=15, f_edge=1e-1)
120
  return final_img
 
21
  root = Path.cwd()
22
  device = "cuda" if torch.cuda.is_available() else "cpu"
23
  # Load model
24
+ p_wct = PhotoWCT().to(device)
25
  p_wct.load_state_dict(torch.load(root/"models/components/photo_wct.pth"))
26
+ p_pro = Propagator().to(device)
27
  stylization_module=p_wct
28
  smoothing_module=p_pro
29
 
 
115
  return stylized_img
116
 
117
  def smoother(stylized_img, over_img):
118
+ if device == 'cuda':
119
+ smoothing_module.to(device)
120
  final_img = smoothing_module.process(stylized_img, over_img)
121
  #final_img = smooth_filter(stylized_img, over_img, f_radius=15, f_edge=1e-1)
122
  return final_img