File size: 11,288 Bytes
244baf9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
from fastapi import FastAPI, HTTPException, Response
from pydantic import BaseModel
from contextlib import asynccontextmanager
import numpy as np
from PIL import Image
import io
import uuid
from typing import List, Union

import axengine
import torch

from transformers import CLIPTokenizer, PreTrainedTokenizer
import time
import argparse

import os
import traceback
from diffusers import DPMSolverMultistepScheduler
# 配置日志格式
DEBUG_MODE = True
LOG_TIMESTAMP = True

def debug_log(msg):
    if DEBUG_MODE:
        timestamp = f"[{time.strftime('%Y-%m-%d %H:%M:%S')}] " if LOG_TIMESTAMP else ""
        print(f"{timestamp}[DEBUG] {msg}")
        
# 服务配置
MODEL_PATHS = {
    "tokenizer": "./models/tokenizer",
    "text_encoder": "./models/text_encoder/sd15_text_encoder_sim.axmodel",
    "unet": "./models/unet.axmodel",
    "vae": "./models/vae_decoder.axmodel",
    "time_embeddings": "./models/time_input_dpmpp_20steps.npy"
}

class DiffusionModels:
    def __init__(self):
        self.models_loaded = False
        self.tokenizer = None
        self.text_encoder = None
        self.unet = None
        self.vae = None
        self.time_embeddings = None

    def load_models(self):
        """预加载所有模型到内存"""
        try:
            # 初始化tokenizer和模型
            self.tokenizer = CLIPTokenizer.from_pretrained(MODEL_PATHS["tokenizer"])
            self.text_encoder = axengine.InferenceSession(MODEL_PATHS["text_encoder"])
            self.unet = axengine.InferenceSession(MODEL_PATHS["unet"])
            self.vae = axengine.InferenceSession(MODEL_PATHS["vae"])
            self.time_embeddings = np.load(MODEL_PATHS["time_embeddings"])
            self.models_loaded = True
            print("所有模型已成功加载到内存")
        except Exception as e:
            print(f"模型加载失败: {str(e)}")
            raise

diffusion_models = DiffusionModels()

@asynccontextmanager
async def lifespan(app: FastAPI):
    # 服务启动时加载模型
    diffusion_models.load_models()
    yield
    # 服务关闭时清理资源
    # (根据axengine的要求添加必要的清理逻辑)

app = FastAPI(lifespan=lifespan)

class GenerationRequest(BaseModel):
    positive_prompt: str
    negative_prompt: str = ""
    num_inference_steps: int = 20
    guidance_scale: float = 7.5
    seed: int = None

@app.post("/generate")
async def generate_image(request: GenerationRequest):
    try:
        # 输入验证
        if len(request.positive_prompt) > 1000:
            raise ValueError("提示词过长")
            
        # 执行推理流程
        image = generate_diffusion_image(
            positive_prompt=request.positive_prompt,
            negative_prompt=request.negative_prompt,
            num_steps=request.num_inference_steps,
            guidance_scale=request.guidance_scale,
            seed=request.seed
        )
        
        # 转换图像为字节流
        img_byte_arr = io.BytesIO()
        image.save(img_byte_arr, format='PNG')
        
        return Response(content=img_byte_arr.getvalue(), media_type="image/png")
        
    except Exception as e:
        error_id = str(uuid.uuid4())
        print(f"Error [{error_id}]: {str(e)}")
        raise HTTPException(
            status_code=500,
            detail=f"生成失败,错误ID:{error_id}"
        )
        
        
        
def get_embeds(prompt, negative_prompt):
    """获取正负提示词的嵌入(带形状验证)"""
    try:
        debug_log(f"开始处理提示词: {prompt[:50]}...")
        start_time = time.time()
        
        
        def process_prompt(prompt_text):
            inputs = diffusion_models.tokenizer(
                prompt_text,
                padding="max_length",
                max_length=77,
                truncation=True,
                return_tensors="pt"
            )
            debug_log(f"Tokenizer输出形状: {inputs.input_ids.shape}")
            
            outputs = diffusion_models.text_encoder.run(None, {"input_ids": inputs.input_ids.numpy().astype(np.int32)})[0]
            debug_log(f"文本编码器输出形状: {outputs.shape} | dtype: {outputs.dtype}")
            return outputs
        
        neg_start = time.time()
        neg_embeds = process_prompt(negative_prompt)
        pos_embeds = process_prompt(prompt)
        debug_log(f"文本编码完成 | 耗时: {(time.time()-start_time):.2f}s")
        
        # 验证形状
        if neg_embeds.shape != (1, 77, 768) or pos_embeds.shape != (1, 77, 768):
            raise ValueError(f"嵌入形状异常: 负面{neg_embeds.shape}, 正面{pos_embeds.shape}")
            
        return neg_embeds, pos_embeds
    except Exception as e:
        print(f"获取嵌入失败: {str(e)}")
        traceback.print_exc()
        exit(1)


def generate_diffusion_image(
    positive_prompt: str,
    negative_prompt: str,
    num_steps: int = 20,
    guidance_scale: float = 7.5,
    seed: int = None
) -> Image.Image:
    """
    生成扩散图像的优化版本(保持输入形状不可变)
    
    参数:
        positive_prompt (str): 正向提示词
        negative_prompt (str): 负向提示词
        num_steps (int): 推理步数 (默认20)
        guidance_scale (float): 分类器自由引导系数 (默认7.5)
        seed (int): 随机种子 (可选)
    
    返回:
        PIL.Image.Image: 生成的图像
    
    异常:
        ValueError: 输入参数无效时抛出
        RuntimeError: 推理过程中出现错误时抛出
    """
    try:
        # 参数验证
        if not positive_prompt:
            raise ValueError("正向提示词不能为空")
        if guidance_scale < 1.0 or guidance_scale > 20.0:
            raise ValueError("引导系数需在1.0-20.0之间")

        debug_log("开始生成流程...")
        start_time = time.time()

        # =====================================================================
        # 1. 初始化配置
        # =====================================================================
        seed = seed if seed is not None else int(time.time() * 1000) % 0xFFFFFFFF
        torch.manual_seed(seed)
        np.random.seed(seed)
        debug_log(f"初始随机种子: {seed}")

        # =====================================================================
        # 2. 文本编码 (保持原有输入形状 [1, 77, 768])
        # =====================================================================
        embed_start = time.time()
        neg_emb, pos_emb = get_embeds(
            positive_prompt,
            negative_prompt,
        )
        debug_log(f"文本编码完成 | 耗时: {time.time()-embed_start:.2f}s")

        # =====================================================================
        # 3. 初始化潜在变量 (固定形状 [1, 4, 60, 40])
        # =====================================================================
        scheduler = DPMSolverMultistepScheduler(
            num_train_timesteps=1000,
            beta_start=0.00085,
            beta_end=0.012,
            beta_schedule="scaled_linear",
            algorithm_type="dpmsolver++",
            use_karras_sigmas=True
        )
        scheduler.set_timesteps(num_steps)
        
        latents_shape = (1, 4, 60, 40)
        latent = torch.randn(latents_shape, generator=torch.Generator().manual_seed(seed))
        latent = latent * scheduler.init_noise_sigma
        latent = latent.numpy().astype(np.float32)
        debug_log(f"潜在变量初始化 | 形状: {latent.shape} sigma:{scheduler.init_noise_sigma:.3f}")

        # =====================================================================
        # 4. 准备时间嵌入 (预加载数据截取)
        # =====================================================================
        if len(diffusion_models.time_embeddings) < num_steps:
            raise ValueError(f"时间嵌入不足: 需要{num_steps}步 当前加载{len(diffusion_models.time_embeddings)}步")
        time_steps = diffusion_models.time_embeddings[:num_steps]

        # =====================================================================
        # 5. 采样主循环 (保持输入形状不可变)
        # =====================================================================
        debug_log("开始采样循环...")
        for step_idx, timestep in enumerate(scheduler.timesteps.numpy().astype(np.int64)):
            step_start = time.time()
            
            # 准备时间嵌入 (形状 [1, 1])
            time_emb = np.expand_dims(time_steps[step_idx], axis=0)

            # -----------------------------------------
            # UNET双推理流程 (强制形状匹配)
            # -----------------------------------------
            # 负面提示推理
            noise_pred_neg = diffusion_models.unet.run(None, {
                "sample": latent,
                "/down_blocks.0/resnets.0/act_1/Mul_output_0": time_emb,
                "encoder_hidden_states": neg_emb
            })[0]
            
            # 正面提示推理
            noise_pred_pos = diffusion_models.unet.run(None, {
                "sample": latent,
                "/down_blocks.0/resnets.0/act_1/Mul_output_0": time_emb,
                "encoder_hidden_states": pos_emb
            })[0]

            # CFG融合 (数值稳定性优化)
            noise_pred = noise_pred_neg + guidance_scale * (noise_pred_pos - noise_pred_neg)


            # 转换为Tensor
            latent_tensor = torch.from_numpy(latent)
            noise_pred_tensor = torch.from_numpy(noise_pred)
            
            # 调度器更新
            debug_log("更新潜在变量...")
            scheduler_start = time.time()
            latent_tensor = scheduler.step(
                model_output=noise_pred_tensor,
                timestep=timestep,
                sample=latent_tensor
            ).prev_sample
            debug_log(f"调度器更新完成 | 耗时: {(time.time()-scheduler_start):.2f}s")
            
            # 转换回numpy
            latent = latent_tensor.numpy().astype(np.float32)
            debug_log(f"更新后潜在变量范围: [{latent.min():.3f}, {latent.max():.3f}]")

            debug_log(f"步骤 {step_idx+1}/{num_steps} | 耗时: {time.time()-step_start:.2f}s")

        # =====================================================================
        # 6. VAE解码 (强制输出形状为768x512)
        # =====================================================================
        debug_log("开始VAE解码...")
        vae_start = time.time()
        latent = latent / 0.18215
        image = diffusion_models.vae.run(None, {"latent": latent})[0]
        
        # 转换为PIL图像 (优化内存拷贝)
        image = np.transpose(image.squeeze(), (1, 2, 0))
        image = np.clip((image / 2 + 0.5) * 255, 0, 255).astype(np.uint8)
        pil_image = Image.fromarray(image[..., :3])  # 移除alpha通道
        pil_image.save("./api.png")
        debug_log(f"总耗时: {time.time()-start_time:.2f}s")
        return pil_image

    except Exception as e:
        error_msg = f"生成失败: {str(e)}"
        debug_log(error_msg)
        traceback.print_exc()
        raise RuntimeError(error_msg)