MSarah's picture
Update app.py
04703f6 verified
raw
history blame contribute delete
No virus
1.45 kB
from transformers import AutoModelForSequenceClassification, AutoTokenizer
# Replace with your actual model name from Hugging Face Hub
model_name = "MSarah/ComplainClassificationFineTunedBert"
# Load the model and tokenizer
model = AutoModelForSequenceClassification.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
def predict(text):
"""
Classifies text using the loaded model.
Args:
text (str): The text to be classified.
Returns:
str: The predicted complaint category.
"""
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
predicted_class = torch.argmax(logits, dim=-1).item()
# Map predicted class to category label
if predicted_class == 0:
return "Account Services"
elif predicted_class == 1:
return "Others"
elif predicted_class == 2:
return "Mortgage/Loan"
elif predicted_class == 3:
return "Credit card or prepaid card"
else:
return "Theft/Dispute Reporting"
# Import libraries for Gradio interface (optional)
import gradio as gr
# Define Gradio interface (optional)
iface = gr.Interface(
fn=predict,
inputs="text",
outputs="text",
title="Complaint Classification",
description="Enter text to classify the complaint category"
)
# Launch the interface locally (optional for testing)
# iface.launch()