File size: 8,853 Bytes
e8dca02
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
# ------------------------------------------
# TextDiffuser: Diffusion Models as Text Painters
# Paper Link: https://arxiv.org/abs/2305.10855
# Code Link: https://github.com/microsoft/unilm/tree/master/textdiffuser
# Copyright (c) Microsoft Corporation.
# This file aims to predict the layout of keywords in user prompts.
# ------------------------------------------

import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)

import re
import numpy as np
import torch
import torch.nn as nn
from transformers import CLIPTokenizer
from PIL import Image, ImageDraw, ImageFont
from util import get_width, get_key_words, adjust_overlap_box, shrink_box, adjust_font_size, alphabet_dic
from model.layout_transformer import LayoutTransformer, TextConditioner
from termcolor import colored

# import layout transformer
model = LayoutTransformer().cuda().eval()
model.load_state_dict(torch.load('textdiffuser-ckpt/layout_transformer.pth'))

# import text encoder and tokenizer
text_encoder = TextConditioner().cuda().eval()
tokenizer = CLIPTokenizer.from_pretrained('openai/clip-vit-large-patch14')


def process_caption(font_path, caption, keywords):
    # remove punctuations. please remove this statement if you want to paint punctuations
    caption = re.sub(u"([^\u0041-\u005a\u0061-\u007a\u0030-\u0039])", " ", caption) 
    
    # tokenize it into ids and get length
    caption_words = tokenizer([caption], truncation=True, max_length=77, return_length=True, return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
    caption_words_ids = caption_words['input_ids'] # (1, 77)
    length = caption_words['length'] # (1, )
    
    # convert id to words
    words = tokenizer.convert_ids_to_tokens(caption_words_ids.view(-1).tolist())
    words = [i.replace('</w>', '') for i in words]
    words_valid = words[:int(length)]

    # store the box coordinates and state of each token
    info_array = np.zeros((77,5)) # (77, 5)

    # split the caption into words and convert them into lower case
    caption_split = caption.split() 
    caption_split = [i.lower() for i in caption_split]

    start_dic = {} # get the start index of each word
    state_list = [] # 0: start, 1: middle, 2: special token
    word_match_list = [] # the index of the word in the caption
    current_caption_index = 0 
    current_match = ''
    for i in range(length): 

        # the first and last token are special tokens
        if i == 0 or i == length-1:
            state_list.append(2) 
            word_match_list.append(127)
            continue

        if current_match == '':
            state_list.append(0)
            start_dic[current_caption_index] = i
        else:
            state_list.append(1)

        current_match += words_valid[i]
        word_match_list.append(current_caption_index)
        if current_match == caption_split[current_caption_index]:
            current_match = ''
            current_caption_index += 1

    while len(state_list) < 77:
        state_list.append(127)
    while len(word_match_list) < 77:
        word_match_list.append(127)

    length_list = []
    width_list =[]
    for i in range(len(word_match_list)):
        if word_match_list[i] == 127:
            length_list.append(0)
            width_list.append(0)
        else:
            length_list.append(len(caption.split()[word_match_list[i]]))
            width_list.append(get_width(font_path, caption.split()[word_match_list[i]]))

    while len(length_list) < 77:
        length_list.append(127)
        width_list.append(0)

    length_list = torch.Tensor(length_list).long() # (77, )
    width_list = torch.Tensor(width_list).long() # (77, )

    boxes = []
    duplicate_dict = {} # some words may appear more than once
    for keyword in keywords: 
        keyword = keyword.lower()
        if keyword in caption_split:
            if keyword not in duplicate_dict:
                duplicate_dict[keyword] = caption_split.index(keyword) 
                index = caption_split.index(keyword)
            else:
                if duplicate_dict[keyword]+1 < len(caption_split) and keyword in caption_split[duplicate_dict[keyword]+1:]:
                    index = duplicate_dict[keyword] + caption_split[duplicate_dict[keyword]+1:].index(keyword)
                    duplicate_dict[keyword] = index
                else:
                    continue
                
            index = caption_split.index(keyword) 
            index = start_dic[index] 
            info_array[index][0] = 1 

            box = [0,0,0,0] 
            boxes.append(list(box))
            info_array[index][1:] = box
    
    boxes_length = len(boxes)
    if boxes_length > 8:
        boxes = boxes[:8]
    while len(boxes) < 8:
        boxes.append([0,0,0,0])

    return caption, length_list, width_list, torch.from_numpy(info_array), words, torch.Tensor(state_list).long(), torch.Tensor(word_match_list).long(), torch.Tensor(boxes), boxes_length


def get_layout_from_prompt(args):

    # prompt = args.prompt
    font_path = args.font_path
    keywords = get_key_words(args.prompt)
    
    print(f'{colored("[!]", "red")} Detected keywords: {keywords} from prompt {args.prompt}')
    
    text_embedding, mask = text_encoder(args.prompt) # (1, 77 768) / (1, 77)

    # process all relevant info
    caption, length_list, width_list, target, words, state_list, word_match_list, boxes, boxes_length = process_caption(font_path, args.prompt, keywords)
    target = target.cuda().unsqueeze(0) # (77, 5)
    width_list = width_list.cuda().unsqueeze(0) # (77, )
    length_list = length_list.cuda().unsqueeze(0) # (77, )
    state_list = state_list.cuda().unsqueeze(0) # (77, )
    word_match_list = word_match_list.cuda().unsqueeze(0) # (77, )

    padding = torch.zeros(1, 1, 4).cuda()
    boxes = boxes.unsqueeze(0).cuda()
    right_shifted_boxes = torch.cat([padding, boxes[:,0:-1,:]],1) # (1, 8, 4)
   
    # inference
    return_boxes= []
    with torch.no_grad():
        for box_index in range(boxes_length):
            
            if box_index == 0:
                encoder_embedding = None
                
            output, encoder_embedding = model(text_embedding, length_list, width_list, mask, state_list, word_match_list, target, right_shifted_boxes, train=False, encoder_embedding=encoder_embedding) 
            output = torch.clamp(output, min=0, max=1) # (1, 8, 4)
            
            # add overlap detection
            output = adjust_overlap_box(output, box_index) # (1, 8, 4)
            
            right_shifted_boxes[:,box_index+1,:] = output[:,box_index,:]
            xmin, ymin, xmax, ymax = output[0, box_index, :].tolist()
            return_boxes.append([xmin, ymin, xmax, ymax])
            
            
    # print the location of keywords
    print(f'index\tkeyword\tx_min\ty_min\tx_max\ty_max')
    for index, keyword in enumerate(keywords):
        x_min = int(return_boxes[index][0] * 512)
        y_min = int(return_boxes[index][1] * 512)
        x_max = int(return_boxes[index][2] * 512)
        y_max = int(return_boxes[index][3] * 512)
        print(f'{index}\t{keyword}\t{x_min}\t{y_min}\t{x_max}\t{y_max}')
    
    
    # paint the layout
    render_image = Image.new('RGB', (512, 512), (255, 255, 255))
    draw = ImageDraw.Draw(render_image)
    segmentation_mask = Image.new("L", (512,512), 0)
    segmentation_mask_draw = ImageDraw.Draw(segmentation_mask)

    for index, box in enumerate(return_boxes):
        box = [int(i*512) for i in box]
        xmin, ymin, xmax, ymax = box
        
        width = xmax - xmin
        height = ymax - ymin
        text = keywords[index]

        font_size = adjust_font_size(args, width, height, draw, text)
        font = ImageFont.truetype(args.font_path, font_size)

        # draw.rectangle([xmin, ymin, xmax,ymax], outline=(255,0,0))
        draw.text((xmin, ymin), text, font=font, fill=(0, 0, 0))
            
        boxes = []
        for i, char in enumerate(text):
            
            # paint character-level segmentation masks
            # https://github.com/python-pillow/Pillow/issues/3921
            bottom_1 = font.getsize(text[i])[1]
            right, bottom_2 = font.getsize(text[:i+1])
            bottom = bottom_1 if bottom_1 < bottom_2 else bottom_2
            width, height = font.getmask(char).size
            right += xmin
            bottom += ymin
            top = bottom - height
            left = right - width
            
            char_box = (left, top, right, bottom)
            boxes.append(char_box)
            
            char_index = alphabet_dic[char]
            segmentation_mask_draw.rectangle(shrink_box(char_box, scale_factor = 0.9), fill=char_index)
    
    print(f'{colored("[√]", "green")} Layout is successfully generated')
    return render_image, segmentation_mask