""" huggingface model adapter Wraps HuggingFace transformers (https://github.com/huggingface/transformers) models for use as a text tower in CLIP model. """ import re import torch import torch.nn as nn from torch import TensorType try: import transformers from transformers import AutoModel, AutoTokenizer, AutoConfig, PretrainedConfig from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, \ BaseModelOutputWithPoolingAndCrossAttentions except ImportError as e: transformers = None class BaseModelOutput: pass class PretrainedConfig: pass from .hf_configs import arch_dict # utils def _camel2snake(s): return re.sub(r'(? TensorType: attn_mask = (x != self.config.pad_token_id).long() out = self.transformer(input_ids=x, attention_mask=attn_mask) pooled_out = self.pooler(out, attn_mask) return self.proj(pooled_out) def lock(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True): if not unlocked_layers: # full freezing for n, p in self.transformer.named_parameters(): p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False return encoder = self.transformer.encoder if hasattr(self.transformer, 'encoder') else self.transformer layer_list = getattr(encoder, arch_dict[self.config.model_type]["config_names"]["layer_attr"]) print(f"Unlocking {unlocked_layers}/{len(layer_list) + 1} layers of hf model") embeddings = getattr( self.transformer, arch_dict[self.config.model_type]["config_names"]["token_embeddings_attr"]) modules = [embeddings, *layer_list][:-unlocked_layers] # freeze layers for module in modules: for n, p in module.named_parameters(): p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False @torch.jit.ignore def set_grad_checkpointing(self, enable=True): self.transformer.gradient_checkpointing_enable() def init_parameters(self): pass