Spaces:
Sleeping
Sleeping
from transformers import AutoConfig, AutoTokenizer, AutoModelForSequenceClassification, set_seed | |
from torch.utils.data import DataLoader | |
from torch.nn import Linear, Module | |
from typing import Dict, List | |
from collections import Counter, defaultdict | |
from itertools import chain | |
import torch | |
torch.manual_seed(0) | |
set_seed(34) | |
if torch.cuda.is_available(): | |
torch.cuda.manual_seed_all(0) | |
class MimicTransformer(Module): | |
def __init__(self, num_labels=738, tokenizer_name='clinical', cutoff=512): | |
""" | |
:param args: | |
""" | |
super().__init__() | |
self.tokenizer_name = self.find_tokenizer(tokenizer_name) | |
self.num_labels = num_labels | |
self.config = AutoConfig.from_pretrained(self.tokenizer_name, num_labels=self.num_labels) | |
self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name, config=self.config) | |
self.model = AutoModelForSequenceClassification.from_pretrained(self.tokenizer_name, config=self.config) | |
self.model.eval() | |
if 'longformer' in self.tokenizer_name: | |
self.cutoff = self.model.config.max_position_embeddings | |
else: | |
self.cutoff = cutoff | |
self.linear = Linear(in_features=self.cutoff, out_features=1) | |
def parse_icds(self, instances: List[Dict]): | |
token_list = defaultdict(set) | |
token_freq_list = [] | |
for instance in instances: | |
icds = list(chain(*instance['icd'])) | |
icd_dict_list = list({icd['start']: icd for icd in icds}.values()) | |
for icd_dict in icd_dict_list: | |
icd_ent = icd_dict['text'] | |
icd_tokenized = self.tokenizer(icd_ent, add_special_tokens=False)['input_ids'] | |
icd_dict['tokens'] = icd_tokenized | |
icd_dict['labels'] = [] | |
for i,token in enumerate(icd_tokenized): | |
if i != 0: | |
label = "I-ATTN" | |
else: | |
label = "B-ATTN" | |
icd_dict['labels'].append(label) | |
token_list[token].add(label) | |
token_freq_list.append(str(token) + ": " + label) | |
token_tag_freqs = Counter(token_freq_list) | |
for token in token_list: | |
if len(token_list[token]) == 2: | |
inside_count = token_tag_freqs[str(token) + ": I-ATTN"] | |
begin_count = token_tag_freqs[str(token) + ": B-ATTN"] | |
if begin_count > inside_count: | |
token_list[token].remove('I-ATTN') | |
else: | |
token_list[token].remove('B-ATTN') | |
return token_list | |
def collate_mimic( | |
self, instances: List[Dict], device='cuda' | |
): | |
tokenized = [ | |
self.tokenizer.encode( | |
' '.join(instance['description']), max_length=self.cutoff, truncation=True, padding='max_length' | |
) for instance in instances | |
] | |
entries = [instance['entry'] for instance in instances] | |
labels = torch.tensor([x['drg'] for x in instances], dtype=torch.long).to(device).unsqueeze(1) | |
inputs = torch.tensor(tokenized, dtype=torch.long).to(device) | |
icds = self.parse_icds(instances) | |
xai_labels = torch.zeros(size=inputs.shape, dtype=torch.float32).to(device) | |
for i,row in enumerate(inputs): | |
for j,ele in enumerate(row): | |
if ele.item() in icds: | |
xai_labels[i][j] = 1 | |
return { | |
'text': inputs, | |
'drg': labels, | |
'entry': entries, | |
'icds': icds, | |
'xai': xai_labels | |
} | |
def forward(self, input_ids, attention_mask=None, drg_labels=None): | |
if drg_labels: | |
cls_results = self.model(input_ids, attention_mask=attention_mask, labels=drg_labels, output_attentions=True) | |
else: | |
cls_results = self.model(input_ids, attention_mask=attention_mask, output_attentions=True) | |
# last_attn = cls_results[-1][-1] # (batch, attn_heads, tokens, tokens) | |
last_attn = torch.mean(torch.stack(cls_results[-1])[:], dim=0) | |
last_layer_attn = torch.mean(last_attn[:, :-3, :, :], dim=1) | |
xai_logits = self.linear(last_layer_attn).squeeze(dim=-1) | |
return (cls_results, xai_logits) | |
def find_tokenizer(self, tokenizer_name): | |
""" | |
:param args: | |
:return: | |
""" | |
if tokenizer_name == 'clinical_longformer': | |
return 'yikuan8/Clinical-Longformer' | |
if tokenizer_name == 'clinical': | |
return 'emilyalsentzer/Bio_ClinicalBERT' | |
else: | |
# standard transformer | |
return 'bert-based-uncased' |