neuralleap commited on
Commit
c760308
1 Parent(s): c3d6b94

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +192 -44
main.py CHANGED
@@ -4,62 +4,210 @@ from fastapi.responses import FileResponse, HTMLResponse
4
  import os
5
  import io
6
 
7
- #import httpcore
8
- #setattr(httpcore, 'SyncHTTPTransport', 'AsyncHTTPProxy')
 
 
 
 
 
 
 
 
9
 
10
- import googletrans
11
- from googletrans import Translator
12
- translator = Translator()
13
- lan = googletrans.LANGUAGES
14
- keys = list(lan.keys())
15
- vals = list(lan.values())
16
 
17
- #from gradio_client import Client
18
 
19
- #client = Client("physician-ai/speech-to-text")
20
 
21
- #print(client.view_api())
22
- app = FastAPI()
23
 
24
 
25
- @app.post("/translate/")
26
- async def translate(text,language):
27
- return {"translated_text": translator.translate(text,dest=keys[vals.index(language)]).text}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
- #@app.post("/speech_to_text/")
30
- #async def speech_to_text(file: UploadFile = File(...)):
 
 
 
 
 
31
 
32
- # Save the file with a specific name
33
- #file_path = "inputvoice.mp3"
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
- #with open(file_path, "wb") as f:
36
- #f.write(file.file.read())
37
- #print("saved")
38
- #respond = client.predict(file_path,api_name="/get_stt")
39
- #print(respond.result())
40
- #return respond
41
 
 
 
 
 
 
 
 
 
42
 
43
- os.environ["COQUI_TOS_AGREED"] = "1"
44
- from TTS.api import TTS
45
 
46
- model_names = TTS().list_models()
47
- m = model_names[0]
48
- print(model_names)
49
- global xtts
50
- xtts = TTS(m, gpu=True)
51
- #tts.to("cpu") # no GPU or Amd
52
- xtts.to("cuda")
53
 
54
- @app.get("/text-to-speech/")
55
- def text_to_speech(text,language):
56
- global xtts
57
- audio_file = 'text_to_speech.wav'
58
- if language=="vietnamese":
59
- from gtts import gTTS
60
- tts = gTTS(text)
61
- tts.save(audio_file)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  else:
63
- xtts.tts_to_file(text,speaker_wav="input.wav",language=keys[vals.index(language)],file_path=audio_file)
64
- return FileResponse(audio_file, media_type='audio/mpeg')
 
65
 
 
4
  import os
5
  import io
6
 
7
+ import torch
8
+ from auto_gptq import AutoGPTQForCausalLM
9
+ from langchain import HuggingFacePipeline, PromptTemplate
10
+ from langchain.chains import RetrievalQA
11
+ from langchain.document_loaders import PyPDFDirectoryLoader
12
+ from langchain.embeddings import HuggingFaceInstructEmbeddings
13
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
14
+ from langchain.vectorstores import Chroma
15
+ from langchain.vectorstores import FAISS
16
+ from transformers import AutoTokenizer, TextStreamer, pipeline
17
 
18
+ DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
 
 
 
 
 
19
 
 
20
 
 
21
 
 
 
22
 
23
 
24
+ embeddings = HuggingFaceInstructEmbeddings(
25
+ model_name="hkunlp/instructor-large", model_kwargs={"device": DEVICE}
26
+ )
27
+ new_db = FAISS.load_local("faiss_index", embeddings)
28
+
29
+ model_name_or_path = "TheBloke/Llama-2-13B-chat-GPTQ"
30
+ model_basename = "model"
31
+
32
+
33
+ tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=True)
34
+
35
+ model = AutoGPTQForCausalLM.from_quantized(
36
+ model_name_or_path,
37
+ revision="gptq-4bit-128g-actorder_True",
38
+ model_basename=model_basename,
39
+ use_safetensors=True,
40
+ trust_remote_code=True,
41
+ device=DEVICE,
42
+ inject_fused_attention=False,
43
+ quantize_config=None,
44
+ )
45
+
46
+ #default promts it will work when we don't set the our custom system propts
47
+ DEFAULT_SYSTEM_PROMPT = """
48
+ You are a helpful, respectful and honest assistant. give answer for any questions.
49
+ """.strip()
50
+
51
+
52
+ def generate_prompt(prompt: str, system_prompt: str = DEFAULT_SYSTEM_PROMPT) -> str:
53
+ return f"""
54
+ [INST] <<SYS>>
55
+ {system_prompt}
56
+ <</SYS>>
57
+ {prompt} [/INST]
58
+ """.strip()
59
+
60
+ # setting the RAG pipeline
61
+ streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
62
+ text_pipeline = pipeline(
63
+ "text-generation",
64
+ model=model,
65
+ tokenizer=tokenizer,
66
+ max_new_tokens=4096,
67
+ temperature=2,
68
+ top_p=0.95,
69
+ repetition_penalty=1.15,
70
+ streamer=streamer,
71
+ )
72
+ global llm,llm2
73
+ llm = HuggingFacePipeline(pipeline=text_pipeline, model_kwargs={"temperature": 2})
74
+ llm2 = HuggingFacePipeline(pipeline=text_pipeline, model_kwargs={"temperature": 2})
75
+ # when the user query is not related to trained PDF data model will give the response from own knowledge
76
+ SYSTEM_PROMPT = "give answer from external data's. don't use the provided context"
77
+
78
+ template = generate_prompt(
79
+ """
80
+ {context}
81
+ Question: {question}
82
+ """,
83
+ system_prompt=SYSTEM_PROMPT,
84
+ )
85
+ prompt = PromptTemplate(template=template, input_variables=["context", "question"])
86
+
87
+ global qa_chain,qa_chain_a
88
+ qa_chain = RetrievalQA.from_chain_type(
89
+ llm=llm,
90
+ chain_type="stuff",
91
+ retriever=new_db.as_retriever(search_kwargs={"k": 2}),
92
+ return_source_documents=True,
93
+ chain_type_kwargs={"prompt": prompt},
94
+ )
95
 
96
+ qa_chain_a = RetrievalQA.from_chain_type(
97
+ llm=llm2,
98
+ chain_type="stuff",
99
+ retriever=new_db.as_retriever(search_kwargs={"k": 2}),
100
+ return_source_documents=True,
101
+ chain_type_kwargs={"prompt": prompt},
102
+ )
103
 
104
+ report_prompt_template = """
105
+ this is report format
106
+ Patient Name: [Insert name here]<br>
107
+ Age: [Insert age here]<br>
108
+ sex: [Insert here]<br>
109
+ Chief Complaint: [insert here]<br>
110
+ History of Present Illness:[insert here]<br>
111
+ Past Medical History: [insert here]<br>
112
+ Medication List: [insert here]<br>
113
+ Social History: [insert here]<br>
114
+ Family History: [insert here]<br>
115
+ Review of Systems: [insert here]<br>
116
+ ICD Code: [insert here]
117
+ convert this bellow details into above format don't add any other details .don't use the provided pdfs data's.\n\n"""
118
 
 
 
 
 
 
 
119
 
120
+ # 4. prompt sets for ask some defined questions and its will guide the model correct way
121
+ final_question ={
122
+ 8:"Do you have a history of medical conditions, such as allergies, chronic illnesses, or previous surgeries? If so, please provide details.",
123
+ 9:"What medications are you currently taking, including supplements and vitamins?",
124
+ 10:"Can you please Describe Family medical history (particularly close relatives): Does anyone in your immediate family suffer from similar symptoms or health issues?",
125
+ 11:"Can you please Describe Social history: Marital status, occupation, living arrangements, education level, and support system.",
126
+ 12:"Could you describe your symptoms, and have you noticed any changes or discomfort related to your respiratory, cardiovascular, gastrointestinal, or other body systems?"
127
+ }
128
 
129
+ # 1 . basic first prompt for handled the llama in correct like a family physician
130
+ sys = "You are a general family physician.\n\n"
131
 
132
+ # 5 . prommpts for get the diagnosis with ICD code based on the conversation, its will handle unrelated questions also(not related to diagnosis)
133
+ end_sys_prompts = "\n\ngive correct treatment and most related diagnosis with ICD code don't ask any questions. if question is not related to provided data don't give answer from this provided data's"
 
 
 
 
 
134
 
135
+
136
+
137
+ def refresh_model():
138
+ global llm,llm2
139
+ llm = HuggingFacePipeline(pipeline=text_pipeline, model_kwargs={"temperature": 2})
140
+ llm2 = HuggingFacePipeline(pipeline=text_pipeline, model_kwargs={"temperature": 2})
141
+
142
+ global qa_chain,qa_chain_a
143
+ qa_chain = RetrievalQA.from_chain_type(
144
+ llm=llm,
145
+ chain_type="stuff",
146
+ retriever=new_db.as_retriever(search_kwargs={"k": 2}),
147
+ return_source_documents=True,
148
+ chain_type_kwargs={"prompt": prompt},
149
+ )
150
+
151
+ qa_chain_a = RetrievalQA.from_chain_type(
152
+ llm=llm2,
153
+ chain_type="stuff",
154
+ retriever=new_db.as_retriever(search_kwargs={"k": 2}),
155
+ return_source_documents=True,
156
+ chain_type_kwargs={"prompt": prompt},
157
+ )
158
+ print("model refreshed")
159
+
160
+ app = FastAPI()
161
+ @app.post("/llm_response/")
162
+ async def llm_response(chain,id,mode):
163
+ id = int(id)
164
+ global qa_chain,qa_chain_a
165
+ refresh_model()
166
+
167
+ if id<13:
168
+ if id>=8:
169
+ return final_question[id]
170
+ else:
171
+ if id<5:
172
+ # 2 . prompmt control the natural way on question asking based on patient response,symptomps type
173
+ question = qa_chain(sys+chain +"""\n\nask single small queston to get details based on the patient response,and don't ask
174
+ same question again, and don't provide treatment and diagnosis ask next small and short question ,
175
+ always don't ask same question again and again , always only ask next single small question""")
176
+
177
+
178
+ else:
179
+ # 3. prompt will guide the model to ask yes or no questions based on patient response,symptomps type
180
+ question = qa_chain(sys+chain +"""\n\nask single small queston to get details based on the patient response,and don't ask
181
+ same question again, and don't provide treatment and diagnosis ask next small and short question with yes or no format ,
182
+ always don't ask same question again and again , always only ask next single small question""")
183
+ try:
184
+ if "Patient:" in str(question['result']) or "Patient response:" in str(question['result']):
185
+ return str((str(question['result']).split("\n\n")[-1]).split(":")[-1])
186
+ else:
187
+ return str(question['result']).split("\n\n")[1]
188
+
189
+ except:
190
+ if "Patient:" in str(question['result']) or "Patient response:" in str(question['result']):
191
+ return str(question['result']).split(":")[-1]
192
+ else:
193
+ return str(question['result'])
194
+
195
+ if id==16:
196
+ diagnosis_and_treatment = qa_chain(sys+chain+end_sys_prompts)
197
+ diagnosis_and_treatment = str(diagnosis_and_treatment['result'])
198
+
199
+ if mode!="h&p":
200
+ return diagnosis_and_treatment
201
+ else:
202
+ report = qa_chain_a(report_prompt_template+sys+chain+"\n\ntreatment & diagnosis with ICD code below\n"+diagnosis_and_treatment)
203
+ return str(report['result'])
204
+
205
+ result_ex = qa_chain(sys+chain+"""\n\n\nalways give small and single response based on the patient
206
+ response. don't give multiline response always give response based on last patient response""")
207
+ if "Patient:" in str(result_ex['result']) or "Patient response:" in str(result_ex['result']) or "Patient Response" in str(result_ex['result']):
208
+ return str((str(result_ex['result']).split("\n\n")[-1]).split(":")[-1])
209
  else:
210
+ return str(result_ex['result']).split("\n\n")[1]
211
+
212
+
213