HTK / retro_reader /base.py
faori's picture
Upload folder using huggingface_hub
550665c verified
raw
history blame
11.8 kB
import os
import gc
import time
import json
import math
import collections
from datetime import datetime
from typing import Optional, List, Dict, Tuple, Callable, Any, Union
import torch
import numpy as np
from transformers import (
is_datasets_available,
is_torch_tpu_available,
is_torch_xla_available,
)
from transformers.trainer_utils import (
PredictionOutput,
EvalPrediction,
EvalLoopOutput,
denumpify_detensorize,
speed_metrics,
)
from transformers.utils import logging
from transformers.debug_utils import DebugOption
if is_datasets_available():
import datasets
# if is_torch_xla_available():
# import torch_xla.core.xla_model as xm # type: ignore
# import torch_xla.debug.metrics as met # type: ignore
from transformers import Trainer
logger = logging.get_logger(__name__)
class ToMixin:
def _optimizer_to(self, devide: str = "cpu"):
"""
Move the optimizer state to the specified device.
Args:
devide (str, optional): The device to move the optimizer state to. Defaults to "cpu".
"""
for param in self.optimizer.state.values():
if isinstance(param, torch.Tensor):
param.data = param.data.to(devide)
if param._grad is not None:
param._grad.data = param._grad.data.to(devide)
elif isinstance(param, dict):
for subparam in param.values():
if isinstance(subparam, torch.Tensor):
subparam.data = subparam.data.to(devide)
if subparam._grad is not None:
subparam._grad.data = subparam._grad.data.to(devide)
def _scheduler_to(self, devide: str = "cpu") -> None:
"""
Move the scheduler state to the specified device.
Args:
devide (str, optional): The device to move the scheduler state to. Defaults to "cpu".
Returns:
None
"""
for param in self.lr_scheduler.__dict__.values():
if isinstance(param, torch.Tensor):
param.data = param.data.to(devide)
if param._grad is not None:
param._grad.data = param._grad.data.to(devide)
class BaseReader(Trainer, ToMixin):
name: str = None
def __init__(
self,
*args, # Passed to Trainer.__init__
data_args = {}, # Additional arguments for data loading
eval_examples: datasets.Dataset = None, # Evaluation examples
**kwargs # Passed to Trainer.__init__
):
"""
Initializes the BaseReader.
Args:
*args: Positional arguments passed to Trainer.__init__.
data_args (dict): Additional arguments for data loading.
eval_examples (datasets.Dataset): Evaluation examples.
**kwargs: Keyword arguments passed to Trainer.__init__.
"""
# Call the parent class's __init__ method with the given arguments
super().__init__(*args, **kwargs)
# Set the data_args attribute
self.data_args = data_args
# Set the eval_examples attribute
self.eval_examples = eval_examples
def free_memory(self):
"""
Move the model, optimizer and scheduler state to the CPU, empty the CUDA cache and garbage collect.
This method is useful to free up GPU memory before checkpointing the model or saving it to disk.
"""
self.model.to("cpu")
self._optimizer_to("cpu")
self._scheduler_to("cpu")
torch.cuda.empty_cache()
gc.collect()
def postprocess(
self,
output: EvalLoopOutput,
) -> Union[Any, PredictionOutput]:
"""
Preprocess the evaluation loop output.
This method is called after the evaluation loop has finished and before the evaluation metrics are computed.
It receives the output of the evaluation loop and can be used to modify it before it is passed to the compute_metrics function.
Args:
output (EvalLoopOutput): The output of the evaluation loop.
Returns:
Union[Any, PredictionOutput]: The modified output that will be passed to the compute_metrics function.
"""
return output
def evaluate(
self,
eval_dataset: Optional[datasets.Dataset] = None,
eval_examples: Optional[datasets.Dataset] = None,
ignore_keys: Optional[List[str]] = None,
metric_key_prefix: str = "eval",
) -> Dict[str, float]:
"""
Evaluate the model on the given dataset.
Args:
eval_dataset (Optional[datasets.Dataset], optional): The evaluation dataset. Defaults to None.
eval_examples (Optional[datasets.Dataset], optional): The evaluation examples. Defaults to None.
ignore_keys (Optional[List[str]], optional): Keys to ignore when calculating metrics. Defaults to None.
metric_key_prefix (str, optional): The prefix for metric keys. Defaults to "eval".
Returns:
Dict[str, float]: The evaluation metrics.
"""
# Start tracking memory usage
self._memory_tracker.start()
# Set eval_dataset and eval_dataloader
eval_dataset = self.eval_dataset if eval_dataset is None else eval_dataset
eval_dataloader = self.get_eval_dataloader(eval_dataset)
# Set eval_examples
eval_examples = self.eval_examples if eval_examples is None else eval_examples
# Start timing
start_time = time.time()
# Set compute_metrics
compute_metrics = self.compute_metrics
self.compute_metrics = None
# Set eval_loop
eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop
try:
# Run evaluation loop
output = eval_loop(
eval_dataloader,
description="Evaluation",
# Only gather predictions if there are metrics to compute
prediction_loss_only=True if compute_metrics is None else None,
ignore_keys=ignore_keys,
metric_key_prefix=metric_key_prefix,
)
finally:
# Restore compute_metrics
self.compute_metrics = compute_metrics
# Set eval_dataset format
if isinstance(eval_dataset, datasets.Dataset):
eval_dataset.set_format(
type=eval_dataset.format["type"],
columns=list(eval_dataset.features.keys()),
)
# Postprocess output
eval_preds = self.postprocess(output, eval_examples, eval_dataset, mode="evaluate")
# Compute metrics
metrics = {}
if self.compute_metrics is not None:
metrics = self.compute_metrics(eval_preds)
# Make metrics JSON-serializable
metrics = denumpify_detensorize(metrics)
# Prefix all keys with metric_key_prefix + '_'
for key in list(metrics.keys()):
if not key.startswith(f"{metric_key_prefix}_"):
metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)
# Add speed metrics
total_batch_size = self.args.eval_batch_size * self.args.world_size
metrics.update(
speed_metrics(
metric_key_prefix,
start_time,
num_samples=output.num_samples,
num_steps=math.ceil(output.num_samples / total_batch_size),
)
)
# Log metrics
self.log(metrics)
# Log and save evaluation results
filename = "eval_results.txt"
eval_result_file = self.name + '_' + filename if self.name else filename
with open(os.path.join(self.args.output_dir, eval_result_file), "w") as writer:
logger.info(f"***** Eval results *****")
writer.write("***** Eval results *****\n")
writer.write(f"{datetime.now()}")
for key in sorted(metrics.keys()):
logger.info(f" {key} = {metrics[key]}")
writer.write(f"{key} = {metrics[key]}\n")
writer.write("\n")
# if DebugOption.TPU_METRICS_DEBUG and is_torch_xla_available():
# # Log debug metrics for PyTorch/XLA
# xm.master_print(met.metrics_report())
# Call callback handler on evaluate
self.control = self.callback_handler.on_evaluate(
self.args, self.state, self.control, metrics
)
# Stop tracking memory usage and update metrics
self._memory_tracker.stop_and_update_metrics(metrics)
return metrics
def predict(
self,
test_dataset: datasets.Dataset,
test_examples: Optional[datasets.Dataset] = None,
ignore_keys: Optional[List[str]] = None,
metric_key_prefix: str = "test",
mode: bool = "predict",
) -> PredictionOutput:
"""
Predicts on the given test dataset and returns the predictions.
Args:
test_dataset (datasets.Dataset): The test dataset.
test_examples (Optional[datasets.Dataset], optional): The test examples. Defaults to None.
ignore_keys (Optional[List[str]], optional): Keys to ignore when calculating metrics. Defaults to None.
metric_key_prefix (str, optional): The prefix for metric keys. Defaults to "test".
mode (bool, optional): The mode of prediction. Defaults to "predict".
Returns:
PredictionOutput: The predictions.
"""
# Start tracking memory usage
self._memory_tracker.start()
# Get the test dataloader
test_dataloader = self.get_test_dataloader(test_dataset)
start_time = time.time()
# Set compute_metrics to None and store it for later use
compute_metrics = self.compute_metrics
self.compute_metrics = None
# Get the evaluation loop
eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop
try:
# Run the evaluation loop
output = eval_loop(
test_dataloader,
description="Prediction",
ignore_keys=ignore_keys,
metric_key_prefix=metric_key_prefix,
)
finally:
# Reset compute_metrics to its original value
self.compute_metrics = compute_metrics
# If the test dataset is a datasets.Dataset, set its format
if isinstance(test_dataset, datasets.Dataset):
test_dataset.set_format(
type=test_dataset.format["type"],
columns=list(test_dataset.features.keys()),
)
# Postprocess the output and return the predictions
predictions = self.postprocess(output, test_examples, test_dataset, mode=mode)
# Stop tracking memory usage and update metrics
self._memory_tracker.stop_and_update_metrics(output.metrics)
return predictions