Spaces:
Sleeping
Sleeping
import json | |
import pandas as pd | |
from transformers import pipeline | |
class Distilbert: | |
""" | |
Class for distilBERT | |
""" | |
def __init__(self, model_path:str): | |
self.pipe = pipeline("text-classification", model=model_path) | |
def predict_query_type(self, custom_query): | |
""" | |
Predict the type of a given custom query using a pre-trained classifier. | |
Parameters: | |
custom_query (str): The input query for which the type needs to be predicted. | |
Returns: | |
str: The predicted label for the input query. | |
""" | |
# Get predictions from the classifier | |
preds = self.pipe(custom_query, top_k=None) | |
# Get the list of labels from the balanced dataset | |
labels_dict = self.label_dict() | |
labels = labels_dict.keys() | |
# Create a DataFrame from the predictions | |
preds_df = pd.DataFrame(preds) | |
# Process the labels to remove 'LABEL_' prefix and convert to integer | |
preds_df['label'] = preds_df['label'].str.replace('LABEL_', '') | |
preds_df['label'] = pd.to_numeric(preds_df['label'], errors='coerce').astype('Int64') | |
preds_df = preds_df.sort_values('label').reset_index(drop=True) | |
# Find the index of the maximum score | |
max_score_index = preds_df['score'].idxmax() | |
# Return the label corresponding to the maximum score | |
return labels[max_score_index] | |
def load_lables(self): | |
""" | |
This function reads the 'labels.json' file, which contains a dictionary mapping integers to string labels. | |
The function then returns this dictionary. | |
Returns: | |
dict: label dictionary with key as label string value and dictionary value as int | |
""" | |
label_dict = {} | |
with open('labels.json') as json_file: | |
label_dict = json.load(json_file) | |
return label_dict | |