File size: 5,008 Bytes
e8dca02
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
# ------------------------------------------
# TextDiffuser: Diffusion Models as Text Painters
# Paper Link: https://arxiv.org/abs/2305.10855
# Code Link: https://github.com/microsoft/unilm/tree/master/textdiffuser
# Copyright (c) Microsoft Corporation.
# This file define the Layout Transformer for predicting the layout of keywords.
# ------------------------------------------

import torch
import torch.nn as nn 
from transformers import CLIPTokenizer, CLIPTextModel

class TextConditioner(nn.Module):
    
    def __init__(self):
        super(TextConditioner, self).__init__()
        self.transformer = CLIPTextModel.from_pretrained('openai/clip-vit-large-patch14')
        self.tokenizer = CLIPTokenizer.from_pretrained('openai/clip-vit-large-patch14')
        
        # fix
        self.transformer.eval()
        for param in self.transformer.parameters():
            param.requires_grad = False
        
    def forward(self, prompt_list):
        batch_encoding = self.tokenizer(prompt_list, truncation=True, max_length=77, return_length=True, return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
        text_embedding = self.transformer(batch_encoding["input_ids"].cuda())
        return text_embedding.last_hidden_state.cuda(), batch_encoding["attention_mask"].cuda() # 1, 77, 768  /  1, 768


class LayoutTransformer(nn.Module): 
    
    def __init__(self, layer_number=2):
        super(LayoutTransformer, self).__init__()
        
        self.encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
        self.transformer = torch.nn.TransformerEncoder(
            self.encoder_layer, num_layers=layer_number 
        )
        
        self.decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8)
        self.decoder_transformer = torch.nn.TransformerDecoder(
            self.decoder_layer, num_layers=layer_number 
        )

        self.mask_embedding = nn.Embedding(2,512)
        self.length_embedding = nn.Embedding(256,512)
        self.width_embedding = nn.Embedding(256,512)
        self.position_embedding = nn.Embedding(256,512)
        self.state_embedding = nn.Embedding(256,512)
        self.match_embedding = nn.Embedding(256,512)
        
        self.x_embedding = nn.Embedding(512,512)
        self.y_embedding = nn.Embedding(512,512)
        self.w_embedding = nn.Embedding(512,512)
        self.h_embedding = nn.Embedding(512,512)
        
        self.encoder_target_embedding = nn.Embedding(256,512)

        self.input_layer = nn.Sequential(
            nn.Linear(768, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
        )

        self.output_layer = nn.Sequential(
            nn.Linear(512, 128),
            nn.ReLU(),
            nn.Linear(128, 4),
        )

    def forward(self, x, length, width, mask, state, match, target, right_shifted_boxes, train=False, encoder_embedding=None):
        
        # detect whether the encoder_embedding is cached
        if encoder_embedding is None:
            # augmentation
            if train:
                width = width + torch.randint(-3, 3, (width.shape[0], width.shape[1])).cuda()

            x = self.input_layer(x) # (1, 77, 512)    
            width_embedding = self.width_embedding(torch.clamp(width, 0, 255).long()) # (1, 77, 512)    
            encoder_target_embedding = self.encoder_target_embedding(target[:,:,0].long()) # (1, 77, 512)    
            pe_embedding = self.position_embedding(torch.arange(77).cuda()).unsqueeze(0) # (1, 77, 512)    
            total_embedding = x + width_embedding + pe_embedding + encoder_target_embedding # combine all the embeddings (1, 77, 512)    
            total_embedding = total_embedding.permute(1,0,2) # (77, 1, 512)     
            encoder_embedding = self.transformer(total_embedding) # (77, 1, 512)    
      
        right_shifted_boxes_resize = (right_shifted_boxes * 512).long() # (1, 8, 4)
        right_shifted_boxes_resize = torch.clamp(right_shifted_boxes_resize, 0, 511) # (1, 8, 4)
        
        # decoder pe
        pe_decoder = torch.arange(8).cuda() # (8, )
        pe_embedding_decoder = self.position_embedding(pe_decoder).unsqueeze(0) # (1, 8, 512)
        decoder_input = pe_embedding_decoder + self.x_embedding(right_shifted_boxes_resize[:,:,0]) + self.y_embedding(right_shifted_boxes_resize[:,:,1]) + self.w_embedding(right_shifted_boxes_resize[:,:,2]) + self.h_embedding(right_shifted_boxes_resize[:,:,3]) # (1, 8, 512)
        decoder_input = decoder_input.permute(1,0,2) # (8, 1, 512)
        
        # generate triangular mask
        mask = nn.Transformer.generate_square_subsequent_mask(8) # (8, 8)
        mask = mask.cuda() # (8, 8)
        decoder_result = self.decoder_transformer(decoder_input, encoder_embedding, tgt_mask=mask) # (8, 1, 512)
        decoder_result = decoder_result.permute(1,0,2) # (1, 8, 512)
        
        box_prediction = self.output_layer(decoder_result) # (1, 8, 4)
        return box_prediction, encoder_embedding