paragon-analytics commited on
Commit
f387393
1 Parent(s): f3760ed

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -8
app.py CHANGED
@@ -1,5 +1,6 @@
1
  import streamlit as st
2
  import gradio as gr
 
3
  import torch
4
  import tensorflow as tf
5
  from transformers import RobertaTokenizer, RobertaModel
@@ -15,12 +16,19 @@ def adr_predict(x):
15
  output = model(**encoded_input)
16
  scores = output[0][0].detach().numpy()
17
  scores = tf.nn.softmax(scores)
18
- return {"Severe Reaction": float(scores.numpy()[1]), "Non-severe Reaction": float(scores.numpy()[0])}
 
 
 
 
 
 
 
19
 
20
  def main(text):
21
  text = str(text).lower()
22
  obj = adr_predict(text)
23
- return obj
24
 
25
  title = "Welcome to **ADR Detector** 🪐"
26
  description1 = """
@@ -41,18 +49,15 @@ with gr.Blocks(title=title) as demo:
41
  # color_map={"+++": "royalblue","++": "cornflowerblue",
42
  # "+": "lightsteelblue", "NA":"white"})
43
  # NER = gr.HTML(label = 'NER:')
44
- # intp =gr.HighlightedText(label="Word Scores",
45
- # combine_adjacent=False).style(color_map={"++": "darkgreen","+": "green",
46
- # "--": "darkred",
47
- # "-": "red", "NA":"white"})
48
 
49
  submit_btn.click(
50
  main,
51
  [text],
52
- [label], api_name="adr"
53
  )
54
 
55
  gr.Markdown("### Click on any of the examples below to see to what extent they contain resilience messaging:")
56
- gr.Examples([["I have minor pain."],["I have severe pain."]], [text], [label], main, cache_examples=True)
57
 
58
  demo.launch()
 
1
  import streamlit as st
2
  import gradio as gr
3
+ import shap
4
  import torch
5
  import tensorflow as tf
6
  from transformers import RobertaTokenizer, RobertaModel
 
16
  output = model(**encoded_input)
17
  scores = output[0][0].detach().numpy()
18
  scores = tf.nn.softmax(scores)
19
+
20
+ # build a pipeline object to do predictions
21
+ pred = transformers.pipeline("text-classification", model=model,
22
+ tokenizer=tokenizer, device=0, return_all_scores=True)
23
+ explainer = shap.Explainer(pred)
24
+ shap_values = explainer([x])
25
+ shap_plot = shap.plots.text(shap_values)
26
+ return {"Severe Reaction": float(scores.numpy()[1]), "Non-severe Reaction": float(scores.numpy()[0])}, shap_plot
27
 
28
  def main(text):
29
  text = str(text).lower()
30
  obj = adr_predict(text)
31
+ return obj[0],obj[1]
32
 
33
  title = "Welcome to **ADR Detector** 🪐"
34
  description1 = """
 
49
  # color_map={"+++": "royalblue","++": "cornflowerblue",
50
  # "+": "lightsteelblue", "NA":"white"})
51
  # NER = gr.HTML(label = 'NER:')
52
+ shap_plot = gr.HighlightedText(label="Word Scores",combine_adjacent=False)
 
 
 
53
 
54
  submit_btn.click(
55
  main,
56
  [text],
57
+ [label,shap_plot], api_name="adr"
58
  )
59
 
60
  gr.Markdown("### Click on any of the examples below to see to what extent they contain resilience messaging:")
61
+ gr.Examples([["I have minor pain."],["I have severe pain."]], [text], [label,shap_plot], main, cache_examples=True)
62
 
63
  demo.launch()