Upload anytext.py
Browse files- text_embedding_module/anytext.py +25 -15
text_embedding_module/anytext.py
CHANGED
@@ -69,6 +69,7 @@ from diffusers.utils import (
|
|
69 |
from diffusers.utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor
|
70 |
from diffusers.configuration_utils import register_to_config, ConfigMixin
|
71 |
from diffusers.models.modeling_utils import ModelMixin
|
|
|
72 |
|
73 |
|
74 |
checker = BasicTokenizer()
|
@@ -152,7 +153,7 @@ class EmbeddingManager(nn.Module):
|
|
152 |
self.token_dim = token_dim
|
153 |
|
154 |
self.proj = nn.Linear(40 * 64, token_dim)
|
155 |
-
self.proj.load_state_dict(load_file("proj.safetensors", device=str(embedder.device)))
|
156 |
if use_fp16:
|
157 |
self.proj = self.proj.to(dtype=torch.float16)
|
158 |
|
@@ -269,9 +270,20 @@ def crop_image(src_img, mask):
|
|
269 |
|
270 |
|
271 |
def create_predictor(model_dir=None, model_lang="ch", device="cpu", use_fp16=False):
|
272 |
-
|
273 |
-
|
274 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
275 |
|
276 |
if model_lang == "ch":
|
277 |
n_class = 6625
|
@@ -287,8 +299,8 @@ def create_predictor(model_dir=None, model_lang="ch", device="cpu", use_fp16=Fal
|
|
287 |
)
|
288 |
|
289 |
rec_model = RecModel(rec_config)
|
290 |
-
|
291 |
-
|
292 |
return rec_model
|
293 |
|
294 |
|
@@ -450,22 +462,20 @@ class TextRecognizer(object):
|
|
450 |
return loss
|
451 |
|
452 |
|
453 |
-
class TextEmbeddingModule(
|
454 |
-
@register_to_config
|
455 |
def __init__(self, font_path, use_fp16=False, device="cpu"):
|
456 |
super().__init__()
|
457 |
-
self.use_fp16 = use_fp16
|
458 |
-
self.device = device
|
459 |
# TODO: Learn if the recommended font file is free to use
|
460 |
self.font = ImageFont.truetype(font_path, 60)
|
461 |
-
self.frozen_CLIP_embedder_t3 = FrozenCLIPEmbedderT3(device=
|
462 |
-
self.embedding_manager = EmbeddingManager(self.frozen_CLIP_embedder_t3, use_fp16=
|
463 |
-
rec_model_dir = "OCR/ppv3_rec.pth"
|
464 |
-
self.text_predictor = create_predictor(rec_model_dir, device=
|
465 |
args = {}
|
466 |
args["rec_image_shape"] = "3, 48, 320"
|
467 |
args["rec_batch_num"] = 6
|
468 |
-
args["rec_char_dict_path"] = "OCR/ppocr_keys_v1.txt"
|
469 |
args["use_fp16"] = self.use_fp16
|
470 |
self.embedding_manager.recog = TextRecognizer(args, self.text_predictor)
|
471 |
|
|
|
69 |
from diffusers.utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor
|
70 |
from diffusers.configuration_utils import register_to_config, ConfigMixin
|
71 |
from diffusers.models.modeling_utils import ModelMixin
|
72 |
+
from huggingface_hub import hf_hub_download
|
73 |
|
74 |
|
75 |
checker = BasicTokenizer()
|
|
|
153 |
self.token_dim = token_dim
|
154 |
|
155 |
self.proj = nn.Linear(40 * 64, token_dim)
|
156 |
+
# self.proj.load_state_dict(load_file("proj.safetensors", device=str(embedder.device)))
|
157 |
if use_fp16:
|
158 |
self.proj = self.proj.to(dtype=torch.float16)
|
159 |
|
|
|
270 |
|
271 |
|
272 |
def create_predictor(model_dir=None, model_lang="ch", device="cpu", use_fp16=False):
|
273 |
+
if model_dir is None or not os.path.exists(model_dir):
|
274 |
+
try:
|
275 |
+
# Use the repo id from which the pipeline was loaded
|
276 |
+
model_dir = hf_hub_download(
|
277 |
+
repo_id="tolgacangoz/anytext",
|
278 |
+
filename="text_embedding_module/OCR/ppv3_rec.pth",
|
279 |
+
local_dir=".cache/diffusers",
|
280 |
+
local_dir_use_symlinks=True
|
281 |
+
)
|
282 |
+
except Exception as e:
|
283 |
+
raise ValueError(f"Could not download the model file: {e}")
|
284 |
+
|
285 |
+
if model_dir is not None and not os.path.exists(model_dir):
|
286 |
+
raise ValueError("not find model file path {}".format(model_dir))
|
287 |
|
288 |
if model_lang == "ch":
|
289 |
n_class = 6625
|
|
|
299 |
)
|
300 |
|
301 |
rec_model = RecModel(rec_config)
|
302 |
+
state_dict = torch.load(model_dir, map_location=device)
|
303 |
+
rec_model.load_state_dict(state_dict)
|
304 |
return rec_model
|
305 |
|
306 |
|
|
|
462 |
return loss
|
463 |
|
464 |
|
465 |
+
class TextEmbeddingModule(nn.Module):
|
466 |
+
# @register_to_config
|
467 |
def __init__(self, font_path, use_fp16=False, device="cpu"):
|
468 |
super().__init__()
|
|
|
|
|
469 |
# TODO: Learn if the recommended font file is free to use
|
470 |
self.font = ImageFont.truetype(font_path, 60)
|
471 |
+
self.frozen_CLIP_embedder_t3 = FrozenCLIPEmbedderT3(device=device, use_fp16=use_fp16)
|
472 |
+
self.embedding_manager = EmbeddingManager(self.frozen_CLIP_embedder_t3, use_fp16=use_fp16)
|
473 |
+
rec_model_dir = "./text_embedding_module/OCR/ppv3_rec.pth"
|
474 |
+
self.text_predictor = create_predictor(rec_model_dir, device=device, use_fp16=use_fp16).eval()
|
475 |
args = {}
|
476 |
args["rec_image_shape"] = "3, 48, 320"
|
477 |
args["rec_batch_num"] = 6
|
478 |
+
args["rec_char_dict_path"] = "./text_embedding_module/OCR/ppocr_keys_v1.txt"
|
479 |
args["use_fp16"] = self.use_fp16
|
480 |
self.embedding_manager.recog = TextRecognizer(args, self.text_predictor)
|
481 |
|