|
|
import os |
|
|
import sys |
|
|
|
|
|
import torch |
|
|
from lightning import seed_everything |
|
|
from safetensors.torch import load_file as load_safetensors |
|
|
|
|
|
from ldf_utils.initialize import compare_statedict_and_parameters, instantiate, load_config |
|
|
|
|
|
|
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
|
|
|
|
|
|
|
def load_model_from_config(): |
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
torch.set_float32_matmul_precision("high") |
|
|
cfg = load_config() |
|
|
seed_everything(cfg.seed) |
|
|
|
|
|
|
|
|
|
|
|
if '--config' in sys.argv: |
|
|
config_idx = sys.argv.index('--config') + 1 |
|
|
config_dir = os.path.dirname(os.path.abspath(sys.argv[config_idx])) |
|
|
else: |
|
|
config_dir = os.getcwd() |
|
|
|
|
|
vae = instantiate( |
|
|
target=cfg.test_vae.target, |
|
|
cfg=None, |
|
|
hfstyle=False, |
|
|
**cfg.test_vae.params, |
|
|
) |
|
|
|
|
|
|
|
|
vae_path = cfg.test_vae_ckpt |
|
|
if not os.path.isabs(vae_path): |
|
|
vae_path = os.path.join(config_dir, vae_path) |
|
|
|
|
|
|
|
|
vae_state_dict = load_safetensors(vae_path) |
|
|
vae.load_state_dict(vae_state_dict, strict=True) |
|
|
print(f"Loaded VAE model from {vae_path}") |
|
|
|
|
|
compare_statedict_and_parameters( |
|
|
state_dict=vae.state_dict(), |
|
|
named_parameters=vae.named_parameters(), |
|
|
named_buffers=vae.named_buffers(), |
|
|
) |
|
|
vae.to(device) |
|
|
vae.eval() |
|
|
|
|
|
|
|
|
model_params = dict(cfg.model.params) |
|
|
|
|
|
if 'checkpoint_path' in model_params and model_params['checkpoint_path']: |
|
|
if not os.path.isabs(model_params['checkpoint_path']): |
|
|
model_params['checkpoint_path'] = os.path.join(config_dir, model_params['checkpoint_path']) |
|
|
if 'tokenizer_path' in model_params and model_params['tokenizer_path']: |
|
|
if not os.path.isabs(model_params['tokenizer_path']): |
|
|
model_params['tokenizer_path'] = os.path.join(config_dir, model_params['tokenizer_path']) |
|
|
|
|
|
model = instantiate( |
|
|
target=cfg.model.target, cfg=None, hfstyle=False, **model_params |
|
|
) |
|
|
|
|
|
|
|
|
model_path = cfg.test_ckpt |
|
|
if not os.path.isabs(model_path): |
|
|
model_path = os.path.join(config_dir, model_path) |
|
|
|
|
|
|
|
|
model_state_dict = load_safetensors(model_path) |
|
|
model.load_state_dict(model_state_dict, strict=True) |
|
|
print(f"Loaded model from {model_path}") |
|
|
|
|
|
compare_statedict_and_parameters( |
|
|
state_dict=model.state_dict(), |
|
|
named_parameters=model.named_parameters(), |
|
|
named_buffers=model.named_buffers(), |
|
|
) |
|
|
model.to(device) |
|
|
model.eval() |
|
|
|
|
|
return vae, model |
|
|
|
|
|
|
|
|
@torch.inference_mode() |
|
|
def generate_feature_stream( |
|
|
model, feature_length, text, feature_text_end=None, num_denoise_steps=None |
|
|
): |
|
|
""" |
|
|
Streaming interface for feature generation |
|
|
Args: |
|
|
model: Loaded model |
|
|
feature_length: List[int], generation length for each sample |
|
|
text: List[str] or List[List[str]], text prompts |
|
|
feature_text_end: List[List[int]], time points where text ends (if text is list of list) |
|
|
num_denoise_steps: Number of denoising steps |
|
|
Yields: |
|
|
dict: Contains "generated" (current generated feature segment) |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
x = {"feature_length": torch.tensor(feature_length), "text": text} |
|
|
|
|
|
if feature_text_end is not None: |
|
|
x["feature_text_end"] = feature_text_end |
|
|
|
|
|
|
|
|
|
|
|
generator = model.stream_generate(x, num_denoise_steps=num_denoise_steps) |
|
|
|
|
|
for step_output in generator: |
|
|
|
|
|
yield step_output |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
import argparse |
|
|
|
|
|
parser = argparse.ArgumentParser() |
|
|
parser.add_argument("--config", type=str, required=True, help="Path to config") |
|
|
parser.add_argument( |
|
|
"--text", type=str, default="a person walks forward", help="Text prompt" |
|
|
) |
|
|
parser.add_argument("--length", type=int, default=120, help="Motion length") |
|
|
parser.add_argument( |
|
|
"--output", type=str, default="output.mp4", help="Output video path" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--num_denoise_steps", type=int, default=None, help="Number of denoising steps" |
|
|
) |
|
|
args = parser.parse_args() |
|
|
|
|
|
print("Loading model...") |
|
|
vae, model = load_model_from_config() |
|
|
|
|
|
|