| """ | |
| This file is used to extract feature for visulization during training | |
| """ | |
| import os | |
| import sys | |
| sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) | |
| import torch | |
| import os | |
| import numpy as np | |
| from tqdm import tqdm | |
| import libs.autoencoder | |
| from libs.clip import FrozenCLIPEmbedder | |
| from libs.t5 import T5Embedder | |
| def main(): | |
| prompts = [ | |
| 'A road with traffic lights, street lights and cars.', | |
| 'A bus driving in a city area with traffic signs.', | |
| 'A bus pulls over to the curb close to an intersection.', | |
| 'A group of people are walking and one is holding an umbrella.', | |
| 'A baseball player taking a swing at an incoming ball.', | |
| 'A dog next to a white cat with black-tipped ears.', | |
| 'A tiger standing on a rooftop while singing and jamming on an electric guitar under a spotlight. anime illustration.', | |
| 'A bird wearing headphones and speaking into a high-end microphone in a recording studio.', | |
| 'A bus made of cardboard.', | |
| 'A tower in the mountains.', | |
| 'Two cups of coffee, one with latte art of a cat. The other has latter art of a bird.', | |
| 'Oil painting of a robot made of sushi, holding chopsticks.', | |
| 'Portrait of a dog wearing a hat and holding a flag that has a yin-yang symbol on it.', | |
| 'A teddy bear wearing a motorcycle helmet and cape is standing in front of Loch Awe with Kilchurn Castle behind him. dslr photo.', | |
| 'A man standing on the moon', | |
| ] | |
| save_dir = f'run_vis' | |
| os.makedirs(save_dir, exist_ok=True) | |
| device = 'cuda' | |
| llm = 'clip' | |
| if llm=='clip': | |
| clip = FrozenCLIPEmbedder() | |
| clip.eval() | |
| clip.to(device) | |
| elif llm=='t5': | |
| t5 = T5Embedder(device=device) | |
| else: | |
| raise NotImplementedError | |
| if llm=='clip': | |
| latent, latent_and_others = clip.encode(prompts) | |
| token_embedding = latent_and_others['token_embedding'] | |
| token_mask = latent_and_others['token_mask'] | |
| token = latent_and_others['tokens'] | |
| elif llm=='t5': | |
| latent, latent_and_others = t5.get_text_embeddings(prompts) | |
| token_embedding = latent_and_others['token_embedding'].to(torch.float32) * 10.0 | |
| token_mask = latent_and_others['token_mask'] | |
| token = latent_and_others['tokens'] | |
| for i in range(len(prompts)): | |
| data = {'promt': prompts[i], | |
| 'token_embedding': token_embedding[i].detach().cpu().numpy(), | |
| 'token_mask': token_mask[i].detach().cpu().numpy(), | |
| 'token': token[i].detach().cpu().numpy()} | |
| np.save(os.path.join(save_dir, f'{i}.npy'), data) | |
| if __name__ == '__main__': | |
| main() | |