Spaces:
Runtime error
Runtime error
''' | |
LinCIR | |
Copyright (c) 2023-present NAVER Corp. | |
CC BY-NC-4.0 (https://creativecommons.org/licenses/by-nc/4.0/) | |
''' | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from transformers import CLIPTextModelWithProjection, CLIPVisionModelWithProjection, CLIPImageProcessor, CLIPTokenizer | |
def build_text_encoder(args): | |
clip_model_dict = {'base32': 'openai/clip-vit-base-patch32', | |
'base': 'openai/clip-vit-base-patch16', | |
'large': 'openai/clip-vit-large-patch14', | |
'huge': 'laion/CLIP-ViT-H-14-laion2B-s32B-b79K', | |
'giga': 'Geonmo/CLIP-Giga-config-fixed', | |
'meta-large': 'facebook/metaclip-l14-fullcc2.5b', | |
'meta-huge': 'facebook/metaclip-h14-fullcc2.5b', | |
} | |
clip_preprocess = CLIPImageProcessor(crop_size={'height': 224, 'width': 224}, | |
do_center_crop=True, | |
do_convert_rgb=True, | |
do_normalize=True, | |
do_rescale=True, | |
do_resize=True, | |
image_mean=[0.48145466, 0.4578275, 0.40821073], | |
image_std=[0.26862954, 0.26130258, 0.27577711], | |
resample=3, | |
size={'shortest_edge': 224}, | |
) | |
clip_vision_model = CLIPVisionModelWithProjection.from_pretrained(clip_model_dict[args.clip_model_name], torch_dtype=torch.float16 if args.mixed_precision == 'fp16' else torch.float32, cache_dir=args.cache_dir) | |
clip_text_model = CLIPTextModelWithProjection.from_pretrained(clip_model_dict[args.clip_model_name], torch_dtype=torch.float16 if args.mixed_precision == 'fp16' else torch.float32, cache_dir=args.cache_dir) | |
tokenizer = CLIPTokenizer.from_pretrained('stabilityai/stable-diffusion-xl-base-1.0', subfolder='tokenizer_2', cache_dir=args.cache_dir) | |
tokenizer.add_special_tokens({'additional_special_tokens':["[$]"]}) # NOTE: 49408 | |
return clip_vision_model, clip_preprocess, clip_text_model, tokenizer | |
class Phi(nn.Module): | |
""" | |
Textual Inversion Phi network. | |
Takes as input the visual features of an image and outputs the pseudo-work embedding. | |
Copy-paste from https://github.com/miccunifi/SEARLE/blob/main/src/phi.py | |
""" | |
def __init__(self, input_dim: int, hidden_dim: int, output_dim: int, dropout: int): | |
super().__init__() | |
self.layers = nn.Sequential( | |
nn.Linear(input_dim, hidden_dim), | |
nn.GELU(), | |
nn.Dropout(p=dropout), | |
nn.Linear(hidden_dim, hidden_dim), | |
nn.GELU(), | |
nn.Dropout(p=dropout), | |
nn.Linear(hidden_dim, output_dim), | |
) | |
def forward(self, x): | |
#x = F.normalize(x, dim=-1) | |
return self.layers(x) | |
class EMAModel: | |
""" | |
Exponential Moving Average of models weights | |
""" | |
def __init__(self, parameters, decay=0.9999): | |
parameters = list(parameters) | |
self.shadow_params = [p.clone().detach() for p in parameters] | |
self.collected_params = None | |
self.decay = decay | |
self.optimization_step = 0 | |
def step(self, parameters): | |
parameters = list(parameters) | |
self.optimization_step += 1 | |
# Compute the decay factor for the exponential moving average. | |
value = (1 + self.optimization_step) / (10 + self.optimization_step) | |
one_minus_decay = 1 - min(self.decay, value) | |
for s_param, param in zip(self.shadow_params, parameters): | |
if param.requires_grad: | |
s_param.sub_(one_minus_decay * (s_param - param)) | |
else: | |
s_param.copy_(param) | |
torch.cuda.empty_cache() | |
def copy_to(self, parameters) -> None: | |
""" | |
Copy current averaged parameters into given collection of parameters. | |
Args: | |
parameters: Iterable of `torch.nn.Parameter`; the parameters to be | |
updated with the stored moving averages. If `None`, the | |
parameters with which this `ExponentialMovingAverage` was | |
initialized will be used. | |
""" | |
parameters = list(parameters) | |
for s_param, param in zip(self.shadow_params, parameters): | |
param.data.copy_(s_param.data) | |
def to(self, device=None, dtype=None) -> None: | |
r"""Move internal buffers of the ExponentialMovingAverage to `device`. | |
Args: | |
device: like `device` argument to `torch.Tensor.to` | |
""" | |
# .to() on the tensors handles None correctly | |
self.shadow_params = [ | |
p.to(device=device, dtype=dtype) if p.is_floating_point() else p.to(device=device) | |
for p in self.shadow_params | |
] | |
def state_dict(self) -> dict: | |
r""" | |
Returns the state of the ExponentialMovingAverage as a dict. | |
This method is used by accelerate during checkpointing to save the ema state dict. | |
""" | |
# Following PyTorch conventions, references to tensors are returned: | |
# "returns a reference to the state and not its copy!" - | |
# https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict | |
return { | |
"decay": self.decay, | |
"optimization_step": self.optimization_step, | |
"shadow_params": self.shadow_params, | |
"collected_params": self.collected_params, | |
} | |
def load_state_dict(self, state_dict: dict) -> None: | |
r""" | |
Loads the ExponentialMovingAverage state. | |
This method is used by accelerate during checkpointing to save the ema state dict. | |
Args: | |
state_dict (dict): EMA state. Should be an object returned | |
from a call to :meth:`state_dict`. | |
""" | |
# deepcopy, to be consistent with module API | |
state_dict = copy.deepcopy(state_dict) | |
self.decay = state_dict["decay"] | |
if self.decay < 0.0 or self.decay > 1.0: | |
raise ValueError("Decay must be between 0 and 1") | |
self.optimization_step = state_dict["optimization_step"] | |
if not isinstance(self.optimization_step, int): | |
raise ValueError("Invalid optimization_step") | |
self.shadow_params = state_dict["shadow_params"] | |
if not isinstance(self.shadow_params, list): | |
raise ValueError("shadow_params must be a list") | |
if not all(isinstance(p, torch.Tensor) for p in self.shadow_params): | |
raise ValueError("shadow_params must all be Tensors") | |
self.collected_params = state_dict["collected_params"] | |
if self.collected_params is not None: | |
if not isinstance(self.collected_params, list): | |
raise ValueError("collected_params must be a list") | |
if not all(isinstance(p, torch.Tensor) for p in self.collected_params): | |
raise ValueError("collected_params must all be Tensors") | |
if len(self.collected_params) != len(self.shadow_params): | |
raise ValueError("collected_params and shadow_params must have the same length") | |
class PIC2WORD(nn.Module): | |
def __init__(self, embed_dim=512, middle_dim=512, output_dim=512, n_layer=2, dropout=0.1): | |
super().__init__() | |
self.fc_out = nn.Linear(middle_dim, output_dim) | |
layers = [] | |
dim = embed_dim | |
for _ in range(n_layer): | |
block = [] | |
block.append(nn.Linear(dim, middle_dim)) | |
block.append(nn.Dropout(dropout)) | |
block.append(nn.ReLU()) | |
dim = middle_dim | |
layers.append(nn.Sequential(*block)) | |
self.layers = nn.Sequential(*layers) | |
def forward(self, x: torch.Tensor): | |
for layer in self.layers: | |
x = layer(x) | |
return self.fc_out(x) | |