CHEMISTral7Bv0.3 / tests /test_data.py
Clemspace's picture
Initial model upload
cb9e677
import copy
import json
from pathlib import Path
import numpy as np
import pytest
from mistral_common.protocol.instruct.messages import FinetuningAssistantMessage
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
from finetune.data.args import DataArgs, InstructArgs
from finetune.data.data_loader import build_data_loader
from finetune.data.dataset import (
DataFile,
SampleType,
get_dataset_iterator,
get_rng,
lazy_load_and_yield,
maybe_chunk_lines,
parse_data_sources,
preload_and_yield,
)
from finetune.data.tokenize import build_instruct_sample, encode
from .test_utils import spawn_for_all_world_sizes
# fmt: off
EXPECTED_X = [
[
# for pretrain
[
# for DP=1
[2051851, 1961139, 2000184, 2081307, 2341123, 1225437, 1739008, 724695, 570810, 632094]
],
[
# for DP=2
[2020745, 1938377, 2244286, 2042079, 1824023],
[2103241, 2032118, 1868430, 1093072, 770996],
]
],
[
# for instruct
[
# for DP=1
[1379941, 1438894, 965536, 1019713, 889921, 999322, 1647173, 941080, 1281597, 1584884]
],
[
# for DP=2
[1379941, 1438894, 889899, 1005451, 876854],
[1034325, 999322, 982295, 941080, 725946],
]
]
]
EXPECTED_Y = [
[
# for pretrain
[
# for DP=1
[2081367, 1961098, 1970714, 2110856, 2334822, 1251057, 1745267, 699854, 571600, 660015]
],
[
# for DP=2
[2021840, 1966833, 2223275, 2063077, 1824011],
[2132793, 2002569, 1870876, 1122569, 757126],
]
],
[
# for instruct
[
# for DP=1
[1409448, 1430886, 937609, 1019339, 889921, 970976, 1660330, 942631, 1308399, 1583658]
],
[
# for DP=2
[1409448, 1430886, 895531, 990091, 863522],
[1041462, 970976, 991091, 942631, 737311]
]
]
]
EXPECTED_MASKS = [
[
# for pretrain
[
# for DP=1
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
],
[
# for DP=2
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
]
],
[
# for instruct
[
# for DP=1
[47, 0, 34, 0, 0, 82, 0, 0, 0, 0]
],
[
# for DP=2
[47, 0, 4, 0, 0],
[19, 82, 0, 0, 23],
]
]
]
EXPECTED_X_FUNC = [
[
# for DP=1
[1005531, 1551735, 1261711, 1531024, 1280259, 1069883, 858107, 1021583, 1203265, 1242999],
],
[
# for DP=2
[985281, 1217766, 1442139, 1533790, 1253607],
[1005531, 1551735, 1261711, 1531024, 1280259],
]
]
EXPECTED_Y_FUNC = [
[
# for DP=1
[977126, 1580120, 1233326, 1559463, 1280241, 1042456, 879031, 994127, 1196263, 1270581],
],
[
# for DP=2
[957934, 1218899, 1441783, 1533011, 1224541],
[977126, 1580120, 1233326, 1559463, 1280241],
]
]
EXPECTED_MASKS_FUNC = [
[
# for DP=1
[91, 0, 0, 0, 0, 77, 0, 0, 53, 0],
],
[
# for DP=2
[16, 47, 0, 86, 98],
[91, 0, 0, 0, 0],
]
]
# fmt: on
class MockTokenizer:
def encode(self, content: str, *args, **kwargs) -> str:
return content
class MockInstructTokenizerBaseBase:
def __init__(self):
self.tokenizer = MockTokenizer()
def encode_user_message(self, message, *args, **kwargs):
return message.content
def encode_assistant_message(self, message, *args, **kwargs):
return message.content
def start(self):
return []
def stringify(samples):
lines = []
for sample in samples:
string_list = sample.tokens
lines.append("".join(string_list))
return lines
@pytest.mark.parametrize(
("world_size", "model_parallel", "is_instruct"),
[
(1, 1, False),
(2, 1, False),
(2, 2, False),
(1, 1, True),
(2, 1, True),
(2, 2, True),
],
)
def test_data_loader_dist(world_size, model_parallel, is_instruct):
spawn_for_all_world_sizes(
_check_data_loader_dist,
world_sizes=[world_size],
args=[model_parallel, is_instruct],
deterministic=True,
)
def _check_data_loader_dist(
rank: int,
world_size: int,
filename: str,
filename_rpc: str,
model_parallel: int,
is_instruct: bool,
):
dp_world_size = world_size // model_parallel
dp_rank = rank // model_parallel
seed = 0
seq_len = 100
batch_size = 1
instruct = InstructArgs(shuffle=False, dynamic_chunk_fn_call=False)
if is_instruct:
# at the moment we have to pass some instuction finetuning
data_args = DataArgs(
data="",
instruct_data="tests/fixtures/sample_instruct.jsonl:.1,tests/fixtures/sample_instruct_2.jsonl:.1,tests/fixtures/sample_instruct_3.jsonl:.1",
instruct=instruct,
)
else:
data_args = DataArgs(
data="tests/fixtures/sample_pretrain_1.jsonl:1.0,tests/fixtures/sample_pretrain_2.jsonl:1.0",
instruct_data="tests/fixtures/sample_instruct.jsonl:.01",
instruct=instruct,
)
instruct_tokenizer = MistralTokenizer.v3().instruct_tokenizer
data_loader = build_data_loader(
instruct_tokenizer,
data_args,
batch_size,
seq_len,
seed=seed,
rank=dp_rank,
world_size=dp_world_size,
is_eval=False,
)
x_sums = []
y_sums = []
masks = []
num_samples = 10 // dp_world_size
for _ in range(num_samples):
batch = next(data_loader)
x_sums.append(batch.x.sum())
y_sums.append(batch.y.sum())
mask_sum = batch.y_mask.sum() if batch.y_mask is not None else 0
masks.append(mask_sum)
expected_x_sums = EXPECTED_X[is_instruct][dp_world_size - 1][dp_rank]
expected_y_sums = EXPECTED_Y[is_instruct][dp_world_size - 1][dp_rank]
expected_masks = EXPECTED_MASKS[is_instruct][dp_world_size - 1][dp_rank]
print(f"rank: {rank}, world_size: {world_size}, x: {x_sums}")
print(f"rank: {rank}, world_size: {world_size}, y: {y_sums}")
print(f"rank: {rank}, world_size: {world_size}, x shape: {masks}")
assert x_sums == expected_x_sums, x_sums
assert y_sums == expected_y_sums, y_sums
assert masks == expected_masks, masks
@pytest.mark.parametrize("world_size", [1, 2])
def test_data_loader_dist_fn_call(world_size):
spawn_for_all_world_sizes(
_check_data_loader_dist_fn_call,
world_sizes=[world_size],
deterministic=True,
)
def _check_data_loader_dist_fn_call(
rank: int,
world_size: int,
filename: str,
filename_rpc: str,
):
dp_world_size = world_size
dp_rank = rank
seed = 0
seq_len = 100
batch_size = 1
data_args = DataArgs(
data="",
instruct_data="tests/fixtures/sample_instruct_fn_call_short.jsonl:.3",
instruct=InstructArgs(shuffle=True, dynamic_chunk_fn_call=True),
)
instruct_tokenizer = MistralTokenizer.v3().instruct_tokenizer
data_loader = build_data_loader(
instruct_tokenizer,
data_args,
batch_size,
seq_len,
seed=seed,
rank=dp_rank,
world_size=dp_world_size,
is_eval=False,
)
x_sums = []
y_sums = []
masks = []
num_samples = 10 // dp_world_size
for _ in range(num_samples):
batch = next(data_loader)
x_sums.append(batch.x.sum())
y_sums.append(batch.y.sum())
mask_sum = batch.y_mask.sum() if batch.y_mask is not None else 0
masks.append(mask_sum)
expected_x_sums = EXPECTED_X_FUNC[dp_world_size - 1][dp_rank]
expected_y_sums = EXPECTED_Y_FUNC[dp_world_size - 1][dp_rank]
expected_masks = EXPECTED_MASKS_FUNC[dp_world_size - 1][dp_rank]
assert x_sums == expected_x_sums, x_sums
assert y_sums == expected_y_sums, y_sums
assert masks == expected_masks, masks
def test_data_loader_equal_fsdp():
spawn_for_all_world_sizes(
_check_data_loader_equal_fsdp,
world_sizes=[2],
deterministic=True,
)
def _check_data_loader_equal_fsdp(
rank: int,
world_size: int,
filename: str,
filename_rpc: str,
):
model_parallel = 2
world_size // model_parallel
rank // model_parallel
seed = 0
seq_len = 100
batch_size = 1
instruct = InstructArgs(shuffle=False, dynamic_chunk_fn_call=False)
data_args = DataArgs(
data="",
instruct_data="tests/fixtures/sample_instruct.jsonl:.1,tests/fixtures/sample_instruct_2.jsonl:.1,tests/fixtures/sample_instruct_3.jsonl:.1",
instruct=instruct,
)
instruct_tokenizer = MistralTokenizer.v3().instruct_tokenizer
data_loader_0 = build_data_loader(
instruct_tokenizer,
data_args,
batch_size,
seq_len,
seed=seed,
rank=0,
world_size=world_size,
is_eval=False,
)
data_loader_1 = build_data_loader(
instruct_tokenizer,
data_args,
batch_size,
seq_len,
seed=seed,
rank=1,
world_size=world_size,
is_eval=False,
)
x_sums = []
y_sums = []
num_samples = 10 // 2 # run 5 * 2 training steps
for _ in range(num_samples):
batch = next(data_loader_0)
x_sums.append(batch.x.sum())
y_sums.append(batch.y.sum())
batch = next(data_loader_1)
x_sums.append(batch.x.sum())
y_sums.append(batch.y.sum())
# check that mp can match ddp for both ranks
expected_x_sums = [
y for x in zip(EXPECTED_X[1][1][0], EXPECTED_X[1][1][1]) for y in x
]
expected_y_sums = [
y for x in zip(EXPECTED_Y[1][1][0], EXPECTED_Y[1][1][1]) for y in x
]
assert x_sums == expected_x_sums, x_sums
assert y_sums == expected_y_sums, y_sums
def test_dynamic_fn_call_chunk():
jsonl_file = Path("tests/fixtures/sample_instruct_fn_call_short.jsonl")
non_chunked_samples = []
with jsonl_file.open() as file_handle:
for line in file_handle:
non_chunked_samples.append(build_instruct_sample(json.loads(line)))
num_expected_chunks = 0
for sample in non_chunked_samples:
if sample.only_last:
num_expected_chunks += (
sum(isinstance(m, FinetuningAssistantMessage) for m in sample.messages)
- 1
)
chunked_samples = []
with jsonl_file.open() as file_handle:
lines = file_handle.readlines()
extra_lines = maybe_chunk_lines(lines)
for line in extra_lines:
chunked_samples.append(build_instruct_sample(json.loads(line)))
assert num_expected_chunks == len(chunked_samples)
def test_dynamic_fn_call_chunk_integration():
jsonl_file = Path("tests/fixtures/sample_instruct_fn_call_multi.jsonl")
multi_samples = []
with jsonl_file.open() as file_handle:
for line in file_handle:
multi_samples.append(build_instruct_sample(json.loads(line)))
jsonl_file = Path("tests/fixtures/sample_instruct_fn_call_single.jsonl")
chunked_samples = []
with jsonl_file.open() as file_handle:
for line in file_handle:
chunked_samples.append(build_instruct_sample(json.loads(line)))
with jsonl_file.open() as file_handle:
lines = file_handle.readlines()
extra_lines = maybe_chunk_lines(lines)
for line in extra_lines:
chunked_samples.append(build_instruct_sample(json.loads(line)))
assert list(reversed(multi_samples)) == chunked_samples
def test_fn_call():
batch_size = 1
data_args = DataArgs(
data="",
instruct_data="",
eval_instruct_data="tests/fixtures/sample_instruct_fn_call.jsonl",
)
instruct_tokenizer = MistralTokenizer.v3().instruct_tokenizer
seq_len = 10000
data_loader = build_data_loader(
instruct_tokenizer,
data_args,
batch_size,
seq_len,
seed=None,
rank=0,
world_size=1,
is_eval=True,
)
all_loss_strings = []
for batch in data_loader:
y_mask = (
np.asarray(batch.y_mask, int)
if batch.y_mask is not None
else np.ones_like(batch.x)
)
start_index = end_index = 0
for size in batch.sizes:
end_index += size
tokens = batch.y[start_index:end_index]
mask = y_mask[start_index:end_index]
tokens_for_loss = [int(y) for i, y in enumerate(tokens) if mask[i] == 1]
start_index += size
decoded = instruct_tokenizer.tokenizer.decode(tokens_for_loss)
if len(decoded) > 0:
all_loss_strings.append(decoded)
# Verify that the loss is always only computed over the
expected_loss_strings = []
with open(data_args.eval_instruct_data, "r") as f:
for line in f:
data = json.loads(line)
last_message = data["interactions"][-1]
if "content" in last_message:
expected_loss_strings.append(last_message["content"])
elif "tool_calls" in last_message:
tool_calls = last_message["tool_calls"]
arguments = tool_calls[0]["function"]["arguments"]
string = [
{
"name": call["function"]["name"],
"arguments": json.loads(arguments),
}
for call in tool_calls
]
expected_loss_strings.append(json.dumps(string))
assert expected_loss_strings == all_loss_strings
def test_data_weighting():
data_args = DataArgs(
data="",
instruct_data="",
eval_instruct_data="tests/fixtures/sample_instruct.jsonl",
)
instruct_tokenizer = MistralTokenizer.v3().instruct_tokenizer
jsonl_file = Path(data_args.eval_instruct_data)
with jsonl_file.open() as file_handle:
data = json.loads(next(file_handle))
token_sample = encode(data, instruct_tokenizer, SampleType.INSTRUCT)
data_weight_0 = copy.deepcopy(data)
data_weight_0["interactions"][-1]["weight"] = 0
token_sample_weight_0 = encode(
data_weight_0, instruct_tokenizer, SampleType.INSTRUCT
)
data_weight_1 = copy.deepcopy(data)
data_weight_1["interactions"][-1]["weight"] = 1
token_sample_weight_1 = encode(
data_weight_1, instruct_tokenizer, SampleType.INSTRUCT
)
assert (
token_sample.tokens
== token_sample_weight_0.tokens
== token_sample_weight_1.tokens
)
assert token_sample.masks == token_sample_weight_1.masks
assert token_sample.masks != token_sample_weight_0.masks
assert not any(token_sample_weight_0.masks)
def test_eval_dataloader():
batch_size = 1
data_args = DataArgs(
data="",
instruct_data="",
eval_instruct_data="tests/fixtures/sample_instruct.jsonl,tests/fixtures/sample_instruct_2.jsonl,tests/fixtures/sample_instruct_3.jsonl",
)
instruct_tokenizer = MistralTokenizer.v3().instruct_tokenizer
# make sure that for every seq len the same data is seen
for world_size in [1, 2, 8]:
for seq_len in [10, 100, 1000, 10000]:
x_sums = []
y_sums = []
y_masks = []
data_loaders = []
for rank in range(world_size):
data_loaders.append(
build_data_loader(
instruct_tokenizer,
data_args,
batch_size,
seq_len,
seed=None,
rank=rank,
world_size=world_size,
is_eval=True,
)
)
for data_loader in data_loaders:
for batch in data_loader:
mask = (
np.asarray(batch.y_mask, int)
if batch.y_mask is not None
else np.ones_like(batch.x)
)
x_sums.append((batch.x * mask).sum())
y_sums.append((batch.y * mask).sum())
y_masks.append(mask.sum())
assert len(batch.x) == len(mask) == len(batch.y) == seq_len
assert sum(x_sums) == 71404835
assert sum(y_sums) == 71404795
assert sum(y_masks) == 5538
def test_shuffle_data():
instruct_tokenizer = MockInstructTokenizerBaseBase()
data_args = DataArgs(data="", instruct_data="", eval_instruct_data="")
data_file = Path("tests/fixtures/sample_instruct_long_1.jsonl")
dataset_iterator = get_dataset_iterator(
source=DataFile(path=data_file, sample_type=SampleType.INSTRUCT),
instruct_args=data_args.instruct,
instruct_tokenizer=instruct_tokenizer,
rank=0,
world_size=1,
is_finite=False,
seed=0,
shuffle_at_epoch=True,
)
with data_file.open() as f:
lines = f.readlines()
lines = [
encode(
json.loads(line),
instruct_tokenizer=instruct_tokenizer,
as_type=SampleType.INSTRUCT,
)
for line in lines
]
prev_lines = stringify(lines)
num_lines = len(prev_lines)
samples = []
# run 4 epochs
for i in range(4 * num_lines):
samples.append(next(dataset_iterator))
if (i + 1) % num_lines == 0:
# epoch finished!
# check that order is different but all lines have the same hash
lines = stringify(samples)
assert lines != prev_lines, "No shuffling - make sure dataset is shuffled!"
assert sorted(lines) == sorted(
prev_lines
), "datasets need to match at every epoch"
prev_lines = lines
samples = []
@pytest.mark.parametrize("world_size", [1, 2])
def test_shuffle_data_same_as_no_shuffle(world_size):
spawn_for_all_world_sizes(
_check_shuffle_data_same_as_no_shuffle,
world_sizes=[world_size],
deterministic=True,
)
def _check_shuffle_data_same_as_no_shuffle(
rank: int,
world_size: int,
filename: str,
filename_rpc: str,
):
instruct_tokenizer = MockInstructTokenizerBaseBase()
instruct = InstructArgs(shuffle=False, dynamic_chunk_fn_call=False)
data_args = DataArgs(
data="tests/fixtures/sample_pretrain_1.jsonl:1.0,tests/fixtures/sample_pretrain_2.jsonl:1.0",
instruct_data="tests/fixtures/sample_instruct.jsonl:.1,tests/fixtures/sample_instruct_2.jsonl:.1,tests/fixtures/sample_instruct_3.jsonl:.1",
instruct=instruct,
)
sources, _ = parse_data_sources(
pretrain_data=data_args.data, instruct_data=data_args.instruct_data
)
seed = 0
rng = get_rng(seed, rank)
for source in sources:
jsonl_files = source.jsonl_files
chunk_dataset = source.sample_type == SampleType.INSTRUCT
for jsonl_file in jsonl_files:
samples = [[], []]
for shuffle in [True, False]:
print(jsonl_file)
if shuffle:
iterator = preload_and_yield(
jsonl_file,
chunk_dataset=chunk_dataset,
rank=rank,
world_size=world_size,
rng=rng,
instruct_tokenizer=instruct_tokenizer,
sample_type=source.sample_type,
)
else:
iterator = lazy_load_and_yield(
jsonl_file,
rank=rank,
world_size=world_size,
instruct_tokenizer=instruct_tokenizer,
sample_type=source.sample_type,
)
for tokens in iterator:
samples[shuffle].append(tokens)
strings_0 = sorted(stringify(samples[0]))
strings_1 = sorted(stringify(samples[1]))
assert strings_0 == strings_1