0-Parth-D commited on
Commit
143bd7b
·
1 Parent(s): 55bac67

Added compatibility to cloud models

Browse files
requirements.txt CHANGED
@@ -7,6 +7,7 @@ langchain-text-splitters
7
  langchain-chroma
8
  sentence-transformers
9
  python-dotenv
 
10
 
11
  fastapi
12
  uvicorn
 
7
  langchain-chroma
8
  sentence-transformers
9
  python-dotenv
10
+ langchain-groq
11
 
12
  fastapi
13
  uvicorn
src/rag_code_assistant/agent.py CHANGED
@@ -12,6 +12,7 @@ from pydantic import BaseModel
12
 
13
  from langchain_pinecone import PineconeVectorStore # Changed from Chroma
14
  from langchain_ollama import ChatOllama
 
15
  from langchain_huggingface import HuggingFaceEmbeddings
16
  from langchain_core.tools.retriever import create_retriever_tool
17
  from langchain.agents import create_agent
@@ -33,10 +34,35 @@ def load_vectorstore():
33
  )
34
 
35
  def load_llm():
36
- return ChatOllama(
37
- model="llama3.1",
38
- temperature=0.1,
39
- base_url=os.environ["OLLAMA_BASE_URL"],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  )
41
 
42
  def load_retriever(vectorstore):
 
12
 
13
  from langchain_pinecone import PineconeVectorStore # Changed from Chroma
14
  from langchain_ollama import ChatOllama
15
+ from langchain_groq import ChatGroq
16
  from langchain_huggingface import HuggingFaceEmbeddings
17
  from langchain_core.tools.retriever import create_retriever_tool
18
  from langchain.agents import create_agent
 
34
  )
35
 
36
  def load_llm():
37
+ """
38
+ Loads the LLM with fallback logic:
39
+ - Tries Ollama (local development with your laptop)
40
+ - Falls back to Groq Cloud (production deployment on Hugging Face)
41
+ """
42
+ ollama_url = os.environ["OLLAMA_BASE_URL"]
43
+
44
+ # If OLLAMA_BASE_URL is set, use local Ollama (for demo purposes)
45
+ if ollama_url:
46
+ print("🔧 Using local Ollama LLM (Development Mode)")
47
+ return ChatOllama(
48
+ model="llama3.1",
49
+ temperature=0.1,
50
+ base_url=ollama_url,
51
+ )
52
+
53
+ # Otherwise, use Groq Cloud (for production on Hugging Face)
54
+ groq_api_key = os.environ["GROQ_API_KEY"]
55
+ if not groq_api_key:
56
+ raise ValueError(
57
+ "Neither OLLAMA_BASE_URL nor GROQ_API_KEY found! "
58
+ "Please set one in your environment variables."
59
+ )
60
+
61
+ print("☁️ Using Groq Cloud LLM (Production Mode)")
62
+ return ChatGroq(
63
+ api_key=groq_api_key,
64
+ model_name="llama-3.3-70b-versatile", # Fast, smart, and free!
65
+ temperature=0.1
66
  )
67
 
68
  def load_retriever(vectorstore):