| | """ |
| | Pythia モデル実装 |
| | |
| | EleutherAIによる完全オープンソースモデル |
| | 学習データ(The Pile)とアーキテクチャが完全公開 |
| | """ |
| | from typing import List, Tuple |
| |
|
| | import torch |
| | from transformers import GPTNeoXForCausalLM, AutoTokenizer |
| |
|
| | from .base import BaseLanguageModel, ModelConfig |
| |
|
| |
|
| | |
| | PYTHIA_410M_CONFIG = ModelConfig( |
| | name="Pythia 410M", |
| | model_id="EleutherAI/pythia-410m", |
| | embedding_dim=1024, |
| | vocab_size=50304, |
| | ) |
| |
|
| | |
| | PYTHIA_1B_CONFIG = ModelConfig( |
| | name="Pythia 1B", |
| | model_id="EleutherAI/pythia-1b", |
| | embedding_dim=2048, |
| | vocab_size=50304, |
| | ) |
| |
|
| |
|
| | class PythiaModel(BaseLanguageModel): |
| | """ |
| | Pythiaモデルの実装(GPT-NeoXベース) |
| | |
| | EleutherAIが公開した完全オープンソースモデル。 |
| | 学習データ(The Pile)も公開されている。 |
| | """ |
| |
|
| | |
| | LOGITS_NOISE_SCALE = 10.0 |
| |
|
| | def load(self) -> None: |
| | """モデルとトークナイザーをロード""" |
| | if self._is_loaded: |
| | return |
| |
|
| | try: |
| | self._model = GPTNeoXForCausalLM.from_pretrained(self._config.model_id) |
| | self._tokenizer = AutoTokenizer.from_pretrained(self._config.model_id) |
| | self._model.eval() |
| | self._is_loaded = True |
| | except Exception as e: |
| | raise RuntimeError(f"Failed to load model {self._config.model_id}: {e}") |
| |
|
| | def forward_with_noise( |
| | self, noise: torch.Tensor |
| | ) -> Tuple[torch.Tensor, torch.Tensor]: |
| | """ |
| | ノイズを入力として順伝播を実行し、出力にもノイズを加算 |
| | |
| | Args: |
| | noise: 入力ノイズテンソル [batch, seq_len, embedding_dim] |
| | |
| | Returns: |
| | Tuple[logits, corrupted_logits] |
| | """ |
| | if not self._is_loaded: |
| | raise RuntimeError("Model not loaded. Call load() first.") |
| |
|
| | with torch.no_grad(): |
| | outputs = self._model(inputs_embeds=noise) |
| | logits = outputs.logits |
| |
|
| | |
| | logits_noise = ( |
| | torch.randn_like(logits) * logits.std() * self.LOGITS_NOISE_SCALE |
| | ) |
| | corrupted_logits = logits + logits_noise |
| |
|
| | return logits, corrupted_logits |
| |
|
| | def decode_indices(self, indices: List[int]) -> List[str]: |
| | """トークンインデックスをデコード""" |
| | if not self._is_loaded: |
| | raise RuntimeError("Model not loaded. Call load() first.") |
| |
|
| | return [self._tokenizer.decode([i]) for i in indices] |
| |
|