Spaces:
Build error
Build error
| import gradio as gr | |
| import srsly | |
| from clear_bow.classifier import DictionaryClassifier | |
| from flair.data import Sentence | |
| from flair.models import TARSClassifier | |
| from joblib import load | |
| loaded_models = { | |
| "dictionary_classifier": DictionaryClassifier( | |
| classifier_type="multi_label", | |
| label_dictionary=srsly.read_json( | |
| "./model_files/dictionary_classifier/label_dictionary.json" | |
| ), | |
| ), | |
| "sklearn_linear_svc": load("./model_files/sklearn_linear_svc/model.joblib"), | |
| "flair_tars": TARSClassifier.load("./model_files/flair_tars/final-model.pt"), | |
| } | |
| def predict_dictionary_classifier(reddit_comment, dictionary_classifier_model): | |
| return dictionary_classifier_model.predict_single(reddit_comment) | |
| def predict_linear_svc(reddit_comment, sklearn_linear_svc_model): | |
| return dict( | |
| zip( | |
| sklearn_linear_svc_model.multi_label_classes_, | |
| sklearn_linear_svc_model.predict([reddit_comment])[0].toarray()[0], | |
| ) | |
| ) | |
| def predict_flair_tars(text, flair_tars_model): | |
| sentence = Sentence(text) | |
| labels = flair_tars_model.get_current_label_dictionary().get_items() | |
| flair_tars_model.predict(sentence) | |
| pred_dict = {label: 0.0 for label in labels} | |
| for e in sentence.labels: | |
| label = e.to_dict()["value"] | |
| confidence = round(float(e.to_dict()["confidence"]), 2) | |
| pred_dict[label] = confidence | |
| return pred_dict | |
| def model_selector(model_type, reddit_comment): | |
| if model_type == "dictionary_classifier": | |
| return predict_dictionary_classifier( | |
| reddit_comment, loaded_models["dictionary_classifier"] | |
| ) | |
| elif model_type == "linear_svc": | |
| return predict_linear_svc(reddit_comment, loaded_models["sklearn_linear_svc"]) | |
| elif model_type == "flair_tars": | |
| return predict_flair_tars(reddit_comment, loaded_models["flair_tars"]) | |
| demo = gr.Interface( | |
| model_selector, | |
| [gr.Radio(["dictionary_classifier", "linear_svc", "flair_tars"]), "text"], | |
| "text", | |
| examples=[ | |
| [ | |
| "dictionary_classifier", | |
| "Do you really have a $2,080,000 mortgage for an investment property that rents for $700 a week?", | |
| ], | |
| [ | |
| "linear_svc", | |
| "I like the genie analogy. Anecdotally from peers and from own experience, if a role is advertised as 100% in office, then it’s a hard no.", | |
| ], | |
| [ | |
| "flair_tars", | |
| "It’s a seller’s market now, transitioning into a balanced/buyer’s market. Prices are still historically high, but it’s clear the peak is behind us. Rising interest rates doing exactly as expected.", | |
| ], | |
| ], | |
| title="Few-shot multi-label classification", | |
| description="A comparison of models, ranging from cheap to expensive. Enjoy!", | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |