neuralleap commited on
Commit
f2eec33
1 Parent(s): cf220f4

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +29 -7
main.py CHANGED
@@ -81,9 +81,11 @@ text_pipeline = pipeline(
81
  repetition_penalty=1.15,
82
  streamer=streamer,
83
  )
84
- global llm,llm2
85
  llm = HuggingFacePipeline(pipeline=text_pipeline, model_kwargs={"temperature": 2})
86
  llm2 = HuggingFacePipeline(pipeline=text_pipeline, model_kwargs={"temperature": 2})
 
 
87
  # when the user query is not related to trained PDF data model will give the response from own knowledge
88
  SYSTEM_PROMPT = "give answer from external data's. don't use the provided context"
89
 
@@ -96,7 +98,16 @@ Question: {question}
96
  )
97
  prompt = PromptTemplate(template=template, input_variables=["context", "question"])
98
 
99
- global qa_chain,qa_chain_a
 
 
 
 
 
 
 
 
 
100
  qa_chain = RetrievalQA.from_chain_type(
101
  llm=llm,
102
  chain_type="stuff",
@@ -147,11 +158,22 @@ end_sys_prompts = "\n\ngive correct treatment and most related diagnosis with IC
147
 
148
 
149
  def refresh_model():
150
- global llm,llm2
151
  llm = HuggingFacePipeline(pipeline=text_pipeline, model_kwargs={"temperature": 2})
152
  llm2 = HuggingFacePipeline(pipeline=text_pipeline, model_kwargs={"temperature": 2})
 
 
 
 
 
 
 
 
 
 
 
 
153
 
154
- global qa_chain,qa_chain_a
155
  qa_chain = RetrievalQA.from_chain_type(
156
  llm=llm,
157
  chain_type="stuff",
@@ -182,7 +204,7 @@ app.add_middleware(
182
  @app.post("/llm_response/")
183
  async def llm_response(chain,id,mode):
184
  id = int(id)
185
- global qa_chain,qa_chain_a
186
  refresh_model()
187
 
188
  def QA():
@@ -207,7 +229,7 @@ async def llm_response(chain,id,mode):
207
  return str(result_ex['result']).split("\n\n")
208
 
209
  if str(mode)=="dirrect_QA" and id==3:
210
- diagnosis_and_treatment = qa_chain(sys+chain+end_sys_prompts)
211
  diagnosis_and_treatment = str(diagnosis_and_treatment['result'])
212
  print(diagnosis_and_treatment)
213
  print("dirrect answer")
@@ -245,7 +267,7 @@ async def llm_response(chain,id,mode):
245
  return str(question['result'])
246
 
247
  if id==13:
248
- diagnosis_and_treatment = qa_chain(sys+chain+end_sys_prompts)
249
  diagnosis_and_treatment = str(diagnosis_and_treatment['result'])
250
  print(mode,diagnosis_and_treatment)
251
  report = qa_chain_a(report_prompt_template+sys+chain+"\n\ntreatment & diagnosis with ICD code below\n"+diagnosis_and_treatment)
 
81
  repetition_penalty=1.15,
82
  streamer=streamer,
83
  )
84
+ global llm,llm2,llm3
85
  llm = HuggingFacePipeline(pipeline=text_pipeline, model_kwargs={"temperature": 2})
86
  llm2 = HuggingFacePipeline(pipeline=text_pipeline, model_kwargs={"temperature": 2})
87
+ llm3 = HuggingFacePipeline(pipeline=text_pipeline, model_kwargs={"temperature": 2})
88
+
89
  # when the user query is not related to trained PDF data model will give the response from own knowledge
90
  SYSTEM_PROMPT = "give answer from external data's. don't use the provided context"
91
 
 
98
  )
99
  prompt = PromptTemplate(template=template, input_variables=["context", "question"])
100
 
101
+ global qa_chain,qa_chain_a,qa_chain_v
102
+
103
+ qa_chain_v = RetrievalQA.from_chain_type(
104
+ llm=llm3,
105
+ chain_type="stuff",
106
+ retriever=store.as_retriever(search_kwargs={"k": 2}),
107
+ return_source_documents=True,
108
+ chain_type_kwargs={"prompt": prompt},
109
+ )
110
+
111
  qa_chain = RetrievalQA.from_chain_type(
112
  llm=llm,
113
  chain_type="stuff",
 
158
 
159
 
160
  def refresh_model():
161
+ global llm,llm2,llm3
162
  llm = HuggingFacePipeline(pipeline=text_pipeline, model_kwargs={"temperature": 2})
163
  llm2 = HuggingFacePipeline(pipeline=text_pipeline, model_kwargs={"temperature": 2})
164
+ llm3 = HuggingFacePipeline(pipeline=text_pipeline, model_kwargs={"temperature": 2})
165
+
166
+ global qa_chain,qa_chain_a,qa_chain_v
167
+
168
+
169
+ qa_chain_v = RetrievalQA.from_chain_type(
170
+ llm=llm3,
171
+ chain_type="stuff",
172
+ retriever=store.as_retriever(search_kwargs={"k": 2}),
173
+ return_source_documents=True,
174
+ chain_type_kwargs={"prompt": prompt},
175
+ )
176
 
 
177
  qa_chain = RetrievalQA.from_chain_type(
178
  llm=llm,
179
  chain_type="stuff",
 
204
  @app.post("/llm_response/")
205
  async def llm_response(chain,id,mode):
206
  id = int(id)
207
+ global qa_chain,qa_chain_a,qa_chain_v
208
  refresh_model()
209
 
210
  def QA():
 
229
  return str(result_ex['result']).split("\n\n")
230
 
231
  if str(mode)=="dirrect_QA" and id==3:
232
+ diagnosis_and_treatment = qa_chain_v(sys+chain+end_sys_prompts)
233
  diagnosis_and_treatment = str(diagnosis_and_treatment['result'])
234
  print(diagnosis_and_treatment)
235
  print("dirrect answer")
 
267
  return str(question['result'])
268
 
269
  if id==13:
270
+ diagnosis_and_treatment = qa_chain_v(sys+chain+end_sys_prompts)
271
  diagnosis_and_treatment = str(diagnosis_and_treatment['result'])
272
  print(mode,diagnosis_and_treatment)
273
  report = qa_chain_a(report_prompt_template+sys+chain+"\n\ntreatment & diagnosis with ICD code below\n"+diagnosis_and_treatment)