Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import torch.nn as nn | |
from functools import partial | |
import clip | |
from einops import rearrange, repeat | |
from transformers import CLIPTokenizer, CLIPTextModel | |
import kornia | |
import numpy as np | |
import os | |
def embedding_forward( | |
self, | |
input_ids = None, | |
position_ids = None, | |
name_batch = None, | |
inputs_embeds = None, | |
embedding_manager = None, | |
only_embedding=True, | |
random_embeddings = None, | |
timesteps = None, | |
) -> torch.Tensor: | |
seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2] | |
if inputs_embeds is None: | |
inputs_embeds = self.token_embedding(input_ids) | |
if only_embedding: | |
return inputs_embeds | |
if embedding_manager is not None: | |
inputs_embeds, other_return_dict = embedding_manager(input_ids, inputs_embeds, name_batch, random_embeddings, timesteps) | |
if position_ids is None: | |
position_ids = self.position_ids[:, :seq_length] | |
position_embeddings = self.position_embedding(position_ids) | |
embeddings = inputs_embeds + position_embeddings | |
return embeddings, other_return_dict | |
def _get_celeb_embeddings_basis(tokenizer, text_encoder, good_names_txt): | |
device = text_encoder.device | |
max_length = 77 | |
with open(good_names_txt, "r") as f: | |
celeb_names = f.read().splitlines() | |
''' get tokens and embeddings ''' | |
all_embeddings = [] | |
for name in celeb_names: | |
batch_encoding = tokenizer(name, truncation=True, return_tensors="pt") | |
tokens = batch_encoding["input_ids"].to(device)[:, 1:3] | |
embeddings = text_encoder.text_model.embeddings(input_ids=tokens, only_embedding=True) | |
all_embeddings.append(embeddings) | |
all_embeddings: torch.Tensor = torch.cat(all_embeddings, dim=0) | |
print('[all_embeddings loaded] shape =', all_embeddings.shape, | |
'max:', all_embeddings.max(), | |
'min={}', all_embeddings.min()) | |
name_emb_mean = all_embeddings.mean(0) | |
name_emb_std = all_embeddings.std(0) | |
print('[name_emb_mean loaded] shape =', name_emb_mean.shape, | |
'max:', name_emb_mean.max(), | |
'min={}', name_emb_mean.min()) | |
return name_emb_mean, name_emb_std | |