paragon-analytics commited on
Commit
518ac36
1 Parent(s): 3a53d21

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -91
app.py CHANGED
@@ -12,7 +12,6 @@ from transformers import AutoModelForSequenceClassification
12
  from transformers import TFAutoModelForSequenceClassification
13
  from transformers import AutoTokenizer
14
  import matplotlib.pyplot as plt
15
- # from transformers_interpret import SequenceClassificationExplainer
16
 
17
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
18
 
@@ -25,90 +24,17 @@ pred = transformers.pipeline("text-classification", model=model,
25
 
26
  explainer = shap.Explainer(pred)
27
 
28
- def interpretation_function(text):
29
- shap_values = explainer([text])
30
- scores = list(zip(shap_values.data[0], shap_values.values[0, :, 1]))
31
- return scores
32
-
33
- # model = AutoModelForSequenceClassification.from_pretrained("paragon-analytics/ADRv1")
34
- # modelc = AutoModelForSequenceClassification.from_pretrained("paragon-analytics/ADRv1").cuda
35
-
36
-
37
- # cls_explainer = SequenceClassificationExplainer(
38
- # model,
39
- # tokenizer)
40
-
41
- # # define a prediction function
42
- # def f(x):
43
- # tv = torch.tensor([tokenizer.encode(v, padding='max_length', max_length=500, truncation=True) for v in x]).cuda()
44
- # outputs = modelc(tv)[0].detach().cpu().numpy()
45
- # scores = (np.exp(outputs).T / np.exp(outputs).sum(-1)).T
46
- # val = sp.special.logit(scores[:,1]) # use one vs rest logit units
47
- # return val
48
-
49
  def adr_predict(x):
50
  encoded_input = tokenizer(x, return_tensors='pt')
51
  output = model(**encoded_input)
52
  scores = output[0][0].detach().numpy()
53
  scores = tf.nn.softmax(scores)
54
-
55
- # # build a pipeline object to do predictions
56
- # pred = transformers.pipeline("text-classification", model=model,
57
- # tokenizer=tokenizer, device=0, return_all_scores=True)
58
- # explainer = shap.Explainer(pred)
59
- # shap_values = explainer([x])
60
- # shap_plot = shap.plots.text(shap_values)
61
-
62
- # word_attributions = cls_explainer(str(x))
63
- # # scores = list(zip(shap_values.data[0], shap_values.values[0, :, 1]))
64
- # letter = []
65
- # score = []
66
- # for i in word_attributions:
67
- # if i[1]>0.5:
68
- # a = "++"
69
- # elif (i[1]<=0.5) and (i[1]>0.1):
70
- # a = "+"
71
- # elif (i[1]>=-0.5) and (i[1]<-0.1):
72
- # a = "-"
73
- # elif i[1]<-0.5:
74
- # a = "--"
75
- # else:
76
- # a = "NA"
77
-
78
- # letter.append(i[0])
79
- # score.append(a)
80
-
81
- # word_attributions = [(letter[i], score[i]) for i in range(0, len(letter))]
82
-
83
- # # SHAP:
84
- # # build an explainer using a token masker
85
- # explainer = shap.Explainer(f, tokenizer)
86
- # shap_values = explainer(str(x), fixed_context=1)
87
- # scores = list(zip(shap_values.data[0], shap_values.values[0, :, 1]))
88
- # # plot the first sentence's explanation
89
- # # plt = shap.plots.text(shap_values[0],display=False)
90
-
91
- # shap_scores = interpretation_function(str(x).lower())
92
-
93
  shap_values = explainer([str(x).lower()])
94
  local_plot = shap.plots.text(shap_values[0], display=False)
95
-
96
- # local_plot = (
97
- # ""
98
- # + plot
99
- # + ""
100
- # )
101
-
102
- # plt.tight_layout()
103
- # local_plot = plt.gcf()
104
- # plt.rcParams['figure.figsize'] = 6,4
105
- # plt.close()
106
-
107
-
108
-
109
  return {"Severe Reaction": float(scores.numpy()[1]), "Non-severe Reaction": float(scores.numpy()[0])}, local_plot
110
- # shap_scores
111
- # , word_attributions ,scores
112
 
113
  def main(prob1):
114
  text = str(prob1).lower()
@@ -128,31 +54,18 @@ with gr.Blocks(title=title) as demo:
128
 
129
  with gr.Column(visible=True) as output_col:
130
  label = gr.Label(label = "Predicted Label")
131
- # impplot = gr.HighlightedText(label="Important Words", combine_adjacent=False).style(
132
- # color_map={"+++": "royalblue","++": "cornflowerblue",
133
- # "+": "lightsteelblue", "NA":"white"})
134
- # NER = gr.HTML(label = 'NER:')
135
- # intp = gr.HighlightedText(label="Word Scores",
136
- # combine_adjacent=False).style(color_map={"++": "darkred","+": "red",
137
- # "--": "darkblue",
138
- # "-": "blue", "NA":"white"})
139
-
140
- # interpretation = gr.components.Interpretation(prob1)
141
  local_plot = gr.HTML(label = 'Shap:')
142
- # local_plot = gr.Plot(label = 'Shap:')
143
-
144
 
145
  submit_btn.click(
146
  main,
147
  [prob1],
148
  [label
149
- # ,intp
150
  ,local_plot
151
  ], api_name="adr"
152
  )
153
 
154
  gr.Markdown("### Click on any of the examples below to see to what extent they contain resilience messaging:")
155
- gr.Examples([["I have minor pain."],["I have severe pain."]], [prob1], [label,local_plot
156
  ], main, cache_examples=True)
157
 
158
  demo.launch()
 
12
  from transformers import TFAutoModelForSequenceClassification
13
  from transformers import AutoTokenizer
14
  import matplotlib.pyplot as plt
 
15
 
16
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
17
 
 
24
 
25
  explainer = shap.Explainer(pred)
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  def adr_predict(x):
28
  encoded_input = tokenizer(x, return_tensors='pt')
29
  output = model(**encoded_input)
30
  scores = output[0][0].detach().numpy()
31
  scores = tf.nn.softmax(scores)
32
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  shap_values = explainer([str(x).lower()])
34
  local_plot = shap.plots.text(shap_values[0], display=False)
35
+
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  return {"Severe Reaction": float(scores.numpy()[1]), "Non-severe Reaction": float(scores.numpy()[0])}, local_plot
37
+
 
38
 
39
  def main(prob1):
40
  text = str(prob1).lower()
 
54
 
55
  with gr.Column(visible=True) as output_col:
56
  label = gr.Label(label = "Predicted Label")
 
 
 
 
 
 
 
 
 
 
57
  local_plot = gr.HTML(label = 'Shap:')
 
 
58
 
59
  submit_btn.click(
60
  main,
61
  [prob1],
62
  [label
 
63
  ,local_plot
64
  ], api_name="adr"
65
  )
66
 
67
  gr.Markdown("### Click on any of the examples below to see to what extent they contain resilience messaging:")
68
+ gr.Examples([["I have severe pain."],["I have minor pain."]], [prob1], [label,local_plot
69
  ], main, cache_examples=True)
70
 
71
  demo.launch()