Open-Sora / scripts /inference.py
kadirnar's picture
Upload 98 files
e7d5680 verified
raw
history blame
No virus
3.98 kB
import os
import torch
import colossalai
import torch.distributed as dist
from mmengine.runner import set_random_seed
from opensora.datasets import save_sample
from opensora.registry import MODELS, SCHEDULERS, build_module
from opensora.utils.config_utils import parse_configs
from opensora.utils.misc import to_torch_dtype
from opensora.acceleration.parallel_states import set_sequence_parallel_group
from colossalai.cluster import DistCoordinator
def load_prompts(prompt_path):
with open(prompt_path, "r") as f:
prompts = [line.strip() for line in f.readlines()]
return prompts
def main():
# ======================================================
# 1. cfg and init distributed env
# ======================================================
cfg = parse_configs(training=False)
print(cfg)
# init distributed
colossalai.launch_from_torch({})
coordinator = DistCoordinator()
if coordinator.world_size > 1:
set_sequence_parallel_group(dist.group.WORLD)
enable_sequence_parallelism = True
else:
enable_sequence_parallelism = False
# ======================================================
# 2. runtime variables
# ======================================================
torch.set_grad_enabled(False)
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = to_torch_dtype(cfg.dtype)
set_random_seed(seed=cfg.seed)
prompts = load_prompts(cfg.prompt_path)
# ======================================================
# 3. build model & load weights
# ======================================================
# 3.1. build model
input_size = (cfg.num_frames, *cfg.image_size)
vae = build_module(cfg.vae, MODELS)
latent_size = vae.get_latent_size(input_size)
text_encoder = build_module(cfg.text_encoder, MODELS, device=device) # T5 must be fp32
model = build_module(
cfg.model,
MODELS,
input_size=latent_size,
in_channels=vae.out_channels,
caption_channels=text_encoder.output_dim,
model_max_length=text_encoder.model_max_length,
dtype=dtype,
enable_sequence_parallelism=enable_sequence_parallelism,
)
text_encoder.y_embedder = model.y_embedder # hack for classifier-free guidance
# 3.2. move to device & eval
vae = vae.to(device, dtype).eval()
model = model.to(device, dtype).eval()
# 3.3. build scheduler
scheduler = build_module(cfg.scheduler, SCHEDULERS)
# 3.4. support for multi-resolution
model_args = dict()
if cfg.multi_resolution:
image_size = cfg.image_size
hw = torch.tensor([image_size], device=device, dtype=dtype).repeat(cfg.batch_size, 1)
ar = torch.tensor([[image_size[0] / image_size[1]]], device=device, dtype=dtype).repeat(cfg.batch_size, 1)
model_args["data_info"] = dict(ar=ar, hw=hw)
# ======================================================
# 4. inference
# ======================================================
sample_idx = 0
save_dir = cfg.save_dir
os.makedirs(save_dir, exist_ok=True)
for i in range(0, len(prompts), cfg.batch_size):
batch_prompts = prompts[i : i + cfg.batch_size]
samples = scheduler.sample(
model,
text_encoder,
z_size=(vae.out_channels, *latent_size),
prompts=batch_prompts,
device=device,
additional_args=model_args,
)
samples = vae.decode(samples.to(dtype))
if coordinator.is_master():
for idx, sample in enumerate(samples):
print(f"Prompt: {batch_prompts[idx]}")
save_path = os.path.join(save_dir, f"sample_{sample_idx}")
save_sample(sample, fps=cfg.fps, save_path=save_path)
sample_idx += 1
if __name__ == "__main__":
main()