File size: 5,296 Bytes
c8fa89c
 
 
 
094008d
1cf9e80
 
c8fa89c
094008d
 
c8fa89c
 
 
7dac8c1
 
 
 
1cf9e80
7dac8c1
c8fa89c
094008d
c8fa89c
 
 
 
 
 
bca8f87
c8fa89c
 
 
 
 
 
 
 
 
a345062
c8fa89c
2161eb0
7dac8c1
a345062
c8fa89c
7dac8c1
2169e97
7dac8c1
 
 
 
2161eb0
1cf9e80
2161eb0
2169e97
1cf9e80
2169e97
1cf9e80
2161eb0
1cf9e80
 
 
2161eb0
 
 
 
1cf9e80
2169e97
 
bca8f87
 
 
7dac8c1
1cf9e80
7dac8c1
c8fa89c
 
 
 
 
 
 
 
 
 
 
bca8f87
 
094008d
bca8f87
 
 
 
 
094008d
bca8f87
 
 
 
c8fa89c
094008d
c8fa89c
094008d
c8fa89c
094008d
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import os
import torch
import random
import numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed
from typing import Optional, List
from dataclasses import dataclass, field

# NEU: Importiere die zentrale cleanup-Funktion
from .utils import dbg, cleanup_memory

os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"

@dataclass
class StableLLMConfig:
    hidden_dim: int
    num_layers: int
    layer_list: List[torch.nn.Module] = field(default_factory=list, repr=False)

class LLM:
    # __init__ und _populate_stable_config bleiben exakt wie in der vorherigen Version.
    def __init__(self, model_id: str, device: str = "auto", seed: int = 42):
        self.model_id = model_id
        self.seed = seed
        self.set_all_seeds(self.seed)
        token = os.environ.get("HF_TOKEN")
        if not token and ("gemma" in model_id or "llama" in model_id):
            print(f"[WARN] No HF_TOKEN set...", flush=True)
        kwargs = {"torch_dtype": torch.bfloat16} if torch.cuda.is_available() else {}
        dbg(f"Loading tokenizer for '{model_id}'...")
        self.tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True, token=token)
        dbg(f"Loading model '{model_id}' with kwargs: {kwargs}")
        self.model = AutoModelForCausalLM.from_pretrained(model_id, device_map=device, token=token, **kwargs)
        try:
            self.model.set_attn_implementation('eager')
            dbg("Successfully set attention implementation to 'eager'.")
        except Exception as e:
            print(f"[WARN] Could not set 'eager' attention: {e}.", flush=True)
        self.model.eval()
        self.config = self.model.config
        self.stable_config = self._populate_stable_config()
        print(f"[INFO] Model '{model_id}' loaded on device: {self.model.device}", flush=True)

    def _populate_stable_config(self) -> StableLLMConfig:
        hidden_dim = 0
        try:
            hidden_dim = self.model.get_input_embeddings().weight.shape[1]
        except AttributeError:
            hidden_dim = getattr(self.config, 'hidden_size', getattr(self.config, 'd_model', 0))
        num_layers = 0
        layer_list = []
        try:
            if hasattr(self.model, 'model') and hasattr(self.model.model, 'language_model') and hasattr(self.model.model.language_model, 'layers'):
                 layer_list = self.model.model.language_model.layers
            elif hasattr(self.model, 'model') and hasattr(self.model.model, 'layers'):
                 layer_list = self.model.model.layers
            elif hasattr(self.model, 'transformer') and hasattr(self.model.transformer, 'h'):
                 layer_list = self.model.transformer.h
            if layer_list:
                num_layers = len(layer_list)
        except (AttributeError, TypeError):
            pass
        if num_layers == 0:
            num_layers = getattr(self.config, 'num_hidden_layers', getattr(self.config, 'num_layers', 0))
        if hidden_dim <= 0 or num_layers <= 0 or not layer_list:
            dbg("--- CRITICAL: Failed to auto-determine model configuration. ---")
            dbg(self.model)
        assert hidden_dim > 0, "Could not determine hidden dimension."
        assert num_layers > 0, "Could not determine number of layers."
        assert layer_list, "Could not find the list of transformer layers."
        dbg(f"Populated stable config: hidden_dim={hidden_dim}, num_layers={num_layers}")
        return StableLLMConfig(hidden_dim=hidden_dim, num_layers=num_layers, layer_list=layer_list)

    def set_all_seeds(self, seed: int):
        os.environ['PYTHONHASHSEED'] = str(seed)
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(seed)
        set_seed(seed)
        torch.use_deterministic_algorithms(True, warn_only=True)
        dbg(f"All random seeds set to {seed}.")

    @torch.no_grad()
    def generate_text(self, prompt: str, max_new_tokens: int, temperature: float) -> str:
        self.set_all_seeds(self.seed)
        messages = [{"role": "user", "content": prompt}]
        inputs = self.tokenizer.apply_chat_template(
            messages, tokenize=True, add_generation_prompt=True, return_tensors="pt"
        ).to(self.model.device)
        outputs = self.model.generate(
            inputs, max_new_tokens=max_new_tokens, temperature=temperature, do_sample=temperature > 0,
        )
        response_tokens = outputs[0, inputs.shape[-1]:]
        return self.tokenizer.decode(response_tokens, skip_special_tokens=True)

def get_or_load_model(model_id: str, seed: int) -> LLM:
    """Lädt bei jedem Aufruf eine frische, isolierte Instanz des Modells."""
    dbg(f"--- Force-reloading model '{model_id}' for total run isolation ---")
    cleanup_memory() # Bereinige Speicher, *bevor* ein neues Modell geladen wird.
    return LLM(model_id=model_id, seed=seed)

# NEU: Explizite Funktion zum Freigeben von Ressourcen
def release_model(llm: Optional[LLM]):
    """
    Gibt die Ressourcen eines LLM-Objekts explizit frei und ruft die zentrale
    Speicherbereinigungs-Funktion auf.
    """
    if llm is None:
        return
    dbg(f"Releasing model instance for '{llm.model_id}'.")
    del llm
    cleanup_memory()