|
|
|
import copy |
|
|
|
from mmengine.runner import ValLoop as MMENGINE_ValLoop |
|
from mmengine.dist import broadcast_object_list, is_main_process, get_world_size, get_rank, barrier, collect_results |
|
import math |
|
import torch |
|
from mmengine.model import is_model_wrapper |
|
from types import MethodType |
|
from xtuner.utils import (DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX, |
|
PROMPT_TEMPLATE) |
|
from xtuner.tools.utils import get_stop_criteria, is_cn_string |
|
from transformers import GenerationConfig |
|
|
|
TORCH_DTYPE_MAP = dict( |
|
fp16=torch.float16, bf16=torch.bfloat16, fp32=torch.float32, auto='auto') |
|
|
|
class TestLoop(MMENGINE_ValLoop): |
|
def __init__(self, runner, dataloader, evaluator=None, torch_dtype='fp16', select_metric='first') -> None: |
|
|
|
super(MMENGINE_ValLoop, self).__init__(runner, dataloader) |
|
self._runner = runner |
|
self.torch_dtype = torch_dtype |
|
if torch_dtype is not None: |
|
self.torch_dtype = TORCH_DTYPE_MAP[torch_dtype] |
|
self.select_metric = select_metric |
|
|
|
def run(self) -> dict: |
|
"""Launch Test.""" |
|
self.runner.logger.info('==================== Start test loop ===================') |
|
self.runner.call_hook('before_test') |
|
self.runner.call_hook('before_test_epoch') |
|
|
|
if is_model_wrapper(self.runner.model): |
|
model = self.runner.model.module |
|
else: |
|
model = self.runner.model |
|
|
|
model.gradient_checkpointing_disable() |
|
model.eval() |
|
model.cuda() |
|
|
|
rank = get_rank() |
|
metrics = [] |
|
|
|
current_run_total_ids = 0 |
|
for _, dataset in enumerate(self.dataloader.dataset.datasets): |
|
if not hasattr(model, 'preparing_for_generation'): |
|
model.preparing_for_generation = MethodType(default_preparing_for_generation, model) |
|
print("Warning, the model do not have the preparing_for_generation() function, using the default!!!") |
|
model.preparing_for_generation(dataset.metainfo) |
|
|
|
|
|
results = [] |
|
n_samples = len(dataset) |
|
per_rank_samples = math.ceil(n_samples / get_world_size()) |
|
per_rank_ids = range(per_rank_samples * rank, |
|
min(n_samples, per_rank_samples * (rank + 1))) |
|
for idx in per_rank_ids: |
|
data_batch = dataset[idx] |
|
self.run_iter(current_run_total_ids, data_batch, results, model) |
|
current_run_total_ids += 1 |
|
|
|
barrier() |
|
self.runner.logger.info('==================== Start collect results ===================') |
|
results = collect_results(results, len(dataset)) |
|
self.runner.logger.info('========= Starting the evaluation of a data ===========') |
|
if is_main_process(): |
|
metric = dataset.evaluate(results, self.runner.work_dir) |
|
objects = [metric] |
|
else: |
|
objects = [None] |
|
broadcast_object_list(objects) |
|
metric = objects[0] |
|
metrics.append(metric) |
|
|
|
|
|
if self.select_metric == 'first': |
|
metrics = metrics[0] |
|
else: |
|
raise NotImplementedError |
|
|
|
self.runner.logger.info('================ Ending test loop ================') |
|
self.runner.call_hook('after_test_epoch', metrics=metrics) |
|
self.runner.call_hook('after_test') |
|
return metrics |
|
|
|
@torch.no_grad() |
|
def run_iter(self, idx, data_batch, results, model): |
|
assert 'text_prompts' in data_batch and 'pixel_values' in data_batch and 'img_id' in data_batch |
|
prediction = {'img_id': data_batch['img_id']} |
|
|
|
self.runner.call_hook( |
|
'before_test_iter', batch_idx=idx, data_batch=data_batch) |
|
|
|
outputs = model.predict_forward(**data_batch) |
|
prediction.update(outputs) |
|
results.append(prediction) |
|
|
|
self.runner.call_hook( |
|
'after_test_iter', |
|
batch_idx=idx, |
|
data_batch=data_batch, |
|
outputs=outputs) |
|
|
|
def default_preparing_for_generation(self, metainfo): |
|
|
|
|
|
assert hasattr(self, 'tokenizer'), "The Model does not have the tokenizer!!!" |
|
|
|
self.bot_name = 'BOT' |
|
template = PROMPT_TEMPLATE['internlm2_chat'] |
|
self.template = template |
|
stop_words = [] |
|
stop_words += template.get('STOP_WORDS', []) |
|
stop_criteria = get_stop_criteria( |
|
tokenizer=self.tokenizer, stop_words=stop_words) |
|
self.stop_criteria = stop_criteria |
|
|
|
default_generation_kwargs = dict( |
|
max_new_tokens=2048, |
|
do_sample=False, |
|
eos_token_id=self.tokenizer.eos_token_id, |
|
pad_token_id=( |
|
self.tokenizer.pad_token_id |
|
if self.tokenizer.pad_token_id is not None |
|
else self.tokenizer.eos_token_id |
|
), |
|
) |
|
default_generation_kwargs.update(metainfo.get('generation_kwargs', {})) |
|
self.gen_config = GenerationConfig(**default_generation_kwargs) |
|
return |
|
|
|
|
|
class AnnoLoop(MMENGINE_ValLoop): |
|
def __init__(self, runner, dataloader, evaluator=None, torch_dtype='fp16', select_metric='first') -> None: |
|
|
|
super(MMENGINE_ValLoop, self).__init__(runner, dataloader) |
|
self._runner = runner |
|
self.torch_dtype = torch_dtype |
|
if torch_dtype is not None: |
|
self.torch_dtype = TORCH_DTYPE_MAP[torch_dtype] |
|
self.select_metric = select_metric |
|
|
|
def run(self) -> dict: |
|
"""Launch Test.""" |
|
self.runner.logger.info('==================== Start test loop ===================') |
|
self.runner.call_hook('before_test') |
|
self.runner.call_hook('before_test_epoch') |
|
|
|
if is_model_wrapper(self.runner.model): |
|
model = self.runner.model.module |
|
else: |
|
model = self.runner.model |
|
|
|
model.eval() |
|
|
|
rank = get_rank() |
|
metrics = [] |
|
|
|
current_run_total_ids = 0 |
|
for _, dataset in enumerate(self.dataloader.dataset.datasets): |
|
|
|
|
|
results = [] |
|
n_samples = len(dataset) |
|
per_rank_samples = math.ceil(n_samples / get_world_size()) |
|
per_rank_ids = range(per_rank_samples * rank, |
|
min(n_samples, per_rank_samples * (rank + 1))) |
|
for idx in per_rank_ids: |
|
data_batch = dataset[idx] |
|
self.run_iter(current_run_total_ids, data_batch, results, model) |
|
current_run_total_ids += 1 |
|
if hasattr(model, 'save_step'): |
|
model.save_step(last=True) |
|
|
|
barrier() |
|
self.runner.logger.info('==================== Start collect results ===================') |
|
results = collect_results(results, len(dataset)) |
|
self.runner.logger.info('========= Starting the evaluation of a data ===========') |
|
if is_main_process(): |
|
metric = dataset.evaluate(results, self.runner.work_dir) |
|
objects = [metric] |
|
else: |
|
objects = [None] |
|
broadcast_object_list(objects) |
|
metric = objects[0] |
|
metrics.append(metric) |
|
|
|
|
|
if self.select_metric == 'first': |
|
metrics = metrics[0] |
|
else: |
|
raise NotImplementedError |
|
|
|
self.runner.logger.info('================ Ending test loop ================') |
|
self.runner.call_hook('after_test_epoch', metrics=metrics) |
|
self.runner.call_hook('after_test') |
|
return metrics |
|
|
|
@torch.no_grad() |
|
def run_iter(self, idx, data_batch, results, model): |
|
prediction = {} |
|
|
|
self.runner.call_hook( |
|
'before_test_iter', batch_idx=idx, data_batch=data_batch) |
|
|
|
outputs = model.predict_forward(**data_batch) |
|
prediction.update(outputs) |
|
results.append(prediction) |
|
|
|
self.runner.call_hook( |
|
'after_test_iter', |
|
batch_idx=idx, |
|
data_batch=data_batch, |
|
outputs=outputs) |