mikachou commited on
Commit
4b67ac0
1 Parent(s): d743c08

add chart with proba

Browse files
Files changed (1) hide show
  1. app.py +68 -14
app.py CHANGED
@@ -2,6 +2,7 @@ import gradio as gr
2
  import joblib
3
  import spacy
4
  import numpy as np
 
5
  from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer
6
  from sklearn.preprocessing import MultiLabelBinarizer
7
  from sklearn.base import BaseEstimator, TransformerMixin
@@ -21,31 +22,84 @@ def lemmatize(s: str) -> iter:
21
  # lemmatize
22
  return map(lambda token: token.lemma_.lower(), tokens)
23
 
24
- def predict(title: str , post: str, predict_proba: bool):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  text = title + " " + post
26
  lemmes = np.array([' '.join(list(lemmatize(text)))])
27
 
28
  X = tfidf.transform(lemmes)
29
 
30
- if predict_proba:
31
- y_proba = model.predict_proba(X)[0]
32
- tags = list(dict(sorted(tags_binarizer.ts.count.items())).keys())
33
-
34
- result = list(zip(tags, y_proba))
35
- else:
36
- y_bin = model.predict(X)
37
- y_tags = tags_binarizer.inverse_transform(y_bin)
38
 
39
- result = y_tags
 
40
 
41
- return result
42
 
43
  demo = gr.Interface(
44
  fn=predict,
45
  inputs=[
46
  gr.Textbox(label="Title", lines=1, placeholder="Title..."),
47
- gr.Textbox(label="Post", lines=10, placeholder="Post..."),
48
- gr.Checkbox(label="Proba?")],
49
- outputs=gr.Textbox(lines=10))
50
 
51
  demo.launch()
 
2
  import joblib
3
  import spacy
4
  import numpy as np
5
+ import matplotlib.pyplot as plt
6
  from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer
7
  from sklearn.preprocessing import MultiLabelBinarizer
8
  from sklearn.base import BaseEstimator, TransformerMixin
 
22
  # lemmatize
23
  return map(lambda token: token.lemma_.lower(), tokens)
24
 
25
+ def plot(tags, proba):
26
+ plt.style.use('dark_background')
27
+ plt.rcParams.update({'font.size': 16})
28
+
29
+ fig, ax = plt.subplots(figsize=(12,9))
30
+
31
+ ax.barh(tags, proba, align='center', color='darkred')
32
+ ax.set_yticks(tags, labels=tags)
33
+ ax.invert_yaxis() # labels read top-to-bottom
34
+ ax.set_xlabel('Score')
35
+ ax.set_title('Score/Tag')
36
+
37
+ for i, v in enumerate(proba):
38
+ ax.text(v - 0.065, i + 0.05, str(round(v, 2)))
39
+
40
+ plt.xlim(0, 1)
41
+ plt.show()
42
+
43
+ def predict_words(X):
44
+ y_bin = model.predict(X)
45
+ y_tags = " ".join(tags_binarizer.inverse_transform(y_bin)[0])
46
+
47
+ return y_tags
48
+
49
+ def proba_chart(X):
50
+ y_proba = model.predict_proba(X)[0]
51
+ tags = list(dict(sorted(tags_binarizer.ts.count.items())).keys())
52
+
53
+ # combine
54
+ data = list(zip(tags, y_proba))
55
+
56
+ # sort
57
+ data = sorted(data, key=lambda tag_value: tag_value[1], reverse=True)
58
+
59
+ # keep values >= min_score
60
+ data = list(filter(lambda tag_value: tag_value[1] >= 0.1, data))
61
+
62
+ # we have our two dimensions for chart
63
+ tags, proba = zip(*data)
64
+
65
+ # build chart
66
+ plt.style.use('dark_background')
67
+ plt.rcParams.update({'font.size': 16})
68
+
69
+ fig, ax = plt.subplots(figsize=(12,9))
70
+
71
+ ax.barh(tags, proba, align='center', color='darkred')
72
+ ax.set_yticks(tags, labels=tags)
73
+ ax.invert_yaxis() # labels read top-to-bottom
74
+ ax.set_xlabel('Score')
75
+ ax.set_title('Score/Tag')
76
+
77
+ for i, v in enumerate(proba):
78
+ ax.text(v - 0.065, i + 0.05, str(round(v, 2)))
79
+
80
+ plt.xlim(0, 1)
81
+
82
+ return fig
83
+
84
+ def predict(title: str , post: str):
85
  text = title + " " + post
86
  lemmes = np.array([' '.join(list(lemmatize(text)))])
87
 
88
  X = tfidf.transform(lemmes)
89
 
90
+ # predicted words
91
+ words = predict_words(X)
 
 
 
 
 
 
92
 
93
+ # proba chart
94
+ chart = proba_chart(X)
95
 
96
+ return words, chart
97
 
98
  demo = gr.Interface(
99
  fn=predict,
100
  inputs=[
101
  gr.Textbox(label="Title", lines=1, placeholder="Title..."),
102
+ gr.Textbox(label="Post", lines=20, placeholder="Post...")],
103
+ outputs=[gr.Textbox(label="Tags"), gr.Plot()])
 
104
 
105
  demo.launch()