SD1.5-LLM8850 / dpm20_infer.py
LittleMouse
Add File
244baf9
from typing import List, Union
import numpy as np
import axengine
import torch
from PIL import Image
from transformers import CLIPTokenizer, PreTrainedTokenizer
import time
import argparse
import uuid
import os
import traceback
from diffusers import DPMSolverMultistepScheduler
# 配置日志格式
DEBUG_MODE = True
LOG_TIMESTAMP = True
def debug_log(msg):
if DEBUG_MODE:
timestamp = f"[{time.strftime('%Y-%m-%d %H:%M:%S')}] " if LOG_TIMESTAMP else ""
print(f"{timestamp}[DEBUG] {msg}")
def get_args():
try:
parser = argparse.ArgumentParser(
prog="StableDiffusion",
description="Generate picture with the input prompt using DPM++ sampler"
)
parser.add_argument("--prompt", type=str, required=False,
default="masterpiece, best quality, 1girl, (colorful),(delicate eyes and face), volumatic light, ray tracing, bust shot ,extremely detailed CG unity 8k wallpaper,solo,smile,intricate skirt,((flying petal)),(Flowery meadow) sky, cloudy_sky, moonlight, moon, night, (dark theme:1.3), light, fantasy, windy, magic sparks, dark castle,white hair",
help="the input text prompt")
parser.add_argument("--text_model_dir", type=str, required=False, default="./models/",
help="Path to text encoder and tokenizer files")
parser.add_argument("--unet_model", type=str, required=False, default="./models/unet.axmodel",
help="Path to unet axmodel model")
parser.add_argument("--vae_decoder_model", type=str, required=False, default="./models/vae_decoder.axmodel",
help="Path to vae decoder axmodel model")
parser.add_argument("--time_input", type=str, required=False,
default="./models/time_input_dpmpp_20steps.npy", # 使用DPM++时间嵌入
help="Path to time input file")
parser.add_argument("--save_dir", type=str, required=False, default="./txt2img_output_axe",
help="Path to the output image file")
parser.add_argument("--num_inference_steps", type=int, default=20,
help="Number of inference steps for DPM++ sampler")
parser.add_argument("--guidance_scale", type=float, default=7.5, help="Guidance scale for CFG")
parser.add_argument("--seed", type=int, default=None, help="Random seed")
return parser.parse_args()
except Exception as e:
print(f"参数解析失败: {str(e)}")
traceback.print_exc()
exit(1)
def get_embeds(prompt, negative_prompt, tokenizer_dir, text_encoder_dir):
"""获取正负提示词的嵌入(带形状验证)"""
try:
debug_log(f"开始处理提示词: {prompt[:50]}...")
start_time = time.time()
tokenizer = CLIPTokenizer.from_pretrained(tokenizer_dir)
def process_prompt(prompt_text):
inputs = tokenizer(
prompt_text,
padding="max_length",
max_length=77,
truncation=True,
return_tensors="pt"
)
debug_log(f"Tokenizer输出形状: {inputs.input_ids.shape}")
model_path = os.path.join(text_encoder_dir, "sd15_text_encoder_sim.axmodel")
if not os.path.exists(model_path):
raise FileNotFoundError(f"文本编码器模型不存在: {model_path}")
session = axengine.InferenceSession(model_path)
outputs = session.run(None, {"input_ids": inputs.input_ids.numpy().astype(np.int32)})[0]
debug_log(f"文本编码器输出形状: {outputs.shape} | dtype: {outputs.dtype}")
return outputs
neg_start = time.time()
neg_embeds = process_prompt(negative_prompt)
pos_embeds = process_prompt(prompt)
debug_log(f"文本编码完成 | 耗时: {(time.time()-start_time):.2f}s")
# 验证形状
if neg_embeds.shape != (1, 77, 768) or pos_embeds.shape != (1, 77, 768):
raise ValueError(f"嵌入形状异常: 负面{neg_embeds.shape}, 正面{pos_embeds.shape}")
return neg_embeds, pos_embeds
except Exception as e:
print(f"获取嵌入失败: {str(e)}")
traceback.print_exc()
exit(1)
def main():
try:
debug_log("程序启动")
args = get_args()
debug_log(f"参数解析完成 | 随机种子: {args.seed} | 推理步数: {args.num_inference_steps}")
# 初始化相关配置
seed = args.seed if args.seed else int(time.time())
torch.manual_seed(seed)
np.random.seed(seed)
debug_log(f"随机种子设置完成: {seed}")
# 模型路径验证
model_paths = [
args.unet_model,
args.vae_decoder_model,
os.path.join(args.text_model_dir, 'tokenizer'),
os.path.join(args.text_model_dir, 'text_encoder')
]
for path in model_paths:
if not os.path.exists(path):
raise FileNotFoundError(f"模型路径不存在: {path}")
# 初始化调度器
debug_log("初始化调度器...")
scheduler_start = time.time()
scheduler = DPMSolverMultistepScheduler(
num_train_timesteps=1000,
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
algorithm_type="dpmsolver++",
use_karras_sigmas=True
)
scheduler.set_timesteps(args.num_inference_steps)
debug_log(f"调度器初始化完成 | 耗时: {(time.time()-scheduler_start):.2f}s")
# 加载模型
debug_log("加载NPU模型...")
model_load_start = time.time()
unet_session_main = axengine.InferenceSession(args.unet_model)
vae_decoder = axengine.InferenceSession(args.vae_decoder_model)
debug_log(f"模型加载完成 | 总耗时: {(time.time()-model_load_start):.2f}s")
debug_log(f"UNET输入信息: {[str(inp) for inp in unet_session_main.get_inputs()]}")
debug_log(f"VAE输入信息: {[str(inp) for inp in vae_decoder.get_inputs()]}")
# 处理提示词
embed_start = time.time()
neg_embeds, pos_embeds = get_embeds(
args.prompt,
"sketch, duplicate, ugly...", # 简化的负面提示
os.path.join(args.text_model_dir, 'tokenizer'),
os.path.join(args.text_model_dir, 'text_encoder')
)
debug_log(f"提示词处理完成 | 总耗时: {(time.time()-embed_start):.2f}s")
# 初始化潜在变量
latent_start = time.time()
latents_shape = [1, 4, 60, 40]
generator = torch.Generator().manual_seed(seed)
latent = torch.randn(latents_shape, generator=generator)
init_scale = scheduler.init_noise_sigma
latent = latent * init_scale
debug_log(f"潜在变量初始化 | 形状: {latent.shape} | 缩放系数: {init_scale}")
latent = latent.numpy().astype(np.float32) # 确保数据类型正确
debug_log(f"潜在变量转换完成 | dtype: {latent.dtype}")
# 加载时间嵌入
debug_log(f"加载时间嵌入: {args.time_input}")
time_data = np.load(args.time_input)
if len(time_data) < args.num_inference_steps:
raise ValueError(f"时间嵌入不足: 需要{args.num_inference_steps}, 实际{len(time_data)}")
time_data = time_data[:args.num_inference_steps]
debug_log(f"时间嵌入验证通过 | 形状: {time_data.shape}")
# 采样主循环
debug_log("开始采样循环...")
total_unet_time = 0
for step_idx, timestep in enumerate(scheduler.timesteps.numpy().astype(np.int64)):
step_start = time.time()
debug_log(f"\n--- 步骤 {step_idx+1}/{args.num_inference_steps} [ts={timestep}] ---")
try:
# 验证潜在变量
if np.isnan(latent).any():
raise ValueError("潜在变量包含NaN值!")
# 准备时间嵌入
time_emb = np.expand_dims(time_data[step_idx], axis=0)
debug_log(f"时间嵌入形状: {time_emb.shape}")
# UNET推理
debug_log("运行UNET(负面提示)...")
unet_neg_start = time.time()
noise_pred_neg = unet_session_main.run(None, {
"sample": latent,
"/down_blocks.0/resnets.0/act_1/Mul_output_0": time_emb,
"encoder_hidden_states": neg_embeds
})[0]
debug_log(f"UNET(负面)完成 | 形状: {noise_pred_neg.shape} | 耗时: {(time.time()-unet_neg_start):.2f}s")
debug_log("运行UNET(正面提示)...")
unet_pos_start = time.time()
noise_pred_pos = unet_session_main.run(None, {
"sample": latent,
"/down_blocks.0/resnets.0/act_1/Mul_output_0": time_emb,
"encoder_hidden_states": pos_embeds
})[0]
debug_log(f"UNET(正面)完成 | 耗时: {(time.time()-unet_pos_start):.2f}s")
# CFG融合
debug_log(f"应用CFG指导(scale={args.guidance_scale})...")
noise_pred = noise_pred_neg + args.guidance_scale * (noise_pred_pos - noise_pred_neg)
debug_log(f"噪声预测范围: [{noise_pred.min():.3f}, {noise_pred.max():.3f}]")
# 转换为Tensor
latent_tensor = torch.from_numpy(latent)
noise_pred_tensor = torch.from_numpy(noise_pred)
# 调度器更新
debug_log("更新潜在变量...")
scheduler_start = time.time()
latent_tensor = scheduler.step(
model_output=noise_pred_tensor,
timestep=timestep,
sample=latent_tensor
).prev_sample
debug_log(f"调度器更新完成 | 耗时: {(time.time()-scheduler_start):.2f}s")
# 转换回numpy
latent = latent_tensor.numpy().astype(np.float32)
debug_log(f"更新后潜在变量范围: [{latent.min():.3f}, {latent.max():.3f}]")
step_time = time.time() - step_start
total_unet_time += step_time
debug_log(f"步骤完成 | 单步耗时: {step_time:.2f}s | 累计耗时: {total_unet_time:.2f}s")
except Exception as e:
print(f"步骤 {step_idx+1} 执行失败: {str(e)}")
traceback.print_exc()
exit(1)
# VAE解码
debug_log("\n开始VAE解码...")
vae_start = time.time()
try:
latent = latent / 0.18215
debug_log(f"VAE输入范围: [{latent.min():.3f}, {latent.max():.3f}]")
image = vae_decoder.run(None, {"latent": latent})[0]
debug_log(f"VAE输出形状: {image.shape} | 耗时: {(time.time()-vae_start):.2f}s")
except Exception as e:
print(f"VAE解码失败: {str(e)}")
traceback.print_exc()
exit(1)
# 保存结果
debug_log("保存结果...")
try:
image = np.transpose(image, (0, 2, 3, 1)).squeeze(axis=0)
image_denorm = np.clip(image / 2 + 0.5, 0, 1)
image = (image_denorm * 255).round().astype("uint8")
debug_log(f"图像形状: {image.shape} | dtype: {image.dtype}")
pil_image = Image.fromarray(image[:, :, :3])
save_path = os.path.join(args.save_dir, f"{uuid.uuid4()}.png")
pil_image.save(save_path)
debug_log(f"图像保存成功: {save_path}")
except Exception as e:
print(f"保存失败: {str(e)}")
traceback.print_exc()
exit(1)
except Exception as e:
print(f"主流程执行失败: {str(e)}")
traceback.print_exc()
exit(1)
if __name__ == '__main__':
main()