Phoenix21 commited on
Commit
7fb6500
·
verified ·
1 Parent(s): 4ac5e81

Create classification_chain.py

Browse files
Files changed (1) hide show
  1. classification_chain.py +24 -0
classification_chain.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # classification_chain.py
2
+ import os
3
+ from langchain.chains import LLMChain
4
+ from langchain_groq import ChatGroq
5
+
6
+ # We'll import the classification_prompt from prompts.py
7
+ from prompts import classification_prompt
8
+
9
+ def get_classification_chain() -> LLMChain:
10
+ """
11
+ Builds the classification chain (LLMChain) using ChatGroq and the classification prompt.
12
+ """
13
+ # Initialize the ChatGroq model (Gemma2-9b-It) with your GROQ_API_KEY
14
+ chat_groq_model = ChatGroq(
15
+ model="Gemma2-9b-It",
16
+ groq_api_key=os.environ["GROQ_API_KEY"] # must be set in environment
17
+ )
18
+
19
+ # Build an LLMChain
20
+ classification_chain = LLMChain(
21
+ llm=chat_groq_model,
22
+ prompt=classification_prompt
23
+ )
24
+ return classification_chain