Clinical_Decisions / model.py
mohdelgaar's picture
update
16b175f
import copy
import torch
from torch import nn
from transformers import AutoModel
from torch.optim import AdamW
from transformers import get_linear_schedule_with_warmup
# from torchcrf import CRF
class MyModel(nn.Module):
def __init__(self, args, backbone):
super().__init__()
self.args = args
self.backbone = backbone
self.cls_id = 0
hidden_dim = self.backbone.config.hidden_size
self.classifier = nn.Sequential(
nn.Dropout(0.1),
nn.Linear(hidden_dim, args.num_labels)
)
if args.distil_att:
self.distil_att = nn.Parameter(torch.ones(self.backbone.config.hidden_size))
def forward(self, x, mask):
x = x.to(self.backbone.device)
mask = mask.to(self.backbone.device)
out = self.backbone(x, attention_mask = mask, output_attentions=True)
return out, self.classifier(out.last_hidden_state)
def decisions(self, x, mask):
x = x.to(self.backbone.device)
mask = mask.to(self.backbone.device)
out = self.backbone(x, attention_mask = mask, output_attentions=False)
return out, self.classifier(out.last_hidden_state)
def phenos(self, x, mask):
x = x.to(self.backbone.device)
mask = mask.to(self.backbone.device)
out = self.backbone(x, attention_mask = mask, output_attentions=True)
return out, self.classifier(out.pooler_output)
def generate(self, x, mask, choice=None):
outs = []
if self.args.task == 'seq' or choice == 'seq':
for i, offset in enumerate(range(0, x.shape[1], self.args.max_len-1)):
if i == 0:
segment = x[:, offset:offset + self.args.max_len-1]
segment_mask = mask[:, offset:offset + self.args.max_len-1]
else:
segment = torch.cat((torch.ones((x.shape[0], 1), dtype=int).to(x.device)\
*self.cls_id,
x[:, offset:offset + self.args.max_len-1]), axis=1)
segment_mask = torch.cat((torch.ones((mask.shape[0], 1)).to(mask.device),
mask[:, offset:offset + self.args.max_len-1]), axis=1)
logits = self.phenos(segment, segment_mask)[1]
outs.append(logits)
return torch.max(torch.stack(outs, 1), 1).values
elif self.args.task == 'token':
for i, offset in enumerate(range(0, x.shape[1], self.args.max_len)):
segment = x[:, offset:offset + self.args.max_len]
segment_mask = mask[:, offset:offset + self.args.max_len]
h = self.decisions(segment, segment_mask)[0].last_hidden_state
outs.append(h)
h = torch.cat(outs, 1)
return self.classifier(h)
class CNN(nn.Module):
def __init__(self, args):
super().__init__()
self.emb = nn.Embedding(args.vocab_size, args.emb_size)
self.model = nn.Sequential(
nn.Conv1d(args.emb_size, args.hidden_size, args.kernels[0],
padding='same' if args.task == 'token' else 'valid'),
nn.ReLU(),
nn.MaxPool1d(1),
nn.Conv1d(args.hidden_size, args.hidden_size, args.kernels[1],
padding='same' if args.task == 'token' else 'valid'),
nn.ReLU(),
nn.MaxPool1d(1),
nn.Conv1d(args.hidden_size, args.hidden_size, args.kernels[2],
padding='same' if args.task == 'token' else 'valid'),
nn.ReLU(),
nn.MaxPool1d(1),
)
if args.task == 'seq':
out_shape = 512 - args.kernels[0] - args.kernels[1] - args.kernels[2] + 3
elif args.task == 'token':
out_shape = 1
self.classifier = nn.Linear(args.hidden_size*out_shape, args.num_labels)
self.dropout = nn.Dropout()
self.args = args
self.device = None
def forward(self, x, _):
x = x.to(self.device)
bs = x.shape[0]
x = self.emb(x)
x = x.transpose(1,2)
x = self.model(x)
x = self.dropout(x)
if self.args.task == 'token':
x = x.transpose(1,2)
h = self.classifier(x)
return x, h
elif self.args.task == 'seq':
x = x.reshape(bs, -1)
x = self.classifier(x)
return x
def generate(self, x, _):
outs = []
for i, offset in enumerate(range(0, x.shape[1], self.args.max_len)):
segment = x[:, offset:offset + self.args.max_len]
n = segment.shape[1]
if n != self.args.max_len:
segment = torch.nn.functional.pad(segment, (0, self.args.max_len - n))
if self.args.task == 'seq':
logits = self(segment, None)
outs.append(logits)
elif self.args.task == 'token':
h = self(segment, None)[0]
h = h[:,:n]
outs.append(h)
if self.args.task == 'seq':
return torch.max(torch.stack(outs, 1), 1).values
elif self.args.task == 'token':
h = torch.cat(outs, 1)
return self.classifier(h)
class LSTM(nn.Module):
def __init__(self, args):
super().__init__()
self.emb = nn.Embedding(args.vocab_size, args.emb_size)
self.model = nn.LSTM(args.emb_size, args.hidden_size, num_layers=args.num_layers,
batch_first=True, bidirectional=True)
dim = 2*args.num_layers*args.hidden_size if args.task == 'seq' else 2*args.hidden_size
self.classifier = nn.Linear(dim, args.num_labels)
self.dropout = nn.Dropout()
self.args = args
self.device = None
def forward(self, x, _):
x = x.to(self.device)
x = self.emb(x)
o, (x, _) = self.model(x)
o_out = self.classifier(o) if self.args.task == 'token' else None
if self.args.task == 'seq':
x = torch.cat([h for h in x], 1)
x = self.dropout(x)
x = self.classifier(x)
return (x, o), o_out
def generate(self, x, _):
outs = []
for i, offset in enumerate(range(0, x.shape[1], self.args.max_len)):
segment = x[:, offset:offset + self.args.max_len]
if self.args.task == 'seq':
logits = self(segment, None)[0][0]
outs.append(logits)
elif self.args.task == 'token':
h = self(segment, None)[0][1]
outs.append(h)
if self.args.task == 'seq':
return torch.max(torch.stack(outs, 1), 1).values
elif self.args.task == 'token':
h = torch.cat(outs, 1)
return self.classifier(h)
def load_model(args, device):
if args.model == 'lstm':
model = LSTM(args).to(device)
model.device = device
elif args.model == 'cnn':
model = CNN(args).to(device)
model.device = device
else:
model = MyModel(args, AutoModel.from_pretrained(args.model_name)).to(device)
if args.ckpt:
model.load_state_dict(torch.load(args.ckpt, map_location=device), strict=False)
if args.distil:
args2 = copy.deepcopy(args)
args2.task = 'token'
# args2.num_labels = args.num_decs
args2.num_labels = args.num_umls_tags
model_B = MyModel(args2, AutoModel.from_pretrained(args.model_name)).to(device)
model_B.load_state_dict(torch.load(args.distil_ckpt, map_location=device), strict=False)
for p in model_B.parameters():
p.requires_grad = False
else:
model_B = None
if args.label_encoding == 'multiclass':
if args.use_crf:
crit = CRF(args.num_labels, batch_first = True).to(device)
else:
crit = nn.CrossEntropyLoss(reduction='none')
else:
crit = nn.BCEWithLogitsLoss(
pos_weight=torch.ones(args.num_labels).to(device)*args.pos_weight,
reduction='none'
)
optimizer = AdamW(model.parameters(), lr=args.lr)
lr_scheduler = get_linear_schedule_with_warmup(optimizer,
int(0.1*args.total_steps), args.total_steps)
return model, crit, optimizer, lr_scheduler, model_B