Michael-Geis
added slider bar for user controlled confidence on tags
7af8946
raw
history blame
1.55 kB
import json
def postprocess(model_output, threshold_probability):
with open("./data/arxiv-label-dict.json", "r") as file:
subject_dict = json.loads(file.read())
predicted_tags = [
result["label"]
for result in model_output[0]
if result["score"] > threshold_probability
]
return sorted([subject_dict[tag] for tag in predicted_tags])
# class ModelOutputDecoder(BaseEstimator, TransformerMixin):
# def fit(self, X, y=None):
# return self
# def transform(self, X, y=None):
# if y is None:
# return X
# ## Load label dictionary
# with open("./data/arxiv-label-dict.json") as file:
# string_dict = file.read()
# label_dict = json.loads(string_dict)
# col_list = list(label_dict.keys())
# def decode_label(label):
# ## For a row of y (individual label) returns the list of english subjects corresponding to this label
# return [label_dict[col_list[index]] for index in np.where(label == 1)[0]]
# num_rows, _ = y.shape
# decoded_labels = []
# for i in range(num_rows):
# decoded_labels.append(decode_label(y[i, :]))
# decoded_labels_as_series = pd.Series(
# decoded_labels, name="decoded_labels", index=X.index
# )
# return pd.merge(
# left=X,
# left_index=True,
# right=decoded_labels_as_series,
# right_index=True,
# validate="1:1",
# )