Upload anytext.py
Browse files- anytext.py +12 -62
anytext.py
CHANGED
@@ -35,6 +35,7 @@ import PIL.Image
|
|
35 |
import torch
|
36 |
import torch.nn.functional as F
|
37 |
from easydict import EasyDict as edict
|
|
|
38 |
from huggingface_hub import hf_hub_download
|
39 |
from ocr_recog.RecModel import RecModel
|
40 |
from PIL import Image, ImageDraw, ImageFont
|
@@ -206,13 +207,12 @@ def get_recog_emb(encoder, img_list):
|
|
206 |
class EmbeddingManager(nn.Module):
|
207 |
def __init__(
|
208 |
self,
|
209 |
-
|
210 |
placeholder_string="*",
|
211 |
use_fp16=False,
|
212 |
-
device="cpu",
|
213 |
):
|
214 |
super().__init__()
|
215 |
-
get_token_for_string = partial(get_clip_token_for_string,
|
216 |
token_dim = 768
|
217 |
self.get_recog_emb = None
|
218 |
self.token_dim = token_dim
|
@@ -223,7 +223,7 @@ class EmbeddingManager(nn.Module):
|
|
223 |
filename="text_embedding_module/proj.safetensors",
|
224 |
cache_dir=HF_MODULES_CACHE,
|
225 |
)
|
226 |
-
self.proj.load_state_dict(load_file(proj_dir, device=str(device)))
|
227 |
if use_fp16:
|
228 |
self.proj = self.proj.to(dtype=torch.float16)
|
229 |
|
@@ -526,20 +526,14 @@ class TextEmbeddingModule(nn.Module):
|
|
526 |
self.font = ImageFont.truetype(font_path, 60)
|
527 |
self.use_fp16 = use_fp16
|
528 |
self.device = device
|
529 |
-
|
530 |
-
|
531 |
-
version = "openai/clip-vit-large-patch14"
|
532 |
-
torch_dtype = torch.float16 if use_fp16 else torch.float32
|
533 |
-
self.clip_tokenizer = CLIPTokenizer.from_pretrained(version)
|
534 |
-
self.clip_text_model = CLIPTextModel.from_pretrained(version, torch_dtype=torch_dtype).to(device)
|
535 |
-
self.max_length = 77 # same as before
|
536 |
-
|
537 |
-
self.embedding_manager = EmbeddingManager(self.clip_tokenizer, use_fp16=use_fp16, device=device)
|
538 |
rec_model_dir = "./text_embedding_module/OCR/ppv3_rec.pth"
|
539 |
self.text_predictor = create_predictor(rec_model_dir, device=device, use_fp16=use_fp16).eval()
|
540 |
args = {}
|
541 |
args["rec_image_shape"] = "3, 48, 320"
|
542 |
args["rec_batch_num"] = 6
|
|
|
543 |
args["rec_char_dict_path"] = hf_hub_download(
|
544 |
repo_id="tolgacangoz/anytext",
|
545 |
filename="text_embedding_module/OCR/ppocr_keys_v1.txt",
|
@@ -548,50 +542,6 @@ class TextEmbeddingModule(nn.Module):
|
|
548 |
args["use_fp16"] = use_fp16
|
549 |
self.embedding_manager.recog = TextRecognizer(args, self.text_predictor)
|
550 |
|
551 |
-
# New helper method to mimic old encode() functionality with chunk splitting
|
552 |
-
def _encode_text(self, texts, embedding_manager=None, **kwargs):
|
553 |
-
batch_encoding = self.clip_tokenizer(
|
554 |
-
texts,
|
555 |
-
truncation=False,
|
556 |
-
max_length=self.max_length,
|
557 |
-
padding="longest",
|
558 |
-
return_tensors="pt",
|
559 |
-
)
|
560 |
-
input_ids = batch_encoding["input_ids"]
|
561 |
-
tokens_list = self._split_chunks(input_ids)
|
562 |
-
embeds_list = []
|
563 |
-
for tokens in tokens_list:
|
564 |
-
tokens = tokens.to(self.device)
|
565 |
-
outputs = self.clip_text_model(input_ids=tokens, **kwargs)
|
566 |
-
# use last_hidden_state as in the old version
|
567 |
-
embeds_list.append(outputs.last_hidden_state)
|
568 |
-
return torch.cat(embeds_list, dim=1)
|
569 |
-
|
570 |
-
# New helper for splitting tokens (mimicking split_chunks behavior)
|
571 |
-
def _split_chunks(self, input_ids, chunk_size=75):
|
572 |
-
tokens_list = []
|
573 |
-
bs, n = input_ids.shape
|
574 |
-
id_start = input_ids[:, 0].unsqueeze(1)
|
575 |
-
id_end = input_ids[:, -1].unsqueeze(1)
|
576 |
-
if n == 2: # empty caption
|
577 |
-
tokens_list.append(torch.cat((id_start,) + (id_end,) * (chunk_size + 1), dim=1))
|
578 |
-
return tokens_list
|
579 |
-
|
580 |
-
trimmed = input_ids[:, 1:-1]
|
581 |
-
num_full = (n - 2) // chunk_size
|
582 |
-
for i in range(num_full):
|
583 |
-
group = trimmed[:, i*chunk_size:(i+1)*chunk_size]
|
584 |
-
group_pad = torch.cat((id_start, group, id_end), dim=1)
|
585 |
-
tokens_list.append(group_pad)
|
586 |
-
rem = (n - 2) % chunk_size
|
587 |
-
if rem > 0:
|
588 |
-
group = trimmed[:, -rem:]
|
589 |
-
pad_cols = chunk_size - group.shape[1]
|
590 |
-
padding = id_end.expand(bs, pad_cols)
|
591 |
-
group_pad = torch.cat((id_start, group, padding, id_end), dim=1)
|
592 |
-
tokens_list.append(group_pad)
|
593 |
-
return tokens_list
|
594 |
-
|
595 |
@torch.no_grad()
|
596 |
def forward(
|
597 |
self,
|
@@ -704,9 +654,10 @@ class TextEmbeddingModule(nn.Module):
|
|
704 |
# hint = self.arr2tensor(np_hint, len(prompt))
|
705 |
|
706 |
self.embedding_manager.encode_text(text_info)
|
707 |
-
prompt_embeds = self.
|
|
|
708 |
self.embedding_manager.encode_text(text_info)
|
709 |
-
negative_prompt_embeds = self.
|
710 |
[negative_prompt or ""], embedding_manager=self.embedding_manager
|
711 |
)
|
712 |
|
@@ -856,11 +807,10 @@ class TextEmbeddingModule(nn.Module):
|
|
856 |
return new_string[:-nSpace]
|
857 |
|
858 |
def to(self, *args, **kwargs):
|
859 |
-
self.
|
860 |
-
self.device = self.clip_text_model.device
|
861 |
self.embedding_manager = self.embedding_manager.to(*args, **kwargs)
|
862 |
self.text_predictor = self.text_predictor.to(*args, **kwargs)
|
863 |
-
self.device = self.
|
864 |
return self
|
865 |
|
866 |
|
|
|
35 |
import torch
|
36 |
import torch.nn.functional as F
|
37 |
from easydict import EasyDict as edict
|
38 |
+
from frozen_clip_embedder_t3 import FrozenCLIPEmbedderT3
|
39 |
from huggingface_hub import hf_hub_download
|
40 |
from ocr_recog.RecModel import RecModel
|
41 |
from PIL import Image, ImageDraw, ImageFont
|
|
|
207 |
class EmbeddingManager(nn.Module):
|
208 |
def __init__(
|
209 |
self,
|
210 |
+
embedder,
|
211 |
placeholder_string="*",
|
212 |
use_fp16=False,
|
|
|
213 |
):
|
214 |
super().__init__()
|
215 |
+
get_token_for_string = partial(get_clip_token_for_string, embedder.tokenizer)
|
216 |
token_dim = 768
|
217 |
self.get_recog_emb = None
|
218 |
self.token_dim = token_dim
|
|
|
223 |
filename="text_embedding_module/proj.safetensors",
|
224 |
cache_dir=HF_MODULES_CACHE,
|
225 |
)
|
226 |
+
self.proj.load_state_dict(load_file(proj_dir, device=str(embedder.device)))
|
227 |
if use_fp16:
|
228 |
self.proj = self.proj.to(dtype=torch.float16)
|
229 |
|
|
|
526 |
self.font = ImageFont.truetype(font_path, 60)
|
527 |
self.use_fp16 = use_fp16
|
528 |
self.device = device
|
529 |
+
self.frozen_CLIP_embedder_t3 = FrozenCLIPEmbedderT3(device=device, use_fp16=use_fp16)
|
530 |
+
self.embedding_manager = EmbeddingManager(self.frozen_CLIP_embedder_t3, use_fp16=use_fp16)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
531 |
rec_model_dir = "./text_embedding_module/OCR/ppv3_rec.pth"
|
532 |
self.text_predictor = create_predictor(rec_model_dir, device=device, use_fp16=use_fp16).eval()
|
533 |
args = {}
|
534 |
args["rec_image_shape"] = "3, 48, 320"
|
535 |
args["rec_batch_num"] = 6
|
536 |
+
args["rec_char_dict_path"] = "./text_embedding_module/OCR/ppocr_keys_v1.txt"
|
537 |
args["rec_char_dict_path"] = hf_hub_download(
|
538 |
repo_id="tolgacangoz/anytext",
|
539 |
filename="text_embedding_module/OCR/ppocr_keys_v1.txt",
|
|
|
542 |
args["use_fp16"] = use_fp16
|
543 |
self.embedding_manager.recog = TextRecognizer(args, self.text_predictor)
|
544 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
545 |
@torch.no_grad()
|
546 |
def forward(
|
547 |
self,
|
|
|
654 |
# hint = self.arr2tensor(np_hint, len(prompt))
|
655 |
|
656 |
self.embedding_manager.encode_text(text_info)
|
657 |
+
prompt_embeds = self.frozen_CLIP_embedder_t3.encode([prompt], embedding_manager=self.embedding_manager)
|
658 |
+
|
659 |
self.embedding_manager.encode_text(text_info)
|
660 |
+
negative_prompt_embeds = self.frozen_CLIP_embedder_t3.encode(
|
661 |
[negative_prompt or ""], embedding_manager=self.embedding_manager
|
662 |
)
|
663 |
|
|
|
807 |
return new_string[:-nSpace]
|
808 |
|
809 |
def to(self, *args, **kwargs):
|
810 |
+
self.frozen_CLIP_embedder_t3 = self.frozen_CLIP_embedder_t3.to(*args, **kwargs)
|
|
|
811 |
self.embedding_manager = self.embedding_manager.to(*args, **kwargs)
|
812 |
self.text_predictor = self.text_predictor.to(*args, **kwargs)
|
813 |
+
self.device = self.frozen_CLIP_embedder_t3.device
|
814 |
return self
|
815 |
|
816 |
|