E-Hospital commited on
Commit
dd6f4e5
1 Parent(s): cce5ae2

Upload 2 files

Browse files
Files changed (2) hide show
  1. main.py.py +204 -0
  2. requirements.txt.txt +17 -0
main.py.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """TestAPI.ipynb
3
+
4
+ Automatically generated by Colaboratory.
5
+
6
+ Original file is located at
7
+ https://colab.research.google.com/drive/1WToaz7kQoFpI0_M8j6uWPigBrKlkL4ml
8
+ """
9
+
10
+
11
+
12
+ from transformers import AutoTokenizer,AutoModelForCausalLM
13
+ import os
14
+ os.environ["CUDA_VISIBLE_DEVICES"]="0"
15
+ import torch
16
+ from transformers import AutoTokenizer, AutoModelForCausalLM
17
+
18
+ from typing import List
19
+
20
+
21
+ from peft import PeftModel
22
+ from transformers import AutoTokenizer, AutoModelForCausalLM
23
+ from typing import Any, List, Mapping, Optional
24
+ import torch
25
+ from transformers import AutoTokenizer, AutoModelForCausalLM
26
+ from langchain.callbacks.manager import CallbackManagerForLLMRun
27
+ from langchain.llms.base import LLM
28
+
29
+ import torch
30
+ from transformers import AutoTokenizer, AutoModelForCausalLM
31
+
32
+ import mysql.connector
33
+ import re
34
+ from datetime import datetime
35
+
36
+
37
+ from langchain.memory import ConversationBufferMemory
38
+ from typing import Any, List, Mapping, Optional
39
+
40
+ from langchain.callbacks.manager import CallbackManagerForLLMRun
41
+ from langchain.llms.base import LLM
42
+
43
+ from fastapi import FastAPI
44
+ from fastapi.middleware.cors import CORSMiddleware
45
+ from langchain.memory import ConversationSummaryBufferMemory
46
+
47
+ from langchain.memory import ConversationSummaryMemory
48
+
49
+
50
+ model_name = "Open-Orca/OpenOrca-Platypus2-13B"
51
+
52
+
53
+ tokenizer = AutoTokenizer.from_pretrained(
54
+ model_name, trust_remote_code=True)
55
+
56
+ model = AutoModelForCausalLM.from_pretrained(
57
+ model_name,
58
+ trust_remote_code=True,
59
+ load_in_8bit = True,
60
+ device_map = "auto",
61
+ )
62
+
63
+
64
+
65
+ model = PeftModel.from_pretrained(model, "teslalord/open-orca-platypus-2-medical")
66
+
67
+ model = model.merge_and_unload()
68
+
69
+
70
+
71
+ class CustomLLM(LLM):
72
+ n: int
73
+ # custom_model: llm # Replace with the actual type of your custom model
74
+
75
+ @property
76
+ def _llm_type(self) -> str:
77
+ return "custom"
78
+
79
+ def _call(
80
+ self,
81
+ prompt: str,
82
+ stop: Optional[List[str]] = None,
83
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
84
+ **kwargs: Any,
85
+ ) -> str:
86
+ if stop is not None:
87
+ raise ValueError("stop kwargs are not permitted.")
88
+
89
+ input_ids = tokenizer.encode(prompt, return_tensors="pt").to('cuda')
90
+ with torch.no_grad():
91
+ output = model.generate(input_ids, max_length=500, num_return_sequences=1, do_sample=True, top_k=50)
92
+ generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
93
+ response = generated_text.split("->:")[-1]
94
+ return response
95
+
96
+ @property
97
+ def _identifying_params(self) -> Mapping[str, Any]:
98
+ """Get the identifying parameters."""
99
+ return {"n": self.n}
100
+
101
+
102
+
103
+ def ask_bot(question):
104
+ input_ids = tokenizer.encode(question, return_tensors="pt").to('cuda')
105
+ with torch.no_grad():
106
+ output = model.generate(input_ids, max_length=500, num_return_sequences=1, do_sample=True, top_k=50)
107
+ generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
108
+ response = generated_text.split("->:")[-1]
109
+ return response
110
+
111
+
112
+ class DbHandler():
113
+ def __init__(self):
114
+ self.db_con = mysql.connector.connect(
115
+ host="frwahxxknm9kwy6c.cbetxkdyhwsb.us-east-1.rds.amazonaws.com",
116
+ user="j6qbx3bgjysst4jr",
117
+ password="mcbsdk2s27ldf37t",
118
+ port=3306,
119
+ database="nkw2tiuvgv6ufu1z")
120
+ self.cursorObject = self.db_con.cursor()
121
+
122
+ def insert(self, fields, values):
123
+ try:
124
+ # Convert the lists to comma-separated strings
125
+ fields_str = ', '.join(fields)
126
+ values_str = ', '.join([f"'{v}'" for v in values]) # Wrap values in single quotes for SQL strings
127
+
128
+ # Construct the SQL query
129
+ query = f"INSERT INTO chatbot_conversation ({fields_str}) VALUES ({values_str})"
130
+
131
+ self.cursorObject.execute(query)
132
+ self.db_con.commit()
133
+ return True
134
+ except Exception as e:
135
+ print(e)
136
+ return False
137
+
138
+ def get_history(self, patient_id):
139
+ try:
140
+ query = f"SELECT * FROM chatbot_conversation WHERE patient_id = '{patient_id}' ORDER BY timestamp ASC;"
141
+ self.cursorObject.execute(query)
142
+ data = self.cursorObject.fetchall()
143
+ return data
144
+ except Exception as e:
145
+ print(e)
146
+ return None
147
+
148
+
149
+ def close_db(self):
150
+ self.db_con.close()
151
+
152
+
153
+
154
+
155
+
156
+ def get_conversation_history(db, patient_id):
157
+ conversations = db.get_history(patient_id)
158
+ if conversations:
159
+ return conversations[-1][5]
160
+ return ""
161
+
162
+
163
+
164
+ llm = CustomLLM(n=10)
165
+ app = FastAPI()
166
+
167
+ app.add_middleware(
168
+ CORSMiddleware,
169
+ allow_origins=['*'],
170
+ allow_credentials=True,
171
+ allow_methods=['*'],
172
+ allow_headers=['*'],
173
+ )
174
+
175
+ @app.get('/healthcheck')
176
+ async def root():
177
+ return {'status': 'running'}
178
+
179
+ @app.post('/{patient_id}')
180
+ def chatbot(patient_id, user_data: dict=None):
181
+ user_input = user_data["userObject"]["userInput"].get("message")
182
+ db = DbHandler()
183
+ try:
184
+ history = get_conversation_history(db, patient_id)
185
+ memory = ConversationSummaryBufferMemory(llm=llm, max_token_limit=200)
186
+ prompt = "You are now a medical chatbot, and I am a patient. I will describe my conditions and symptoms and you will give me medical suggestions"
187
+ if history:
188
+ human_input = prompt + "The following is the patient's previous conversation with you: " + history + "This is the current question: " + user_input + " ->:"
189
+ else:
190
+ human_input = prompt + user_input + " ->:"
191
+ human_text = user_input.replace("'", "")
192
+ response = llm._call(human_input)
193
+ response = response.replace("'", "")
194
+ memory.save_context({"input": user_input}, {"output": response})
195
+ summary = memory.load_memory_variables({})
196
+ ai_text = response.replace("'", "")
197
+ memory.save_context({"input": user_input}, {"output": ai_text})
198
+ summary = memory.load_memory_variables({})
199
+ db.insert(("patient_id", "patient_text", "ai_text", "timestamp", "summarized_text"), (patient_id, human_text, ai_text, str(datetime.now()), summary['history'].replace("'", "")))
200
+ db.close_db()
201
+ return {"response": response}
202
+ finally:
203
+ db.close_db()
204
+
requirements.txt.txt ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ fastapi
2
+ nest-asyncio
3
+ pyngrok
4
+ uvicorn
5
+ langchain
6
+ mysql-connector-python
7
+ transformers
8
+ accelerate
9
+ evaluate
10
+ datasets
11
+ peft
12
+ torch
13
+ huggingface_hub
14
+ bitsandbytes
15
+ loralib
16
+ git+https://github.com/huggingface/transformers.git@main
17
+ git+https://github.com/huggingface/peft.git