Text-to-Image
Diffusers
Safetensors
tolgacangoz commited on
Commit
c747cba
·
verified ·
1 Parent(s): 04fc5b3

Upload anytext.py

Browse files
Files changed (1) hide show
  1. text_embedding_module/anytext.py +25 -15
text_embedding_module/anytext.py CHANGED
@@ -69,6 +69,7 @@ from diffusers.utils import (
69
  from diffusers.utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor
70
  from diffusers.configuration_utils import register_to_config, ConfigMixin
71
  from diffusers.models.modeling_utils import ModelMixin
 
72
 
73
 
74
  checker = BasicTokenizer()
@@ -152,7 +153,7 @@ class EmbeddingManager(nn.Module):
152
  self.token_dim = token_dim
153
 
154
  self.proj = nn.Linear(40 * 64, token_dim)
155
- self.proj.load_state_dict(load_file("proj.safetensors", device=str(embedder.device)))
156
  if use_fp16:
157
  self.proj = self.proj.to(dtype=torch.float16)
158
 
@@ -269,9 +270,20 @@ def crop_image(src_img, mask):
269
 
270
 
271
  def create_predictor(model_dir=None, model_lang="ch", device="cpu", use_fp16=False):
272
- model_file_path = model_dir
273
- if model_file_path is not None and not os.path.exists(model_file_path):
274
- raise ValueError("not find model file path {}".format(model_file_path))
 
 
 
 
 
 
 
 
 
 
 
275
 
276
  if model_lang == "ch":
277
  n_class = 6625
@@ -287,8 +299,8 @@ def create_predictor(model_dir=None, model_lang="ch", device="cpu", use_fp16=Fal
287
  )
288
 
289
  rec_model = RecModel(rec_config)
290
- if model_file_path is not None:
291
- rec_model.load_state_dict(torch.load(model_file_path, map_location=device))
292
  return rec_model
293
 
294
 
@@ -450,22 +462,20 @@ class TextRecognizer(object):
450
  return loss
451
 
452
 
453
- class TextEmbeddingModule(ModelMixin, ConfigMixin):
454
- @register_to_config
455
  def __init__(self, font_path, use_fp16=False, device="cpu"):
456
  super().__init__()
457
- self.use_fp16 = use_fp16
458
- self.device = device
459
  # TODO: Learn if the recommended font file is free to use
460
  self.font = ImageFont.truetype(font_path, 60)
461
- self.frozen_CLIP_embedder_t3 = FrozenCLIPEmbedderT3(device=self.device, use_fp16=self.use_fp16)
462
- self.embedding_manager = EmbeddingManager(self.frozen_CLIP_embedder_t3, use_fp16=self.use_fp16)
463
- rec_model_dir = "OCR/ppv3_rec.pth"
464
- self.text_predictor = create_predictor(rec_model_dir, device=self.device, use_fp16=self.use_fp16).eval()
465
  args = {}
466
  args["rec_image_shape"] = "3, 48, 320"
467
  args["rec_batch_num"] = 6
468
- args["rec_char_dict_path"] = "OCR/ppocr_keys_v1.txt"
469
  args["use_fp16"] = self.use_fp16
470
  self.embedding_manager.recog = TextRecognizer(args, self.text_predictor)
471
 
 
69
  from diffusers.utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor
70
  from diffusers.configuration_utils import register_to_config, ConfigMixin
71
  from diffusers.models.modeling_utils import ModelMixin
72
+ from huggingface_hub import hf_hub_download
73
 
74
 
75
  checker = BasicTokenizer()
 
153
  self.token_dim = token_dim
154
 
155
  self.proj = nn.Linear(40 * 64, token_dim)
156
+ # self.proj.load_state_dict(load_file("proj.safetensors", device=str(embedder.device)))
157
  if use_fp16:
158
  self.proj = self.proj.to(dtype=torch.float16)
159
 
 
270
 
271
 
272
  def create_predictor(model_dir=None, model_lang="ch", device="cpu", use_fp16=False):
273
+ if model_dir is None or not os.path.exists(model_dir):
274
+ try:
275
+ # Use the repo id from which the pipeline was loaded
276
+ model_dir = hf_hub_download(
277
+ repo_id="tolgacangoz/anytext",
278
+ filename="text_embedding_module/OCR/ppv3_rec.pth",
279
+ local_dir=".cache/diffusers",
280
+ local_dir_use_symlinks=True
281
+ )
282
+ except Exception as e:
283
+ raise ValueError(f"Could not download the model file: {e}")
284
+
285
+ if model_dir is not None and not os.path.exists(model_dir):
286
+ raise ValueError("not find model file path {}".format(model_dir))
287
 
288
  if model_lang == "ch":
289
  n_class = 6625
 
299
  )
300
 
301
  rec_model = RecModel(rec_config)
302
+ state_dict = torch.load(model_dir, map_location=device)
303
+ rec_model.load_state_dict(state_dict)
304
  return rec_model
305
 
306
 
 
462
  return loss
463
 
464
 
465
+ class TextEmbeddingModule(nn.Module):
466
+ # @register_to_config
467
  def __init__(self, font_path, use_fp16=False, device="cpu"):
468
  super().__init__()
 
 
469
  # TODO: Learn if the recommended font file is free to use
470
  self.font = ImageFont.truetype(font_path, 60)
471
+ self.frozen_CLIP_embedder_t3 = FrozenCLIPEmbedderT3(device=device, use_fp16=use_fp16)
472
+ self.embedding_manager = EmbeddingManager(self.frozen_CLIP_embedder_t3, use_fp16=use_fp16)
473
+ rec_model_dir = "./text_embedding_module/OCR/ppv3_rec.pth"
474
+ self.text_predictor = create_predictor(rec_model_dir, device=device, use_fp16=use_fp16).eval()
475
  args = {}
476
  args["rec_image_shape"] = "3, 48, 320"
477
  args["rec_batch_num"] = 6
478
+ args["rec_char_dict_path"] = "./text_embedding_module/OCR/ppocr_keys_v1.txt"
479
  args["use_fp16"] = self.use_fp16
480
  self.embedding_manager.recog = TextRecognizer(args, self.text_predictor)
481