Sakil commited on
Commit
96563ef
1 Parent(s): 7618d7f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -7
app.py CHANGED
@@ -36,22 +36,22 @@ def retrieval_qa_chain(llm, prompt, db):
36
  return qa_chain
37
 
38
  # Loading the model
39
- def load_llm():
40
  # Load the locally downloaded model here
41
  llm = CTransformers(
42
  model="llama-2-7b-chat.ggmlv3.q8_0.bin",
43
  model_type="llama",
44
- max_new_tokens=512,
45
- temperature=0.5
46
  )
47
  return llm
48
 
49
  # QA Model Function
50
- def qa_bot():
51
  embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2",
52
  model_kwargs={'device': 'cpu'})
53
  db = FAISS.load_local(DB_FAISS_PATH, embeddings)
54
- llm = load_llm()
55
  qa_prompt = set_custom_prompt()
56
  qa = retrieval_qa_chain(llm, qa_prompt, db)
57
 
@@ -60,7 +60,10 @@ def qa_bot():
60
  def main():
61
  st.title("AI ChatBot LLM")
62
 
63
- qa_result = qa_bot()
 
 
 
64
 
65
  user_input = st.text_input("Enter your question:")
66
 
@@ -74,6 +77,9 @@ def main():
74
  st.write("Sources:", sources)
75
  else:
76
  st.write("No sources found")
 
 
 
77
 
78
  if __name__ == "__main__":
79
- main()
 
36
  return qa_chain
37
 
38
  # Loading the model
39
+ def load_llm(max_new_tokens, temperature):
40
  # Load the locally downloaded model here
41
  llm = CTransformers(
42
  model="llama-2-7b-chat.ggmlv3.q8_0.bin",
43
  model_type="llama",
44
+ max_new_tokens=max_new_tokens,
45
+ temperature=temperature
46
  )
47
  return llm
48
 
49
  # QA Model Function
50
+ def qa_bot(max_new_tokens, temperature):
51
  embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2",
52
  model_kwargs={'device': 'cpu'})
53
  db = FAISS.load_local(DB_FAISS_PATH, embeddings)
54
+ llm = load_llm(max_new_tokens, temperature)
55
  qa_prompt = set_custom_prompt()
56
  qa = retrieval_qa_chain(llm, qa_prompt, db)
57
 
 
60
  def main():
61
  st.title("AI ChatBot LLM")
62
 
63
+ max_new_tokens = st.slider("Max New Tokens", min_value=1, max_value=1000, value=512)
64
+ temperature = st.slider("Temperature", min_value=0.1, max_value=2.0, step=0.1, value=0.5)
65
+
66
+ qa_result = qa_bot(max_new_tokens, temperature)
67
 
68
  user_input = st.text_input("Enter your question:")
69
 
 
77
  st.write("Sources:", sources)
78
  else:
79
  st.write("No sources found")
80
+
81
+ if st.button("Clear"):
82
+ st.text_input("Enter your question:", value="")
83
 
84
  if __name__ == "__main__":
85
+ main()