import numpy as np import torch from transformers import glue_convert_examples_to_features as convert_examples_to_features from transformers import InputExample class MyClassifier(): def __init__(self,model,tokenizer,label_list,output_mode,exit_type,exit_value,model_type='albert',max_length=128): self.model = model self.model.eval() self.model_type = model_type self.tokenizer = tokenizer self.label_list = label_list self.output_mode = output_mode self.max_length = max_length self.exit_type = exit_type self.exit_value = exit_value self.count = 0 self.reset_status(mode='all',stats=True) if exit_type == 'patience': self.set_patience(patience=exit_value) elif exit_type == 'confi': self.set_threshold(confidence_threshold=exit_value) def tokenize(self,input_,idx): examples = [] guid = f"dev_{idx}" if input_[1] == "": text_a = input_[0] text_b = None else: text_a = input_[0] text_b = input_[1] # print(f'len: {len(input_)}\t text_a: {text_a}\t text_b:{text_b}') label = None examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) # print(examples) features = convert_examples_to_features( examples, self.tokenizer, label_list=self.label_list, max_length=self.max_length, output_mode=self.output_mode, ) # print(features) all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long) all_attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long) all_token_type_ids = torch.tensor([f.token_type_ids for f in features], dtype=torch.long) return all_input_ids,all_attention_mask,all_token_type_ids def set_threshold(self,confidence_threshold): if self.model_type == 'albert': self.model.albert.set_confi_threshold(confidence_threshold) elif self.model_type == 'bert': self.model.bert.set_confi_threshold(confidence_threshold) def set_patience(self,patience): if self.model_type == 'albert': self.model.albert.set_patience(patience) elif self.model_type == 'bert': self.model.bert.set_patience(patience) def set_exit_position(self,exit_pos): if self.model_type == 'albert': self.model.albert.set_exit_pos = exit_pos def reset_status(self,mode,stats=False): if self.model_type == 'albert': self.model.albert.set_mode(mode) if stats: self.model.albert.reset_stats() elif self.model_type == 'bert': self.model.bert.set_mode(mode) if stats: self.model.bert.reset_stats() def get_exit_number(self): if self.model_type == 'albert': return self.model.albert.config.num_hidden_layers elif self.model_type == 'bert': return self.model.bert.config.num_hidden_layers def get_current_exit(self): if self.model_type == 'albert': return self.model.albert.current_exit_layer elif self.model_type == 'bert': return self.model.bert.current_exit_layer # TODO: 改一下预测算法得到预测结果 def get_pred(self,input_): # print(self.get_prob(input_).argmax(axis=2).shape) return self.get_prob(input_).argmax(axis=2) def get_prob(self,input_): self.reset_status(mode=self.exit_type,stats=False) # set patience ret = [] for sent in input_: self.count+=1 batch = self.tokenize(sent,idx=self.count) inputs = {"input_ids": batch[0], "attention_mask": batch[1],"token_type_ids":batch[2]} outputs = self.model(**inputs)[0] # get all logits output_ = torch.softmax(outputs,dim=1)[0].detach().cpu().numpy() ret.append(output_) return np.array(ret) def get_prob_time(self,input_,exit_position): self.reset_status(mode='exact',stats=False) # set patience self.set_exit_position(exit_position) ret = [] for sent in input_: self.count+=1 batch = self.tokenize(sent,idx=self.count) inputs = {"input_ids": batch[0], "attention_mask": batch[1], "token_type_ids":batch[2]} outputs = self.model(**inputs)[0] # get all logits output_ = [torch.softmax(output,dim=1)[0].detach().cpu().numpy() for output in outputs] ret.append(output_) return np.array(ret)