Case-Classification / distillbert-classification-finetuning.py
Jorge Fioranelli
Added zero-shot and distilbert models
6f0a968
raw
history blame
1.76 kB
import ktrain
from ktrain import text
import pandas as pd
from sklearn.model_selection import train_test_split
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
data = pd.read_csv('data/internet_provider.csv') # Replace 'data.csv' with your actual file name
categories = ['Slow Connection', 'Billing', 'Setup', 'No Connectivity']
train_data, temp_data = train_test_split(data, test_size=0.2, random_state=42, shuffle=True)
val_data, test_data = train_test_split(temp_data, test_size=0.5, random_state=42, shuffle=True)
model_name = "distilbert-base-uncased"
model = text.Transformer(model_name=model_name, maxlen=512, class_names=categories)
train_data = model.preprocess_train(train_data["Text"].tolist(), train_data["Category"].tolist())
val_data = model.preprocess_train(val_data["Text"].tolist(), val_data["Category"].tolist())
test_data = model.preprocess_train(test_data["Text"].tolist(), test_data["Category"].tolist())
classifier = model.get_classifier()
learner = ktrain.get_learner(classifier, train_data=train_data, val_data=val_data, batch_size=16)
learner.lr_find(show_plot=True, max_epochs=20)
learner.fit_onecycle(0.0001, 1)
learner.validate(class_names=categories)
learner.view_top_losses(n=5, preproc=model)
print(train_data.iloc[100])
predictor = ktrain.get_predictor(learner.model, preproc=model)
x = "I have issues with my internet connection"
prediction = predictor.predict(x)
print(f"prediction: {prediction}")
print(predictor.explain(x))
predictor.save("distilbest-model")
predictor = ktrain.load_predictor("distilbest-model")
x = "I have issues with my internet connection"
prediction = predictor.predict(x)
print(f"prediction: {prediction}")
print(predictor.explain(x))