Ahsen Khaliq commited on
Commit
cb897c1
1 Parent(s): f3fa32b

gpu updates

Browse files
Files changed (1) hide show
  1. app.py +8 -5
app.py CHANGED
@@ -90,8 +90,8 @@ size = 256
90
  means = [0.485, 0.456, 0.406]
91
  stds = [0.229, 0.224, 0.225]
92
 
93
- t_stds = torch.tensor(stds).cpu()[:,None,None]
94
- t_means = torch.tensor(means).cpu()[:,None,None]
95
 
96
  def makeEven(_x):
97
  return int(_x) if (_x % 2 == 0) else int(_x+1)
@@ -104,7 +104,7 @@ def tensor2im(var):
104
  return var.mul(t_stds).add(t_means).mul(255.).clamp(0,255).permute(1,2,0)
105
 
106
  def proc_pil_img(input_image, model):
107
- transformed_image = img_transforms(input_image)[None,...].cpu()
108
 
109
  with torch.no_grad():
110
  result_image = model(transformed_image)[0]; print(result_image.shape)
@@ -112,6 +112,9 @@ def proc_pil_img(input_image, model):
112
  output_image = output_image.detach().cpu().numpy().astype('uint8')
113
  output_image = PIL.Image.fromarray(output_image)
114
  return output_image
 
 
 
115
 
116
  def fit(img,maxsize=512):
117
  maxdim = max(*img.size)
@@ -124,9 +127,9 @@ def fit(img,maxsize=512):
124
 
125
  def process(im, version):
126
  if version == 'version 0.3':
127
- model = torch.jit.load('./ArcaneGANv0.3.jit',map_location='cpu').to('cpu').float().eval().cpu()
128
  else:
129
- model = torch.jit.load('./ArcaneGANv0.2.jit',map_location='cpu').to('cpu').float().eval().cpu()
130
  im = scale_by_face_size(im, target_face=300, max_res=1_500_000, max_upscale=2)
131
  res = proc_pil_img(im, model)
132
  return res
90
  means = [0.485, 0.456, 0.406]
91
  stds = [0.229, 0.224, 0.225]
92
 
93
+ t_stds = torch.tensor(stds).cuda().half()[:,None,None]
94
+ t_means = torch.tensor(means).cuda().half()[:,None,None]
95
 
96
  def makeEven(_x):
97
  return int(_x) if (_x % 2 == 0) else int(_x+1)
104
  return var.mul(t_stds).add(t_means).mul(255.).clamp(0,255).permute(1,2,0)
105
 
106
  def proc_pil_img(input_image, model):
107
+ transformed_image = img_transforms(input_image)[None,...].cuda().half()
108
 
109
  with torch.no_grad():
110
  result_image = model(transformed_image)[0]; print(result_image.shape)
112
  output_image = output_image.detach().cpu().numpy().astype('uint8')
113
  output_image = PIL.Image.fromarray(output_image)
114
  return output_image
115
+
116
+
117
+
118
 
119
  def fit(img,maxsize=512):
120
  maxdim = max(*img.size)
127
 
128
  def process(im, version):
129
  if version == 'version 0.3':
130
+ model = torch.jit.load('./ArcaneGANv0.3.jit').eval().cuda().half()
131
  else:
132
+ model = torch.jit.load('./ArcaneGANv0.2.jit').eval().cuda().half()
133
  im = scale_by_face_size(im, target_face=300, max_res=1_500_000, max_upscale=2)
134
  res = proc_pil_img(im, model)
135
  return res