medichat_assignment / tests /test_config.py
helamouri's picture
update model
eca6215
import pytest
import torch
from src.config import (MAX_SEQ_LENGTH, DTYPE, LOAD_IN_4BIT, DEVICE_MAP, EOS_TOKEN,
ALPACA_PROMPT_TEMPLATE, TRAIN_ARGS)
# Test that required configuration keys are present
def test_required_config_keys():
assert MAX_SEQ_LENGTH is not None, "MAX_SEQ_LENGTH is not set."
assert TRAIN_ARGS is not None, "TRAIN_ARGS is not set."
assert ALPACA_PROMPT_TEMPLATE is not None, "ALPACA_PROMPT_TEMPLATE is not set."
assert DEVICE_MAP is not None, "DEVICE_MAP is not set."
# Test that MAX_SEQ_LENGTH is a power of two
def test_max_seq_length():
assert isinstance(MAX_SEQ_LENGTH, int), "MAX_SEQ_LENGTH should be an integer."
assert MAX_SEQ_LENGTH > 0, "MAX_SEQ_LENGTH should be greater than 0."
assert (MAX_SEQ_LENGTH & (MAX_SEQ_LENGTH - 1)) == 0, "MAX_SEQ_LENGTH should be a power of two."
# Test that TRAIN_ARGS dictionary contains required fields and types
def test_train_args():
required_keys = [
"per_device_train_batch_size",
"gradient_accumulation_steps",
"warmup_steps",
"max_steps",
"learning_rate",
"fp16",
"bf16",
"logging_steps",
"optim",
"weight_decay",
"lr_scheduler_type",
"seed",
"output_dir"
]
for key in required_keys:
assert key in TRAIN_ARGS, f"Missing {key} in TRAIN_ARGS."
# Check types of specific fields
assert isinstance(TRAIN_ARGS["per_device_train_batch_size"], int), "per_device_train_batch_size should be an integer."
assert isinstance(TRAIN_ARGS["learning_rate"], float), "learning_rate should be a float."
assert isinstance(TRAIN_ARGS["output_dir"], str), "output_dir should be a string."
# Test that the DEVICE_MAP references a valid CUDA device
@pytest.mark.gpu
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires GPU")
def test_device_map():
device = DEVICE_MAP.get('', None)
assert device is not None, "DEVICE_MAP should reference a CUDA device."
assert isinstance(device, int), "DEVICE_MAP should be an integer (CUDA device ID)."
assert torch.cuda.is_available(), "CUDA is not available, but DEVICE_MAP points to a CUDA device."
# Test that the EOS_TOKEN is set dynamically based on the tokenizer
def test_eos_token():
assert EOS_TOKEN is not None, "EOS_TOKEN should be dynamically set based on tokenizer."
# Test the ALPACA_PROMPT_TEMPLATE for expected formatting
def test_alpaca_prompt_template():
test_instruction = "Test Instruction"
test_input = "Test Input"
test_output = "Test Output"
formatted_prompt = ALPACA_PROMPT_TEMPLATE.format(test_instruction, test_input, test_output)
# Ensure that the prompt template contains the required placeholders
assert "{}" in formatted_prompt, "ALPACA_PROMPT_TEMPLATE should contain placeholders."
assert "###Instruction:" in formatted_prompt, "ALPACA_PROMPT_TEMPLATE should contain '###Instruction'."
assert "###Input:" in formatted_prompt, "ALPACA_PROMPT_TEMPLATE should contain '###Input'."
assert "###Response:" in formatted_prompt, "ALPACA_PROMPT_TEMPLATE should contain '###Response'."
# Test that the LOAD_IN_4BIT setting is a boolean
def test_load_in_4bit():
assert isinstance(LOAD_IN_4BIT, bool), "LOAD_IN_4BIT should be a boolean."
# Test for the DTYPE (should be None or a valid data type)
def test_dtype():
assert DTYPE is None or isinstance(DTYPE, type), "DTYPE should be None or a valid data type."