danielhajialigol commited on
Commit
6a4a8e0
1 Parent(s): b3e501a

added drg and icd external link functionality

Browse files
Files changed (2) hide show
  1. app.py +7 -4
  2. utils.py +37 -5
app.py CHANGED
@@ -2,10 +2,9 @@ import numpy as np
2
  import gradio as gr
3
  import pandas as pd
4
  import torch
5
- import random
6
 
7
  from model import MimicTransformer
8
- from utils import load_rule, get_attribution, get_drg_link, visualize_attn
9
  from transformers import set_seed
10
 
11
  set_seed(42)
@@ -21,7 +20,7 @@ related_tensor = torch.load('discharge_embeddings.pt')
21
 
22
  # get model and results
23
  mimic = read_model(model=mimic, path=model_path)
24
- all_summaries = pd.read_csv('all_summaries.csv')['SUMMARIES'][:10000].to_list()
25
 
26
  tokenizer = mimic.tokenizer
27
  mimic.eval()
@@ -78,9 +77,12 @@ def run(text, related_discharges=False):
78
  model_results = get_model_results(text=text)
79
  drg_code = model_results['class']
80
  drg_link = get_drg_link(drg_code=drg_code)
 
81
  row = rule_df[rule_df['DRG_CODE'] == drg_code]
82
  drg_description = row['DESCRIPTION'].values[0]
83
  model_results['class_dsc'] = drg_description
 
 
84
  global related_summaries
85
  # related_summaries = generate_similar_summeries()
86
  related_summaries = find_related_summaries(model_results['logits'])
@@ -129,7 +131,8 @@ def prettify_text(nested_list):
129
  idx = 1
130
  string = ''
131
  for li in nested_list:
132
- string += f'({idx})\n{li[0]}\n\n'
 
133
  idx += 1
134
  return string
135
 
 
2
  import gradio as gr
3
  import pandas as pd
4
  import torch
 
5
 
6
  from model import MimicTransformer
7
+ from utils import load_rule, get_attribution, get_drg_link, get_icd_annotations, visualize_attn
8
  from transformers import set_seed
9
 
10
  set_seed(42)
 
20
 
21
  # get model and results
22
  mimic = read_model(model=mimic, path=model_path)
23
+ all_summaries = pd.read_csv('all_summaries.csv')['SUMMARIES'].to_list()
24
 
25
  tokenizer = mimic.tokenizer
26
  mimic.eval()
 
77
  model_results = get_model_results(text=text)
78
  drg_code = model_results['class']
79
  drg_link = get_drg_link(drg_code=drg_code)
80
+ icd_results = get_icd_annotations(text=text)
81
  row = rule_df[rule_df['DRG_CODE'] == drg_code]
82
  drg_description = row['DESCRIPTION'].values[0]
83
  model_results['class_dsc'] = drg_description
84
+ model_results['drg_link'] = drg_link
85
+ model_results['icd_results'] = icd_results
86
  global related_summaries
87
  # related_summaries = generate_similar_summeries()
88
  related_summaries = find_related_summaries(model_results['logits'])
 
131
  idx = 1
132
  string = ''
133
  for li in nested_list:
134
+ delimiters = 99 * '='
135
+ string += f'({idx})\n{li[0]}\n{delimiters}\n'
136
  idx += 1
137
  return string
138
 
utils.py CHANGED
@@ -66,7 +66,12 @@ def clean_text(text):
66
  return new_text
67
 
68
  def get_drg_link(drg_code):
69
- return f'https://www.aapc.com/codes/icd9-codes/{drg_code}'
 
 
 
 
 
70
 
71
  def prettify(dict_list, k):
72
  li = [di[k] for di in dict_list]
@@ -179,7 +184,7 @@ def reconstruct_text(tokenizer, tokens, attn):
179
  # final representation of text
180
  final_text = ' '.join(reconstructed_tokens).replace(' .', '.')
181
  final_text = final_text.replace(' ,', ',')
182
- assert final_text == reconstructed_text
183
  return aggregated_attn, reconstructed_tokens
184
 
185
  def load_rule(path):
@@ -225,7 +230,7 @@ def visualize_attn(model_results):
225
  raw_input_ids=tokens,
226
  convergence_score=1
227
  )
228
- return visualize_text(viz_record)
229
 
230
 
231
  def modify_attn_html(attn_html):
@@ -233,20 +238,46 @@ def modify_attn_html(attn_html):
233
  htmls = [attn_split[0]]
234
  for html in attn_split[1:]:
235
  # wrap around href tag
236
- href_html = f'<a href="espn.com" \
237
  <mark{html} \
238
  </a>'
239
  htmls.append(href_html)
240
  return "".join(htmls)
241
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
242
  # copied out of captum because we need raw html instead of a jupyter widget
243
- def visualize_text(datarecord):
244
  dom = ["<table width: 100%>"]
245
  rows = [
246
  "<th style='text-align: left'>Predicted DRG</th>"
247
  "<th style='text-align: left'>Word Importance</th>"
 
248
  ]
249
  pred_class_html = visualization.format_classname(datarecord.pred_class)
 
 
250
  word_attn_html = visualization.format_word_importances(
251
  datarecord.raw_input_ids, datarecord.word_attributions
252
  )
@@ -257,6 +288,7 @@ def visualize_text(datarecord):
257
  "<tr>",
258
  pred_class_html,
259
  word_attn_html,
 
260
  "<tr>",
261
  ]
262
  )
 
66
  return new_text
67
 
68
  def get_drg_link(drg_code):
69
+ drg_code = str(drg_code)
70
+ if len(drg_code) == 1:
71
+ drg_code = '00' + drg_code
72
+ elif len(drg_code) == 2:
73
+ drg_code = '0' + drg_code
74
+ return f'https://www.findacode.com/code.php?set=DRG&c={drg_code}'
75
 
76
  def prettify(dict_list, k):
77
  li = [di[k] for di in dict_list]
 
184
  # final representation of text
185
  final_text = ' '.join(reconstructed_tokens).replace(' .', '.')
186
  final_text = final_text.replace(' ,', ',')
187
+ # final_text == reconstructed_text
188
  return aggregated_attn, reconstructed_tokens
189
 
190
  def load_rule(path):
 
230
  raw_input_ids=tokens,
231
  convergence_score=1
232
  )
233
+ return visualize_text(viz_record, drg_link=model_results['drg_link'], icd_annotations=model_results['icd_results'])
234
 
235
 
236
  def modify_attn_html(attn_html):
 
238
  htmls = [attn_split[0]]
239
  for html in attn_split[1:]:
240
  # wrap around href tag
241
+ href_html = f'<a href="https://espn.com" \
242
  <mark{html} \
243
  </a>'
244
  htmls.append(href_html)
245
  return "".join(htmls)
246
 
247
+ def modify_code_html(html, link, icd=False):
248
+ html = html.split('<td>')[1].split('</td>')[0]
249
+ href_html = f'<td><a href="{link}"{html}</a></td>'
250
+ if icd:
251
+ href_html = href_html.replace('<td>', '').replace('</td>', '')
252
+ return href_html
253
+
254
+ def modify_drg_html(html, drg_link):
255
+ return modify_code_html(html=html, link=drg_link, icd=False)
256
+
257
+ def get_icd_html(icd_list):
258
+ if len(icd_list) == 0:
259
+ return '<td><text style="padding-right:2em"><b>N/A</b></text></td>'
260
+ final_html = '<td>'
261
+ for icd_dict in icd_list:
262
+ text, link = icd_dict['text'], icd_dict['link']
263
+ tmp_html = visualization.format_classname(classname=text)
264
+ html = modify_code_html(html=tmp_html, link=link, icd=True)
265
+ final_html += html
266
+ return final_html + '</td>'
267
+
268
+
269
+
270
  # copied out of captum because we need raw html instead of a jupyter widget
271
+ def visualize_text(datarecord, drg_link, icd_annotations):
272
  dom = ["<table width: 100%>"]
273
  rows = [
274
  "<th style='text-align: left'>Predicted DRG</th>"
275
  "<th style='text-align: left'>Word Importance</th>"
276
+ "<th style='text-align: left'>ICD Codes</th>"
277
  ]
278
  pred_class_html = visualization.format_classname(datarecord.pred_class)
279
+ icd_class_html = get_icd_html(icd_annotations)
280
+ pred_class_html = modify_drg_html(html=pred_class_html, drg_link=drg_link)
281
  word_attn_html = visualization.format_word_importances(
282
  datarecord.raw_input_ids, datarecord.word_attributions
283
  )
 
288
  "<tr>",
289
  pred_class_html,
290
  word_attn_html,
291
+ icd_class_html,
292
  "<tr>",
293
  ]
294
  )