Text-to-Image
Diffusers
Safetensors
tolgacangoz commited on
Commit
481bca0
·
verified ·
1 Parent(s): 651cfb4

Upload anytext.py

Browse files
Files changed (1) hide show
  1. auxiliary_latent_module/anytext.py +233 -26
auxiliary_latent_module/anytext.py CHANGED
@@ -35,7 +35,6 @@ import PIL.Image
35
  import torch
36
  import torch.nn.functional as F
37
  from easydict import EasyDict as edict
38
- from frozen_clip_embedder_t3 import FrozenCLIPEmbedderT3
39
  from huggingface_hub import hf_hub_download
40
  from ocr_recog.RecModel import RecModel
41
  from PIL import Image, ImageDraw, ImageFont
@@ -325,12 +324,6 @@ def adjust_image(box, img):
325
  return result
326
 
327
 
328
- """
329
- mask: numpy.ndarray, mask of textual, HWC
330
- src_img: torch.Tensor, source image, CHW
331
- """
332
-
333
-
334
  def crop_image(src_img, mask):
335
  box = min_bounding_rect(mask)
336
  result = adjust_image(box, src_img)
@@ -526,11 +519,225 @@ class TextRecognizer(object):
526
  return loss
527
 
528
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
529
  class TextEmbeddingModule(nn.Module):
530
- # @register_to_config
531
  def __init__(self, font_path, use_fp16=False, device="cpu"):
532
  super().__init__()
533
- # TODO: Learn if the recommended font file is free to use
534
  self.font = ImageFont.truetype(font_path, 60)
535
  self.use_fp16 = use_fp16
536
  self.device = device
@@ -724,10 +931,11 @@ class TextEmbeddingModule(nn.Module):
724
  ratio = min(W * 0.9 / text_width, H * 0.9 / text_height)
725
  new_font = font.font_variant(size=int(g_size * ratio))
726
 
727
- text_width, text_height = new_font.getsize(text)
728
- offset_x, offset_y = new_font.getoffset(text)
 
729
  x = (img.width - text_width) // 2
730
- y = (img.height - text_height) // 2 - offset_y // 2
731
  draw.text((x, y), text, font=new_font, fill="white")
732
  img = np.expand_dims(np.array(img), axis=2).astype(np.float64)
733
  return img
@@ -1019,7 +1227,7 @@ class AnyTextPipeline(
1019
  Args:
1020
  vae ([`AutoencoderKL`]):
1021
  Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
1022
- text_encoder ([`~transformers.CLIPTextModel`]):
1023
  Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
1024
  tokenizer ([`~transformers.CLIPTokenizer`]):
1025
  A `CLIPTokenizer` to tokenize text.
@@ -1049,26 +1257,25 @@ class AnyTextPipeline(
1049
  self,
1050
  font_path: str,
1051
  vae: AutoencoderKL,
1052
- text_encoder: CLIPTextModel,
1053
  tokenizer: CLIPTokenizer,
1054
  unet: UNet2DConditionModel,
1055
  controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel],
1056
  scheduler: KarrasDiffusionSchedulers,
1057
  safety_checker: StableDiffusionSafetyChecker,
1058
  feature_extractor: CLIPImageProcessor,
 
1059
  trust_remote_code: bool = False,
1060
- text_embedding_module: TextEmbeddingModule = None,
1061
- auxiliary_latent_module: AuxiliaryLatentModule = None,
1062
  image_encoder: CLIPVisionModelWithProjection = None,
1063
  requires_safety_checker: bool = True,
1064
  ):
1065
  super().__init__()
1066
- self.text_embedding_module = TextEmbeddingModule(
1067
- use_fp16=unet.dtype == torch.float16, device=unet.device, font_path=font_path
1068
- )
1069
- self.auxiliary_latent_module = AuxiliaryLatentModule(
1070
- vae=vae, use_fp16=unet.dtype == torch.float16, device=unet.device, font_path=font_path
1071
- )
1072
 
1073
  if safety_checker is None and requires_safety_checker:
1074
  logger.warning(
@@ -1099,8 +1306,8 @@ class AnyTextPipeline(
1099
  safety_checker=safety_checker,
1100
  feature_extractor=feature_extractor,
1101
  image_encoder=image_encoder,
1102
- text_embedding_module=self.text_embedding_module,
1103
- auxiliary_latent_module=self.auxiliary_latent_module,
1104
  )
1105
  self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
1106
  self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
@@ -1968,7 +2175,7 @@ class AnyTextPipeline(
1968
  self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
1969
  )
1970
  draw_pos = draw_pos.to(device=device) if isinstance(draw_pos, torch.Tensor) else draw_pos
1971
- prompt_embeds, negative_prompt_embeds, text_info, np_hint = self.text_embedding_module(
1972
  prompt,
1973
  texts,
1974
  negative_prompt,
@@ -2210,6 +2417,6 @@ class AnyTextPipeline(
2210
 
2211
  def to(self, *args, **kwargs):
2212
  super().to(*args, **kwargs)
2213
- self.text_embedding_module.to(*args, **kwargs)
2214
  self.auxiliary_latent_module.to(*args, **kwargs)
2215
  return self
 
35
  import torch
36
  import torch.nn.functional as F
37
  from easydict import EasyDict as edict
 
38
  from huggingface_hub import hf_hub_download
39
  from ocr_recog.RecModel import RecModel
40
  from PIL import Image, ImageDraw, ImageFont
 
324
  return result
325
 
326
 
 
 
 
 
 
 
327
  def crop_image(src_img, mask):
328
  box = min_bounding_rect(mask)
329
  result = adjust_image(box, src_img)
 
519
  return loss
520
 
521
 
522
+ import torch
523
+ from torch import nn
524
+ from transformers import CLIPTextModel, CLIPTokenizer
525
+ from transformers.modeling_attn_mask_utils import _create_4d_causal_attention_mask, _prepare_4d_attention_mask
526
+
527
+
528
+ class AbstractEncoder(nn.Module):
529
+ def __init__(self):
530
+ super().__init__()
531
+
532
+ def encode(self, *args, **kwargs):
533
+ raise NotImplementedError
534
+
535
+
536
+ class FrozenCLIPEmbedderT3(AbstractEncoder):
537
+ """Uses the CLIP transformer encoder for text (from Hugging Face)"""
538
+
539
+ def __init__(
540
+ self,
541
+ version="openai/clip-vit-large-patch14",
542
+ device="cpu",
543
+ max_length=77,
544
+ freeze=True,
545
+ use_fp16=False,
546
+ ):
547
+ super().__init__()
548
+ self.tokenizer = CLIPTokenizer.from_pretrained(version)
549
+ self.transformer = CLIPTextModel.from_pretrained(
550
+ version, use_safetensors=True, torch_dtype=torch.float16 if use_fp16 else torch.float32
551
+ ).to(device)
552
+ self.device = device
553
+ self.max_length = max_length
554
+ if freeze:
555
+ self.freeze()
556
+
557
+ def embedding_forward(
558
+ self,
559
+ input_ids=None,
560
+ position_ids=None,
561
+ inputs_embeds=None,
562
+ embedding_manager=None,
563
+ ):
564
+ seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
565
+ if position_ids is None:
566
+ position_ids = self.position_ids[:, :seq_length]
567
+ if inputs_embeds is None:
568
+ inputs_embeds = self.token_embedding(input_ids)
569
+ if embedding_manager is not None:
570
+ inputs_embeds = embedding_manager(input_ids, inputs_embeds)
571
+ position_embeddings = self.position_embedding(position_ids)
572
+ embeddings = inputs_embeds + position_embeddings
573
+ return embeddings
574
+
575
+ self.transformer.text_model.embeddings.forward = embedding_forward.__get__(
576
+ self.transformer.text_model.embeddings
577
+ )
578
+
579
+ def encoder_forward(
580
+ self,
581
+ inputs_embeds,
582
+ attention_mask=None,
583
+ causal_attention_mask=None,
584
+ output_attentions=None,
585
+ output_hidden_states=None,
586
+ return_dict=None,
587
+ ):
588
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
589
+ output_hidden_states = (
590
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
591
+ )
592
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
593
+ encoder_states = () if output_hidden_states else None
594
+ all_attentions = () if output_attentions else None
595
+ hidden_states = inputs_embeds
596
+ for idx, encoder_layer in enumerate(self.layers):
597
+ if output_hidden_states:
598
+ encoder_states = encoder_states + (hidden_states,)
599
+ layer_outputs = encoder_layer(
600
+ hidden_states,
601
+ attention_mask,
602
+ causal_attention_mask,
603
+ output_attentions=output_attentions,
604
+ )
605
+ hidden_states = layer_outputs[0]
606
+ if output_attentions:
607
+ all_attentions = all_attentions + (layer_outputs[1],)
608
+ if output_hidden_states:
609
+ encoder_states = encoder_states + (hidden_states,)
610
+ return hidden_states
611
+
612
+ self.transformer.text_model.encoder.forward = encoder_forward.__get__(self.transformer.text_model.encoder)
613
+
614
+ def text_encoder_forward(
615
+ self,
616
+ input_ids=None,
617
+ attention_mask=None,
618
+ position_ids=None,
619
+ output_attentions=None,
620
+ output_hidden_states=None,
621
+ return_dict=None,
622
+ embedding_manager=None,
623
+ ):
624
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
625
+ output_hidden_states = (
626
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
627
+ )
628
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
629
+ if input_ids is None:
630
+ raise ValueError("You have to specify either input_ids")
631
+ input_shape = input_ids.size()
632
+ input_ids = input_ids.view(-1, input_shape[-1])
633
+ hidden_states = self.embeddings(
634
+ input_ids=input_ids, position_ids=position_ids, embedding_manager=embedding_manager
635
+ )
636
+ # CLIP's text model uses causal mask, prepare it here.
637
+ # https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324
638
+ causal_attention_mask = _create_4d_causal_attention_mask(
639
+ input_shape, hidden_states.dtype, device=hidden_states.device
640
+ )
641
+ # expand attention_mask
642
+ if attention_mask is not None:
643
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
644
+ attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype)
645
+ last_hidden_state = self.encoder(
646
+ inputs_embeds=hidden_states,
647
+ attention_mask=attention_mask,
648
+ causal_attention_mask=causal_attention_mask,
649
+ output_attentions=output_attentions,
650
+ output_hidden_states=output_hidden_states,
651
+ return_dict=return_dict,
652
+ )
653
+ last_hidden_state = self.final_layer_norm(last_hidden_state)
654
+ return last_hidden_state
655
+
656
+ self.transformer.text_model.forward = text_encoder_forward.__get__(self.transformer.text_model)
657
+
658
+ def transformer_forward(
659
+ self,
660
+ input_ids=None,
661
+ attention_mask=None,
662
+ position_ids=None,
663
+ output_attentions=None,
664
+ output_hidden_states=None,
665
+ return_dict=None,
666
+ embedding_manager=None,
667
+ ):
668
+ return self.text_model(
669
+ input_ids=input_ids,
670
+ attention_mask=attention_mask,
671
+ position_ids=position_ids,
672
+ output_attentions=output_attentions,
673
+ output_hidden_states=output_hidden_states,
674
+ return_dict=return_dict,
675
+ embedding_manager=embedding_manager,
676
+ )
677
+
678
+ self.transformer.forward = transformer_forward.__get__(self.transformer)
679
+
680
+ def freeze(self):
681
+ self.transformer = self.transformer.eval()
682
+ for param in self.parameters():
683
+ param.requires_grad = False
684
+
685
+ def forward(self, text, **kwargs):
686
+ batch_encoding = self.tokenizer(
687
+ text,
688
+ truncation=False,
689
+ max_length=self.max_length,
690
+ return_length=True,
691
+ return_overflowing_tokens=False,
692
+ padding="longest",
693
+ return_tensors="pt",
694
+ )
695
+ input_ids = batch_encoding["input_ids"]
696
+ tokens_list = self.split_chunks(input_ids)
697
+ z_list = []
698
+ for tokens in tokens_list:
699
+ tokens = tokens.to(self.device)
700
+ _z = self.transformer(input_ids=tokens, **kwargs)
701
+ z_list += [_z]
702
+ return torch.cat(z_list, dim=1)
703
+
704
+ def encode(self, text, **kwargs):
705
+ return self(text, **kwargs)
706
+
707
+ def split_chunks(self, input_ids, chunk_size=75):
708
+ tokens_list = []
709
+ bs, n = input_ids.shape
710
+ id_start = input_ids[:, 0].unsqueeze(1) # dim --> [bs, 1]
711
+ id_end = input_ids[:, -1].unsqueeze(1)
712
+ if n == 2: # empty caption
713
+ tokens_list.append(torch.cat((id_start,) + (id_end,) * (chunk_size + 1), dim=1))
714
+
715
+ trimmed_encoding = input_ids[:, 1:-1]
716
+ num_full_groups = (n - 2) // chunk_size
717
+
718
+ for i in range(num_full_groups):
719
+ group = trimmed_encoding[:, i * chunk_size : (i + 1) * chunk_size]
720
+ group_pad = torch.cat((id_start, group, id_end), dim=1)
721
+ tokens_list.append(group_pad)
722
+
723
+ remaining_columns = (n - 2) % chunk_size
724
+ if remaining_columns > 0:
725
+ remaining_group = trimmed_encoding[:, -remaining_columns:]
726
+ padding_columns = chunk_size - remaining_group.shape[1]
727
+ padding = id_end.expand(bs, padding_columns)
728
+ remaining_group_pad = torch.cat((id_start, remaining_group, padding, id_end), dim=1)
729
+ tokens_list.append(remaining_group_pad)
730
+ return tokens_list
731
+
732
+ def to(self, *args, **kwargs):
733
+ self.transformer = self.transformer.to(*args, **kwargs)
734
+ self.device = self.transformer.device
735
+ return self
736
+
737
+
738
  class TextEmbeddingModule(nn.Module):
 
739
  def __init__(self, font_path, use_fp16=False, device="cpu"):
740
  super().__init__()
 
741
  self.font = ImageFont.truetype(font_path, 60)
742
  self.use_fp16 = use_fp16
743
  self.device = device
 
931
  ratio = min(W * 0.9 / text_width, H * 0.9 / text_height)
932
  new_font = font.font_variant(size=int(g_size * ratio))
933
 
934
+ left, top, right, bottom = new_font.getbbox(text)
935
+ text_width = right - left
936
+ text_height = bottom - top
937
  x = (img.width - text_width) // 2
938
+ y = (img.height - text_height) // 2 - top // 2
939
  draw.text((x, y), text, font=new_font, fill="white")
940
  img = np.expand_dims(np.array(img), axis=2).astype(np.float64)
941
  return img
 
1227
  Args:
1228
  vae ([`AutoencoderKL`]):
1229
  Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
1230
+ text_encoder ([`~anytext.TextEmbeddingModule`]):
1231
  Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
1232
  tokenizer ([`~transformers.CLIPTokenizer`]):
1233
  A `CLIPTokenizer` to tokenize text.
 
1257
  self,
1258
  font_path: str,
1259
  vae: AutoencoderKL,
1260
+ text_encoder: TextEmbeddingModule,
1261
  tokenizer: CLIPTokenizer,
1262
  unet: UNet2DConditionModel,
1263
  controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel],
1264
  scheduler: KarrasDiffusionSchedulers,
1265
  safety_checker: StableDiffusionSafetyChecker,
1266
  feature_extractor: CLIPImageProcessor,
1267
+ auxiliary_latent_module: AuxiliaryLatentModule,
1268
  trust_remote_code: bool = False,
 
 
1269
  image_encoder: CLIPVisionModelWithProjection = None,
1270
  requires_safety_checker: bool = True,
1271
  ):
1272
  super().__init__()
1273
+ # self.text_embedding_module = TextEmbeddingModule(
1274
+ # use_fp16=unet.dtype == torch.float16, device=unet.device, font_path=font_path
1275
+ # )
1276
+ # self.auxiliary_latent_module = AuxiliaryLatentModule(
1277
+ # vae=vae, use_fp16=unet.dtype == torch.float16, device=unet.device, font_path=font_path
1278
+ # )
1279
 
1280
  if safety_checker is None and requires_safety_checker:
1281
  logger.warning(
 
1306
  safety_checker=safety_checker,
1307
  feature_extractor=feature_extractor,
1308
  image_encoder=image_encoder,
1309
+ # text_embedding_module=self.text_embedding_module,
1310
+ auxiliary_latent_module=auxiliary_latent_module,
1311
  )
1312
  self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
1313
  self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
 
2175
  self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
2176
  )
2177
  draw_pos = draw_pos.to(device=device) if isinstance(draw_pos, torch.Tensor) else draw_pos
2178
+ prompt_embeds, negative_prompt_embeds, text_info, np_hint = self.text_encoder(
2179
  prompt,
2180
  texts,
2181
  negative_prompt,
 
2417
 
2418
  def to(self, *args, **kwargs):
2419
  super().to(*args, **kwargs)
2420
+ # self.text_embedding_module.to(*args, **kwargs)
2421
  self.auxiliary_latent_module.to(*args, **kwargs)
2422
  return self