Spaces:
Paused
Paused
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import torch | |
import logging | |
logger = logging.getLogger() | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
logger.info("Getting Device") | |
logger.info(device) | |
if torch.cuda.is_available(): | |
num_of_gpus = torch.cuda.device_count() | |
logger.info("Getting gpus") | |
logger.info(num_of_gpus) | |
model = AutoModelForCausalLM.from_pretrained( | |
"E-Hospital/open-orca-platypus-2-lora-medical", | |
trust_remote_code=True | |
) | |
tokenizer = AutoTokenizer.from_pretrained("Open-Orca/OpenOrca-Platypus2-13B", trust_remote_code=True) | |
def ask_bot(question): | |
input_ids = tokenizer.encode(question, return_tensors="pt").to(device) | |
with torch.no_grad(): | |
output = model.generate(input_ids, max_length=200, num_return_sequences=1, do_sample=True, top_k=50) | |
generated_text = tokenizer.decode(output[0], skip_special_tokens=True) | |
response = generated_text.split("->:")[-1] | |
return response | |
import mysql.connector | |
import re | |
from datetime import datetime | |
from typing import Any, List, Mapping, Optional | |
from langchain.memory import ConversationBufferMemory | |
from typing import Any, List, Mapping, Optional | |
from langchain.callbacks.manager import CallbackManagerForLLMRun | |
from langchain.llms.base import LLM | |
from fastapi import FastAPI | |
from fastapi.middleware.cors import CORSMiddleware | |
from langchain.memory import ConversationSummaryBufferMemory | |
from langchain.memory import ConversationSummaryMemory | |
class CustomLLM(LLM): | |
n: int | |
# custom_model: llm # Replace with the actual type of your custom model | |
def _llm_type(self) -> str: | |
return "custom" | |
def _call( | |
self, | |
prompt: str, | |
stop: Optional[List[str]] = None, | |
run_manager: Optional[CallbackManagerForLLMRun] = None, | |
**kwargs: Any, | |
) -> str: | |
if stop is not None: | |
raise ValueError("stop kwargs are not permitted.") | |
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device) | |
with torch.no_grad(): | |
output = model.generate(input_ids, max_length=100, num_return_sequences=1, do_sample=True, top_k=50) | |
generated_text = tokenizer.decode(output[0], skip_special_tokens=True) | |
response = generated_text.split("->:")[-1] | |
return response | |
def _identifying_params(self) -> Mapping[str, Any]: | |
"""Get the identifying parameters.""" | |
return {"n": self.n} | |
class DbHandler(): | |
def __init__(self): | |
self.db_con = mysql.connector.connect( | |
host="frwahxxknm9kwy6c.cbetxkdyhwsb.us-east-1.rds.amazonaws.com", | |
user="j6qbx3bgjysst4jr", | |
password="mcbsdk2s27ldf37t", | |
port=3306, | |
database="nkw2tiuvgv6ufu1z") | |
self.cursorObject = self.db_con.cursor() | |
def insert(self, fields, values): | |
try: | |
# Convert the lists to comma-separated strings | |
fields_str = ', '.join(fields) | |
values_str = ', '.join([f"'{v}'" for v in values]) # Wrap values in single quotes for SQL strings | |
# Construct the SQL query | |
query = f"INSERT INTO chatbot_conversation ({fields_str}) VALUES ({values_str})" | |
self.cursorObject.execute(query) | |
self.db_con.commit() | |
return True | |
except Exception as e: | |
print(e) | |
return False | |
def get_history(self, patient_id): | |
try: | |
query = f"SELECT * FROM chatbot_conversation WHERE patient_id = {patient_id} ORDER BY timestamp ASC;" | |
self.cursorObject.execute(query) | |
data = self.cursorObject.fetchall() | |
return data | |
except Exception as e: | |
print(e) | |
return None | |
def close_db(self): | |
self.db_con.close() | |
def get_conversation_history(db, patient_id): | |
conversations = db.get_history(patient_id) | |
if conversations: | |
return conversations[-1][5] | |
return "" | |
llm = CustomLLM(n=10) | |
app = FastAPI() | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=['*'], | |
allow_credentials=True, | |
allow_methods=['*'], | |
allow_headers=['*'], | |
) | |
async def root(): | |
return {'status': 'running'} | |
def chatbot(patient_id, user_data: dict=None): | |
user_input = user_data["userObject"]["userInput"].get("message") | |
db = DbHandler() | |
try: | |
history = get_conversation_history(db, patient_id) | |
memory = ConversationSummaryBufferMemory(llm=llm, max_token_limit=200) | |
prompt = "You are now a medical chatbot, and I am a patient. I will describe my conditions and symptoms and you will give me medical suggestions" | |
if history: | |
human_input = prompt + "The following is the patient's previous conversation with you: " + history + "This is the current question: " + user_input + " ->:" | |
else: | |
human_input = prompt + user_input + " ->:" | |
human_text = user_input.replace("'", "") | |
# response = llm._call(human_input) | |
response = ask_bot(human_input) | |
# response = response.replace("'", "") | |
# memory.save_context({"input": user_input}, {"output": response}) | |
# summary = memory.load_memory_variables({}) | |
# ai_text = response.replace("'", "") | |
# memory.save_context({"input": user_input}, {"output": ai_text}) | |
# summary = memory.load_memory_variables({}) | |
# db.insert(("patient_id", "patient_text", "ai_text", "timestamp", "summarized_text"), (patient_id, human_text, ai_text, str(datetime.now()), summary['history'].replace("'", ""))) | |
db.close_db() | |
return {"response": response} | |
finally: | |
db.close_db() | |
if __name__=='__main__': | |
uvicorn.run('main:app', reload=True) |