| | """ |
| | モデル関連のテスト |
| | """ |
| | import pytest |
| | import torch |
| |
|
| | from src.models.base import ModelConfig, BaseLanguageModel |
| | from src.models.registry import ModelRegistry, DEFAULT_MODEL_KEY |
| | from src.models.gpt2 import GPT2Model, GPT2_SMALL_CONFIG |
| |
|
| | |
| | from src.models.gpt_oss import GPTOSSModel, GPT_OSS_20B_CONFIG |
| | from src.models.pythia import PythiaModel, PYTHIA_410M_CONFIG, PYTHIA_1B_CONFIG |
| | from src.models.olmo import OLMoModel, OLMO_1B_CONFIG, OLMO_7B_CONFIG |
| | from src.models.bloom import BLOOMModel, BLOOM_560M_CONFIG |
| |
|
| | |
| | from src.models.llama import LlamaModel, LLAMA_3_2_1B_CONFIG, LLAMA_3_2_3B_CONFIG |
| | from src.models.qwen import QwenModel, QWEN_2_5_0_5B_CONFIG, QWEN_2_5_1_5B_CONFIG |
| | from src.models.mistral import MistralModel, MISTRAL_7B_CONFIG |
| |
|
| |
|
| | class TestModelConfig: |
| | """ModelConfigのテスト""" |
| |
|
| | def test_config_is_immutable(self): |
| | """設定がイミュータブルであることを確認""" |
| | config = ModelConfig( |
| | name="Test", |
| | model_id="test", |
| | embedding_dim=768, |
| | vocab_size=50000, |
| | ) |
| | with pytest.raises(Exception): |
| | config.name = "Changed" |
| |
|
| | def test_config_attributes(self): |
| | """設定属性が正しく保持されることを確認""" |
| | config = ModelConfig( |
| | name="Test Model", |
| | model_id="test-model", |
| | embedding_dim=1024, |
| | vocab_size=30000, |
| | ) |
| | assert config.name == "Test Model" |
| | assert config.model_id == "test-model" |
| | assert config.embedding_dim == 1024 |
| | assert config.vocab_size == 30000 |
| |
|
| |
|
| | class TestModelRegistry: |
| | """ModelRegistryのテスト""" |
| |
|
| | def test_list_models(self): |
| | """登録済みモデル一覧が取得できることを確認""" |
| | models = ModelRegistry.list_models() |
| | assert len(models) > 0 |
| | assert DEFAULT_MODEL_KEY in models |
| |
|
| | def test_get_model(self): |
| | """モデルインスタンスが取得できることを確認""" |
| | model = ModelRegistry.get(DEFAULT_MODEL_KEY) |
| | assert isinstance(model, BaseLanguageModel) |
| |
|
| | def test_get_nonexistent_model(self): |
| | """存在しないモデルでKeyErrorが発生することを確認""" |
| | with pytest.raises(KeyError): |
| | ModelRegistry.get("nonexistent-model") |
| |
|
| | def test_get_config(self): |
| | """モデル設定が取得できることを確認""" |
| | config = ModelRegistry.get_config(DEFAULT_MODEL_KEY) |
| | assert config is not None |
| | assert isinstance(config, ModelConfig) |
| |
|
| | def test_get_all_configs(self): |
| | """すべてのモデル設定が取得できることを確認""" |
| | configs = ModelRegistry.get_all_configs() |
| | assert len(configs) > 0 |
| | for key, config in configs.items(): |
| | assert isinstance(config, ModelConfig) |
| |
|
| |
|
| | class TestGPT2Model: |
| | """GPT2Modelのテスト""" |
| |
|
| | def test_config(self): |
| | """設定が正しいことを確認""" |
| | model = GPT2Model(GPT2_SMALL_CONFIG) |
| | assert model.config == GPT2_SMALL_CONFIG |
| | assert model.config.embedding_dim == 768 |
| |
|
| | def test_is_loaded_initial(self): |
| | """初期状態ではロードされていないことを確認""" |
| | model = GPT2Model(GPT2_SMALL_CONFIG) |
| | assert not model.is_loaded |
| |
|
| | def test_generate_noise(self): |
| | """ノイズ生成が正しい形状であることを確認""" |
| | model = GPT2Model(GPT2_SMALL_CONFIG) |
| | noise = model.generate_noise(seq_len=16, batch_size=2) |
| | assert noise.shape == (2, 16, 768) |
| |
|
| |
|
| | @pytest.mark.slow |
| | class TestGPT2ModelIntegration: |
| | """GPT2Modelの統合テスト(モデルロードが必要)""" |
| |
|
| | @pytest.fixture |
| | def loaded_model(self): |
| | """ロード済みモデルを提供""" |
| | model = GPT2Model(GPT2_SMALL_CONFIG) |
| | model.load() |
| | return model |
| |
|
| | def test_load(self, loaded_model): |
| | """モデルがロードできることを確認""" |
| | assert loaded_model.is_loaded |
| |
|
| | def test_forward_with_noise(self, loaded_model): |
| | """順伝播が正しい形状を返すことを確認""" |
| | noise = loaded_model.generate_noise(seq_len=8) |
| | logits, corrupted_logits = loaded_model.forward_with_noise(noise) |
| |
|
| | assert logits.shape[0] == 1 |
| | assert logits.shape[1] == 8 |
| | assert logits.shape[2] == loaded_model.config.vocab_size |
| |
|
| | def test_decode_indices(self, loaded_model): |
| | """デコードが文字列リストを返すことを確認""" |
| | indices = [100, 200, 300] |
| | decoded = loaded_model.decode_indices(indices) |
| |
|
| | assert len(decoded) == 3 |
| | assert all(isinstance(s, str) for s in decoded) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| |
|
| | class TestGPTOSSModel: |
| | """GPTOSSModelのテスト""" |
| |
|
| | def test_config(self): |
| | """設定が正しいことを確認""" |
| | model = GPTOSSModel(GPT_OSS_20B_CONFIG) |
| | assert model.config == GPT_OSS_20B_CONFIG |
| | assert model.config.embedding_dim == 4096 |
| | assert model.config.vocab_size == 128000 |
| |
|
| | def test_is_loaded_initial(self): |
| | """初期状態ではロードされていないことを確認""" |
| | model = GPTOSSModel(GPT_OSS_20B_CONFIG) |
| | assert not model.is_loaded |
| |
|
| | def test_generate_noise(self): |
| | """ノイズ生成が正しい形状であることを確認""" |
| | model = GPTOSSModel(GPT_OSS_20B_CONFIG) |
| | noise = model.generate_noise(seq_len=16, batch_size=2) |
| | assert noise.shape == (2, 16, 4096) |
| |
|
| |
|
| | class TestPythiaModel: |
| | """PythiaModelのテスト""" |
| |
|
| | def test_config_410m(self): |
| | """Pythia 410M設定が正しいことを確認""" |
| | model = PythiaModel(PYTHIA_410M_CONFIG) |
| | assert model.config == PYTHIA_410M_CONFIG |
| | assert model.config.embedding_dim == 1024 |
| | assert model.config.vocab_size == 50304 |
| |
|
| | def test_config_1b(self): |
| | """Pythia 1B設定が正しいことを確認""" |
| | model = PythiaModel(PYTHIA_1B_CONFIG) |
| | assert model.config == PYTHIA_1B_CONFIG |
| | assert model.config.embedding_dim == 2048 |
| | assert model.config.vocab_size == 50304 |
| |
|
| | def test_is_loaded_initial(self): |
| | """初期状態ではロードされていないことを確認""" |
| | model = PythiaModel(PYTHIA_410M_CONFIG) |
| | assert not model.is_loaded |
| |
|
| | def test_generate_noise(self): |
| | """ノイズ生成が正しい形状であることを確認""" |
| | model = PythiaModel(PYTHIA_410M_CONFIG) |
| | noise = model.generate_noise(seq_len=16, batch_size=2) |
| | assert noise.shape == (2, 16, 1024) |
| |
|
| |
|
| | class TestOLMoModel: |
| | """OLMoModelのテスト""" |
| |
|
| | def test_config_1b(self): |
| | """OLMo 1B設定が正しいことを確認""" |
| | model = OLMoModel(OLMO_1B_CONFIG) |
| | assert model.config == OLMO_1B_CONFIG |
| | assert model.config.embedding_dim == 2048 |
| | assert model.config.vocab_size == 50304 |
| |
|
| | def test_config_7b(self): |
| | """OLMo 7B設定が正しいことを確認""" |
| | model = OLMoModel(OLMO_7B_CONFIG) |
| | assert model.config == OLMO_7B_CONFIG |
| | assert model.config.embedding_dim == 4096 |
| | assert model.config.vocab_size == 50304 |
| |
|
| | def test_is_loaded_initial(self): |
| | """初期状態ではロードされていないことを確認""" |
| | model = OLMoModel(OLMO_1B_CONFIG) |
| | assert not model.is_loaded |
| |
|
| | def test_generate_noise(self): |
| | """ノイズ生成が正しい形状であることを確認""" |
| | model = OLMoModel(OLMO_1B_CONFIG) |
| | noise = model.generate_noise(seq_len=16, batch_size=2) |
| | assert noise.shape == (2, 16, 2048) |
| |
|
| |
|
| | class TestBLOOMModel: |
| | """BLOOMModelのテスト""" |
| |
|
| | def test_config(self): |
| | """BLOOM 560M設定が正しいことを確認""" |
| | model = BLOOMModel(BLOOM_560M_CONFIG) |
| | assert model.config == BLOOM_560M_CONFIG |
| | assert model.config.embedding_dim == 1024 |
| | assert model.config.vocab_size == 250880 |
| |
|
| | def test_is_loaded_initial(self): |
| | """初期状態ではロードされていないことを確認""" |
| | model = BLOOMModel(BLOOM_560M_CONFIG) |
| | assert not model.is_loaded |
| |
|
| | def test_generate_noise(self): |
| | """ノイズ生成が正しい形状であることを確認""" |
| | model = BLOOMModel(BLOOM_560M_CONFIG) |
| | noise = model.generate_noise(seq_len=16, batch_size=2) |
| | assert noise.shape == (2, 16, 1024) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| |
|
| | class TestLlamaModel: |
| | """LlamaModelのテスト""" |
| |
|
| | def test_config_1b(self): |
| | """Llama 3.2 1B設定が正しいことを確認""" |
| | model = LlamaModel(LLAMA_3_2_1B_CONFIG) |
| | assert model.config == LLAMA_3_2_1B_CONFIG |
| | assert model.config.embedding_dim == 2048 |
| | assert model.config.vocab_size == 128256 |
| |
|
| | def test_config_3b(self): |
| | """Llama 3.2 3B設定が正しいことを確認""" |
| | model = LlamaModel(LLAMA_3_2_3B_CONFIG) |
| | assert model.config == LLAMA_3_2_3B_CONFIG |
| | assert model.config.embedding_dim == 3072 |
| | assert model.config.vocab_size == 128256 |
| |
|
| | def test_is_loaded_initial(self): |
| | """初期状態ではロードされていないことを確認""" |
| | model = LlamaModel(LLAMA_3_2_1B_CONFIG) |
| | assert not model.is_loaded |
| |
|
| | def test_generate_noise(self): |
| | """ノイズ生成が正しい形状であることを確認""" |
| | model = LlamaModel(LLAMA_3_2_1B_CONFIG) |
| | noise = model.generate_noise(seq_len=16, batch_size=2) |
| | assert noise.shape == (2, 16, 2048) |
| |
|
| |
|
| | class TestQwenModel: |
| | """QwenModelのテスト""" |
| |
|
| | def test_config_0_5b(self): |
| | """Qwen2.5 0.5B設定が正しいことを確認""" |
| | model = QwenModel(QWEN_2_5_0_5B_CONFIG) |
| | assert model.config == QWEN_2_5_0_5B_CONFIG |
| | assert model.config.embedding_dim == 896 |
| | assert model.config.vocab_size == 151936 |
| |
|
| | def test_config_1_5b(self): |
| | """Qwen2.5 1.5B設定が正しいことを確認""" |
| | model = QwenModel(QWEN_2_5_1_5B_CONFIG) |
| | assert model.config == QWEN_2_5_1_5B_CONFIG |
| | assert model.config.embedding_dim == 1536 |
| | assert model.config.vocab_size == 151936 |
| |
|
| | def test_is_loaded_initial(self): |
| | """初期状態ではロードされていないことを確認""" |
| | model = QwenModel(QWEN_2_5_0_5B_CONFIG) |
| | assert not model.is_loaded |
| |
|
| | def test_generate_noise(self): |
| | """ノイズ生成が正しい形状であることを確認""" |
| | model = QwenModel(QWEN_2_5_0_5B_CONFIG) |
| | noise = model.generate_noise(seq_len=16, batch_size=2) |
| | assert noise.shape == (2, 16, 896) |
| |
|
| |
|
| | class TestMistralModel: |
| | """MistralModelのテスト""" |
| |
|
| | def test_config(self): |
| | """Mistral 7B設定が正しいことを確認""" |
| | model = MistralModel(MISTRAL_7B_CONFIG) |
| | assert model.config == MISTRAL_7B_CONFIG |
| | assert model.config.embedding_dim == 4096 |
| | assert model.config.vocab_size == 32768 |
| |
|
| | def test_is_loaded_initial(self): |
| | """初期状態ではロードされていないことを確認""" |
| | model = MistralModel(MISTRAL_7B_CONFIG) |
| | assert not model.is_loaded |
| |
|
| | def test_generate_noise(self): |
| | """ノイズ生成が正しい形状であることを確認""" |
| | model = MistralModel(MISTRAL_7B_CONFIG) |
| | noise = model.generate_noise(seq_len=16, batch_size=2) |
| | assert noise.shape == (2, 16, 4096) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| |
|
| | class TestModelRegistryNewModels: |
| | """新規追加モデルのレジストリテスト""" |
| |
|
| | @pytest.mark.parametrize("model_key", [ |
| | "gpt-oss-20b", |
| | "pythia-410m", |
| | "pythia-1b", |
| | "olmo-1b", |
| | "olmo-7b", |
| | "bloom-560m", |
| | "llama-3.2-1b", |
| | "llama-3.2-3b", |
| | "qwen2.5-0.5b", |
| | "qwen2.5-1.5b", |
| | "mistral-7b", |
| | ]) |
| | def test_model_registered(self, model_key): |
| | """新モデルがレジストリに登録されていることを確認""" |
| | models = ModelRegistry.list_models() |
| | assert model_key in models |
| |
|
| | @pytest.mark.parametrize("model_key", [ |
| | "gpt-oss-20b", |
| | "pythia-410m", |
| | "pythia-1b", |
| | "olmo-1b", |
| | "olmo-7b", |
| | "bloom-560m", |
| | "llama-3.2-1b", |
| | "llama-3.2-3b", |
| | "qwen2.5-0.5b", |
| | "qwen2.5-1.5b", |
| | "mistral-7b", |
| | ]) |
| | def test_model_instance_creation(self, model_key): |
| | """新モデルのインスタンスが作成できることを確認""" |
| | model = ModelRegistry.get(model_key) |
| | assert isinstance(model, BaseLanguageModel) |
| | assert not model.is_loaded |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| |
|
| | @pytest.mark.slow |
| | class TestPythiaModelIntegration: |
| | """Pythiaモデルの統合テスト(小さいモデルで代表テスト)""" |
| |
|
| | @pytest.fixture |
| | def loaded_model(self): |
| | """ロード済みモデルを提供""" |
| | model = PythiaModel(PYTHIA_410M_CONFIG) |
| | model.load() |
| | return model |
| |
|
| | def test_load(self, loaded_model): |
| | """モデルがロードできることを確認""" |
| | assert loaded_model.is_loaded |
| |
|
| | def test_forward_with_noise(self, loaded_model): |
| | """順伝播が正しい形状を返すことを確認""" |
| | noise = loaded_model.generate_noise(seq_len=8) |
| | logits, corrupted_logits = loaded_model.forward_with_noise(noise) |
| |
|
| | assert logits.shape[0] == 1 |
| | assert logits.shape[1] == 8 |
| | assert logits.shape[2] == loaded_model.config.vocab_size |
| |
|
| | def test_decode_indices(self, loaded_model): |
| | """デコードが文字列リストを返すことを確認""" |
| | indices = [100, 200, 300] |
| | decoded = loaded_model.decode_indices(indices) |
| |
|
| | assert len(decoded) == 3 |
| | assert all(isinstance(s, str) for s in decoded) |
| |
|
| |
|
| | @pytest.mark.slow |
| | class TestBLOOMModelIntegration: |
| | """BLOOMモデルの統合テスト""" |
| |
|
| | @pytest.fixture |
| | def loaded_model(self): |
| | """ロード済みモデルを提供""" |
| | model = BLOOMModel(BLOOM_560M_CONFIG) |
| | model.load() |
| | return model |
| |
|
| | def test_load(self, loaded_model): |
| | """モデルがロードできることを確認""" |
| | assert loaded_model.is_loaded |
| |
|
| | def test_forward_with_noise(self, loaded_model): |
| | """順伝播が正しい形状を返すことを確認""" |
| | noise = loaded_model.generate_noise(seq_len=8) |
| | logits, corrupted_logits = loaded_model.forward_with_noise(noise) |
| |
|
| | assert logits.shape[0] == 1 |
| | assert logits.shape[1] == 8 |
| | assert logits.shape[2] == loaded_model.config.vocab_size |
| |
|
| |
|
| | @pytest.mark.slow |
| | class TestQwenModelIntegration: |
| | """Qwenモデルの統合テスト(小さいモデルで代表テスト)""" |
| |
|
| | @pytest.fixture |
| | def loaded_model(self): |
| | """ロード済みモデルを提供""" |
| | model = QwenModel(QWEN_2_5_0_5B_CONFIG) |
| | model.load() |
| | return model |
| |
|
| | def test_load(self, loaded_model): |
| | """モデルがロードできることを確認""" |
| | assert loaded_model.is_loaded |
| |
|
| | def test_forward_with_noise(self, loaded_model): |
| | """順伝播が正しい形状を返すことを確認""" |
| | noise = loaded_model.generate_noise(seq_len=8) |
| | logits, corrupted_logits = loaded_model.forward_with_noise(noise) |
| |
|
| | assert logits.shape[0] == 1 |
| | assert logits.shape[1] == 8 |
| | assert logits.shape[2] == loaded_model.config.vocab_size |
| |
|