|
|
|
|
|
import time |
|
|
from abc import ABC, abstractmethod |
|
|
from typing import Dict, List, Literal |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
from transformers.trainer_utils import EvalPrediction |
|
|
|
|
|
from swift.utils import Serializer, get_logger |
|
|
|
|
|
logger = get_logger() |
|
|
|
|
|
|
|
|
class Metric(ABC): |
|
|
|
|
|
def __init__(self): |
|
|
self._default = {} |
|
|
self._default_factory = {} |
|
|
|
|
|
def add_state(self, name: str, default=None, default_factory=None) -> None: |
|
|
if not hasattr(self, '_default'): |
|
|
raise AttributeError('Please call super().__init__() first.') |
|
|
if default is None: |
|
|
self._default_factory[name] = default_factory |
|
|
assert name not in self._default, f'self._default: {self._default}' |
|
|
default = default_factory() |
|
|
else: |
|
|
self._default[name] = default |
|
|
assert name not in self._default_factory, f'self._default_factory: {self._default_factory}' |
|
|
setattr(self, name, default) |
|
|
|
|
|
def reset(self): |
|
|
for k, v in self._default.items(): |
|
|
setattr(self, k, v) |
|
|
for k, v in self._default_factory.items(): |
|
|
setattr(self, k, v()) |
|
|
|
|
|
@abstractmethod |
|
|
def update(self, *args, **kwargs): |
|
|
pass |
|
|
|
|
|
@abstractmethod |
|
|
def compute(self): |
|
|
pass |
|
|
|
|
|
|
|
|
class InferStats(Metric): |
|
|
|
|
|
def __init__(self): |
|
|
super().__init__() |
|
|
self.add_state('start_runtime', default_factory=lambda: time.perf_counter()) |
|
|
self.add_state('num_prompt_tokens', default_factory=dict) |
|
|
self.add_state('num_generated_tokens', default_factory=dict) |
|
|
|
|
|
def update(self, output): |
|
|
id_ = output.id |
|
|
self.num_prompt_tokens[id_] = output.usage.prompt_tokens |
|
|
self.num_generated_tokens[id_] = output.usage.completion_tokens |
|
|
|
|
|
def compute(self): |
|
|
runtime = time.perf_counter() - self.start_runtime |
|
|
num_samples = len(self.num_generated_tokens) |
|
|
num_generated_tokens = sum(self.num_generated_tokens.values()) |
|
|
return { |
|
|
'num_prompt_tokens': sum(self.num_prompt_tokens.values()), |
|
|
'num_generated_tokens': num_generated_tokens, |
|
|
'num_samples': num_samples, |
|
|
'runtime': runtime, |
|
|
'samples/s': num_samples / runtime, |
|
|
'tokens/s': num_generated_tokens / runtime, |
|
|
} |
|
|
|
|
|
|
|
|
class MeanMetric(Metric): |
|
|
|
|
|
def __init__(self, nan_value=0): |
|
|
super().__init__() |
|
|
self.nan_value = nan_value |
|
|
self.add_state('state', default=0.) |
|
|
self.add_state('count', default=0) |
|
|
|
|
|
def update(self, state: torch.Tensor): |
|
|
if isinstance(state, (torch.Tensor, np.ndarray)): |
|
|
state = state.tolist() |
|
|
|
|
|
if isinstance(state, (list, tuple)): |
|
|
count = len(state) |
|
|
state = sum(state) |
|
|
else: |
|
|
count = 1 |
|
|
|
|
|
self.state += state |
|
|
self.count += count |
|
|
|
|
|
def compute(self): |
|
|
return { |
|
|
'value': self.state / self.count if self.count > 0 else self.nan_value, |
|
|
} |
|
|
|
|
|
|
|
|
def compute_rouge_bleu(preds: List[str], labels: List[str]): |
|
|
import jieba |
|
|
from nltk.translate.bleu_score import SmoothingFunction, sentence_bleu |
|
|
from rouge.rouge import Rouge |
|
|
score_dict = {key: MeanMetric() for key in ['rouge-1', 'rouge-2', 'rouge-l', 'bleu-4']} |
|
|
|
|
|
for pred, label in zip(preds, labels): |
|
|
hypothesis = list(jieba.cut(pred)) |
|
|
reference = list(jieba.cut(label)) |
|
|
if not hypothesis or not reference: |
|
|
continue |
|
|
rouge = Rouge() |
|
|
scores = rouge.get_scores(' '.join(hypothesis), ' '.join(reference))[0] |
|
|
for k, v in scores.items(): |
|
|
score_dict[k].update(v['f']) |
|
|
bleu_score = sentence_bleu([list(label)], list(pred), smoothing_function=SmoothingFunction().method3) |
|
|
score_dict['bleu-4'].update(bleu_score) |
|
|
|
|
|
return {k: round(v.compute()['value'] * 100, 6) for k, v in score_dict.items()} |
|
|
|
|
|
|
|
|
def compute_nlg_metrics(prediction) -> Dict[str, float]: |
|
|
preds, labels = prediction[0], prediction[1] |
|
|
new_preds, new_labels = [], [] |
|
|
for i in range(preds.shape[0]): |
|
|
new_preds.append(Serializer.from_tensor(preds[i])) |
|
|
new_labels.append(Serializer.from_tensor(labels[i])) |
|
|
return compute_rouge_bleu(new_preds, new_labels) |
|
|
|
|
|
|
|
|
def compute_acc(preds, |
|
|
labels, |
|
|
*, |
|
|
acc_strategy: Literal['token', 'seq'] = 'token', |
|
|
is_encoder_decoder: bool = False) -> Dict[str, List[float]]: |
|
|
|
|
|
if isinstance(preds, torch.Tensor): |
|
|
if torch.is_floating_point(labels): |
|
|
return {} |
|
|
preds = preds.cpu().numpy() |
|
|
labels = labels.cpu().numpy() |
|
|
if preds.ndim >= 2 and not is_encoder_decoder: |
|
|
labels = labels[..., 1:] |
|
|
preds = preds[..., :-1] |
|
|
if np.issubdtype(labels.dtype, np.floating) or preds.shape != labels.shape: |
|
|
return {} |
|
|
|
|
|
masks = labels != -100 |
|
|
if acc_strategy == 'token' or preds.ndim == 1: |
|
|
acc_list = (preds[masks] == labels[masks]).tolist() |
|
|
else: |
|
|
acc_list = [] |
|
|
for i, m in enumerate(masks): |
|
|
acc_list.append(np.all(preds[i, m] == labels[i, m])) |
|
|
return {f'{acc_strategy}_acc' if preds.ndim >= 2 else 'acc': acc_list} |
|
|
|
|
|
|
|
|
def compute_acc_metrics(eval_prediction: EvalPrediction, |
|
|
*, |
|
|
acc_strategy: Literal['token', 'seq'] = 'token', |
|
|
is_encoder_decoder: bool = False) -> Dict[str, float]: |
|
|
|
|
|
metric = compute_acc( |
|
|
eval_prediction.predictions, |
|
|
eval_prediction.label_ids, |
|
|
acc_strategy=acc_strategy, |
|
|
is_encoder_decoder=is_encoder_decoder) |
|
|
if len(metric) == 0: |
|
|
return {} |
|
|
return {k: sum(v) / len(v) for k, v in metric.items()} |
|
|
|
|
|
|
|
|
def preprocess_logits_for_acc(logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: |
|
|
if isinstance(logits, (list, tuple)): |
|
|
logits = logits[0] |
|
|
preds = logits.argmax(dim=-1) |
|
|
return preds |
|
|
|
|
|
|
|
|
|
|
|
METRIC_MAPPING = { |
|
|
'acc': (compute_acc_metrics, preprocess_logits_for_acc), |
|
|
'nlg': (compute_nlg_metrics, None), |
|
|
} |
|
|
|
|
|
|
|
|
def get_metric(metric: str): |
|
|
return METRIC_MAPPING[metric] |
|
|
|