ClefChen commited on
Commit
416611e
1 Parent(s): bd5b53c

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -23
app.py CHANGED
@@ -12,6 +12,19 @@ from model import RNN_model
12
  from timeit import default_timer as timer
13
  from typing import Tuple, Dict
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  # Import data
16
  df= pd.read_csv('Symptom2Disease.csv')
17
  df.drop('Unnamed: 0', axis= 1, inplace= True)
@@ -47,17 +60,36 @@ class_names= {0: 'Acne',
47
  23: 'urinary tract infection'
48
  }
49
 
50
- vectorizer= nltk_u.vectorizer()
51
- vectorizer.fit(train_data.text)
52
-
53
-
54
-
55
- # Model and transforms preparation
56
- model= RNN_model()
57
- # Load state dict
58
- model.load_state_dict(torch.load(
59
- f= 'pretrained_symtom_to_disease_model.pth',
60
- map_location= torch.device('cpu')))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  # Disease Advice
62
  disease_advice = {
63
  'Acne': "Maintain a proper skincare routine, avoid excessive touching of the affected areas, and consider using over-the-counter topical treatments. If severe, consult a dermatologist.",
@@ -90,7 +122,6 @@ howto= """Welcome to the <b>Medical Chatbot</b>, powered by Gradio.
90
  Currently, the chatbot can WELCOME YOU, PREDICT DISEASE based on your symptoms and SUGGEST POSSIBLE SOLUTIONS AND RECOMENDATIONS, and BID YOU FAREWELL.
91
  <b>How to Start:</b> Simply type your messages in the textbox to chat with the Chatbot and press enter!<br><br>
92
  The bot will respond based on the best possible answers to your messages.
93
-
94
  """
95
 
96
 
@@ -175,17 +206,19 @@ with gr.Blocks(css = """#col_container { margin-left: auto; margin-right: auto;}
175
  elif message.lower() in goodbyes:
176
  bot_message= random.choice(goodbye_replies)
177
  else:
 
 
178
  #bot_message= random.choice(goodbye_replies)
179
-
180
- transform_text= vectorizer.transform([message])
181
- transform_text= torch.tensor(transform_text.toarray()).to(torch.float32)
182
- model.eval()
183
- with torch.inference_mode():
184
- y_logits=model(transform_text)
185
- pred_prob= torch.argmax(torch.softmax(y_logits, dim=1), dim=1)
186
 
187
- test_pred= class_names[pred_prob.item()]
188
- bot_message = f' Based on your symptoms, I believe you are having {test_pred} and I would advice you {disease_advice[test_pred]}'
189
  chat_history.append((message, bot_message))
190
  time.sleep(2)
191
  return "", chat_history
@@ -194,5 +227,4 @@ with gr.Blocks(css = """#col_container { margin-left: auto; margin-right: auto;}
194
 
195
 
196
  # Launch the demo
197
- demo.launch()
198
-
 
12
  from timeit import default_timer as timer
13
  from typing import Tuple, Dict
14
 
15
+ import torch
16
+ from transformers import AutoModel, AutoTokenizer
17
+
18
+ # 导入预训练模型和分词器
19
+ model_name = "microsoft/phi-2"
20
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
21
+
22
+ # 设置填充令牌,如果分词器没有默认的填充令牌
23
+ if tokenizer.pad_token is None:
24
+ tokenizer.pad_token = tokenizer.eos_token
25
+
26
+ model = AutoModel.from_pretrained(model_name)
27
+
28
  # Import data
29
  df= pd.read_csv('Symptom2Disease.csv')
30
  df.drop('Unnamed: 0', axis= 1, inplace= True)
 
60
  23: 'urinary tract infection'
61
  }
62
 
63
+ # 数据预处理
64
+ def preprocess(text):
65
+ inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
66
+ return inputs
67
+
68
+ # 模型预测逻辑
69
+ def get_prediction(inputs):
70
+ model.eval()
71
+ with torch.no_grad():
72
+ outputs = model(**inputs)
73
+ logits = outputs.last_hidden_state[:, 0, :] # 取CLS标记的输出进行分类
74
+ pred_prob = torch.softmax(logits, dim=1)
75
+ pred = torch.argmax(pred_prob, dim=1).item()
76
+ if pred in class_names:
77
+ return class_names[pred]
78
+ else:
79
+ print(f"Warning: Prediction index {pred} not found in class_names.")
80
+ return "Unknown" # 或者其他默认的响应
81
+
82
+ # vectorizer= nltk_u.vectorizer()
83
+ # vectorizer.fit(train_data.text)
84
+
85
+
86
+
87
+ # # Model and transforms preparation
88
+ # model= RNN_model()
89
+ # # Load state dict
90
+ # model.load_state_dict(torch.load(
91
+ # f= 'pretrained_symtom_to_disease_model.pth',
92
+ # map_location= torch.device('cpu')))
93
  # Disease Advice
94
  disease_advice = {
95
  'Acne': "Maintain a proper skincare routine, avoid excessive touching of the affected areas, and consider using over-the-counter topical treatments. If severe, consult a dermatologist.",
 
122
  Currently, the chatbot can WELCOME YOU, PREDICT DISEASE based on your symptoms and SUGGEST POSSIBLE SOLUTIONS AND RECOMENDATIONS, and BID YOU FAREWELL.
123
  <b>How to Start:</b> Simply type your messages in the textbox to chat with the Chatbot and press enter!<br><br>
124
  The bot will respond based on the best possible answers to your messages.
 
125
  """
126
 
127
 
 
206
  elif message.lower() in goodbyes:
207
  bot_message= random.choice(goodbye_replies)
208
  else:
209
+ inputs = preprocess(message)
210
+ bot_message = f"Based on your symptoms, I believe you may have {get_prediction(inputs)}."
211
  #bot_message= random.choice(goodbye_replies)
212
+
213
+ # transform_text= vectorizer.transform([message])
214
+ # transform_text= torch.tensor(transform_text.toarray()).to(torch.float32)
215
+ # model.eval()
216
+ # with torch.inference_mode():
217
+ # y_logits=model(transform_text)
218
+ # pred_prob= torch.argmax(torch.softmax(y_logits, dim=1), dim=1)
219
 
220
+ # test_pred= class_names[pred_prob.item()]
221
+ # bot_message = f' Based on your symptoms, I believe you are having {test_pred} and I would advice you {disease_advice[test_pred]}'
222
  chat_history.append((message, bot_message))
223
  time.sleep(2)
224
  return "", chat_history
 
227
 
228
 
229
  # Launch the demo
230
+ demo.launch()