samhardyhey's picture
lint
a4b887b
import gradio as gr
import srsly
from clear_bow.classifier import DictionaryClassifier
from flair.data import Sentence
from flair.models import TARSClassifier
from joblib import load
loaded_models = {
"dictionary_classifier": DictionaryClassifier(
classifier_type="multi_label",
label_dictionary=srsly.read_json(
"./model_files/dictionary_classifier/label_dictionary.json"
),
),
"sklearn_linear_svc": load("./model_files/sklearn_linear_svc/model.joblib"),
"flair_tars": TARSClassifier.load("./model_files/flair_tars/final-model.pt"),
}
def predict_dictionary_classifier(reddit_comment, dictionary_classifier_model):
return dictionary_classifier_model.predict_single(reddit_comment)
def predict_linear_svc(reddit_comment, sklearn_linear_svc_model):
return dict(
zip(
sklearn_linear_svc_model.multi_label_classes_,
sklearn_linear_svc_model.predict([reddit_comment])[0].toarray()[0],
)
)
def predict_flair_tars(text, flair_tars_model):
sentence = Sentence(text)
labels = flair_tars_model.get_current_label_dictionary().get_items()
flair_tars_model.predict(sentence)
pred_dict = {label: 0.0 for label in labels}
for e in sentence.labels:
label = e.to_dict()["value"]
confidence = round(float(e.to_dict()["confidence"]), 2)
pred_dict[label] = confidence
return pred_dict
def model_selector(model_type, reddit_comment):
if model_type == "dictionary_classifier":
return predict_dictionary_classifier(
reddit_comment, loaded_models["dictionary_classifier"]
)
elif model_type == "linear_svc":
return predict_linear_svc(reddit_comment, loaded_models["sklearn_linear_svc"])
elif model_type == "flair_tars":
return predict_flair_tars(reddit_comment, loaded_models["flair_tars"])
demo = gr.Interface(
model_selector,
[gr.Radio(["dictionary_classifier", "linear_svc", "flair_tars"]), "text"],
"text",
examples=[
[
"dictionary_classifier",
"Do you really have a $2,080,000 mortgage for an investment property that rents for $700 a week?",
],
[
"linear_svc",
"I like the genie analogy. Anecdotally from peers and from own experience, if a role is advertised as 100% in office, then it’s a hard no.",
],
[
"flair_tars",
"It’s a seller’s market now, transitioning into a balanced/buyer’s market. Prices are still historically high, but it’s clear the peak is behind us. Rising interest rates doing exactly as expected.",
],
],
title="Few-shot multi-label classification",
description="A comparison of models, ranging from cheap to expensive. Enjoy!",
)
if __name__ == "__main__":
demo.launch()