| | import argparse |
| | import os |
| | from typing import Optional, Union |
| |
|
| | import numpy as np |
| | import torch |
| | from tqdm import tqdm |
| |
|
| | from dataset import config_utils |
| | from dataset.config_utils import BlueprintGenerator, ConfigSanitizer |
| | import accelerate |
| |
|
| | from dataset.image_video_dataset import ARCHITECTURE_WAN, ItemInfo, save_text_encoder_output_cache_wan |
| |
|
| | |
| | from wan.configs import wan_t2v_14B |
| |
|
| | import cache_text_encoder_outputs |
| | import logging |
| |
|
| | from utils.model_utils import str_to_dtype |
| | from wan.modules.t5 import T5EncoderModel |
| |
|
| | logger = logging.getLogger(__name__) |
| | logging.basicConfig(level=logging.INFO) |
| |
|
| |
|
| | def encode_and_save_batch( |
| | text_encoder: T5EncoderModel, batch: list[ItemInfo], device: torch.device, accelerator: Optional[accelerate.Accelerator] |
| | ): |
| | prompts = [item.caption for item in batch] |
| | |
| |
|
| | |
| | with torch.no_grad(): |
| | if accelerator is not None: |
| | with accelerator.autocast(): |
| | context = text_encoder(prompts, device) |
| | else: |
| | context = text_encoder(prompts, device) |
| |
|
| | |
| | for item, ctx in zip(batch, context): |
| | save_text_encoder_output_cache_wan(item, ctx) |
| |
|
| |
|
| | def main(args): |
| | device = args.device if args.device is not None else "cuda" if torch.cuda.is_available() else "cpu" |
| | device = torch.device(device) |
| |
|
| | |
| | blueprint_generator = BlueprintGenerator(ConfigSanitizer()) |
| | logger.info(f"Load dataset config from {args.dataset_config}") |
| | user_config = config_utils.load_user_config(args.dataset_config) |
| | blueprint = blueprint_generator.generate(user_config, args, architecture=ARCHITECTURE_WAN) |
| | train_dataset_group = config_utils.generate_dataset_group_by_blueprint(blueprint.dataset_group) |
| |
|
| | datasets = train_dataset_group.datasets |
| |
|
| | |
| | config = wan_t2v_14B.t2v_14B |
| | accelerator = None |
| | if args.fp8_t5: |
| | accelerator = accelerate.Accelerator(mixed_precision="bf16" if config.t5_dtype == torch.bfloat16 else "fp16") |
| |
|
| | |
| | all_cache_files_for_dataset, all_cache_paths_for_dataset = cache_text_encoder_outputs.prepare_cache_files_and_paths(datasets) |
| |
|
| | |
| | logger.info(f"Loading T5: {args.t5}") |
| | text_encoder = T5EncoderModel( |
| | text_len=config.text_len, dtype=config.t5_dtype, device=device, weight_path=args.t5, fp8=args.fp8_t5 |
| | ) |
| |
|
| | |
| | logger.info("Encoding with T5") |
| |
|
| | def encode_for_text_encoder(batch: list[ItemInfo]): |
| | encode_and_save_batch(text_encoder, batch, device, accelerator) |
| |
|
| | cache_text_encoder_outputs.process_text_encoder_batches( |
| | args.num_workers, |
| | args.skip_existing, |
| | args.batch_size, |
| | datasets, |
| | all_cache_files_for_dataset, |
| | all_cache_paths_for_dataset, |
| | encode_for_text_encoder, |
| | ) |
| | del text_encoder |
| |
|
| | |
| | cache_text_encoder_outputs.post_process_cache_files(datasets, all_cache_files_for_dataset, all_cache_paths_for_dataset) |
| |
|
| |
|
| | def wan_setup_parser(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: |
| | parser.add_argument("--t5", type=str, default=None, required=True, help="text encoder (T5) checkpoint path") |
| | parser.add_argument("--fp8_t5", action="store_true", help="use fp8 for Text Encoder model") |
| | return parser |
| |
|
| |
|
| | if __name__ == "__main__": |
| | parser = cache_text_encoder_outputs.setup_parser_common() |
| | parser = wan_setup_parser(parser) |
| |
|
| | args = parser.parse_args() |
| | main(args) |
| |
|