Edward J. Schwartz commited on
Commit
af9812a
1 Parent(s): e0ddff8

Use pipeline for interpretation

Browse files
Files changed (2) hide show
  1. app.py +10 -9
  2. requirements.txt +2 -1
app.py CHANGED
@@ -1,5 +1,6 @@
1
  import gradio as gr
2
  import shap
 
3
 
4
  import os
5
  import re
@@ -9,6 +10,8 @@ import tempfile
9
 
10
  model = gr.load("ejschwartz/oo-method-test-model-bylibrary", src="models")
11
 
 
 
12
  def get_all_dis(bname, addrs=None):
13
 
14
  anafile = tempfile.NamedTemporaryFile(prefix=os.path.basename(bname) + "_", suffix=".bat_ana")
@@ -122,17 +125,15 @@ with gr.Blocks() as demo:
122
  clazz: gr.Label.update(top_k)
123
  }
124
 
 
 
 
 
 
125
  def interpretation_function(text):
126
 
127
- def model_wrap(input):
128
- print(f"model_wrap input = {input}", file=sys.stderr)
129
- return model.fn(input)['confidences']
130
-
131
- print(text, file=sys.stderr)
132
- out = model_wrap(text)
133
- print(out, file=sys.stderr)
134
- explainer = shap.Explainer(model_wrap)
135
- shap_values = explainer(text)
136
 
137
  # Dimensions are (batch size, text size, number of classes)
138
  # Since we care about positive sentiment, use index 1
 
1
  import gradio as gr
2
  import shap
3
+ import transformers
4
 
5
  import os
6
  import re
 
10
 
11
  model = gr.load("ejschwartz/oo-method-test-model-bylibrary", src="models")
12
 
13
+ model_interp = transformers.pipeline("text-classification", "ejschwartz/oo-method-test-model-bylibrary")
14
+
15
  def get_all_dis(bname, addrs=None):
16
 
17
  anafile = tempfile.NamedTemporaryFile(prefix=os.path.basename(bname) + "_", suffix=".bat_ana")
 
125
  clazz: gr.Label.update(top_k)
126
  }
127
 
128
+ # XXX: Ideally we'd use the gr.load model, which uses the huggingface
129
+ # inference API. But shap library appears to use information in the
130
+ # transformers pipeline, and I don't feel like figuring out how to
131
+ # reimplement that, so we'll just use a regular transformers pipeline here
132
+ # for interpretation.
133
  def interpretation_function(text):
134
 
135
+ explainer = shap.Explainer(model_interp)
136
+ shap_values = explainer([text])
 
 
 
 
 
 
 
137
 
138
  # Dimensions are (batch size, text size, number of classes)
139
  # Since we care about positive sentiment, use index 1
requirements.txt CHANGED
@@ -1,4 +1,5 @@
1
  gradio
2
  pandas
3
  numpy >= 1.22.4
4
- shap
 
 
1
  gradio
2
  pandas
3
  numpy >= 1.22.4
4
+ shap
5
+ transformers[torch]