paragon-analytics commited on
Commit
7f48a24
·
1 Parent(s): 518ac36

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -4
app.py CHANGED
@@ -24,6 +24,11 @@ pred = transformers.pipeline("text-classification", model=model,
24
 
25
  explainer = shap.Explainer(pred)
26
 
 
 
 
 
 
27
  def adr_predict(x):
28
  encoded_input = tokenizer(x, return_tensors='pt')
29
  output = model(**encoded_input)
@@ -32,14 +37,16 @@ def adr_predict(x):
32
 
33
  shap_values = explainer([str(x).lower()])
34
  local_plot = shap.plots.text(shap_values[0], display=False)
 
 
35
 
36
- return {"Severe Reaction": float(scores.numpy()[1]), "Non-severe Reaction": float(scores.numpy()[0])}, local_plot
37
 
38
 
39
  def main(prob1):
40
  text = str(prob1).lower()
41
  obj = adr_predict(text)
42
- return obj[0],obj[1]
43
 
44
  title = "Welcome to **ADR Detector** 🪐"
45
  description1 = """This app takes text (up to a few sentences) and predicts to what extent the text describes severe (or non-severe) adverse reaction to medicaitons."""
@@ -60,12 +67,12 @@ with gr.Blocks(title=title) as demo:
60
  main,
61
  [prob1],
62
  [label
63
- ,local_plot
64
  ], api_name="adr"
65
  )
66
 
67
  gr.Markdown("### Click on any of the examples below to see to what extent they contain resilience messaging:")
68
- gr.Examples([["I have severe pain."],["I have minor pain."]], [prob1], [label,local_plot
69
  ], main, cache_examples=True)
70
 
71
  demo.launch()
 
24
 
25
  explainer = shap.Explainer(pred)
26
 
27
+
28
+ ##
29
+ classifier = transformers.pipeline("text-classification", model = "cross-encoder/qnli-electra-base")
30
+ ##
31
+
32
  def adr_predict(x):
33
  encoded_input = tokenizer(x, return_tensors='pt')
34
  output = model(**encoded_input)
 
37
 
38
  shap_values = explainer([str(x).lower()])
39
  local_plot = shap.plots.text(shap_values[0], display=False)
40
+
41
+ med = classifier(x+str("There is a medication."))[0]
42
 
43
+ return {"Severe Reaction": float(scores.numpy()[1]), "Non-severe Reaction": float(scores.numpy()[0])}, local_plot, med
44
 
45
 
46
  def main(prob1):
47
  text = str(prob1).lower()
48
  obj = adr_predict(text)
49
+ return obj[0],obj[1],obj[2]
50
 
51
  title = "Welcome to **ADR Detector** 🪐"
52
  description1 = """This app takes text (up to a few sentences) and predicts to what extent the text describes severe (or non-severe) adverse reaction to medicaitons."""
 
67
  main,
68
  [prob1],
69
  [label
70
+ ,local_plot, med
71
  ], api_name="adr"
72
  )
73
 
74
  gr.Markdown("### Click on any of the examples below to see to what extent they contain resilience messaging:")
75
+ gr.Examples([["I have severe pain."],["I have minor pain."]], [prob1], [label,local_plot, med
76
  ], main, cache_examples=True)
77
 
78
  demo.launch()