nateraw commited on
Commit
674ae47
1 Parent(s): 7127bf8

Create new file

Browse files
Files changed (1) hide show
  1. handler.py +47 -0
handler.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from base64 import b64encode
2
+ from io import BytesIO
3
+ from pathlib import Path
4
+
5
+ import numpy as np
6
+ from basicsr.archs.rrdbnet_arch import RRDBNet
7
+ from PIL import Image
8
+ from realesrgan import RealESRGANer
9
+
10
+
11
+ class EndpointHandler:
12
+ def __init__(self, path=""):
13
+ model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
14
+ self.upsampler = RealESRGANer(
15
+ scale=4,
16
+ model_path=str(Path(path) / "RealESRGAN_x4plus.pth"),
17
+ model=model,
18
+ tile=0,
19
+ tile_pad=10,
20
+ pre_pad=0,
21
+ half=True,
22
+ )
23
+
24
+ def __call__(self, data):
25
+ """
26
+ Args:
27
+ data (:obj:):
28
+ includes the input data and the parameters for the inference.
29
+ Return:
30
+ A :obj:`dict`:. base64 encoded image
31
+ """
32
+ image = data.pop("inputs", data)
33
+ image = Image.open(BytesIO(image)).convert("RGB")
34
+ image = np.array(image)
35
+ image = image[:, :, ::-1] # RGB -> BGR
36
+
37
+ image, _ = self.upsampler.enhance(image, outscale=4)
38
+ image = image[:, :, ::-1] # BGR -> RGB
39
+ image = Image.fromarray(image)
40
+
41
+ # encode image as base 64
42
+ buffered = BytesIO()
43
+ image.save(buffered, format="JPEG")
44
+ img_str = b64encode(buffered.getvalue())
45
+
46
+ # postprocess the prediction
47
+ return {"image": img_str.decode()}