|
import numpy as np |
|
import torch |
|
|
|
from transformers import glue_convert_examples_to_features as convert_examples_to_features |
|
from transformers import InputExample |
|
|
|
class MyClassifier(): |
|
def __init__(self,model,tokenizer,label_list,output_mode,exit_type,exit_value,model_type='albert',max_length=128): |
|
self.model = model |
|
self.model.eval() |
|
self.model_type = model_type |
|
self.tokenizer = tokenizer |
|
self.label_list = label_list |
|
self.output_mode = output_mode |
|
self.max_length = max_length |
|
self.exit_type = exit_type |
|
self.exit_value = exit_value |
|
self.count = 0 |
|
self.reset_status(mode='all',stats=True) |
|
if exit_type == 'patience': |
|
self.set_patience(patience=exit_value) |
|
elif exit_type == 'confi': |
|
self.set_threshold(confidence_threshold=exit_value) |
|
|
|
def tokenize(self,input_,idx): |
|
examples = [] |
|
guid = f"dev_{idx}" |
|
if input_[1] == "<none>": |
|
text_a = input_[0] |
|
text_b = None |
|
else: |
|
text_a = input_[0] |
|
text_b = input_[1] |
|
|
|
label = None |
|
examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) |
|
|
|
features = convert_examples_to_features( |
|
examples, |
|
self.tokenizer, |
|
label_list=self.label_list, |
|
max_length=self.max_length, |
|
output_mode=self.output_mode, |
|
) |
|
|
|
all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long) |
|
all_attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long) |
|
all_token_type_ids = torch.tensor([f.token_type_ids for f in features], dtype=torch.long) |
|
return all_input_ids,all_attention_mask,all_token_type_ids |
|
|
|
def set_threshold(self,confidence_threshold): |
|
if self.model_type == 'albert': |
|
self.model.albert.set_confi_threshold(confidence_threshold) |
|
elif self.model_type == 'bert': |
|
self.model.bert.set_confi_threshold(confidence_threshold) |
|
|
|
def set_patience(self,patience): |
|
if self.model_type == 'albert': |
|
self.model.albert.set_patience(patience) |
|
elif self.model_type == 'bert': |
|
self.model.bert.set_patience(patience) |
|
|
|
def set_exit_position(self,exit_pos): |
|
if self.model_type == 'albert': |
|
self.model.albert.set_exit_pos = exit_pos |
|
|
|
def reset_status(self,mode,stats=False): |
|
if self.model_type == 'albert': |
|
self.model.albert.set_mode(mode) |
|
if stats: |
|
self.model.albert.reset_stats() |
|
elif self.model_type == 'bert': |
|
self.model.bert.set_mode(mode) |
|
if stats: |
|
self.model.bert.reset_stats() |
|
|
|
def get_exit_number(self): |
|
if self.model_type == 'albert': |
|
return self.model.albert.config.num_hidden_layers |
|
elif self.model_type == 'bert': |
|
return self.model.bert.config.num_hidden_layers |
|
|
|
def get_current_exit(self): |
|
if self.model_type == 'albert': |
|
return self.model.albert.current_exit_layer |
|
elif self.model_type == 'bert': |
|
return self.model.bert.current_exit_layer |
|
|
|
|
|
def get_pred(self,input_): |
|
|
|
return self.get_prob(input_).argmax(axis=2) |
|
|
|
def get_prob(self,input_): |
|
self.reset_status(mode=self.exit_type,stats=False) |
|
ret = [] |
|
for sent in input_: |
|
self.count+=1 |
|
batch = self.tokenize(sent,idx=self.count) |
|
inputs = {"input_ids": batch[0], "attention_mask": batch[1],"token_type_ids":batch[2]} |
|
outputs = self.model(**inputs)[0] |
|
output_ = torch.softmax(outputs,dim=1)[0].detach().cpu().numpy() |
|
ret.append(output_) |
|
return np.array(ret) |
|
|
|
def get_prob_time(self,input_,exit_position): |
|
self.reset_status(mode='exact',stats=False) |
|
self.set_exit_position(exit_position) |
|
ret = [] |
|
for sent in input_: |
|
self.count+=1 |
|
batch = self.tokenize(sent,idx=self.count) |
|
inputs = {"input_ids": batch[0], "attention_mask": batch[1], "token_type_ids":batch[2]} |
|
outputs = self.model(**inputs)[0] |
|
output_ = [torch.softmax(output,dim=1)[0].detach().cpu().numpy() for output in outputs] |
|
ret.append(output_) |
|
return np.array(ret) |