HamOrSpam / app.py
Rodrigo Uribe
changed outputs
674d60f
raw
history blame contribute delete
No virus
1.48 kB
import gradio as gr
from transformers import TFDistilBertForSequenceClassification, DistilBertTokenizerFast
import tensorflow as tf
# Load the model and tokenizer from Hugging Face Hub
model_name = "Buebito/HamOrSpam_Model" # Replace with your actual model path
tokenizer = DistilBertTokenizerFast.from_pretrained(model_name)
model = TFDistilBertForSequenceClassification.from_pretrained(model_name)
def classify_text(text):
# Tokenize the text input and prepare it for the model
inputs = tokenizer(text, return_tensors="tf", truncation=True, padding=True, max_length=512)
# Get model predictions
predictions = model(inputs.data)[0]
# Convert predictions to probabilities using softmax
probabilities = tf.nn.softmax(predictions, axis=-1)
# Get the higher probability index
prediction_index = tf.argmax(probabilities, axis=-1).numpy()[0]
# Convert the index to label
label = "ham" if prediction_index == 0 else "spam"
# Get the probabilities of each class
ham_prob = probabilities[0][0].numpy()
spam_prob = probabilities[0][1].numpy()
# Return the label and the probabilities separately
return label, {"ham": ham_prob, "spam": spam_prob}
# Create the Gradio interface
iface = gr.Interface(
fn=classify_text,
inputs=gr.Textbox(lines=2, placeholder="Enter Text Here..."),
outputs=[
gr.Label(label="Classification"),
gr.JSON(label="Probabilities")
]
)
# Launch the app
iface.launch()