sonali-tamhankar commited on
Commit
1128416
1 Parent(s): bfe7bca

Delete model.py

Browse files
Files changed (1) hide show
  1. model.py +0 -94
model.py DELETED
@@ -1,94 +0,0 @@
1
- from langchain.document_loaders import PyPDFLoader, DirectoryLoader
2
- from langchain import PromptTemplate
3
- from langchain.embeddings import HuggingFaceEmbeddings
4
- from langchain.vectorstores import FAISS
5
- from langchain.llms import CTransformers
6
- from langchain.chains import RetrievalQA
7
- import chainlit as cl
8
-
9
- DB_FAISS_PATH = '.'
10
-
11
- custom_prompt_template = """Use the following pieces of information to answer the user's question.
12
- If you don't know the answer, just say that you don't know, don't try to make up an answer.
13
-
14
- Context: {context}
15
- Question: {question}
16
-
17
- Only return the helpful answer below and nothing else.
18
- Helpful answer:
19
- """
20
-
21
- def set_custom_prompt():
22
- """
23
- Prompt template for QA retrieval for each vectorstore
24
- """
25
- prompt = PromptTemplate(template=custom_prompt_template,
26
- input_variables=['context', 'question'])
27
- return prompt
28
-
29
- #Retrieval QA Chain
30
- def retrieval_qa_chain(llm, prompt, db):
31
- qa_chain = RetrievalQA.from_chain_type(llm=llm,
32
- chain_type='stuff',
33
- retriever=db.as_retriever(search_kwargs={'k': 2}),
34
- return_source_documents=True,
35
- chain_type_kwargs={'prompt': prompt}
36
- )
37
- return qa_chain
38
-
39
- #Loading the model
40
- def load_llm():
41
- # Load the locally downloaded model here
42
- llm = CTransformers(
43
- model = "llama-2-7b-chat.ggmlv3.q8_0.bin",
44
- model_type="llama",
45
- max_new_tokens = 512,
46
- temperature = 0.5
47
- )
48
- return llm
49
-
50
- #QA Model Function
51
- def qa_bot():
52
- embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2",
53
- model_kwargs={'device': 'cpu'})
54
- db = FAISS.load_local(DB_FAISS_PATH, embeddings)
55
- llm = load_llm()
56
- qa_prompt = set_custom_prompt()
57
- qa = retrieval_qa_chain(llm, qa_prompt, db)
58
-
59
- return qa
60
-
61
- #output function
62
- def final_result(query):
63
- qa_result = qa_bot()
64
- response = qa_result({'query': query})
65
- return response
66
-
67
- #chainlit code
68
- @cl.on_chat_start
69
- async def start():
70
- chain = qa_bot()
71
- msg = cl.Message(content="Starting the bot...")
72
- await msg.send()
73
- msg.content = "Hi, Welcome to Medical Bot. What is your query?"
74
- await msg.update()
75
-
76
- cl.user_session.set("chain", chain)
77
-
78
- @cl.on_message
79
- async def main(message):
80
- chain = cl.user_session.get("chain")
81
- cb = cl.AsyncLangchainCallbackHandler(
82
- stream_final_answer=True, answer_prefix_tokens=["FINAL", "ANSWER"]
83
- )
84
- cb.answer_reached = True
85
- res = await chain.acall(message, callbacks=[cb])
86
- answer = res["result"]
87
- sources = res["source_documents"]
88
-
89
- if sources:
90
- answer += f"\nSources:" + str(sources)
91
- else:
92
- answer += "\nNo sources found"
93
-
94
- await cl.Message(content=answer).send()