drug-ner / app.py
muzamilhxmi
Update app.py
0bf47ff verified
import gradio as gr
import torch
from transformers import pipeline, AutoModelForTokenClassification, AutoTokenizer
# πŸ”₯ Define Label Mapping Dictionary
LABEL_MAP = {
0: "DOSAGE",
1: "DRUG_NAME",
2: "EVENT",
3: "LOCATION",
4: "OTHER",
5: "ROA",
6: "SYMPTOM",
7: "TEMPORAL",
}
# πŸ”Ή Replace with your Hugging Face Model
MODEL_NAME = "blaikhole/distilbert-drug-ner"
# βœ… Load Model & Tokenizer
device = 0 if torch.cuda.is_available() else -1 # Use GPU if available
model = AutoModelForTokenClassification.from_pretrained(MODEL_NAME).to("cuda" if device == 0 else "cpu")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
# βœ… Load Hugging Face NER Pipeline
ner_pipeline = pipeline("ner", model=model, tokenizer=tokenizer, device=device)
def merge_subwords_and_decode(entities):
"""Merge subwords and map label IDs to actual labels, treating 'OTHER' tokens as normal text."""
merged_entities = []
current_word = ""
current_label = None
for entity in entities:
word = entity["word"]
label_id = int(entity["entity"].replace("LABEL_", ""))
label = LABEL_MAP.get(label_id, "OTHER")
# Treat OTHER as normal text by setting its label to None
if label == "OTHER":
label = None
# Merge subwords: if the token starts with '##', append it without space
if word.startswith("##"):
current_word += word[2:]
else:
# If the current token and new token share the same label, add a space before appending
if current_word and current_label == label:
current_word += " " + word
else:
# Append the current token segment to the result before starting a new one
if current_word:
merged_entities.append((current_word, current_label))
current_word = word
current_label = label
if current_word:
merged_entities.append((current_word, current_label))
return merged_entities
# πŸ”Ή Function to Run NER
def analyze_text(text):
entities = ner_pipeline(text) # Get raw predictions
return merge_subwords_and_decode(entities) # Fix labels & subwords
# πŸ”Ή Example Sentence
EXAMPLE_SENT = "The patient was prescribed oxycodone and experienced dizziness."
# πŸ”Ή Gradio UI
with gr.Blocks() as demo:
gr.Markdown("# πŸ₯ Drug Named Entity Recognition (NER)")
text_input = gr.Textbox(label="Enter Text", lines=5, value=EXAMPLE_SENT)
analyze_button = gr.Button("Run NER Model")
output = gr.HighlightedText(label="NER Result", combine_adjacent=True)
analyze_button.click(analyze_text, inputs=[text_input], outputs=[output])
demo.launch()