|
|
|
|
|
|
|
|
|
|
|
from typing import List, Union |
|
|
import numpy as np |
|
|
import axengine |
|
|
import torch |
|
|
from PIL import Image |
|
|
from transformers import CLIPTokenizer, PreTrainedTokenizer |
|
|
import time |
|
|
import argparse |
|
|
import uuid |
|
|
import os |
|
|
import traceback |
|
|
from diffusers import DPMSolverMultistepScheduler |
|
|
|
|
|
|
|
|
DEBUG_MODE = True |
|
|
LOG_TIMESTAMP = True |
|
|
|
|
|
def debug_log(msg): |
|
|
if DEBUG_MODE: |
|
|
timestamp = f"[{time.strftime('%Y-%m-%d %H:%M:%S')}] " if LOG_TIMESTAMP else "" |
|
|
print(f"{timestamp}[DEBUG] {msg}") |
|
|
|
|
|
def get_args(): |
|
|
try: |
|
|
parser = argparse.ArgumentParser( |
|
|
prog="StableDiffusion", |
|
|
description="Generate picture with the input prompt using DPM++ sampler" |
|
|
) |
|
|
parser.add_argument("--prompt", type=str, required=False, |
|
|
default="masterpiece, best quality, 1girl, (colorful),(delicate eyes and face), volumatic light, ray tracing, bust shot ,extremely detailed CG unity 8k wallpaper,solo,smile,intricate skirt,((flying petal)),(Flowery meadow) sky, cloudy_sky, moonlight, moon, night, (dark theme:1.3), light, fantasy, windy, magic sparks, dark castle,white hair", |
|
|
help="the input text prompt") |
|
|
parser.add_argument("--text_model_dir", type=str, required=False, default="./models/", |
|
|
help="Path to text encoder and tokenizer files") |
|
|
parser.add_argument("--unet_model", type=str, required=False, default="./models/unet.axmodel", |
|
|
help="Path to unet axmodel model") |
|
|
parser.add_argument("--vae_decoder_model", type=str, required=False, default="./models/vae_decoder.axmodel", |
|
|
help="Path to vae decoder axmodel model") |
|
|
parser.add_argument("--time_input", type=str, required=False, |
|
|
default="./models/time_input_dpmpp_20steps.npy", |
|
|
help="Path to time input file") |
|
|
parser.add_argument("--save_dir", type=str, required=False, default="./txt2img_output_axe", |
|
|
help="Path to the output image file") |
|
|
parser.add_argument("--num_inference_steps", type=int, default=20, |
|
|
help="Number of inference steps for DPM++ sampler") |
|
|
parser.add_argument("--guidance_scale", type=float, default=7.5, help="Guidance scale for CFG") |
|
|
parser.add_argument("--seed", type=int, default=None, help="Random seed") |
|
|
return parser.parse_args() |
|
|
except Exception as e: |
|
|
print(f"参数解析失败: {str(e)}") |
|
|
traceback.print_exc() |
|
|
exit(1) |
|
|
|
|
|
def get_embeds(prompt, negative_prompt, tokenizer_dir, text_encoder_dir): |
|
|
"""获取正负提示词的嵌入(带形状验证)""" |
|
|
try: |
|
|
debug_log(f"开始处理提示词: {prompt[:50]}...") |
|
|
start_time = time.time() |
|
|
|
|
|
tokenizer = CLIPTokenizer.from_pretrained(tokenizer_dir) |
|
|
|
|
|
def process_prompt(prompt_text): |
|
|
inputs = tokenizer( |
|
|
prompt_text, |
|
|
padding="max_length", |
|
|
max_length=77, |
|
|
truncation=True, |
|
|
return_tensors="pt" |
|
|
) |
|
|
debug_log(f"Tokenizer输出形状: {inputs.input_ids.shape}") |
|
|
|
|
|
model_path = os.path.join(text_encoder_dir, "sd15_text_encoder_sim.axmodel") |
|
|
if not os.path.exists(model_path): |
|
|
raise FileNotFoundError(f"文本编码器模型不存在: {model_path}") |
|
|
|
|
|
session = axengine.InferenceSession(model_path) |
|
|
outputs = session.run(None, {"input_ids": inputs.input_ids.numpy().astype(np.int32)})[0] |
|
|
debug_log(f"文本编码器输出形状: {outputs.shape} | dtype: {outputs.dtype}") |
|
|
return outputs |
|
|
|
|
|
neg_start = time.time() |
|
|
neg_embeds = process_prompt(negative_prompt) |
|
|
pos_embeds = process_prompt(prompt) |
|
|
debug_log(f"文本编码完成 | 耗时: {(time.time()-start_time):.2f}s") |
|
|
|
|
|
|
|
|
if neg_embeds.shape != (1, 77, 768) or pos_embeds.shape != (1, 77, 768): |
|
|
raise ValueError(f"嵌入形状异常: 负面{neg_embeds.shape}, 正面{pos_embeds.shape}") |
|
|
|
|
|
return neg_embeds, pos_embeds |
|
|
except Exception as e: |
|
|
print(f"获取嵌入失败: {str(e)}") |
|
|
traceback.print_exc() |
|
|
exit(1) |
|
|
|
|
|
def main(): |
|
|
try: |
|
|
debug_log("程序启动") |
|
|
args = get_args() |
|
|
debug_log(f"参数解析完成 | 随机种子: {args.seed} | 推理步数: {args.num_inference_steps}") |
|
|
|
|
|
|
|
|
seed = args.seed if args.seed else int(time.time()) |
|
|
torch.manual_seed(seed) |
|
|
np.random.seed(seed) |
|
|
debug_log(f"随机种子设置完成: {seed}") |
|
|
|
|
|
|
|
|
model_paths = [ |
|
|
args.unet_model, |
|
|
args.vae_decoder_model, |
|
|
os.path.join(args.text_model_dir, 'tokenizer'), |
|
|
os.path.join(args.text_model_dir, 'text_encoder') |
|
|
] |
|
|
for path in model_paths: |
|
|
if not os.path.exists(path): |
|
|
raise FileNotFoundError(f"模型路径不存在: {path}") |
|
|
|
|
|
|
|
|
debug_log("初始化调度器...") |
|
|
scheduler_start = time.time() |
|
|
scheduler = DPMSolverMultistepScheduler( |
|
|
num_train_timesteps=1000, |
|
|
beta_start=0.00085, |
|
|
beta_end=0.012, |
|
|
beta_schedule="scaled_linear", |
|
|
algorithm_type="dpmsolver++", |
|
|
use_karras_sigmas=True |
|
|
) |
|
|
scheduler.set_timesteps(args.num_inference_steps) |
|
|
debug_log(f"调度器初始化完成 | 耗时: {(time.time()-scheduler_start):.2f}s") |
|
|
|
|
|
|
|
|
debug_log("加载NPU模型...") |
|
|
model_load_start = time.time() |
|
|
unet_session_main = axengine.InferenceSession(args.unet_model) |
|
|
vae_decoder = axengine.InferenceSession(args.vae_decoder_model) |
|
|
debug_log(f"模型加载完成 | 总耗时: {(time.time()-model_load_start):.2f}s") |
|
|
debug_log(f"UNET输入信息: {[str(inp) for inp in unet_session_main.get_inputs()]}") |
|
|
debug_log(f"VAE输入信息: {[str(inp) for inp in vae_decoder.get_inputs()]}") |
|
|
|
|
|
|
|
|
embed_start = time.time() |
|
|
neg_embeds, pos_embeds = get_embeds( |
|
|
args.prompt, |
|
|
"sketch, duplicate, ugly...", |
|
|
os.path.join(args.text_model_dir, 'tokenizer'), |
|
|
os.path.join(args.text_model_dir, 'text_encoder') |
|
|
) |
|
|
debug_log(f"提示词处理完成 | 总耗时: {(time.time()-embed_start):.2f}s") |
|
|
|
|
|
|
|
|
latent_start = time.time() |
|
|
latents_shape = [1, 4, 60, 40] |
|
|
generator = torch.Generator().manual_seed(seed) |
|
|
latent = torch.randn(latents_shape, generator=generator) |
|
|
init_scale = scheduler.init_noise_sigma |
|
|
latent = latent * init_scale |
|
|
debug_log(f"潜在变量初始化 | 形状: {latent.shape} | 缩放系数: {init_scale}") |
|
|
latent = latent.numpy().astype(np.float32) |
|
|
debug_log(f"潜在变量转换完成 | dtype: {latent.dtype}") |
|
|
|
|
|
|
|
|
debug_log(f"加载时间嵌入: {args.time_input}") |
|
|
time_data = np.load(args.time_input) |
|
|
if len(time_data) < args.num_inference_steps: |
|
|
raise ValueError(f"时间嵌入不足: 需要{args.num_inference_steps}, 实际{len(time_data)}") |
|
|
time_data = time_data[:args.num_inference_steps] |
|
|
debug_log(f"时间嵌入验证通过 | 形状: {time_data.shape}") |
|
|
|
|
|
|
|
|
debug_log("开始采样循环...") |
|
|
total_unet_time = 0 |
|
|
for step_idx, timestep in enumerate(scheduler.timesteps.numpy().astype(np.int64)): |
|
|
step_start = time.time() |
|
|
debug_log(f"\n--- 步骤 {step_idx+1}/{args.num_inference_steps} [ts={timestep}] ---") |
|
|
|
|
|
try: |
|
|
|
|
|
if np.isnan(latent).any(): |
|
|
raise ValueError("潜在变量包含NaN值!") |
|
|
|
|
|
|
|
|
time_emb = np.expand_dims(time_data[step_idx], axis=0) |
|
|
debug_log(f"时间嵌入形状: {time_emb.shape}") |
|
|
|
|
|
|
|
|
debug_log("运行UNET(负面提示)...") |
|
|
unet_neg_start = time.time() |
|
|
noise_pred_neg = unet_session_main.run(None, { |
|
|
"sample": latent, |
|
|
"/down_blocks.0/resnets.0/act_1/Mul_output_0": time_emb, |
|
|
"encoder_hidden_states": neg_embeds |
|
|
})[0] |
|
|
debug_log(f"UNET(负面)完成 | 形状: {noise_pred_neg.shape} | 耗时: {(time.time()-unet_neg_start):.2f}s") |
|
|
|
|
|
debug_log("运行UNET(正面提示)...") |
|
|
unet_pos_start = time.time() |
|
|
noise_pred_pos = unet_session_main.run(None, { |
|
|
"sample": latent, |
|
|
"/down_blocks.0/resnets.0/act_1/Mul_output_0": time_emb, |
|
|
"encoder_hidden_states": pos_embeds |
|
|
})[0] |
|
|
debug_log(f"UNET(正面)完成 | 耗时: {(time.time()-unet_pos_start):.2f}s") |
|
|
|
|
|
|
|
|
debug_log(f"应用CFG指导(scale={args.guidance_scale})...") |
|
|
noise_pred = noise_pred_neg + args.guidance_scale * (noise_pred_pos - noise_pred_neg) |
|
|
debug_log(f"噪声预测范围: [{noise_pred.min():.3f}, {noise_pred.max():.3f}]") |
|
|
|
|
|
|
|
|
latent_tensor = torch.from_numpy(latent) |
|
|
noise_pred_tensor = torch.from_numpy(noise_pred) |
|
|
|
|
|
|
|
|
debug_log("更新潜在变量...") |
|
|
scheduler_start = time.time() |
|
|
latent_tensor = scheduler.step( |
|
|
model_output=noise_pred_tensor, |
|
|
timestep=timestep, |
|
|
sample=latent_tensor |
|
|
).prev_sample |
|
|
debug_log(f"调度器更新完成 | 耗时: {(time.time()-scheduler_start):.2f}s") |
|
|
|
|
|
|
|
|
latent = latent_tensor.numpy().astype(np.float32) |
|
|
debug_log(f"更新后潜在变量范围: [{latent.min():.3f}, {latent.max():.3f}]") |
|
|
|
|
|
step_time = time.time() - step_start |
|
|
total_unet_time += step_time |
|
|
debug_log(f"步骤完成 | 单步耗时: {step_time:.2f}s | 累计耗时: {total_unet_time:.2f}s") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"步骤 {step_idx+1} 执行失败: {str(e)}") |
|
|
traceback.print_exc() |
|
|
exit(1) |
|
|
|
|
|
|
|
|
debug_log("\n开始VAE解码...") |
|
|
vae_start = time.time() |
|
|
try: |
|
|
latent = latent / 0.18215 |
|
|
debug_log(f"VAE输入范围: [{latent.min():.3f}, {latent.max():.3f}]") |
|
|
image = vae_decoder.run(None, {"latent": latent})[0] |
|
|
debug_log(f"VAE输出形状: {image.shape} | 耗时: {(time.time()-vae_start):.2f}s") |
|
|
except Exception as e: |
|
|
print(f"VAE解码失败: {str(e)}") |
|
|
traceback.print_exc() |
|
|
exit(1) |
|
|
|
|
|
|
|
|
debug_log("保存结果...") |
|
|
try: |
|
|
image = np.transpose(image, (0, 2, 3, 1)).squeeze(axis=0) |
|
|
image_denorm = np.clip(image / 2 + 0.5, 0, 1) |
|
|
image = (image_denorm * 255).round().astype("uint8") |
|
|
debug_log(f"图像形状: {image.shape} | dtype: {image.dtype}") |
|
|
|
|
|
pil_image = Image.fromarray(image[:, :, :3]) |
|
|
save_path = os.path.join(args.save_dir, f"{uuid.uuid4()}.png") |
|
|
pil_image.save(save_path) |
|
|
debug_log(f"图像保存成功: {save_path}") |
|
|
except Exception as e: |
|
|
print(f"保存失败: {str(e)}") |
|
|
traceback.print_exc() |
|
|
exit(1) |
|
|
|
|
|
except Exception as e: |
|
|
print(f"主流程执行失败: {str(e)}") |
|
|
traceback.print_exc() |
|
|
exit(1) |
|
|
|
|
|
if __name__ == '__main__': |
|
|
main() |
|
|
|