| |
| from fastapi import FastAPI, HTTPException |
| from fastapi.responses import StreamingResponse |
| from pydantic import BaseModel |
| from typing import Optional |
| import uvicorn |
| import numpy as np |
| import os |
| import torch |
| from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig |
| from config import LLM_HF_MODEL, LLM_AX_MODEL, LLM_API_PORT |
|
|
|
|
| app = FastAPI(title="Fast-API", description="本地推理接口") |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
| """ |
| axengine 相关 |
| """ |
| from ml_dtypes import bfloat16 |
| from utils.infer_func import InferManager |
|
|
| |
| tokenizer = None |
| imer = None |
| embeds = None |
|
|
| def init_model(): |
| global tokenizer, imer, embeds |
| if tokenizer is None: |
| cfg = AutoConfig.from_pretrained(LLM_HF_MODEL) |
| imer = InferManager(cfg, LLM_AX_MODEL, model_type="qwen2") |
| embeds = np.load(os.path.join(LLM_AX_MODEL, "model.embed_tokens.weight.npy")) |
| |
| tokenizer = AutoTokenizer.from_pretrained(LLM_HF_MODEL, trust_remote_code=True) |
| print("✅ 模型加载完成。") |
|
|
| |
| @app.on_event("startup") |
| async def startup_event(): |
| init_model() |
|
|
| class GenRequest(BaseModel): |
| prompt: str |
| max_tokens: Optional[int] = 1024 |
| temperature: Optional[float] = 0.6 |
| top_p: Optional[float] = 0.9 |
|
|
| class GenResponse(BaseModel): |
| text: str |
|
|
| @app.post("/generate", response_model=GenResponse) |
| def generate_text(req: GenRequest): |
| try: |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| messages = [ |
| {"role": "system", "content": "你的名字叫做 [AXERA-RAG 助手]. 你是一个高效、精准的问答助手. 你可以根据上下文内容, 回答用户提出的问题, 回答时不要提及多余的、无用的内容, 且仅输出你的回答."}, |
| {"role": "user", "content": req.prompt} |
| ] |
| text = tokenizer.apply_chat_template( |
| messages, |
| tokenize=False, |
| add_generation_prompt=True |
| ) |
| model_inputs = tokenizer([text], return_tensors="pt").to(device) |
|
|
| """ |
| axengine 框架模型推理 |
| """ |
| input_ids = model_inputs['input_ids'] |
| inputs_embeds = np.take(embeds, input_ids.cpu().numpy(), axis=0) |
| prefill_data = inputs_embeds |
| prefill_data = prefill_data.astype(bfloat16) |
| token_ids = input_ids[0].cpu().numpy().tolist() |
| generated_text = "" |
|
|
| def generate_stream(): |
| nonlocal token_ids, generated_text |
| token_ids = imer.prefill(tokenizer, token_ids, prefill_data[0], slice_len=128) |
| generated_text += tokenizer.decode(token_ids[-1], skip_special_tokens=True) |
|
|
| |
| |
| |
| |
| |
|
|
| |
| prefill_word = tokenizer.decode(token_ids[-1], skip_special_tokens=True) |
| prefill_word = prefill_word.strip().replace("\n", "\\n").replace("\"", "\\\"") |
|
|
| seq_len = len(token_ids) - 1 |
| prefill_len = 128 |
| for step_idx in range(imer.max_seq_len): |
| if prefill_len > 0 and step_idx < seq_len: |
| continue |
| token_ids, next_token_id = imer.decode_next_token(tokenizer, token_ids, embeds, slice_len=128, step_idx=step_idx) |
| if next_token_id == tokenizer.eos_token_id and next_token_id > seq_len: |
| break |
| try: |
| if next_token_id is not None: |
| word = tokenizer.decode([next_token_id], skip_special_tokens=True) |
| generated_text += word |
| if prefill_word is not None: |
| word = prefill_word + word |
| prefill_word = None |
| |
| |
| word = word.strip().replace("\n", "\\n").replace("\"", "\\\"") |
| |
| yield f"data: {{\"token\": \"{word}\"}}\n\n" |
| except Exception as e: |
| print(f"Error decoding token {next_token_id}: {e}") |
|
|
| return StreamingResponse( |
| generate_stream(), |
| media_type="text/event-stream", |
| headers={ |
| "Cache-Control": "no-cache", |
| "Connection": "keep-alive", |
| "X-Accel-Buffering": "no" |
| } |
| ) |
|
|
| except Exception as e: |
| raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
| if __name__ == "__main__": |
| uvicorn.run(app, host="0.0.0.0", port=LLM_API_PORT, reload=False) |
|
|