Spaces:
Runtime error
Runtime error
from transformers import AutoModelForSequenceClassification, AutoTokenizer | |
# Replace with your actual model name from Hugging Face Hub | |
model_name = "MSarah/ComplainClassificationFineTunedBert" | |
# Load the model and tokenizer | |
model = AutoModelForSequenceClassification.from_pretrained(model_name) | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
def predict(text): | |
""" | |
Classifies text using the loaded model. | |
Args: | |
text (str): The text to be classified. | |
Returns: | |
str: The predicted complaint category. | |
""" | |
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True) | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
logits = outputs.logits | |
predicted_class = torch.argmax(logits, dim=-1).item() | |
# Map predicted class to category label | |
if predicted_class == 0: | |
return "Account Services" | |
elif predicted_class == 1: | |
return "Others" | |
elif predicted_class == 2: | |
return "Mortgage/Loan" | |
elif predicted_class == 3: | |
return "Credit card or prepaid card" | |
else: | |
return "Theft/Dispute Reporting" | |
# Import libraries for Gradio interface (optional) | |
import gradio as gr | |
# Define Gradio interface (optional) | |
iface = gr.Interface( | |
fn=predict, | |
inputs="text", | |
outputs="text", | |
title="Complaint Classification", | |
description="Enter text to classify the complaint category" | |
) | |
# Launch the interface locally (optional for testing) | |
# iface.launch() | |