CharacterFactory / models /embedding_manager.py
wangqinghehe's picture
0515_first_upload
3ab16a9
raw
history blame
No virus
7.93 kB
import torch
from torch import nn
from einops import rearrange
import numpy as np
from typing import List
from models.id_embedding.helpers import get_rep_pos, shift_tensor_dim0
from models.id_embedding.meta_net import StyleVectorizer
from models.celeb_embeddings import _get_celeb_embeddings_basis
from functools import partial
import torch.nn.functional as F
import torch.nn as nn
import torch.nn.init as init
DEFAULT_PLACEHOLDER_TOKEN = ["*"]
PROGRESSIVE_SCALE = 2000
def get_clip_token_for_string(tokenizer, string):
batch_encoding = tokenizer(string, return_length=True, padding=True, truncation=True, return_overflowing_tokens=False, return_tensors="pt")
tokens = batch_encoding["input_ids"]
return tokens
def get_embedding_for_clip_token(embedder, token):
return embedder(token.unsqueeze(0))
class EmbeddingManagerId_adain(nn.Module):
def __init__(
self,
tokenizer,
text_encoder,
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"),
experiment_name = "normal_GAN",
num_embeds_per_token: int = 2,
loss_type: str = None,
mlp_depth: int = 2,
token_dim: int = 1024,
input_dim: int = 1024,
**kwargs
):
super().__init__()
self.device = device
self.num_es = num_embeds_per_token
self.get_token_for_string = partial(get_clip_token_for_string, tokenizer)
self.get_embedding_for_tkn = partial(get_embedding_for_clip_token, text_encoder.text_model.embeddings)
self.token_dim = token_dim
''' 1. Placeholder mapping dicts '''
self.placeholder_token = self.get_token_for_string("*")[0][1]
if experiment_name == "normal_GAN":
self.celeb_embeddings_mean, self.celeb_embeddings_std = _get_celeb_embeddings_basis(tokenizer, text_encoder, "datasets_face/good_names.txt")
elif experiment_name == "man_GAN":
self.celeb_embeddings_mean, self.celeb_embeddings_std = _get_celeb_embeddings_basis(tokenizer, text_encoder, "datasets_face/good_names_man.txt")
elif experiment_name == "woman_GAN":
self.celeb_embeddings_mean, self.celeb_embeddings_std = _get_celeb_embeddings_basis(tokenizer, text_encoder, "datasets_face/good_names_woman.txt")
else:
print("Hello, please notice this ^_^")
assert 0
print("now experiment_name:", experiment_name)
self.celeb_embeddings_mean = self.celeb_embeddings_mean.to(device)
self.celeb_embeddings_std = self.celeb_embeddings_std.to(device)
self.name_projection_layer = StyleVectorizer(input_dim, self.token_dim * self.num_es, depth=mlp_depth, lr_mul=0.1)
self.embedding_discriminator = Embedding_discriminator(self.token_dim * self.num_es, dropout_rate = 0.2)
self.adain_mode = 0
def forward(
self,
tokenized_text,
embedded_text,
name_batch,
random_embeddings = None,
timesteps = None,
):
if tokenized_text is not None:
batch_size, n, device = *tokenized_text.shape, tokenized_text.device
other_return_dict = {}
if random_embeddings is not None:
mlp_output_embedding = self.name_projection_layer(random_embeddings)
total_embedding = mlp_output_embedding.view(mlp_output_embedding.shape[0], 2, 1024)
if self.adain_mode == 0:
adained_total_embedding = total_embedding * self.celeb_embeddings_std + self.celeb_embeddings_mean
else:
adained_total_embedding = total_embedding
other_return_dict["total_embedding"] = total_embedding
other_return_dict["adained_total_embedding"] = adained_total_embedding
if name_batch is not None:
if isinstance(name_batch, list):
name_tokens = self.get_token_for_string(name_batch)[:, 1:3]
name_embeddings = self.get_embedding_for_tkn(name_tokens.to(random_embeddings.device))[0]
other_return_dict["name_embeddings"] = name_embeddings
else:
assert 0
if tokenized_text is not None:
placeholder_pos = get_rep_pos(tokenized_text,
[self.placeholder_token])
placeholder_pos = np.array(placeholder_pos)
if len(placeholder_pos) != 0:
batch_size = adained_total_embedding.shape[0]
end_index = min(batch_size, placeholder_pos.shape[0])
embedded_text[placeholder_pos[:, 0], placeholder_pos[:, 1]] = adained_total_embedding[:end_index,0,:]
embedded_text[placeholder_pos[:, 0], placeholder_pos[:, 1] + 1] = adained_total_embedding[:end_index,1,:]
return embedded_text, other_return_dict
def load(self, ckpt_path):
ckpt = torch.load(ckpt_path, map_location='cuda')
if ckpt.get("name_projection_layer") is not None:
self.name_projection_layer = ckpt.get("name_projection_layer").float()
print('[Embedding Manager] weights loaded.')
def save(self, ckpt_path):
save_dict = {}
save_dict["name_projection_layer"] = self.name_projection_layer
torch.save(save_dict, ckpt_path)
def trainable_projection_parameters(self):
trainable_list = []
trainable_list.extend(list(self.name_projection_layer.parameters()))
return trainable_list
class Embedding_discriminator(nn.Module):
def __init__(self, input_size, dropout_rate):
super(Embedding_discriminator, self).__init__()
self.input_size = input_size
self.fc1 = nn.Linear(input_size, 512)
self.fc2 = nn.Linear(512, 256)
self.fc3 = nn.Linear(256, 1)
self.LayerNorm1 = nn.LayerNorm(512)
self.LayerNorm2 = nn.LayerNorm(256)
self.leaky_relu = nn.LeakyReLU(0.2)
self.dropout_rate = dropout_rate
if self.dropout_rate > 0:
self.dropout1 = nn.Dropout(dropout_rate)
self.dropout2 = nn.Dropout(dropout_rate)
def forward(self, input):
x = input.view(-1, self.input_size)
if self.dropout_rate > 0:
x = self.leaky_relu(self.dropout1(self.fc1(x)))
else:
x = self.leaky_relu(self.fc1(x))
if self.dropout_rate > 0:
x = self.leaky_relu(self.dropout2(self.fc2(x)))
else:
x = self.leaky_relu(self.fc2(x))
x = self.fc3(x)
return x
def save(self, ckpt_path):
save_dict = {}
save_dict["fc1"] = self.fc1
save_dict["fc2"] = self.fc2
save_dict["fc3"] = self.fc3
save_dict["LayerNorm1"] = self.LayerNorm1
save_dict["LayerNorm2"] = self.LayerNorm2
save_dict["leaky_relu"] = self.leaky_relu
save_dict["dropout1"] = self.dropout1
save_dict["dropout2"] = self.dropout2
torch.save(save_dict, ckpt_path)
def load(self, ckpt_path):
ckpt = torch.load(ckpt_path, map_location='cuda')
if ckpt.get("first_name_proj_layer") is not None:
self.fc1 = ckpt.get("fc1").float()
self.fc2 = ckpt.get("fc2").float()
self.fc3 = ckpt.get("fc3").float()
self.LayerNorm1 = ckpt.get("LayerNorm1").float()
self.LayerNorm2 = ckpt.get("LayerNorm2").float()
self.leaky_relu = ckpt.get("leaky_relu").float()
self.dropout1 = ckpt.get("dropout1").float()
self.dropout2 = ckpt.get("dropout2").float()
print('[Embedding D] weights loaded.')