| | import os |
| | import faiss |
| | import pickle |
| | from PyPDF2 import PdfReader |
| | from transformers import AutoTokenizer, AutoModel, AutoConfig |
| | from torch.nn import functional as F |
| | import torch |
| | from config import INDEX_FILE, EMBEDDINGS_FILE, LLM_API_URL, EMBED_AX_MODEL, EMBED_HF_MODEL |
| | import numpy as np |
| | import requests |
| | import json |
| | import re |
| | import chardet |
| |
|
| | device = "cuda" if torch.cuda.is_available() else "cpu" |
| |
|
| | |
| | tokenizer = AutoTokenizer.from_pretrained(EMBED_HF_MODEL, padding_side="left") |
| |
|
| | """ |
| | axengine 相关 |
| | 加载 embedding 模型 |
| | """ |
| | from ml_dtypes import bfloat16 |
| | from utils.infer_func import InferManager |
| |
|
| | embeds = np.load(os.path.join(EMBED_AX_MODEL, "model.embed_tokens.weight.npy")) |
| | cfg = AutoConfig.from_pretrained(EMBED_HF_MODEL) |
| | imer = InferManager(cfg, EMBED_AX_MODEL, device_id=0) |
| |
|
| | """ |
| | torch 加载 embedding 模型 |
| | model = AutoModel.from_pretrained(EMBED_HF_MODEL).to(device) |
| | model.eval() |
| | embedder = model |
| | """ |
| |
|
| | def last_token_pool(last_hidden_states, attention_mask): |
| | left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0]) |
| | if left_padding: |
| | return last_hidden_states[:, -1] |
| | else: |
| | sequence_lengths = attention_mask.sum(dim=1) - 1 |
| | batch_size = last_hidden_states.shape[0] |
| | return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths] |
| |
|
| | def encode_texts(texts): |
| | task_desc = "Given a web search query, retrieve relevant passages that answer the query" |
| | inputs = [f"Instruct: {task_desc}\nQuery: {t}" for t in texts] |
| |
|
| | inputs_tokenized = tokenizer( |
| | inputs, padding=True, truncation=True, max_length=8192, return_tensors="pt" |
| | ) |
| | inputs_tokenized = {k: v.to(device) for k, v in inputs_tokenized.items()} |
| |
|
| | """ |
| | torch 相关 |
| | with torch.no_grad(): |
| | outputs = model(**inputs_tokenized) |
| | embeddings = last_token_pool(outputs.last_hidden_state, inputs_tokenized["attention_mask"]) |
| | embeddings = F.normalize(embeddings, p=2, dim=1) |
| | """ |
| |
|
| | """ |
| | axengine 相关 |
| | """ |
| | input_ids = inputs_tokenized['input_ids'] |
| | inputs_embeds = np.take(embeds, input_ids.cpu().numpy(), axis=0) |
| | prefill_data = inputs_embeds |
| | prefill_data = prefill_data.astype(bfloat16) |
| | token_ids = input_ids[0].cpu().numpy().tolist() |
| | token_len = len(token_ids) |
| |
|
| | batch_num, seq_len, seq_dim = inputs_embeds.shape |
| | last_hidden_state = np.zeros((batch_num, seq_len, seq_dim), dtype=bfloat16) |
| | for batch_idx in range(batch_num): |
| | last_hidden_state[batch_idx] = imer.prefill(tokenizer, token_ids, prefill_data[batch_idx], slice_len=128, return_last_hidden_state=True) |
| | embeddings = last_token_pool(torch.from_numpy(last_hidden_state.astype(np.float32)), inputs_tokenized['attention_mask']) |
| | |
| | embeddings = F.normalize(embeddings, p=2, dim=1) |
| |
|
| | return embeddings.cpu().numpy() |
| |
|
| | |
| | def load_pdf_chunks(pdf_path, chunk_size=500, chunk_overlap=100): |
| | reader = PdfReader(pdf_path) |
| | all_text = "" |
| | for page in reader.pages: |
| | all_text += page.extract_text() + "\n" |
| |
|
| | |
| | chunks = [] |
| | start = 0 |
| | while start < len(all_text): |
| | end = min(start + chunk_size, len(all_text)) |
| | chunks.append(all_text[start:end]) |
| | start += chunk_size - chunk_overlap |
| | return chunks |
| |
|
| | |
| | def load_txt_chunks(txt_path, chunk_size=20, chunk_overlap=5): |
| | with open(txt_path, 'rb') as f: |
| | raw_data = f.read() |
| | result = chardet.detect(raw_data) |
| | encoding = result['encoding'] if result['encoding'] else 'utf-8' |
| |
|
| | try: |
| | with open(txt_path, 'r', encoding=encoding) as f: |
| | all_text = f.read() |
| | except UnicodeDecodeError: |
| | try: |
| | with open(txt_path, 'r', encoding='gbk') as f: |
| | all_text = f.read() |
| | except: |
| | with open(txt_path, 'r', encoding='latin-1') as f: |
| | all_text = f.read() |
| | all_text = re.sub(r'\s+', ' ', all_text).strip() |
| | chunks = [] |
| | start = 0 |
| | while start < len(all_text): |
| | end = min(start + chunk_size, len(all_text)) |
| | chunks.append(all_text[start:end]) |
| | start += chunk_size - chunk_overlap |
| | return chunks |
| |
|
| | |
| | def build_index(file_path): |
| | |
| | if file_path.lower().endswith('.pdf'): |
| | chunks = load_pdf_chunks(file_path) |
| | elif file_path.lower().endswith('.txt'): |
| | chunks = load_txt_chunks(file_path) |
| | else: |
| | raise ValueError(f"不支持的文件类型: {file_path}") |
| |
|
| | embeddings = encode_texts(chunks) |
| | faiss.normalize_L2(embeddings) |
| | dim = embeddings.shape[1] |
| | index = faiss.IndexFlatIP(dim) |
| | index.add(embeddings) |
| |
|
| | |
| | faiss.write_index(index, INDEX_FILE) |
| | with open(EMBEDDINGS_FILE, "wb") as f: |
| | pickle.dump(chunks, f) |
| |
|
| | return f"✅ 成功构建索引: {len(chunks)}个片段" |
| |
|
| | def index_exists(): |
| | return os.path.exists(INDEX_FILE) and os.path.exists(EMBEDDINGS_FILE) |
| |
|
| | def get_top_k(query, k=3): |
| | if not index_exists(): |
| | return [] |
| | index = faiss.read_index(INDEX_FILE) |
| | with open(EMBEDDINGS_FILE, "rb") as f: |
| | texts = pickle.load(f) |
| | |
| | query_vec = encode_texts([query]) |
| | D, I = index.search(query_vec, k) |
| | return [texts[i] for i in I[0]] |
| |
|
| | def ask_question(query): |
| | context = "\n".join(get_top_k(query)) |
| | prompt = f"""上下文内容是你可以参考的资料, 用户问题才是你需要回答的内容. |
| | [上下文内容]: |
| | - {context}\n |
| | |
| | [用户问题]: |
| | - {query}\n |
| | |
| | [简洁的输出回答]: |
| | """ |
| | print("DEBUG: prompt is \n", prompt) |
| |
|
| | |
| | response = requests.post(LLM_API_URL, json={"prompt": prompt, "max_tokens": 1024}) |
| | return response.json().get("text", "❌ LLM 接口未响应").strip() |
| |
|
| | def stream_answer(query): |
| | context = "\n".join(get_top_k(query)) |
| | prompt = f"""上下文内容是你可以参考的资料, 用户问题才是你需要回答的内容. |
| | [上下文内容]: |
| | - {context}\n |
| | |
| | [用户问题]: |
| | - {query}\n |
| | |
| | [简洁的输出回答]: |
| | """ |
| | print("DEBUG: prompt is \n", prompt) |
| |
|
| | """流式获取答案并逐个token生成的函数""" |
| | data = { |
| | "prompt": prompt, |
| | "max_tokens": 1024, |
| | "temperature": 0.6, |
| | "top_p": 0.9 |
| | } |
| |
|
| | try: |
| | |
| | with requests.post( |
| | LLM_API_URL, |
| | json=data, |
| | stream=True |
| | ) as response: |
| | |
| | if response.status_code != 200: |
| | yield f"⚠️ 请求错误:{response.status_code}" |
| | return |
| | |
| | for chunk in response.iter_lines(): |
| | |
| | if chunk and b'data:' in chunk: |
| | |
| | line = chunk.decode('utf-8').strip() |
| | json_data = line.replace('data:', '') |
| | try: |
| | |
| | event = json.loads(json_data) |
| | if 'token' in event: |
| | yield event['token'] |
| | elif event.get('end') or event.get('finish_reason'): |
| | return |
| | except json.JSONDecodeError: |
| | |
| | yield json_data |
| | except Exception as e: |
| | yield f"⚠️ 连接错误:{str(e)}" |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| |
|