| """
|
| Tests for TouchGrass Trainer.
|
| """
|
|
|
| import pytest
|
| import torch
|
| from unittest.mock import MagicMock, patch
|
|
|
| from TouchGrass.training.trainer import TouchGrassTrainer
|
|
|
|
|
| class TestTouchGrassTrainer:
|
| """Test suite for TouchGrassTrainer."""
|
|
|
| def setup_method(self):
|
| """Set up test fixtures."""
|
| self.device = "cpu"
|
| self.d_model = 768
|
| self.vocab_size = 32000
|
|
|
|
|
| self.model = MagicMock()
|
| self.model.parameters.return_value = [torch.randn(10, requires_grad=True)]
|
|
|
|
|
| self.tokenizer = MagicMock()
|
| self.tokenizer.pad_token_id = 0
|
|
|
|
|
| self.loss_fn = MagicMock()
|
| self.loss_fn.return_value = {"total_loss": torch.tensor(0.5)}
|
|
|
|
|
| self.optimizer = MagicMock()
|
| self.optimizer.step = MagicMock()
|
| self.optimizer.zero_grad = MagicMock()
|
|
|
|
|
| self.scheduler = MagicMock()
|
| self.scheduler.step = MagicMock()
|
|
|
|
|
| self.config = {
|
| "batch_size": 4,
|
| "gradient_accumulation_steps": 1,
|
| "learning_rate": 2e-4,
|
| "max_grad_norm": 1.0,
|
| "num_epochs": 1,
|
| "save_steps": 100,
|
| "eval_steps": 50,
|
| "output_dir": "test_output"
|
| }
|
|
|
| def test_trainer_initialization(self):
|
| """Test trainer initialization."""
|
| trainer = TouchGrassTrainer(
|
| model=self.model,
|
| tokenizer=self.tokenizer,
|
| loss_fn=self.loss_fn,
|
| optimizer=self.optimizer,
|
| scheduler=self.scheduler,
|
| config=self.config,
|
| device=self.device
|
| )
|
|
|
| assert trainer.model == self.model
|
| assert trainer.tokenizer == self.tokenizer
|
| assert trainer.loss_fn == self.loss_fn
|
| assert trainer.optimizer == self.optimizer
|
| assert trainer.scheduler == self.scheduler
|
| assert trainer.config == self.config
|
|
|
| def test_trainer_required_components(self):
|
| """Test that all required components are present."""
|
| trainer = TouchGrassTrainer(
|
| model=self.model,
|
| tokenizer=self.tokenizer,
|
| loss_fn=self.loss_fn,
|
| optimizer=self.optimizer,
|
| config=self.config,
|
| device=self.device
|
| )
|
|
|
| assert hasattr(trainer, "train")
|
| assert hasattr(trainer, "evaluate")
|
| assert hasattr(trainer, "save_checkpoint")
|
| assert hasattr(trainer, "load_checkpoint")
|
|
|
| def test_prepare_batch(self):
|
| """Test batch preparation."""
|
| trainer = TouchGrassTrainer(
|
| model=self.model,
|
| tokenizer=self.tokenizer,
|
| loss_fn=self.loss_fn,
|
| optimizer=self.optimizer,
|
| config=self.config,
|
| device=self.device
|
| )
|
|
|
| batch = {
|
| "input_ids": torch.randint(0, self.vocab_size, (4, 10)),
|
| "attention_mask": torch.ones(4, 10),
|
| "labels": torch.randint(0, self.vocab_size, (4, 10))
|
| }
|
|
|
| prepared = trainer._prepare_batch(batch)
|
| assert "input_ids" in prepared
|
| assert "attention_mask" in prepared
|
| assert "labels" in prepared
|
|
|
| def test_training_step(self):
|
| """Test single training step."""
|
| trainer = TouchGrassTrainer(
|
| model=self.model,
|
| tokenizer=self.tokenizer,
|
| loss_fn=self.loss_fn,
|
| optimizer=self.optimizer,
|
| config=self.config,
|
| device=self.device
|
| )
|
|
|
| batch = {
|
| "input_ids": torch.randint(0, self.vocab_size, (4, 10)),
|
| "attention_mask": torch.ones(4, 10),
|
| "labels": torch.randint(0, self.vocab_size, (4, 10))
|
| }
|
|
|
| loss = trainer._training_step(batch)
|
| assert isinstance(loss, torch.Tensor) or loss is not None
|
|
|
| def test_evaluation_step(self):
|
| """Test single evaluation step."""
|
| trainer = TouchGrassTrainer(
|
| model=self.model,
|
| tokenizer=self.tokenizer,
|
| loss_fn=self.loss_fn,
|
| optimizer=self.optimizer,
|
| config=self.config,
|
| device=self.device
|
| )
|
|
|
| batch = {
|
| "input_ids": torch.randint(0, self.vocab_size, (4, 10)),
|
| "attention_mask": torch.ones(4, 10),
|
| "labels": torch.randint(0, self.vocab_size, (4, 10))
|
| }
|
|
|
| metrics = trainer._evaluation_step(batch)
|
| assert isinstance(metrics, dict)
|
|
|
| def test_gradient_accumulation(self):
|
| """Test gradient accumulation."""
|
| config = self.config.copy()
|
| config["gradient_accumulation_steps"] = 2
|
|
|
| trainer = TouchGrassTrainer(
|
| model=self.model,
|
| tokenizer=self.tokenizer,
|
| loss_fn=self.loss_fn,
|
| optimizer=self.optimizer,
|
| config=config,
|
| device=self.device
|
| )
|
|
|
| assert trainer.gradient_accumulation_steps == 2
|
|
|
| def test_checkpoint_saving(self, tmp_path):
|
| """Test checkpoint saving."""
|
| config = self.config.copy()
|
| config["output_dir"] = str(tmp_path / "checkpoints")
|
|
|
| trainer = TouchGrassTrainer(
|
| model=self.model,
|
| tokenizer=self.tokenizer,
|
| loss_fn=self.loss_fn,
|
| optimizer=self.optimizer,
|
| config=config,
|
| device=self.device
|
| )
|
|
|
| trainer.save_checkpoint(step=100)
|
|
|
|
|
|
|
| def test_learning_rate_scheduler_step(self):
|
| """Test that scheduler is stepped correctly."""
|
| config = self.config.copy()
|
| config["learning_rate"] = 1e-3
|
|
|
| trainer = TouchGrassTrainer(
|
| model=self.model,
|
| tokenizer=self.tokenizer,
|
| loss_fn=self.loss_fn,
|
| optimizer=self.optimizer,
|
| scheduler=self.scheduler,
|
| config=config,
|
| device=self.device
|
| )
|
|
|
|
|
| batch = {"input_ids": torch.randint(0, self.vocab_size, (4, 10)), "attention_mask": torch.ones(4, 10), "labels": torch.randint(0, self.vocab_size, (4, 10))}
|
| trainer._training_step(batch)
|
|
|
|
|
|
|
|
|
| def test_gradient_clipping(self):
|
| """Test gradient clipping."""
|
| config = self.config.copy()
|
| config["max_grad_norm"] = 1.0
|
|
|
| trainer = TouchGrassTrainer(
|
| model=self.model,
|
| tokenizer=self.tokenizer,
|
| loss_fn=self.loss_fn,
|
| optimizer=self.optimizer,
|
| config=config,
|
| device=self.device
|
| )
|
|
|
| assert trainer.max_grad_norm == 1.0
|
|
|
| def test_mixed_precision_flag(self):
|
| """Test mixed precision training flag."""
|
| config = self.config.copy()
|
| config["mixed_precision"] = True
|
|
|
| trainer = TouchGrassTrainer(
|
| model=self.model,
|
| tokenizer=self.tokenizer,
|
| loss_fn=self.loss_fn,
|
| optimizer=self.optimizer,
|
| config=config,
|
| device=self.device
|
| )
|
|
|
| assert trainer.mixed_precision is True
|
|
|
| def test_device_assignment(self):
|
| """Test that model and data are moved to correct device."""
|
| trainer = TouchGrassTrainer(
|
| model=self.model,
|
| tokenizer=self.tokenizer,
|
| loss_fn=self.loss_fn,
|
| optimizer=self.optimizer,
|
| config=self.config,
|
| device="cpu"
|
| )
|
|
|
| assert trainer.device == "cpu"
|
|
|
| def test_optimizer_zero_grad_called(self):
|
| """Test that optimizer.zero_grad is called."""
|
| trainer = TouchGrassTrainer(
|
| model=self.model,
|
| tokenizer=self.tokenizer,
|
| loss_fn=self.loss_fn,
|
| optimizer=self.optimizer,
|
| config=self.config,
|
| device=self.device
|
| )
|
|
|
| batch = {"input_ids": torch.randint(0, self.vocab_size, (4, 10)), "attention_mask": torch.ones(4, 10), "labels": torch.randint(0, self.vocab_size, (4, 10))}
|
| trainer._training_step(batch)
|
|
|
| self.optimizer.zero_grad.assert_called()
|
|
|
| def test_optimizer_step_called(self):
|
| """Test that optimizer.step is called."""
|
| trainer = TouchGrassTrainer(
|
| model=self.model,
|
| tokenizer=self.tokenizer,
|
| loss_fn=self.loss_fn,
|
| optimizer=self.optimizer,
|
| config=self.config,
|
| device=self.device
|
| )
|
|
|
| batch = {"input_ids": torch.randint(0, self.vocab_size, (4, 10)), "attention_mask": torch.ones(4, 10), "labels": torch.randint(0, self.vocab_size, (4, 10))}
|
| trainer._training_step(batch)
|
|
|
| self.optimizer.step.assert_called()
|
|
|
| def test_loss_fn_called_with_outputs(self):
|
| """Test that loss function is called with model outputs."""
|
| trainer = TouchGrassTrainer(
|
| model=self.model,
|
| tokenizer=self.tokenizer,
|
| loss_fn=self.loss_fn,
|
| optimizer=self.optimizer,
|
| config=self.config,
|
| device=self.device
|
| )
|
|
|
| batch = {"input_ids": torch.randint(0, self.vocab_size, (4, 10)), "attention_mask": torch.ones(4, 10), "labels": torch.randint(0, self.vocab_size, (4, 10))}
|
| trainer._training_step(batch)
|
|
|
|
|
| self.loss_fn.assert_called()
|
|
|
| def test_training_loop(self):
|
| """Test full training loop (simplified)."""
|
| trainer = TouchGrassTrainer(
|
| model=self.model,
|
| tokenizer=self.tokenizer,
|
| loss_fn=self.loss_fn,
|
| optimizer=self.optimizer,
|
| config=self.config,
|
| device=self.device
|
| )
|
|
|
|
|
| train_dataloader = [{"input_ids": torch.randint(0, self.vocab_size, (4, 10)), "attention_mask": torch.ones(4, 10), "labels": torch.randint(0, self.vocab_size, (4, 10))}]
|
| eval_dataloader = [{"input_ids": torch.randint(0, self.vocab_size, (4, 10)), "attention_mask": torch.ones(4, 10), "labels": torch.randint(0, self.vocab_size, (4, 10))}]
|
|
|
|
|
| metrics = trainer.train(train_dataloader, eval_dataloader)
|
| assert isinstance(metrics, dict)
|
|
|
| def test_evaluation_loop(self):
|
| """Test evaluation loop."""
|
| trainer = TouchGrassTrainer(
|
| model=self.model,
|
| tokenizer=self.tokenizer,
|
| loss_fn=self.loss_fn,
|
| optimizer=self.optimizer,
|
| config=self.config,
|
| device=self.device
|
| )
|
|
|
| eval_dataloader = [{"input_ids": torch.randint(0, self.vocab_size, (4, 10)), "attention_mask": torch.ones(4, 10), "labels": torch.randint(0, self.vocab_size, (4, 10))}]
|
|
|
| metrics = trainer.evaluate(eval_dataloader)
|
| assert isinstance(metrics, dict)
|
|
|
| def test_config_validation(self):
|
| """Test that config has required keys."""
|
| required_keys = ["batch_size", "learning_rate", "num_epochs", "output_dir"]
|
|
|
| for key in required_keys:
|
| config = self.config.copy()
|
| del config[key]
|
| with pytest.raises(ValueError, match=key):
|
| TouchGrassTrainer(
|
| model=self.model,
|
| tokenizer=self.tokenizer,
|
| loss_fn=self.loss_fn,
|
| optimizer=self.optimizer,
|
| config=config,
|
| device=self.device
|
| )
|
|
|
| def test_model_mode_training(self):
|
| """Test that model is set to training mode."""
|
| trainer = TouchGrassTrainer(
|
| model=self.model,
|
| tokenizer=self.tokenizer,
|
| loss_fn=self.loss_fn,
|
| optimizer=self.optimizer,
|
| config=self.config,
|
| device=self.device
|
| )
|
|
|
| batch = {"input_ids": torch.randint(0, self.vocab_size, (4, 10)), "attention_mask": torch.ones(4, 10), "labels": torch.randint(0, self.vocab_size, (4, 10))}
|
| trainer._training_step(batch)
|
|
|
| self.model.train.assert_called()
|
|
|
| def test_model_mode_evaluation(self):
|
| """Test that model is set to eval mode during evaluation."""
|
| trainer = TouchGrassTrainer(
|
| model=self.model,
|
| tokenizer=self.tokenizer,
|
| loss_fn=self.loss_fn,
|
| optimizer=self.optimizer,
|
| config=self.config,
|
| device=self.device
|
| )
|
|
|
| batch = {"input_ids": torch.randint(0, self.vocab_size, (4, 10)), "attention_mask": torch.ones(4, 10), "labels": torch.randint(0, self.vocab_size, (4, 10))}
|
| trainer._evaluation_step(batch)
|
|
|
| self.model.eval.assert_called()
|
|
|
|
|
| if __name__ == "__main__":
|
| pytest.main([__file__, "-v"])
|
|
|