import torch from transformers import CLIPTextModel, CLIPTextModelWithProjection class SDXLTextEncoder(torch.nn.Module): """Wrapper around HuggingFace text encoders for SDXL. Creates two text encoders (a CLIPTextModel and CLIPTextModelWithProjection) that behave like one. Args: model_name (str): Name of the model's text encoders to load. Defaults to 'stabilityai/stable-diffusion-xl-base-1.0'. encode_latents_in_fp16 (bool): Whether to encode latents in fp16. Defaults to True. """ def __init__(self, file_path_or_name='stabilityai/stable-diffusion-xl-base-1.0', encode_latents_in_fp16=True, torch_dtype=None): super().__init__() if torch_dtype is None: torch_dtype = torch.float16 if encode_latents_in_fp16 else None self.dtype = torch_dtype self.text_encoder = CLIPTextModel.from_pretrained(file_path_or_name, subfolder='text_encoder', torch_dtype=torch_dtype) self.text_encoder_2 = CLIPTextModelWithProjection.from_pretrained(file_path_or_name, subfolder='text_encoder_2', torch_dtype=torch_dtype) @classmethod def from_pretrained(cls, file_path_or_name='stabilityai/stable-diffusion-xl-base-1.0', encode_latents_in_fp16=True, torch_dtype=None, **kwargs): """ Create a new instance of SDXLTextEncoder with specified pretrained model parameters. Args: file_path_or_name (str): Name or path of the model's text encoders to load. encode_latents_in_fp16 (bool): Whether to encode latents in fp16. torch_dtype (torch.dtype): Data type for model parameters. **kwargs: Additional keyword arguments. Returns: SDXLTextEncoder: A new instance of SDXLTextEncoder. """ # Update arguments with any additional kwargs init_args = {'file_path_or_name': file_path_or_name, 'encode_latents_in_fp16': encode_latents_in_fp16, 'torch_dtype': torch_dtype} init_args.update(kwargs) # Create and return a new instance of SDXLTextEncoder return cls(**init_args) @property def device(self): return self.text_encoder.device def forward(self, tokenized_text): # first text encoder conditioning = self.text_encoder(tokenized_text[0], output_hidden_states=True).hidden_states[-2] # second text encoder text_encoder_2_out = self.text_encoder_2(tokenized_text[1], output_hidden_states=True) pooled_conditioning = text_encoder_2_out[0] # (batch_size, 1280) conditioning_2 = text_encoder_2_out.hidden_states[-2] # (batch_size, 77, 1280) conditioning = torch.concat([conditioning, conditioning_2], dim=-1) return conditioning, pooled_conditioning