MotionCLR / models /gaussian_diffusion.py
EvanTHU's picture
Update models/gaussian_diffusion.py
54ac52b verified
from diffusers import (
DPMSolverMultistepScheduler,
DDPMScheduler,
DDIMScheduler,
PNDMScheduler,
DEISMultistepScheduler,
)
import torch
import yaml
import math
import tqdm
import time
class DiffusePipeline(object):
def __init__(
self,
opt,
model,
diffuser_name,
num_inference_steps,
device,
torch_dtype=torch.float16,
):
self.device = device
self.torch_dtype = torch_dtype
self.diffuser_name = diffuser_name
self.num_inference_steps = num_inference_steps
if self.torch_dtype == torch.float16:
model = model.half()
print("Xxx",device,"xxxx")
self.model = model.to(device)
self.opt = opt
# Load parameters from YAML file
with open("config/diffuser_params.yaml", "r") as yaml_file:
diffuser_params = yaml.safe_load(yaml_file)
# Select diffusion'parameters based on diffuser_name
if diffuser_name in diffuser_params:
params = diffuser_params[diffuser_name]
scheduler_class_name = params["scheduler_class"]
additional_params = params["additional_params"]
# align training parameters
additional_params["num_train_timesteps"] = opt.diffusion_steps
additional_params["beta_schedule"] = opt.beta_schedule
additional_params["prediction_type"] = opt.prediction_type
try:
scheduler_class = globals()[scheduler_class_name]
except KeyError:
raise ValueError(f"Class '{scheduler_class_name}' not found.")
self.scheduler = scheduler_class(**additional_params)
else:
raise ValueError(f"Unsupported diffuser_name: {diffuser_name}")
def generate_batch(self, caption, m_lens):
B = len(caption)
T = m_lens.max()
shape = (B, T, self.model.input_feats)
# random sampling noise x_T
sample = torch.randn(shape, device=self.device, dtype=self.torch_dtype)
# set timesteps
self.scheduler.set_timesteps(self.num_inference_steps, self.device)
timesteps = [
torch.tensor([t] * B, device=self.device).long()
for t in self.scheduler.timesteps
]
# cache text_embedded
enc_text = self.model.encode_text(caption, self.device)
for i, t in enumerate(timesteps):
# 1. model predict
with torch.no_grad():
if getattr(self.model, "cond_mask_prob", 0) > 0:
predict = self.model.forward_with_cfg(sample, t, enc_text=enc_text)
else:
predict = self.model(sample, t, enc_text=enc_text)
# 2. compute less noisy motion and set x_t -> x_t-1
sample = self.scheduler.step(predict, t[0], sample).prev_sample
return sample
def generate(self, caption, m_lens, batch_size=32):
N = len(caption)
infer_mode = ""
if getattr(self.model, "cond_mask_prob", 0) > 0:
infer_mode = "classifier-free-guidance"
print(
f"\nUsing {self.diffuser_name} diffusion scheduler to {infer_mode} generate {N} motions, sampling {self.num_inference_steps} steps."
)
self.model.eval()
all_output = []
t_sum = 0
cur_idx = 0
for bacth_idx in tqdm.tqdm(range(math.ceil(N / batch_size))):
if cur_idx + batch_size >= N:
batch_caption = caption[cur_idx:]
batch_m_lens = m_lens[cur_idx:]
else:
batch_caption = caption[cur_idx : cur_idx + batch_size]
batch_m_lens = m_lens[cur_idx : cur_idx + batch_size]
torch.cuda.synchronize()
start_time = time.time()
output = self.generate_batch(batch_caption, batch_m_lens)
torch.cuda.synchronize()
now_time = time.time()
# The average inference time is calculated after GPU warm-up in the first 50 steps.
if (bacth_idx + 1) * self.num_inference_steps >= 50:
t_sum += now_time - start_time
# Crop motion with gt/predicted motion length
B = output.shape[0]
for i in range(B):
all_output.append(output[i, : batch_m_lens[i]])
cur_idx += batch_size
# calcalate average inference time
t_eval = t_sum / (bacth_idx - 1)
print(
"The average generation time of a batch motion (bs=%d) is %f seconds"
% (batch_size, t_eval)
)
return all_output, t_eval