from transformers import DistilBertForSequenceClassification, DistilBertTokenizer import torch import joblib # Load the model and tokenizer model = DistilBertForSequenceClassification.from_pretrained(".") tokenizer = DistilBertTokenizer.from_pretrained(".") # Load the label mapping label_mapping = joblib.load("label_mapping.joblib") def predict(text): # Tokenize the input text inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512) # Make prediction with torch.no_grad(): outputs = model(**inputs) # Get the predicted class predicted_class = torch.argmax(outputs.logits, dim=1).item() # Map the predicted class to its label predicted_label = label_mapping[predicted_class] return predicted_label # Test the function print(predict("Your test text here"))