RAG.axera / rag_engine.py
yongqiang
Initialize the repository
1ed9a31
import os
import faiss
import pickle
from PyPDF2 import PdfReader
from transformers import AutoTokenizer, AutoModel, AutoConfig
from torch.nn import functional as F
import torch
from config import INDEX_FILE, EMBEDDINGS_FILE, LLM_API_URL, EMBED_AX_MODEL, EMBED_HF_MODEL
import numpy as np
import requests
import json
import re
import chardet # 用于检测文本编码
device = "cuda" if torch.cuda.is_available() else "cpu"
# ========== Transformers 加载 embedding 模型 ==========
tokenizer = AutoTokenizer.from_pretrained(EMBED_HF_MODEL, padding_side="left")
"""
axengine 相关
加载 embedding 模型
"""
from ml_dtypes import bfloat16
from utils.infer_func import InferManager
embeds = np.load(os.path.join(EMBED_AX_MODEL, "model.embed_tokens.weight.npy"))
cfg = AutoConfig.from_pretrained(EMBED_HF_MODEL)
imer = InferManager(cfg, EMBED_AX_MODEL, device_id=0) # 如果运行在 axcl 上, device_id 可以指定除 0 之外可访问的卡 id
"""
torch 加载 embedding 模型
model = AutoModel.from_pretrained(EMBED_HF_MODEL).to(device)
model.eval()
embedder = model
"""
def last_token_pool(last_hidden_states, attention_mask):
left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
if left_padding:
return last_hidden_states[:, -1]
else:
sequence_lengths = attention_mask.sum(dim=1) - 1
batch_size = last_hidden_states.shape[0]
return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]
def encode_texts(texts):
task_desc = "Given a web search query, retrieve relevant passages that answer the query"
inputs = [f"Instruct: {task_desc}\nQuery: {t}" for t in texts]
inputs_tokenized = tokenizer(
inputs, padding=True, truncation=True, max_length=8192, return_tensors="pt"
)
inputs_tokenized = {k: v.to(device) for k, v in inputs_tokenized.items()}
"""
torch 相关
with torch.no_grad():
outputs = model(**inputs_tokenized)
embeddings = last_token_pool(outputs.last_hidden_state, inputs_tokenized["attention_mask"])
embeddings = F.normalize(embeddings, p=2, dim=1)
"""
"""
axengine 相关
"""
input_ids = inputs_tokenized['input_ids']
inputs_embeds = np.take(embeds, input_ids.cpu().numpy(), axis=0)
prefill_data = inputs_embeds
prefill_data = prefill_data.astype(bfloat16)
token_ids = input_ids[0].cpu().numpy().tolist()
token_len = len(token_ids)
batch_num, seq_len, seq_dim = inputs_embeds.shape
last_hidden_state = np.zeros((batch_num, seq_len, seq_dim), dtype=bfloat16)
for batch_idx in range(batch_num):
last_hidden_state[batch_idx] = imer.prefill(tokenizer, token_ids, prefill_data[batch_idx], slice_len=128, return_last_hidden_state=True)
embeddings = last_token_pool(torch.from_numpy(last_hidden_state.astype(np.float32)), inputs_tokenized['attention_mask'])
# normalize embeddings
embeddings = F.normalize(embeddings, p=2, dim=1)
return embeddings.cpu().numpy()
# 读取 PDF 并分段
def load_pdf_chunks(pdf_path, chunk_size=500, chunk_overlap=100):
reader = PdfReader(pdf_path)
all_text = ""
for page in reader.pages:
all_text += page.extract_text() + "\n"
# 按字符长度切分
chunks = []
start = 0
while start < len(all_text):
end = min(start + chunk_size, len(all_text))
chunks.append(all_text[start:end])
start += chunk_size - chunk_overlap
return chunks
# 读取 TXT 文件并分段
def load_txt_chunks(txt_path, chunk_size=20, chunk_overlap=5):
with open(txt_path, 'rb') as f:
raw_data = f.read()
result = chardet.detect(raw_data)
encoding = result['encoding'] if result['encoding'] else 'utf-8'
try:
with open(txt_path, 'r', encoding=encoding) as f:
all_text = f.read()
except UnicodeDecodeError:
try:
with open(txt_path, 'r', encoding='gbk') as f:
all_text = f.read()
except:
with open(txt_path, 'r', encoding='latin-1') as f:
all_text = f.read()
all_text = re.sub(r'\s+', ' ', all_text).strip()
chunks = []
start = 0
while start < len(all_text):
end = min(start + chunk_size, len(all_text))
chunks.append(all_text[start:end])
start += chunk_size - chunk_overlap
return chunks
# 构建并保存向量索引
def build_index(file_path):
# 根据文件类型选择加载方法
if file_path.lower().endswith('.pdf'):
chunks = load_pdf_chunks(file_path)
elif file_path.lower().endswith('.txt'):
chunks = load_txt_chunks(file_path)
else:
raise ValueError(f"不支持的文件类型: {file_path}")
embeddings = encode_texts(chunks) # use transformers model
faiss.normalize_L2(embeddings)
dim = embeddings.shape[1]
index = faiss.IndexFlatIP(dim)
index.add(embeddings)
# 保存
faiss.write_index(index, INDEX_FILE)
with open(EMBEDDINGS_FILE, "wb") as f:
pickle.dump(chunks, f)
return f"✅ 成功构建索引: {len(chunks)}个片段"
def index_exists():
return os.path.exists(INDEX_FILE) and os.path.exists(EMBEDDINGS_FILE)
def get_top_k(query, k=3):
if not index_exists():
return []
index = faiss.read_index(INDEX_FILE)
with open(EMBEDDINGS_FILE, "rb") as f:
texts = pickle.load(f)
# query_vec = model.encode([query])
query_vec = encode_texts([query]) # use transformers model
D, I = index.search(query_vec, k)
return [texts[i] for i in I[0]]
def ask_question(query):
context = "\n".join(get_top_k(query))
prompt = f"""上下文内容是你可以参考的资料, 用户问题才是你需要回答的内容.
[上下文内容]:
- {context}\n
[用户问题]:
- {query}\n
[简洁的输出回答]:
"""
print("DEBUG: prompt is \n", prompt)
# 向本地 LLM API 发请求
response = requests.post(LLM_API_URL, json={"prompt": prompt, "max_tokens": 1024})
return response.json().get("text", "❌ LLM 接口未响应").strip()
def stream_answer(query):
context = "\n".join(get_top_k(query))
prompt = f"""上下文内容是你可以参考的资料, 用户问题才是你需要回答的内容.
[上下文内容]:
- {context}\n
[用户问题]:
- {query}\n
[简洁的输出回答]:
"""
print("DEBUG: prompt is \n", prompt)
"""流式获取答案并逐个token生成的函数"""
data = {
"prompt": prompt,
"max_tokens": 1024,
"temperature": 0.6,
"top_p": 0.9
}
try:
# 发送流式请求
with requests.post(
LLM_API_URL,
json=data,
stream=True
) as response:
# 检查响应状态
if response.status_code != 200:
yield f"⚠️ 请求错误:{response.status_code}"
return
# 处理流式数据
for chunk in response.iter_lines():
# 过滤心跳和空行
if chunk and b'data:' in chunk:
# 提取JSON数据
line = chunk.decode('utf-8').strip()
json_data = line.replace('data:', '')
try:
# 解析JSON格式
event = json.loads(json_data)
if 'token' in event:
yield event['token']
elif event.get('end') or event.get('finish_reason'):
return
except json.JSONDecodeError:
# 如果后端返回的是文本
yield json_data
except Exception as e:
yield f"⚠️ 连接错误:{str(e)}"
# if __name__ == "__main__":
# import argparse
# parser = argparse.ArgumentParser(description="构建 PDF 索引并回答问题")
# parser.add_argument("--pdf", type=str, required=True, help="PDF 文件路径")
# args = parser.parse_args()
# build_index(args.pdf)
# print("🚗🚗🌲🌲 索引构建完成!")