E-Hospital commited on
Commit
31ae848
1 Parent(s): 29c74d5

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +17 -53
main.py CHANGED
@@ -1,38 +1,27 @@
1
- # -*- coding: utf-8 -*-
2
- """TestAPI.ipynb
3
-
4
- Automatically generated by Colaboratory.
5
-
6
- Original file is located at
7
- https://colab.research.google.com/drive/1WToaz7kQoFpI0_M8j6uWPigBrKlkL4ml
8
- """
9
-
10
-
11
-
12
- from transformers import AutoTokenizer,AutoModelForCausalLM
13
- import os
14
- os.environ["CUDA_VISIBLE_DEVICES"]="0"
15
- import torch
16
  from transformers import AutoTokenizer, AutoModelForCausalLM
 
17
 
18
- from typing import List
19
-
 
 
 
 
20
 
21
- from peft import PeftModel
22
- from transformers import AutoTokenizer, AutoModelForCausalLM
23
- from typing import Any, List, Mapping, Optional
24
- import torch
25
- from transformers import AutoTokenizer, AutoModelForCausalLM
26
- from langchain.callbacks.manager import CallbackManagerForLLMRun
27
- from langchain.llms.base import LLM
28
 
29
- import torch
30
- from transformers import AutoTokenizer, AutoModelForCausalLM
31
 
32
  import mysql.connector
33
  import re
34
  from datetime import datetime
35
-
36
 
37
  from langchain.memory import ConversationBufferMemory
38
  from typing import Any, List, Mapping, Optional
@@ -46,28 +35,6 @@ from langchain.memory import ConversationSummaryBufferMemory
46
 
47
  from langchain.memory import ConversationSummaryMemory
48
 
49
-
50
- model_name = "Open-Orca/OpenOrca-Platypus2-13B"
51
-
52
-
53
- tokenizer = AutoTokenizer.from_pretrained(
54
- model_name, trust_remote_code=True)
55
-
56
- model = AutoModelForCausalLM.from_pretrained(
57
- model_name,
58
- trust_remote_code=True,
59
- load_in_8bit = True,
60
- device_map = "auto",
61
- )
62
-
63
-
64
-
65
- model = PeftModel.from_pretrained(model, "teslalord/open-orca-platypus-2-medical")
66
-
67
- model = model.merge_and_unload()
68
-
69
-
70
-
71
  class CustomLLM(LLM):
72
  n: int
73
  # custom_model: llm # Replace with the actual type of your custom model
@@ -137,7 +104,7 @@ class DbHandler():
137
 
138
  def get_history(self, patient_id):
139
  try:
140
- query = f"SELECT * FROM chatbot_conversation WHERE patient_id = '{patient_id}' ORDER BY timestamp ASC;"
141
  self.cursorObject.execute(query)
142
  data = self.cursorObject.fetchall()
143
  return data
@@ -159,8 +126,6 @@ def get_conversation_history(db, patient_id):
159
  return conversations[-1][5]
160
  return ""
161
 
162
-
163
-
164
  llm = CustomLLM(n=10)
165
  app = FastAPI()
166
 
@@ -201,4 +166,3 @@ def chatbot(patient_id, user_data: dict=None):
201
  return {"response": response}
202
  finally:
203
  db.close_db()
204
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from transformers import AutoTokenizer, AutoModelForCausalLM
2
+ import torch
3
 
4
+ model = AutoModelForCausalLM.from_pretrained(
5
+ "E-Hospital/open-orca-platypus-2-lora-medical",
6
+ trust_remote_code=True,
7
+ device_map = "auto",
8
+ )
9
+ tokenizer = AutoTokenizer.from_pretrained("Open-Orca/OpenOrca-Platypus2-13B", trust_remote_code=True)
10
 
11
+ def ask_bot(question):
12
+ input_ids = tokenizer.encode(question, return_tensors="pt").to('cuda')
13
+ with torch.no_grad():
14
+ output = model.generate(input_ids, max_length=500, num_return_sequences=1, do_sample=True, top_k=50)
15
+ generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
16
+ response = generated_text.split("->:")[-1]
17
+ return response
18
 
19
+ ask_bot("I have diabetes. What should I do?")
 
20
 
21
  import mysql.connector
22
  import re
23
  from datetime import datetime
24
+ from typing import Any, List, Mapping, Optional
25
 
26
  from langchain.memory import ConversationBufferMemory
27
  from typing import Any, List, Mapping, Optional
 
35
 
36
  from langchain.memory import ConversationSummaryMemory
37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  class CustomLLM(LLM):
39
  n: int
40
  # custom_model: llm # Replace with the actual type of your custom model
 
104
 
105
  def get_history(self, patient_id):
106
  try:
107
+ query = f"SELECT * FROM chatbot_conversation WHERE patient_id = {patient_id} ORDER BY timestamp ASC;"
108
  self.cursorObject.execute(query)
109
  data = self.cursorObject.fetchall()
110
  return data
 
126
  return conversations[-1][5]
127
  return ""
128
 
 
 
129
  llm = CustomLLM(n=10)
130
  app = FastAPI()
131
 
 
166
  return {"response": response}
167
  finally:
168
  db.close_db()