drkareemkamal commited on
Commit
e9f3fa2
·
verified ·
1 Parent(s): 7d71f93

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +102 -102
model.py CHANGED
@@ -1,102 +1,102 @@
1
- #from langchain import PromptTemplate
2
- from langchain_core.prompts import PromptTemplate
3
-
4
- from langchain_community.embeddings import HuggingFaceBgeEmbeddings
5
- from langchain_community.vectorstores import FAISS
6
- from langchain_community.llms.ctransformers import CTransformers
7
- #from langchain.chains import RetrievalQA
8
- from langchain.chains.retrieval_qa.base import RetrievalQA
9
- import chainlit as cl
10
- from transformers import AutoModel
11
-
12
- DB_FAISS_PATH = 'vectorstores/'
13
-
14
- custom_prompt_template = '''
15
- use the following pieces of information to answer the user's questions/
16
- If you don't know the answer, please just say that don't know the answer, don't try to make uo an answer.
17
-
18
- Content : {}
19
- Question : {question}
20
-
21
- only return the helpful answer below and nothing else.
22
- '''
23
-
24
- def set_custom_prompt():
25
- """
26
- Prompt template for QA retrieval for vector stores
27
- """
28
- prompt = PromptTemplate(template = custom_prompt_template,
29
- input_variables = ['context','question'])
30
-
31
- return prompt
32
-
33
- def load_llm():
34
- llm = CTransformers(
35
- model = 'TheBloke/Llama-2-7B-Chat-GGML',
36
- #model = AutoModel.from_pretrained("TheBloke/Llama-2-7B-Chat-GGML"),
37
- model_type = 'llama',
38
- max_new_token = 512,
39
- temperature = 0.5
40
- )
41
- return llm
42
-
43
- def retrieval_qa_chain(llm,prompt,db):
44
- qa_chain = RetrievalQA.from_chain_type(
45
- llm = llm,
46
- chain_type = 'stuff',
47
- retriever = db.as_retriever(search_kwargs= {'k': 2}),
48
- return_source_documents = True,
49
- chain_type_kwargs = {'prompt': prompt}
50
- )
51
-
52
- return qa_chain
53
-
54
- def qa_bot():
55
- embeddings = HuggingFaceBgeEmbeddings(model_name = 'sentence-transformers/all-MiniLM-L6-v2',
56
- model_kwargs = {'device':'cpu'})
57
-
58
- db = FAISS.load_local(DB_FAISS_PATH,embeddings)
59
- llm = load_llm()
60
- qa_prompt = set_custom_prompt()
61
- qa = retrieval_qa_chain(llm,qa_prompt, db)
62
-
63
- return qa
64
-
65
- def final_result(query):
66
- qa_result = qa_bot()
67
- response = qa_result({'quert' : query})
68
-
69
- return response
70
-
71
-
72
- ## Chainlit
73
- @cl.on_chat_start
74
- async def start():
75
- chain = qa_bot()
76
- msg = cl.Message(content = 'Starting the bot...')
77
- await msg.send()
78
-
79
- msg.conteny = "Hi Welcome to the medical Bot. What is your query?"
80
- await msg.update()
81
- cl.user_session.set('chain', chain)
82
-
83
- @cl.on_message
84
- async def main(message):
85
- chain = cl.user_session.set('chain')
86
- cb = cl.AsyncLangchainCallbackHandler(
87
- stream_final_answer= True,
88
- answer_prefix_tokens= ['FINAL','ANSWER']
89
- )
90
- cb.answer_reached = True
91
- res = await chain.acall(message,callbacks = [cb])
92
- answer = res['result']
93
- sources = res['sources_documents']
94
-
95
- if sources :
96
- answer += f"\nSources :" + str(sources)
97
-
98
- else :
99
- answer += f"\nNo Rources Found"
100
-
101
- await cl.Message(content=answer).send()
102
-
 
1
+ #from langchain import PromptTemplate
2
+ from langchain_core.prompts import PromptTemplate
3
+
4
+ from langchain_community.embeddings import HuggingFaceBgeEmbeddings
5
+ from langchain_community.vectorstores import FAISS
6
+ from langchain_community.llms.ctransformers import CTransformers
7
+ #from langchain.chains import RetrievalQA
8
+ from langchain.chains.retrieval_qa.base import RetrievalQA
9
+ import chainlit as cl
10
+ from transformers import AutoModel
11
+
12
+ DB_FAISS_PATH = 'vectorstores/'
13
+
14
+ custom_prompt_template = '''
15
+ use the following pieces of information to answer the user's questions.
16
+ If you don't know the answer, please just say that don't know the answer, don't try to make uo an answer.
17
+
18
+ Context : {}
19
+ Question : {question}
20
+
21
+ only return the helpful answer below and nothing else.
22
+ '''
23
+
24
+ def set_custom_prompt():
25
+ """
26
+ Prompt template for QA retrieval for vector stores
27
+ """
28
+ prompt = PromptTemplate(template = custom_prompt_template,
29
+ input_variables = ['context','question'])
30
+
31
+ return prompt
32
+
33
+ def load_llm():
34
+ llm = CTransformers(
35
+ model = 'TheBloke/Llama-2-7B-Chat-GGML',
36
+ #model = AutoModel.from_pretrained("TheBloke/Llama-2-7B-Chat-GGML"),
37
+ model_type = 'llama',
38
+ max_new_token = 512,
39
+ temperature = 0.5
40
+ )
41
+ return llm
42
+
43
+ def retrieval_qa_chain(llm,prompt,db):
44
+ qa_chain = RetrievalQA.from_chain_type(
45
+ llm = llm,
46
+ chain_type = 'stuff',
47
+ retriever = db.as_retriever(search_kwargs= {'k': 2}),
48
+ return_source_documents = True,
49
+ chain_type_kwargs = {'prompt': prompt}
50
+ )
51
+
52
+ return qa_chain
53
+
54
+ def qa_bot():
55
+ embeddings = HuggingFaceBgeEmbeddings(model_name = 'sentence-transformers/all-MiniLM-L6-v2',
56
+ model_kwargs = {'device':'cpu'})
57
+
58
+ db = FAISS.load_local(DB_FAISS_PATH,embeddings)
59
+ llm = load_llm()
60
+ qa_prompt = set_custom_prompt()
61
+ qa = retrieval_qa_chain(llm,qa_prompt, db)
62
+
63
+ return qa
64
+
65
+ def final_result(query):
66
+ qa_result = qa_bot()
67
+ response = qa_result({'query' : query})
68
+
69
+ return response
70
+
71
+
72
+ ## Chainlit
73
+ @cl.on_chat_start
74
+ async def start():
75
+ chain = qa_bot()
76
+ msg = cl.Message(content = 'Starting the bot...')
77
+ await msg.send()
78
+
79
+ msg.conteny = "Hi Welcome to the medical Bot. What is your query?"
80
+ await msg.update()
81
+ cl.user_session.set('chain', chain)
82
+
83
+ @cl.on_message
84
+ async def main(message):
85
+ chain = cl.user_session.set('chain')
86
+ cb = cl.AsyncLangchainCallbackHandler(
87
+ stream_final_answer= True,
88
+ answer_prefix_tokens= ['FINAL','ANSWER']
89
+ )
90
+ cb.answer_reached = True
91
+ res = await chain.acall(message,callbacks = [cb])
92
+ answer = res['result']
93
+ sources = res['sources_documents']
94
+
95
+ if sources :
96
+ answer += f"\nSources :" + str(sources)
97
+
98
+ else :
99
+ answer += f"\nNo Rources Found"
100
+
101
+ await cl.Message(content=answer).send()
102
+