jonathanpark commited on
Commit
e4d55ce
1 Parent(s): 6c599c3

try adding custom handler

Browse files
Files changed (3) hide show
  1. .vscode/settings.json +1 -0
  2. handler.py +81 -0
  3. requirements.txt +1 -0
.vscode/settings.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {}
handler.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
2
+ import torch
3
+ import requests
4
+ from PIL import Image
5
+ from io import BytesIO
6
+ from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline, DDIMScheduler
7
+
8
+ # set device
9
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
10
+
11
+ if device.type != 'cuda':
12
+ raise ValueError("need to run on GPU")
13
+
14
+ model_id = "stabilityai/stable-diffusion-2-1-base"
15
+
16
+ class EndpointHandler():
17
+ def __init__(self, path=""):
18
+ # load the optimized model
19
+ self.textPipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
20
+ self.textPipe.scheduler = DDIMScheduler.from_config(self.textPipe.scheduler.config)
21
+ self.textPipe = self.textPipe.to(device)
22
+
23
+ # create an img2img model
24
+ self.imgPipe = StableDiffusionImg2ImgPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
25
+ self.imgPipe.scheduler = DDIMScheduler.from_config(self.imgPipe.scheduler.config)
26
+ self.imgPipe = self.imgPipe.to(device)
27
+
28
+ def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
29
+ """
30
+ Args:
31
+ data (:obj:):
32
+ includes the input data and the parameters for the inference.
33
+ Return:
34
+ A :obj:`dict`:. base64 encoded image
35
+ """
36
+ prompt = data.pop("inputs", data)
37
+ url = data.pop("src", data)
38
+ response = requests.get(url)
39
+ init_image = Image.open(BytesIO(response.content)).convert("RGB")
40
+ init_image.thumbnail((512, 512))
41
+
42
+ params = data.pop("parameters", data)
43
+
44
+ # hyperparamters
45
+ num_inference_steps = params.pop("num_inference_steps", 25)
46
+ guidance_scale = params.pop("guidance_scale", 7.5)
47
+ negative_prompt = params.pop("negative_prompt", None)
48
+ height = params.pop("height", None)
49
+ width = params.pop("width", None)
50
+ manual_seed = params.pop("manual_seed", -1)
51
+
52
+ out = None
53
+
54
+ if data.get("url"):
55
+ generator = torch.Generator(device='cuda')
56
+ generator.manual_seed(manual_seed)
57
+ # run img2img pipeline
58
+ out = self.imgPipe(prompt,
59
+ image=init_image,
60
+ num_inference_steps=num_inference_steps,
61
+ guidance_scale=guidance_scale,
62
+ num_images_per_prompt=1,
63
+ negative_prompt=negative_prompt,
64
+ height=height,
65
+ width=width
66
+ )
67
+ else:
68
+ # run text pipeline
69
+ out = self.textPipe(prompt,
70
+ image=init_image,
71
+ num_inference_steps=num_inference_steps,
72
+ guidance_scale=guidance_scale,
73
+ num_images_per_prompt=1,
74
+ negative_prompt=negative_prompt,
75
+ height=height,
76
+ width=width
77
+ )
78
+
79
+
80
+ # return first generated PIL image
81
+ return out.images[0]
requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ diffusers==0.10.2