Spaces:
Build error
Build error
import gradio as gr | |
import srsly | |
from clear_bow.classifier import DictionaryClassifier | |
from flair.data import Sentence | |
from flair.models import TARSClassifier | |
from joblib import load | |
loaded_models = { | |
"dictionary_classifier": DictionaryClassifier( | |
classifier_type="multi_label", | |
label_dictionary=srsly.read_json( | |
"./model_files/dictionary_classifier/label_dictionary.json" | |
), | |
), | |
"sklearn_linear_svc": load("./model_files/sklearn_linear_svc/model.joblib"), | |
"flair_tars": TARSClassifier.load("./model_files/flair_tars/final-model.pt"), | |
} | |
def predict_dictionary_classifier(reddit_comment, dictionary_classifier_model): | |
return dictionary_classifier_model.predict_single(reddit_comment) | |
def predict_linear_svc(reddit_comment, sklearn_linear_svc_model): | |
return dict( | |
zip( | |
sklearn_linear_svc_model.multi_label_classes_, | |
sklearn_linear_svc_model.predict([reddit_comment])[0].toarray()[0], | |
) | |
) | |
def predict_flair_tars(text, flair_tars_model): | |
sentence = Sentence(text) | |
labels = flair_tars_model.get_current_label_dictionary().get_items() | |
flair_tars_model.predict(sentence) | |
pred_dict = {label: 0.0 for label in labels} | |
for e in sentence.labels: | |
label = e.to_dict()["value"] | |
confidence = round(float(e.to_dict()["confidence"]), 2) | |
pred_dict[label] = confidence | |
return pred_dict | |
def model_selector(model_type, reddit_comment): | |
if model_type == "dictionary_classifier": | |
return predict_dictionary_classifier( | |
reddit_comment, loaded_models["dictionary_classifier"] | |
) | |
elif model_type == "linear_svc": | |
return predict_linear_svc(reddit_comment, loaded_models["sklearn_linear_svc"]) | |
elif model_type == "flair_tars": | |
return predict_flair_tars(reddit_comment, loaded_models["flair_tars"]) | |
demo = gr.Interface( | |
model_selector, | |
[gr.Radio(["dictionary_classifier", "linear_svc", "flair_tars"]), "text"], | |
"text", | |
examples=[ | |
[ | |
"dictionary_classifier", | |
"Do you really have a $2,080,000 mortgage for an investment property that rents for $700 a week?", | |
], | |
[ | |
"linear_svc", | |
"I like the genie analogy. Anecdotally from peers and from own experience, if a role is advertised as 100% in office, then it’s a hard no.", | |
], | |
[ | |
"flair_tars", | |
"It’s a seller’s market now, transitioning into a balanced/buyer’s market. Prices are still historically high, but it’s clear the peak is behind us. Rising interest rates doing exactly as expected.", | |
], | |
], | |
title="Few-shot multi-label classification", | |
description="A comparison of models, ranging from cheap to expensive. Enjoy!", | |
) | |
if __name__ == "__main__": | |
demo.launch() | |