Spaces:
Sleeping
Sleeping
| from sklearn.dummy import DummyClassifier | |
| from tqdm import tqdm | |
| import multiprocessing | |
| import numpy as np | |
| import tensorflow as tf | |
| from transformers import DistilBertTokenizerFast | |
| class PredictProba(DummyClassifier): | |
| def __init__(self, tflite_model_path: str, classes_: list, n_tokens: int): | |
| self.classes_ = classes_ # required attribute for an estimator to be used in calibration classifier | |
| self.n_tokens = n_tokens | |
| self.tflite_model_path = tflite_model_path | |
| def fit(self, x, y): | |
| print('called fit') | |
| return self # fit method is required for an estimator to be used in calibration classifier | |
| def get_token_batches(attention_mask, input_ids, batch_size: int=8): | |
| n_texts = len(attention_mask) | |
| n_batches = int(np.ceil(n_texts / batch_size)) | |
| if n_texts <= batch_size: | |
| n_batches = 1 | |
| attention_mask_batches = [] | |
| input_ids_batches = [] | |
| for i in range(n_batches): | |
| if i != n_batches-1: | |
| attention_mask_batches.append(attention_mask[i*batch_size: batch_size*(i+1)]) | |
| input_ids_batches.append(input_ids[i*batch_size: batch_size*(i+1)]) | |
| else: | |
| attention_mask_batches.append(attention_mask[i*batch_size:]) | |
| input_ids_batches.append(input_ids[i*batch_size:]) | |
| return attention_mask_batches, input_ids_batches | |
| def get_batch_inference(self, batch_size, attention_mask, input_ids): | |
| interpreter = tf.lite.Interpreter(model_path=self.tflite_model_path) | |
| interpreter.allocate_tensors() | |
| input_details = interpreter.get_input_details() | |
| output_details = interpreter.get_output_details()[0] | |
| interpreter.resize_tensor_input(input_details[0]['index'],[batch_size, self.n_tokens]) | |
| interpreter.resize_tensor_input(input_details[1]['index'],[batch_size, self.n_tokens]) | |
| interpreter.resize_tensor_input(output_details['index'],[batch_size, len(self.classes_)]) | |
| interpreter.allocate_tensors() | |
| interpreter.set_tensor(input_details[0]["index"], attention_mask) | |
| interpreter.set_tensor(input_details[1]["index"], input_ids) | |
| interpreter.invoke() | |
| tflite_pred = interpreter.get_tensor(output_details["index"]) | |
| return tflite_pred | |
| def inference(self, texts): | |
| model_checkpoint = "distilbert-base-uncased" | |
| tokenizer = DistilBertTokenizerFast.from_pretrained(model_checkpoint) | |
| tokens = tokenizer(texts, max_length=self.n_tokens, padding="max_length", | |
| truncation=True, return_tensors="tf") | |
| attention_mask, input_ids = tokens['attention_mask'], tokens['input_ids'] | |
| attention_mask_batches, input_ids_batches = self.get_token_batches(attention_mask, input_ids) | |
| pool = multiprocessing.Pool(processes=multiprocessing.cpu_count()) | |
| results = [] | |
| for attention_mask, input_ids in zip(attention_mask_batches, input_ids_batches): | |
| f = pool.apply_async(self.get_batch_inference, args=(len(attention_mask), attention_mask, input_ids)) | |
| results.append(f) | |
| all_predictions = np.array([]) | |
| for n_batch in tqdm(range(len(results))): | |
| tflite_pred = results[n_batch].get(timeout=360) | |
| if n_batch == 0: | |
| all_predictions = tflite_pred | |
| else: | |
| all_predictions = np.concatenate((all_predictions, tflite_pred), axis=0) | |
| return all_predictions | |
| def predict_proba(self, X, y=None): | |
| predict_prob = self.inference(X) | |
| return predict_prob | |