ClefChen commited on
Commit
1360799
1 Parent(s): fe27810

Update app_phi2.py

Browse files
Files changed (1) hide show
  1. app_phi2.py +42 -17
app_phi2.py CHANGED
@@ -12,6 +12,14 @@ from model import RNN_model
12
  from timeit import default_timer as timer
13
  from typing import Tuple, Dict
14
 
 
 
 
 
 
 
 
 
15
  # Import data
16
  df= pd.read_csv('Symptom2Disease.csv')
17
  df.drop('Unnamed: 0', axis= 1, inplace= True)
@@ -47,17 +55,32 @@ class_names= {0: 'Acne',
47
  23: 'urinary tract infection'
48
  }
49
 
50
- vectorizer= nltk_u.vectorizer()
51
- vectorizer.fit(train_data.text)
 
 
52
 
 
 
 
 
 
 
 
 
 
53
 
 
 
54
 
55
- # Model and transforms preparation
56
- model= RNN_model()
57
- # Load state dict
58
- model.load_state_dict(torch.load(
59
- f= 'pretrained_symtom_to_disease_model.pth',
60
- map_location= torch.device('cpu')))
 
 
61
  # Disease Advice
62
  disease_advice = {
63
  'Acne': "Maintain a proper skincare routine, avoid excessive touching of the affected areas, and consider using over-the-counter topical treatments. If severe, consult a dermatologist.",
@@ -174,17 +197,19 @@ with gr.Blocks(css = """#col_container { margin-left: auto; margin-right: auto;}
174
  elif message.lower() in goodbyes:
175
  bot_message= random.choice(goodbye_replies)
176
  else:
 
 
177
  #bot_message= random.choice(goodbye_replies)
178
-
179
- transform_text= vectorizer.transform([message])
180
- transform_text= torch.tensor(transform_text.toarray()).to(torch.float32)
181
- model.eval()
182
- with torch.inference_mode():
183
- y_logits=model(transform_text)
184
- pred_prob= torch.argmax(torch.softmax(y_logits, dim=1), dim=1)
185
 
186
- test_pred= class_names[pred_prob.item()]
187
- bot_message = f' Based on your symptoms, I believe you are having {test_pred} and I would advice you {disease_advice[test_pred]}'
188
  chat_history.append((message, bot_message))
189
  time.sleep(2)
190
  return "", chat_history
 
12
  from timeit import default_timer as timer
13
  from typing import Tuple, Dict
14
 
15
+ import torch
16
+ from transformers import AutoModel, AutoTokenizer
17
+
18
+ # 导入预训练模型和分词器
19
+ model_name = "microsoft/phi2-base"
20
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
21
+ model = AutoModel.from_pretrained(model_name)
22
+
23
  # Import data
24
  df= pd.read_csv('Symptom2Disease.csv')
25
  df.drop('Unnamed: 0', axis= 1, inplace= True)
 
55
  23: 'urinary tract infection'
56
  }
57
 
58
+ # 数据预处理
59
+ def preprocess(text):
60
+ inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
61
+ return inputs
62
 
63
+ # 模型预测逻辑
64
+ def get_prediction(inputs):
65
+ model.eval()
66
+ with torch.no_grad():
67
+ outputs = model(**inputs)
68
+ logits = outputs.last_hidden_state[:, 0, :] # 取CLS标记的输出进行分类
69
+ pred_prob = torch.softmax(logits, dim=1)
70
+ pred = torch.argmax(pred_prob, dim=1)
71
+ return class_names[pred.item()]
72
 
73
+ # vectorizer= nltk_u.vectorizer()
74
+ # vectorizer.fit(train_data.text)
75
 
76
+
77
+
78
+ # # Model and transforms preparation
79
+ # model= RNN_model()
80
+ # # Load state dict
81
+ # model.load_state_dict(torch.load(
82
+ # f= 'pretrained_symtom_to_disease_model.pth',
83
+ # map_location= torch.device('cpu')))
84
  # Disease Advice
85
  disease_advice = {
86
  'Acne': "Maintain a proper skincare routine, avoid excessive touching of the affected areas, and consider using over-the-counter topical treatments. If severe, consult a dermatologist.",
 
197
  elif message.lower() in goodbyes:
198
  bot_message= random.choice(goodbye_replies)
199
  else:
200
+ inputs = preprocess(message)
201
+ bot_message = f"Based on your symptoms, I believe you may have {get_prediction(inputs)}."
202
  #bot_message= random.choice(goodbye_replies)
203
+
204
+ # transform_text= vectorizer.transform([message])
205
+ # transform_text= torch.tensor(transform_text.toarray()).to(torch.float32)
206
+ # model.eval()
207
+ # with torch.inference_mode():
208
+ # y_logits=model(transform_text)
209
+ # pred_prob= torch.argmax(torch.softmax(y_logits, dim=1), dim=1)
210
 
211
+ # test_pred= class_names[pred_prob.item()]
212
+ # bot_message = f' Based on your symptoms, I believe you are having {test_pred} and I would advice you {disease_advice[test_pred]}'
213
  chat_history.append((message, bot_message))
214
  time.sleep(2)
215
  return "", chat_history