Spaces:
Sleeping
Sleeping
Merge pull request #10 from Dutta-SD/develop
Browse files- .gitignore +5 -1
- app.py +8 -4
- components/config.py +1 -0
- components/get_predictions.py +59 -11
- components/models.py +12 -0
- components/utils.py +12 -3
- requirements.txt +7 -0
.gitignore
CHANGED
@@ -191,4 +191,8 @@ cython_debug/
|
|
191 |
|
192 |
# Support for Project snippet scope
|
193 |
|
194 |
-
# End of https://www.toptal.com/developers/gitignore/api/python,visualstudiocode,data
|
|
|
|
|
|
|
|
|
|
191 |
|
192 |
# Support for Project snippet scope
|
193 |
|
194 |
+
# End of https://www.toptal.com/developers/gitignore/api/python,visualstudiocode,data
|
195 |
+
|
196 |
+
static/nltk
|
197 |
+
.vscode
|
198 |
+
try.py
|
app.py
CHANGED
@@ -1,7 +1,8 @@
|
|
1 |
from components.get_predictions import get_predictions
|
2 |
-
from gradio.components import Textbox
|
3 |
from gradio.interface import Interface
|
4 |
from gradio.themes import Monochrome
|
|
|
5 |
|
6 |
|
7 |
def get_input_fields() -> Textbox:
|
@@ -11,23 +12,25 @@ def get_input_fields() -> Textbox:
|
|
11 |
Textbox: Input Field as gradio TextBox
|
12 |
"""
|
13 |
return Textbox(
|
14 |
-
lines=
|
15 |
placeholder="Enter The Text",
|
16 |
value="",
|
17 |
label="Text to Predict",
|
18 |
)
|
19 |
|
20 |
|
21 |
-
def get_output_fields() -> list[
|
22 |
"""Gets Output Fields
|
23 |
|
24 |
Returns:
|
25 |
-
list[
|
26 |
"""
|
27 |
|
28 |
return [
|
29 |
Textbox(type="text", label="Aggression Prediction"),
|
30 |
Textbox(type="text", label="Misogyny Prediction"),
|
|
|
|
|
31 |
]
|
32 |
|
33 |
|
@@ -51,6 +54,7 @@ def get_interface() -> Interface:
|
|
51 |
|
52 |
|
53 |
if __name__ == "__main__":
|
|
|
54 |
interface = get_interface()
|
55 |
|
56 |
# Launch the interface
|
|
|
1 |
from components.get_predictions import get_predictions
|
2 |
+
from gradio.components import Textbox, IOComponent, Plot
|
3 |
from gradio.interface import Interface
|
4 |
from gradio.themes import Monochrome
|
5 |
+
from components.utils import initialize
|
6 |
|
7 |
|
8 |
def get_input_fields() -> Textbox:
|
|
|
12 |
Textbox: Input Field as gradio TextBox
|
13 |
"""
|
14 |
return Textbox(
|
15 |
+
lines=10,
|
16 |
placeholder="Enter The Text",
|
17 |
value="",
|
18 |
label="Text to Predict",
|
19 |
)
|
20 |
|
21 |
|
22 |
+
def get_output_fields() -> list[str | IOComponent]:
|
23 |
"""Gets Output Fields
|
24 |
|
25 |
Returns:
|
26 |
+
list[str | IOComponent]: output fields as gradio textbox
|
27 |
"""
|
28 |
|
29 |
return [
|
30 |
Textbox(type="text", label="Aggression Prediction"),
|
31 |
Textbox(type="text", label="Misogyny Prediction"),
|
32 |
+
Plot(label="Explanation of Aggression", scale=1),
|
33 |
+
Plot(label="Explanation of Misogyny", scale=1),
|
34 |
]
|
35 |
|
36 |
|
|
|
54 |
|
55 |
|
56 |
if __name__ == "__main__":
|
57 |
+
initialize()
|
58 |
interface = get_interface()
|
59 |
|
60 |
# Launch the interface
|
components/config.py
CHANGED
@@ -12,6 +12,7 @@ class Settings:
|
|
12 |
0: "NGEN - Non Misogynistic Content",
|
13 |
1: "GEN - Misogynistic Content",
|
14 |
}
|
|
|
15 |
|
16 |
|
17 |
app_config = Settings()
|
|
|
12 |
0: "NGEN - Non Misogynistic Content",
|
13 |
1: "GEN - Misogynistic Content",
|
14 |
}
|
15 |
+
NUM_EXPLAINER_FEATURES: int = 10
|
16 |
|
17 |
|
18 |
app_config = Settings()
|
components/get_predictions.py
CHANGED
@@ -1,29 +1,77 @@
|
|
1 |
import components.utils as utils
|
2 |
from components.config import app_config
|
3 |
-
import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
|
5 |
|
6 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
"""Gets Predictions for the Texts
|
8 |
|
9 |
Args:
|
10 |
text (str): The input text to get predictions for
|
11 |
|
12 |
Returns:
|
13 |
-
tuple[str,
|
|
|
14 |
"""
|
15 |
|
16 |
cleaned_data = [utils.clean_one_text(text)]
|
17 |
|
18 |
-
|
19 |
-
|
20 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
|
22 |
-
|
23 |
-
|
24 |
-
pred_2 = model_2.predict(cleaned_data)[0]
|
25 |
|
26 |
return (
|
27 |
-
app_config.TASK_A_MAP[
|
28 |
-
app_config.TASK_B_MAP[
|
|
|
|
|
29 |
)
|
|
|
1 |
import components.utils as utils
|
2 |
from components.config import app_config
|
3 |
+
from components.models import (
|
4 |
+
pipeline_task_A,
|
5 |
+
pipeline_task_B,
|
6 |
+
explainer_task_A,
|
7 |
+
explainer_task_B,
|
8 |
+
)
|
9 |
+
from lime.lime_text import LimeTextExplainer
|
10 |
+
from typing import Any
|
11 |
+
from matplotlib.figure import Figure
|
12 |
|
13 |
|
14 |
+
def predict_for_pipeline(
|
15 |
+
model_pipeline: Any,
|
16 |
+
explainer: LimeTextExplainer,
|
17 |
+
cleaned_data: list[str],
|
18 |
+
labels: list,
|
19 |
+
) -> tuple[int, Figure | None]:
|
20 |
+
"""Generates Prediction and Explanation given the cleaned text
|
21 |
+
|
22 |
+
Args:
|
23 |
+
model_pipeline (Any): Joblib imported model pipeline
|
24 |
+
explainer (LimeTextExplainer): text explainer
|
25 |
+
cleaned_data (list[str]): cleaned text
|
26 |
+
labels(list): list of integers as labels
|
27 |
+
|
28 |
+
Returns:
|
29 |
+
tuple[int, Figure]: class prediction and LIME explanation as matplotlib figure
|
30 |
+
"""
|
31 |
+
|
32 |
+
explanation = explainer.explain_instance(
|
33 |
+
cleaned_data[0],
|
34 |
+
model_pipeline.predict_proba,
|
35 |
+
num_features=app_config.NUM_EXPLAINER_FEATURES,
|
36 |
+
labels=labels,
|
37 |
+
)
|
38 |
+
|
39 |
+
class_prediction = model_pipeline.predict(cleaned_data)[0]
|
40 |
+
return class_prediction, explanation.as_pyplot_figure(label=1)
|
41 |
+
|
42 |
+
|
43 |
+
def get_predictions(text: str) -> tuple:
|
44 |
"""Gets Predictions for the Texts
|
45 |
|
46 |
Args:
|
47 |
text (str): The input text to get predictions for
|
48 |
|
49 |
Returns:
|
50 |
+
tuple[str, Any]: Predictions for task A and task B
|
51 |
+
along with Figures
|
52 |
"""
|
53 |
|
54 |
cleaned_data = [utils.clean_one_text(text)]
|
55 |
|
56 |
+
prediction_task_A = predict_for_pipeline(
|
57 |
+
pipeline_task_A,
|
58 |
+
explainer_task_A,
|
59 |
+
cleaned_data,
|
60 |
+
[0, 1, 2],
|
61 |
+
)
|
62 |
+
prediction_task_B = predict_for_pipeline(
|
63 |
+
pipeline_task_B,
|
64 |
+
explainer_task_B,
|
65 |
+
cleaned_data,
|
66 |
+
[0, 1],
|
67 |
+
)
|
68 |
|
69 |
+
print(prediction_task_A)
|
70 |
+
print(prediction_task_B)
|
|
|
71 |
|
72 |
return (
|
73 |
+
app_config.TASK_A_MAP[prediction_task_A[0]],
|
74 |
+
app_config.TASK_B_MAP[prediction_task_B[0]],
|
75 |
+
prediction_task_A[1],
|
76 |
+
prediction_task_B[1],
|
77 |
)
|
components/models.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import joblib
|
2 |
+
from components.config import app_config
|
3 |
+
from lime.lime_text import LimeTextExplainer
|
4 |
+
|
5 |
+
|
6 |
+
# Takes in a string and outputs list
|
7 |
+
pipeline_task_A = joblib.load(app_config.TASK_A_MODEL_PATH)
|
8 |
+
pipeline_task_B = joblib.load(app_config.TASK_B_MODEL_PATH)
|
9 |
+
|
10 |
+
# LIME text explainer
|
11 |
+
explainer_task_A: LimeTextExplainer = LimeTextExplainer()
|
12 |
+
explainer_task_B: LimeTextExplainer = LimeTextExplainer()
|
components/utils.py
CHANGED
@@ -1,8 +1,8 @@
|
|
1 |
import string
|
2 |
import nltk
|
3 |
import re
|
4 |
-
|
5 |
-
|
6 |
|
7 |
|
8 |
# Cleans one text
|
@@ -32,8 +32,17 @@ def clean_one_text(text: str) -> str:
|
|
32 |
|
33 |
s.difference_update(not_words)
|
34 |
|
35 |
-
stmr =
|
36 |
tokens = [token for token in tk.tokenize(new_string) if token.lower() not in s]
|
37 |
clean_tokens = [stmr.stem(token) for token in tokens]
|
38 |
text = " ".join(clean_tokens)
|
39 |
return text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import string
|
2 |
import nltk
|
3 |
import re
|
4 |
+
from nltk.stem.porter import PorterStemmer
|
5 |
+
import warnings
|
6 |
|
7 |
|
8 |
# Cleans one text
|
|
|
32 |
|
33 |
s.difference_update(not_words)
|
34 |
|
35 |
+
stmr = PorterStemmer()
|
36 |
tokens = [token for token in tk.tokenize(new_string) if token.lower() not in s]
|
37 |
clean_tokens = [stmr.stem(token) for token in tokens]
|
38 |
text = " ".join(clean_tokens)
|
39 |
return text
|
40 |
+
|
41 |
+
|
42 |
+
def setup_nltk():
|
43 |
+
nltk.download("stopwords")
|
44 |
+
|
45 |
+
|
46 |
+
def initialize():
|
47 |
+
warnings.filterwarnings("ignore")
|
48 |
+
setup_nltk()
|
requirements.txt
CHANGED
@@ -36,11 +36,14 @@ httpcore==0.17.2
|
|
36 |
httpx==0.24.1
|
37 |
huggingface-hub==0.15.1
|
38 |
idna==3.4
|
|
|
39 |
itsdangerous==2.1.2
|
40 |
Jinja2==3.1.2
|
41 |
joblib==1.2.0
|
42 |
jsonschema==4.17.3
|
43 |
kiwisolver==1.4.4
|
|
|
|
|
44 |
linkify-it-py==2.0.2
|
45 |
markdown-it-py==2.2.0
|
46 |
markdown2==2.4.8
|
@@ -51,6 +54,7 @@ mdurl==0.1.2
|
|
51 |
monotonic==1.6
|
52 |
multidict==6.0.4
|
53 |
mypy-extensions==1.0.0
|
|
|
54 |
nltk==3.8.1
|
55 |
numpy==1.24.3
|
56 |
orjson==3.9.1
|
@@ -71,9 +75,11 @@ pyrsistent==0.19.3
|
|
71 |
python-dateutil==2.8.2
|
72 |
python-multipart==0.0.6
|
73 |
pytz==2023.3
|
|
|
74 |
PyYAML==6.0
|
75 |
regex==2023.6.3
|
76 |
requests==2.31.0
|
|
|
77 |
scikit-learn==1.2.2
|
78 |
scipy==1.10.1
|
79 |
semantic-version==2.10.0
|
@@ -81,6 +87,7 @@ six==1.16.0
|
|
81 |
sniffio==1.3.0
|
82 |
starlette==0.27.0
|
83 |
threadpoolctl==3.1.0
|
|
|
84 |
tomli==2.0.1
|
85 |
toolz==0.12.0
|
86 |
tqdm==4.65.0
|
|
|
36 |
httpx==0.24.1
|
37 |
huggingface-hub==0.15.1
|
38 |
idna==3.4
|
39 |
+
imageio==2.31.1
|
40 |
itsdangerous==2.1.2
|
41 |
Jinja2==3.1.2
|
42 |
joblib==1.2.0
|
43 |
jsonschema==4.17.3
|
44 |
kiwisolver==1.4.4
|
45 |
+
lazy_loader==0.2
|
46 |
+
lime==0.2.0.1
|
47 |
linkify-it-py==2.0.2
|
48 |
markdown-it-py==2.2.0
|
49 |
markdown2==2.4.8
|
|
|
54 |
monotonic==1.6
|
55 |
multidict==6.0.4
|
56 |
mypy-extensions==1.0.0
|
57 |
+
networkx==3.1
|
58 |
nltk==3.8.1
|
59 |
numpy==1.24.3
|
60 |
orjson==3.9.1
|
|
|
75 |
python-dateutil==2.8.2
|
76 |
python-multipart==0.0.6
|
77 |
pytz==2023.3
|
78 |
+
PyWavelets==1.4.1
|
79 |
PyYAML==6.0
|
80 |
regex==2023.6.3
|
81 |
requests==2.31.0
|
82 |
+
scikit-image==0.21.0
|
83 |
scikit-learn==1.2.2
|
84 |
scipy==1.10.1
|
85 |
semantic-version==2.10.0
|
|
|
87 |
sniffio==1.3.0
|
88 |
starlette==0.27.0
|
89 |
threadpoolctl==3.1.0
|
90 |
+
tifffile==2023.4.12
|
91 |
tomli==2.0.1
|
92 |
toolz==0.12.0
|
93 |
tqdm==4.65.0
|