|
|
|
import numpy as np |
|
import os |
|
import sys |
|
from joblib import Parallel, delayed |
|
from tqdm import tqdm |
|
import torch |
|
import json |
|
import pandas as pd |
|
import argparse |
|
|
|
|
|
|
|
from model.custom_sd import StableDiffusionPipeline |
|
from model.custom_vd import TextToVideoSDPipeline |
|
from model.custom_ad import AudioLDMPipeline |
|
|
|
|
|
|
|
def save_to_path(emb, path): |
|
"""Save embeddings to disk.""" |
|
try: |
|
with open(path, 'wb') as wf: |
|
np.save(wf, emb) |
|
except: |
|
print("Error with", path) |
|
return path |
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
batch_size = 128 |
|
|
|
dtype = torch.float16 if torch.cuda.is_available() else torch.float32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
data_path = sys.argv[1] |
|
modality = sys.argv[2] |
|
clip_output_dir = sys.argv[3] |
|
ckpt_path = sys.argv[4] |
|
|
|
if not os.path.exists(clip_output_dir): |
|
os.makedirs(clip_output_dir, exist_ok=True) |
|
|
|
|
|
existing_files = set([f.strip('.npy') for f in os.listdir(clip_output_dir)]) |
|
|
|
caption_list = [] |
|
name_list = [] |
|
if modality == 'audio': |
|
print('extract audio caption embedding') |
|
with open(data_path, 'r', encoding='utf-8') as f: |
|
data = json.load(f) |
|
|
|
for row in tqdm(data, total=len(data)): |
|
one_audio_name, one_caption = row["audio_name"], row["caption"] |
|
if one_audio_name not in existing_files: |
|
caption_list.append(one_caption) |
|
name_list.append(one_audio_name) |
|
pipe = AudioLDMPipeline.from_pretrained(ckpt_path, torch_dtype=dtype) |
|
if not torch.cuda.is_available(): |
|
print('WARNING: using CPU, this will be slow!') |
|
else: |
|
pipe = pipe.to("cuda") |
|
elif modality == 'image': |
|
print('extract image caption embedding') |
|
with open(data_path, 'r', encoding='utf-8') as f: |
|
data = json.load(f) |
|
for row in tqdm(data, total=len(data)): |
|
one_image_name, one_caption = row["image_name"], row["caption"] |
|
if one_image_name not in existing_files: |
|
caption_list.append(one_caption) |
|
name_list.append(one_image_name) |
|
pipe = StableDiffusionPipeline.from_pretrained(ckpt_path, torch_dtype=dtype) |
|
if not torch.cuda.is_available(): |
|
print('WARNING: using CPU, this will be slow!') |
|
else: |
|
pipe = pipe.to("cuda") |
|
elif modality == 'video': |
|
print('extract video caption embedding') |
|
with open(data_path, 'r', encoding='utf-8') as f: |
|
data = json.load(f) |
|
for row in tqdm(data, total=len(data)): |
|
one_video_name, one_caption = row["video_name"], row["caption"] |
|
if one_video_name not in existing_files: |
|
caption_list.append(one_caption) |
|
name_list.append(one_video_name) |
|
pipe = TextToVideoSDPipeline.from_pretrained(ckpt_path, torch_dtype=dtype) |
|
if not torch.cuda.is_available(): |
|
print('WARNING: using CPU, this will be slow!') |
|
else: |
|
pipe = pipe.to("cuda") |
|
|
|
print('Extract embeddings in batches.') |
|
num_batches = int(np.ceil(len(caption_list) / batch_size)) |
|
for i in tqdm(range(num_batches)): |
|
start_idx = i * batch_size |
|
end_idx = start_idx + batch_size |
|
batch_captions = caption_list[start_idx:end_idx] |
|
batch_ids = name_list[start_idx:end_idx] |
|
prompt_embeds = pipe(batch_captions, return_prompts_only=True).detach().cpu().numpy() |
|
|
|
|
|
Parallel(n_jobs=8)(delayed(save_to_path)( |
|
prompt_embeds[j, :, ...], os.path.join(clip_output_dir, f'{batch_ids[j]}.npy') |
|
) for j in range(prompt_embeds.shape[0])) |
|
|