Jacobmadwed commited on
Commit
f24cb6a
·
verified ·
1 Parent(s): f719992

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +21 -16
handler.py CHANGED
@@ -1,31 +1,36 @@
1
  import os
2
- from diffusers import DiffusionPipeline
 
3
  from PIL import Image
 
4
 
5
  class EndpointHandler:
6
- def __init__(self, model_dir="stabilityai/stable-diffusion-x4-upscaler"):
7
- # Load the diffusion pipeline
8
- self.pipeline = DiffusionPipeline.from_pretrained(model_dir)
9
-
 
 
 
10
  # Ensure the output directory exists
11
  os.makedirs('output', exist_ok=True)
12
 
13
- def upscale_image(self, image_path):
14
  # Load the image
15
- image = Image.open(image_path)
16
 
17
- # Upscale the image
18
- upscaled_image = self.pipeline(image)
19
 
20
- # Save the upscaled image
21
- save_path = "output/upscaled_image.png"
22
- upscaled_image.save(save_path)
23
-
24
- return upscaled_image, save_path
25
 
26
  # Example usage for testing
27
  if __name__ == "__main__":
28
  handler = EndpointHandler()
29
  test_image_path = 'path_to_test_image.jpg' # Replace with your test image path
30
- upscaled_image, save_path = handler.upscale_image(test_image_path)
31
- print(f"Upscaled image saved at {save_path}")
 
1
  import os
2
+ import torch
3
+ from gfpgan import GFPGANer
4
  from PIL import Image
5
+ import cv2
6
 
7
  class EndpointHandler:
8
+ def __init__(self, model_path='GFPGANv1.4.pth'):
9
+ # Load the GFPGAN model
10
+ self.model_path = model_path
11
+ self.bg_upsampler = None # You can set this to RealESRGANer if needed
12
+ self.face_enhancer = GFPGANer(
13
+ model_path=self.model_path, upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=self.bg_upsampler)
14
+
15
  # Ensure the output directory exists
16
  os.makedirs('output', exist_ok=True)
17
 
18
+ def enhance_image(self, image_path):
19
  # Load the image
20
+ img = cv2.imread(image_path, cv2.IMREAD_UNCHANGED)
21
 
22
+ # Perform face enhancement
23
+ _, _, output = self.face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True)
24
 
25
+ # Save the enhanced image
26
+ save_path = "output/enhanced_image.png"
27
+ cv2.imwrite(save_path, output)
28
+
29
+ return output, save_path
30
 
31
  # Example usage for testing
32
  if __name__ == "__main__":
33
  handler = EndpointHandler()
34
  test_image_path = 'path_to_test_image.jpg' # Replace with your test image path
35
+ enhanced_image, save_path = handler.enhance_image(test_image_path)
36
+ print(f"Enhanced image saved at {save_path}")