Spaces:
Sleeping
Sleeping
import pytest | |
import torch | |
from src.config import (MAX_SEQ_LENGTH, DTYPE, LOAD_IN_4BIT, DEVICE_MAP, EOS_TOKEN, | |
ALPACA_PROMPT_TEMPLATE, TRAIN_ARGS) | |
# Test that required configuration keys are present | |
def test_required_config_keys(): | |
assert MAX_SEQ_LENGTH is not None, "MAX_SEQ_LENGTH is not set." | |
assert TRAIN_ARGS is not None, "TRAIN_ARGS is not set." | |
assert ALPACA_PROMPT_TEMPLATE is not None, "ALPACA_PROMPT_TEMPLATE is not set." | |
assert DEVICE_MAP is not None, "DEVICE_MAP is not set." | |
# Test that MAX_SEQ_LENGTH is a power of two | |
def test_max_seq_length(): | |
assert isinstance(MAX_SEQ_LENGTH, int), "MAX_SEQ_LENGTH should be an integer." | |
assert MAX_SEQ_LENGTH > 0, "MAX_SEQ_LENGTH should be greater than 0." | |
assert (MAX_SEQ_LENGTH & (MAX_SEQ_LENGTH - 1)) == 0, "MAX_SEQ_LENGTH should be a power of two." | |
# Test that TRAIN_ARGS dictionary contains required fields and types | |
def test_train_args(): | |
required_keys = [ | |
"per_device_train_batch_size", | |
"gradient_accumulation_steps", | |
"warmup_steps", | |
"max_steps", | |
"learning_rate", | |
"fp16", | |
"bf16", | |
"logging_steps", | |
"optim", | |
"weight_decay", | |
"lr_scheduler_type", | |
"seed", | |
"output_dir" | |
] | |
for key in required_keys: | |
assert key in TRAIN_ARGS, f"Missing {key} in TRAIN_ARGS." | |
# Check types of specific fields | |
assert isinstance(TRAIN_ARGS["per_device_train_batch_size"], int), "per_device_train_batch_size should be an integer." | |
assert isinstance(TRAIN_ARGS["learning_rate"], float), "learning_rate should be a float." | |
assert isinstance(TRAIN_ARGS["output_dir"], str), "output_dir should be a string." | |
# Test that the DEVICE_MAP references a valid CUDA device | |
def test_device_map(): | |
device = DEVICE_MAP.get('', None) | |
assert device is not None, "DEVICE_MAP should reference a CUDA device." | |
assert isinstance(device, int), "DEVICE_MAP should be an integer (CUDA device ID)." | |
assert torch.cuda.is_available(), "CUDA is not available, but DEVICE_MAP points to a CUDA device." | |
# Test that the EOS_TOKEN is set dynamically based on the tokenizer | |
def test_eos_token(): | |
assert EOS_TOKEN is not None, "EOS_TOKEN should be dynamically set based on tokenizer." | |
# Test the ALPACA_PROMPT_TEMPLATE for expected formatting | |
def test_alpaca_prompt_template(): | |
test_instruction = "Test Instruction" | |
test_input = "Test Input" | |
test_output = "Test Output" | |
formatted_prompt = ALPACA_PROMPT_TEMPLATE.format(test_instruction, test_input, test_output) | |
# Ensure that the prompt template contains the required placeholders | |
assert "{}" in formatted_prompt, "ALPACA_PROMPT_TEMPLATE should contain placeholders." | |
assert "###Instruction:" in formatted_prompt, "ALPACA_PROMPT_TEMPLATE should contain '###Instruction'." | |
assert "###Input:" in formatted_prompt, "ALPACA_PROMPT_TEMPLATE should contain '###Input'." | |
assert "###Response:" in formatted_prompt, "ALPACA_PROMPT_TEMPLATE should contain '###Response'." | |
# Test that the LOAD_IN_4BIT setting is a boolean | |
def test_load_in_4bit(): | |
assert isinstance(LOAD_IN_4BIT, bool), "LOAD_IN_4BIT should be a boolean." | |
# Test for the DTYPE (should be None or a valid data type) | |
def test_dtype(): | |
assert DTYPE is None or isinstance(DTYPE, type), "DTYPE should be None or a valid data type." | |