Spaces:
Runtime error
Runtime error
Commit
•
f2eec33
1
Parent(s):
cf220f4
Update main.py
Browse files
main.py
CHANGED
@@ -81,9 +81,11 @@ text_pipeline = pipeline(
|
|
81 |
repetition_penalty=1.15,
|
82 |
streamer=streamer,
|
83 |
)
|
84 |
-
global llm,llm2
|
85 |
llm = HuggingFacePipeline(pipeline=text_pipeline, model_kwargs={"temperature": 2})
|
86 |
llm2 = HuggingFacePipeline(pipeline=text_pipeline, model_kwargs={"temperature": 2})
|
|
|
|
|
87 |
# when the user query is not related to trained PDF data model will give the response from own knowledge
|
88 |
SYSTEM_PROMPT = "give answer from external data's. don't use the provided context"
|
89 |
|
@@ -96,7 +98,16 @@ Question: {question}
|
|
96 |
)
|
97 |
prompt = PromptTemplate(template=template, input_variables=["context", "question"])
|
98 |
|
99 |
-
global qa_chain,qa_chain_a
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
100 |
qa_chain = RetrievalQA.from_chain_type(
|
101 |
llm=llm,
|
102 |
chain_type="stuff",
|
@@ -147,11 +158,22 @@ end_sys_prompts = "\n\ngive correct treatment and most related diagnosis with IC
|
|
147 |
|
148 |
|
149 |
def refresh_model():
|
150 |
-
global llm,llm2
|
151 |
llm = HuggingFacePipeline(pipeline=text_pipeline, model_kwargs={"temperature": 2})
|
152 |
llm2 = HuggingFacePipeline(pipeline=text_pipeline, model_kwargs={"temperature": 2})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
153 |
|
154 |
-
global qa_chain,qa_chain_a
|
155 |
qa_chain = RetrievalQA.from_chain_type(
|
156 |
llm=llm,
|
157 |
chain_type="stuff",
|
@@ -182,7 +204,7 @@ app.add_middleware(
|
|
182 |
@app.post("/llm_response/")
|
183 |
async def llm_response(chain,id,mode):
|
184 |
id = int(id)
|
185 |
-
global qa_chain,qa_chain_a
|
186 |
refresh_model()
|
187 |
|
188 |
def QA():
|
@@ -207,7 +229,7 @@ async def llm_response(chain,id,mode):
|
|
207 |
return str(result_ex['result']).split("\n\n")
|
208 |
|
209 |
if str(mode)=="dirrect_QA" and id==3:
|
210 |
-
diagnosis_and_treatment =
|
211 |
diagnosis_and_treatment = str(diagnosis_and_treatment['result'])
|
212 |
print(diagnosis_and_treatment)
|
213 |
print("dirrect answer")
|
@@ -245,7 +267,7 @@ async def llm_response(chain,id,mode):
|
|
245 |
return str(question['result'])
|
246 |
|
247 |
if id==13:
|
248 |
-
diagnosis_and_treatment =
|
249 |
diagnosis_and_treatment = str(diagnosis_and_treatment['result'])
|
250 |
print(mode,diagnosis_and_treatment)
|
251 |
report = qa_chain_a(report_prompt_template+sys+chain+"\n\ntreatment & diagnosis with ICD code below\n"+diagnosis_and_treatment)
|
|
|
81 |
repetition_penalty=1.15,
|
82 |
streamer=streamer,
|
83 |
)
|
84 |
+
global llm,llm2,llm3
|
85 |
llm = HuggingFacePipeline(pipeline=text_pipeline, model_kwargs={"temperature": 2})
|
86 |
llm2 = HuggingFacePipeline(pipeline=text_pipeline, model_kwargs={"temperature": 2})
|
87 |
+
llm3 = HuggingFacePipeline(pipeline=text_pipeline, model_kwargs={"temperature": 2})
|
88 |
+
|
89 |
# when the user query is not related to trained PDF data model will give the response from own knowledge
|
90 |
SYSTEM_PROMPT = "give answer from external data's. don't use the provided context"
|
91 |
|
|
|
98 |
)
|
99 |
prompt = PromptTemplate(template=template, input_variables=["context", "question"])
|
100 |
|
101 |
+
global qa_chain,qa_chain_a,qa_chain_v
|
102 |
+
|
103 |
+
qa_chain_v = RetrievalQA.from_chain_type(
|
104 |
+
llm=llm3,
|
105 |
+
chain_type="stuff",
|
106 |
+
retriever=store.as_retriever(search_kwargs={"k": 2}),
|
107 |
+
return_source_documents=True,
|
108 |
+
chain_type_kwargs={"prompt": prompt},
|
109 |
+
)
|
110 |
+
|
111 |
qa_chain = RetrievalQA.from_chain_type(
|
112 |
llm=llm,
|
113 |
chain_type="stuff",
|
|
|
158 |
|
159 |
|
160 |
def refresh_model():
|
161 |
+
global llm,llm2,llm3
|
162 |
llm = HuggingFacePipeline(pipeline=text_pipeline, model_kwargs={"temperature": 2})
|
163 |
llm2 = HuggingFacePipeline(pipeline=text_pipeline, model_kwargs={"temperature": 2})
|
164 |
+
llm3 = HuggingFacePipeline(pipeline=text_pipeline, model_kwargs={"temperature": 2})
|
165 |
+
|
166 |
+
global qa_chain,qa_chain_a,qa_chain_v
|
167 |
+
|
168 |
+
|
169 |
+
qa_chain_v = RetrievalQA.from_chain_type(
|
170 |
+
llm=llm3,
|
171 |
+
chain_type="stuff",
|
172 |
+
retriever=store.as_retriever(search_kwargs={"k": 2}),
|
173 |
+
return_source_documents=True,
|
174 |
+
chain_type_kwargs={"prompt": prompt},
|
175 |
+
)
|
176 |
|
|
|
177 |
qa_chain = RetrievalQA.from_chain_type(
|
178 |
llm=llm,
|
179 |
chain_type="stuff",
|
|
|
204 |
@app.post("/llm_response/")
|
205 |
async def llm_response(chain,id,mode):
|
206 |
id = int(id)
|
207 |
+
global qa_chain,qa_chain_a,qa_chain_v
|
208 |
refresh_model()
|
209 |
|
210 |
def QA():
|
|
|
229 |
return str(result_ex['result']).split("\n\n")
|
230 |
|
231 |
if str(mode)=="dirrect_QA" and id==3:
|
232 |
+
diagnosis_and_treatment = qa_chain_v(sys+chain+end_sys_prompts)
|
233 |
diagnosis_and_treatment = str(diagnosis_and_treatment['result'])
|
234 |
print(diagnosis_and_treatment)
|
235 |
print("dirrect answer")
|
|
|
267 |
return str(question['result'])
|
268 |
|
269 |
if id==13:
|
270 |
+
diagnosis_and_treatment = qa_chain_v(sys+chain+end_sys_prompts)
|
271 |
diagnosis_and_treatment = str(diagnosis_and_treatment['result'])
|
272 |
print(mode,diagnosis_and_treatment)
|
273 |
report = qa_chain_a(report_prompt_template+sys+chain+"\n\ntreatment & diagnosis with ICD code below\n"+diagnosis_and_treatment)
|