Spaces:
Sleeping
Sleeping
Rajan Sharma
commited on
Update rag.py
Browse files
rag.py
CHANGED
|
@@ -1,26 +1,20 @@
|
|
| 1 |
from typing import List, Tuple
|
| 2 |
import numpy as np
|
| 3 |
-
import
|
| 4 |
-
from settings import COHERE_API_KEY, COHERE_EMBED_MODEL
|
| 5 |
|
| 6 |
class RAGIndex:
|
| 7 |
def __init__(self):
|
| 8 |
-
self.client = cohere.Client(api_key=COHERE_API_KEY) if COHERE_API_KEY else None
|
| 9 |
self.texts: List[str] = []
|
| 10 |
self.vecs: np.ndarray | None = None
|
| 11 |
|
| 12 |
-
def _embed(self, texts: List[str]) -> np.ndarray:
|
| 13 |
-
if not texts: return np.zeros((0, 384), dtype="float32")
|
| 14 |
-
if not self.client:
|
| 15 |
-
# Fallback: random embeddings (avoid crash; not ideal)
|
| 16 |
-
return np.random.normal(size=(len(texts), 384)).astype("float32")
|
| 17 |
-
resp = self.client.embed(texts=texts, model=COHERE_EMBED_MODEL)
|
| 18 |
-
vecs = np.array(getattr(resp, "embeddings", []) or getattr(resp, "data", []), dtype="float32")
|
| 19 |
-
return vecs
|
| 20 |
-
|
| 21 |
def add(self, chunks: List[str]):
|
| 22 |
if not chunks: return
|
| 23 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
if self.vecs is None:
|
| 25 |
self.vecs = new_vecs
|
| 26 |
self.texts = list(chunks)
|
|
@@ -30,7 +24,9 @@ class RAGIndex:
|
|
| 30 |
|
| 31 |
def retrieve(self, query: str, k: int = 6) -> List[Tuple[str, float]]:
|
| 32 |
if not self.texts: return []
|
| 33 |
-
|
|
|
|
|
|
|
| 34 |
sims = (self.vecs @ qv) / (np.linalg.norm(self.vecs, axis=1) * (np.linalg.norm(qv) + 1e-9))
|
| 35 |
idx = np.argsort(-sims)[:k]
|
| 36 |
return [(self.texts[i], float(sims[i])) for i in idx]
|
|
|
|
| 1 |
from typing import List, Tuple
|
| 2 |
import numpy as np
|
| 3 |
+
from llm_router import cohere_embed
|
|
|
|
| 4 |
|
| 5 |
class RAGIndex:
|
| 6 |
def __init__(self):
|
|
|
|
| 7 |
self.texts: List[str] = []
|
| 8 |
self.vecs: np.ndarray | None = None
|
| 9 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
def add(self, chunks: List[str]):
|
| 11 |
if not chunks: return
|
| 12 |
+
new_vecs_list = cohere_embed(chunks)
|
| 13 |
+
if not new_vecs_list:
|
| 14 |
+
# fallback: random to avoid crash (not ideal for accuracy)
|
| 15 |
+
new_vecs = np.random.normal(size=(len(chunks), 384)).astype("float32")
|
| 16 |
+
else:
|
| 17 |
+
new_vecs = np.array(new_vecs_list, dtype="float32")
|
| 18 |
if self.vecs is None:
|
| 19 |
self.vecs = new_vecs
|
| 20 |
self.texts = list(chunks)
|
|
|
|
| 24 |
|
| 25 |
def retrieve(self, query: str, k: int = 6) -> List[Tuple[str, float]]:
|
| 26 |
if not self.texts: return []
|
| 27 |
+
qv_list = cohere_embed([query])
|
| 28 |
+
if not qv_list: return []
|
| 29 |
+
qv = np.array(qv_list[0], dtype="float32")
|
| 30 |
sims = (self.vecs @ qv) / (np.linalg.norm(self.vecs, axis=1) * (np.linalg.norm(qv) + 1e-9))
|
| 31 |
idx = np.argsort(-sims)[:k]
|
| 32 |
return [(self.texts[i], float(sims[i])) for i in idx]
|