Jatinydv commited on
Commit
fc55ef6
1 Parent(s): cb40971

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -37
app.py CHANGED
@@ -1,10 +1,10 @@
1
- import gradio as gr
2
  from langchain_community.document_loaders import PyPDFLoader, DirectoryLoader
3
  from langchain.prompts import PromptTemplate
4
  from langchain_community.embeddings import HuggingFaceEmbeddings
5
  from langchain_community.vectorstores import FAISS
6
  from langchain_community.llms import CTransformers
7
  from langchain.chains import RetrievalQA
 
8
 
9
  DB_FAISS_PATH = 'vectorstore/db_faiss'
10
 
@@ -19,59 +19,48 @@ Helpful answer:
19
  """
20
 
21
  def set_custom_prompt():
 
 
 
22
  prompt = PromptTemplate(template=custom_prompt_template,
23
  input_variables=['context', 'question'])
24
  return prompt
25
 
26
- def retrieval_qa_chain(llm, prompt, db):
27
- qa_chain = RetrievalQA.from_chain_type(llm=llm,
28
- chain_type='stuff',
29
- retriever=db.as_retriever(search_kwargs={'k': 2}),
30
- return_source_documents=True,
31
- chain_type_kwargs={'prompt': prompt}
32
- )
33
- return qa_chain
34
-
35
  def load_llm():
 
36
  llm = CTransformers(
37
- model="TheBloke/Llama-2-7B-Chat-GGML",
38
  model_type="llama",
39
- max_new_tokens=512,
40
- temperature=0.5
41
  )
42
  return llm
43
 
44
- def qa_bot():
45
  embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2",
46
  model_kwargs={'device': 'cpu'})
47
  db = FAISS.load_local(DB_FAISS_PATH, embeddings)
48
  llm = load_llm()
49
  qa_prompt = set_custom_prompt()
50
- qa = retrieval_qa_chain(llm, qa_prompt, db)
51
- return qa
 
 
 
 
 
 
52
 
53
- # Define a function to respond to messages using your QA model
54
- def respond(message, history, system_message, max_tokens, temperature, top_p):
55
- qa_result = qa_bot()
56
- response = qa_result({'query': message})
57
  return response
58
 
59
- # Create a Gradio interface using the respond function
60
- demo = gr.ChatInterface(
61
- respond,
62
- additional_inputs=[
63
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
64
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
65
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
66
- gr.Slider(
67
- minimum=0.1,
68
- maximum=1.0,
69
- value=0.95,
70
- step=0.05,
71
- label="Top-p (nucleus sampling)",
72
- ),
73
- ],
74
  )
75
 
76
- if __name__ == "__main__":
77
- demo.launch()
 
 
 
1
  from langchain_community.document_loaders import PyPDFLoader, DirectoryLoader
2
  from langchain.prompts import PromptTemplate
3
  from langchain_community.embeddings import HuggingFaceEmbeddings
4
  from langchain_community.vectorstores import FAISS
5
  from langchain_community.llms import CTransformers
6
  from langchain.chains import RetrievalQA
7
+ import gradio as gr
8
 
9
  DB_FAISS_PATH = 'vectorstore/db_faiss'
10
 
 
19
  """
20
 
21
  def set_custom_prompt():
22
+ """
23
+ Prompt template for QA retrieval for each vectorstore
24
+ """
25
  prompt = PromptTemplate(template=custom_prompt_template,
26
  input_variables=['context', 'question'])
27
  return prompt
28
 
 
 
 
 
 
 
 
 
 
29
  def load_llm():
30
+ # Load the locally downloaded model here
31
  llm = CTransformers(
32
+ model = "TheBloke/Llama-2-7B-Chat-GGML",
33
  model_type="llama",
34
+ max_new_tokens = 512,
35
+ temperature = 0.5
36
  )
37
  return llm
38
 
39
+ def qa_bot(query):
40
  embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2",
41
  model_kwargs={'device': 'cpu'})
42
  db = FAISS.load_local(DB_FAISS_PATH, embeddings)
43
  llm = load_llm()
44
  qa_prompt = set_custom_prompt()
45
+ qa_chain = RetrievalQA.from_chain_type(llm=llm,
46
+ chain_type='stuff',
47
+ retriever=db.as_retriever(search_kwargs={'k': 2}),
48
+ return_source_documents=True,
49
+ chain_type_kwargs={'prompt': qa_prompt}
50
+ )
51
+ result = qa_chain({'query': query})
52
+ response = result['answers'][0]['text'] if result['answers'] else "Sorry, I don't have an answer for that."
53
 
 
 
 
 
54
  return response
55
 
56
+ iface = gr.Interface(
57
+ fn=qa_bot,
58
+ inputs="text",
59
+ outputs="text",
60
+ title="Medical Query Bot",
61
+ description="Enter your medical query to get an answer."
 
 
 
 
 
 
 
 
 
62
  )
63
 
64
+ if __name__ == '__main__':
65
+ iface.launch()
66
+