import os import pathlib import random import unittest from unittest.mock import MagicMock import pandas as pd import pytest from llm_studio.app_utils.default_datasets import ( prepare_default_dataset_causal_language_modeling, ) from llm_studio.src.datasets.conversation_chain_handler import ConversationChainHandler from llm_studio.src.utils.data_utils import load_train_valid_data @pytest.fixture def cfg_mock(): cfg = MagicMock() cfg.dataset.train_dataframe = "/path/to/train/data" cfg.dataset.validation_dataframe = "/path/to/validation/data" cfg.dataset.system_column = "None" cfg.dataset.prompt_column = "prompt" cfg.dataset.answer_column = "answer" cfg.dataset.validation_size = 0.2 return cfg @pytest.fixture def read_dataframe_drop_missing_labels_mock(monkeypatch): data = { "prompt": [f"Prompt{i}" for i in range(100)], "answer": [f"Answer{i}" for i in range(100)], "id": list(range(100)), } df = pd.DataFrame(data) mock = MagicMock(return_value=df) monkeypatch.setattr( "llm_studio.src.utils.data_utils.read_dataframe_drop_missing_labels", mock ) return mock numbers = list(range(100)) random.shuffle( numbers, ) groups = [numbers[n::13] for n in range(13)] @pytest.fixture def conversation_chain_ids_mock(monkeypatch): def mocked_init(self, *args, **kwargs): self.conversation_chain_ids = groups with unittest.mock.patch.object( ConversationChainHandler, "__init__", new=mocked_init ): yield def test_get_data_custom_validation_strategy( cfg_mock, read_dataframe_drop_missing_labels_mock ): cfg_mock.dataset.validation_strategy = "custom" train_df, val_df = load_train_valid_data(cfg_mock) assert len(train_df), len(val_df) == 100 def test_get_data_automatic_split( cfg_mock, read_dataframe_drop_missing_labels_mock, conversation_chain_ids_mock ): cfg_mock.dataset.validation_strategy = "automatic" train_df, val_df = load_train_valid_data(cfg_mock) train_ids = set(train_df["id"].tolist()) val_ids = set(val_df["id"].tolist()) assert len(train_ids.intersection(val_ids)) == 0 assert len(train_ids) + len(val_ids) == 100 shared_groups = [ i for i in groups if not train_ids.isdisjoint(i) and not val_ids.isdisjoint(i) ] assert len(shared_groups) == 0 def test_oasst_data_automatic_split(tmp_path: pathlib.Path): prepare_default_dataset_causal_language_modeling(tmp_path) assert len(os.listdir(tmp_path)) > 0, tmp_path cfg_mock = MagicMock() for file in os.listdir(tmp_path): if file.endswith(".pq"): cfg_mock.dataset.train_dataframe = os.path.join(tmp_path, file) cfg_mock.dataset.system_column = "None" cfg_mock.dataset.prompt_column = ("instruction",) cfg_mock.dataset.answer_column = "output" cfg_mock.dataset.parent_id_column = "parent_id" cfg_mock.dataset.validation_strategy = "automatic" for validation_size in [0.05, 0.1, 0.2, 0.3, 0.4, 0.5]: cfg_mock.dataset.validation_size = validation_size train_df, val_df = load_train_valid_data(cfg_mock) assert set(train_df["parent_id"].dropna().values).isdisjoint( set(val_df["id"].dropna().values) ) assert set(val_df["parent_id"].dropna().values).isdisjoint( set(train_df["id"].dropna().values) ) assert (len(val_df) / (len(train_df) + len(val_df))) == pytest.approx( validation_size, 0.05 )