File size: 4,599 Bytes
7bc5051
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import argparse
import os
from typing import Optional, Union

import numpy as np
import torch
from tqdm import tqdm
from transformers import LlamaTokenizerFast, LlamaModel, CLIPTokenizer, CLIPTextModel
from dataset import config_utils
from dataset.config_utils import BlueprintGenerator, ConfigSanitizer
from dataset.image_video_dataset import ARCHITECTURE_FRAMEPACK, ItemInfo, save_text_encoder_output_cache_framepack
import cache_text_encoder_outputs
from frame_pack import hunyuan
from frame_pack.framepack_utils import load_text_encoder1, load_text_encoder2

import logging

from frame_pack.utils import crop_or_pad_yield_mask

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)


def encode_and_save_batch(
    tokenizer1: LlamaTokenizerFast,
    text_encoder1: LlamaModel,
    tokenizer2: CLIPTokenizer,
    text_encoder2: CLIPTextModel,
    batch: list[ItemInfo],
    device: torch.device,
):
    prompts = [item.caption for item in batch]

    # encode prompt
    # FramePack's encode_prompt_conds only supports single prompt, so we need to encode each prompt separately
    list_of_llama_vec = []
    list_of_llama_attention_mask = []
    list_of_clip_l_pooler = []
    for prompt in prompts:
        with torch.autocast(device_type=device.type, dtype=text_encoder1.dtype), torch.no_grad():
            # llama_vec, clip_l_pooler = hunyuan.encode_prompt_conds(prompts, text_encoder1, text_encoder2, tokenizer1, tokenizer2)
            llama_vec, clip_l_pooler = hunyuan.encode_prompt_conds(prompt, text_encoder1, text_encoder2, tokenizer1, tokenizer2)
            llama_vec, llama_attention_mask = crop_or_pad_yield_mask(llama_vec, length=512)

        list_of_llama_vec.append(llama_vec.squeeze(0))
        list_of_llama_attention_mask.append(llama_attention_mask.squeeze(0))
        list_of_clip_l_pooler.append(clip_l_pooler.squeeze(0))

    # save prompt cache
    for item, llama_vec, llama_attention_mask, clip_l_pooler in zip(
        batch, list_of_llama_vec, list_of_llama_attention_mask, list_of_clip_l_pooler
    ):
        # save llama_vec and clip_l_pooler to cache
        save_text_encoder_output_cache_framepack(item, llama_vec, llama_attention_mask, clip_l_pooler)


def main(args):
    device = args.device if args.device is not None else "cuda" if torch.cuda.is_available() else "cpu"
    device = torch.device(device)

    # Load dataset config
    blueprint_generator = BlueprintGenerator(ConfigSanitizer())
    logger.info(f"Load dataset config from {args.dataset_config}")
    user_config = config_utils.load_user_config(args.dataset_config)
    blueprint = blueprint_generator.generate(user_config, args, architecture=ARCHITECTURE_FRAMEPACK)
    train_dataset_group = config_utils.generate_dataset_group_by_blueprint(blueprint.dataset_group)

    datasets = train_dataset_group.datasets

    # prepare cache files and paths: all_cache_files_for_dataset = exisiting cache files, all_cache_paths_for_dataset = all cache paths in the dataset
    all_cache_files_for_dataset, all_cache_paths_for_dataset = cache_text_encoder_outputs.prepare_cache_files_and_paths(datasets)

    # load text encoder
    tokenizer1, text_encoder1 = load_text_encoder1(args, args.fp8_llm, device)
    tokenizer2, text_encoder2 = load_text_encoder2(args)
    text_encoder2.to(device)

    # Encode with Text Encoders
    logger.info("Encoding with Text Encoders")

    def encode_for_text_encoder(batch: list[ItemInfo]):
        encode_and_save_batch(tokenizer1, text_encoder1, tokenizer2, text_encoder2, batch, device)

    cache_text_encoder_outputs.process_text_encoder_batches(
        args.num_workers,
        args.skip_existing,
        args.batch_size,
        datasets,
        all_cache_files_for_dataset,
        all_cache_paths_for_dataset,
        encode_for_text_encoder,
    )

    # remove cache files not in dataset
    cache_text_encoder_outputs.post_process_cache_files(datasets, all_cache_files_for_dataset, all_cache_paths_for_dataset, args.keep_cache)


def framepack_setup_parser(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
    parser.add_argument("--text_encoder1", type=str, required=True, help="Text Encoder 1 directory")
    parser.add_argument("--text_encoder2", type=str, required=True, help="Text Encoder 2 directory")
    parser.add_argument("--fp8_llm", action="store_true", help="use fp8 for Text Encoder 1 (LLM)")
    return parser


if __name__ == "__main__":
    parser = cache_text_encoder_outputs.setup_parser_common()
    parser = framepack_setup_parser(parser)

    args = parser.parse_args()
    main(args)