# Import and class names setup
import gradio as gr
import os
import torch
import random
import nltk_u
import pandas as pd
from sklearn.model_selection import train_test_split
import time
from model import RNN_model
from timeit import default_timer as timer
from typing import Tuple, Dict
import torch
from transformers import AutoModel, AutoTokenizer
# 导入预训练模型和分词器
model_name = "microsoft/phi-2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
# 设置填充令牌,如果分词器没有默认的填充令牌
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model = AutoModel.from_pretrained(model_name)
# Import data
df= pd.read_csv('Symptom2Disease.csv')
df.drop('Unnamed: 0', axis= 1, inplace= True)
# Preprocess data
df.drop_duplicates(inplace= True)
train_data, test_data= train_test_split(df, test_size=0.15, random_state=42 )
# Setup class names
class_names= {0: 'Acne',
1: 'Arthritis',
2: 'Bronchial Asthma',
3: 'Cervical spondylosis',
4: 'Chicken pox',
5: 'Common Cold',
6: 'Dengue',
7: 'Dimorphic Hemorrhoids',
8: 'Fungal infection',
9: 'Hypertension',
10: 'Impetigo',
11: 'Jaundice',
12: 'Malaria',
13: 'Migraine',
14: 'Pneumonia',
15: 'Psoriasis',
16: 'Typhoid',
17: 'Varicose Veins',
18: 'allergy',
19: 'diabetes',
20: 'drug reaction',
21: 'gastroesophageal reflux disease',
22: 'peptic ulcer disease',
23: 'urinary tract infection'
}
# 数据预处理
def preprocess(text):
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
return inputs
# 模型预测逻辑
def get_prediction(inputs):
model.eval()
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.last_hidden_state[:, 0, :] # 取CLS标记的输出进行分类
pred_prob = torch.softmax(logits, dim=1)
pred = torch.argmax(pred_prob, dim=1).item()
if pred in class_names:
return class_names[pred]
else:
print(f"Warning: Prediction index {pred} not found in class_names.")
return "Unknown" # 或者其他默认的响应
# vectorizer= nltk_u.vectorizer()
# vectorizer.fit(train_data.text)
# # Model and transforms preparation
# model= RNN_model()
# # Load state dict
# model.load_state_dict(torch.load(
# f= 'pretrained_symtom_to_disease_model.pth',
# map_location= torch.device('cpu')))
# Disease Advice
disease_advice = {
'Acne': "Maintain a proper skincare routine, avoid excessive touching of the affected areas, and consider using over-the-counter topical treatments. If severe, consult a dermatologist.",
'Arthritis': "Stay active with gentle exercises, manage weight, and consider pain-relief strategies like hot/cold therapy. Consult a rheumatologist for tailored guidance.",
'Bronchial Asthma': "Follow prescribed inhaler and medication regimen, avoid triggers like smoke and allergens, and have an asthma action plan. Regular check-ups with a pulmonologist are important.",
'Cervical spondylosis': "Maintain good posture, do neck exercises, and use ergonomic support. Physical therapy and pain management techniques might be helpful.",
'Chicken pox': "Rest, maintain hygiene, and avoid scratching. Consult a doctor for appropriate antiviral treatment.",
'Common Cold': "Get plenty of rest, stay hydrated, and consider over-the-counter remedies for symptom relief. Seek medical attention if symptoms worsen or last long.",
'Dengue': "Stay hydrated, rest, and manage fever with acetaminophen. Seek medical care promptly, as dengue can escalate quickly.",
'Dimorphic Hemorrhoids': "Follow a high-fiber diet, maintain good hygiene, and consider stool softeners. Consult a doctor if symptoms persist.",
'Fungal infection': "Keep the affected area clean and dry, use antifungal creams, and avoid sharing personal items. Consult a dermatologist if it persists.",
'Hypertension': "Follow a balanced diet, exercise regularly, reduce salt intake, and take prescribed medications. Regular check-ups with a healthcare provider are important.",
'Impetigo': "Keep the affected area clean, use prescribed antibiotics, and avoid close contact. Consult a doctor for proper treatment.",
'Jaundice': "Get plenty of rest, maintain hydration, and follow a doctor's advice for diet and medications. Regular monitoring is important.",
'Malaria': "Take prescribed antimalarial medications, rest, and manage fever. Seek medical attention for severe cases.",
'Migraine': "Identify triggers, manage stress, and consider pain-relief medications. Consult a neurologist for personalized management.",
'Pneumonia': "Follow prescribed antibiotics, rest, stay hydrated, and monitor symptoms. Seek immediate medical attention for severe cases.",
'Psoriasis': "Moisturize, use prescribed creams, and avoid triggers. Consult a dermatologist for effective management.",
'Typhoid': "Take prescribed antibiotics, rest, and stay hydrated. Dietary precautions are important. Consult a doctor for proper treatment.",
'Varicose Veins': "Elevate legs, exercise regularly, and wear compression stockings. Consult a vascular specialist for evaluation and treatment options.",
'allergy': "Identify triggers, manage exposure, and consider antihistamines. Consult an allergist for comprehensive management.",
'diabetes': "Follow a balanced diet, exercise, monitor blood sugar levels, and take prescribed medications. Regular visits to an endocrinologist are essential.",
'drug reaction': "Discontinue the suspected medication, seek medical attention if symptoms are severe, and inform healthcare providers about the reaction.",
'gastroesophageal reflux disease': "Follow dietary changes, avoid large meals, and consider medications. Consult a doctor for personalized management.",
'peptic ulcer disease': "Avoid spicy and acidic foods, take prescribed medications, and manage stress. Consult a gastroenterologist for guidance.",
'urinary tract infection': "Stay hydrated, take prescribed antibiotics, and maintain good hygiene. Consult a doctor for appropriate treatment."
}
howto= """Welcome to the Medical Chatbot, powered by Gradio.
Currently, the chatbot can WELCOME YOU, PREDICT DISEASE based on your symptoms and SUGGEST POSSIBLE SOLUTIONS AND RECOMENDATIONS, and BID YOU FAREWELL.
How to Start: Simply type your messages in the textbox to chat with the Chatbot and press enter!
The bot will respond based on the best possible answers to your messages.
"""
# Create the gradio demo
with gr.Blocks(css = """#col_container { margin-left: auto; margin-right: auto;} #chatbot {height: 520px; overflow: auto;}""") as demo:
gr.HTML('