m97j commited on
Commit
2aa7bf4
ยท
1 Parent(s): 307f6b8

Initial codes commit

Browse files
app.py CHANGED
@@ -4,14 +4,16 @@ from fastapi import FastAPI, Request, Form
4
  from fastapi.responses import HTMLResponse
5
  from fastapi.templating import Jinja2Templates
6
  from api.endpoints import router
7
- from db.initializer import initialize
 
8
  from service.search import search
9
 
10
  templates = Jinja2Templates(directory="templates")
11
 
12
  @asynccontextmanager
13
- async def lifespan(_app: FastAPI):
14
- initialize()
 
15
  yield
16
 
17
  app = FastAPI(lifespan=lifespan)
 
4
  from fastapi.responses import HTMLResponse
5
  from fastapi.templating import Jinja2Templates
6
  from api.endpoints import router
7
+ from db.initializer import initialize_dbs
8
+ from models.initializer import initialize_models
9
  from service.search import search
10
 
11
  templates = Jinja2Templates(directory="templates")
12
 
13
  @asynccontextmanager
14
+ async def lifespan(app: FastAPI):
15
+ initialize_dbs()
16
+ initialize_models(app)
17
  yield
18
 
19
  app = FastAPI(lifespan=lifespan)
config.py CHANGED
@@ -5,7 +5,7 @@ import os
5
  HF_TOKEN = os.getenv("HF_TOKEN")
6
 
7
  # HF datasets repo info
8
- HF_REPO_ID = os.getenv("HF_REPO_ID", "m97j/pls-datasets")
9
  HF_INDEX_FILE = os.getenv("HF_INDEX_FILE", "faiss/faiss_index_flat.faiss")
10
  HF_IDS_FILE = os.getenv("HF_IDS_FILE", "faiss/vector_ids.npy")
11
 
@@ -18,8 +18,16 @@ HF_CORPUS_SPLIT = os.getenv("HF_CORPUS_SPLIT", "train")
18
  MARKER_DIR = os.getenv("MARKER_DIR", "rag/state")
19
  CORPUS_READY_MARK = os.path.join(MARKER_DIR, ".corpus_ready")
20
 
21
- # Embedding / LLM model
22
- EMBED_MODEL = os.getenv("EMBED_MODEL", "intfloat/multilingual-e5-large")
 
 
 
 
 
 
 
 
23
  TOP_K = int(os.getenv("TOP_K", "5"))
24
 
25
 
 
5
  HF_TOKEN = os.getenv("HF_TOKEN")
6
 
7
  # HF datasets repo info
8
+ HF_DS_REPO_ID = os.getenv("HF_REPO_ID", "m97j/pls-datasets")
9
  HF_INDEX_FILE = os.getenv("HF_INDEX_FILE", "faiss/faiss_index_flat.faiss")
10
  HF_IDS_FILE = os.getenv("HF_IDS_FILE", "faiss/vector_ids.npy")
11
 
 
18
  MARKER_DIR = os.getenv("MARKER_DIR", "rag/state")
19
  CORPUS_READY_MARK = os.path.join(MARKER_DIR, ".corpus_ready")
20
 
21
+ # Embedding model
22
+ HF_MODEL_REPO_ID = os.getenv("HF_MODEL_REPO_ID", "m97j/pragmatic-search")
23
+ EMBED_MODEL = os.getenv("EMBED_MODEL", "model_quantized.onnx")
24
+ EMBED_DIR = os.getenv("EMBED_DIR", "embedder")
25
+
26
+ # Reranking model
27
+ RERANK_MODEL = os.getenv("RERANK_MODEL", "model_quantized.onnx")
28
+ RERANK_DIR = os.getenv("RERANK_DIR", "reranker")
29
+
30
+ # Retrieval settings
31
  TOP_K = int(os.getenv("TOP_K", "5"))
32
 
33
 
db/initializer.py CHANGED
@@ -2,7 +2,7 @@
2
  import faiss
3
  import numpy as np
4
  from huggingface_hub import hf_hub_download
5
- from config import HF_REPO_ID, HF_INDEX_FILE, HF_IDS_FILE
6
  from modules.retriever import set_index
7
  from modules import corpus
8
 
@@ -10,8 +10,8 @@ _vector_ids = None
10
 
11
  def _load_index_in_memory():
12
  """HF Hub์—์„œ ์ธ๋ฑ์Šค/ID ๋งคํ•‘์„ ๋ฐ›์•„ ๋ฉ”๋ชจ๋ฆฌ์— ๋กœ๋“œ"""
13
- index_path = hf_hub_download(repo_id=HF_REPO_ID, filename=HF_INDEX_FILE, repo_type="dataset")
14
- ids_path = hf_hub_download(repo_id=HF_REPO_ID, filename=HF_IDS_FILE, repo_type="dataset")
15
  index = faiss.read_index(index_path)
16
  set_index(index)
17
  global _vector_ids
@@ -21,7 +21,7 @@ def get_vector_ids():
21
  global _vector_ids
22
  return _vector_ids
23
 
24
- def initialize():
25
  # 1) ์ฝ”ํผ์Šค ์ค€๋น„ (์ตœ์ดˆ 1ํšŒ๋งŒ ๋‹ค์šด๋กœ๋“œ)
26
  corpus.prepare_corpus()
27
  # 2) ์ธ๋ฑ์Šค/ID ๋งคํ•‘ ๋ฉ”๋ชจ๋ฆฌ ๋กœ๋“œ
 
2
  import faiss
3
  import numpy as np
4
  from huggingface_hub import hf_hub_download
5
+ from config import HF_DS_REPO_ID, HF_INDEX_FILE, HF_IDS_FILE
6
  from modules.retriever import set_index
7
  from modules import corpus
8
 
 
10
 
11
  def _load_index_in_memory():
12
  """HF Hub์—์„œ ์ธ๋ฑ์Šค/ID ๋งคํ•‘์„ ๋ฐ›์•„ ๋ฉ”๋ชจ๋ฆฌ์— ๋กœ๋“œ"""
13
+ index_path = hf_hub_download(repo_id=HF_DS_REPO_ID, filename=HF_INDEX_FILE, repo_type="dataset")
14
+ ids_path = hf_hub_download(repo_id=HF_DS_REPO_ID, filename=HF_IDS_FILE, repo_type="dataset")
15
  index = faiss.read_index(index_path)
16
  set_index(index)
17
  global _vector_ids
 
21
  global _vector_ids
22
  return _vector_ids
23
 
24
+ def initialize_dbs():
25
  # 1) ์ฝ”ํผ์Šค ์ค€๋น„ (์ตœ์ดˆ 1ํšŒ๋งŒ ๋‹ค์šด๋กœ๋“œ)
26
  corpus.prepare_corpus()
27
  # 2) ์ธ๋ฑ์Šค/ID ๋งคํ•‘ ๋ฉ”๋ชจ๋ฆฌ ๋กœ๋“œ
models/embedder.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # rag/models/embedder.py
2
+ from typing import List
3
+ import numpy as np
4
+ import onnxruntime as ort
5
+ from fastapi import Request
6
+
7
+ def _l2_normalize(vec: np.ndarray) -> List[float]:
8
+ norm = np.linalg.norm(vec) or 1.0
9
+ return (vec / norm).tolist()
10
+
11
+ def get_embedding(request: Request, text: str) -> List[float]:
12
+ """
13
+ request.app.state.embedder_sess : ONNX Runtime InferenceSession
14
+ request.app.state.embedder_tokenizer : ํ† ํฌ๋‚˜์ด์ €
15
+ """
16
+ tokenizer = request.app.state.embedder_tokenizer
17
+ sess: ort.InferenceSession = request.app.state.embedder_sess
18
+
19
+ inputs = tokenizer(text, return_tensors="np", padding=True, truncation=True, max_length=256)
20
+ ort_inputs = {k: v for k, v in inputs.items()}
21
+ ort_outs = sess.run(None, ort_inputs)
22
+ # ์ผ๋ฐ˜์ ์œผ๋กœ ์ฒซ ๋ฒˆ์งธ ์ถœ๋ ฅ์ด [batch, dim] ์ž„๋ฒ ๋”ฉ
23
+ vec = ort_outs[0][0]
24
+ return _l2_normalize(vec)
25
+
26
+
models/initializer.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # rag/models/initializer.py
2
+ from transformers import AutoTokenizer
3
+ import onnxruntime as ort
4
+ from huggingface_hub import hf_hub_download
5
+ from fastapi import FastAPI
6
+ from config import HF_MODEL_REPO_ID, EMBED_MODEL, EMBED_DIR, RERANK_MODEL, RERANK_DIR
7
+
8
+ def initialize_models(app: FastAPI):
9
+ # Embedder
10
+ embedder_tokenizer = AutoTokenizer.from_pretrained(
11
+ HF_MODEL_REPO_ID,
12
+ subfolder=EMBED_DIR # ํ† ํฌ๋‚˜์ด์ € ๊ด€๋ จ ํŒŒ์ผ์ด embedder/ ์•ˆ์— ์žˆ์œผ๋ฏ€๋กœ ์ง€์ •
13
+ )
14
+ embedder_model_path = hf_hub_download(
15
+ repo_id=HF_MODEL_REPO_ID,
16
+ filename=EMBED_MODEL,
17
+ subfolder=EMBED_DIR
18
+ )
19
+ embedder_sess = ort.InferenceSession(embedder_model_path, providers=["CPUExecutionProvider"])
20
+
21
+ # Reranker
22
+ reranker_tokenizer = AutoTokenizer.from_pretrained(
23
+ HF_MODEL_REPO_ID,
24
+ subfolder=RERANK_DIR # ํ† ํฌ๋‚˜์ด์ € ๊ด€๋ จ ํŒŒ์ผ์ด reranker/ ์•ˆ์— ์žˆ์œผ๋ฏ€๋กœ ์ง€์ •
25
+ )
26
+ reranker_model_path = hf_hub_download(
27
+ repo_id=HF_MODEL_REPO_ID,
28
+ filename=RERANK_MODEL,
29
+ subfolder=RERANK_DIR
30
+ )
31
+ reranker_sess = ort.InferenceSession(reranker_model_path, providers=["CPUExecutionProvider"])
32
+
33
+ # FastAPI app.state์— ์ €์žฅ โ†’ ์ „์—ญ ๊ณต์œ 
34
+ app.state.embedder_tokenizer = embedder_tokenizer
35
+ app.state.embedder_sess = embedder_sess
36
+ app.state.reranker_tokenizer = reranker_tokenizer
37
+ app.state.reranker_sess = reranker_sess
{modules โ†’ models}/reranker.py RENAMED
@@ -1,39 +1,32 @@
1
- # rag/modules/reranker.py
2
  import os
3
  from typing import List, Dict
4
- from huggingface_hub import InferenceClient
 
5
 
6
- # ํ™˜๊ฒฝ๋ณ€์ˆ˜์—์„œ ๋ชจ๋ธ๋ช…๊ณผ ํ† ํฐ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ
7
- HF_TOKEN = os.getenv("HF_TOKEN")
8
- RERANK_MODEL = os.getenv("RERANK_MODEL", "BAAI/bge-reranker-large")
9
-
10
- _client = InferenceClient(model=RERANK_MODEL, token=HF_TOKEN)
11
-
12
- # threshold ๊ฐ’์€ ํ™˜๊ฒฝ๋ณ€์ˆ˜๋‚˜ config์—์„œ ๊ด€๋ฆฌ ๊ฐ€๋Šฅ
13
  THRESHOLD = float(os.getenv("RERANK_THRESHOLD", "0.3"))
14
 
15
- def rerank(query: str, contexts: List[Dict]) -> List[Dict]:
16
  """
 
 
17
  contexts: [{"id": ..., "text": ...}, ...]
18
- ๋ฐ˜ํ™˜: threshold ์ด์ƒ ์ ์ˆ˜๋งŒ ํฌํ•จ๋œ reranked contexts
19
  """
20
  if not contexts:
21
  return []
22
 
23
- # reranker ์ž…๋ ฅ: (query, passage) ์Œ ๋ฆฌ์ŠคํŠธ
24
- pairs = [(query, ctx["text"]) for ctx in contexts]
25
 
26
- # Inference API ํ˜ธ์ถœ โ†’ ๊ฐ ์Œ์— ๋Œ€ํ•œ ์ ์ˆ˜ ๋ฐ˜ํ™˜
27
- scores = _client.rerank(inputs=pairs)
 
 
 
28
 
29
- # scores๋Š” [{"score": float}, ...] ํ˜•ํƒœ
30
  for ctx, sc in zip(contexts, scores):
31
- ctx["score"] = sc["score"]
32
 
33
- # ์ ์ˆ˜ ๋‚ด๋ฆผ์ฐจ์ˆœ ์ •๋ ฌ
34
  reranked = sorted(contexts, key=lambda x: x["score"], reverse=True)
35
-
36
- # threshold ์ด์ƒ๋งŒ ํ•„ํ„ฐ๋ง
37
  reranked = [c for c in reranked if c["score"] >= THRESHOLD]
38
-
39
  return reranked
 
1
+ # rag/models/reranker.py
2
  import os
3
  from typing import List, Dict
4
+ import onnxruntime as ort
5
+ from fastapi import Request
6
 
 
 
 
 
 
 
 
7
  THRESHOLD = float(os.getenv("RERANK_THRESHOLD", "0.3"))
8
 
9
+ def rerank(request: Request, query: str, contexts: List[Dict]) -> List[Dict]:
10
  """
11
+ request.app.state.reranker_sess : ONNX Runtime InferenceSession
12
+ request.app.state.reranker_tokenizer : ํ† ํฌ๋‚˜์ด์ €
13
  contexts: [{"id": ..., "text": ...}, ...]
 
14
  """
15
  if not contexts:
16
  return []
17
 
18
+ tokenizer = request.app.state.reranker_tokenizer
19
+ sess: ort.InferenceSession = request.app.state.reranker_sess
20
 
21
+ pairs = [(query, ctx["text"]) for ctx in contexts]
22
+ inputs = tokenizer(pairs, return_tensors="np", padding=True, truncation=True, max_length=256)
23
+ ort_inputs = {k: v for k, v in inputs.items()}
24
+ scores = sess.run(None, ort_inputs)[0] # [batch, 1] ํ˜•ํƒœ๋ผ๊ณ  ๊ฐ€์ •
25
+ scores = scores.squeeze(-1)
26
 
 
27
  for ctx, sc in zip(contexts, scores):
28
+ ctx["score"] = float(sc)
29
 
 
30
  reranked = sorted(contexts, key=lambda x: x["score"], reverse=True)
 
 
31
  reranked = [c for c in reranked if c["score"] >= THRESHOLD]
 
32
  return reranked
modules/embedder.py DELETED
@@ -1,23 +0,0 @@
1
- # rag/modules/embedder.py
2
- import math
3
- from typing import List
4
- from huggingface_hub import InferenceClient
5
- from config import EMBED_MODEL, HF_TOKEN
6
-
7
- # ๋ชจ๋ธ๊ณผ ํ† ํฐ ์ง€์ •
8
- _client = InferenceClient(model=EMBED_MODEL, token=HF_TOKEN)
9
-
10
- def _l2_normalize(vec: List[float]) -> List[float]:
11
- norm = math.sqrt(sum(x * x for x in vec)) or 1.0
12
- return [x / norm for x in vec]
13
-
14
- def get_embedding(text: str) -> List[float]:
15
- # embeddings ํƒœ์Šคํฌ๋ฅผ ์ง์ ‘ ์ง€์ •
16
- response = _client.post(
17
- json={"inputs": text},
18
- task="embeddings"
19
- )
20
- # ํ•ญ์ƒ [batch_size, embedding_dim] ํ˜•ํƒœ ๋ฐ˜ํ™˜
21
- vec = response[0]
22
- return _l2_normalize(vec)
23
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
service/search.py CHANGED
@@ -1,8 +1,8 @@
1
  # rag/service/search.py
2
- from modules.embedder import get_embedding
3
  from modules.retriever import retrieve_ids
4
  from modules.corpus import fetch_contexts_by_ids
5
- from modules.reranker import rerank
6
 
7
  def search(query: str) -> list[dict]:
8
  embedding = get_embedding(query)
 
1
  # rag/service/search.py
2
+ from models.embedder import get_embedding
3
  from modules.retriever import retrieve_ids
4
  from modules.corpus import fetch_contexts_by_ids
5
+ from models.reranker import rerank
6
 
7
  def search(query: str) -> list[dict]:
8
  embedding = get_embedding(query)