Sandaruth commited on
Commit
2ffda8f
1 Parent(s): 870ee5f

update model

Browse files
Files changed (3) hide show
  1. Retrieval.py +34 -0
  2. app.py +19 -15
  3. model.py +1 -11
Retrieval.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from model import llm, vectorstore, splitter, embedding, QA_PROMPT
3
+
4
+
5
+ # Chain for Web
6
+ from langchain.chains import RetrievalQA
7
+
8
+ bsic_chain = RetrievalQA.from_chain_type(
9
+ llm=llm,
10
+ chain_type="stuff",
11
+ retriever = vectorstore.as_retriever(search_kwargs={"k": 4}),
12
+ return_source_documents= True,
13
+ input_key="question",
14
+ chain_type_kwargs={"prompt": QA_PROMPT},
15
+ )
16
+
17
+
18
+
19
+ from langchain.retrievers.multi_query import MultiQueryRetriever
20
+ # from kk import MultiQueryRetriever
21
+
22
+ retriever_from_llm = MultiQueryRetriever.from_llm(
23
+ retriever=vectorstore.as_retriever(search_kwargs={"k": 3}),
24
+ llm=llm,
25
+ )
26
+
27
+ multiQuery_chain = RetrievalQA.from_chain_type(
28
+ llm=llm,
29
+ chain_type="stuff",
30
+ retriever = retriever_from_llm,
31
+ return_source_documents= True,
32
+ input_key="question",
33
+ chain_type_kwargs={"prompt": QA_PROMPT},
34
+ )
app.py CHANGED
@@ -1,10 +1,10 @@
1
  import streamlit as st
2
- from model import Web_qa
3
  import time
4
 
5
  from htmlTemplates import css, bot_template, user_template, source_template
6
 
7
- st.set_page_config(page_title="Chat with ATrad",page_icon=":currency_exchange:")
8
  st.write(css, unsafe_allow_html=True)
9
 
10
  def main():
@@ -17,8 +17,11 @@ def main():
17
  4. Source documents will be displayed in the sidebar.
18
  """)
19
 
 
 
 
 
20
  # Button to connect to Google link ------------------------------------------------
21
-
22
  st.sidebar.markdown('<a href="https://drive.google.com/drive/folders/13v6LsaYH9wEwvqVtlLG1U4OiUHgZ7hY4?usp=sharing" target="_blank" style="display: inline-block;'
23
  'background-color: #475063; color: white; padding: 10px 20px; text-align: center;border: 1px solid white;'
24
  'text-decoration: none; cursor: pointer; border-radius: 5px;">Sources</a>',
@@ -27,8 +30,8 @@ def main():
27
  st.title("ATrad Chat App")
28
 
29
  # Chat area -----------------------------------------------------------------------
30
-
31
  user_input = st.text_input("", key="user_input",placeholder="Type your question here...")
 
32
  # JavaScript code to submit the form on Enter key press
33
  js_submit = f"""
34
  document.addEventListener("keydown", function(event) {{
@@ -38,31 +41,32 @@ def main():
38
  }});
39
  """
40
  st.markdown(f'<script>{js_submit}</script>', unsafe_allow_html=True)
 
41
  if st.button("Send"):
42
  if user_input:
43
-
44
  with st.spinner('Waiting for response...'):
45
-
46
  # Add bot response here (you can replace this with your bot logic)
47
- response, metadata, source_documents = generate_bot_response(user_input)
48
- st.write(user_template.replace(
49
- "{{MSG}}", user_input), unsafe_allow_html=True)
50
- st.write(bot_template.replace(
51
- "{{MSG}}", response ), unsafe_allow_html=True)
52
 
53
  # Source documents
54
- print("metadata", metadata)
55
  st.sidebar.title("Source Documents")
56
  for i, doc in enumerate(source_documents, 1):
57
- tit=metadata[i-1]["source"].split("\\")[-1]
58
  with st.sidebar.expander(f"{tit}"):
59
  st.write(doc) # Assuming the Document object can be directly written to display its content
60
 
61
- def generate_bot_response(user_input):
62
  # Simple bot logic (replace with your actual bot logic)
63
  start_time = time.time()
64
  print(f"User Input: {user_input}")
65
- res = Web_qa(user_input)
 
 
 
 
 
66
  response = res['result']
67
  metadata = [i.metadata for i in res.get("source_documents", [])]
68
  end_time = time.time()
 
1
  import streamlit as st
2
+ from Retrieval import bsic_chain, multiQuery_chain
3
  import time
4
 
5
  from htmlTemplates import css, bot_template, user_template, source_template
6
 
7
+ st.set_page_config(page_title="Chat with ATrad", page_icon=":currency_exchange:")
8
  st.write(css, unsafe_allow_html=True)
9
 
10
  def main():
 
17
  4. Source documents will be displayed in the sidebar.
18
  """)
19
 
20
+ # Dropdown to select model --------------------------------------------------------
21
+ model_selection = st.sidebar.selectbox("Select Model", ["Basic", "MultiQuery"])
22
+ print(model_selection)
23
+
24
  # Button to connect to Google link ------------------------------------------------
 
25
  st.sidebar.markdown('<a href="https://drive.google.com/drive/folders/13v6LsaYH9wEwvqVtlLG1U4OiUHgZ7hY4?usp=sharing" target="_blank" style="display: inline-block;'
26
  'background-color: #475063; color: white; padding: 10px 20px; text-align: center;border: 1px solid white;'
27
  'text-decoration: none; cursor: pointer; border-radius: 5px;">Sources</a>',
 
30
  st.title("ATrad Chat App")
31
 
32
  # Chat area -----------------------------------------------------------------------
 
33
  user_input = st.text_input("", key="user_input",placeholder="Type your question here...")
34
+
35
  # JavaScript code to submit the form on Enter key press
36
  js_submit = f"""
37
  document.addEventListener("keydown", function(event) {{
 
41
  }});
42
  """
43
  st.markdown(f'<script>{js_submit}</script>', unsafe_allow_html=True)
44
+
45
  if st.button("Send"):
46
  if user_input:
 
47
  with st.spinner('Waiting for response...'):
 
48
  # Add bot response here (you can replace this with your bot logic)
49
+ response, metadata, source_documents = generate_bot_response(user_input, model_selection)
50
+ st.write(user_template.replace("{{MSG}}", user_input), unsafe_allow_html=True)
51
+ st.write(bot_template.replace("{{MSG}}", response), unsafe_allow_html=True)
 
 
52
 
53
  # Source documents
 
54
  st.sidebar.title("Source Documents")
55
  for i, doc in enumerate(source_documents, 1):
56
+ tit = metadata[i-1]["source"].split("\\")[-1]
57
  with st.sidebar.expander(f"{tit}"):
58
  st.write(doc) # Assuming the Document object can be directly written to display its content
59
 
60
+ def generate_bot_response(user_input, model):
61
  # Simple bot logic (replace with your actual bot logic)
62
  start_time = time.time()
63
  print(f"User Input: {user_input}")
64
+
65
+ if model == "Basic":
66
+ res = bsic_chain(user_input)
67
+ elif model == "MultiQuery":
68
+ res = multiQuery_chain(user_input)
69
+
70
  response = res['result']
71
  metadata = [i.metadata for i in res.get("source_documents", [])]
72
  end_time = time.time()
model.py CHANGED
@@ -68,15 +68,5 @@ from langchain.prompts import PromptTemplate
68
  QA_PROMPT = PromptTemplate(input_variables=["context", "question"],template=qa_template_V2,)
69
 
70
 
71
- # Chain for Web
72
- from langchain.chains import RetrievalQA
73
-
74
- Web_qa = RetrievalQA.from_chain_type(
75
- llm=llm,
76
- chain_type="stuff",
77
- retriever = vectorstore.as_retriever(search_kwargs={"k": 4}),
78
- return_source_documents= True,
79
- input_key="question",
80
- chain_type_kwargs={"prompt": QA_PROMPT},
81
- )
82
 
 
68
  QA_PROMPT = PromptTemplate(input_variables=["context", "question"],template=qa_template_V2,)
69
 
70
 
71
+
 
 
 
 
 
 
 
 
 
 
72