gabruarya commited on
Commit
e5eb788
1 Parent(s): 922177d

updated for groq

Browse files
Files changed (1) hide show
  1. app.py +16 -8
app.py CHANGED
@@ -11,12 +11,11 @@ from langchain.prompts import PromptTemplate
11
  from langchain.chains import RetrievalQA
12
  import streamlit.components.v1 as components
13
  from langchain_groq import ChatGroq
 
 
14
  import time
15
 
16
  HUGGINGFACEHUB_API_TOKEN = st.secrets['HUGGINGFACEHUB_API_TOKEN']
17
- os.environ["HUGGINGFACEHUB_API_TOKEN"] = HUGGINGFACEHUB_API_TOKEN
18
- speed = 10
19
-
20
 
21
  @dataclass
22
  class Message:
@@ -36,7 +35,7 @@ def initialize_session_state():
36
  if "conversation" not in st.session_state:
37
  llama = LlamaAPI(st.secrets["LlamaAPI"])
38
  model = ChatLlamaAPI(client=llama)
39
- chat = ChatGroq(temperature=0, groq_api_key=st.secrets["Groq_api"], model_name="mixtral-8x7b-32768")
40
 
41
  embeddings = download_hugging_face_embeddings()
42
 
@@ -61,13 +60,22 @@ def initialize_session_state():
61
  """
62
 
63
  PROMPT = PromptTemplate(template=prompt_template, input_variables=["context", "question"])
64
- chain_type_kwargs = {"prompt": PROMPT}
65
- retrieval_chain = RetrievalQA.from_chain_type(llm=chat,
 
 
 
 
 
 
 
 
66
  chain_type="stuff",
67
  retriever=docsearch.as_retriever(
68
  search_kwargs={'k': 2}),
69
  return_source_documents=True,
70
- chain_type_kwargs=chain_type_kwargs,
 
71
  )
72
 
73
  st.session_state.conversation = retrieval_chain
@@ -78,7 +86,7 @@ def on_click_callback():
78
  response = st.session_state.conversation(
79
  human_prompt
80
  )
81
- llm_response = response['result']
82
  st.session_state.history.append(
83
  Message("👤 Human", human_prompt)
84
  )
 
11
  from langchain.chains import RetrievalQA
12
  import streamlit.components.v1 as components
13
  from langchain_groq import ChatGroq
14
+ from langchain.chains import ConversationalRetrievalChain
15
+ from langchain.memory import ChatMessageHistory, ConversationBufferMemory
16
  import time
17
 
18
  HUGGINGFACEHUB_API_TOKEN = st.secrets['HUGGINGFACEHUB_API_TOKEN']
 
 
 
19
 
20
  @dataclass
21
  class Message:
 
35
  if "conversation" not in st.session_state:
36
  llama = LlamaAPI(st.secrets["LlamaAPI"])
37
  model = ChatLlamaAPI(client=llama)
38
+ chat = ChatGroq(temperature=0.5, groq_api_key=st.secrets["Groq_api"], model_name="mixtral-8x7b-32768")
39
 
40
  embeddings = download_hugging_face_embeddings()
41
 
 
60
  """
61
 
62
  PROMPT = PromptTemplate(template=prompt_template, input_variables=["context", "question"])
63
+
64
+ #chain_type_kwargs = {"prompt": PROMPT}
65
+ message_history = ChatMessageHistory()
66
+ memory = ConversationBufferMemory(
67
+ memory_key="chat_history",
68
+ output_key="answer",
69
+ chat_memory=message_history,
70
+ return_messages=True,
71
+ )
72
+ retrieval_chain = ConversationalRetrievalChain.from_llm(llm=chat,
73
  chain_type="stuff",
74
  retriever=docsearch.as_retriever(
75
  search_kwargs={'k': 2}),
76
  return_source_documents=True,
77
+ combine_docs_chain_kwargs={"prompt": PROMPT},
78
+ memory= memory
79
  )
80
 
81
  st.session_state.conversation = retrieval_chain
 
86
  response = st.session_state.conversation(
87
  human_prompt
88
  )
89
+ llm_response = response['answer']
90
  st.session_state.history.append(
91
  Message("👤 Human", human_prompt)
92
  )