AdamyaG commited on
Commit
763f886
1 Parent(s): 21533a5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -17
app.py CHANGED
@@ -6,10 +6,7 @@ from langchain_core.messages import AIMessage, HumanMessage
6
  from langchain_community.document_loaders import WebBaseLoader
7
  from langchain.text_splitter import RecursiveCharacterTextSplitter
8
  from langchain_community.vectorstores import Chroma
9
- # from langchain_openai import OpenAIEmbeddings, ChatOpenAI
10
- from langchain.llms import HuggingFaceHub
11
  from langchain.embeddings import HuggingFaceEmbeddings
12
- from dotenv import load_dotenv
13
  from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
14
  from langchain.chains import create_history_aware_retriever, create_retrieval_chain
15
  from langchain.chains.combine_documents import create_stuff_documents_chain
@@ -32,7 +29,11 @@ def get_vectorstore_from_url(url):
32
  return vector_store
33
 
34
  def get_context_retriever_chain(vector_store):
35
- llm = HuggingFaceHub(repo_id = "HuggingFaceH4/zephyr-7b-beta", model_kwargs = {"temperature":0.5, "max_length":512})
 
 
 
 
36
  retriever = vector_store.as_retriever()
37
 
38
  prompt = ChatPromptTemplate.from_messages([
@@ -46,8 +47,11 @@ def get_context_retriever_chain(vector_store):
46
  return retriever_chain
47
 
48
  def get_conversational_rag_chain(retriever_chain):
49
- llm = HuggingFaceHub(repo_id = "HuggingFaceH4/zephyr-7b-beta", model_kwargs = {"temperature":0.5, "max_length":512})
50
-
 
 
 
51
  prompt = ChatPromptTemplate.from_messages([
52
  ("system", "Answer the user's questions based on the below context:\n\n{context}"),
53
  MessagesPlaceholder(variable_name="chat_history"),
@@ -100,14 +104,10 @@ else:
100
 
101
 
102
  # conversation
103
- first = 0
104
- for i, message in enumerate(st.session_state.chat_history):
105
- if first < 3:
106
- first += 1
107
- else:
108
- if isinstance(message, AIMessage):
109
- with st.chat_message("AI"):
110
- st.write(message.content)
111
- elif isinstance(message, HumanMessage):
112
- with st.chat_message("Human"):
113
- st.write(message.content)
 
6
  from langchain_community.document_loaders import WebBaseLoader
7
  from langchain.text_splitter import RecursiveCharacterTextSplitter
8
  from langchain_community.vectorstores import Chroma
 
 
9
  from langchain.embeddings import HuggingFaceEmbeddings
 
10
  from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
11
  from langchain.chains import create_history_aware_retriever, create_retrieval_chain
12
  from langchain.chains.combine_documents import create_stuff_documents_chain
 
29
  return vector_store
30
 
31
  def get_context_retriever_chain(vector_store):
32
+ # llm = HuggingFaceHub(repo_id = "HuggingFaceH4/zephyr-7b-beta", model_kwargs = {"temperature":0.5, "max_length":512})
33
+ llm = HuggingFaceEndpoint(
34
+ repo_id="HuggingFaceH4/zephyr-7b-beta",
35
+ task="text-generation",
36
+ max_new_tokens=512)
37
  retriever = vector_store.as_retriever()
38
 
39
  prompt = ChatPromptTemplate.from_messages([
 
47
  return retriever_chain
48
 
49
  def get_conversational_rag_chain(retriever_chain):
50
+ # llm = HuggingFaceHub(repo_id = "HuggingFaceH4/zephyr-7b-beta", model_kwargs = {"temperature":0.5, "max_length":512})
51
+ llm = HuggingFaceEndpoint(
52
+ repo_id="HuggingFaceH4/zephyr-7b-beta",
53
+ task="text-generation",
54
+ max_new_tokens=512)
55
  prompt = ChatPromptTemplate.from_messages([
56
  ("system", "Answer the user's questions based on the below context:\n\n{context}"),
57
  MessagesPlaceholder(variable_name="chat_history"),
 
104
 
105
 
106
  # conversation
107
+ for message in st.session_state.chat_history:
108
+ if isinstance(message, AIMessage):
109
+ with st.chat_message("AI"):
110
+ st.write(message.content)
111
+ elif isinstance(message, HumanMessage):
112
+ with st.chat_message("Human"):
113
+ st.write(message.content)