test / whitebox_utils /classifier.py
adamtayzzz's picture
Update whitebox_utils/classifier.py
b84d9de
raw
history blame
No virus
4.74 kB
import numpy as np
import torch
from transformers import glue_convert_examples_to_features as convert_examples_to_features
from transformers import InputExample
class MyClassifier():
def __init__(self,model,tokenizer,label_list,output_mode,exit_type,exit_value,model_type='albert',max_length=128):
self.model = model
self.model.eval()
self.model_type = model_type
self.tokenizer = tokenizer
self.label_list = label_list
self.output_mode = output_mode
self.max_length = max_length
self.exit_type = exit_type
self.exit_value = exit_value
self.count = 0
self.reset_status(mode='all',stats=True)
if exit_type == 'patience':
self.set_patience(patience=exit_value)
elif exit_type == 'confi':
self.set_threshold(confidence_threshold=exit_value)
def tokenize(self,input_,idx):
examples = []
guid = f"dev_{idx}"
if input_[1] == "<none>":
text_a = input_[0]
text_b = None
else:
text_a = input_[0]
text_b = input_[1]
# print(f'len: {len(input_)}\t text_a: {text_a}\t text_b:{text_b}')
label = None
examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
# print(examples)
features = convert_examples_to_features(
examples,
self.tokenizer,
label_list=self.label_list,
max_length=self.max_length,
output_mode=self.output_mode,
)
# print(features)
all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
all_attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long)
all_token_type_ids = torch.tensor([f.token_type_ids for f in features], dtype=torch.long)
return all_input_ids,all_attention_mask,all_token_type_ids
def set_threshold(self,confidence_threshold):
if self.model_type == 'albert':
self.model.albert.set_confi_threshold(confidence_threshold)
elif self.model_type == 'bert':
self.model.bert.set_confi_threshold(confidence_threshold)
def set_patience(self,patience):
if self.model_type == 'albert':
self.model.albert.set_patience(patience)
elif self.model_type == 'bert':
self.model.bert.set_patience(patience)
def set_exit_position(self,exit_pos):
if self.model_type == 'albert':
self.model.albert.set_exit_pos(exit_pos)
def reset_status(self,mode,stats=False):
if self.model_type == 'albert':
self.model.albert.set_mode(mode)
if stats:
self.model.albert.reset_stats()
elif self.model_type == 'bert':
self.model.bert.set_mode(mode)
if stats:
self.model.bert.reset_stats()
def get_exit_number(self):
if self.model_type == 'albert':
return self.model.albert.config.num_hidden_layers
elif self.model_type == 'bert':
return self.model.bert.config.num_hidden_layers
def get_current_exit(self):
if self.model_type == 'albert':
return self.model.albert.current_exit_layer
elif self.model_type == 'bert':
return self.model.bert.current_exit_layer
# TODO: 改一下预测算法得到预测结果
def get_pred(self,input_):
# print(self.get_prob(input_).argmax(axis=2).shape)
return self.get_prob(input_).argmax(axis=2)
def get_prob(self,input_):
self.reset_status(mode=self.exit_type,stats=False) # set patience
ret = []
for sent in input_:
self.count+=1
batch = self.tokenize(sent,idx=self.count)
inputs = {"input_ids": batch[0], "attention_mask": batch[1],"token_type_ids":batch[2]}
outputs = self.model(**inputs)[0] # get all logits
output_ = torch.softmax(outputs,dim=1)[0].detach().cpu().numpy()
ret.append(output_)
return np.array(ret)
def get_prob_time(self,input_,exit_position):
self.reset_status(mode='exact',stats=False) # set patience
self.set_exit_position(exit_position)
ret = []
for sent in input_:
self.count+=1
batch = self.tokenize(sent,idx=self.count)
inputs = {"input_ids": batch[0], "attention_mask": batch[1], "token_type_ids":batch[2]}
outputs = self.model(**inputs)[0] # get all logits
output_ = torch.softmax(outputs,dim=1)[0].detach().cpu().numpy()
ret.append(output_)
return np.array(ret)