sdutta28 commited on
Commit
15c875a
1 Parent(s): e00c87e

Added LIME explainability

Browse files
.gitignore CHANGED
@@ -191,4 +191,8 @@ cython_debug/
191
 
192
  # Support for Project snippet scope
193
 
194
- # End of https://www.toptal.com/developers/gitignore/api/python,visualstudiocode,data
 
 
 
 
 
191
 
192
  # Support for Project snippet scope
193
 
194
+ # End of https://www.toptal.com/developers/gitignore/api/python,visualstudiocode,data
195
+
196
+ static/nltk
197
+ .vscode
198
+ try.py
app.py CHANGED
@@ -1,7 +1,8 @@
1
  from components.get_predictions import get_predictions
2
- from gradio.components import Textbox
3
  from gradio.interface import Interface
4
  from gradio.themes import Monochrome
 
5
 
6
 
7
  def get_input_fields() -> Textbox:
@@ -11,23 +12,25 @@ def get_input_fields() -> Textbox:
11
  Textbox: Input Field as gradio TextBox
12
  """
13
  return Textbox(
14
- lines=2,
15
  placeholder="Enter The Text",
16
  value="",
17
  label="Text to Predict",
18
  )
19
 
20
 
21
- def get_output_fields() -> list[Textbox]:
22
  """Gets Output Fields
23
 
24
  Returns:
25
- list[Textbox...]: output fields as gradio textbox
26
  """
27
 
28
  return [
29
  Textbox(type="text", label="Aggression Prediction"),
30
  Textbox(type="text", label="Misogyny Prediction"),
 
 
31
  ]
32
 
33
 
@@ -51,6 +54,7 @@ def get_interface() -> Interface:
51
 
52
 
53
  if __name__ == "__main__":
 
54
  interface = get_interface()
55
 
56
  # Launch the interface
 
1
  from components.get_predictions import get_predictions
2
+ from gradio.components import Textbox, IOComponent, Plot
3
  from gradio.interface import Interface
4
  from gradio.themes import Monochrome
5
+ from components.utils import initialize
6
 
7
 
8
  def get_input_fields() -> Textbox:
 
12
  Textbox: Input Field as gradio TextBox
13
  """
14
  return Textbox(
15
+ lines=10,
16
  placeholder="Enter The Text",
17
  value="",
18
  label="Text to Predict",
19
  )
20
 
21
 
22
+ def get_output_fields() -> list[str | IOComponent]:
23
  """Gets Output Fields
24
 
25
  Returns:
26
+ list[str | IOComponent]: output fields as gradio textbox
27
  """
28
 
29
  return [
30
  Textbox(type="text", label="Aggression Prediction"),
31
  Textbox(type="text", label="Misogyny Prediction"),
32
+ Plot(label="Explanation of Aggression", scale=1),
33
+ Plot(label="Explanation of Misogyny", scale=1),
34
  ]
35
 
36
 
 
54
 
55
 
56
  if __name__ == "__main__":
57
+ initialize()
58
  interface = get_interface()
59
 
60
  # Launch the interface
components/config.py CHANGED
@@ -12,6 +12,7 @@ class Settings:
12
  0: "NGEN - Non Misogynistic Content",
13
  1: "GEN - Misogynistic Content",
14
  }
 
15
 
16
 
17
  app_config = Settings()
 
12
  0: "NGEN - Non Misogynistic Content",
13
  1: "GEN - Misogynistic Content",
14
  }
15
+ NUM_EXPLAINER_FEATURES: int = 10
16
 
17
 
18
  app_config = Settings()
components/get_predictions.py CHANGED
@@ -1,29 +1,77 @@
1
  import components.utils as utils
2
  from components.config import app_config
3
- import joblib
 
 
 
 
 
 
 
 
4
 
5
 
6
- def get_predictions(text: str) -> tuple[str, str]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  """Gets Predictions for the Texts
8
 
9
  Args:
10
  text (str): The input text to get predictions for
11
 
12
  Returns:
13
- tuple[str, str]: Predictions for task A and task B
 
14
  """
15
 
16
  cleaned_data = [utils.clean_one_text(text)]
17
 
18
- # Load Models
19
- model_1 = joblib.load(app_config.TASK_A_MODEL_PATH)
20
- model_2 = joblib.load(app_config.TASK_B_MODEL_PATH)
 
 
 
 
 
 
 
 
 
21
 
22
- # Predictions
23
- pred_1 = model_1.predict(cleaned_data)[0]
24
- pred_2 = model_2.predict(cleaned_data)[0]
25
 
26
  return (
27
- app_config.TASK_A_MAP[pred_1],
28
- app_config.TASK_B_MAP[pred_2],
 
 
29
  )
 
1
  import components.utils as utils
2
  from components.config import app_config
3
+ from components.models import (
4
+ pipeline_task_A,
5
+ pipeline_task_B,
6
+ explainer_task_A,
7
+ explainer_task_B,
8
+ )
9
+ from lime.lime_text import LimeTextExplainer
10
+ from typing import Any
11
+ from matplotlib.figure import Figure
12
 
13
 
14
+ def predict_for_pipeline(
15
+ model_pipeline: Any,
16
+ explainer: LimeTextExplainer,
17
+ cleaned_data: list[str],
18
+ labels: list,
19
+ ) -> tuple[int, Figure | None]:
20
+ """Generates Prediction and Explanation given the cleaned text
21
+
22
+ Args:
23
+ model_pipeline (Any): Joblib imported model pipeline
24
+ explainer (LimeTextExplainer): text explainer
25
+ cleaned_data (list[str]): cleaned text
26
+ labels(list): list of integers as labels
27
+
28
+ Returns:
29
+ tuple[int, Figure]: class prediction and LIME explanation as matplotlib figure
30
+ """
31
+
32
+ explanation = explainer.explain_instance(
33
+ cleaned_data[0],
34
+ model_pipeline.predict_proba,
35
+ num_features=app_config.NUM_EXPLAINER_FEATURES,
36
+ labels=labels,
37
+ )
38
+
39
+ class_prediction = model_pipeline.predict(cleaned_data)[0]
40
+ return class_prediction, explanation.as_pyplot_figure(label=1)
41
+
42
+
43
+ def get_predictions(text: str) -> tuple:
44
  """Gets Predictions for the Texts
45
 
46
  Args:
47
  text (str): The input text to get predictions for
48
 
49
  Returns:
50
+ tuple[str, Any]: Predictions for task A and task B
51
+ along with Figures
52
  """
53
 
54
  cleaned_data = [utils.clean_one_text(text)]
55
 
56
+ prediction_task_A = predict_for_pipeline(
57
+ pipeline_task_A,
58
+ explainer_task_A,
59
+ cleaned_data,
60
+ [0, 1, 2],
61
+ )
62
+ prediction_task_B = predict_for_pipeline(
63
+ pipeline_task_B,
64
+ explainer_task_B,
65
+ cleaned_data,
66
+ [0, 1],
67
+ )
68
 
69
+ print(prediction_task_A)
70
+ print(prediction_task_B)
 
71
 
72
  return (
73
+ app_config.TASK_A_MAP[prediction_task_A[0]],
74
+ app_config.TASK_B_MAP[prediction_task_B[0]],
75
+ prediction_task_A[1],
76
+ prediction_task_B[1],
77
  )
components/models.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import joblib
2
+ from components.config import app_config
3
+ from lime.lime_text import LimeTextExplainer
4
+
5
+
6
+ # Takes in a string and outputs list
7
+ pipeline_task_A = joblib.load(app_config.TASK_A_MODEL_PATH)
8
+ pipeline_task_B = joblib.load(app_config.TASK_B_MODEL_PATH)
9
+
10
+ # LIME text explainer
11
+ explainer_task_A: LimeTextExplainer = LimeTextExplainer()
12
+ explainer_task_B: LimeTextExplainer = LimeTextExplainer()
components/utils.py CHANGED
@@ -1,8 +1,8 @@
1
  import string
2
  import nltk
3
  import re
4
-
5
- nltk.download("stopwords")
6
 
7
 
8
  # Cleans one text
@@ -32,8 +32,17 @@ def clean_one_text(text: str) -> str:
32
 
33
  s.difference_update(not_words)
34
 
35
- stmr = nltk.stem.porter.PorterStemmer()
36
  tokens = [token for token in tk.tokenize(new_string) if token.lower() not in s]
37
  clean_tokens = [stmr.stem(token) for token in tokens]
38
  text = " ".join(clean_tokens)
39
  return text
 
 
 
 
 
 
 
 
 
 
1
  import string
2
  import nltk
3
  import re
4
+ from nltk.stem.porter import PorterStemmer
5
+ import warnings
6
 
7
 
8
  # Cleans one text
 
32
 
33
  s.difference_update(not_words)
34
 
35
+ stmr = PorterStemmer()
36
  tokens = [token for token in tk.tokenize(new_string) if token.lower() not in s]
37
  clean_tokens = [stmr.stem(token) for token in tokens]
38
  text = " ".join(clean_tokens)
39
  return text
40
+
41
+
42
+ def setup_nltk():
43
+ nltk.download("stopwords")
44
+
45
+
46
+ def initialize():
47
+ warnings.filterwarnings("ignore")
48
+ setup_nltk()
requirements.txt CHANGED
@@ -36,11 +36,14 @@ httpcore==0.17.2
36
  httpx==0.24.1
37
  huggingface-hub==0.15.1
38
  idna==3.4
 
39
  itsdangerous==2.1.2
40
  Jinja2==3.1.2
41
  joblib==1.2.0
42
  jsonschema==4.17.3
43
  kiwisolver==1.4.4
 
 
44
  linkify-it-py==2.0.2
45
  markdown-it-py==2.2.0
46
  markdown2==2.4.8
@@ -51,6 +54,7 @@ mdurl==0.1.2
51
  monotonic==1.6
52
  multidict==6.0.4
53
  mypy-extensions==1.0.0
 
54
  nltk==3.8.1
55
  numpy==1.24.3
56
  orjson==3.9.1
@@ -71,9 +75,11 @@ pyrsistent==0.19.3
71
  python-dateutil==2.8.2
72
  python-multipart==0.0.6
73
  pytz==2023.3
 
74
  PyYAML==6.0
75
  regex==2023.6.3
76
  requests==2.31.0
 
77
  scikit-learn==1.2.2
78
  scipy==1.10.1
79
  semantic-version==2.10.0
@@ -81,6 +87,7 @@ six==1.16.0
81
  sniffio==1.3.0
82
  starlette==0.27.0
83
  threadpoolctl==3.1.0
 
84
  tomli==2.0.1
85
  toolz==0.12.0
86
  tqdm==4.65.0
 
36
  httpx==0.24.1
37
  huggingface-hub==0.15.1
38
  idna==3.4
39
+ imageio==2.31.1
40
  itsdangerous==2.1.2
41
  Jinja2==3.1.2
42
  joblib==1.2.0
43
  jsonschema==4.17.3
44
  kiwisolver==1.4.4
45
+ lazy_loader==0.2
46
+ lime==0.2.0.1
47
  linkify-it-py==2.0.2
48
  markdown-it-py==2.2.0
49
  markdown2==2.4.8
 
54
  monotonic==1.6
55
  multidict==6.0.4
56
  mypy-extensions==1.0.0
57
+ networkx==3.1
58
  nltk==3.8.1
59
  numpy==1.24.3
60
  orjson==3.9.1
 
75
  python-dateutil==2.8.2
76
  python-multipart==0.0.6
77
  pytz==2023.3
78
+ PyWavelets==1.4.1
79
  PyYAML==6.0
80
  regex==2023.6.3
81
  requests==2.31.0
82
+ scikit-image==0.21.0
83
  scikit-learn==1.2.2
84
  scipy==1.10.1
85
  semantic-version==2.10.0
 
87
  sniffio==1.3.0
88
  starlette==0.27.0
89
  threadpoolctl==3.1.0
90
+ tifffile==2023.4.12
91
  tomli==2.0.1
92
  toolz==0.12.0
93
  tqdm==4.65.0