|
from transformers import DistilBertForSequenceClassification, DistilBertTokenizer |
|
import torch |
|
import joblib |
|
|
|
|
|
model = DistilBertForSequenceClassification.from_pretrained(".") |
|
tokenizer = DistilBertTokenizer.from_pretrained(".") |
|
|
|
|
|
label_mapping = joblib.load("label_mapping.joblib") |
|
|
|
def predict(text): |
|
|
|
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512) |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
|
|
|
|
predicted_class = torch.argmax(outputs.logits, dim=1).item() |
|
|
|
|
|
predicted_label = label_mapping[predicted_class] |
|
|
|
return predicted_label |
|
|
|
|
|
print(predict("Your test text here")) |