''' LinCIR Copyright (c) 2023-present NAVER Corp. CC BY-NC-4.0 (https://creativecommons.org/licenses/by-nc/4.0/) ''' import torch import torch.nn as nn import torch.nn.functional as F from transformers import CLIPTextModelWithProjection, CLIPVisionModelWithProjection, CLIPImageProcessor, CLIPTokenizer def build_text_encoder(args): clip_model_dict = {'base32': 'openai/clip-vit-base-patch32', 'base': 'openai/clip-vit-base-patch16', 'large': 'openai/clip-vit-large-patch14', 'huge': 'laion/CLIP-ViT-H-14-laion2B-s32B-b79K', 'giga': 'Geonmo/CLIP-Giga-config-fixed', 'meta-large': 'facebook/metaclip-l14-fullcc2.5b', 'meta-huge': 'facebook/metaclip-h14-fullcc2.5b', } clip_preprocess = CLIPImageProcessor(crop_size={'height': 224, 'width': 224}, do_center_crop=True, do_convert_rgb=True, do_normalize=True, do_rescale=True, do_resize=True, image_mean=[0.48145466, 0.4578275, 0.40821073], image_std=[0.26862954, 0.26130258, 0.27577711], resample=3, size={'shortest_edge': 224}, ) clip_vision_model = CLIPVisionModelWithProjection.from_pretrained(clip_model_dict[args.clip_model_name], torch_dtype=torch.float16 if args.mixed_precision == 'fp16' else torch.float32, cache_dir=args.cache_dir) clip_text_model = CLIPTextModelWithProjection.from_pretrained(clip_model_dict[args.clip_model_name], torch_dtype=torch.float16 if args.mixed_precision == 'fp16' else torch.float32, cache_dir=args.cache_dir) tokenizer = CLIPTokenizer.from_pretrained('stabilityai/stable-diffusion-xl-base-1.0', subfolder='tokenizer_2', cache_dir=args.cache_dir) tokenizer.add_special_tokens({'additional_special_tokens':["[$]"]}) # NOTE: 49408 return clip_vision_model, clip_preprocess, clip_text_model, tokenizer class Phi(nn.Module): """ Textual Inversion Phi network. Takes as input the visual features of an image and outputs the pseudo-work embedding. Copy-paste from https://github.com/miccunifi/SEARLE/blob/main/src/phi.py """ def __init__(self, input_dim: int, hidden_dim: int, output_dim: int, dropout: int): super().__init__() self.layers = nn.Sequential( nn.Linear(input_dim, hidden_dim), nn.GELU(), nn.Dropout(p=dropout), nn.Linear(hidden_dim, hidden_dim), nn.GELU(), nn.Dropout(p=dropout), nn.Linear(hidden_dim, output_dim), ) def forward(self, x): #x = F.normalize(x, dim=-1) return self.layers(x) class EMAModel: """ Exponential Moving Average of models weights """ def __init__(self, parameters, decay=0.9999): parameters = list(parameters) self.shadow_params = [p.clone().detach() for p in parameters] self.collected_params = None self.decay = decay self.optimization_step = 0 @torch.no_grad() def step(self, parameters): parameters = list(parameters) self.optimization_step += 1 # Compute the decay factor for the exponential moving average. value = (1 + self.optimization_step) / (10 + self.optimization_step) one_minus_decay = 1 - min(self.decay, value) for s_param, param in zip(self.shadow_params, parameters): if param.requires_grad: s_param.sub_(one_minus_decay * (s_param - param)) else: s_param.copy_(param) torch.cuda.empty_cache() def copy_to(self, parameters) -> None: """ Copy current averaged parameters into given collection of parameters. Args: parameters: Iterable of `torch.nn.Parameter`; the parameters to be updated with the stored moving averages. If `None`, the parameters with which this `ExponentialMovingAverage` was initialized will be used. """ parameters = list(parameters) for s_param, param in zip(self.shadow_params, parameters): param.data.copy_(s_param.data) def to(self, device=None, dtype=None) -> None: r"""Move internal buffers of the ExponentialMovingAverage to `device`. Args: device: like `device` argument to `torch.Tensor.to` """ # .to() on the tensors handles None correctly self.shadow_params = [ p.to(device=device, dtype=dtype) if p.is_floating_point() else p.to(device=device) for p in self.shadow_params ] def state_dict(self) -> dict: r""" Returns the state of the ExponentialMovingAverage as a dict. This method is used by accelerate during checkpointing to save the ema state dict. """ # Following PyTorch conventions, references to tensors are returned: # "returns a reference to the state and not its copy!" - # https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict return { "decay": self.decay, "optimization_step": self.optimization_step, "shadow_params": self.shadow_params, "collected_params": self.collected_params, } def load_state_dict(self, state_dict: dict) -> None: r""" Loads the ExponentialMovingAverage state. This method is used by accelerate during checkpointing to save the ema state dict. Args: state_dict (dict): EMA state. Should be an object returned from a call to :meth:`state_dict`. """ # deepcopy, to be consistent with module API state_dict = copy.deepcopy(state_dict) self.decay = state_dict["decay"] if self.decay < 0.0 or self.decay > 1.0: raise ValueError("Decay must be between 0 and 1") self.optimization_step = state_dict["optimization_step"] if not isinstance(self.optimization_step, int): raise ValueError("Invalid optimization_step") self.shadow_params = state_dict["shadow_params"] if not isinstance(self.shadow_params, list): raise ValueError("shadow_params must be a list") if not all(isinstance(p, torch.Tensor) for p in self.shadow_params): raise ValueError("shadow_params must all be Tensors") self.collected_params = state_dict["collected_params"] if self.collected_params is not None: if not isinstance(self.collected_params, list): raise ValueError("collected_params must be a list") if not all(isinstance(p, torch.Tensor) for p in self.collected_params): raise ValueError("collected_params must all be Tensors") if len(self.collected_params) != len(self.shadow_params): raise ValueError("collected_params and shadow_params must have the same length") class PIC2WORD(nn.Module): def __init__(self, embed_dim=512, middle_dim=512, output_dim=512, n_layer=2, dropout=0.1): super().__init__() self.fc_out = nn.Linear(middle_dim, output_dim) layers = [] dim = embed_dim for _ in range(n_layer): block = [] block.append(nn.Linear(dim, middle_dim)) block.append(nn.Dropout(dropout)) block.append(nn.ReLU()) dim = middle_dim layers.append(nn.Sequential(*block)) self.layers = nn.Sequential(*layers) def forward(self, x: torch.Tensor): for layer in self.layers: x = layer(x) return self.fc_out(x)