rodrigomasini commited on
Commit
a670855
1 Parent(s): 255d880

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +423 -1
app.py CHANGED
@@ -94,6 +94,7 @@ from diffusers import StableDiffusionXLPipeline
94
  from utils import PhotoMakerStableDiffusionXLPipeline
95
  from diffusers import DDIMScheduler
96
  import torch.nn.functional as F
 
97
  def cal_attn_mask(total_length,id_length,sa16,sa32,sa64,device="cuda",dtype= torch.float16):
98
  bool_matrix256 = torch.rand((1, total_length * 256),device = device,dtype = dtype) < sa16
99
  bool_matrix1024 = torch.rand((1, total_length * 1024),device = device,dtype = dtype) < sa32
@@ -133,7 +134,428 @@ import copy
133
  import os
134
  from huggingface_hub import hf_hub_download
135
  from diffusers.utils import load_image
136
- from utils.utils import get_comic # must remove this one
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
 
138
  style_list = [
139
  {
 
94
  from utils import PhotoMakerStableDiffusionXLPipeline
95
  from diffusers import DDIMScheduler
96
  import torch.nn.functional as F
97
+
98
  def cal_attn_mask(total_length,id_length,sa16,sa32,sa64,device="cuda",dtype= torch.float16):
99
  bool_matrix256 = torch.rand((1, total_length * 256),device = device,dtype = dtype) < sa16
100
  bool_matrix1024 = torch.rand((1, total_length * 1024),device = device,dtype = dtype) < sa32
 
134
  import os
135
  from huggingface_hub import hf_hub_download
136
  from diffusers.utils import load_image
137
+ 15.8 kB
138
+ from email.mime import image
139
+ import torch
140
+ import base64
141
+ import gradio as gr
142
+ import numpy as np
143
+ from PIL import Image,ImageOps,ImageDraw, ImageFont
144
+ from io import BytesIO
145
+ import random
146
+ MAX_COLORS = 12
147
+ def get_random_bool():
148
+ return random.choice([True, False])
149
+
150
+ def add_white_border(input_image, border_width=10):
151
+ """
152
+ 为PIL图像添加指定宽度的白色边框。
153
+
154
+ :param input_image: PIL图像对象
155
+ :param border_width: 边框宽度(单位:像素)
156
+ :return: 带有白色边框的PIL图像对象
157
+ """
158
+ border_color = 'white' # 白色边框
159
+ # 添加边框
160
+ img_with_border = ImageOps.expand(input_image, border=border_width, fill=border_color)
161
+ return img_with_border
162
+
163
+ def process_mulline_text(draw, text, font, max_width):
164
+ """
165
+ Draw the text on an image with word wrapping.
166
+ """
167
+ lines = [] # Store the lines of text here
168
+ words = text.split()
169
+
170
+ # Start building lines of text, and wrap when necessary
171
+ current_line = ""
172
+ for word in words:
173
+ test_line = f"{current_line} {word}".strip()
174
+ # Check the width of the line with this word added
175
+ width, _ = draw.textsize(test_line, font=font)
176
+ if width <= max_width:
177
+ # If it fits, add this word to the current line
178
+ current_line = test_line
179
+ else:
180
+ # If not, store the line and start a new one
181
+ lines.append(current_line)
182
+ current_line = word
183
+ # Add the last line
184
+ lines.append(current_line)
185
+ return lines
186
+
187
+
188
+
189
+ def add_caption(image, text, position = "bottom-mid", font = None, text_color= 'black', bg_color = (255, 255, 255) , bg_opacity = 200):
190
+ if text == "":
191
+ return image
192
+ image = image.convert("RGBA")
193
+ draw = ImageDraw.Draw(image)
194
+ width, height = image.size
195
+ lines = process_mulline_text(draw,text,font,width)
196
+ text_positions = []
197
+ maxwidth = 0
198
+ for ind, line in enumerate(lines[::-1]):
199
+ text_width, text_height = draw.textsize(line, font=font)
200
+ if position == 'bottom-right':
201
+ text_position = (width - text_width - 10, height - (text_height + 20))
202
+ elif position == 'bottom-left':
203
+ text_position = (10, height - (text_height + 20))
204
+ elif position == 'bottom-mid':
205
+ text_position = ((width - text_width) // 2, height - (text_height + 20) ) # 居中文本
206
+ height = text_position[1]
207
+ maxwidth = max(maxwidth,text_width)
208
+ text_positions.append(text_position)
209
+ rectpos = (width - maxwidth) // 2
210
+ rectangle_position = [rectpos - 5, text_positions[-1][1] - 5, rectpos + maxwidth + 5, text_positions[0][1] + text_height + 5]
211
+ image_with_transparency = Image.new('RGBA', image.size)
212
+ draw_with_transparency = ImageDraw.Draw(image_with_transparency)
213
+ draw_with_transparency.rectangle(rectangle_position, fill=bg_color + (bg_opacity,))
214
+
215
+ image.paste(Image.alpha_composite(image.convert('RGBA'), image_with_transparency))
216
+ print(ind,text_position)
217
+ draw = ImageDraw.Draw(image)
218
+ for ind, line in enumerate(lines[::-1]):
219
+ text_position = text_positions[ind]
220
+ draw.text(text_position, line, fill=text_color, font=font)
221
+
222
+ return image.convert('RGB')
223
+
224
+ def get_comic(images,types = "4panel",captions = [],font = None,pad_image = None):
225
+ if pad_image == None:
226
+ pad_image = Image.open("./images/pad_images.png")
227
+ if font == None:
228
+ font = ImageFont.truetype("./fonts/Inkfree.ttf", int(30 * images[0].size[1] / 1024))
229
+ if types == "No typesetting (default)":
230
+ return images
231
+ elif types == "Four Pannel":
232
+ return get_comic_4panel(images,captions,font,pad_image)
233
+ else: # "Classic Comic Style"
234
+ return get_comic_classical(images,captions,font,pad_image)
235
+
236
+ def get_caption_group(images_groups,captions = []):
237
+ caption_groups = []
238
+ for i in range(len(images_groups)):
239
+ length = len(images_groups[i])
240
+ caption_groups.append(captions[:length])
241
+ captions = captions[length:]
242
+ if len(caption_groups[-1]) < len(images_groups[-1]):
243
+ caption_groups[-1] = caption_groups[-1] + [""] * (len(images_groups[-1]) - len(caption_groups[-1]))
244
+ return caption_groups
245
+
246
+ def get_comic_classical(images,captions = None,font = None,pad_image = None):
247
+ if pad_image == None:
248
+ raise ValueError("pad_image is None")
249
+ images = [add_white_border(image) for image in images]
250
+ pad_image = pad_image.resize(images[0].size, Image.ANTIALIAS)
251
+ images_groups = distribute_images2(images,pad_image)
252
+ print(images_groups)
253
+ if captions != None:
254
+ captions_groups = get_caption_group(images_groups,captions)
255
+ # print(images_groups)
256
+ row_images = []
257
+ for ind, img_group in enumerate(images_groups):
258
+ row_images.append(get_row_image2(img_group ,captions= captions_groups[ind] if captions != None else None,font = font))
259
+
260
+ return [combine_images_vertically_with_resize(row_images)]
261
+
262
+
263
+
264
+ def get_comic_4panel(images,captions = [],font = None,pad_image = None):
265
+ if pad_image == None:
266
+ raise ValueError("pad_image is None")
267
+ pad_image = pad_image.resize(images[0].size, Image.ANTIALIAS)
268
+ images = [add_white_border(image) for image in images]
269
+ assert len(captions) == len(images)
270
+ for i,caption in enumerate(captions):
271
+ images[i] = add_caption(images[i],caption,font = font)
272
+ images_nums = len(images)
273
+ pad_nums = int((4 - images_nums % 4) % 4)
274
+ images = images + [pad_image for _ in range(pad_nums)]
275
+ comics = []
276
+ assert len(images)%4 == 0
277
+ for i in range(len(images)//4):
278
+ comics.append(combine_images_vertically_with_resize([combine_images_horizontally(images[i*4:i*4+2]), combine_images_horizontally(images[i*4+2:i*4+4])]))
279
+
280
+ return comics
281
+
282
+ def get_row_image(images):
283
+ row_image_arr = []
284
+ if len(images)>3:
285
+ stack_img_nums = (len(images) - 2)//2
286
+ else:
287
+ stack_img_nums = 0
288
+ while(len(images)>0):
289
+ if stack_img_nums <=0:
290
+ row_image_arr.append(images[0])
291
+ images = images[1:]
292
+ elif len(images)>stack_img_nums*2:
293
+ if get_random_bool():
294
+ row_image_arr.append(concat_images_vertically_and_scale(images[:2]))
295
+ images = images[2:]
296
+ stack_img_nums -=1
297
+ else:
298
+ row_image_arr.append(images[0])
299
+ images = images[1:]
300
+ else:
301
+ row_image_arr.append(concat_images_vertically_and_scale(images[:2]))
302
+ images = images[2:]
303
+ stack_img_nums-=1
304
+ return combine_images_horizontally(row_image_arr)
305
+
306
+ def get_row_image2(images,captions = None, font = None):
307
+ row_image_arr = []
308
+ if len(images)== 6:
309
+ sequence_list = [1,1,2,2]
310
+ elif len(images)== 4:
311
+ sequence_list = [1,1,2]
312
+ else:
313
+ raise ValueError("images nums is not 4 or 6 found",len(images))
314
+ random.shuffle(sequence_list)
315
+ index = 0
316
+ for length in sequence_list:
317
+ if length == 1:
318
+ if captions != None:
319
+ images_tmp = add_caption(images[0],text = captions[index],font= font)
320
+ else:
321
+ images_tmp = images[0]
322
+ row_image_arr.append( images_tmp)
323
+ images = images[1:]
324
+ index +=1
325
+ elif length == 2:
326
+ row_image_arr.append(concat_images_vertically_and_scale(images[:2]))
327
+ images = images[2:]
328
+ index +=2
329
+
330
+ return combine_images_horizontally(row_image_arr)
331
+
332
+
333
+
334
+ def concat_images_vertically_and_scale(images,scale_factor=2):
335
+ # 加载所有图像
336
+ # 确保所有图像的宽度一致
337
+ widths = [img.width for img in images]
338
+ if not all(width == widths[0] for width in widths):
339
+ raise ValueError('All images must have the same width.')
340
+
341
+ # 计算总高度
342
+ total_height = sum(img.height for img in images)
343
+
344
+ # 创建新的图像,宽度与原图相同,高度为所有图像高度之和
345
+ max_width = max(widths)
346
+ concatenated_image = Image.new('RGB', (max_width, total_height))
347
+
348
+ # 竖直拼接图像
349
+ current_height = 0
350
+ for img in images:
351
+ concatenated_image.paste(img, (0, current_height))
352
+ current_height += img.height
353
+
354
+ # 缩放图像为1/n高度
355
+ new_height = concatenated_image.height // scale_factor
356
+ new_width = concatenated_image.width // scale_factor
357
+ resized_image = concatenated_image.resize((new_width, new_height), Image.ANTIALIAS)
358
+
359
+ return resized_image
360
+
361
+
362
+ def combine_images_horizontally(images):
363
+ # 读取所有图片并存入列表
364
+
365
+ # 获取每幅图像的宽度和高度
366
+ widths, heights = zip(*(i.size for i in images))
367
+
368
+ # 计算总宽度和最大高度
369
+ total_width = sum(widths)
370
+ max_height = max(heights)
371
+
372
+ # 创建新的空白图片,用于拼接
373
+ new_im = Image.new('RGB', (total_width, max_height))
374
+
375
+ # 将图片横向拼接
376
+ x_offset = 0
377
+ for im in images:
378
+ new_im.paste(im, (x_offset, 0))
379
+ x_offset += im.width
380
+
381
+ return new_im
382
+
383
+ def combine_images_vertically_with_resize(images):
384
+
385
+ # 获取所有图片的宽度和高度
386
+ widths, heights = zip(*(i.size for i in images))
387
+
388
+ # 确定新图片的宽度,即所有图片中最小的宽度
389
+ min_width = min(widths)
390
+
391
+ # 调整图片尺寸以保持宽度一致,长宽比不变
392
+ resized_images = []
393
+ for img in images:
394
+ # 计算新高度保持图片长宽比
395
+ new_height = int(min_width * img.height / img.width)
396
+ # 调整图片大小
397
+ resized_img = img.resize((min_width, new_height), Image.ANTIALIAS)
398
+ resized_images.append(resized_img)
399
+
400
+ # 计算所有调整尺寸后图片的总高度
401
+ total_height = sum(img.height for img in resized_images)
402
+
403
+ # 创建一个足够宽和高的新图片对象
404
+ new_im = Image.new('RGB', (min_width, total_height))
405
+
406
+ # 竖直拼接图片
407
+ y_offset = 0
408
+ for im in resized_images:
409
+ new_im.paste(im, (0, y_offset))
410
+ y_offset += im.height
411
+
412
+ return new_im
413
+
414
+ def distribute_images2(images, pad_image):
415
+ groups = []
416
+ remaining = len(images)
417
+ if len(images) <= 8:
418
+ group_sizes = [4]
419
+ else:
420
+ group_sizes = [4, 6]
421
+
422
+ size_index = 0
423
+ while remaining > 0:
424
+ size = group_sizes[size_index%len(group_sizes)]
425
+ if remaining < size and remaining < min(group_sizes):
426
+ size = min(group_sizes)
427
+ if remaining > size:
428
+ new_group = images[-remaining: -remaining + size]
429
+ else:
430
+ new_group = images[-remaining:]
431
+ groups.append(new_group)
432
+ size_index += 1
433
+ remaining -= size
434
+ print(remaining,groups)
435
+ groups[-1] = groups[-1] + [pad_image for _ in range(-remaining)]
436
+
437
+ return groups
438
+
439
+
440
+ def distribute_images(images, group_sizes=(4, 3, 2)):
441
+ groups = []
442
+ remaining = len(images)
443
+
444
+ while remaining > 0:
445
+ # 优先分配最大组(4张图片),再考虑3张,最后处理2张
446
+ for size in sorted(group_sizes, reverse=True):
447
+ # 如果剩下的图片数量大于等于当前组大小,或者为图片总数时(也就是第一次迭代)
448
+ # 开始创建新组
449
+ if remaining >= size or remaining == len(images):
450
+ if remaining > size:
451
+ new_group = images[-remaining: -remaining + size]
452
+ else:
453
+ new_group = images[-remaining:]
454
+ groups.append(new_group)
455
+ remaining -= size
456
+ break
457
+ # 如果剩下的图片少于最小的组大小(2张)并且已经有组了,就把剩下的图片加到最后一个组
458
+ elif remaining < min(group_sizes) and groups:
459
+ groups[-1].extend(images[-remaining:])
460
+ remaining = 0
461
+
462
+ return groups
463
+
464
+ def create_binary_matrix(img_arr, target_color):
465
+ mask = np.all(img_arr == target_color, axis=-1)
466
+ binary_matrix = mask.astype(int)
467
+ return binary_matrix
468
+
469
+ def preprocess_mask(mask_, h, w, device):
470
+ mask = np.array(mask_)
471
+ mask = mask.astype(np.float32)
472
+ mask = mask[None, None]
473
+ mask[mask < 0.5] = 0
474
+ mask[mask >= 0.5] = 1
475
+ mask = torch.from_numpy(mask).to(device)
476
+ mask = torch.nn.functional.interpolate(mask, size=(h, w), mode='nearest')
477
+ return mask
478
+
479
+ def process_sketch(canvas_data):
480
+ binary_matrixes = []
481
+ base64_img = canvas_data['image']
482
+ image_data = base64.b64decode(base64_img.split(',')[1])
483
+ image = Image.open(BytesIO(image_data)).convert("RGB")
484
+ im2arr = np.array(image)
485
+ colors = [tuple(map(int, rgb[4:-1].split(','))) for rgb in canvas_data['colors']]
486
+ colors_fixed = []
487
+
488
+ r, g, b = 255, 255, 255
489
+ binary_matrix = create_binary_matrix(im2arr, (r,g,b))
490
+ binary_matrixes.append(binary_matrix)
491
+ binary_matrix_ = np.repeat(np.expand_dims(binary_matrix, axis=(-1)), 3, axis=(-1))
492
+ colored_map = binary_matrix_*(r,g,b) + (1-binary_matrix_)*(50,50,50)
493
+ colors_fixed.append(gr.update(value=colored_map.astype(np.uint8)))
494
+
495
+ for color in colors:
496
+ r, g, b = color
497
+ if any(c != 255 for c in (r, g, b)):
498
+ binary_matrix = create_binary_matrix(im2arr, (r,g,b))
499
+ binary_matrixes.append(binary_matrix)
500
+ binary_matrix_ = np.repeat(np.expand_dims(binary_matrix, axis=(-1)), 3, axis=(-1))
501
+ colored_map = binary_matrix_*(r,g,b) + (1-binary_matrix_)*(50,50,50)
502
+ colors_fixed.append(gr.update(value=colored_map.astype(np.uint8)))
503
+
504
+ visibilities = []
505
+ colors = []
506
+ for n in range(MAX_COLORS):
507
+ visibilities.append(gr.update(visible=False))
508
+ colors.append(gr.update())
509
+ for n in range(len(colors_fixed)):
510
+ visibilities[n] = gr.update(visible=True)
511
+ colors[n] = colors_fixed[n]
512
+
513
+ return [gr.update(visible=True), binary_matrixes, *visibilities, *colors]
514
+
515
+ def process_prompts(binary_matrixes, *seg_prompts):
516
+ return [gr.update(visible=True), gr.update(value=' , '.join(seg_prompts[:len(binary_matrixes)]))]
517
+
518
+ def process_example(layout_path, all_prompts, seed_):
519
+
520
+ all_prompts = all_prompts.split('***')
521
+
522
+ binary_matrixes = []
523
+ colors_fixed = []
524
+
525
+ im2arr = np.array(Image.open(layout_path))[:,:,:3]
526
+ unique, counts = np.unique(np.reshape(im2arr,(-1,3)), axis=0, return_counts=True)
527
+ sorted_idx = np.argsort(-counts)
528
+
529
+ binary_matrix = create_binary_matrix(im2arr, (0,0,0))
530
+ binary_matrixes.append(binary_matrix)
531
+ binary_matrix_ = np.repeat(np.expand_dims(binary_matrix, axis=(-1)), 3, axis=(-1))
532
+ colored_map = binary_matrix_*(255,255,255) + (1-binary_matrix_)*(50,50,50)
533
+ colors_fixed.append(gr.update(value=colored_map.astype(np.uint8)))
534
+
535
+ for i in range(len(all_prompts)-1):
536
+ r, g, b = unique[sorted_idx[i]]
537
+ if any(c != 255 for c in (r, g, b)) and any(c != 0 for c in (r, g, b)):
538
+ binary_matrix = create_binary_matrix(im2arr, (r,g,b))
539
+ binary_matrixes.append(binary_matrix)
540
+ binary_matrix_ = np.repeat(np.expand_dims(binary_matrix, axis=(-1)), 3, axis=(-1))
541
+ colored_map = binary_matrix_*(r,g,b) + (1-binary_matrix_)*(50,50,50)
542
+ colors_fixed.append(gr.update(value=colored_map.astype(np.uint8)))
543
+
544
+ visibilities = []
545
+ colors = []
546
+ prompts = []
547
+ for n in range(MAX_COLORS):
548
+ visibilities.append(gr.update(visible=False))
549
+ colors.append(gr.update())
550
+ prompts.append(gr.update())
551
+
552
+ for n in range(len(colors_fixed)):
553
+ visibilities[n] = gr.update(visible=True)
554
+ colors[n] = colors_fixed[n]
555
+ prompts[n] = all_prompts[n+1]
556
+
557
+ return [gr.update(visible=True), binary_matrixes, *visibilities, *colors, *prompts,
558
+ gr.update(visible=True), gr.update(value=all_prompts[0]), int(seed_)]
559
 
560
  style_list = [
561
  {