Spaces:
Running
on
Zero
Running
on
Zero
| from diffusers import ( | |
| DPMSolverMultistepScheduler, | |
| DDPMScheduler, | |
| DDIMScheduler, | |
| PNDMScheduler, | |
| DEISMultistepScheduler, | |
| ) | |
| import torch | |
| import yaml | |
| import math | |
| import tqdm | |
| import time | |
| class DiffusePipeline(object): | |
| def __init__( | |
| self, | |
| opt, | |
| model, | |
| diffuser_name, | |
| num_inference_steps, | |
| device, | |
| torch_dtype=torch.float16, | |
| ): | |
| self.device = device | |
| self.torch_dtype = torch_dtype | |
| self.diffuser_name = diffuser_name | |
| self.num_inference_steps = num_inference_steps | |
| if self.torch_dtype == torch.float16: | |
| model = model.half() | |
| print("Xxx",device,"xxxx") | |
| self.model = model.to(device) | |
| self.opt = opt | |
| # Load parameters from YAML file | |
| with open("config/diffuser_params.yaml", "r") as yaml_file: | |
| diffuser_params = yaml.safe_load(yaml_file) | |
| # Select diffusion'parameters based on diffuser_name | |
| if diffuser_name in diffuser_params: | |
| params = diffuser_params[diffuser_name] | |
| scheduler_class_name = params["scheduler_class"] | |
| additional_params = params["additional_params"] | |
| # align training parameters | |
| additional_params["num_train_timesteps"] = opt.diffusion_steps | |
| additional_params["beta_schedule"] = opt.beta_schedule | |
| additional_params["prediction_type"] = opt.prediction_type | |
| try: | |
| scheduler_class = globals()[scheduler_class_name] | |
| except KeyError: | |
| raise ValueError(f"Class '{scheduler_class_name}' not found.") | |
| self.scheduler = scheduler_class(**additional_params) | |
| else: | |
| raise ValueError(f"Unsupported diffuser_name: {diffuser_name}") | |
| def generate_batch(self, caption, m_lens): | |
| B = len(caption) | |
| T = m_lens.max() | |
| shape = (B, T, self.model.input_feats) | |
| # random sampling noise x_T | |
| sample = torch.randn(shape, device=self.device, dtype=self.torch_dtype) | |
| # set timesteps | |
| self.scheduler.set_timesteps(self.num_inference_steps, self.device) | |
| timesteps = [ | |
| torch.tensor([t] * B, device=self.device).long() | |
| for t in self.scheduler.timesteps | |
| ] | |
| # cache text_embedded | |
| enc_text = self.model.encode_text(caption, self.device) | |
| for i, t in enumerate(timesteps): | |
| # 1. model predict | |
| with torch.no_grad(): | |
| if getattr(self.model, "cond_mask_prob", 0) > 0: | |
| predict = self.model.forward_with_cfg(sample, t, enc_text=enc_text) | |
| else: | |
| predict = self.model(sample, t, enc_text=enc_text) | |
| # 2. compute less noisy motion and set x_t -> x_t-1 | |
| sample = self.scheduler.step(predict, t[0], sample).prev_sample | |
| return sample | |
| def generate(self, caption, m_lens, batch_size=32): | |
| N = len(caption) | |
| infer_mode = "" | |
| if getattr(self.model, "cond_mask_prob", 0) > 0: | |
| infer_mode = "classifier-free-guidance" | |
| print( | |
| f"\nUsing {self.diffuser_name} diffusion scheduler to {infer_mode} generate {N} motions, sampling {self.num_inference_steps} steps." | |
| ) | |
| self.model.eval() | |
| all_output = [] | |
| t_sum = 0 | |
| cur_idx = 0 | |
| for bacth_idx in tqdm.tqdm(range(math.ceil(N / batch_size))): | |
| if cur_idx + batch_size >= N: | |
| batch_caption = caption[cur_idx:] | |
| batch_m_lens = m_lens[cur_idx:] | |
| else: | |
| batch_caption = caption[cur_idx : cur_idx + batch_size] | |
| batch_m_lens = m_lens[cur_idx : cur_idx + batch_size] | |
| torch.cuda.synchronize() | |
| start_time = time.time() | |
| output = self.generate_batch(batch_caption, batch_m_lens) | |
| torch.cuda.synchronize() | |
| now_time = time.time() | |
| # The average inference time is calculated after GPU warm-up in the first 50 steps. | |
| if (bacth_idx + 1) * self.num_inference_steps >= 50: | |
| t_sum += now_time - start_time | |
| # Crop motion with gt/predicted motion length | |
| B = output.shape[0] | |
| for i in range(B): | |
| all_output.append(output[i, : batch_m_lens[i]]) | |
| cur_idx += batch_size | |
| # calcalate average inference time | |
| t_eval = t_sum / (bacth_idx - 1) | |
| print( | |
| "The average generation time of a batch motion (bs=%d) is %f seconds" | |
| % (batch_size, t_eval) | |
| ) | |
| return all_output, t_eval | |