hazarri's picture
Update app.py
b397916 verified
from transformers import pipeline
import gradio as gr
# Load the DeBERTa zero-shot classifier
classifier = pipeline(
"zero-shot-classification",
model="MoritzLaurer/deberta-v3-large-zeroshot-v2.0"
)
# Default candidate labels (can be overridden by API)
DEFAULT_LABELS = ["mild", "moderate", "severe", "life-threatening", "death"]
def classify_side_effect(text, candidate_labels=None):
"""
Classify the severity of a side effect using zero-shot classification.
Args:
text (str): Input text describing the side effect.
candidate_labels (list[str]): Optional list of labels.
Returns:
dict: Predicted labels and scores.
"""
if not text.strip():
return {"error": "Empty input"}
labels = candidate_labels if candidate_labels else DEFAULT_LABELS
result = classifier(text, candidate_labels=labels)
# Return structured output for API
return {
"labels": result["labels"],
"scores": [float(s) for s in result["scores"]],
"top_label": result["labels"][0],
"top_score": float(result["scores"][0])
}
# Define the API endpoint (for programmatic use)
api = gr.Interface(
fn=classify_side_effect,
inputs=[
gr.Textbox(label="Side Effect Text"),
gr.Textbox(label="Candidate Labels (comma-separated, optional)")
],
outputs=gr.JSON(label="Classification Result"),
title="Zero-Shot ADR Severity Classifier API",
description="Predicts the severity level of a side effect using DeBERTa-v3 large zero-shot classification."
)
# Add a user-friendly UI for manual testing
demo = gr.Interface(
fn=lambda text: classify_side_effect(text),
inputs=gr.Textbox(label="Enter a side effect or sentence"),
outputs=gr.Label(label="Top Predicted Severity"),
title="Zero-Shot ADR Severity Classifier",
description="Classifies a side effect sentence into severity categories using DeBERTa v3."
)
# Combine both: API and Demo
demo_and_api = gr.TabbedInterface(
[demo, api],
["Demo", "API"]
)
if __name__ == "__main__":
demo_and_api.launch()