ai-challenge / src /distilbert.py
Sebastian Kułaga
Initial commit
a6c83a1
raw
history blame
1.96 kB
import json
import pandas as pd
from transformers import pipeline
class Distilbert:
"""
Class for distilBERT
"""
def __init__(self, model_path:str):
self.pipe = pipeline("text-classification", model=model_path)
def predict_query_type(self, custom_query):
"""
Predict the type of a given custom query using a pre-trained classifier.
Parameters:
custom_query (str): The input query for which the type needs to be predicted.
Returns:
str: The predicted label for the input query.
"""
# Get predictions from the classifier
preds = self.pipe(custom_query, top_k=None)
# Get the list of labels from the balanced dataset
labels_dict = self.label_dict()
labels = labels_dict.keys()
# Create a DataFrame from the predictions
preds_df = pd.DataFrame(preds)
# Process the labels to remove 'LABEL_' prefix and convert to integer
preds_df['label'] = preds_df['label'].str.replace('LABEL_', '')
preds_df['label'] = pd.to_numeric(preds_df['label'], errors='coerce').astype('Int64')
preds_df = preds_df.sort_values('label').reset_index(drop=True)
# Find the index of the maximum score
max_score_index = preds_df['score'].idxmax()
# Return the label corresponding to the maximum score
return labels[max_score_index]
def load_lables(self):
"""
This function reads the 'labels.json' file, which contains a dictionary mapping integers to string labels.
The function then returns this dictionary.
Returns:
dict: label dictionary with key as label string value and dictionary value as int
"""
label_dict = {}
with open('labels.json') as json_file:
label_dict = json.load(json_file)
return label_dict