H2OTest / tests /src /utils /test_data_utils.py
elineve's picture
Upload 301 files
07423df
raw
history blame
3.68 kB
import os
import pathlib
import random
import unittest
from unittest.mock import MagicMock
import pandas as pd
import pytest
from llm_studio.app_utils.default_datasets import (
prepare_default_dataset_causal_language_modeling,
)
from llm_studio.src.datasets.conversation_chain_handler import ConversationChainHandler
from llm_studio.src.utils.data_utils import load_train_valid_data
@pytest.fixture
def cfg_mock():
cfg = MagicMock()
cfg.dataset.train_dataframe = "/path/to/train/data"
cfg.dataset.validation_dataframe = "/path/to/validation/data"
cfg.dataset.system_column = "None"
cfg.dataset.prompt_column = "prompt"
cfg.dataset.answer_column = "answer"
cfg.dataset.validation_size = 0.2
return cfg
@pytest.fixture
def read_dataframe_drop_missing_labels_mock(monkeypatch):
data = {
"prompt": [f"Prompt{i}" for i in range(100)],
"answer": [f"Answer{i}" for i in range(100)],
"id": list(range(100)),
}
df = pd.DataFrame(data)
mock = MagicMock(return_value=df)
monkeypatch.setattr(
"llm_studio.src.utils.data_utils.read_dataframe_drop_missing_labels", mock
)
return mock
numbers = list(range(100))
random.shuffle(
numbers,
)
groups = [numbers[n::13] for n in range(13)]
@pytest.fixture
def conversation_chain_ids_mock(monkeypatch):
def mocked_init(self, *args, **kwargs):
self.conversation_chain_ids = groups
with unittest.mock.patch.object(
ConversationChainHandler, "__init__", new=mocked_init
):
yield
def test_get_data_custom_validation_strategy(
cfg_mock, read_dataframe_drop_missing_labels_mock
):
cfg_mock.dataset.validation_strategy = "custom"
train_df, val_df = load_train_valid_data(cfg_mock)
assert len(train_df), len(val_df) == 100
def test_get_data_automatic_split(
cfg_mock, read_dataframe_drop_missing_labels_mock, conversation_chain_ids_mock
):
cfg_mock.dataset.validation_strategy = "automatic"
train_df, val_df = load_train_valid_data(cfg_mock)
train_ids = set(train_df["id"].tolist())
val_ids = set(val_df["id"].tolist())
assert len(train_ids.intersection(val_ids)) == 0
assert len(train_ids) + len(val_ids) == 100
shared_groups = [
i for i in groups if not train_ids.isdisjoint(i) and not val_ids.isdisjoint(i)
]
assert len(shared_groups) == 0
def test_oasst_data_automatic_split(tmp_path: pathlib.Path):
prepare_default_dataset_causal_language_modeling(tmp_path)
assert len(os.listdir(tmp_path)) > 0, tmp_path
cfg_mock = MagicMock()
for file in os.listdir(tmp_path):
if file.endswith(".pq"):
cfg_mock.dataset.train_dataframe = os.path.join(tmp_path, file)
cfg_mock.dataset.system_column = "None"
cfg_mock.dataset.prompt_column = ("instruction",)
cfg_mock.dataset.answer_column = "output"
cfg_mock.dataset.parent_id_column = "parent_id"
cfg_mock.dataset.validation_strategy = "automatic"
for validation_size in [0.05, 0.1, 0.2, 0.3, 0.4, 0.5]:
cfg_mock.dataset.validation_size = validation_size
train_df, val_df = load_train_valid_data(cfg_mock)
assert set(train_df["parent_id"].dropna().values).isdisjoint(
set(val_df["id"].dropna().values)
)
assert set(val_df["parent_id"].dropna().values).isdisjoint(
set(train_df["id"].dropna().values)
)
assert (len(val_df) / (len(train_df) + len(val_df))) == pytest.approx(
validation_size, 0.05
)