File size: 5,695 Bytes
4e28148
c23327e
 
 
 
 
4e1d4d6
4e28148
 
 
 
 
 
 
 
3606324
 
f101698
4e1d4d6
c23327e
4e1d4d6
4e28148
becf062
c23327e
 
 
 
 
 
4e28148
 
4e1d4d6
 
 
 
3606324
4e28148
 
 
 
 
 
b4ead85
4e28148
 
becf062
4e28148
 
 
 
 
 
 
 
 
 
 
c23327e
e278675
4e28148
 
c23327e
 
becf062
4e28148
 
4e1d4d6
c23327e
4e28148
 
 
 
 
 
 
0614fbc
 
 
f2b373b
4e28148
4e1d4d6
 
 
4e28148
becf062
4e28148
c23327e
4e1d4d6
4e28148
 
4e1d4d6
4e28148
 
 
4e1d4d6
4e28148
 
4e1d4d6
 
c23327e
4e28148
c23327e
4e1d4d6
4e28148
c23327e
4e28148
4e1d4d6
 
c23327e
4e1d4d6
0614fbc
c23327e
 
 
 
 
4e28148
 
4e1d4d6
c23327e
4e1d4d6
4e28148
c23327e
 
 
 
4e1d4d6
c23327e
becf062
 
 
4e1d4d6
c23327e
 
 
 
 
4e28148
 
4e1d4d6
c23327e
4e1d4d6
4e28148
becf062
c23327e
becf062
4e1d4d6
 
0614fbc
c23327e
becf062
4e1d4d6
 
 
c23327e
 
b4ead85
c23327e
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
"""
custom_modeling.py  – model-agnostic toxicity wrapper
----------------------------------------------------
Place in repo root together with:
  • toxic.keras
Add to config.json:
  "auto_map": { "AutoModelForCausalLM": "custom_modeling.SafeGenerationModel" }
"""

import importlib
from functools import lru_cache

import torch
import transformers
import tensorflow as tf
from huggingface_hub import hf_hub_download


# ------------------------------------------------------------------ #
# 1)  MIXIN – toxicity filtering logic                               #
# ------------------------------------------------------------------ #
class _SafeGenerationMixin:
    _toxicity_model = None
    _tox_threshold   = 0.6

    # Separate messages
    _safe_in_msg  = "Sorry, I can’t help with that request."
    _safe_out_msg = "I’m sorry, but I can’t continue with that."

    _tokenizer = None

    # ---- helpers ----------------------------------------------------
    def _device(self):
        return next(self.parameters()).device

    @property
    def _tox_model(self):
        if self._toxicity_model is None:
            path = hf_hub_download(
                repo_id=self.config.name_or_path,
                filename="toxic.keras",
            )
            self._toxicity_model = tf.keras.models.load_model(path, compile=False)
        return self._toxicity_model

    def _ensure_tokenizer(self):
        if self._tokenizer is None:
            try:
                self._tokenizer = transformers.AutoTokenizer.from_pretrained(
                    self.config.name_or_path, trust_remote_code=True
                )
            except Exception:
                pass

    def _is_toxic(self, text: str) -> bool:
        if not text.strip():
            return False
        inputs = tf.constant([text], dtype=tf.string)
        prob   = float(self._tox_model.predict(inputs)[0, 0])
        return prob >= self._tox_threshold

    def _safe_ids(self, message: str, length: int | None = None):
        """Encode *message* and pad/truncate to *length* tokens (if given)."""
        self._ensure_tokenizer()
        if self._tokenizer is None:
            raise RuntimeError("Tokenizer unavailable for safe-message encoding.")

        ids = self._tokenizer(message, return_tensors="pt")["input_ids"][0]
        if length is not None:
            pad_id = (
                self.config.eos_token_id
                if self.config.eos_token_id is not None
                else (self.config.pad_token_id or 0)
            )
            if ids.size(0) < length:
                ids = torch.cat(
                    [ids, ids.new_full((length - ids.size(0),), pad_id)], dim=0
                )
            else:
                ids = ids[:length]
        return ids.to(self._device())

    # ---- main override ---------------------------------------------
    def generate(self, *args, **kwargs):
        self._ensure_tokenizer()

        # 1) prompt toxicity
        prompt_txt = None
        if self._tokenizer is not None:
            if "input_ids" in kwargs:
                prompt_txt = self._tokenizer.decode(
                    kwargs["input_ids"][0].tolist(), skip_special_tokens=True
                )
            elif args:
                prompt_txt = self._tokenizer.decode(
                    args[0][0].tolist(), skip_special_tokens=True
                )

        if prompt_txt and self._is_toxic(prompt_txt):
            return self._safe_ids(self._safe_in_msg).unsqueeze(0)

        # 2) normal generation
        outputs = super().generate(*args, **kwargs)

        # 3) output toxicity
        if self._tokenizer is None:
            return outputs

        new_seqs = []
        for seq in outputs.detach().cpu():
            txt = self._tokenizer.decode(seq.tolist(), skip_special_tokens=True)
            if self._is_toxic(txt):
                new_seqs.append(self._safe_ids(self._safe_out_msg, length=seq.size(0)))
            else:
                new_seqs.append(seq)
        return torch.stack(new_seqs, dim=0).to(self._device())


# ------------------------------------------------------------------ #
# 2)  utilities: resolve base class & cache subclass                 #
# ------------------------------------------------------------------ #
@lru_cache(None)
def _get_base_cls(arch: str):
    if hasattr(transformers, arch):
        return getattr(transformers, arch)
    stem = arch.replace("ForCausalLM", "").lower()
    module = importlib.import_module(f"transformers.models.{stem}.modeling_{stem}")
    return getattr(module, arch)


@lru_cache(None)
def _make_safe_subclass(base_cls):
    return type(
        f"SafeGeneration_{base_cls.__name__}",
        (_SafeGenerationMixin, base_cls),
        {},
    )


# ------------------------------------------------------------------ #
# 3)  Dispatcher class – referenced by auto_map                      #
# ------------------------------------------------------------------ #
class SafeGenerationModel:
    @classmethod
    def from_pretrained(cls, repo_id, *model_args, **kwargs):
        kwargs.setdefault("trust_remote_code", True)
        if kwargs.get("torch_dtype") == "auto":
            kwargs.pop("torch_dtype")

        config = transformers.AutoConfig.from_pretrained(repo_id, **kwargs)
        if not getattr(config, "architectures", None):
            raise ValueError("`config.architectures` missing in config.json.")
        arch_str = config.architectures[0]

        Base = _get_base_cls(arch_str)
        Safe = _make_safe_subclass(Base)

        kwargs.pop("config", None)    # avoid duplicate
        return Safe.from_pretrained(repo_id, *model_args, config=config, **kwargs)