File size: 5,770 Bytes
dd6f4e5
31ae848
c3d009e
4d61bc4
c3d009e
 
9110acb
c3d009e
d235373
 
 
 
 
 
91e75bf
dd6f4e5
31ae848
 
3de3aa5
31ae848
 
dd6f4e5
31ae848
e0f4606
31ae848
f18e25c
31ae848
 
 
dd6f4e5
 
 
 
31ae848
dd6f4e5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e0f4606
dd6f4e5
8fc450b
dd6f4e5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31ae848
dd6f4e5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f18e25c
 
8fc450b
 
 
 
 
 
 
dd6f4e5
 
 
 
4c4c643
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import logging

logger = logging.getLogger()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 

logger.info("Getting Device")
logger.info(device)
if torch.cuda.is_available():
    num_of_gpus = torch.cuda.device_count()
    logger.info("Getting gpus")
    logger.info(num_of_gpus)


model = AutoModelForCausalLM.from_pretrained(
    "E-Hospital/open-orca-platypus-2-lora-medical",
    trust_remote_code=True
)
tokenizer = AutoTokenizer.from_pretrained("Open-Orca/OpenOrca-Platypus2-13B", trust_remote_code=True)

def ask_bot(question):
  input_ids = tokenizer.encode(question, return_tensors="pt").to(device)
  with torch.no_grad():
      output = model.generate(input_ids, max_length=200, num_return_sequences=1, do_sample=True, top_k=50)
  generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
  response = generated_text.split("->:")[-1]
  return response

import mysql.connector
import re
from datetime import datetime
from typing import Any, List, Mapping, Optional

from langchain.memory import ConversationBufferMemory
from typing import Any, List, Mapping, Optional

from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM

from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from langchain.memory import ConversationSummaryBufferMemory

from langchain.memory import ConversationSummaryMemory

class CustomLLM(LLM):
    n: int
    # custom_model: llm  # Replace with the actual type of your custom model

    @property
    def _llm_type(self) -> str:
        return "custom"

    def _call(
        self,
        prompt: str,
        stop: Optional[List[str]] = None,
        run_manager: Optional[CallbackManagerForLLMRun] = None,
        **kwargs: Any,
    ) -> str:
        if stop is not None:
            raise ValueError("stop kwargs are not permitted.")

        input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
        with torch.no_grad():
            output = model.generate(input_ids, max_length=100, num_return_sequences=1, do_sample=True, top_k=50)
        generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
        response = generated_text.split("->:")[-1]
        return response

    @property
    def _identifying_params(self) -> Mapping[str, Any]:
        """Get the identifying parameters."""
        return {"n": self.n}



class DbHandler():
    def __init__(self):
        self.db_con = mysql.connector.connect(
            host="frwahxxknm9kwy6c.cbetxkdyhwsb.us-east-1.rds.amazonaws.com",
            user="j6qbx3bgjysst4jr",
            password="mcbsdk2s27ldf37t",
            port=3306,
            database="nkw2tiuvgv6ufu1z")
        self.cursorObject = self.db_con.cursor()

    def insert(self, fields, values):
        try:
            # Convert the lists to comma-separated strings
            fields_str = ', '.join(fields)
            values_str = ', '.join([f"'{v}'" for v in values])  # Wrap values in single quotes for SQL strings

            # Construct the SQL query
            query = f"INSERT INTO chatbot_conversation ({fields_str}) VALUES ({values_str})"

            self.cursorObject.execute(query)
            self.db_con.commit()
            return True
        except Exception as e:
            print(e)
            return False

    def get_history(self, patient_id):
        try:
            query = f"SELECT * FROM chatbot_conversation WHERE patient_id = {patient_id} ORDER BY timestamp ASC;"
            self.cursorObject.execute(query)
            data = self.cursorObject.fetchall()
            return data
        except Exception as e:
            print(e)
            return None


    def close_db(self):
        self.db_con.close()





def get_conversation_history(db, patient_id):
    conversations = db.get_history(patient_id)
    if conversations:
        return conversations[-1][5]
    return ""

llm = CustomLLM(n=10)
app = FastAPI()

app.add_middleware(
    CORSMiddleware,
    allow_origins=['*'],
    allow_credentials=True,
    allow_methods=['*'],
    allow_headers=['*'],
)

@app.get('/healthcheck')
async def root():
    return {'status': 'running'}

@app.post('/{patient_id}')
def chatbot(patient_id, user_data: dict=None):
    user_input = user_data["userObject"]["userInput"].get("message")
    db = DbHandler()
    try:
      history = get_conversation_history(db, patient_id)
      memory = ConversationSummaryBufferMemory(llm=llm, max_token_limit=200)
      prompt = "You are now a medical chatbot, and I am a patient. I will describe my conditions and symptoms and you will give me medical suggestions"
      if history:
        human_input = prompt + "The following is the patient's previous conversation with you: " + history + "This is the current question: " + user_input + " ->:"
      else:
        human_input = prompt + user_input + " ->:"
      human_text =  user_input.replace("'", "")
      # response = llm._call(human_input)
      response = ask_bot(human_input)
      # response = response.replace("'", "")
      # memory.save_context({"input": user_input}, {"output": response})
      # summary = memory.load_memory_variables({})
      # ai_text = response.replace("'", "")
      # memory.save_context({"input": user_input}, {"output": ai_text})
      # summary = memory.load_memory_variables({})
      # db.insert(("patient_id", "patient_text", "ai_text", "timestamp", "summarized_text"), (patient_id, human_text, ai_text, str(datetime.now()), summary['history'].replace("'", "")))
      db.close_db()
      return {"response": response}
    finally:
      db.close_db()

if __name__=='__main__':
    uvicorn.run('main:app', reload=True)