neuralworm's picture
fix
094008d
import os
import torch
import random
import numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed
from typing import Optional, List
from dataclasses import dataclass, field
# NEU: Importiere die zentrale cleanup-Funktion
from .utils import dbg, cleanup_memory
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
@dataclass
class StableLLMConfig:
hidden_dim: int
num_layers: int
layer_list: List[torch.nn.Module] = field(default_factory=list, repr=False)
class LLM:
# __init__ und _populate_stable_config bleiben exakt wie in der vorherigen Version.
def __init__(self, model_id: str, device: str = "auto", seed: int = 42):
self.model_id = model_id
self.seed = seed
self.set_all_seeds(self.seed)
token = os.environ.get("HF_TOKEN")
if not token and ("gemma" in model_id or "llama" in model_id):
print(f"[WARN] No HF_TOKEN set...", flush=True)
kwargs = {"torch_dtype": torch.bfloat16} if torch.cuda.is_available() else {}
dbg(f"Loading tokenizer for '{model_id}'...")
self.tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True, token=token)
dbg(f"Loading model '{model_id}' with kwargs: {kwargs}")
self.model = AutoModelForCausalLM.from_pretrained(model_id, device_map=device, token=token, **kwargs)
try:
self.model.set_attn_implementation('eager')
dbg("Successfully set attention implementation to 'eager'.")
except Exception as e:
print(f"[WARN] Could not set 'eager' attention: {e}.", flush=True)
self.model.eval()
self.config = self.model.config
self.stable_config = self._populate_stable_config()
print(f"[INFO] Model '{model_id}' loaded on device: {self.model.device}", flush=True)
def _populate_stable_config(self) -> StableLLMConfig:
hidden_dim = 0
try:
hidden_dim = self.model.get_input_embeddings().weight.shape[1]
except AttributeError:
hidden_dim = getattr(self.config, 'hidden_size', getattr(self.config, 'd_model', 0))
num_layers = 0
layer_list = []
try:
if hasattr(self.model, 'model') and hasattr(self.model.model, 'language_model') and hasattr(self.model.model.language_model, 'layers'):
layer_list = self.model.model.language_model.layers
elif hasattr(self.model, 'model') and hasattr(self.model.model, 'layers'):
layer_list = self.model.model.layers
elif hasattr(self.model, 'transformer') and hasattr(self.model.transformer, 'h'):
layer_list = self.model.transformer.h
if layer_list:
num_layers = len(layer_list)
except (AttributeError, TypeError):
pass
if num_layers == 0:
num_layers = getattr(self.config, 'num_hidden_layers', getattr(self.config, 'num_layers', 0))
if hidden_dim <= 0 or num_layers <= 0 or not layer_list:
dbg("--- CRITICAL: Failed to auto-determine model configuration. ---")
dbg(self.model)
assert hidden_dim > 0, "Could not determine hidden dimension."
assert num_layers > 0, "Could not determine number of layers."
assert layer_list, "Could not find the list of transformer layers."
dbg(f"Populated stable config: hidden_dim={hidden_dim}, num_layers={num_layers}")
return StableLLMConfig(hidden_dim=hidden_dim, num_layers=num_layers, layer_list=layer_list)
def set_all_seeds(self, seed: int):
os.environ['PYTHONHASHSEED'] = str(seed)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
set_seed(seed)
torch.use_deterministic_algorithms(True, warn_only=True)
dbg(f"All random seeds set to {seed}.")
@torch.no_grad()
def generate_text(self, prompt: str, max_new_tokens: int, temperature: float) -> str:
self.set_all_seeds(self.seed)
messages = [{"role": "user", "content": prompt}]
inputs = self.tokenizer.apply_chat_template(
messages, tokenize=True, add_generation_prompt=True, return_tensors="pt"
).to(self.model.device)
outputs = self.model.generate(
inputs, max_new_tokens=max_new_tokens, temperature=temperature, do_sample=temperature > 0,
)
response_tokens = outputs[0, inputs.shape[-1]:]
return self.tokenizer.decode(response_tokens, skip_special_tokens=True)
def get_or_load_model(model_id: str, seed: int) -> LLM:
"""Lädt bei jedem Aufruf eine frische, isolierte Instanz des Modells."""
dbg(f"--- Force-reloading model '{model_id}' for total run isolation ---")
cleanup_memory() # Bereinige Speicher, *bevor* ein neues Modell geladen wird.
return LLM(model_id=model_id, seed=seed)
# NEU: Explizite Funktion zum Freigeben von Ressourcen
def release_model(llm: Optional[LLM]):
"""
Gibt die Ressourcen eines LLM-Objekts explizit frei und ruft die zentrale
Speicherbereinigungs-Funktion auf.
"""
if llm is None:
return
dbg(f"Releasing model instance for '{llm.model_id}'.")
del llm
cleanup_memory()