Case-Classification / distillbert-classification-run.py
Jorge Fioranelli
Added zero-shot and distilbert models
6f0a968
raw
history blame
563 Bytes
import ktrain
from ktrain import text
import pandas as pd
from sklearn.model_selection import train_test_split
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
predictor = ktrain.load_predictor("models/distilbert-base-uncased-finetuned-internet-provider")
x = "I have issues with my internet connection"
prediction = predictor.predict(x)
print(f"prediction: {prediction}")
labels = predictor.get_classes()
probs = predictor.predict_proba(x)
for i, label in enumerate(labels):
print(label, ":", probs[i])