| |
|
|
| import argparse |
| import os |
|
|
| import torch |
| import torch.distributed as dist |
| from accelerate.logging import get_logger |
|
|
| from fastvideo.utils.load import load_text_encoder |
|
|
| logger = get_logger(__name__) |
|
|
|
|
| def main(args): |
| local_rank = int(os.getenv("RANK", 0)) |
| world_size = int(os.getenv("WORLD_SIZE", 1)) |
| print("world_size", world_size, "local rank", local_rank) |
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| torch.cuda.set_device(local_rank) |
| if not dist.is_initialized(): |
| dist.init_process_group(backend="nccl", |
| init_method="env://", |
| world_size=world_size, |
| rank=local_rank) |
|
|
| text_encoder = load_text_encoder(args.model_type, |
| args.model_path, |
| device=device) |
| autocast_type = torch.float16 if args.model_type == "hunyuan" else torch.bfloat16 |
| |
| |
| os.makedirs(os.path.join(args.output_dir, "validation"), exist_ok=True) |
| os.makedirs( |
| os.path.join(args.output_dir, "validation", "prompt_attention_mask"), |
| exist_ok=True, |
| ) |
| os.makedirs(os.path.join(args.output_dir, "validation", "prompt_embed"), |
| exist_ok=True) |
|
|
| with open(args.validation_prompt_txt, "r", encoding="utf-8") as file: |
| lines = file.readlines() |
| prompts = [line.strip() for line in lines] |
| for prompt in prompts: |
| with torch.inference_mode(): |
| with torch.autocast("cuda", dtype=autocast_type): |
| prompt_embeds, prompt_attention_mask = text_encoder.encode_prompt( |
| prompt) |
| file_name = prompt.split(".")[0] |
| prompt_embed_path = os.path.join(args.output_dir, "validation", |
| "prompt_embed", |
| f"{file_name}.pt") |
| prompt_attention_mask_path = os.path.join( |
| args.output_dir, |
| "validation", |
| "prompt_attention_mask", |
| f"{file_name}.pt", |
| ) |
| torch.save(prompt_embeds[0], prompt_embed_path) |
| torch.save(prompt_attention_mask[0], |
| prompt_attention_mask_path) |
| print(f"sample {file_name} saved") |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser() |
| |
| parser.add_argument("--model_path", type=str, default="data/mochi") |
| parser.add_argument("--model_type", type=str, default="mochi") |
| parser.add_argument("--validation_prompt_txt", type=str) |
| parser.add_argument( |
| "--output_dir", |
| type=str, |
| default=None, |
| help= |
| "The output directory where the model predictions and checkpoints will be written.", |
| ) |
| args = parser.parse_args() |
| main(args) |
|
|