Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,666 Bytes
b887ad8 54ac52b b887ad8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
from diffusers import (
DPMSolverMultistepScheduler,
DDPMScheduler,
DDIMScheduler,
PNDMScheduler,
DEISMultistepScheduler,
)
import torch
import yaml
import math
import tqdm
import time
class DiffusePipeline(object):
def __init__(
self,
opt,
model,
diffuser_name,
num_inference_steps,
device,
torch_dtype=torch.float16,
):
self.device = device
self.torch_dtype = torch_dtype
self.diffuser_name = diffuser_name
self.num_inference_steps = num_inference_steps
if self.torch_dtype == torch.float16:
model = model.half()
print("Xxx",device,"xxxx")
self.model = model.to(device)
self.opt = opt
# Load parameters from YAML file
with open("config/diffuser_params.yaml", "r") as yaml_file:
diffuser_params = yaml.safe_load(yaml_file)
# Select diffusion'parameters based on diffuser_name
if diffuser_name in diffuser_params:
params = diffuser_params[diffuser_name]
scheduler_class_name = params["scheduler_class"]
additional_params = params["additional_params"]
# align training parameters
additional_params["num_train_timesteps"] = opt.diffusion_steps
additional_params["beta_schedule"] = opt.beta_schedule
additional_params["prediction_type"] = opt.prediction_type
try:
scheduler_class = globals()[scheduler_class_name]
except KeyError:
raise ValueError(f"Class '{scheduler_class_name}' not found.")
self.scheduler = scheduler_class(**additional_params)
else:
raise ValueError(f"Unsupported diffuser_name: {diffuser_name}")
def generate_batch(self, caption, m_lens):
B = len(caption)
T = m_lens.max()
shape = (B, T, self.model.input_feats)
# random sampling noise x_T
sample = torch.randn(shape, device=self.device, dtype=self.torch_dtype)
# set timesteps
self.scheduler.set_timesteps(self.num_inference_steps, self.device)
timesteps = [
torch.tensor([t] * B, device=self.device).long()
for t in self.scheduler.timesteps
]
# cache text_embedded
enc_text = self.model.encode_text(caption, self.device)
for i, t in enumerate(timesteps):
# 1. model predict
with torch.no_grad():
if getattr(self.model, "cond_mask_prob", 0) > 0:
predict = self.model.forward_with_cfg(sample, t, enc_text=enc_text)
else:
predict = self.model(sample, t, enc_text=enc_text)
# 2. compute less noisy motion and set x_t -> x_t-1
sample = self.scheduler.step(predict, t[0], sample).prev_sample
return sample
def generate(self, caption, m_lens, batch_size=32):
N = len(caption)
infer_mode = ""
if getattr(self.model, "cond_mask_prob", 0) > 0:
infer_mode = "classifier-free-guidance"
print(
f"\nUsing {self.diffuser_name} diffusion scheduler to {infer_mode} generate {N} motions, sampling {self.num_inference_steps} steps."
)
self.model.eval()
all_output = []
t_sum = 0
cur_idx = 0
for bacth_idx in tqdm.tqdm(range(math.ceil(N / batch_size))):
if cur_idx + batch_size >= N:
batch_caption = caption[cur_idx:]
batch_m_lens = m_lens[cur_idx:]
else:
batch_caption = caption[cur_idx : cur_idx + batch_size]
batch_m_lens = m_lens[cur_idx : cur_idx + batch_size]
torch.cuda.synchronize()
start_time = time.time()
output = self.generate_batch(batch_caption, batch_m_lens)
torch.cuda.synchronize()
now_time = time.time()
# The average inference time is calculated after GPU warm-up in the first 50 steps.
if (bacth_idx + 1) * self.num_inference_steps >= 50:
t_sum += now_time - start_time
# Crop motion with gt/predicted motion length
B = output.shape[0]
for i in range(B):
all_output.append(output[i, : batch_m_lens[i]])
cur_idx += batch_size
# calcalate average inference time
t_eval = t_sum / (bacth_idx - 1)
print(
"The average generation time of a batch motion (bs=%d) is %f seconds"
% (batch_size, t_eval)
)
return all_output, t_eval
|