# Copyright (c) Alibaba, Inc. and its affiliates. import time from abc import ABC, abstractmethod from typing import Dict, List, Literal import numpy as np import torch from transformers.trainer_utils import EvalPrediction from swift.utils import Serializer, get_logger logger = get_logger() class Metric(ABC): def __init__(self): self._default = {} self._default_factory = {} def add_state(self, name: str, default=None, default_factory=None) -> None: if not hasattr(self, '_default'): raise AttributeError('Please call super().__init__() first.') if default is None: self._default_factory[name] = default_factory assert name not in self._default, f'self._default: {self._default}' default = default_factory() else: self._default[name] = default assert name not in self._default_factory, f'self._default_factory: {self._default_factory}' setattr(self, name, default) def reset(self): for k, v in self._default.items(): setattr(self, k, v) for k, v in self._default_factory.items(): setattr(self, k, v()) @abstractmethod def update(self, *args, **kwargs): pass @abstractmethod def compute(self): pass class InferStats(Metric): def __init__(self): super().__init__() self.add_state('start_runtime', default_factory=lambda: time.perf_counter()) self.add_state('num_prompt_tokens', default_factory=dict) self.add_state('num_generated_tokens', default_factory=dict) def update(self, output): id_ = output.id self.num_prompt_tokens[id_] = output.usage.prompt_tokens self.num_generated_tokens[id_] = output.usage.completion_tokens def compute(self): runtime = time.perf_counter() - self.start_runtime num_samples = len(self.num_generated_tokens) num_generated_tokens = sum(self.num_generated_tokens.values()) return { 'num_prompt_tokens': sum(self.num_prompt_tokens.values()), 'num_generated_tokens': num_generated_tokens, 'num_samples': num_samples, 'runtime': runtime, 'samples/s': num_samples / runtime, 'tokens/s': num_generated_tokens / runtime, } class MeanMetric(Metric): def __init__(self, nan_value=0): super().__init__() self.nan_value = nan_value self.add_state('state', default=0.) self.add_state('count', default=0) def update(self, state: torch.Tensor): if isinstance(state, (torch.Tensor, np.ndarray)): state = state.tolist() if isinstance(state, (list, tuple)): count = len(state) state = sum(state) else: count = 1 self.state += state self.count += count def compute(self): return { 'value': self.state / self.count if self.count > 0 else self.nan_value, } def compute_rouge_bleu(preds: List[str], labels: List[str]): import jieba from nltk.translate.bleu_score import SmoothingFunction, sentence_bleu from rouge.rouge import Rouge score_dict = {key: MeanMetric() for key in ['rouge-1', 'rouge-2', 'rouge-l', 'bleu-4']} for pred, label in zip(preds, labels): hypothesis = list(jieba.cut(pred)) reference = list(jieba.cut(label)) if not hypothesis or not reference: continue rouge = Rouge() scores = rouge.get_scores(' '.join(hypothesis), ' '.join(reference))[0] for k, v in scores.items(): score_dict[k].update(v['f']) bleu_score = sentence_bleu([list(label)], list(pred), smoothing_function=SmoothingFunction().method3) score_dict['bleu-4'].update(bleu_score) return {k: round(v.compute()['value'] * 100, 6) for k, v in score_dict.items()} def compute_nlg_metrics(prediction) -> Dict[str, float]: preds, labels = prediction[0], prediction[1] new_preds, new_labels = [], [] for i in range(preds.shape[0]): new_preds.append(Serializer.from_tensor(preds[i])) new_labels.append(Serializer.from_tensor(labels[i])) return compute_rouge_bleu(new_preds, new_labels) def compute_acc(preds, labels, *, acc_strategy: Literal['token', 'seq'] = 'token', is_encoder_decoder: bool = False) -> Dict[str, List[float]]: if isinstance(preds, torch.Tensor): if torch.is_floating_point(labels): return {} preds = preds.cpu().numpy() labels = labels.cpu().numpy() if preds.ndim >= 2 and not is_encoder_decoder: labels = labels[..., 1:] preds = preds[..., :-1] if np.issubdtype(labels.dtype, np.floating) or preds.shape != labels.shape: return {} masks = labels != -100 if acc_strategy == 'token' or preds.ndim == 1: acc_list = (preds[masks] == labels[masks]).tolist() else: acc_list = [] for i, m in enumerate(masks): acc_list.append(np.all(preds[i, m] == labels[i, m])) return {f'{acc_strategy}_acc' if preds.ndim >= 2 else 'acc': acc_list} def compute_acc_metrics(eval_prediction: EvalPrediction, *, acc_strategy: Literal['token', 'seq'] = 'token', is_encoder_decoder: bool = False) -> Dict[str, float]: metric = compute_acc( eval_prediction.predictions, eval_prediction.label_ids, acc_strategy=acc_strategy, is_encoder_decoder=is_encoder_decoder) if len(metric) == 0: return {} return {k: sum(v) / len(v) for k, v in metric.items()} def preprocess_logits_for_acc(logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: if isinstance(logits, (list, tuple)): logits = logits[0] preds = logits.argmax(dim=-1) return preds # Add your own metric calculation method here, use --metric xxx to train METRIC_MAPPING = { 'acc': (compute_acc_metrics, preprocess_logits_for_acc), 'nlg': (compute_nlg_metrics, None), } def get_metric(metric: str): return METRIC_MAPPING[metric]