Initial codes commit
Browse files- app.py +5 -3
- config.py +11 -3
- db/initializer.py +4 -4
- models/embedder.py +26 -0
- models/initializer.py +37 -0
- {modules โ models}/reranker.py +14 -21
- modules/embedder.py +0 -23
- service/search.py +2 -2
app.py
CHANGED
|
@@ -4,14 +4,16 @@ from fastapi import FastAPI, Request, Form
|
|
| 4 |
from fastapi.responses import HTMLResponse
|
| 5 |
from fastapi.templating import Jinja2Templates
|
| 6 |
from api.endpoints import router
|
| 7 |
-
from db.initializer import
|
|
|
|
| 8 |
from service.search import search
|
| 9 |
|
| 10 |
templates = Jinja2Templates(directory="templates")
|
| 11 |
|
| 12 |
@asynccontextmanager
|
| 13 |
-
async def lifespan(
|
| 14 |
-
|
|
|
|
| 15 |
yield
|
| 16 |
|
| 17 |
app = FastAPI(lifespan=lifespan)
|
|
|
|
| 4 |
from fastapi.responses import HTMLResponse
|
| 5 |
from fastapi.templating import Jinja2Templates
|
| 6 |
from api.endpoints import router
|
| 7 |
+
from db.initializer import initialize_dbs
|
| 8 |
+
from models.initializer import initialize_models
|
| 9 |
from service.search import search
|
| 10 |
|
| 11 |
templates = Jinja2Templates(directory="templates")
|
| 12 |
|
| 13 |
@asynccontextmanager
|
| 14 |
+
async def lifespan(app: FastAPI):
|
| 15 |
+
initialize_dbs()
|
| 16 |
+
initialize_models(app)
|
| 17 |
yield
|
| 18 |
|
| 19 |
app = FastAPI(lifespan=lifespan)
|
config.py
CHANGED
|
@@ -5,7 +5,7 @@ import os
|
|
| 5 |
HF_TOKEN = os.getenv("HF_TOKEN")
|
| 6 |
|
| 7 |
# HF datasets repo info
|
| 8 |
-
|
| 9 |
HF_INDEX_FILE = os.getenv("HF_INDEX_FILE", "faiss/faiss_index_flat.faiss")
|
| 10 |
HF_IDS_FILE = os.getenv("HF_IDS_FILE", "faiss/vector_ids.npy")
|
| 11 |
|
|
@@ -18,8 +18,16 @@ HF_CORPUS_SPLIT = os.getenv("HF_CORPUS_SPLIT", "train")
|
|
| 18 |
MARKER_DIR = os.getenv("MARKER_DIR", "rag/state")
|
| 19 |
CORPUS_READY_MARK = os.path.join(MARKER_DIR, ".corpus_ready")
|
| 20 |
|
| 21 |
-
# Embedding
|
| 22 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
TOP_K = int(os.getenv("TOP_K", "5"))
|
| 24 |
|
| 25 |
|
|
|
|
| 5 |
HF_TOKEN = os.getenv("HF_TOKEN")
|
| 6 |
|
| 7 |
# HF datasets repo info
|
| 8 |
+
HF_DS_REPO_ID = os.getenv("HF_REPO_ID", "m97j/pls-datasets")
|
| 9 |
HF_INDEX_FILE = os.getenv("HF_INDEX_FILE", "faiss/faiss_index_flat.faiss")
|
| 10 |
HF_IDS_FILE = os.getenv("HF_IDS_FILE", "faiss/vector_ids.npy")
|
| 11 |
|
|
|
|
| 18 |
MARKER_DIR = os.getenv("MARKER_DIR", "rag/state")
|
| 19 |
CORPUS_READY_MARK = os.path.join(MARKER_DIR, ".corpus_ready")
|
| 20 |
|
| 21 |
+
# Embedding model
|
| 22 |
+
HF_MODEL_REPO_ID = os.getenv("HF_MODEL_REPO_ID", "m97j/pragmatic-search")
|
| 23 |
+
EMBED_MODEL = os.getenv("EMBED_MODEL", "model_quantized.onnx")
|
| 24 |
+
EMBED_DIR = os.getenv("EMBED_DIR", "embedder")
|
| 25 |
+
|
| 26 |
+
# Reranking model
|
| 27 |
+
RERANK_MODEL = os.getenv("RERANK_MODEL", "model_quantized.onnx")
|
| 28 |
+
RERANK_DIR = os.getenv("RERANK_DIR", "reranker")
|
| 29 |
+
|
| 30 |
+
# Retrieval settings
|
| 31 |
TOP_K = int(os.getenv("TOP_K", "5"))
|
| 32 |
|
| 33 |
|
db/initializer.py
CHANGED
|
@@ -2,7 +2,7 @@
|
|
| 2 |
import faiss
|
| 3 |
import numpy as np
|
| 4 |
from huggingface_hub import hf_hub_download
|
| 5 |
-
from config import
|
| 6 |
from modules.retriever import set_index
|
| 7 |
from modules import corpus
|
| 8 |
|
|
@@ -10,8 +10,8 @@ _vector_ids = None
|
|
| 10 |
|
| 11 |
def _load_index_in_memory():
|
| 12 |
"""HF Hub์์ ์ธ๋ฑ์ค/ID ๋งคํ์ ๋ฐ์ ๋ฉ๋ชจ๋ฆฌ์ ๋ก๋"""
|
| 13 |
-
index_path = hf_hub_download(repo_id=
|
| 14 |
-
ids_path = hf_hub_download(repo_id=
|
| 15 |
index = faiss.read_index(index_path)
|
| 16 |
set_index(index)
|
| 17 |
global _vector_ids
|
|
@@ -21,7 +21,7 @@ def get_vector_ids():
|
|
| 21 |
global _vector_ids
|
| 22 |
return _vector_ids
|
| 23 |
|
| 24 |
-
def
|
| 25 |
# 1) ์ฝํผ์ค ์ค๋น (์ต์ด 1ํ๋ง ๋ค์ด๋ก๋)
|
| 26 |
corpus.prepare_corpus()
|
| 27 |
# 2) ์ธ๋ฑ์ค/ID ๋งคํ ๋ฉ๋ชจ๋ฆฌ ๋ก๋
|
|
|
|
| 2 |
import faiss
|
| 3 |
import numpy as np
|
| 4 |
from huggingface_hub import hf_hub_download
|
| 5 |
+
from config import HF_DS_REPO_ID, HF_INDEX_FILE, HF_IDS_FILE
|
| 6 |
from modules.retriever import set_index
|
| 7 |
from modules import corpus
|
| 8 |
|
|
|
|
| 10 |
|
| 11 |
def _load_index_in_memory():
|
| 12 |
"""HF Hub์์ ์ธ๋ฑ์ค/ID ๋งคํ์ ๋ฐ์ ๋ฉ๋ชจ๋ฆฌ์ ๋ก๋"""
|
| 13 |
+
index_path = hf_hub_download(repo_id=HF_DS_REPO_ID, filename=HF_INDEX_FILE, repo_type="dataset")
|
| 14 |
+
ids_path = hf_hub_download(repo_id=HF_DS_REPO_ID, filename=HF_IDS_FILE, repo_type="dataset")
|
| 15 |
index = faiss.read_index(index_path)
|
| 16 |
set_index(index)
|
| 17 |
global _vector_ids
|
|
|
|
| 21 |
global _vector_ids
|
| 22 |
return _vector_ids
|
| 23 |
|
| 24 |
+
def initialize_dbs():
|
| 25 |
# 1) ์ฝํผ์ค ์ค๋น (์ต์ด 1ํ๋ง ๋ค์ด๋ก๋)
|
| 26 |
corpus.prepare_corpus()
|
| 27 |
# 2) ์ธ๋ฑ์ค/ID ๋งคํ ๋ฉ๋ชจ๋ฆฌ ๋ก๋
|
models/embedder.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# rag/models/embedder.py
|
| 2 |
+
from typing import List
|
| 3 |
+
import numpy as np
|
| 4 |
+
import onnxruntime as ort
|
| 5 |
+
from fastapi import Request
|
| 6 |
+
|
| 7 |
+
def _l2_normalize(vec: np.ndarray) -> List[float]:
|
| 8 |
+
norm = np.linalg.norm(vec) or 1.0
|
| 9 |
+
return (vec / norm).tolist()
|
| 10 |
+
|
| 11 |
+
def get_embedding(request: Request, text: str) -> List[float]:
|
| 12 |
+
"""
|
| 13 |
+
request.app.state.embedder_sess : ONNX Runtime InferenceSession
|
| 14 |
+
request.app.state.embedder_tokenizer : ํ ํฌ๋์ด์
|
| 15 |
+
"""
|
| 16 |
+
tokenizer = request.app.state.embedder_tokenizer
|
| 17 |
+
sess: ort.InferenceSession = request.app.state.embedder_sess
|
| 18 |
+
|
| 19 |
+
inputs = tokenizer(text, return_tensors="np", padding=True, truncation=True, max_length=256)
|
| 20 |
+
ort_inputs = {k: v for k, v in inputs.items()}
|
| 21 |
+
ort_outs = sess.run(None, ort_inputs)
|
| 22 |
+
# ์ผ๋ฐ์ ์ผ๋ก ์ฒซ ๋ฒ์งธ ์ถ๋ ฅ์ด [batch, dim] ์๋ฒ ๋ฉ
|
| 23 |
+
vec = ort_outs[0][0]
|
| 24 |
+
return _l2_normalize(vec)
|
| 25 |
+
|
| 26 |
+
|
models/initializer.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# rag/models/initializer.py
|
| 2 |
+
from transformers import AutoTokenizer
|
| 3 |
+
import onnxruntime as ort
|
| 4 |
+
from huggingface_hub import hf_hub_download
|
| 5 |
+
from fastapi import FastAPI
|
| 6 |
+
from config import HF_MODEL_REPO_ID, EMBED_MODEL, EMBED_DIR, RERANK_MODEL, RERANK_DIR
|
| 7 |
+
|
| 8 |
+
def initialize_models(app: FastAPI):
|
| 9 |
+
# Embedder
|
| 10 |
+
embedder_tokenizer = AutoTokenizer.from_pretrained(
|
| 11 |
+
HF_MODEL_REPO_ID,
|
| 12 |
+
subfolder=EMBED_DIR # ํ ํฌ๋์ด์ ๊ด๋ จ ํ์ผ์ด embedder/ ์์ ์์ผ๋ฏ๋ก ์ง์
|
| 13 |
+
)
|
| 14 |
+
embedder_model_path = hf_hub_download(
|
| 15 |
+
repo_id=HF_MODEL_REPO_ID,
|
| 16 |
+
filename=EMBED_MODEL,
|
| 17 |
+
subfolder=EMBED_DIR
|
| 18 |
+
)
|
| 19 |
+
embedder_sess = ort.InferenceSession(embedder_model_path, providers=["CPUExecutionProvider"])
|
| 20 |
+
|
| 21 |
+
# Reranker
|
| 22 |
+
reranker_tokenizer = AutoTokenizer.from_pretrained(
|
| 23 |
+
HF_MODEL_REPO_ID,
|
| 24 |
+
subfolder=RERANK_DIR # ํ ํฌ๋์ด์ ๊ด๋ จ ํ์ผ์ด reranker/ ์์ ์์ผ๋ฏ๋ก ์ง์
|
| 25 |
+
)
|
| 26 |
+
reranker_model_path = hf_hub_download(
|
| 27 |
+
repo_id=HF_MODEL_REPO_ID,
|
| 28 |
+
filename=RERANK_MODEL,
|
| 29 |
+
subfolder=RERANK_DIR
|
| 30 |
+
)
|
| 31 |
+
reranker_sess = ort.InferenceSession(reranker_model_path, providers=["CPUExecutionProvider"])
|
| 32 |
+
|
| 33 |
+
# FastAPI app.state์ ์ ์ฅ โ ์ ์ญ ๊ณต์
|
| 34 |
+
app.state.embedder_tokenizer = embedder_tokenizer
|
| 35 |
+
app.state.embedder_sess = embedder_sess
|
| 36 |
+
app.state.reranker_tokenizer = reranker_tokenizer
|
| 37 |
+
app.state.reranker_sess = reranker_sess
|
{modules โ models}/reranker.py
RENAMED
|
@@ -1,39 +1,32 @@
|
|
| 1 |
-
# rag/
|
| 2 |
import os
|
| 3 |
from typing import List, Dict
|
| 4 |
-
|
|
|
|
| 5 |
|
| 6 |
-
# ํ๊ฒฝ๋ณ์์์ ๋ชจ๋ธ๋ช
๊ณผ ํ ํฐ ๋ถ๋ฌ์ค๊ธฐ
|
| 7 |
-
HF_TOKEN = os.getenv("HF_TOKEN")
|
| 8 |
-
RERANK_MODEL = os.getenv("RERANK_MODEL", "BAAI/bge-reranker-large")
|
| 9 |
-
|
| 10 |
-
_client = InferenceClient(model=RERANK_MODEL, token=HF_TOKEN)
|
| 11 |
-
|
| 12 |
-
# threshold ๊ฐ์ ํ๊ฒฝ๋ณ์๋ config์์ ๊ด๋ฆฌ ๊ฐ๋ฅ
|
| 13 |
THRESHOLD = float(os.getenv("RERANK_THRESHOLD", "0.3"))
|
| 14 |
|
| 15 |
-
def rerank(query: str, contexts: List[Dict]) -> List[Dict]:
|
| 16 |
"""
|
|
|
|
|
|
|
| 17 |
contexts: [{"id": ..., "text": ...}, ...]
|
| 18 |
-
๋ฐํ: threshold ์ด์ ์ ์๋ง ํฌํจ๋ reranked contexts
|
| 19 |
"""
|
| 20 |
if not contexts:
|
| 21 |
return []
|
| 22 |
|
| 23 |
-
|
| 24 |
-
|
| 25 |
|
| 26 |
-
|
| 27 |
-
|
|
|
|
|
|
|
|
|
|
| 28 |
|
| 29 |
-
# scores๋ [{"score": float}, ...] ํํ
|
| 30 |
for ctx, sc in zip(contexts, scores):
|
| 31 |
-
ctx["score"] = sc
|
| 32 |
|
| 33 |
-
# ์ ์ ๋ด๋ฆผ์ฐจ์ ์ ๋ ฌ
|
| 34 |
reranked = sorted(contexts, key=lambda x: x["score"], reverse=True)
|
| 35 |
-
|
| 36 |
-
# threshold ์ด์๋ง ํํฐ๋ง
|
| 37 |
reranked = [c for c in reranked if c["score"] >= THRESHOLD]
|
| 38 |
-
|
| 39 |
return reranked
|
|
|
|
| 1 |
+
# rag/models/reranker.py
|
| 2 |
import os
|
| 3 |
from typing import List, Dict
|
| 4 |
+
import onnxruntime as ort
|
| 5 |
+
from fastapi import Request
|
| 6 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
THRESHOLD = float(os.getenv("RERANK_THRESHOLD", "0.3"))
|
| 8 |
|
| 9 |
+
def rerank(request: Request, query: str, contexts: List[Dict]) -> List[Dict]:
|
| 10 |
"""
|
| 11 |
+
request.app.state.reranker_sess : ONNX Runtime InferenceSession
|
| 12 |
+
request.app.state.reranker_tokenizer : ํ ํฌ๋์ด์
|
| 13 |
contexts: [{"id": ..., "text": ...}, ...]
|
|
|
|
| 14 |
"""
|
| 15 |
if not contexts:
|
| 16 |
return []
|
| 17 |
|
| 18 |
+
tokenizer = request.app.state.reranker_tokenizer
|
| 19 |
+
sess: ort.InferenceSession = request.app.state.reranker_sess
|
| 20 |
|
| 21 |
+
pairs = [(query, ctx["text"]) for ctx in contexts]
|
| 22 |
+
inputs = tokenizer(pairs, return_tensors="np", padding=True, truncation=True, max_length=256)
|
| 23 |
+
ort_inputs = {k: v for k, v in inputs.items()}
|
| 24 |
+
scores = sess.run(None, ort_inputs)[0] # [batch, 1] ํํ๋ผ๊ณ ๊ฐ์
|
| 25 |
+
scores = scores.squeeze(-1)
|
| 26 |
|
|
|
|
| 27 |
for ctx, sc in zip(contexts, scores):
|
| 28 |
+
ctx["score"] = float(sc)
|
| 29 |
|
|
|
|
| 30 |
reranked = sorted(contexts, key=lambda x: x["score"], reverse=True)
|
|
|
|
|
|
|
| 31 |
reranked = [c for c in reranked if c["score"] >= THRESHOLD]
|
|
|
|
| 32 |
return reranked
|
modules/embedder.py
DELETED
|
@@ -1,23 +0,0 @@
|
|
| 1 |
-
# rag/modules/embedder.py
|
| 2 |
-
import math
|
| 3 |
-
from typing import List
|
| 4 |
-
from huggingface_hub import InferenceClient
|
| 5 |
-
from config import EMBED_MODEL, HF_TOKEN
|
| 6 |
-
|
| 7 |
-
# ๋ชจ๋ธ๊ณผ ํ ํฐ ์ง์
|
| 8 |
-
_client = InferenceClient(model=EMBED_MODEL, token=HF_TOKEN)
|
| 9 |
-
|
| 10 |
-
def _l2_normalize(vec: List[float]) -> List[float]:
|
| 11 |
-
norm = math.sqrt(sum(x * x for x in vec)) or 1.0
|
| 12 |
-
return [x / norm for x in vec]
|
| 13 |
-
|
| 14 |
-
def get_embedding(text: str) -> List[float]:
|
| 15 |
-
# embeddings ํ์คํฌ๋ฅผ ์ง์ ์ง์
|
| 16 |
-
response = _client.post(
|
| 17 |
-
json={"inputs": text},
|
| 18 |
-
task="embeddings"
|
| 19 |
-
)
|
| 20 |
-
# ํญ์ [batch_size, embedding_dim] ํํ ๋ฐํ
|
| 21 |
-
vec = response[0]
|
| 22 |
-
return _l2_normalize(vec)
|
| 23 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
service/search.py
CHANGED
|
@@ -1,8 +1,8 @@
|
|
| 1 |
# rag/service/search.py
|
| 2 |
-
from
|
| 3 |
from modules.retriever import retrieve_ids
|
| 4 |
from modules.corpus import fetch_contexts_by_ids
|
| 5 |
-
from
|
| 6 |
|
| 7 |
def search(query: str) -> list[dict]:
|
| 8 |
embedding = get_embedding(query)
|
|
|
|
| 1 |
# rag/service/search.py
|
| 2 |
+
from models.embedder import get_embedding
|
| 3 |
from modules.retriever import retrieve_ids
|
| 4 |
from modules.corpus import fetch_contexts_by_ids
|
| 5 |
+
from models.reranker import rerank
|
| 6 |
|
| 7 |
def search(query: str) -> list[dict]:
|
| 8 |
embedding = get_embedding(query)
|