danielhajialigol commited on
Commit
eca4ff8
1 Parent(s): 9901139

fix padding issue

Browse files
Files changed (2) hide show
  1. app.py +2 -1
  2. utils.py +3 -1
app.py CHANGED
@@ -6,7 +6,7 @@ import torch
6
  from model import MimicTransformer
7
  from utils import load_rule, get_attribution, get_diseases, get_drg_link, get_icd_annotations, visualize_attn
8
  from transformers import AutoTokenizer, AutoModel, set_seed, pipeline
9
- set_seed(42)
10
  model_path = 'checkpoint_0_9113.bin'
11
  related_tensor = torch.load('discharge_embeddings.pt')
12
  all_summaries = pd.read_csv('all_summaries.csv')['SUMMARIES'].to_list()
@@ -56,6 +56,7 @@ def get_model_results(text):
56
  logits = outputs[0][0]
57
  out = logits.detach().cpu()[0]
58
  drg_code = i2d[out.argmax().item()]
 
59
  prob = torch.nn.functional.softmax(out).max()
60
  return {
61
  'class': drg_code,
 
6
  from model import MimicTransformer
7
  from utils import load_rule, get_attribution, get_diseases, get_drg_link, get_icd_annotations, visualize_attn
8
  from transformers import AutoTokenizer, AutoModel, set_seed, pipeline
9
+
10
  model_path = 'checkpoint_0_9113.bin'
11
  related_tensor = torch.load('discharge_embeddings.pt')
12
  all_summaries = pd.read_csv('all_summaries.csv')['SUMMARIES'].to_list()
 
56
  logits = outputs[0][0]
57
  out = logits.detach().cpu()[0]
58
  drg_code = i2d[out.argmax().item()]
59
+ print(out.topk(5))
60
  prob = torch.nn.functional.softmax(out).max()
61
  return {
62
  'class': drg_code,
utils.py CHANGED
@@ -204,7 +204,9 @@ def tokenize_icds(tokenizer, annotations, token_ids):
204
 
205
  def get_attribution(text, tokenizer, model_outputs, inputs, k=7):
206
  tokens = tokenizer.convert_ids_to_tokens(inputs.input_ids[0])
207
- padding_idx = tokens.index('[PAD]')
 
 
208
  tokens = tokens[:padding_idx][1:-1]
209
  attn = model_outputs[-1][0]
210
  agg_attn, final_text = reconstruct_text(tokenizer=tokenizer, tokens=tokens, attn=attn)
 
204
 
205
  def get_attribution(text, tokenizer, model_outputs, inputs, k=7):
206
  tokens = tokenizer.convert_ids_to_tokens(inputs.input_ids[0])
207
+ padding_idx = 512
208
+ if '[PAD]' in tokens:
209
+ padding_idx = tokens.index('[PAD]')
210
  tokens = tokens[:padding_idx][1:-1]
211
  attn = model_outputs[-1][0]
212
  agg_attn, final_text = reconstruct_text(tokenizer=tokenizer, tokens=tokens, attn=attn)