stackoverflow / app.py
mikachou's picture
add probability feature
f41c648
raw
history blame
No virus
1.49 kB
import gradio as gr
import joblib
import spacy
import numpy as np
from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer
from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.base import BaseEstimator, TransformerMixin
nlp = spacy.load('en_core_web_sm')
tfidf = joblib.load('./tfidf.joblib')
model = joblib.load('./model.joblib')
tags_binarizer = joblib.load('./tags.joblib')
def lemmatize(s: str) -> iter:
# tokenize
doc = nlp(s)
# remove punct and stopwords
tokens = filter(lambda token: not token.is_space and not token.is_punct and not token.is_stop and not token.is_digit, doc)
# lemmatize
return map(lambda token: token.lemma_.lower(), tokens)
def predict(title: str , post: str, predict_proba: bool):
text = title + " " + post
lemmes = np.array([' '.join(list(lemmatize(text)))])
X = tfidf.transform(lemmes)
if predict_proba:
y_proba = model.predict_proba(X)[0]
tags = list(dict(sorted(tags_binarizer.ts.count.items())).keys())
result = list(zip(tags, y_proba))
else:
y_bin = model.predict(X)
y_tags = tags_binarizer.inverse_transform(y_bin)
result = y_tags
return result
demo = gr.Interface(
fn=predict,
inputs=[
gr.Textbox(label="Title", lines=1, placeholder="Title..."),
gr.Textbox(label="Post", lines=10, placeholder="Post..."),
gr.Checkbox(label="Proba?")],
outputs=gr.Textbox(lines=10))
demo.launch()