danielhajialigol commited on
Commit
4067c90
1 Parent(s): d1017a6

adding ner models

Browse files
Files changed (4) hide show
  1. app.py +39 -22
  2. discharge_embeddings.pt +2 -2
  3. requirements.txt +1 -0
  4. utils.py +46 -4
app.py CHANGED
@@ -4,32 +4,38 @@ import pandas as pd
4
  import torch
5
 
6
  from model import MimicTransformer
7
- from utils import load_rule, get_attribution, get_drg_link, get_icd_annotations, visualize_attn
8
- from transformers import set_seed
9
 
10
  set_seed(42)
 
 
 
 
 
 
 
11
 
12
  def read_model(model, path):
13
  model.load_state_dict(torch.load(path, map_location=torch.device('cpu')), strict=False)
14
  return model
15
 
16
- model_path = 'checkpoint_0_9113.bin'
17
  mimic = MimicTransformer(cutoff=512)
18
-
19
- related_tensor = torch.load('discharge_embeddings.pt')
20
-
21
- # get model and results
22
  mimic = read_model(model=mimic, path=model_path)
23
- all_summaries = pd.read_csv('all_summaries.csv')['SUMMARIES'].to_list()
24
-
25
  tokenizer = mimic.tokenizer
26
  mimic.eval()
27
 
28
- ex1 = """Radiologic studies also included a chest CT, which confirmed cavitary lesions in the left lung apex consistent with infectious tuberculosis. This also moderate-sized left pleural effusion."""
29
- ex2 = """We have discharged Mrs Smith on regular oral Furosemide (40mg OD) and we have requested an outpatient ultrasound of her renal tract which will be performed in the next few weeks. We will review Mrs Smith in the Cardiology Outpatient Clinic in 6 weeks time."""
30
- ex3 = """Blood tests revealed a raised BNP. An ECG showed evidence of left-ventricular hypertrophy and echocardiography revealed grossly impaired ventricular function (ejection fraction 35%). A chest X-ray demonstrated bilateral pleural effusions, with evidence of upper lobe diversion."""
31
- ex4 = """Mrs Smith presented to A&E with worsening shortness of breath and ankle swelling. On arrival, she was tachypnoeic and hypoxic (oxygen saturation 82% on air). Clinical examination revealed reduced breath sounds and dullness to percussion in both lung bases. There was also a significant degree of lower limb oedema extending up to the mid-thigh bilaterally."""
32
- examples = [ex1, ex2, ex3, ex4]
 
 
 
 
 
 
33
  related_summaries = [[ex1]]
34
  related_chosen = []
35
  related_attn = []
@@ -59,9 +65,14 @@ def get_model_results(text):
59
  'logits': logits
60
  }
61
 
62
- def find_related_summaries(raw_embedding):
63
- raw_embedding = torch.nn.functional.normalize(raw_embedding)
64
- scores = torch.mm(related_tensor, raw_embedding.transpose(1,0))
 
 
 
 
 
65
  scores_indices = scores.topk(k=5, dim=0)
66
  indices, scores = scores_indices[-1], torch.round(100 * scores_indices[0], decimals=2)
67
  summaries = []
@@ -74,8 +85,13 @@ def find_related_summaries(raw_embedding):
74
 
75
 
76
  def run(text, related_discharges=False):
 
77
  model_results = get_model_results(text=text)
78
  drg_code = model_results['class']
 
 
 
 
79
  drg_link = get_drg_link(drg_code=drg_code)
80
  icd_results = get_icd_annotations(text=text)
81
  row = rule_df[rule_df['DRG_CODE'] == drg_code]
@@ -85,7 +101,7 @@ def run(text, related_discharges=False):
85
  model_results['icd_results'] = icd_results
86
  global related_summaries
87
  # related_summaries = generate_similar_summeries()
88
- related_summaries = find_related_summaries(model_results['logits'])
89
  if related_discharges:
90
  return visualize_attn(model_results=model_results)
91
  return (
@@ -193,10 +209,11 @@ def main():
193
 
194
  # input to related summaries
195
  with gr.Row() as row:
196
- input_related = gr.TextArea(label="Input up to 3 Related Discharge Summary/Summaries Here", visible=False)
197
- with gr.Row() as row:
198
- rmv_related_btn = gr.Button(value='Remove Related Summary', visible=False)
199
- sbm_btn = gr.Button(value="Submit Related Summaries", components=[input_related], visible=False)
 
200
 
201
  with gr.Row() as row:
202
  related = gr.Dataset(samples=[], components=[input_related], visible=False, type='index')
 
4
  import torch
5
 
6
  from model import MimicTransformer
7
+ from utils import load_rule, get_attribution, get_diseases, get_drg_link, get_icd_annotations, visualize_attn
8
+ from transformers import AutoTokenizer, AutoModel, set_seed, pipeline
9
 
10
  set_seed(42)
11
+ model_path = 'checkpoint_0_9113.bin'
12
+ related_tensor = torch.load('discharge_embeddings.pt')
13
+ all_summaries = pd.read_csv('all_summaries.csv')['SUMMARIES'].to_list()
14
+
15
+ similarity_tokenizer = AutoTokenizer.from_pretrained('kamalkraj/BioSimCSE-BioLinkBERT-BASE')
16
+ similarity_model = AutoModel.from_pretrained('kamalkraj/BioSimCSE-BioLinkBERT-BASE')
17
+ similarity_model.eval()
18
 
19
  def read_model(model, path):
20
  model.load_state_dict(torch.load(path, map_location=torch.device('cpu')), strict=False)
21
  return model
22
 
 
23
  mimic = MimicTransformer(cutoff=512)
 
 
 
 
24
  mimic = read_model(model=mimic, path=model_path)
 
 
25
  tokenizer = mimic.tokenizer
26
  mimic.eval()
27
 
28
+ # disease ner model
29
+ pipe = pipeline("token-classification", model="alvaroalon2/biobert_diseases_ner")
30
+
31
+ #
32
+
33
+ ex1 = """HEAD CT: Head CT showed no intracranial hemorrhage or mass effect, but old infarction consistent with past medical history."""
34
+ ex2 = """Radiologic studies also included a chest CT, which confirmed cavitary lesions in the left lung apex consistent with infectious tuberculosis. This also moderate-sized left pleural effusion."""
35
+ ex3 = """We have discharged Mrs Smith on regular oral Furosemide (40mg OD) and we have requested an outpatient ultrasound of her renal tract which will be performed in the next few weeks. We will review Mrs Smith in the Cardiology Outpatient Clinic in 6 weeks time."""
36
+ ex4 = """Blood tests revealed a raised BNP. An ECG showed evidence of left-ventricular hypertrophy and echocardiography revealed grossly impaired ventricular function (ejection fraction 35%). A chest X-ray demonstrated bilateral pleural effusions, with evidence of upper lobe diversion."""
37
+ ex5 = """Mrs Smith presented to A&E with worsening shortness of breath and ankle swelling. On arrival, she was tachypnoeic and hypoxic (oxygen saturation 82% on air). Clinical examination revealed reduced breath sounds and dullness to percussion in both lung bases. There was also a significant degree of lower limb oedema extending up to the mid-thigh bilaterally."""
38
+ examples = [ex1, ex2, ex3, ex4, ex5]
39
  related_summaries = [[ex1]]
40
  related_chosen = []
41
  related_attn = []
 
65
  'logits': logits
66
  }
67
 
68
+ def find_related_summaries(text):
69
+ inputs = similarity_tokenizer(
70
+ text, padding='max_length', truncation=True, return_tensors='pt', max_length=512
71
+ )
72
+ outputs = similarity_model(**inputs)
73
+ embedding = mean_pooling(outputs, attention_mask=inputs.attention_mask)
74
+ embedding = torch.nn.functional.normalize(embedding)
75
+ scores = torch.mm(related_tensor, embedding.transpose(1,0))
76
  scores_indices = scores.topk(k=5, dim=0)
77
  indices, scores = scores_indices[-1], torch.round(100 * scores_indices[0], decimals=2)
78
  summaries = []
 
85
 
86
 
87
  def run(text, related_discharges=False):
88
+ # initial drg results
89
  model_results = get_model_results(text=text)
90
  drg_code = model_results['class']
91
+
92
+ # find diseases
93
+ diseases = get_diseases(text=text, pipe=pipe)
94
+ model_results['diseases'] = diseases
95
  drg_link = get_drg_link(drg_code=drg_code)
96
  icd_results = get_icd_annotations(text=text)
97
  row = rule_df[rule_df['DRG_CODE'] == drg_code]
 
101
  model_results['icd_results'] = icd_results
102
  global related_summaries
103
  # related_summaries = generate_similar_summeries()
104
+ related_summaries = find_related_summaries(text=text)
105
  if related_discharges:
106
  return visualize_attn(model_results=model_results)
107
  return (
 
209
 
210
  # input to related summaries
211
  with gr.Row() as row:
212
+ with gr.Column(scale=5) as col:
213
+ input_related = gr.TextArea(label="Input up to 3 Related Discharge Summary/Summaries Here", visible=False)
214
+ with gr.Column(scale=1) as col:
215
+ rmv_related_btn = gr.Button(value='Remove Related Summary', visible=False)
216
+ sbm_btn = gr.Button(value="Submit Related Summaries", components=[input_related], visible=False)
217
 
218
  with gr.Row() as row:
219
  related = gr.Dataset(samples=[], components=[input_related], visible=False, type='index')
discharge_embeddings.pt CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:a2f32db3a0e8f504091f5ae649dd9c0d90368497310343518a8b64cb15a0cfc5
3
- size 29520786
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5596bf755d73898c6544b6254c2415283e2891deb6e3748f51ae9fb10794baee
3
+ size 30720786
requirements.txt CHANGED
@@ -4,3 +4,4 @@ gradio
4
  transformers
5
  captum
6
  tqdm
 
 
4
  transformers
5
  captum
6
  tqdm
7
+ sentence-transformers
utils.py CHANGED
@@ -20,6 +20,28 @@ class PyTMinMaxScalerVectorized(object):
20
  scale = 1.0 / (tensor.max(dim=0, keepdim=True)[0] - tensor.min(dim=0, keepdim=True)[0])
21
  tensor.mul_(scale).sub_(tensor.min(dim=0, keepdim=True)[0])
22
  return tensor
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
  def find_end(text):
25
  """Find the end of the report."""
@@ -230,7 +252,12 @@ def visualize_attn(model_results):
230
  raw_input_ids=tokens,
231
  convergence_score=1
232
  )
233
- return visualize_text(viz_record, drg_link=model_results['drg_link'], icd_annotations=model_results['icd_results'])
 
 
 
 
 
234
 
235
 
236
  def modify_attn_html(attn_html):
@@ -238,7 +265,7 @@ def modify_attn_html(attn_html):
238
  htmls = [attn_split[0]]
239
  for html in attn_split[1:]:
240
  # wrap around href tag
241
- href_html = f'<a href="https://espn.com" \
242
  <mark{html} \
243
  </a>'
244
  htmls.append(href_html)
@@ -258,36 +285,51 @@ def get_icd_html(icd_list):
258
  if len(icd_list) == 0:
259
  return '<td><text style="padding-right:2em"><b>N/A</b></text></td>'
260
  final_html = '<td>'
 
261
  for icd_dict in icd_list:
262
  text, link = icd_dict['text'], icd_dict['link']
 
 
263
  tmp_html = visualization.format_classname(classname=text)
264
  html = modify_code_html(html=tmp_html, link=link, icd=True)
265
  final_html += html
 
266
  return final_html + '</td>'
267
 
 
 
 
 
 
 
 
 
 
268
 
269
 
270
  # copied out of captum because we need raw html instead of a jupyter widget
271
- def visualize_text(datarecord, drg_link, icd_annotations):
272
  dom = ["<table width: 100%>"]
273
  rows = [
274
  "<th style='text-align: left'>Predicted DRG</th>"
275
  "<th style='text-align: left'>Word Importance</th>"
 
276
  "<th style='text-align: left'>ICD Codes</th>"
277
  ]
278
  pred_class_html = visualization.format_classname(datarecord.pred_class)
279
  icd_class_html = get_icd_html(icd_annotations)
 
280
  pred_class_html = modify_drg_html(html=pred_class_html, drg_link=drg_link)
281
  word_attn_html = visualization.format_word_importances(
282
  datarecord.raw_input_ids, datarecord.word_attributions
283
  )
284
- word_attn_html = modify_attn_html(word_attn_html)
285
  rows.append(
286
  "".join(
287
  [
288
  "<tr>",
289
  pred_class_html,
290
  word_attn_html,
 
291
  icd_class_html,
292
  "<tr>",
293
  ]
 
20
  scale = 1.0 / (tensor.max(dim=0, keepdim=True)[0] - tensor.min(dim=0, keepdim=True)[0])
21
  tensor.mul_(scale).sub_(tensor.min(dim=0, keepdim=True)[0])
22
  return tensor
23
+
24
+ def get_diseases(text, pipe):
25
+ results = pipe(text)
26
+ diseases = []
27
+ disease_span = []
28
+ for result in results:
29
+ ent = result['entity']
30
+ # start of a new entity
31
+ if ent == 'B-DISEASE':
32
+ disease_span = result['start'], result['end']
33
+ elif ent == 'I-DISEASE':
34
+ disease_span = disease_span[0], result['end']
35
+ else:
36
+ if len(disease_span) > 1:
37
+ disease = text[disease_span[0]: disease_span[1]]
38
+ if len(disease) > 2:
39
+ diseases.append(disease)
40
+ disease_span = []
41
+ if len(disease_span) > 1:
42
+ disease = text[disease_span[0]: disease_span[1]]
43
+ diseases.append(disease)
44
+ return diseases
45
 
46
  def find_end(text):
47
  """Find the end of the report."""
 
252
  raw_input_ids=tokens,
253
  convergence_score=1
254
  )
255
+ return visualize_text(
256
+ viz_record,
257
+ drg_link=model_results['drg_link'],
258
+ icd_annotations=model_results['icd_results'],
259
+ diseases=model_results['diseases']
260
+ )
261
 
262
 
263
  def modify_attn_html(attn_html):
 
265
  htmls = [attn_split[0]]
266
  for html in attn_split[1:]:
267
  # wrap around href tag
268
+ href_html = f'<a href="https://" \
269
  <mark{html} \
270
  </a>'
271
  htmls.append(href_html)
 
285
  if len(icd_list) == 0:
286
  return '<td><text style="padding-right:2em"><b>N/A</b></text></td>'
287
  final_html = '<td>'
288
+ icd_set = set()
289
  for icd_dict in icd_list:
290
  text, link = icd_dict['text'], icd_dict['link']
291
+ if text in icd_set:
292
+ continue
293
  tmp_html = visualization.format_classname(classname=text)
294
  html = modify_code_html(html=tmp_html, link=link, icd=True)
295
  final_html += html
296
+ icd_set.add(text)
297
  return final_html + '</td>'
298
 
299
+
300
+ def get_disease_html(diseases):
301
+ if len(diseases) == 0:
302
+ return '<td><text style="padding-right:2em"><b>N/A</b></text></td>'
303
+ diseases = list(set(diseases))
304
+ diseases_str = ', '.join(diseases)
305
+ html = visualization.format_classname(classname=diseases_str)
306
+ return html + '</td>'
307
+
308
 
309
 
310
  # copied out of captum because we need raw html instead of a jupyter widget
311
+ def visualize_text(datarecord, drg_link, icd_annotations, diseases):
312
  dom = ["<table width: 100%>"]
313
  rows = [
314
  "<th style='text-align: left'>Predicted DRG</th>"
315
  "<th style='text-align: left'>Word Importance</th>"
316
+ "<th style='text-align: left'>Diseases</th>"
317
  "<th style='text-align: left'>ICD Codes</th>"
318
  ]
319
  pred_class_html = visualization.format_classname(datarecord.pred_class)
320
  icd_class_html = get_icd_html(icd_annotations)
321
+ disease_html = get_disease_html(diseases)
322
  pred_class_html = modify_drg_html(html=pred_class_html, drg_link=drg_link)
323
  word_attn_html = visualization.format_word_importances(
324
  datarecord.raw_input_ids, datarecord.word_attributions
325
  )
 
326
  rows.append(
327
  "".join(
328
  [
329
  "<tr>",
330
  pred_class_html,
331
  word_attn_html,
332
+ disease_html,
333
  icd_class_html,
334
  "<tr>",
335
  ]