File size: 1,747 Bytes
2c83deb
d70b1c6
2c83deb
 
 
 
 
26fde58
 
 
 
 
 
2c83deb
 
 
 
 
2c745bf
 
2c83deb
 
 
2c745bf
2c83deb
22d7300
01e3be6
 
2c83deb
 
be8f6aa
2c83deb
 
4fa2b44
2c83deb
 
 
be8f6aa
 
 
2c83deb
01e3be6
 
 
 
 
 
 
 
22d7300
01e3be6
 
 
22d7300
2c83deb
01e3be6
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
from typing import  Dict, List, Any
from PIL import Image
import torch
from torch import autocast
from diffusers import StableDiffusionPipeline
import base64
from io import BytesIO
from transformers.utils import logging

logging.set_verbosity_info()
logger = logging.get_logger("transformers")
logger.info("INFO")
logger.warning("WARN")


# set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

#if device.type != 'cuda':
    #raise ValueError("need to run on GPU")

class EndpointHandler():
    def __init__(self, path=""):
        self.path = path
        # load the optimized model
        model_id = "stabilityai/stable-diffusion-x4-upscaler"
        #self.pipe = StableDiffusionPipeline.from_pretrained(path, torch_dtype=torch.float16)
        #self.pipe = self.pipe.to(device)


    def __call__(self, data) -> List[Dict[str, Any]]:
        """
        Args:
            image (:obj:`string`)
        Return:
            A :obj:`dict`:. base64 encoded image
        """
        logger.info('data received %s', data)
        inputs = data.get("inputs")
        logger.info('inputs received %s', inputs)

        image_base64 = base64.b64decode(inputs['image'])
        logger.info('image_base64')
        image_bytes = BytesIO(image_base64)
        logger.info('image_bytes')
        image = Image.open(image_bytes)
        logger.info('image')
        #with autocast(device.type):
        #    upscaled_image = self.pipe(prompt="", image = decoded_image).images[0]

        #buffered = BytesIO()
        #upscaled_image.save(buffered, format="JPEG")
        #img_str = base64.b64encode(buffered.getvalue())

        # postprocess the prediction
        #return {"image": img_str}
        return {"image": "test"}