nateraw commited on
Commit
fa550ea
1 Parent(s): 6be24a2

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +10 -8
handler.py CHANGED
@@ -10,13 +10,12 @@ from realesrgan import RealESRGANer
10
 
11
  class EndpointHandler:
12
  def __init__(self, path=""):
13
- model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
14
  self.upsampler = RealESRGANer(
15
  scale=4,
16
  model_path=str(Path(path) / "RealESRGAN_x4plus.pth"),
17
  model=model,
18
  tile=0,
19
- tile_pad=10,
20
  pre_pad=0,
21
  half=True,
22
  )
@@ -30,18 +29,21 @@ class EndpointHandler:
30
  A :obj:`dict`:. base64 encoded image
31
  """
32
  image = data.pop("inputs", data)
33
- # image = Image.open(BytesIO(image)).convert("RGB")
 
 
 
 
34
  image = np.array(image)
35
  image = image[:, :, ::-1] # RGB -> BGR
36
-
37
  image, _ = self.upsampler.enhance(image, outscale=4)
38
  image = image[:, :, ::-1] # BGR -> RGB
39
  image = Image.fromarray(image)
40
 
41
- # encode image as base 64
42
  buffered = BytesIO()
43
  image.save(buffered, format="PNG")
44
- img_str = b64encode(buffered.getvalue())
 
45
 
46
- # postprocess the prediction
47
- return {"image": img_str.decode()}
 
10
 
11
  class EndpointHandler:
12
  def __init__(self, path=""):
13
+ model = RRDBNet(num_in_ch=3, num_out_ch=3)
14
  self.upsampler = RealESRGANer(
15
  scale=4,
16
  model_path=str(Path(path) / "RealESRGAN_x4plus.pth"),
17
  model=model,
18
  tile=0,
 
19
  pre_pad=0,
20
  half=True,
21
  )
 
29
  A :obj:`dict`:. base64 encoded image
30
  """
31
  image = data.pop("inputs", data)
32
+
33
+ # This lets us pass local images as well while developing
34
+ if isinstance(image, bytes):
35
+ image = Image.open(BytesIO(image)) # .convert("RGB")
36
+
37
  image = np.array(image)
38
  image = image[:, :, ::-1] # RGB -> BGR
 
39
  image, _ = self.upsampler.enhance(image, outscale=4)
40
  image = image[:, :, ::-1] # BGR -> RGB
41
  image = Image.fromarray(image)
42
 
43
+ # Turn output image into bytestr
44
  buffered = BytesIO()
45
  image.save(buffered, format="PNG")
46
+ img_bytes = b64encode(buffered.getvalue())
47
+ img_str = img_bytes.decode()
48
 
49
+ return {"image": img_str}