File size: 6,236 Bytes
7feac49 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 |
# Copyright (c) Alibaba, Inc. and its affiliates.
import time
from abc import ABC, abstractmethod
from typing import Dict, List, Literal
import numpy as np
import torch
from transformers.trainer_utils import EvalPrediction
from swift.utils import Serializer, get_logger
logger = get_logger()
class Metric(ABC):
def __init__(self):
self._default = {}
self._default_factory = {}
def add_state(self, name: str, default=None, default_factory=None) -> None:
if not hasattr(self, '_default'):
raise AttributeError('Please call super().__init__() first.')
if default is None:
self._default_factory[name] = default_factory
assert name not in self._default, f'self._default: {self._default}'
default = default_factory()
else:
self._default[name] = default
assert name not in self._default_factory, f'self._default_factory: {self._default_factory}'
setattr(self, name, default)
def reset(self):
for k, v in self._default.items():
setattr(self, k, v)
for k, v in self._default_factory.items():
setattr(self, k, v())
@abstractmethod
def update(self, *args, **kwargs):
pass
@abstractmethod
def compute(self):
pass
class InferStats(Metric):
def __init__(self):
super().__init__()
self.add_state('start_runtime', default_factory=lambda: time.perf_counter())
self.add_state('num_prompt_tokens', default_factory=dict)
self.add_state('num_generated_tokens', default_factory=dict)
def update(self, output):
id_ = output.id
self.num_prompt_tokens[id_] = output.usage.prompt_tokens
self.num_generated_tokens[id_] = output.usage.completion_tokens
def compute(self):
runtime = time.perf_counter() - self.start_runtime
num_samples = len(self.num_generated_tokens)
num_generated_tokens = sum(self.num_generated_tokens.values())
return {
'num_prompt_tokens': sum(self.num_prompt_tokens.values()),
'num_generated_tokens': num_generated_tokens,
'num_samples': num_samples,
'runtime': runtime,
'samples/s': num_samples / runtime,
'tokens/s': num_generated_tokens / runtime,
}
class MeanMetric(Metric):
def __init__(self, nan_value=0):
super().__init__()
self.nan_value = nan_value
self.add_state('state', default=0.)
self.add_state('count', default=0)
def update(self, state: torch.Tensor):
if isinstance(state, (torch.Tensor, np.ndarray)):
state = state.tolist()
if isinstance(state, (list, tuple)):
count = len(state)
state = sum(state)
else:
count = 1
self.state += state
self.count += count
def compute(self):
return {
'value': self.state / self.count if self.count > 0 else self.nan_value,
}
def compute_rouge_bleu(preds: List[str], labels: List[str]):
import jieba
from nltk.translate.bleu_score import SmoothingFunction, sentence_bleu
from rouge.rouge import Rouge
score_dict = {key: MeanMetric() for key in ['rouge-1', 'rouge-2', 'rouge-l', 'bleu-4']}
for pred, label in zip(preds, labels):
hypothesis = list(jieba.cut(pred))
reference = list(jieba.cut(label))
if not hypothesis or not reference:
continue
rouge = Rouge()
scores = rouge.get_scores(' '.join(hypothesis), ' '.join(reference))[0]
for k, v in scores.items():
score_dict[k].update(v['f'])
bleu_score = sentence_bleu([list(label)], list(pred), smoothing_function=SmoothingFunction().method3)
score_dict['bleu-4'].update(bleu_score)
return {k: round(v.compute()['value'] * 100, 6) for k, v in score_dict.items()}
def compute_nlg_metrics(prediction) -> Dict[str, float]:
preds, labels = prediction[0], prediction[1]
new_preds, new_labels = [], []
for i in range(preds.shape[0]):
new_preds.append(Serializer.from_tensor(preds[i]))
new_labels.append(Serializer.from_tensor(labels[i]))
return compute_rouge_bleu(new_preds, new_labels)
def compute_acc(preds,
labels,
*,
acc_strategy: Literal['token', 'seq'] = 'token',
is_encoder_decoder: bool = False) -> Dict[str, List[float]]:
if isinstance(preds, torch.Tensor):
if torch.is_floating_point(labels):
return {}
preds = preds.cpu().numpy()
labels = labels.cpu().numpy()
if preds.ndim >= 2 and not is_encoder_decoder:
labels = labels[..., 1:]
preds = preds[..., :-1]
if np.issubdtype(labels.dtype, np.floating) or preds.shape != labels.shape:
return {}
masks = labels != -100
if acc_strategy == 'token' or preds.ndim == 1:
acc_list = (preds[masks] == labels[masks]).tolist()
else:
acc_list = []
for i, m in enumerate(masks):
acc_list.append(np.all(preds[i, m] == labels[i, m]))
return {f'{acc_strategy}_acc' if preds.ndim >= 2 else 'acc': acc_list}
def compute_acc_metrics(eval_prediction: EvalPrediction,
*,
acc_strategy: Literal['token', 'seq'] = 'token',
is_encoder_decoder: bool = False) -> Dict[str, float]:
metric = compute_acc(
eval_prediction.predictions,
eval_prediction.label_ids,
acc_strategy=acc_strategy,
is_encoder_decoder=is_encoder_decoder)
if len(metric) == 0:
return {}
return {k: sum(v) / len(v) for k, v in metric.items()}
def preprocess_logits_for_acc(logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
if isinstance(logits, (list, tuple)):
logits = logits[0]
preds = logits.argmax(dim=-1)
return preds
# Add your own metric calculation method here, use --metric xxx to train
METRIC_MAPPING = {
'acc': (compute_acc_metrics, preprocess_logits_for_acc),
'nlg': (compute_nlg_metrics, None),
}
def get_metric(metric: str):
return METRIC_MAPPING[metric]
|