Spaces:
Running
Running
from ..base_handler import ModelHandler | |
from transformers import AutoTokenizer | |
import torch | |
import time | |
class TokenClassificationHandler(ModelHandler): | |
def __init__(self, model_name, model_class, quantization_type, test_text): | |
super().__init__(model_name, model_class, quantization_type, test_text) | |
self.tokenizer = AutoTokenizer.from_pretrained(model_name) | |
def run_inference(self, model, text): | |
inputs = self.tokenizer(text, return_tensors='pt', truncation=True, padding=True).to(self.device) | |
start_time = time.time() | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
end_time = time.time() | |
return outputs, end_time - start_time | |
def decode_output(self, model, outputs): | |
tokens = self.tokenizer.convert_ids_to_tokens(outputs['input_ids'][0]) | |
labels = torch.argmax(outputs.logits, dim=-1).squeeze().tolist() | |
decoded_labels = [model.config.id2label[label] for label in labels] | |
return dict(zip(tokens, decoded_labels)) | |
def compare_outputs(self, original_outputs, quantized_outputs): | |
"""Compare outputs for token classification models""" | |
if original_outputs is None or quantized_outputs is None: | |
return None | |
orig_logits = original_outputs.logits.cpu().numpy() | |
quant_logits = quantized_outputs.logits.cpu().numpy() | |
orig_preds = orig_logits.argmax(axis=-1) | |
quant_preds = quant_logits.argmax(axis=-1) | |
input_tokens = self.tokenizer.convert_ids_to_tokens( | |
self.tokenizer(self.test_text, return_tensors='pt')['input_ids'][0] | |
) | |
orig_labels = [self.original_model.config.id2label[p] for p in orig_preds[0]] | |
quant_labels = [self.quantized_model.config.id2label[p] for p in quant_preds[0]] | |
original_results = list(zip(input_tokens, orig_labels)) | |
quantized_results = list(zip(input_tokens, quant_labels)) | |
token_matches = sum(o_label == q_label for o_label, q_label in zip(orig_labels, quant_labels)) | |
total_tokens = len(orig_labels) | |
metrics = { | |
'original_predictions': original_results, | |
'quantized_predictions': quantized_results, | |
'token_level_accuracy': float(token_matches) / total_tokens if total_tokens > 0 else 0.0, | |
'sequence_exact_match': float((orig_preds == quant_preds).all()), | |
'logits_mse': ((orig_logits - quant_logits) ** 2).mean(), | |
} | |
return metrics |