File size: 5,695 Bytes
4e28148 c23327e 4e1d4d6 4e28148 3606324 f101698 4e1d4d6 c23327e 4e1d4d6 4e28148 becf062 c23327e 4e28148 4e1d4d6 3606324 4e28148 b4ead85 4e28148 becf062 4e28148 c23327e e278675 4e28148 c23327e becf062 4e28148 4e1d4d6 c23327e 4e28148 0614fbc f2b373b 4e28148 4e1d4d6 4e28148 becf062 4e28148 c23327e 4e1d4d6 4e28148 4e1d4d6 4e28148 4e1d4d6 4e28148 4e1d4d6 c23327e 4e28148 c23327e 4e1d4d6 4e28148 c23327e 4e28148 4e1d4d6 c23327e 4e1d4d6 0614fbc c23327e 4e28148 4e1d4d6 c23327e 4e1d4d6 4e28148 c23327e 4e1d4d6 c23327e becf062 4e1d4d6 c23327e 4e28148 4e1d4d6 c23327e 4e1d4d6 4e28148 becf062 c23327e becf062 4e1d4d6 0614fbc c23327e becf062 4e1d4d6 c23327e b4ead85 c23327e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 |
"""
custom_modeling.py – model-agnostic toxicity wrapper
----------------------------------------------------
Place in repo root together with:
• toxic.keras
Add to config.json:
"auto_map": { "AutoModelForCausalLM": "custom_modeling.SafeGenerationModel" }
"""
import importlib
from functools import lru_cache
import torch
import transformers
import tensorflow as tf
from huggingface_hub import hf_hub_download
# ------------------------------------------------------------------ #
# 1) MIXIN – toxicity filtering logic #
# ------------------------------------------------------------------ #
class _SafeGenerationMixin:
_toxicity_model = None
_tox_threshold = 0.6
# Separate messages
_safe_in_msg = "Sorry, I can’t help with that request."
_safe_out_msg = "I’m sorry, but I can’t continue with that."
_tokenizer = None
# ---- helpers ----------------------------------------------------
def _device(self):
return next(self.parameters()).device
@property
def _tox_model(self):
if self._toxicity_model is None:
path = hf_hub_download(
repo_id=self.config.name_or_path,
filename="toxic.keras",
)
self._toxicity_model = tf.keras.models.load_model(path, compile=False)
return self._toxicity_model
def _ensure_tokenizer(self):
if self._tokenizer is None:
try:
self._tokenizer = transformers.AutoTokenizer.from_pretrained(
self.config.name_or_path, trust_remote_code=True
)
except Exception:
pass
def _is_toxic(self, text: str) -> bool:
if not text.strip():
return False
inputs = tf.constant([text], dtype=tf.string)
prob = float(self._tox_model.predict(inputs)[0, 0])
return prob >= self._tox_threshold
def _safe_ids(self, message: str, length: int | None = None):
"""Encode *message* and pad/truncate to *length* tokens (if given)."""
self._ensure_tokenizer()
if self._tokenizer is None:
raise RuntimeError("Tokenizer unavailable for safe-message encoding.")
ids = self._tokenizer(message, return_tensors="pt")["input_ids"][0]
if length is not None:
pad_id = (
self.config.eos_token_id
if self.config.eos_token_id is not None
else (self.config.pad_token_id or 0)
)
if ids.size(0) < length:
ids = torch.cat(
[ids, ids.new_full((length - ids.size(0),), pad_id)], dim=0
)
else:
ids = ids[:length]
return ids.to(self._device())
# ---- main override ---------------------------------------------
def generate(self, *args, **kwargs):
self._ensure_tokenizer()
# 1) prompt toxicity
prompt_txt = None
if self._tokenizer is not None:
if "input_ids" in kwargs:
prompt_txt = self._tokenizer.decode(
kwargs["input_ids"][0].tolist(), skip_special_tokens=True
)
elif args:
prompt_txt = self._tokenizer.decode(
args[0][0].tolist(), skip_special_tokens=True
)
if prompt_txt and self._is_toxic(prompt_txt):
return self._safe_ids(self._safe_in_msg).unsqueeze(0)
# 2) normal generation
outputs = super().generate(*args, **kwargs)
# 3) output toxicity
if self._tokenizer is None:
return outputs
new_seqs = []
for seq in outputs.detach().cpu():
txt = self._tokenizer.decode(seq.tolist(), skip_special_tokens=True)
if self._is_toxic(txt):
new_seqs.append(self._safe_ids(self._safe_out_msg, length=seq.size(0)))
else:
new_seqs.append(seq)
return torch.stack(new_seqs, dim=0).to(self._device())
# ------------------------------------------------------------------ #
# 2) utilities: resolve base class & cache subclass #
# ------------------------------------------------------------------ #
@lru_cache(None)
def _get_base_cls(arch: str):
if hasattr(transformers, arch):
return getattr(transformers, arch)
stem = arch.replace("ForCausalLM", "").lower()
module = importlib.import_module(f"transformers.models.{stem}.modeling_{stem}")
return getattr(module, arch)
@lru_cache(None)
def _make_safe_subclass(base_cls):
return type(
f"SafeGeneration_{base_cls.__name__}",
(_SafeGenerationMixin, base_cls),
{},
)
# ------------------------------------------------------------------ #
# 3) Dispatcher class – referenced by auto_map #
# ------------------------------------------------------------------ #
class SafeGenerationModel:
@classmethod
def from_pretrained(cls, repo_id, *model_args, **kwargs):
kwargs.setdefault("trust_remote_code", True)
if kwargs.get("torch_dtype") == "auto":
kwargs.pop("torch_dtype")
config = transformers.AutoConfig.from_pretrained(repo_id, **kwargs)
if not getattr(config, "architectures", None):
raise ValueError("`config.architectures` missing in config.json.")
arch_str = config.architectures[0]
Base = _get_base_cls(arch_str)
Safe = _make_safe_subclass(Base)
kwargs.pop("config", None) # avoid duplicate
return Safe.from_pretrained(repo_id, *model_args, config=config, **kwargs) |