File size: 4,575 Bytes
f73dc21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
from transformers import AutoConfig, AutoTokenizer, AutoModelForSequenceClassification
from torch.utils.data import DataLoader
from torch.nn import Linear, Module
from typing import Dict, List
from collections import Counter, defaultdict
from itertools import chain
import torch

class MimicTransformer(Module):
    def __init__(self, num_labels=738, tokenizer_name='clinical', cutoff=512):
        """
        :param args:
        """
        super().__init__()
        self.tokenizer_name = self.find_tokenizer(tokenizer_name)
        self.num_labels = num_labels
        self.config = AutoConfig.from_pretrained(self.tokenizer_name, num_labels=self.num_labels)
        self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name, config=self.config)
        self.model = AutoModelForSequenceClassification.from_pretrained(self.tokenizer_name, config=self.config)
        if 'longformer' in self.tokenizer_name:
            self.cutoff = self.model.config.max_position_embeddings
        else:
            self.cutoff = cutoff
        self.linear = Linear(in_features=self.cutoff, out_features=1)
 
    def parse_icds(self, instances: List[Dict]):
        token_list = defaultdict(set)
        token_freq_list = []
        for instance in instances:
            icds = list(chain(*instance['icd']))
            icd_dict_list = list({icd['start']: icd for icd in icds}.values())
            for icd_dict in icd_dict_list:
                icd_ent = icd_dict['text']
                icd_tokenized = self.tokenizer(icd_ent, add_special_tokens=False)['input_ids']
                icd_dict['tokens'] = icd_tokenized
                icd_dict['labels'] = []
                for i,token in enumerate(icd_tokenized):
                    if i != 0:
                        label = "I-ATTN"
                    else:
                        label = "B-ATTN"
                    icd_dict['labels'].append(label)
                    token_list[token].add(label)
                    token_freq_list.append(str(token) + ": " + label)
        token_tag_freqs = Counter(token_freq_list)
        for token in token_list:
            if len(token_list[token]) == 2:
                inside_count = token_tag_freqs[str(token) + ": I-ATTN"]
                begin_count = token_tag_freqs[str(token) + ": B-ATTN"]
                if begin_count > inside_count:
                    token_list[token].remove('I-ATTN')
                else:
                    token_list[token].remove('B-ATTN')
        return token_list
    

    def collate_mimic(
            self, instances: List[Dict], device='cuda'
    ):
        tokenized = [
            self.tokenizer.encode(
                ' '.join(instance['description']), max_length=self.cutoff, truncation=True, padding='max_length'
            ) for instance in instances
        ]
        entries = [instance['entry'] for instance in instances]
        labels = torch.tensor([x['drg'] for x in instances], dtype=torch.long).to(device).unsqueeze(1)
        inputs = torch.tensor(tokenized, dtype=torch.long).to(device)
        icds = self.parse_icds(instances)
        xai_labels = torch.zeros(size=inputs.shape, dtype=torch.float32).to(device)
        for i,row in enumerate(inputs):
            for j,ele in enumerate(row):
                if ele.item() in icds:
                    xai_labels[i][j] = 1
        return {
            'text': inputs,
            'drg': labels,
            'entry': entries,
            'icds': icds,
            'xai': xai_labels
        }

    def forward(self, input_ids, attention_mask=None, drg_labels=None):
        if drg_labels:
            cls_results = self.model(input_ids, attention_mask=attention_mask, labels=drg_labels, output_attentions=True)
        else:
            cls_results = self.model(input_ids, attention_mask=attention_mask, output_attentions=True)
        # last_attn = cls_results[-1][-1] # (batch, attn_heads, tokens, tokens)
        last_attn = torch.mean(torch.stack(cls_results[-1])[:], dim=0)
        last_layer_attn = torch.mean(last_attn[:, :-3, :, :], dim=1)
        xai_logits = self.linear(last_layer_attn).squeeze(dim=-1)
        return (cls_results, xai_logits)
    
    def find_tokenizer(self, tokenizer_name):
        """
    
        :param args:
        :return:
        """
        if tokenizer_name == 'clinical_longformer':
            return 'yikuan8/Clinical-Longformer'
        if tokenizer_name == 'clinical':
            return 'emilyalsentzer/Bio_ClinicalBERT'
        else:
            # standard transformer
            return 'bert-based-uncased'