camenduru commited on
Commit
44c1a06
β€’
1 Parent(s): e67e909

Create worker.py

Browse files
Files changed (1) hide show
  1. worker.py +391 -0
worker.py ADDED
@@ -0,0 +1,391 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import math
3
+ import numpy as np
4
+ import torch
5
+ import safetensors.torch as sf
6
+
7
+ from PIL import Image
8
+ from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline
9
+ from diffusers import AutoencoderKL, UNet2DConditionModel, DDIMScheduler, EulerAncestralDiscreteScheduler, DPMSolverMultistepScheduler
10
+ from diffusers.models.attention_processor import AttnProcessor2_0
11
+ from transformers import CLIPTextModel, CLIPTokenizer
12
+ from briarmbg import BriaRMBG
13
+ from enum import Enum
14
+ from torch.hub import download_url_to_file
15
+
16
+ import runpod
17
+
18
+ # 'stablediffusionapi/realistic-vision-v51'
19
+ # 'runwayml/stable-diffusion-v1-5'
20
+ sd15_name = 'stablediffusionapi/realistic-vision-v51'
21
+ tokenizer = CLIPTokenizer.from_pretrained(sd15_name, subfolder="tokenizer")
22
+ text_encoder = CLIPTextModel.from_pretrained(sd15_name, subfolder="text_encoder")
23
+ vae = AutoencoderKL.from_pretrained(sd15_name, subfolder="vae")
24
+ unet = UNet2DConditionModel.from_pretrained(sd15_name, subfolder="unet")
25
+ rmbg = BriaRMBG.from_pretrained("briaai/RMBG-1.4")
26
+
27
+ # Change UNet
28
+
29
+ with torch.no_grad():
30
+ new_conv_in = torch.nn.Conv2d(8, unet.conv_in.out_channels, unet.conv_in.kernel_size, unet.conv_in.stride, unet.conv_in.padding)
31
+ new_conv_in.weight.zero_()
32
+ new_conv_in.weight[:, :4, :, :].copy_(unet.conv_in.weight)
33
+ new_conv_in.bias = unet.conv_in.bias
34
+ unet.conv_in = new_conv_in
35
+
36
+ unet_original_forward = unet.forward
37
+
38
+
39
+ def hooked_unet_forward(sample, timestep, encoder_hidden_states, **kwargs):
40
+ c_concat = kwargs['cross_attention_kwargs']['concat_conds'].to(sample)
41
+ c_concat = torch.cat([c_concat] * (sample.shape[0] // c_concat.shape[0]), dim=0)
42
+ new_sample = torch.cat([sample, c_concat], dim=1)
43
+ kwargs['cross_attention_kwargs'] = {}
44
+ return unet_original_forward(new_sample, timestep, encoder_hidden_states, **kwargs)
45
+
46
+
47
+ unet.forward = hooked_unet_forward
48
+
49
+ # Load
50
+
51
+ model_path = 'iclight_sd15_fc.safetensors'
52
+
53
+ if not os.path.exists(model_path):
54
+ download_url_to_file(url='https://huggingface.co/lllyasviel/ic-light/resolve/main/iclight_sd15_fc.safetensors', dst=model_path)
55
+
56
+ sd_offset = sf.load_file(model_path)
57
+ sd_origin = unet.state_dict()
58
+ keys = sd_origin.keys()
59
+ sd_merged = {k: sd_origin[k] + sd_offset[k] for k in sd_origin.keys()}
60
+ unet.load_state_dict(sd_merged, strict=True)
61
+ del sd_offset, sd_origin, sd_merged, keys
62
+
63
+ # Device
64
+
65
+ device = torch.device('cuda:1')
66
+ text_encoder = text_encoder.to(device=device, dtype=torch.float16)
67
+ vae = vae.to(device=device, dtype=torch.bfloat16)
68
+ unet = unet.to(device=device, dtype=torch.float16)
69
+ rmbg = rmbg.to(device=device, dtype=torch.float32)
70
+
71
+ # SDP
72
+
73
+ unet.set_attn_processor(AttnProcessor2_0())
74
+ vae.set_attn_processor(AttnProcessor2_0())
75
+
76
+ # Samplers
77
+
78
+ ddim_scheduler = DDIMScheduler(
79
+ num_train_timesteps=1000,
80
+ beta_start=0.00085,
81
+ beta_end=0.012,
82
+ beta_schedule="scaled_linear",
83
+ clip_sample=False,
84
+ set_alpha_to_one=False,
85
+ steps_offset=1,
86
+ )
87
+
88
+ euler_a_scheduler = EulerAncestralDiscreteScheduler(
89
+ num_train_timesteps=1000,
90
+ beta_start=0.00085,
91
+ beta_end=0.012,
92
+ steps_offset=1
93
+ )
94
+
95
+ dpmpp_2m_sde_karras_scheduler = DPMSolverMultistepScheduler(
96
+ num_train_timesteps=1000,
97
+ beta_start=0.00085,
98
+ beta_end=0.012,
99
+ algorithm_type="sde-dpmsolver++",
100
+ use_karras_sigmas=True,
101
+ steps_offset=1
102
+ )
103
+
104
+ # Pipelines
105
+
106
+ t2i_pipe = StableDiffusionPipeline(
107
+ vae=vae,
108
+ text_encoder=text_encoder,
109
+ tokenizer=tokenizer,
110
+ unet=unet,
111
+ scheduler=dpmpp_2m_sde_karras_scheduler,
112
+ safety_checker=None,
113
+ requires_safety_checker=False,
114
+ feature_extractor=None,
115
+ image_encoder=None
116
+ )
117
+
118
+ i2i_pipe = StableDiffusionImg2ImgPipeline(
119
+ vae=vae,
120
+ text_encoder=text_encoder,
121
+ tokenizer=tokenizer,
122
+ unet=unet,
123
+ scheduler=dpmpp_2m_sde_karras_scheduler,
124
+ safety_checker=None,
125
+ requires_safety_checker=False,
126
+ feature_extractor=None,
127
+ image_encoder=None
128
+ )
129
+
130
+
131
+ @torch.inference_mode()
132
+ def encode_prompt_inner(txt: str):
133
+ max_length = tokenizer.model_max_length
134
+ chunk_length = tokenizer.model_max_length - 2
135
+ id_start = tokenizer.bos_token_id
136
+ id_end = tokenizer.eos_token_id
137
+ id_pad = id_end
138
+
139
+ def pad(x, p, i):
140
+ return x[:i] if len(x) >= i else x + [p] * (i - len(x))
141
+
142
+ tokens = tokenizer(txt, truncation=False, add_special_tokens=False)["input_ids"]
143
+ chunks = [[id_start] + tokens[i: i + chunk_length] + [id_end] for i in range(0, len(tokens), chunk_length)]
144
+ chunks = [pad(ck, id_pad, max_length) for ck in chunks]
145
+
146
+ token_ids = torch.tensor(chunks).to(device=device, dtype=torch.int64)
147
+ conds = text_encoder(token_ids).last_hidden_state
148
+
149
+ return conds
150
+
151
+
152
+ @torch.inference_mode()
153
+ def encode_prompt_pair(positive_prompt, negative_prompt):
154
+ c = encode_prompt_inner(positive_prompt)
155
+ uc = encode_prompt_inner(negative_prompt)
156
+
157
+ c_len = float(len(c))
158
+ uc_len = float(len(uc))
159
+ max_count = max(c_len, uc_len)
160
+ c_repeat = int(math.ceil(max_count / c_len))
161
+ uc_repeat = int(math.ceil(max_count / uc_len))
162
+ max_chunk = max(len(c), len(uc))
163
+
164
+ c = torch.cat([c] * c_repeat, dim=0)[:max_chunk]
165
+ uc = torch.cat([uc] * uc_repeat, dim=0)[:max_chunk]
166
+
167
+ c = torch.cat([p[None, ...] for p in c], dim=1)
168
+ uc = torch.cat([p[None, ...] for p in uc], dim=1)
169
+
170
+ return c, uc
171
+
172
+
173
+ @torch.inference_mode()
174
+ def pytorch2numpy(imgs, quant=True):
175
+ results = []
176
+ for x in imgs:
177
+ y = x.movedim(0, -1)
178
+
179
+ if quant:
180
+ y = y * 127.5 + 127.5
181
+ y = y.detach().float().cpu().numpy().clip(0, 255).astype(np.uint8)
182
+ else:
183
+ y = y * 0.5 + 0.5
184
+ y = y.detach().float().cpu().numpy().clip(0, 1).astype(np.float32)
185
+
186
+ results.append(y)
187
+ return results
188
+
189
+
190
+ @torch.inference_mode()
191
+ def numpy2pytorch(imgs):
192
+ h = torch.from_numpy(np.stack(imgs, axis=0)).float() / 127.0 - 1.0 # so that 127 must be strictly 0.0
193
+ h = h.movedim(-1, 1)
194
+ return h
195
+
196
+
197
+ def resize_and_center_crop(image, target_width, target_height):
198
+ pil_image = Image.fromarray(image)
199
+ original_width, original_height = pil_image.size
200
+ scale_factor = max(target_width / original_width, target_height / original_height)
201
+ resized_width = int(round(original_width * scale_factor))
202
+ resized_height = int(round(original_height * scale_factor))
203
+ resized_image = pil_image.resize((resized_width, resized_height), Image.LANCZOS)
204
+ left = (resized_width - target_width) / 2
205
+ top = (resized_height - target_height) / 2
206
+ right = (resized_width + target_width) / 2
207
+ bottom = (resized_height + target_height) / 2
208
+ cropped_image = resized_image.crop((left, top, right, bottom))
209
+ return np.array(cropped_image)
210
+
211
+
212
+ def resize_without_crop(image, target_width, target_height):
213
+ pil_image = Image.fromarray(image)
214
+ resized_image = pil_image.resize((target_width, target_height), Image.LANCZOS)
215
+ return np.array(resized_image)
216
+
217
+
218
+ @torch.inference_mode()
219
+ def run_rmbg(img, sigma=0.0):
220
+ H, W, C = img.shape
221
+ assert C == 3
222
+ k = (256.0 / float(H * W)) ** 0.5
223
+ feed = resize_without_crop(img, int(64 * round(W * k)), int(64 * round(H * k)))
224
+ feed = numpy2pytorch([feed]).to(device=device, dtype=torch.float32)
225
+ alpha = rmbg(feed)[0][0]
226
+ alpha = torch.nn.functional.interpolate(alpha, size=(H, W), mode="bilinear")
227
+ alpha = alpha.movedim(1, -1)[0]
228
+ alpha = alpha.detach().float().cpu().numpy().clip(0, 1)
229
+ result = 127 + (img.astype(np.float32) - 127 + sigma) * alpha
230
+ return result.clip(0, 255).astype(np.uint8), alpha
231
+
232
+ @torch.inference_mode()
233
+ def process(input_fg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, lowres_denoise, bg_source):
234
+ input_bg = None
235
+
236
+ if bg_source == 'NONE':
237
+ pass
238
+ elif bg_source == 'LEFT':
239
+ gradient = np.linspace(255, 0, image_width)
240
+ image = np.tile(gradient, (image_height, 1))
241
+ input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
242
+ elif bg_source == 'RIGHT':
243
+ gradient = np.linspace(0, 255, image_width)
244
+ image = np.tile(gradient, (image_height, 1))
245
+ input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
246
+ elif bg_source == 'TOP':
247
+ gradient = np.linspace(255, 0, image_height)[:, None]
248
+ image = np.tile(gradient, (1, image_width))
249
+ input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
250
+ elif bg_source == 'BOTTOM':
251
+ gradient = np.linspace(0, 255, image_height)[:, None]
252
+ image = np.tile(gradient, (1, image_width))
253
+ input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
254
+ else:
255
+ raise 'Wrong initial latent!'
256
+
257
+ rng = torch.Generator(device=device).manual_seed(int(seed))
258
+
259
+ fg = resize_and_center_crop(input_fg, image_width, image_height)
260
+
261
+ concat_conds = numpy2pytorch([fg]).to(device=vae.device, dtype=vae.dtype)
262
+ concat_conds = vae.encode(concat_conds).latent_dist.mode() * vae.config.scaling_factor
263
+
264
+ conds, unconds = encode_prompt_pair(positive_prompt=prompt + ', ' + a_prompt, negative_prompt=n_prompt)
265
+
266
+ if input_bg is None:
267
+ latents = t2i_pipe(
268
+ prompt_embeds=conds,
269
+ negative_prompt_embeds=unconds,
270
+ width=image_width,
271
+ height=image_height,
272
+ num_inference_steps=steps,
273
+ num_images_per_prompt=num_samples,
274
+ generator=rng,
275
+ output_type='latent',
276
+ guidance_scale=cfg,
277
+ cross_attention_kwargs={'concat_conds': concat_conds},
278
+ ).images.to(vae.dtype) / vae.config.scaling_factor
279
+ else:
280
+ bg = resize_and_center_crop(input_bg, image_width, image_height)
281
+ bg_latent = numpy2pytorch([bg]).to(device=vae.device, dtype=vae.dtype)
282
+ bg_latent = vae.encode(bg_latent).latent_dist.mode() * vae.config.scaling_factor
283
+ latents = i2i_pipe(
284
+ image=bg_latent,
285
+ strength=lowres_denoise,
286
+ prompt_embeds=conds,
287
+ negative_prompt_embeds=unconds,
288
+ width=image_width,
289
+ height=image_height,
290
+ num_inference_steps=int(round(steps / lowres_denoise)),
291
+ num_images_per_prompt=num_samples,
292
+ generator=rng,
293
+ output_type='latent',
294
+ guidance_scale=cfg,
295
+ cross_attention_kwargs={'concat_conds': concat_conds},
296
+ ).images.to(vae.dtype) / vae.config.scaling_factor
297
+
298
+ pixels = vae.decode(latents).sample
299
+ pixels = pytorch2numpy(pixels)
300
+ pixels = [resize_without_crop(
301
+ image=p,
302
+ target_width=int(round(image_width * highres_scale / 64.0) * 64),
303
+ target_height=int(round(image_height * highres_scale / 64.0) * 64))
304
+ for p in pixels]
305
+
306
+ pixels = numpy2pytorch(pixels).to(device=vae.device, dtype=vae.dtype)
307
+ latents = vae.encode(pixels).latent_dist.mode() * vae.config.scaling_factor
308
+ latents = latents.to(device=unet.device, dtype=unet.dtype)
309
+
310
+ image_height, image_width = latents.shape[2] * 8, latents.shape[3] * 8
311
+
312
+ fg = resize_and_center_crop(input_fg, image_width, image_height)
313
+ concat_conds = numpy2pytorch([fg]).to(device=vae.device, dtype=vae.dtype)
314
+ concat_conds = vae.encode(concat_conds).latent_dist.mode() * vae.config.scaling_factor
315
+
316
+ latents = i2i_pipe(
317
+ image=latents,
318
+ strength=highres_denoise,
319
+ prompt_embeds=conds,
320
+ negative_prompt_embeds=unconds,
321
+ width=image_width,
322
+ height=image_height,
323
+ num_inference_steps=int(round(steps / highres_denoise)),
324
+ num_images_per_prompt=num_samples,
325
+ generator=rng,
326
+ output_type='latent',
327
+ guidance_scale=cfg,
328
+ cross_attention_kwargs={'concat_conds': concat_conds},
329
+ ).images.to(vae.dtype) / vae.config.scaling_factor
330
+
331
+ pixels = vae.decode(latents).sample
332
+
333
+ return pytorch2numpy(pixels)
334
+
335
+ import json
336
+ from diffusers.utils import load_image
337
+
338
+ def closestNumber(n, m):
339
+ q = int(n / m)
340
+ n1 = m * q
341
+ if (n * m) > 0:
342
+ n2 = m * (q + 1)
343
+ else:
344
+ n2 = m * (q - 1)
345
+ if abs(n - n1) < abs(n - n2):
346
+ return n1
347
+ return n2
348
+
349
+ def is_parsable_json(command):
350
+ try:
351
+ json.loads(command)
352
+ return True
353
+ except json.JSONDecodeError:
354
+ return False
355
+
356
+ @torch.inference_mode()
357
+ def generate(command):
358
+ print(command)
359
+ if is_parsable_json(command):
360
+ values = json.loads(command)
361
+ input_fg = values['input_fg']
362
+ input_fg = load_image(input_fg)
363
+ input_fg = np.asarray(input_fg)
364
+ prompt = values['prompt']
365
+ width =closestNumber(values['width'], 8)
366
+ height = closestNumber(values['height'], 8)
367
+ seed = values['seed']
368
+ steps = values['steps']
369
+ a_prompt = values['a_prompt']
370
+ n_prompt = values['n_prompt']
371
+ cfg = values['cfg']
372
+ highres_scale = values['highres_scale']
373
+ highres_denoise = values['highres_denoise']
374
+ lowres_denoise = values['lowres_denoise']
375
+ bg_source = values['bg_source']
376
+ input_fg, matting = run_rmbg(input_fg)
377
+ images = process(input_fg, prompt, width, height, 1, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, lowres_denoise, bg_source)
378
+ image = Image.fromarray(images[0])
379
+ image.save('/content/image.jpg')
380
+ return image
381
+ else:
382
+ input_fg = load_image("https://hips.hearstapps.com/hmg-prod/images/scarlett-johansson-attends-the-premiere-of-illuminations-news-photo-1639390369.jpg?crop=1.00xw:0.836xh;0,0&resize=640:*")
383
+ input_fg = np.asarray(input_fg)
384
+ width = closestNumber(512, 8)
385
+ height = closestNumber(512, 8)
386
+ images = process(input_fg, command, width, height, 1, 1, 25, 'best quality', 'lowres, bad anatomy, bad hands, cropped, worst quality', 2, 1.5, 0.5, 0.9, 'RIGHT')
387
+ image = Image.fromarray(images[0])
388
+ image.save('/content/image.jpg')
389
+ return image
390
+
391
+ runpod.serverless.start({"handler": generate})