Arif
commited on
Commit
·
d3aefee
1
Parent(s):
be05fd6
Updated device selection mps/cpu/gpu
Browse files- src/retrieval/rag_chain.py +19 -3
src/retrieval/rag_chain.py
CHANGED
|
@@ -1,4 +1,3 @@
|
|
| 1 |
-
# src/retrieval/rag_chain.py
|
| 2 |
import sys
|
| 3 |
import os
|
| 4 |
from dotenv import load_dotenv
|
|
@@ -21,16 +20,33 @@ DB_PATH = os.getenv("CHROMA_DB_PATH", "data/chroma_db")
|
|
| 21 |
COLLECTION_NAME = os.getenv("CHROMA_COLLECTION_NAME", "rag_experiments")
|
| 22 |
EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "sentence-transformers/all-MiniLM-L6-v2")
|
| 23 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
def format_docs(docs):
|
| 25 |
return "\n\n".join(doc.page_content for doc in docs)
|
| 26 |
|
| 27 |
def build_rag_chain():
|
| 28 |
"""Builds and returns the RAG chain using LCEL."""
|
| 29 |
|
| 30 |
-
# 1. Initialize Embeddings
|
| 31 |
embeddings = HuggingFaceEmbeddings(
|
| 32 |
model_name=EMBEDDING_MODEL,
|
| 33 |
-
model_kwargs={'device':
|
| 34 |
)
|
| 35 |
|
| 36 |
# 2. Initialize Retriever
|
|
|
|
|
|
|
| 1 |
import sys
|
| 2 |
import os
|
| 3 |
from dotenv import load_dotenv
|
|
|
|
| 20 |
COLLECTION_NAME = os.getenv("CHROMA_COLLECTION_NAME", "rag_experiments")
|
| 21 |
EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "sentence-transformers/all-MiniLM-L6-v2")
|
| 22 |
|
| 23 |
+
# --- CRITICAL FIX: Detect Device ---
|
| 24 |
+
def get_device():
|
| 25 |
+
"""Detect the appropriate device for the current platform."""
|
| 26 |
+
import platform
|
| 27 |
+
system = platform.system()
|
| 28 |
+
|
| 29 |
+
if system == "Darwin": # macOS
|
| 30 |
+
return "mps"
|
| 31 |
+
elif system == "Linux": # Linux (HF, Cloud)
|
| 32 |
+
return "cpu"
|
| 33 |
+
else: # Windows or other
|
| 34 |
+
return "cpu"
|
| 35 |
+
|
| 36 |
+
DEVICE = get_device()
|
| 37 |
+
print(f"🖥️ Detected Platform: {DEVICE}")
|
| 38 |
+
# --------------------------------
|
| 39 |
+
|
| 40 |
def format_docs(docs):
|
| 41 |
return "\n\n".join(doc.page_content for doc in docs)
|
| 42 |
|
| 43 |
def build_rag_chain():
|
| 44 |
"""Builds and returns the RAG chain using LCEL."""
|
| 45 |
|
| 46 |
+
# 1. Initialize Embeddings with detected device
|
| 47 |
embeddings = HuggingFaceEmbeddings(
|
| 48 |
model_name=EMBEDDING_MODEL,
|
| 49 |
+
model_kwargs={'device': DEVICE} # Use detected device
|
| 50 |
)
|
| 51 |
|
| 52 |
# 2. Initialize Retriever
|