File size: 4,599 Bytes
7bc5051 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 |
import argparse
import os
from typing import Optional, Union
import numpy as np
import torch
from tqdm import tqdm
from transformers import LlamaTokenizerFast, LlamaModel, CLIPTokenizer, CLIPTextModel
from dataset import config_utils
from dataset.config_utils import BlueprintGenerator, ConfigSanitizer
from dataset.image_video_dataset import ARCHITECTURE_FRAMEPACK, ItemInfo, save_text_encoder_output_cache_framepack
import cache_text_encoder_outputs
from frame_pack import hunyuan
from frame_pack.framepack_utils import load_text_encoder1, load_text_encoder2
import logging
from frame_pack.utils import crop_or_pad_yield_mask
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
def encode_and_save_batch(
tokenizer1: LlamaTokenizerFast,
text_encoder1: LlamaModel,
tokenizer2: CLIPTokenizer,
text_encoder2: CLIPTextModel,
batch: list[ItemInfo],
device: torch.device,
):
prompts = [item.caption for item in batch]
# encode prompt
# FramePack's encode_prompt_conds only supports single prompt, so we need to encode each prompt separately
list_of_llama_vec = []
list_of_llama_attention_mask = []
list_of_clip_l_pooler = []
for prompt in prompts:
with torch.autocast(device_type=device.type, dtype=text_encoder1.dtype), torch.no_grad():
# llama_vec, clip_l_pooler = hunyuan.encode_prompt_conds(prompts, text_encoder1, text_encoder2, tokenizer1, tokenizer2)
llama_vec, clip_l_pooler = hunyuan.encode_prompt_conds(prompt, text_encoder1, text_encoder2, tokenizer1, tokenizer2)
llama_vec, llama_attention_mask = crop_or_pad_yield_mask(llama_vec, length=512)
list_of_llama_vec.append(llama_vec.squeeze(0))
list_of_llama_attention_mask.append(llama_attention_mask.squeeze(0))
list_of_clip_l_pooler.append(clip_l_pooler.squeeze(0))
# save prompt cache
for item, llama_vec, llama_attention_mask, clip_l_pooler in zip(
batch, list_of_llama_vec, list_of_llama_attention_mask, list_of_clip_l_pooler
):
# save llama_vec and clip_l_pooler to cache
save_text_encoder_output_cache_framepack(item, llama_vec, llama_attention_mask, clip_l_pooler)
def main(args):
device = args.device if args.device is not None else "cuda" if torch.cuda.is_available() else "cpu"
device = torch.device(device)
# Load dataset config
blueprint_generator = BlueprintGenerator(ConfigSanitizer())
logger.info(f"Load dataset config from {args.dataset_config}")
user_config = config_utils.load_user_config(args.dataset_config)
blueprint = blueprint_generator.generate(user_config, args, architecture=ARCHITECTURE_FRAMEPACK)
train_dataset_group = config_utils.generate_dataset_group_by_blueprint(blueprint.dataset_group)
datasets = train_dataset_group.datasets
# prepare cache files and paths: all_cache_files_for_dataset = exisiting cache files, all_cache_paths_for_dataset = all cache paths in the dataset
all_cache_files_for_dataset, all_cache_paths_for_dataset = cache_text_encoder_outputs.prepare_cache_files_and_paths(datasets)
# load text encoder
tokenizer1, text_encoder1 = load_text_encoder1(args, args.fp8_llm, device)
tokenizer2, text_encoder2 = load_text_encoder2(args)
text_encoder2.to(device)
# Encode with Text Encoders
logger.info("Encoding with Text Encoders")
def encode_for_text_encoder(batch: list[ItemInfo]):
encode_and_save_batch(tokenizer1, text_encoder1, tokenizer2, text_encoder2, batch, device)
cache_text_encoder_outputs.process_text_encoder_batches(
args.num_workers,
args.skip_existing,
args.batch_size,
datasets,
all_cache_files_for_dataset,
all_cache_paths_for_dataset,
encode_for_text_encoder,
)
# remove cache files not in dataset
cache_text_encoder_outputs.post_process_cache_files(datasets, all_cache_files_for_dataset, all_cache_paths_for_dataset, args.keep_cache)
def framepack_setup_parser(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
parser.add_argument("--text_encoder1", type=str, required=True, help="Text Encoder 1 directory")
parser.add_argument("--text_encoder2", type=str, required=True, help="Text Encoder 2 directory")
parser.add_argument("--fp8_llm", action="store_true", help="use fp8 for Text Encoder 1 (LLM)")
return parser
if __name__ == "__main__":
parser = cache_text_encoder_outputs.setup_parser_common()
parser = framepack_setup_parser(parser)
args = parser.parse_args()
main(args)
|