DexterSptizu's picture
Update app.py
4c072f6 verified
raw
history blame
2.48 kB
import gradio as gr
from transformers import pipeline
# Load the token classification model
pipe = pipeline("token-classification", model="Clinical-AI-Apollo/Medical-NER", aggregation_strategy='simple')
# Define colors for different entity types
entity_colors = {
"AGE": "#ffadad",
"SEX": "#ffd6a5",
"DISEASE_DISORDER": "#caffbf",
"SIGN_SYMPTOM": "#9bf6ff",
"LAB_VALUE": "#a0c4ff",
"THERAPEUTIC_PROCEDURE": "#bdb2ff",
"CLINICAL_EVENT": "#ffc6ff",
"DIAGNOSTIC_PROCEDURE": "#fffffc",
"DETAILED_DESCRIPTION": "#fdffb6",
"BIOLOGICAL_STRUCTURE": "#ffb5a7"
}
def classify_text(text):
# Get token classification results
result = pipe(text)
# Format the results into HTML with color highlighting and entity names
highlighted_text = ""
last_pos = 0
for res in result:
entity = res['entity_group']
word = res['word']
start = res['start']
end = res['end']
# Add text before the entity without highlighting
highlighted_text += text[last_pos:start]
# Add highlighted entity text with the entity name displayed
color = entity_colors.get(entity, "#e0e0e0") # Default to gray if entity type not defined
highlighted_text += f"""
<span style='background-color:{color}; padding:2px; border-radius:5px;'>
{word}
<span style='display:inline-block; background-color:#fff; color:#000; border-radius:3px; padding:2px; margin-left:5px; font-size:10px;'>{entity}</span>
</span>"""
# Update last position
last_pos = end
# Add the rest of the text after the last entity
highlighted_text += text[last_pos:]
return f"<div style='font-family: Arial, sans-serif; line-height: 1.5;'>{highlighted_text}</div>"
# Gradio Interface
demo = gr.Interface(
fn=classify_text,
inputs=gr.Textbox(lines=5, label="Enter Medical Text"),
outputs=gr.HTML(label="Entity Classification with Highlighting and Labels"),
title="Medical Entity Classification",
description="Enter medical-related text, and the model will classify medical entities with color highlighting and labels.",
examples=[
["45 year old woman diagnosed with CAD"],
["A 65-year-old male presents with acute chest pain and a history of hypertension."],
["The patient underwent a laparoscopic cholecystectomy."]
]
)
if __name__ == "__main__":
demo.launch()