ashishkgpian's picture
Update app.py
a7ec3bd verified
raw
history blame
1.62 kB
import gradio as gr
from transformers import pipeline
# Load the model
classifier = pipeline(
"text-classification",
model="ashishkgpian/biobert_icd9_classifier_ehr"
)
def classify_symptoms(text):
"""
Classify medical symptoms and return top ICD9 codes
Args:
text (str): Input medical symptom description
Returns:
dict: Top classification results with ICD9 codes and probabilities
"""
try:
# Get classification results
results = classifier(text, top_k=5)
# Format results for more readable output
formatted_results = []
for result in results:
formatted_results.append({
"ICD9 Code": result['label'],
"Confidence": f"{result['score']:.2%}"
})
return formatted_results
except Exception as e:
return f"Error processing classification: {str(e)}"
# Create Gradio interface
demo = gr.Interface(
fn=classify_symptoms,
inputs=gr.Textbox(
label="Enter Medical Symptoms",
placeholder="Describe patient symptoms here..."
),
outputs=gr.JSON(label="Top 5 ICD9 Classifications"),
title="BioBERT ICD9 Symptom Classifier",
description="Classify medical symptoms into ICD9 diagnostic codes using a fine-tuned BioBERT model.",
theme="huggingface",
examples=[
["Patient experiencing chest pain and shortness of breath"],
["Recurring headaches with nausea"],
["Diabetic symptoms including frequent urination"]
]
)
if __name__ == "__main__":
demo.launch()