Rajan Sharma commited on
Commit
3982b77
·
verified ·
1 Parent(s): 521ffa1

Update rag.py

Browse files
Files changed (1) hide show
  1. rag.py +10 -14
rag.py CHANGED
@@ -1,26 +1,20 @@
1
  from typing import List, Tuple
2
  import numpy as np
3
- import cohere
4
- from settings import COHERE_API_KEY, COHERE_EMBED_MODEL
5
 
6
  class RAGIndex:
7
  def __init__(self):
8
- self.client = cohere.Client(api_key=COHERE_API_KEY) if COHERE_API_KEY else None
9
  self.texts: List[str] = []
10
  self.vecs: np.ndarray | None = None
11
 
12
- def _embed(self, texts: List[str]) -> np.ndarray:
13
- if not texts: return np.zeros((0, 384), dtype="float32")
14
- if not self.client:
15
- # Fallback: random embeddings (avoid crash; not ideal)
16
- return np.random.normal(size=(len(texts), 384)).astype("float32")
17
- resp = self.client.embed(texts=texts, model=COHERE_EMBED_MODEL)
18
- vecs = np.array(getattr(resp, "embeddings", []) or getattr(resp, "data", []), dtype="float32")
19
- return vecs
20
-
21
  def add(self, chunks: List[str]):
22
  if not chunks: return
23
- new_vecs = self._embed(chunks)
 
 
 
 
 
24
  if self.vecs is None:
25
  self.vecs = new_vecs
26
  self.texts = list(chunks)
@@ -30,7 +24,9 @@ class RAGIndex:
30
 
31
  def retrieve(self, query: str, k: int = 6) -> List[Tuple[str, float]]:
32
  if not self.texts: return []
33
- qv = self._embed([query])[0]
 
 
34
  sims = (self.vecs @ qv) / (np.linalg.norm(self.vecs, axis=1) * (np.linalg.norm(qv) + 1e-9))
35
  idx = np.argsort(-sims)[:k]
36
  return [(self.texts[i], float(sims[i])) for i in idx]
 
1
  from typing import List, Tuple
2
  import numpy as np
3
+ from llm_router import cohere_embed
 
4
 
5
  class RAGIndex:
6
  def __init__(self):
 
7
  self.texts: List[str] = []
8
  self.vecs: np.ndarray | None = None
9
 
 
 
 
 
 
 
 
 
 
10
  def add(self, chunks: List[str]):
11
  if not chunks: return
12
+ new_vecs_list = cohere_embed(chunks)
13
+ if not new_vecs_list:
14
+ # fallback: random to avoid crash (not ideal for accuracy)
15
+ new_vecs = np.random.normal(size=(len(chunks), 384)).astype("float32")
16
+ else:
17
+ new_vecs = np.array(new_vecs_list, dtype="float32")
18
  if self.vecs is None:
19
  self.vecs = new_vecs
20
  self.texts = list(chunks)
 
24
 
25
  def retrieve(self, query: str, k: int = 6) -> List[Tuple[str, float]]:
26
  if not self.texts: return []
27
+ qv_list = cohere_embed([query])
28
+ if not qv_list: return []
29
+ qv = np.array(qv_list[0], dtype="float32")
30
  sims = (self.vecs @ qv) / (np.linalg.norm(self.vecs, axis=1) * (np.linalg.norm(qv) + 1e-9))
31
  idx = np.argsort(-sims)[:k]
32
  return [(self.texts[i], float(sims[i])) for i in idx]