|
''' |
|
Copyright (c) Alibaba, Inc. and its affiliates. |
|
''' |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from functools import partial |
|
from iopaint.model.anytext.ldm.modules.diffusionmodules.util import conv_nd, linear |
|
|
|
|
|
def get_clip_token_for_string(tokenizer, string): |
|
batch_encoding = tokenizer(string, truncation=True, max_length=77, return_length=True, |
|
return_overflowing_tokens=False, padding="max_length", return_tensors="pt") |
|
tokens = batch_encoding["input_ids"] |
|
assert torch.count_nonzero(tokens - 49407) == 2, f"String '{string}' maps to more than a single token. Please use another string" |
|
return tokens[0, 1] |
|
|
|
|
|
def get_bert_token_for_string(tokenizer, string): |
|
token = tokenizer(string) |
|
assert torch.count_nonzero(token) == 3, f"String '{string}' maps to more than a single token. Please use another string" |
|
token = token[0, 1] |
|
return token |
|
|
|
|
|
def get_clip_vision_emb(encoder, processor, img): |
|
_img = img.repeat(1, 3, 1, 1)*255 |
|
inputs = processor(images=_img, return_tensors="pt") |
|
inputs['pixel_values'] = inputs['pixel_values'].to(img.device) |
|
outputs = encoder(**inputs) |
|
emb = outputs.image_embeds |
|
return emb |
|
|
|
|
|
def get_recog_emb(encoder, img_list): |
|
_img_list = [(img.repeat(1, 3, 1, 1)*255)[0] for img in img_list] |
|
encoder.predictor.eval() |
|
_, preds_neck = encoder.pred_imglist(_img_list, show_debug=False) |
|
return preds_neck |
|
|
|
|
|
def pad_H(x): |
|
_, _, H, W = x.shape |
|
p_top = (W - H) // 2 |
|
p_bot = W - H - p_top |
|
return F.pad(x, (0, 0, p_top, p_bot)) |
|
|
|
|
|
class EncodeNet(nn.Module): |
|
def __init__(self, in_channels, out_channels): |
|
super(EncodeNet, self).__init__() |
|
chan = 16 |
|
n_layer = 4 |
|
|
|
self.conv1 = conv_nd(2, in_channels, chan, 3, padding=1) |
|
self.conv_list = nn.ModuleList([]) |
|
_c = chan |
|
for i in range(n_layer): |
|
self.conv_list.append(conv_nd(2, _c, _c*2, 3, padding=1, stride=2)) |
|
_c *= 2 |
|
self.conv2 = conv_nd(2, _c, out_channels, 3, padding=1) |
|
self.avgpool = nn.AdaptiveAvgPool2d(1) |
|
self.act = nn.SiLU() |
|
|
|
def forward(self, x): |
|
x = self.act(self.conv1(x)) |
|
for layer in self.conv_list: |
|
x = self.act(layer(x)) |
|
x = self.act(self.conv2(x)) |
|
x = self.avgpool(x) |
|
x = x.view(x.size(0), -1) |
|
return x |
|
|
|
|
|
class EmbeddingManager(nn.Module): |
|
def __init__( |
|
self, |
|
embedder, |
|
valid=True, |
|
glyph_channels=20, |
|
position_channels=1, |
|
placeholder_string='*', |
|
add_pos=False, |
|
emb_type='ocr', |
|
**kwargs |
|
): |
|
super().__init__() |
|
if hasattr(embedder, 'tokenizer'): |
|
get_token_for_string = partial(get_clip_token_for_string, embedder.tokenizer) |
|
token_dim = 768 |
|
if hasattr(embedder, 'vit'): |
|
assert emb_type == 'vit' |
|
self.get_vision_emb = partial(get_clip_vision_emb, embedder.vit, embedder.processor) |
|
self.get_recog_emb = None |
|
else: |
|
get_token_for_string = partial(get_bert_token_for_string, embedder.tknz_fn) |
|
token_dim = 1280 |
|
self.token_dim = token_dim |
|
self.emb_type = emb_type |
|
|
|
self.add_pos = add_pos |
|
if add_pos: |
|
self.position_encoder = EncodeNet(position_channels, token_dim) |
|
if emb_type == 'ocr': |
|
self.proj = linear(40*64, token_dim) |
|
if emb_type == 'conv': |
|
self.glyph_encoder = EncodeNet(glyph_channels, token_dim) |
|
|
|
self.placeholder_token = get_token_for_string(placeholder_string) |
|
|
|
def encode_text(self, text_info): |
|
if self.get_recog_emb is None and self.emb_type == 'ocr': |
|
self.get_recog_emb = partial(get_recog_emb, self.recog) |
|
|
|
gline_list = [] |
|
pos_list = [] |
|
for i in range(len(text_info['n_lines'])): |
|
n_lines = text_info['n_lines'][i] |
|
for j in range(n_lines): |
|
gline_list += [text_info['gly_line'][j][i:i+1]] |
|
if self.add_pos: |
|
pos_list += [text_info['positions'][j][i:i+1]] |
|
|
|
if len(gline_list) > 0: |
|
if self.emb_type == 'ocr': |
|
recog_emb = self.get_recog_emb(gline_list) |
|
enc_glyph = self.proj(recog_emb.reshape(recog_emb.shape[0], -1)) |
|
elif self.emb_type == 'vit': |
|
enc_glyph = self.get_vision_emb(pad_H(torch.cat(gline_list, dim=0))) |
|
elif self.emb_type == 'conv': |
|
enc_glyph = self.glyph_encoder(pad_H(torch.cat(gline_list, dim=0))) |
|
if self.add_pos: |
|
enc_pos = self.position_encoder(torch.cat(gline_list, dim=0)) |
|
enc_glyph = enc_glyph+enc_pos |
|
|
|
self.text_embs_all = [] |
|
n_idx = 0 |
|
for i in range(len(text_info['n_lines'])): |
|
n_lines = text_info['n_lines'][i] |
|
text_embs = [] |
|
for j in range(n_lines): |
|
text_embs += [enc_glyph[n_idx:n_idx+1]] |
|
n_idx += 1 |
|
self.text_embs_all += [text_embs] |
|
|
|
def forward( |
|
self, |
|
tokenized_text, |
|
embedded_text, |
|
): |
|
b, device = tokenized_text.shape[0], tokenized_text.device |
|
for i in range(b): |
|
idx = tokenized_text[i] == self.placeholder_token.to(device) |
|
if sum(idx) > 0: |
|
if i >= len(self.text_embs_all): |
|
print('truncation for log images...') |
|
break |
|
text_emb = torch.cat(self.text_embs_all[i], dim=0) |
|
if sum(idx) != len(text_emb): |
|
print('truncation for long caption...') |
|
embedded_text[i][idx] = text_emb[:sum(idx)] |
|
return embedded_text |
|
|
|
def embedding_parameters(self): |
|
return self.parameters() |
|
|