Arif commited on
Commit
d3aefee
·
1 Parent(s): be05fd6

Updated device selection mps/cpu/gpu

Browse files
Files changed (1) hide show
  1. src/retrieval/rag_chain.py +19 -3
src/retrieval/rag_chain.py CHANGED
@@ -1,4 +1,3 @@
1
- # src/retrieval/rag_chain.py
2
  import sys
3
  import os
4
  from dotenv import load_dotenv
@@ -21,16 +20,33 @@ DB_PATH = os.getenv("CHROMA_DB_PATH", "data/chroma_db")
21
  COLLECTION_NAME = os.getenv("CHROMA_COLLECTION_NAME", "rag_experiments")
22
  EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "sentence-transformers/all-MiniLM-L6-v2")
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  def format_docs(docs):
25
  return "\n\n".join(doc.page_content for doc in docs)
26
 
27
  def build_rag_chain():
28
  """Builds and returns the RAG chain using LCEL."""
29
 
30
- # 1. Initialize Embeddings
31
  embeddings = HuggingFaceEmbeddings(
32
  model_name=EMBEDDING_MODEL,
33
- model_kwargs={'device': 'mps'}
34
  )
35
 
36
  # 2. Initialize Retriever
 
 
1
  import sys
2
  import os
3
  from dotenv import load_dotenv
 
20
  COLLECTION_NAME = os.getenv("CHROMA_COLLECTION_NAME", "rag_experiments")
21
  EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "sentence-transformers/all-MiniLM-L6-v2")
22
 
23
+ # --- CRITICAL FIX: Detect Device ---
24
+ def get_device():
25
+ """Detect the appropriate device for the current platform."""
26
+ import platform
27
+ system = platform.system()
28
+
29
+ if system == "Darwin": # macOS
30
+ return "mps"
31
+ elif system == "Linux": # Linux (HF, Cloud)
32
+ return "cpu"
33
+ else: # Windows or other
34
+ return "cpu"
35
+
36
+ DEVICE = get_device()
37
+ print(f"🖥️ Detected Platform: {DEVICE}")
38
+ # --------------------------------
39
+
40
  def format_docs(docs):
41
  return "\n\n".join(doc.page_content for doc in docs)
42
 
43
  def build_rag_chain():
44
  """Builds and returns the RAG chain using LCEL."""
45
 
46
+ # 1. Initialize Embeddings with detected device
47
  embeddings = HuggingFaceEmbeddings(
48
  model_name=EMBEDDING_MODEL,
49
+ model_kwargs={'device': DEVICE} # Use detected device
50
  )
51
 
52
  # 2. Initialize Retriever