from transformers import AutoConfig, AutoModel, PretrainedConfig, CLIPTextConfig, CLIPVisionConfig, PreTrainedModel, CLIPTextModelWithProjection, CLIPVisionModelWithProjection from transformers.utils import ModelOutput import torch import open_clip from dataclasses import dataclass import safetensors.torch from peft import get_peft_config, get_peft_model, LoraConfig, TaskType import os HF_SAFE_WEIGHTS_NAME = "open_clip_model.safetensors" HF_SAFE_WEIGHTS_NAME_PRIOR = "prior_model.safetensors" @dataclass class PriorTransformerOutput(ModelOutput): """ The output of [`PriorTransformer`]. Args: predicted_image_embedding (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`): The predicted CLIP image embedding conditioned on the CLIP text embedding input. """ predicted_image_embedding: torch.FloatTensor @dataclass class TextEncoderOutput(ModelOutput): """ Output class for CLIPTextEncoderOnly model to store the outputs in a Hugging Face transformer style. Attributes: prompt_embeds (torch.Tensor): The embeddings of the input prompts. last_hidden_states (torch.Tensor): The last hidden states from the model. """ text_embeds: torch.FloatTensor = None last_hidden_state: torch.FloatTensor = None class CLIPTextEncoderOnlyConfig(CLIPTextConfig): model_type = "clip_custom_text_model" def __init__(self, model_name: str = None, pretrained: bool = True, frozen: bool = False, lora: dict = None, **kwargs): self.model_name = model_name self.pretrained = pretrained self.frozen = frozen self.lora = lora super().__init__(**kwargs) class CLIPTextEncoderOnly(PreTrainedModel): config_class = CLIPTextEncoderOnlyConfig def __init__(self, config): """ Initializes the Hugging Face text encoder for CLIP model, inheriting from PreTrainedModel. :param model_name: The name or path of the pretrained model. :param pretrained: Whether to load the pretrained weights. """ super().__init__(config) if config.pretrained: self.model = CLIPTextModelWithProjection.from_pretrained(config.model_name) else: base_cfg = CLIPTextConfig.from_pretrained(config.model_name) self.model = CLIPTextModelWithProjection(base_cfg) if config.lora: l_config = LoraConfig( r=config.lora.lora_r, lora_alpha=config.lora.lora_alpha, target_modules=[ "k_proj", "v_proj", "q_proj", "out_proj", "fc1", "fc2", "visual_projection", "text_projection" ], lora_dropout=config.lora.lora_dropout, bias="lora_only", ) self.model = get_peft_model(self.model, l_config) def forward(self, input_ids, attention_mask=None, position_ids=None): """ Forward pass of the model. :param input_ids: Indices of input sequence tokens in the vocabulary. :param attention_mask: Mask to avoid performing attention on padding token indices. :param token_type_ids: Segment token indices to indicate first and second portions of the inputs. :return: Outputs of the model. """ outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, output_hidden_states=True) return TextEncoderOutput(text_embeds=outputs.text_embeds, last_hidden_state=outputs.last_hidden_state) class CustomTextEncoderOnlyConfig(CLIPTextConfig): model_type = "whole_custom_text_model" def __init__(self, model_name: str = None, pretrained: bool = True, frozen: bool = False, output_hidden_size: int = 512, last_hidden_state: bool = False, lora: dict = None, **kwargs): self.model_name = model_name self.pretrained = pretrained self.frozen = frozen self.output_hidden_size = output_hidden_size self.last_hidden_state = last_hidden_state self.lora = lora super().__init__(**kwargs) class CustomTextEncoderOnly(PreTrainedModel): config_class = CustomTextEncoderOnlyConfig def __init__(self, config): """ Initializes the Hugging Face text encoder for CLIP model, inheriting from PreTrainedModel. :param model_name: The name or path of the pretrained model. :param pretrained: Whether to load the pretrained weights. """ super().__init__(config) self.last_hidden_state = config.last_hidden_state if config.pretrained: self.model = AutoModel.from_pretrained(config.model_name) if config.frozen: for param in self.model.parameters(): param.requires_grad = False else: self.model = AutoModel(config) self.fc1 = torch.nn.Linear(self.model.config.hidden_size, config.output_hidden_size) if config.last_hidden_state: self.fc2 = torch.nn.Linear(self.model.config.hidden_size, config.output_hidden_size) if config.lora: l_config = LoraConfig( task_type=TaskType.FEATURE_EXTRACTION, r=config.lora.lora_r, lora_alpha=config.lora.lora_alpha, lora_dropout=config.lora.lora_dropout, bias="lora_only", ) self.model = get_peft_model(self.model, l_config) def forward(self, input_ids, attention_mask=None, token_type_ids=None): """ Forward pass of the model. :param input_ids: Indices of input sequence tokens in the vocabulary. :param attention_mask: Mask to avoid performing attention on padding token indices. :param token_type_ids: Segment token indices to indicate first and second portions of the inputs. :return: Outputs of the model. """ outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, output_hidden_states=True) text_embeds = self.fc1(outputs[1]) last_hidden_state = None if self.last_hidden_state: last_hidden_state = self.fc2(outputs[0]) else: last_hidden_state = outputs[0] return TextEncoderOutput(text_embeds=text_embeds, last_hidden_state=last_hidden_state) class CLIPVisionEncoderOnlyConfig(PretrainedConfig): model_type = "clip_custom_vision_model" def __init__(self, model_name: str = None, pretrained: bool = True, frozen: bool = False, lora: dict = None, **kwargs): self.model_name = model_name self.pretrained = pretrained self.frozen = frozen self.lora = lora super().__init__(**kwargs) class CLIPVisionEncoderOnly(PreTrainedModel): config_class = CLIPVisionEncoderOnlyConfig def __init__(self, config): """ Initializes the Hugging Face text encoder for CLIP model, inheriting from PreTrainedModel. :param model_name: The name or path of the pretrained model. :param pretrained: Whether to load the pretrained weights. """ super().__init__(config) if config.pretrained: self.model = CLIPVisionModelWithProjection.from_pretrained(config.model_name) else: base_cfg = CLIPVisionConfig.from_pretrained(config.model_name) self.model = CLIPVisionModelWithProjection(base_cfg) if config.lora: l_config = LoraConfig( r=config.lora.lora_r, lora_alpha=config.lora.lora_alpha, target_modules=[ "k_proj", "v_proj", "q_proj", "out_proj", "fc1", "fc2", "visual_projection", "text_projection" ], lora_dropout=config.lora.lora_dropout, bias="lora_only", ) self.model = get_peft_model(self.model, l_config) def forward(self, data): """ Forward pass of the model. """ return self.model(**data).image_embeds def parameters(self): return self.model.parameters() class OpenCLIPVisionEncoderOnly(torch.nn.Module): def __init__(self, model_name: str, pretrained: bool = True, frozen: bool = False, lora: dict = None): """ Initializes the Hugging Face text encoder for CLIP model, inheriting from PreTrainedModel. :param model_name: The name or path of the pretrained model. :param pretrained: Whether to load the pretrained weights. """ super().__init__() if pretrained: model, _ = open_clip.create_model_from_pretrained(f"hf-hub:{model_name}") model = model.visual else: raise NotImplemented self.model = model if lora: l_config = LoraConfig( r=lora.lora_r, lora_alpha=lora.lora_alpha, target_modules=[ "k_proj", "v_proj", "q_proj", "out_proj", "fc1", "fc2", "visual_projection", "text_projection" ], lora_dropout=lora.lora_dropout, bias="lora_only", ) self.model = get_peft_model(self.model, l_config) def forward(self, image): """ Forward pass of the model. """ return self.model(image) def save_pretrained(self, save_dir): tensors = self.model.state_dict() safetensors.torch.save_file(tensors, save_dir / HF_SAFE_WEIGHTS_NAME) class CustomPriorModel(torch.nn.Module): def __init__(self, in_hidden_state, out_hidden_state): """ Initializes the Hugging Face text encoder for CLIP model, inheriting from PreTrainedModel. :param model_name: The name or path of the pretrained model. :param pretrained: Whether to load the pretrained weights. """ super().__init__() mid_hidden_state = max(in_hidden_state, out_hidden_state) self.fc1 = torch.nn.Linear(in_hidden_state*2, mid_hidden_state) self.relu = torch.nn.ReLU() self.fc2 = torch.nn.Linear(mid_hidden_state, out_hidden_state) def reinitialize_model(self): for name, param in self.named_parameters(): if param.requires_grad: if len(param.shape) > 1: torch.nn.init.xavier_uniform_(param) else: if 'weight' in name: torch.nn.init.normal_(param) else: torch.nn.init.zeros_(param) def forward(self, feats): """ Forward pass of the model. """ return PriorTransformerOutput(predicted_image_embedding=self.fc2(self.relu(self.fc1(feats)))) def save_pretrained(self, save_dir): pass # tensors = self.state_dict() # safetensors.torch.save_file(tensors, os.path.join(save_dir, HF_SAFE_WEIGHTS_NAME_PRIOR)) def test_text_model(register=False, upload=False): # register the classes if register: AutoConfig.register("clip_custom_text_model", CLIPTextEncoderOnlyConfig) AutoModel.register(CLIPTextEncoderOnlyConfig, CLIPTextEncoderOnly) CLIPTextEncoderOnlyConfig.register_for_auto_class() CLIPTextEncoderOnly.register_for_auto_class("AutoModel") if upload: # Initialize the model model_name = "openai/clip-vit-base-patch32" pretrained=True lora=None cfg = CLIPTextEncoderOnlyConfig(model_name=model_name, pretrained=pretrained, lora=lora) model = CLIPTextEncoderOnly(cfg) model.push_to_hub("test-text-hf-upload") model = CLIPTextEncoderOnly.from_pretrained("mpatel57/test-text-hf-upload", force_download=True) def test_custom_text_model(register=False, upload=False): # register the classes if register: AutoConfig.register("whole_custom_text_model", CustomTextEncoderOnlyConfig) AutoModel.register(CustomTextEncoderOnlyConfig, CustomTextEncoderOnly) CustomTextEncoderOnlyConfig.register_for_auto_class() CustomTextEncoderOnly.register_for_auto_class("AutoModel") if upload: # Initialize the model model_name = "google-bert/bert-base-uncased" pretrained=True frozen=False output_hidden_size=512 last_hidden_state=False lora=None cfg = CustomTextEncoderOnlyConfig(model_name=model_name, pretrained=pretrained, frozen=frozen, output_hidden_size=output_hidden_size, last_hidden_state=last_hidden_state, lora=lora) model = CustomTextEncoderOnly(cfg) model.push_to_hub("test-text-hf-upload") model = CustomTextEncoderOnly.from_pretrained("mpatel57/test-text-hf-upload", force_download=True) def test_vision_model(register=False, upload=False): # register the classes if register: AutoConfig.register("clip_custom_vision_model", CLIPVisionEncoderOnlyConfig) AutoModel.register(CLIPVisionEncoderOnlyConfig, CLIPVisionEncoderOnly) CLIPVisionEncoderOnlyConfig.register_for_auto_class() CLIPVisionEncoderOnly.register_for_auto_class("AutoModel") if upload: # Initialize the model model_name = "openai/clip-vit-base-patch32" pretrained=True lora=None cfg = CLIPVisionEncoderOnlyConfig(model_name=model_name, pretrained=pretrained, lora=lora) model = CLIPVisionEncoderOnly(cfg) model.push_to_hub("test-vision-hf-upload") model = CLIPVisionEncoderOnly.from_pretrained("mpatel57/test-vision-hf-upload", force_download=True) if __name__ == "__main__": test_custom_text_model(register=False, upload=True) # test_text_model(register=False, upload=True) # test_vision_model(register=False, upload=True)