mistral-nanotron / run_evals.py
thomwolf's picture
thomwolf HF staff
add eval code
f1d3dc6
raw
history blame
No virus
19 kB
"""
Nanotron Inference Script
Usage:
```
export CUDA_DEVICE_MAX_CONNECTIONS=1 # important for some distributed operations
torchrun --nproc_per_node=8 run_evals.py --checkpoint-config-path ./pretrained/Mistral-7B-v0.1/config.yaml \
--lighteval-override ./lighteval_eval_config.yaml
```
"""
# flake8: noqa: C901
import argparse
import os
import random
import time
from dataclasses import asdict
from pathlib import Path
import numpy as np
import torch
from huggingface_hub import HFSummaryWriter
from lighteval.evaluator import evaluate, make_results_table
from lighteval.logging.evaluation_tracker import EvaluationTracker
from lighteval.logging.hierarchical_logger import hlog, htrack, htrack_block
from lighteval.logging.info_loggers import (
DetailsLogger,
)
from lighteval.models.model_loader import ModelInfo
from lighteval.tasks.lighteval_task import LightevalTask, create_requests_from_tasks
from lighteval.tasks.registry import Registry, get_custom_tasks, taskinfo_selector
from nanotron import distributed as dist
from nanotron import logging
from nanotron.config import get_config_from_file
from nanotron.logging import get_logger, log_rank
from nanotron.parallel.context import ParallelContext
from nanotron.utils import local_ranks_zero_first
from brrr.config import BrrrConfig
from brrr.experiment_loggers import flatten_dict, obj_to_markdown
from brrr.s3_checkpoints import fs_copy
from brrr.utils import check_env
from lighteval.models.brrr_models import BRRRModel
from modeling_mistral import MistralForTraining
from config_mistral import MistralConfig
logger = get_logger(__name__)
TOKEN = os.getenv("HF_TOKEN")
CACHE_DIR = os.getenv("HF_HOME", "/scratch")
def get_parser():
parser = argparse.ArgumentParser()
parser.add_argument(
"--checkpoint-config-path",
type=str,
required=True,
help="Path to the brr checkpoint YAML or python config file, potentially on S3",
)
parser.add_argument(
"--lighteval-override",
type=str,
help="Path to an optional YAML or python Lighteval config to override part of the checkpoint Lighteval config",
)
parser.add_argument(
"--tokenizer",
type=str,
help="Local or hub path of an optional tokenizer (if not indicated in the checkpoint)",
)
parser.add_argument(
"--s5cmd-path",
type=str,
default="/admin/home/thomwolf/miniconda3/envs/b4r/bin/s5cmd",
help="Path to s5cmd install",
)
parser.add_argument(
"--s5cmd-numworkers",
type=int,
default=64,
help="s5cmd num workers (optional)",
)
parser.add_argument(
"--s5cmd-concurrency",
type=int,
default=10,
help="s5cmd concurrency (optional)",
)
parser.add_argument(
"--cache-dir",
type=str,
default="",
help="Cache directory",
)
return parser
def push_results_to_wandb( # noqa: C901
config: BrrrConfig, results: dict[str, dict[str, float]], details: dict[str, DetailsLogger.CompiledDetail]
):
# config: BrrrConfig = get_config_from_dict(config, config_class=BrrrConfig)
lighteval_config = config.lighteval
try:
global_step = config.general.step
except ValueError:
global_step = 0
if config.lighteval.logging.tensorboard_metric_prefix is not None:
prefix = config.lighteval.logging.tensorboard_metric_prefix
else:
prefix = "eval"
output_dir_tb = Path(lighteval_config.logging.local_output_path) / "tb" / (config.general.run + "_" + prefix)
output_dir_tb.mkdir(parents=True, exist_ok=True)
os.environ["WANDB_DISABLE_SERVICE"] = "True"
import wandb
wandb.tensorboard.patch(root_logdir=config.lighteval.logging.local_output_path)
hlog("Starting wandb with WANDB_DISABLE_SERVICE=True")
wandb.init(
project=config.lighteval.wandb.wandb_project,
entity=config.lighteval.wandb.wandb_entity,
name=config.lighteval.wandb.wandb_run_name,
config=config.as_dict(),
# sync_tensorboard=True,
resume=True,
)
wb_dict = {}
bench_averages = {}
for name, values in results.items():
splited_name = name.split("|")
if len(splited_name) == 3:
_, task_name, _ = splited_name
else:
task_name = name
bench_suite = None
if ":" in task_name:
bench_suite = task_name.split(":")[0] # e.g. MMLU
hlog(f"bench_suite {bench_suite} in {task_name}")
for metric, value in values.items():
if "stderr" in metric:
continue
if bench_suite not in bench_averages:
bench_averages[bench_suite] = {}
bench_averages[bench_suite][metric] = bench_averages[bench_suite].get(metric, []) + [float(value)]
hlog(f"Pushing {task_name} {values} to tensorboard")
for metric, value in values.items():
if "stderr" in metric:
wb_dict[f"stderr_{metric}/{task_name}"] = value
elif bench_suite is not None:
wb_dict[f"{bench_suite}-{metric}/{task_name}"] = value
else:
wb_dict[f"{metric}/{task_name}"] = value
# e.g. MMLU
for name, values in bench_averages.items():
for metric, values in values.items():
hlog(f"Pushing average {name} {metric} {sum(values) / len(values)} to tensorboard")
wb_dict[f"{metric}/{name}"] = sum(values) / len(values)
for task_name, task_details in details.items():
if len(task_details) <= 1:
continue
columns = list(flatten_dict(asdict(task_details[0])).keys())
table = wandb.Table(columns=columns)
table.add_data(*[str(v) for v in flatten_dict(asdict(task_details[0])).values()])
table.add_data(*[str(v) for v in flatten_dict(asdict(task_details[1])).values()])
wandb.log({f"eval_details_{task_name}": table}, step=global_step, commit=False)
wandb.log(dict(wb_dict.items()), step=global_step, commit=True)
# tb_context.add_text("eval_sizes", obj_to_markdown(sizes), global_step=global_step)
# We are doing parallel evaluations of multiple checkpoints and recording the steps not in order
# This messes up with tensorboard, so the easiest is to rename files in the order of the checkpoints
# See: https://github.com/tensorflow/tensorboard/issues/5958
# But tensorboardX don't let us control the prefix of the files (only the suffix), so we need to do it ourselves before commiting the files
hlog(f"Pushed to wandb" f" at {output_dir_tb} and global_step {global_step}")
def push_results_to_tensorboard( # noqa: C901
config: BrrrConfig, results: dict[str, dict[str, float]], details: dict[str, DetailsLogger.CompiledDetail]
):
# config: BrrrConfig = get_config_from_dict(config, config_class=BrrrConfig)
lighteval_config = config.lighteval
try:
global_step = config.general.step
except ValueError:
global_step = 0
if config.lighteval.logging.tensorboard_metric_prefix is not None:
prefix = config.lighteval.logging.tensorboard_metric_prefix
else:
prefix = "eval"
output_dir_tb = Path(lighteval_config.logging.local_output_path) / "tb" / (config.general.run + "_" + prefix)
output_dir_tb.mkdir(parents=True, exist_ok=True)
tb_context = HFSummaryWriter(
logdir=str(output_dir_tb),
repo_id=lighteval_config.logging.hub_repo_tensorboard,
repo_private=True,
path_in_repo="tb",
commit_every=6000, # Very long time so that we can change our files names and trigger push ourselves (see below)
)
bench_averages = {}
for name, values in results.items():
splited_name = name.split("|")
if len(splited_name) == 3:
_, task_name, _ = splited_name
else:
task_name = name
bench_suite = None
if ":" in task_name:
bench_suite = task_name.split(":")[0] # e.g. MMLU
hlog(f"bench_suite {bench_suite} in {task_name}")
for metric, value in values.items():
if "stderr" in metric:
continue
if bench_suite not in bench_averages:
bench_averages[bench_suite] = {}
bench_averages[bench_suite][metric] = bench_averages[bench_suite].get(metric, []) + [float(value)]
hlog(f"Pushing {task_name} {values} to tensorboard")
for metric, value in values.items():
if "stderr" in metric:
tb_context.add_scalar(f"stderr_{prefix}/{task_name}/{metric}", value, global_step=global_step)
elif bench_suite is not None:
tb_context.add_scalar(f"{prefix}_{bench_suite}/{task_name}/{metric}", value, global_step=global_step)
else:
tb_context.add_scalar(f"{prefix}/{task_name}/{metric}", value, global_step=global_step)
# e.g. MMLU
for name, values in bench_averages.items():
for metric, values in values.items():
hlog(f"Pushing average {name} {metric} {sum(values) / len(values)} to tensorboard")
tb_context.add_scalar(f"{prefix}/{name}/{metric}", sum(values) / len(values), global_step=global_step)
tb_context.add_text("eval_config", obj_to_markdown(results), global_step=global_step)
# tb_context.add_text("eval_sizes", obj_to_markdown(sizes), global_step=global_step)
for task_name, task_details in details.items():
tb_context.add_text(
f"eval_details_{task_name}",
obj_to_markdown({"0": task_details[0], "1": task_details[1] if len(task_details) > 1 else {}}),
global_step=global_step,
)
# We are doing parallel evaluations of multiple checkpoints and recording the steps not in order
# This messes up with tensorboard, so the easiest is to rename files in the order of the checkpoints
# See: https://github.com/tensorflow/tensorboard/issues/5958
# But tensorboardX don't let us control the prefix of the files (only the suffix), so we need to do it ourselves before commiting the files
tb_context.close() # flushes the unfinished write operations
time.sleep(5)
files = os.listdir(output_dir_tb)
for file in files:
os.rename(os.path.join(output_dir_tb, file), os.path.join(output_dir_tb, f"{global_step:07d}_{file}"))
# Now we can push to the hub
tb_context.scheduler.trigger()
hlog(
f"Pushed to tensorboard at https://huggingface.co/tensorboard/{lighteval_config.logging.hub_repo_tensorboard}/"
f" at {output_dir_tb} and global_step {global_step}"
)
@htrack()
def main(args):
cache_dir = args.cache_dir or CACHE_DIR
check_env()
dist.initialize_torch_distributed()
with htrack_block("get config"):
if not args.checkpoint_config_path.endswith(".yaml"):
raise ValueError("The checkpoint path should point to a YAML file")
local_config_path = args.checkpoint_config_path
if args.checkpoint_config_path.startswith("s3:/"):
local_config_path = args.checkpoint_config_path.replace("s3:/", cache_dir)
with local_ranks_zero_first():
if os.environ.get("LOCAL_RANK", None) == "0":
os.makedirs(os.path.dirname(local_config_path), exist_ok=True)
fs_copy(args.checkpoint_config_path, local_config_path)
brrr_config: BrrrConfig = get_config_from_file(local_config_path, config_class=BrrrConfig, model_config_class=MistralConfig)
if args.lighteval_override:
local_override_path = args.lighteval_override.replace("s3:/", cache_dir)
if args.lighteval_override.startswith("s3:/"):
local_override_path = args.lighteval_override.replace("s3:/", cache_dir)
with local_ranks_zero_first():
if os.environ.get("LOCAL_RANK", None) == "0":
os.makedirs(os.path.dirname(local_override_path), exist_ok=True)
fs_copy(args.lighteval_override, local_override_path)
lighteval_brrr_config: BrrrConfig = get_config_from_file(local_override_path, config_class=BrrrConfig)
lighteval_config = lighteval_brrr_config.lighteval
brrr_config.lighteval = lighteval_config
else:
local_override_path = ""
lighteval_config = brrr_config.lighteval
parallel_context = ParallelContext(
tensor_parallel_size=lighteval_config.parallelism.tp,
pipeline_parallel_size=lighteval_config.parallelism.pp,
data_parallel_size=lighteval_config.parallelism.dp,
)
evaluation_tracker = EvaluationTracker(token=TOKEN)
evaluation_tracker.general_config_logger.log_args_info(
num_fewshot_seeds=1,
override_batch_size=None,
max_samples=lighteval_config.tasks.max_samples,
job_id=os.environ.get("SLURM_JOB_ID", None),
config=brrr_config.as_dict(),
)
with htrack_block("Test all gather"):
hlog("Test gather tensor")
# Do a first NCCL sync to warmup and try to avoid Timeout after model/data loading
log_rank(
f"[TEST] Running NCCL sync for ranks {list(range(parallel_context.world_pg.size()))}",
logger=logger,
level=logging.WARNING,
group=parallel_context.dp_pg,
rank=0,
)
test_tensor = torch.tensor([dist.get_rank(parallel_context.world_pg)], device=torch.device("cuda"))
test_tensor_list = [torch.zeros_like(test_tensor) for _ in range(parallel_context.world_pg.size())]
dist.all_gather(test_tensor_list, test_tensor, group=parallel_context.world_pg, async_op=False)
dist.barrier()
log_rank(
f"[TEST] NCCL sync for ranks {[t.item() for t in test_tensor_list]}",
logger=logger,
level=logging.WARNING,
group=parallel_context.dp_pg,
rank=0,
)
del test_tensor_list
del test_tensor
with htrack_block("Model loading"):
# We need to load the model in the main process first to avoid downloading the model multiple times
model = BRRRModel(
checkpoint_path=args.checkpoint_config_path.replace("config.yaml", ""),
model_args=brrr_config.model,
tokenizer=brrr_config.tokenizer,
parallel_context=parallel_context,
parallel_config=lighteval_config.parallelism,
lighteval_config=lighteval_config,
batch_size=lighteval_config.batch_size,
cache_dir=os.environ.get("HF_HOME", "/scratch"),
debug_one_layer_model=False,
s5cmd_path=args.s5cmd_path,
s5cmd_numworkers=args.s5cmd_numworkers,
s5cmd_concurrency=args.s5cmd_concurrency,
model_class=MistralForTraining
)
model_info = ModelInfo(model_name=f"{brrr_config.general.run}/{brrr_config.general.step}")
evaluation_tracker.general_config_logger.log_model_info(model_info)
with htrack_block("Tasks loading"):
with local_ranks_zero_first():
tasks_selection = lighteval_config.tasks.tasks
if lighteval_config.tasks.custom_tasks_file:
_, tasks_groups_dict = get_custom_tasks(lighteval_config.tasks.custom_tasks_file)
if tasks_groups_dict and lighteval_config.tasks.tasks in tasks_groups_dict:
tasks_selection = tasks_groups_dict[lighteval_config.tasks.tasks]
task_names_list, few_shots_dict = taskinfo_selector(tasks_selection)
task_dict = Registry(cache_dir=cache_dir).get_task_dict(
task_names_list, custom_tasks_file=lighteval_config.tasks.custom_tasks_file
)
# Loading all the dataset in a distributed manner
LightevalTask.load_datasets(task_dict.values(), lighteval_config.tasks.dataset_loading_processes)
evaluation_tracker.task_config_logger.log(task_dict)
hlog("Loading documents, and requests")
requests, docs = create_requests_from_tasks(
task_dict=task_dict,
fewshot_dict=few_shots_dict,
num_fewshot_seeds=lighteval_config.tasks.num_fewshot_seeds or 1,
lm=model,
max_samples=lighteval_config.tasks.max_samples,
evaluation_tracker=evaluation_tracker,
use_chat_template=False
)
with htrack_block("Setting seeds and waiting for all processes"):
hlog(f"setting seed to {1234} for random and numpy")
random.seed(1234)
np.random.seed(1234)
dist.barrier()
with htrack_block("Evaluation"):
hlog(f"Evaluate on {len(task_names_list)} tasks.")
evaluation_tracker = evaluate(
lm=model,
requests_dict=requests,
docs=docs,
task_dict=task_dict,
override_bs=lighteval_config.batch_size,
evaluation_tracker=evaluation_tracker,
)
if dist.get_rank(parallel_context.world_pg) == 0:
with htrack_block("Compiling and saving results"):
evaluation_tracker.general_config_logger.log_end_time()
evaluation_tracker.metrics_logger.aggregate(task_dict=task_dict, bootstrap_iters=1000)
evaluation_tracker.details_logger.aggregate()
if lighteval_config.logging.local_output_path:
evaluation_tracker.save(
output_dir=lighteval_config.logging.local_output_path,
push_results_to_hub=lighteval_config.logging.push_results_to_hub,
push_details_to_hub=lighteval_config.logging.push_details_to_hub,
public=False,
push_results_to_tensorboard=lighteval_config.logging.push_results_to_tensorboard,
)
if lighteval_config.logging.push_results_to_tensorboard:
push_results_to_tensorboard(
config=brrr_config,
results=evaluation_tracker.metrics_logger.metric_aggregated,
details=evaluation_tracker.details_logger.details,
)
if lighteval_config.wandb is not None:
push_results_to_wandb(
config=brrr_config,
results=evaluation_tracker.metrics_logger.metric_aggregated,
details=evaluation_tracker.details_logger.details,
)
final_dict = evaluation_tracker.generate_final_dict()
hlog(make_results_table(final_dict))
return final_dict
if __name__ == "__main__":
parser = get_parser()
args, unknowns = parser.parse_known_args()
main(args)