File size: 12,507 Bytes
244baf9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 |
from typing import List, Union
import numpy as np
import axengine
import torch
from PIL import Image
from transformers import CLIPTokenizer, PreTrainedTokenizer
import time
import argparse
import uuid
import os
import traceback
from diffusers import DPMSolverMultistepScheduler
# 配置日志格式
DEBUG_MODE = True
LOG_TIMESTAMP = True
def debug_log(msg):
if DEBUG_MODE:
timestamp = f"[{time.strftime('%Y-%m-%d %H:%M:%S')}] " if LOG_TIMESTAMP else ""
print(f"{timestamp}[DEBUG] {msg}")
def get_args():
try:
parser = argparse.ArgumentParser(
prog="StableDiffusion",
description="Generate picture with the input prompt using DPM++ sampler"
)
parser.add_argument("--prompt", type=str, required=False,
default="masterpiece, best quality, 1girl, (colorful),(delicate eyes and face), volumatic light, ray tracing, bust shot ,extremely detailed CG unity 8k wallpaper,solo,smile,intricate skirt,((flying petal)),(Flowery meadow) sky, cloudy_sky, moonlight, moon, night, (dark theme:1.3), light, fantasy, windy, magic sparks, dark castle,white hair",
help="the input text prompt")
parser.add_argument("--text_model_dir", type=str, required=False, default="./models/",
help="Path to text encoder and tokenizer files")
parser.add_argument("--unet_model", type=str, required=False, default="./models/unet.axmodel",
help="Path to unet axmodel model")
parser.add_argument("--vae_decoder_model", type=str, required=False, default="./models/vae_decoder.axmodel",
help="Path to vae decoder axmodel model")
parser.add_argument("--time_input", type=str, required=False,
default="./models/time_input_dpmpp_20steps.npy", # 使用DPM++时间嵌入
help="Path to time input file")
parser.add_argument("--save_dir", type=str, required=False, default="./txt2img_output_axe",
help="Path to the output image file")
parser.add_argument("--num_inference_steps", type=int, default=20,
help="Number of inference steps for DPM++ sampler")
parser.add_argument("--guidance_scale", type=float, default=7.5, help="Guidance scale for CFG")
parser.add_argument("--seed", type=int, default=None, help="Random seed")
return parser.parse_args()
except Exception as e:
print(f"参数解析失败: {str(e)}")
traceback.print_exc()
exit(1)
def get_embeds(prompt, negative_prompt, tokenizer_dir, text_encoder_dir):
"""获取正负提示词的嵌入(带形状验证)"""
try:
debug_log(f"开始处理提示词: {prompt[:50]}...")
start_time = time.time()
tokenizer = CLIPTokenizer.from_pretrained(tokenizer_dir)
def process_prompt(prompt_text):
inputs = tokenizer(
prompt_text,
padding="max_length",
max_length=77,
truncation=True,
return_tensors="pt"
)
debug_log(f"Tokenizer输出形状: {inputs.input_ids.shape}")
model_path = os.path.join(text_encoder_dir, "sd15_text_encoder_sim.axmodel")
if not os.path.exists(model_path):
raise FileNotFoundError(f"文本编码器模型不存在: {model_path}")
session = axengine.InferenceSession(model_path)
outputs = session.run(None, {"input_ids": inputs.input_ids.numpy().astype(np.int32)})[0]
debug_log(f"文本编码器输出形状: {outputs.shape} | dtype: {outputs.dtype}")
return outputs
neg_start = time.time()
neg_embeds = process_prompt(negative_prompt)
pos_embeds = process_prompt(prompt)
debug_log(f"文本编码完成 | 耗时: {(time.time()-start_time):.2f}s")
# 验证形状
if neg_embeds.shape != (1, 77, 768) or pos_embeds.shape != (1, 77, 768):
raise ValueError(f"嵌入形状异常: 负面{neg_embeds.shape}, 正面{pos_embeds.shape}")
return neg_embeds, pos_embeds
except Exception as e:
print(f"获取嵌入失败: {str(e)}")
traceback.print_exc()
exit(1)
def main():
try:
debug_log("程序启动")
args = get_args()
debug_log(f"参数解析完成 | 随机种子: {args.seed} | 推理步数: {args.num_inference_steps}")
# 初始化相关配置
seed = args.seed if args.seed else int(time.time())
torch.manual_seed(seed)
np.random.seed(seed)
debug_log(f"随机种子设置完成: {seed}")
# 模型路径验证
model_paths = [
args.unet_model,
args.vae_decoder_model,
os.path.join(args.text_model_dir, 'tokenizer'),
os.path.join(args.text_model_dir, 'text_encoder')
]
for path in model_paths:
if not os.path.exists(path):
raise FileNotFoundError(f"模型路径不存在: {path}")
# 初始化调度器
debug_log("初始化调度器...")
scheduler_start = time.time()
scheduler = DPMSolverMultistepScheduler(
num_train_timesteps=1000,
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
algorithm_type="dpmsolver++",
use_karras_sigmas=True
)
scheduler.set_timesteps(args.num_inference_steps)
debug_log(f"调度器初始化完成 | 耗时: {(time.time()-scheduler_start):.2f}s")
# 加载模型
debug_log("加载NPU模型...")
model_load_start = time.time()
unet_session_main = axengine.InferenceSession(args.unet_model)
vae_decoder = axengine.InferenceSession(args.vae_decoder_model)
debug_log(f"模型加载完成 | 总耗时: {(time.time()-model_load_start):.2f}s")
debug_log(f"UNET输入信息: {[str(inp) for inp in unet_session_main.get_inputs()]}")
debug_log(f"VAE输入信息: {[str(inp) for inp in vae_decoder.get_inputs()]}")
# 处理提示词
embed_start = time.time()
neg_embeds, pos_embeds = get_embeds(
args.prompt,
"sketch, duplicate, ugly...", # 简化的负面提示
os.path.join(args.text_model_dir, 'tokenizer'),
os.path.join(args.text_model_dir, 'text_encoder')
)
debug_log(f"提示词处理完成 | 总耗时: {(time.time()-embed_start):.2f}s")
# 初始化潜在变量
latent_start = time.time()
latents_shape = [1, 4, 60, 40]
generator = torch.Generator().manual_seed(seed)
latent = torch.randn(latents_shape, generator=generator)
init_scale = scheduler.init_noise_sigma
latent = latent * init_scale
debug_log(f"潜在变量初始化 | 形状: {latent.shape} | 缩放系数: {init_scale}")
latent = latent.numpy().astype(np.float32) # 确保数据类型正确
debug_log(f"潜在变量转换完成 | dtype: {latent.dtype}")
# 加载时间嵌入
debug_log(f"加载时间嵌入: {args.time_input}")
time_data = np.load(args.time_input)
if len(time_data) < args.num_inference_steps:
raise ValueError(f"时间嵌入不足: 需要{args.num_inference_steps}, 实际{len(time_data)}")
time_data = time_data[:args.num_inference_steps]
debug_log(f"时间嵌入验证通过 | 形状: {time_data.shape}")
# 采样主循环
debug_log("开始采样循环...")
total_unet_time = 0
for step_idx, timestep in enumerate(scheduler.timesteps.numpy().astype(np.int64)):
step_start = time.time()
debug_log(f"\n--- 步骤 {step_idx+1}/{args.num_inference_steps} [ts={timestep}] ---")
try:
# 验证潜在变量
if np.isnan(latent).any():
raise ValueError("潜在变量包含NaN值!")
# 准备时间嵌入
time_emb = np.expand_dims(time_data[step_idx], axis=0)
debug_log(f"时间嵌入形状: {time_emb.shape}")
# UNET推理
debug_log("运行UNET(负面提示)...")
unet_neg_start = time.time()
noise_pred_neg = unet_session_main.run(None, {
"sample": latent,
"/down_blocks.0/resnets.0/act_1/Mul_output_0": time_emb,
"encoder_hidden_states": neg_embeds
})[0]
debug_log(f"UNET(负面)完成 | 形状: {noise_pred_neg.shape} | 耗时: {(time.time()-unet_neg_start):.2f}s")
debug_log("运行UNET(正面提示)...")
unet_pos_start = time.time()
noise_pred_pos = unet_session_main.run(None, {
"sample": latent,
"/down_blocks.0/resnets.0/act_1/Mul_output_0": time_emb,
"encoder_hidden_states": pos_embeds
})[0]
debug_log(f"UNET(正面)完成 | 耗时: {(time.time()-unet_pos_start):.2f}s")
# CFG融合
debug_log(f"应用CFG指导(scale={args.guidance_scale})...")
noise_pred = noise_pred_neg + args.guidance_scale * (noise_pred_pos - noise_pred_neg)
debug_log(f"噪声预测范围: [{noise_pred.min():.3f}, {noise_pred.max():.3f}]")
# 转换为Tensor
latent_tensor = torch.from_numpy(latent)
noise_pred_tensor = torch.from_numpy(noise_pred)
# 调度器更新
debug_log("更新潜在变量...")
scheduler_start = time.time()
latent_tensor = scheduler.step(
model_output=noise_pred_tensor,
timestep=timestep,
sample=latent_tensor
).prev_sample
debug_log(f"调度器更新完成 | 耗时: {(time.time()-scheduler_start):.2f}s")
# 转换回numpy
latent = latent_tensor.numpy().astype(np.float32)
debug_log(f"更新后潜在变量范围: [{latent.min():.3f}, {latent.max():.3f}]")
step_time = time.time() - step_start
total_unet_time += step_time
debug_log(f"步骤完成 | 单步耗时: {step_time:.2f}s | 累计耗时: {total_unet_time:.2f}s")
except Exception as e:
print(f"步骤 {step_idx+1} 执行失败: {str(e)}")
traceback.print_exc()
exit(1)
# VAE解码
debug_log("\n开始VAE解码...")
vae_start = time.time()
try:
latent = latent / 0.18215
debug_log(f"VAE输入范围: [{latent.min():.3f}, {latent.max():.3f}]")
image = vae_decoder.run(None, {"latent": latent})[0]
debug_log(f"VAE输出形状: {image.shape} | 耗时: {(time.time()-vae_start):.2f}s")
except Exception as e:
print(f"VAE解码失败: {str(e)}")
traceback.print_exc()
exit(1)
# 保存结果
debug_log("保存结果...")
try:
image = np.transpose(image, (0, 2, 3, 1)).squeeze(axis=0)
image_denorm = np.clip(image / 2 + 0.5, 0, 1)
image = (image_denorm * 255).round().astype("uint8")
debug_log(f"图像形状: {image.shape} | dtype: {image.dtype}")
pil_image = Image.fromarray(image[:, :, :3])
save_path = os.path.join(args.save_dir, f"{uuid.uuid4()}.png")
pil_image.save(save_path)
debug_log(f"图像保存成功: {save_path}")
except Exception as e:
print(f"保存失败: {str(e)}")
traceback.print_exc()
exit(1)
except Exception as e:
print(f"主流程执行失败: {str(e)}")
traceback.print_exc()
exit(1)
if __name__ == '__main__':
main()
|