ohayonguy commited on
Commit
3cd5bc5
1 Parent(s): 7f45a1a
Files changed (2) hide show
  1. app.py +3 -2
  2. arch/hourglass/image_transformer_v2.py +1 -1
app.py CHANGED
@@ -76,8 +76,9 @@ def enhance_face(img, face_helper, has_aligned, only_center_face=False, paste_ba
76
  cropped_face_t = cropped_face_t.unsqueeze(0).to(device)
77
 
78
  dummy_x = torch.zeros_like(cropped_face_t)
79
- output = generate_reconstructions(pmrf, dummy_x, cropped_face_t, None, 25, device)
80
- restored_face = tensor2img(output.squeeze(0), rgb2bgr=True, min_max=(0, 1))
 
81
  # restored_face = cropped_face
82
 
83
  restored_face = restored_face.astype('uint8')
 
76
  cropped_face_t = cropped_face_t.unsqueeze(0).to(device)
77
 
78
  dummy_x = torch.zeros_like(cropped_face_t)
79
+ with torch.autocast("cuda", dtype=torch.bfloat16):
80
+ output = generate_reconstructions(pmrf, dummy_x, cropped_face_t, None, 25, device)
81
+ restored_face = tensor2img(output.to(torch.float32).squeeze(0), rgb2bgr=True, min_max=(0, 1))
82
  # restored_face = cropped_face
83
 
84
  restored_face = restored_face.astype('uint8')
arch/hourglass/image_transformer_v2.py CHANGED
@@ -92,7 +92,7 @@ def linear_geglu(x, weight, bias=None):
92
  if bias is not None:
93
  x = x + bias
94
  x, gate = x.chunk(2, dim=-1)
95
- return x.clone() * F.gelu(gate.clone())
96
 
97
 
98
  def rms_norm(x, scale, eps):
 
92
  if bias is not None:
93
  x = x + bias
94
  x, gate = x.chunk(2, dim=-1)
95
+ return x * F.gelu(gate)
96
 
97
 
98
  def rms_norm(x, scale, eps):