World_Model / URSA /inference.py
BryanW's picture
Add files using upload-large-folder tool
d2253eb verified
import os, torch, numpy
from diffnext.pipelines import URSAPipeline
from diffnext.utils import export_to_video
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
model_id, height, width = "BAAI/URSA-1.7B-FSQ320", 320, 512
model_args = {"torch_dtype": torch.bfloat16, "trust_remote_code": True}
pipe = URSAPipeline.from_pretrained(model_id, **model_args)
pipe = pipe.to(torch.device("cuda"))
text_prompt = "tom and jerry"#"a lone grizzly bear walks through a misty forest at dawn, sunlight catching its fur."
negative_prompt = "worst quality, low quality, inconsistent motion, static, still, blurry, jittery, distorted, ugly"
import time
t1 = time.time()
# Text-to-Image
prompt = text_prompt
num_frames, num_inference_steps = 1, 25
image = pipe(**locals()).frames[0]
image.save("tom/ursa.jpg")
t2 = time.time()
# Image-to-Video
prompt = f"motion=9.0, {text_prompt}"
num_frames, num_inference_steps = 49, 50
video = pipe(**locals()).frames[0]
export_to_video(video, "tom/ursa_1+48f.mp4", fps=12)
t3 = time.time()
# Text-to-Video
image, video = None, None
prompt = f"motion=9.0, {text_prompt}"
num_frames, num_inference_steps = 49, 50
video = pipe(**locals()).frames[0]
export_to_video(video, "tom/ursa_49f.mp4", fps=12)
t4 = time.time()
# Video-to-Video
prompt = f"motion=5.0, {text_prompt}"
num_frames, num_inference_steps = 49, 50
num_cond_frames, cond_noise_scale = 13, 0.1
for i in range(12):
video, start_video = video[-num_cond_frames:], video
video = pipe(**locals()).frames[0]
video = numpy.concatenate([start_video, video[num_cond_frames:]])
export_to_video(video, "tom/ursa_{}f.mp4".format(video.shape[0]), fps=12)
t5 = time.time()
print(f"Text-to-Image time: {t2-t1:.2f} seconds")
print(f"Image-to-Video time: {t3-t2:.2f} seconds")
print(f"Text-to-Video time: {t4-t3:.2f} seconds")
print(f"Video-to-Video time: {t5-t4:.2f} seconds")
# Single H800 GPU, batch_size=1, the inference time is:
# Text-to-Image time: 5.05 seconds
# Image-to-Video time: 101.92 seconds
# Text-to-Video time: 101.52 seconds
# Video-to-Video time: 1226.25 seconds
# cd URSA/
# source .venv_ursa/bin/activate
# accelerate launch --config_file accelerate_configs/deepspeed_zero2.yaml --machine_rank 0 --num_machines 1 --num_processes 8 scripts/train_distill_dimo.py config="./configs/distill_dimo.yaml" experiment.output_dir="./experiments/distill_dimo_v3" distill.teacher_ckpt="/gfs/space/private/fengzl/World_Model/URSA-1.7B" distill.prompt_source="/gfs/space/private/fengzl/World_Model/Koala-36M-v1"