File size: 5,784 Bytes
dd6f4e5
31ae848
dd6f4e5
31ae848
 
 
 
 
 
dd6f4e5
31ae848
 
 
 
 
 
 
dd6f4e5
31ae848
dd6f4e5
 
 
 
31ae848
dd6f4e5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31ae848
dd6f4e5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

model = AutoModelForCausalLM.from_pretrained(
    "E-Hospital/open-orca-platypus-2-lora-medical",
    trust_remote_code=True,
    device_map = "auto",
)
tokenizer = AutoTokenizer.from_pretrained("Open-Orca/OpenOrca-Platypus2-13B", trust_remote_code=True)

def ask_bot(question):
  input_ids = tokenizer.encode(question, return_tensors="pt").to('cuda')
  with torch.no_grad():
      output = model.generate(input_ids, max_length=500, num_return_sequences=1, do_sample=True, top_k=50)
  generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
  response = generated_text.split("->:")[-1]
  return response

ask_bot("I have diabetes. What should I do?")

import mysql.connector
import re
from datetime import datetime
from typing import Any, List, Mapping, Optional

from langchain.memory import ConversationBufferMemory
from typing import Any, List, Mapping, Optional

from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM

from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from langchain.memory import ConversationSummaryBufferMemory

from langchain.memory import ConversationSummaryMemory

class CustomLLM(LLM):
    n: int
    # custom_model: llm  # Replace with the actual type of your custom model

    @property
    def _llm_type(self) -> str:
        return "custom"

    def _call(
        self,
        prompt: str,
        stop: Optional[List[str]] = None,
        run_manager: Optional[CallbackManagerForLLMRun] = None,
        **kwargs: Any,
    ) -> str:
        if stop is not None:
            raise ValueError("stop kwargs are not permitted.")

        input_ids = tokenizer.encode(prompt, return_tensors="pt").to('cuda')
        with torch.no_grad():
            output = model.generate(input_ids, max_length=500, num_return_sequences=1, do_sample=True, top_k=50)
        generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
        response = generated_text.split("->:")[-1]
        return response

    @property
    def _identifying_params(self) -> Mapping[str, Any]:
        """Get the identifying parameters."""
        return {"n": self.n}



def ask_bot(question):
  input_ids = tokenizer.encode(question, return_tensors="pt").to('cuda')
  with torch.no_grad():
      output = model.generate(input_ids, max_length=500, num_return_sequences=1, do_sample=True, top_k=50)
  generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
  response = generated_text.split("->:")[-1]
  return response


class DbHandler():
    def __init__(self):
        self.db_con = mysql.connector.connect(
            host="frwahxxknm9kwy6c.cbetxkdyhwsb.us-east-1.rds.amazonaws.com",
            user="j6qbx3bgjysst4jr",
            password="mcbsdk2s27ldf37t",
            port=3306,
            database="nkw2tiuvgv6ufu1z")
        self.cursorObject = self.db_con.cursor()

    def insert(self, fields, values):
        try:
            # Convert the lists to comma-separated strings
            fields_str = ', '.join(fields)
            values_str = ', '.join([f"'{v}'" for v in values])  # Wrap values in single quotes for SQL strings

            # Construct the SQL query
            query = f"INSERT INTO chatbot_conversation ({fields_str}) VALUES ({values_str})"

            self.cursorObject.execute(query)
            self.db_con.commit()
            return True
        except Exception as e:
            print(e)
            return False

    def get_history(self, patient_id):
        try:
            query = f"SELECT * FROM chatbot_conversation WHERE patient_id = {patient_id} ORDER BY timestamp ASC;"
            self.cursorObject.execute(query)
            data = self.cursorObject.fetchall()
            return data
        except Exception as e:
            print(e)
            return None


    def close_db(self):
        self.db_con.close()





def get_conversation_history(db, patient_id):
    conversations = db.get_history(patient_id)
    if conversations:
        return conversations[-1][5]
    return ""

llm = CustomLLM(n=10)
app = FastAPI()

app.add_middleware(
    CORSMiddleware,
    allow_origins=['*'],
    allow_credentials=True,
    allow_methods=['*'],
    allow_headers=['*'],
)

@app.get('/healthcheck')
async def root():
    return {'status': 'running'}

@app.post('/{patient_id}')
def chatbot(patient_id, user_data: dict=None):
    user_input = user_data["userObject"]["userInput"].get("message")
    db = DbHandler()
    try:
      history = get_conversation_history(db, patient_id)
      memory = ConversationSummaryBufferMemory(llm=llm, max_token_limit=200)
      prompt = "You are now a medical chatbot, and I am a patient. I will describe my conditions and symptoms and you will give me medical suggestions"
      if history:
        human_input = prompt + "The following is the patient's previous conversation with you: " + history + "This is the current question: " + user_input + " ->:"
      else:
        human_input = prompt + user_input + " ->:"
      human_text =  user_input.replace("'", "")
      response = llm._call(human_input)
      response = response.replace("'", "")
      memory.save_context({"input": user_input}, {"output": response})
      summary = memory.load_memory_variables({})
      ai_text = response.replace("'", "")
      memory.save_context({"input": user_input}, {"output": ai_text})
      summary = memory.load_memory_variables({})
      db.insert(("patient_id", "patient_text", "ai_text", "timestamp", "summarized_text"), (patient_id, human_text, ai_text, str(datetime.now()), summary['history'].replace("'", "")))
      db.close_db()
      return {"response": response}
    finally:
      db.close_db()