File size: 2,927 Bytes
e4d55ce
 
 
 
 
 
c6d0824
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e4d55ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3cba83e
f37021f
ca2f7e7
c6d0824
 
f37021f
 
c6d0824
f37021f
c6d0824
 
 
 
 
 
 
 
e4d55ce
c6d0824
e4d55ce
c6d0824
 
 
 
 
 
 
 
 
64362a8
 
c6d0824
e4d55ce
c6d0824
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
from typing import  Dict, List, Any
import torch
from PIL import Image
from io import BytesIO
from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline, DDIMScheduler

import base64
import requests
from io import BytesIO
from PIL import Image

def load_image(image_url):
    if image_url.startswith('data:'):
        # Decode base64 data_uri
        image_data = base64.b64decode(image_url.split(',')[1])
        image = Image.open(BytesIO(image_data))
    else:
        # Load standard image url
        response = requests.get(image_url)
        image = Image.open(BytesIO(response.content))
    return image

# set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

if device.type != 'cuda':
    raise ValueError("need to run on GPU")

model_id = "stabilityai/stable-diffusion-2-1-base"

class EndpointHandler():
    def __init__(self, path=""):
        # load the optimized model
        self.textPipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
        self.textPipe.scheduler = DDIMScheduler.from_config(self.textPipe.scheduler.config)
        self.textPipe = self.textPipe.to(device)

        # create an img2img model
        self.imgPipe = StableDiffusionImg2ImgPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
        self.imgPipe.scheduler = DDIMScheduler.from_config(self.imgPipe.scheduler.config)
        self.imgPipe = self.imgPipe.to(device)

    def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
        """
        Args:
            data (:obj:):
                includes the input data and the parameters for the inference.
        Return:
            A :obj:`dict`:. base64 encoded image
        """
        prompt = data.pop("inputs", data)
        url = data.pop("url", data)

        init_image = load_image(url).convert("RGB")
        init_image.thumbnail((512, 512))


        params = data.pop("parameters", data)

        # hyperparamters
        num_inference_steps = params.pop("num_inference_steps", 25)
        guidance_scale = params.pop("guidance_scale", 7.5)
        negative_prompt = params.pop("negative_prompt", None)
        prompt = params.pop("prompt", None)
        height = params.pop("height", None)
        width = params.pop("width", None)
        manual_seed = params.pop("manual_seed", -1)

        out = None

        generator = torch.Generator(device='cuda')
        generator.manual_seed(manual_seed)
        # run img2img pipeline
        out = self.imgPipe(prompt,
                    image=init_image,
                    num_inference_steps=num_inference_steps,
                    guidance_scale=guidance_scale,
                    num_images_per_prompt=1,
                    negative_prompt=negative_prompt,
                    # height=height,
                    # width=width
        )

        # return first generated PIL image
        return out.images[0]