| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| from collections import OrderedDict |
| from copy import copy |
| import numpy as np |
| import os |
| import math |
| from PIL import Image |
| from polygraphy.backend.common import bytes_from_path |
| from polygraphy.backend.trt import CreateConfig, Profile |
| from polygraphy.backend.trt import engine_from_bytes, engine_from_network, network_from_onnx_path, save_engine |
| from polygraphy.backend.trt import util as trt_util |
| from polygraphy import cuda |
| import random |
| from scipy import integrate |
| import tensorrt as trt |
| import torch |
|
|
| TRT_LOGGER = trt.Logger(trt.Logger.WARNING) |
|
|
| class Engine(): |
| def __init__( |
| self, |
| model_name, |
| engine_dir, |
| ): |
| self.engine_path = os.path.join(engine_dir, model_name+'.plan') |
| self.engine = None |
| self.context = None |
| self.buffers = OrderedDict() |
| self.tensors = OrderedDict() |
|
|
| def __del__(self): |
| [buf.free() for buf in self.buffers.values() if isinstance(buf, cuda.DeviceArray) ] |
| del self.engine |
| del self.context |
| del self.buffers |
| del self.tensors |
|
|
| def build(self, onnx_path, fp16, input_profile=None, enable_preview=False): |
| print(f"Building TensorRT engine for {onnx_path}: {self.engine_path}") |
| p = Profile() |
| if input_profile: |
| for name, dims in input_profile.items(): |
| assert len(dims) == 3 |
| p.add(name, min=dims[0], opt=dims[1], max=dims[2]) |
|
|
| preview_features = [] |
| if enable_preview: |
| trt_version = [int(i) for i in trt.__version__.split(".")] |
| |
| if trt_version[0] > 8 or \ |
| (trt_version[0] == 8 and (trt_version[1] > 5 or (trt_version[1] == 5 and trt_version[2] >= 1))): |
| preview_features = [trt.PreviewFeature.FASTER_DYNAMIC_SHAPES_0805] |
|
|
| engine = engine_from_network(network_from_onnx_path(onnx_path), config=CreateConfig(fp16=fp16, profiles=[p], |
| preview_features=preview_features)) |
| save_engine(engine, path=self.engine_path) |
|
|
| def activate(self): |
| print(f"Loading TensorRT engine: {self.engine_path}") |
| self.engine = engine_from_bytes(bytes_from_path(self.engine_path)) |
| self.context = self.engine.create_execution_context() |
|
|
| def allocate_buffers(self, shape_dict=None, device='cuda'): |
| for idx in range(trt_util.get_bindings_per_profile(self.engine)): |
| binding = self.engine[idx] |
| if shape_dict and binding in shape_dict: |
| shape = shape_dict[binding] |
| else: |
| shape = self.engine.get_binding_shape(binding) |
| dtype = trt_util.np_dtype_from_trt(self.engine.get_binding_dtype(binding)) |
| if self.engine.binding_is_input(binding): |
| self.context.set_binding_shape(idx, shape) |
| |
| np_type_tensor = np.empty(shape=[], dtype=dtype) |
| torch_type_tensor = torch.from_numpy(np_type_tensor) |
| tensor = torch.empty(tuple(shape), dtype=torch_type_tensor.dtype).to(device=device) |
| self.tensors[binding] = tensor |
| self.buffers[binding] = cuda.DeviceView(ptr=tensor.data_ptr(), shape=shape, dtype=dtype) |
|
|
| def infer(self, feed_dict, stream): |
| start_binding, end_binding = trt_util.get_active_profile_bindings(self.context) |
| |
| device_buffers = copy(self.buffers) |
| for name, buf in feed_dict.items(): |
| assert isinstance(buf, cuda.DeviceView) |
| device_buffers[name] = buf |
| bindings = [0] * start_binding + [buf.ptr for buf in device_buffers.values()] |
| noerror = self.context.execute_async_v2(bindings=bindings, stream_handle=stream.ptr) |
| if not noerror: |
| raise ValueError(f"ERROR: inference failed.") |
|
|
| return self.tensors |
|
|
| class LMSDiscreteScheduler(): |
| def __init__( |
| self, |
| device = 'cuda', |
| beta_start = 0.00085, |
| beta_end = 0.012, |
| num_train_timesteps = 1000, |
| ): |
| self.num_train_timesteps = num_train_timesteps |
| self.order = 4 |
|
|
| self.beta_start = beta_start |
| self.beta_end = beta_end |
| betas = (torch.linspace(beta_start**0.5, beta_end**0.5, self.num_train_timesteps, dtype=torch.float32) ** 2) |
| alphas = 1.0 - betas |
| self.alphas_cumprod = torch.cumprod(alphas, dim=0) |
|
|
| sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) |
| sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32) |
| self.sigmas = torch.from_numpy(sigmas) |
|
|
| |
| self.init_noise_sigma = self.sigmas.max() |
|
|
| self.device = device |
|
|
| def set_timesteps(self, steps): |
| self.num_inference_steps = steps |
|
|
| timesteps = np.linspace(0, self.num_train_timesteps - 1, steps, dtype=float)[::-1].copy() |
| sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) |
| sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) |
| sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) |
| self.sigmas = torch.from_numpy(sigmas).to(device=self.device) |
|
|
| |
| self.timesteps = torch.from_numpy(timesteps).to(device=self.device).float() |
| self.derivatives = [] |
|
|
| def scale_model_input(self, sample: torch.FloatTensor, idx, *args, **kwargs) -> torch.FloatTensor: |
| return sample * self.latent_scales[idx] |
|
|
| def configure(self): |
| order = self.order |
| self.lms_coeffs = [] |
| self.latent_scales = [1./((sigma**2 + 1) ** 0.5) for sigma in self.sigmas] |
|
|
| def get_lms_coefficient(order, t, current_order): |
| """ |
| Compute a linear multistep coefficient. |
| """ |
| def lms_derivative(tau): |
| prod = 1.0 |
| for k in range(order): |
| if current_order == k: |
| continue |
| prod *= (tau - self.sigmas[t - k]) / (self.sigmas[t - current_order] - self.sigmas[t - k]) |
| return prod |
| integrated_coeff = integrate.quad(lms_derivative, self.sigmas[t], self.sigmas[t + 1], epsrel=1e-4)[0] |
| return integrated_coeff |
|
|
| for step_index in range(self.num_inference_steps): |
| order = min(step_index + 1, order) |
| self.lms_coeffs.append([get_lms_coefficient(order, step_index, curr_order) for curr_order in range(order)]) |
|
|
| def step(self, output, latents, idx, timestep): |
| |
| |
| sigma = self.sigmas[idx] |
| pred_original_sample = latents - sigma * output |
| |
| derivative = (latents - pred_original_sample) / sigma |
| self.derivatives.append(derivative) |
| if len(self.derivatives) > self.order: |
| self.derivatives.pop(0) |
| |
| prev_sample = latents + sum( |
| coeff * derivative for coeff, derivative in zip(self.lms_coeffs[idx], reversed(self.derivatives)) |
| ) |
|
|
| return prev_sample |
|
|
| class DPMScheduler(): |
| def __init__( |
| self, |
| beta_start = 0.00085, |
| beta_end = 0.012, |
| num_train_timesteps = 1000, |
| solver_order = 2, |
| predict_epsilon = True, |
| thresholding = False, |
| dynamic_thresholding_ratio = 0.995, |
| sample_max_value = 1.0, |
| algorithm_type = "dpmsolver++", |
| solver_type = "midpoint", |
| lower_order_final = True, |
| device = 'cuda', |
| ): |
| |
| self.betas = ( |
| torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 |
| ) |
|
|
| self.device = device |
| self.alphas = 1.0 - self.betas |
| self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) |
| |
| self.alpha_t = torch.sqrt(self.alphas_cumprod) |
| self.sigma_t = torch.sqrt(1 - self.alphas_cumprod) |
| self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t) |
|
|
| |
| self.init_noise_sigma = 1.0 |
|
|
| self.algorithm_type = algorithm_type |
| self.predict_epsilon = predict_epsilon |
| self.thresholding = thresholding |
| self.dynamic_thresholding_ratio = dynamic_thresholding_ratio |
| self.sample_max_value = sample_max_value |
| self.lower_order_final = lower_order_final |
|
|
| |
| if algorithm_type not in ["dpmsolver", "dpmsolver++"]: |
| raise NotImplementedError(f"{algorithm_type} does is not implemented for {self.__class__}") |
| if solver_type not in ["midpoint", "heun"]: |
| raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}") |
|
|
| |
| self.num_inference_steps = None |
| self.solver_order = solver_order |
| self.num_train_timesteps = num_train_timesteps |
| self.solver_type = solver_type |
|
|
| self.first_order_first_coef = [] |
| self.first_order_second_coef = [] |
|
|
| self.second_order_first_coef = [] |
| self.second_order_second_coef = [] |
| self.second_order_third_coef = [] |
|
|
| self.third_order_first_coef = [] |
| self.third_order_second_coef = [] |
| self.third_order_third_coef = [] |
| self.third_order_fourth_coef = [] |
|
|
| def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor: |
| return sample |
|
|
| def configure(self): |
| lower_order_nums = 0 |
| for step_index in range(self.num_inference_steps): |
| step_idx = step_index |
| timestep = self.timesteps[step_idx] |
|
|
| prev_timestep = 0 if step_idx == len(self.timesteps) - 1 else self.timesteps[step_idx + 1] |
|
|
| self.dpm_solver_first_order_coefs_precompute(timestep, prev_timestep) |
|
|
| timestep_list = [self.timesteps[step_index - 1], timestep] |
| self.multistep_dpm_solver_second_order_coefs_precompute(timestep_list, prev_timestep) |
| |
| timestep_list = [self.timesteps[step_index - 2], self.timesteps[step_index - 1], timestep] |
| self.multistep_dpm_solver_third_order_coefs_precompute(timestep_list, prev_timestep) |
|
|
| if lower_order_nums < self.solver_order: |
| lower_order_nums += 1 |
|
|
| def dpm_solver_first_order_coefs_precompute(self, timestep, prev_timestep): |
| lambda_t, lambda_s = self.lambda_t[prev_timestep], self.lambda_t[timestep] |
| alpha_t, alpha_s = self.alpha_t[prev_timestep], self.alpha_t[timestep] |
| sigma_t, sigma_s = self.sigma_t[prev_timestep], self.sigma_t[timestep] |
| h = lambda_t - lambda_s |
| if self.algorithm_type == "dpmsolver++": |
| self.first_order_first_coef.append(sigma_t / sigma_s) |
| self.first_order_second_coef.append(alpha_t * (torch.exp(-h) - 1.0)) |
| elif self.algorithm_type == "dpmsolver": |
| self.first_order_first_coef.append(alpha_t / alpha_s) |
| self.first_order_second_coef.append(sigma_t * (torch.exp(h) - 1.0)) |
|
|
| def multistep_dpm_solver_second_order_coefs_precompute(self, timestep_list, prev_timestep): |
| t, s0, s1 = prev_timestep, timestep_list[-1], timestep_list[-2] |
| lambda_t, lambda_s0, lambda_s1 = self.lambda_t[t], self.lambda_t[s0], self.lambda_t[s1] |
| alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0] |
| sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0] |
| h = lambda_t - lambda_s0 |
| if self.algorithm_type == "dpmsolver++": |
| |
| if self.solver_type == "midpoint": |
| self.second_order_first_coef.append(sigma_t / sigma_s0) |
| self.second_order_second_coef.append((alpha_t * (torch.exp(-h) - 1.0))) |
| self.second_order_third_coef.append(0.5 * (alpha_t * (torch.exp(-h) - 1.0))) |
| elif self.solver_type == "heun": |
| self.second_order_first_coef.append(sigma_t / sigma_s0) |
| self.second_order_second_coef.append((alpha_t * (torch.exp(-h) - 1.0))) |
| self.second_order_third_coef.append(alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) |
| elif self.algorithm_type == "dpmsolver": |
| |
| if self.solver_type == "midpoint": |
| self.second_order_first_coef.append(alpha_t / alpha_s0) |
| self.second_order_second_coef.append((sigma_t * (torch.exp(h) - 1.0))) |
| self.second_order_third_coef.append(0.5 * (sigma_t * (torch.exp(h) - 1.0))) |
| elif self.solver_type == "heun": |
| self.second_order_first_coef.append(alpha_t / alpha_s0) |
| self.second_order_second_coef.append((sigma_t * (torch.exp(h) - 1.0))) |
| self.second_order_third_coef.append((sigma_t * ((torch.exp(h) - 1.0) / h - 1.0))) |
|
|
| def multistep_dpm_solver_third_order_coefs_precompute(self, timestep_list, prev_timestep): |
| t, s0 = prev_timestep, timestep_list[-1] |
| lambda_t, lambda_s0 = ( |
| self.lambda_t[t], |
| self.lambda_t[s0] |
| ) |
| alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0] |
| sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0] |
| h = lambda_t - lambda_s0 |
| if self.algorithm_type == "dpmsolver++": |
| self.third_order_first_coef.append(sigma_t / sigma_s0) |
| self.third_order_second_coef.append(alpha_t * (torch.exp(-h) - 1.0)) |
| self.third_order_third_coef.append(alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) |
| self.third_order_fourth_coef.append(alpha_t * ((torch.exp(-h) - 1.0 + h) / h**2 - 0.5)) |
| elif self.algorithm_type == "dpmsolver": |
| self.third_order_first_coef.append(alpha_t / alpha_s0) |
| self.third_order_second_coef.append(sigma_t * (torch.exp(h) - 1.0)) |
| self.third_order_third_coef.append(sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) |
| self.third_order_fourth_coef.append(sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5)) |
|
|
| def set_timesteps(self, num_inference_steps): |
| self.num_inference_steps = num_inference_steps |
| timesteps = ( |
| np.linspace(0, self.num_train_timesteps - 1, num_inference_steps + 1) |
| .round()[::-1][:-1] |
| .copy() |
| .astype(np.int32) |
| ) |
| self.timesteps = torch.from_numpy(timesteps).to(self.device) |
| self.model_outputs = [ |
| None, |
| ] * self.solver_order |
| self.lower_order_nums = 0 |
|
|
| def convert_model_output( |
| self, model_output, timestep, sample |
| ): |
| |
| if self.algorithm_type == "dpmsolver++": |
| if self.predict_epsilon: |
| alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] |
| x0_pred = (sample - sigma_t * model_output) / alpha_t |
| else: |
| x0_pred = model_output |
| if self.thresholding: |
| |
| dynamic_max_val = torch.quantile( |
| torch.abs(x0_pred).reshape((x0_pred.shape[0], -1)), self.dynamic_thresholding_ratio, dim=1 |
| ) |
| dynamic_max_val = torch.maximum( |
| dynamic_max_val, |
| self.sample_max_value * torch.ones_like(dynamic_max_val).to(dynamic_max_val.device), |
| )[(...,) + (None,) * (x0_pred.ndim - 1)] |
| x0_pred = torch.clamp(x0_pred, -dynamic_max_val, dynamic_max_val) / dynamic_max_val |
| return x0_pred |
| |
| elif self.algorithm_type == "dpmsolver": |
| if self.predict_epsilon: |
| return model_output |
| else: |
| alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] |
| epsilon = (sample - alpha_t * model_output) / sigma_t |
| return epsilon |
|
|
| def dpm_solver_first_order_update( |
| self, |
| idx, |
| model_output, |
| sample |
| ): |
| first_coef = self.first_order_first_coef[idx] |
| second_coef = self.first_order_second_coef[idx] |
|
|
| if self.algorithm_type == "dpmsolver++": |
| x_t = first_coef * sample - second_coef * model_output |
| elif self.algorithm_type == "dpmsolver": |
| x_t = first_coef * sample - second_coef * model_output |
| return x_t |
|
|
| def multistep_dpm_solver_second_order_update( |
| self, |
| idx, |
| model_output_list, |
| timestep_list, |
| prev_timestep, |
| sample |
| ): |
| t, s0, s1 = prev_timestep, timestep_list[-1], timestep_list[-2] |
| m0, m1 = model_output_list[-1], model_output_list[-2] |
| lambda_t, lambda_s0, lambda_s1 = self.lambda_t[t], self.lambda_t[s0], self.lambda_t[s1] |
| h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1 |
| r0 = h_0 / h |
| D0, D1 = m0, (1.0 / r0) * (m0 - m1) |
|
|
| first_coef = self.second_order_first_coef[idx] |
| second_coef = self.second_order_second_coef[idx] |
| third_coef = self.second_order_third_coef[idx] |
|
|
| if self.algorithm_type == "dpmsolver++": |
| |
| if self.solver_type == "midpoint": |
| x_t = ( |
| first_coef * sample |
| - second_coef * D0 |
| - third_coef * D1 |
| ) |
| elif self.solver_type == "heun": |
| x_t = ( |
| first_coef * sample |
| - second_coef * D0 |
| + third_coef * D1 |
| ) |
| elif self.algorithm_type == "dpmsolver": |
| |
| if self.solver_type == "midpoint": |
| x_t = ( |
| first_coef * sample |
| - second_coef * D0 |
| - third_coef * D1 |
| ) |
| elif self.solver_type == "heun": |
| x_t = ( |
| first_coef * sample |
| - second_coef * D0 |
| - third_coef * D1 |
| ) |
| return x_t |
|
|
| def multistep_dpm_solver_third_order_update( |
| self, |
| idx, |
| model_output_list, |
| timestep_list, |
| prev_timestep, |
| sample |
| ): |
| t, s0, s1, s2 = prev_timestep, timestep_list[-1], timestep_list[-2], timestep_list[-3] |
| m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3] |
| lambda_t, lambda_s0, lambda_s1, lambda_s2 = ( |
| self.lambda_t[t], |
| self.lambda_t[s0], |
| self.lambda_t[s1], |
| self.lambda_t[s2], |
| ) |
| h, h_0, h_1 = lambda_t - lambda_s0, lambda_s0 - lambda_s1, lambda_s1 - lambda_s2 |
| r0, r1 = h_0 / h, h_1 / h |
| D0 = m0 |
| D1_0, D1_1 = (1.0 / r0) * (m0 - m1), (1.0 / r1) * (m1 - m2) |
| D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1) |
| D2 = (1.0 / (r0 + r1)) * (D1_0 - D1_1) |
|
|
| first_coef = self.third_order_first_coef[idx] |
| second_coef = self.third_order_second_coef[idx] |
| third_coef = self.third_order_third_coef[idx] |
| fourth_coef = self.third_order_fourth_coef[idx] |
|
|
| if self.algorithm_type == "dpmsolver++": |
| |
| x_t = ( |
| first_coef * sample |
| - second_coef * D0 |
| + third_coef * D1 |
| - fourth_coef * D2 |
| ) |
| elif self.algorithm_type == "dpmsolver": |
| |
| x_t = ( |
| first_coef * sample |
| - second_coef * D0 |
| - third_coef * D1 |
| - fourth_coef * D2 |
| ) |
| return x_t |
|
|
| def step(self, output, latents, step_index, timestep): |
| if self.num_inference_steps is None: |
| raise ValueError( |
| "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" |
| ) |
|
|
| prev_timestep = 0 if step_index == len(self.timesteps) - 1 else self.timesteps[step_index + 1] |
| lower_order_final = ( |
| (step_index == len(self.timesteps) - 1) and self.lower_order_final and len(self.timesteps) < 15 |
| ) |
| lower_order_second = ( |
| (step_index == len(self.timesteps) - 2) and self.lower_order_final and len(self.timesteps) < 15 |
| ) |
|
|
| output = self.convert_model_output(output, timestep, latents) |
| for i in range(self.solver_order - 1): |
| self.model_outputs[i] = self.model_outputs[i + 1] |
| self.model_outputs[-1] = output |
|
|
| if self.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final: |
| prev_sample = self.dpm_solver_first_order_update(step_index, output, latents) |
| elif self.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second: |
| timestep_list = [self.timesteps[step_index - 1], timestep] |
| prev_sample = self.multistep_dpm_solver_second_order_update( |
| step_index, self.model_outputs, timestep_list, prev_timestep, latents |
| ) |
| else: |
| timestep_list = [self.timesteps[step_index - 2], self.timesteps[step_index - 1], timestep] |
| prev_sample = self.multistep_dpm_solver_third_order_update( |
| step_index, self.model_outputs, timestep_list, prev_timestep, latents |
| ) |
|
|
| if self.lower_order_nums < self.solver_order: |
| self.lower_order_nums += 1 |
|
|
| return prev_sample |
|
|
| def save_image(images, image_path_dir, image_name_prefix): |
| """ |
| Save the generated images to png files. |
| """ |
| images = ((images + 1) * 255 / 2).clamp(0, 255).detach().permute(0, 2, 3, 1).round().type(torch.uint8).cpu().numpy() |
| for i in range(images.shape[0]): |
| image_path = os.path.join(image_path_dir, image_name_prefix+str(i+1)+'-'+str(random.randint(1000,9999))+'.png') |
| print(f"Saving image {i+1} / {images.shape[0]} to: {image_path}") |
| Image.fromarray(images[i]).save(image_path) |
|
|