bluestarburst commited on
Commit
6b1429b
1 Parent(s): fc56d40

Upload handler.py

Browse files
Files changed (1) hide show
  1. handler.py +92 -0
handler.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
2
+ import base64
3
+ from PIL import Image
4
+ from io import BytesIO
5
+ from diffusers import StableDiffusionImg2ImgPipeline
6
+ import torch
7
+
8
+
9
+ import numpy as np
10
+ import cv2
11
+ import controlnet_hinter
12
+
13
+ # set device
14
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
15
+ if device.type != 'cuda':
16
+ raise ValueError("need to run on GPU")
17
+ # set mixed precision dtype
18
+ dtype = torch.bfloat16 if torch.cuda.get_device_capability()[
19
+ 0] == 8 else torch.float16
20
+
21
+ model_id = "nitrosocke/Ghibli-Diffusion"
22
+
23
+
24
+ class EndpointHandler():
25
+ def __init__(self, path=""):
26
+ # define default controlnet id and load controlnet
27
+ # Load StableDiffusionControlNetPipeline
28
+ self.pipe = StableDiffusionImg2ImgPipeline.from_pretrained("nitrosocke/Ghibli-Diffusion", torch_dtype=torch.float16).to(
29
+ device
30
+ )
31
+ # Define Generator with seed
32
+ # self.generator = torch.Generator(device="cpu").manual_seed(3)
33
+ self.generator = torch.Generator(device=device).manual_seed(1024)
34
+
35
+ def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
36
+ """
37
+ :param data: A dictionary contains `inputs` and optional `image` field.
38
+ :return: A dictionary with `image` field contains image in base64.
39
+ """
40
+ prompt = data.pop("inputs", None)
41
+ image = data.pop("image", None)
42
+ strength = data.pop("strength", None)
43
+ steps = data.pop("steps", None)
44
+
45
+ # Check if neither prompt nor image is provided
46
+ if prompt is None and image is None:
47
+ return {"error": "Please provide a prompt and base64 encoded image."}
48
+
49
+ # hyperparamters
50
+ num_inference_steps = data.pop("num_inference_steps", 30)
51
+ guidance_scale = data.pop("guidance_scale", 7.5)
52
+ negative_prompt = data.pop("negative_prompt", None)
53
+ height = data.pop("height", None)
54
+ width = data.pop("width", None)
55
+ controlnet_conditioning_scale = data.pop(
56
+ "controlnet_conditioning_scale", 1.0)
57
+
58
+ # process image
59
+ image = self.decode_base64_image(image)
60
+ # control_image = CONTROLNET_MAPPING[self.control_type]["hinter"](image)
61
+
62
+ # run inference pipeline
63
+ # out = self.pipe(
64
+ # prompt=prompt,
65
+ # negative_prompt=negative_prompt,
66
+ # image=control_image,
67
+ # num_inference_steps=num_inference_steps,
68
+ # guidance_scale=strength,
69
+ # num_images_per_prompt=1,
70
+ # height=height,
71
+ # width=width,
72
+ # controlnet_conditioning_scale=controlnet_conditioning_scale,
73
+ # generator=self.generator
74
+ # )
75
+
76
+ out = pipe(
77
+ prompt=prompt,
78
+ image=image,
79
+ strength=0.75,
80
+ guidance_scale=7.5,
81
+ generator=self.generator
82
+ )
83
+
84
+ # return first generate PIL image
85
+ return out.images[0]
86
+
87
+ # helper to decode input image
88
+ def decode_base64_image(self, image_string):
89
+ base64_image = base64.b64decode(image_string)
90
+ buffer = BytesIO(base64_image)
91
+ image = Image.open(buffer).convert("RGB").thumbnail((768, 768))
92
+ return image