File size: 2,351 Bytes
82e3da2
 
 
 
 
1e58902
 
 
370e492
 
1e58902
 
82e3da2
 
 
 
 
 
1e58902
 
82e3da2
 
1e58902
 
82e3da2
 
1e58902
 
82e3da2
1e58902
82e3da2
 
 
 
 
 
 
 
1e58902
 
 
 
 
 
 
370e492
1e58902
370e492
1e58902
 
370e492
1e58902
 
82e3da2
1e58902
 
 
 
 
82e3da2
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
from ts.torch_handler.base_handler import BaseHandler
import torch
import torchvision.transforms as transforms
from PIL import Image
import io
import time
import logging
import torch.nn.functional as F
import numpy as np
from skimage.metrics import peak_signal_noise_ratio, structural_similarity

logger = logging.getLogger(__name__)

class ImageHandler(BaseHandler):
    def __init__(self):
        super(ImageHandler, self).__init__()
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.transform = transforms.Compose([transforms.ToTensor()])
        self.input_tensor_for_metrics = None
        self.start_time = 0

    def preprocess(self, data):
        self.start_time = time.time()

        image_bytes = data[0].get("body")
        image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
        width, height = image.size
        logger.info(f"DATA_QUALITY: resolution={width}x{height}, format={image.format}")
        tensor = self.transform(image).unsqueeze(0).to(self.device)
        self.input_tensor_for_metrics = tensor.clone().detach()
        return tensor

    def inference(self, data, *args, **kwargs):
        with torch.no_grad():
            output = self.model(data)
        return output

    def postprocess(self, data):
        output_batched = data
        input_batched = self.input_tensor_for_metrics
        output_tensor = output_batched.squeeze(0).cpu().clamp(0, 1)
        input_tensor = input_batched.squeeze(0).cpu()
        output_tensor_resized = output_tensor
        if output_tensor.shape != input_tensor.shape:
             output_tensor_resized = F.interpolate(
                 output_tensor.unsqueeze(0),
                 size=input_tensor.shape[-2:],
                 mode='bilinear',
                 align_corners=False
             ).squeeze(0)

        pixel_difference = torch.mean(torch.abs(input_tensor - output_tensor_resized)).item()
        logger.info(f"OUTPUT_QUALITY: denoising_intensity={pixel_difference:.4f}")

        end_time = time.time()
        latency_ms = (end_time - self.start_time) * 1000
        logger.info(f"OPERATIONAL_HEALTH: total_latency={latency_ms:.2f}ms")

        output_image = transforms.ToPILImage()(output_tensor)
        buf = io.BytesIO()
        output_image.save(buf, format="PNG")
        return [buf.getvalue()]