jjeamin commited on
Commit
ab42d96
1 Parent(s): c70c0c7

Fix device cuda -> cpu

Browse files
Files changed (1) hide show
  1. app.py +6 -4
app.py CHANGED
@@ -8,13 +8,15 @@ from tqdm.notebook import tqdm
8
  import gradio as gr
9
  import torch
10
 
 
 
11
  image_size = 512
12
 
13
  means = [0.5, 0.5, 0.5]
14
  stds = [0.5, 0.5, 0.5]
15
 
16
  model_path = hf_hub_download(repo_id="jjeamin/ArcaneStyleTransfer", filename="pytorch_model.bin")
17
- style_transfer = torch.jit.load(model_path).eval().cuda().half()
18
 
19
  mtcnn = MTCNN(image_size=image_size, margin=80)
20
 
@@ -76,8 +78,8 @@ def scale_by_face_size(_img, max_res=1_500_000, target_face=256, fix_ratio=0, ma
76
  img_resized = scale(boxes, _img, max_res, target_face, fix_ratio, max_upscale, VERBOSE)
77
  return img_resized
78
 
79
- t_stds = torch.tensor(stds).cuda().half()[:,None,None]
80
- t_means = torch.tensor(means).cuda().half()[:,None,None]
81
 
82
  img_transforms = transforms.Compose([
83
  transforms.ToTensor(),
@@ -87,7 +89,7 @@ def tensor2im(var):
87
  return var.mul(t_stds).add(t_means).mul(255.).clamp(0,255).permute(1,2,0)
88
 
89
  def proc_pil_img(input_image):
90
- transformed_image = img_transforms(input_image)[None,...].cuda().half()
91
 
92
  with torch.no_grad():
93
  result_image = style_transfer(transformed_image)[0]
 
8
  import gradio as gr
9
  import torch
10
 
11
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
12
+
13
  image_size = 512
14
 
15
  means = [0.5, 0.5, 0.5]
16
  stds = [0.5, 0.5, 0.5]
17
 
18
  model_path = hf_hub_download(repo_id="jjeamin/ArcaneStyleTransfer", filename="pytorch_model.bin")
19
+ style_transfer = torch.jit.load(model_path).eval().to(device).half()
20
 
21
  mtcnn = MTCNN(image_size=image_size, margin=80)
22
 
 
78
  img_resized = scale(boxes, _img, max_res, target_face, fix_ratio, max_upscale, VERBOSE)
79
  return img_resized
80
 
81
+ t_stds = torch.tensor(stds).to(device).half()[:,None,None]
82
+ t_means = torch.tensor(means).to(device).half()[:,None,None]
83
 
84
  img_transforms = transforms.Compose([
85
  transforms.ToTensor(),
 
89
  return var.mul(t_stds).add(t_means).mul(255.).clamp(0,255).permute(1,2,0)
90
 
91
  def proc_pil_img(input_image):
92
+ transformed_image = img_transforms(input_image)[None,...].to(device).half()
93
 
94
  with torch.no_grad():
95
  result_image = style_transfer(transformed_image)[0]