AutoQuantNX / src /handlers /nlp_models /token_classification_handler.py
smokxy's picture
Upload folder using huggingface_hub
9bf1d31 verified
from ..base_handler import ModelHandler
from transformers import AutoTokenizer
import torch
import time
class TokenClassificationHandler(ModelHandler):
def __init__(self, model_name, model_class, quantization_type, test_text):
super().__init__(model_name, model_class, quantization_type, test_text)
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
def run_inference(self, model, text):
inputs = self.tokenizer(text, return_tensors='pt', truncation=True, padding=True).to(self.device)
start_time = time.time()
with torch.no_grad():
outputs = model(**inputs)
end_time = time.time()
return outputs, end_time - start_time
def decode_output(self, model, outputs):
tokens = self.tokenizer.convert_ids_to_tokens(outputs['input_ids'][0])
labels = torch.argmax(outputs.logits, dim=-1).squeeze().tolist()
decoded_labels = [model.config.id2label[label] for label in labels]
return dict(zip(tokens, decoded_labels))
def compare_outputs(self, original_outputs, quantized_outputs):
"""Compare outputs for token classification models"""
if original_outputs is None or quantized_outputs is None:
return None
orig_logits = original_outputs.logits.cpu().numpy()
quant_logits = quantized_outputs.logits.cpu().numpy()
orig_preds = orig_logits.argmax(axis=-1)
quant_preds = quant_logits.argmax(axis=-1)
input_tokens = self.tokenizer.convert_ids_to_tokens(
self.tokenizer(self.test_text, return_tensors='pt')['input_ids'][0]
)
orig_labels = [self.original_model.config.id2label[p] for p in orig_preds[0]]
quant_labels = [self.quantized_model.config.id2label[p] for p in quant_preds[0]]
original_results = list(zip(input_tokens, orig_labels))
quantized_results = list(zip(input_tokens, quant_labels))
token_matches = sum(o_label == q_label for o_label, q_label in zip(orig_labels, quant_labels))
total_tokens = len(orig_labels)
metrics = {
'original_predictions': original_results,
'quantized_predictions': quantized_results,
'token_level_accuracy': float(token_matches) / total_tokens if total_tokens > 0 else 0.0,
'sequence_exact_match': float((orig_preds == quant_preds).all()),
'logits_mse': ((orig_logits - quant_logits) ** 2).mean(),
}
return metrics