| |
|
| | import os
|
| | from typing import Optional, List
|
| |
|
| | import torch
|
| | import torch.nn as nn
|
| | from diffusers import DiffusionPipeline, StableDiffusionPipeline
|
| | from diffusers.utils import BaseOutput
|
| |
|
| | import torch, torch.nn as nn, os
|
| | from typing import Optional
|
| |
|
| | CLASS_NAMES = ['gore', 'hate', 'medical', 'safe', 'sexual']
|
| |
|
| | class SafetyClassifier1280(nn.Module):
|
| | def __init__(self, num_classes: int = 5):
|
| | super().__init__()
|
| | self.pre = nn.AdaptiveAvgPool2d((8, 8))
|
| | self.model = nn.Sequential(
|
| | nn.Conv2d(1280, 512, 3, padding=1),
|
| | nn.BatchNorm2d(512), nn.ReLU(inplace=True), nn.MaxPool2d(2),
|
| | nn.Conv2d(512, 256, 3, padding=1),
|
| | nn.BatchNorm2d(256), nn.ReLU(inplace=True), nn.MaxPool2d(2),
|
| | nn.AdaptiveAvgPool2d(1), nn.Flatten(),
|
| | nn.Linear(256, 128), nn.ReLU(inplace=True), nn.Dropout(0.3),
|
| | nn.Linear(128, num_classes)
|
| | )
|
| | self.apply(self._init_weights)
|
| |
|
| | def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| | x = self.pre(x)
|
| | return self.model(x)
|
| |
|
| |
|
| | def _find_weights_path() -> str:
|
| |
|
| | env_p = os.getenv("SDG_CLASSIFIER_WEIGHTS")
|
| | if env_p and os.path.exists(env_p): return env_p
|
| | for p in ["safety_classifier_1280.pth", os.path.join("classifiers","safety_classifier_1280.pth")]:
|
| | if os.path.exists(p): return p
|
| |
|
| | raise FileNotFoundError(
|
| | "Safety-classifier weights not found. Provide via env SDG_CLASSIFIER_WEIGHTS, "
|
| | "place 'safety_classifier_1280.pth' at repo root or 'classifiers/', "
|
| | "or pass `classifier_weights=...` to the pipeline call."
|
| | )
|
| |
|
| | def load_classifier_1280(weights_path: str, device=None, dtype=torch.float32):
|
| | model = SafetyClassifier1280().to(device or "cpu", dtype=dtype)
|
| | state = torch.load(weights_path, map_location="cpu", weights_only=False)
|
| | if isinstance(state, dict) and "model_state_dict" in state:
|
| | state = state["model_state_dict"]
|
| | model.load_state_dict(state, strict=True)
|
| | model.eval()
|
| | return model
|
| |
|
| | def _here(*paths: str) -> str:
|
| | return os.path.join(os.path.dirname(__file__), *paths)
|
| |
|
| |
|
| | def pick_weights_path() -> str:
|
| | """
|
| | Try common locations; allow env override. Raise if not found.
|
| | """
|
| | candidates = [
|
| | os.getenv("SDG_CLASSIFIER_WEIGHTS", ""),
|
| | _here("classifiers", "safety_classifier_1280.pth"),
|
| | _here("safety_classifier_1280.pth"),
|
| | "classifiers/safety_classifier_1280.pth",
|
| | "safety_classifier_1280.pth",
|
| | ]
|
| | for p in candidates:
|
| | if p and os.path.exists(p):
|
| | return p
|
| | raise FileNotFoundError(
|
| | "Safety-classifier weights not found. Place 'safety_classifier_1280.pth' "
|
| | "in repo root or 'classifiers/' (or set SDG_CLASSIFIER_WEIGHTS, or pass "
|
| | "`classifier_weights=...` to the call())."
|
| | )
|
| |
|
| |
|
| |
|
| | class SDGOutput(BaseOutput):
|
| | images: List
|
| |
|
| |
|
| | class SafeDiffusionGuidance(DiffusionPipeline):
|
| | """
|
| | Minimal custom pipeline that loads a base Stable Diffusion pipeline on demand
|
| | and applies mid-UNet classifier-guided denoising for safety.
|
| | """
|
| |
|
| | def __init__(self,**kwargs):
|
| | super().__init__()
|
| | self.base_pipe_ = None
|
| |
|
| | def _ensure_base(
|
| | self,
|
| | base_pipe: Optional[StableDiffusionPipeline],
|
| | base_model_id: str,
|
| | torch_dtype: torch.dtype,
|
| | ) -> StableDiffusionPipeline:
|
| | if base_pipe is not None:
|
| | self.base_pipe_ = base_pipe
|
| | return self.base_pipe_
|
| | if self.base_pipe_ is None:
|
| | self.base_pipe_ = StableDiffusionPipeline.from_pretrained(
|
| | base_model_id,
|
| | torch_dtype=torch_dtype,
|
| | safety_checker=None,
|
| | requires_safety_checker=False,
|
| | ).to(self.device)
|
| | return self.base_pipe_
|
| |
|
| | @torch.no_grad()
|
| | def __call__(
|
| | self,
|
| | prompt: str,
|
| | negative_prompt: Optional[str] = None,
|
| | num_inference_steps: int = 50,
|
| | guidance_scale: float = 7.5,
|
| | safety_scale: float = 5.0,
|
| | mid_fraction: float = 1.0,
|
| | safe_class_index: int = 3,
|
| | classifier_weights: Optional[str] = None,
|
| | base_pipe: Optional[StableDiffusionPipeline] = None,
|
| | base_model_id: str = "runwayml/stable-diffusion-v1-5",
|
| | generator: Optional[torch.Generator] = None,
|
| | **kwargs,
|
| | ) -> SDGOutput:
|
| |
|
| |
|
| | base = self._ensure_base(base_pipe, base_model_id, torch_dtype=torch.float16)
|
| | device = getattr(base, "_execution_device", base.device)
|
| | dtype = base.unet.dtype
|
| |
|
| |
|
| | tok = base.tokenizer
|
| | max_len = tok.model_max_length
|
| | txt = tok([prompt], padding="max_length", max_length=max_len, return_tensors="pt")
|
| | cond = base.text_encoder(txt.input_ids.to(device)).last_hidden_state
|
| | if negative_prompt is not None:
|
| | uncond_txt = tok([negative_prompt], padding="max_length", max_length=max_len, return_tensors="pt")
|
| | else:
|
| | uncond_txt = tok([""], padding="max_length", max_length=max_len, return_tensors="pt")
|
| | uncond = base.text_encoder(uncond_txt.input_ids.to(device)).last_hidden_state
|
| | cond_embeds = torch.cat([uncond, cond], dim=0)
|
| |
|
| |
|
| | h = kwargs.pop("height", 512); w = kwargs.pop("width", 512)
|
| | latents = torch.randn(
|
| | (1, base.unet.in_channels, h // 8, w // 8),
|
| | device=device, generator=generator, dtype=dtype
|
| | )
|
| |
|
| | base.scheduler.set_timesteps(num_inference_steps, device=device)
|
| | timesteps = base.scheduler.timesteps
|
| |
|
| |
|
| | weights = classifier_weights or pick_weights_for_pipe(base)
|
| | clf = load_classifier_1280(weights, device=device, dtype=torch.float32).eval()
|
| |
|
| |
|
| |
|
| | mid = {}
|
| | def hook(_, __, out): mid["feat"] = out[0] if isinstance(out, tuple) else out
|
| | handle = base.unet.mid_block.register_forward_hook(hook)
|
| |
|
| | base_alpha = 1e-3
|
| |
|
| |
|
| | for i, t in enumerate(timesteps):
|
| |
|
| | lat_in = base.scheduler.scale_model_input(latents, t)
|
| | lat_cat = torch.cat([lat_in, lat_in], dim=0)
|
| | do_guide = (i / len(timesteps)) <= mid_fraction and safety_scale > 0
|
| |
|
| | if do_guide:
|
| |
|
| | with torch.enable_grad():
|
| | lg = latents.detach().clone().requires_grad_(True)
|
| | lin = base.scheduler.scale_model_input(lg, t)
|
| | lcat = torch.cat([lin, lin], dim=0)
|
| |
|
| | _ = base.unet(lcat, t, encoder_hidden_states=cond_embeds).sample
|
| | feat = mid["feat"].detach().to(torch.float32)
|
| | logits = clf(feat)
|
| | probs = torch.softmax(logits, dim=-1)
|
| | unsafe = 1.0 - probs[:, safe_class_index].mean()
|
| |
|
| | loss = safety_scale * unsafe
|
| | loss.backward()
|
| |
|
| | alpha = base_alpha
|
| | if hasattr(base.scheduler, "sigmas"):
|
| | idx = min(i, len(base.scheduler.sigmas) - 1)
|
| | alpha = base_alpha * float(base.scheduler.sigmas[idx])
|
| |
|
| | latents = (lg - alpha * lg.grad).detach()
|
| |
|
| |
|
| | lat_in = base.scheduler.scale_model_input(latents, t)
|
| | lat_cat = torch.cat([lat_in, lat_in], dim=0)
|
| |
|
| | noise_pred = base.unet(lat_cat, t, encoder_hidden_states=cond_embeds).sample
|
| | n_uncond, n_text = noise_pred.chunk(2)
|
| | noise = n_uncond + guidance_scale * (n_text - n_uncond)
|
| | latents = base.scheduler.step(noise, t, latents).prev_sample
|
| |
|
| | handle.remove()
|
| |
|
| |
|
| | img = base.decode_latents(latents)
|
| | pil = base.image_processor.postprocess(img, output_type="pil")[0]
|
| | return SDGOutput(images=[pil])
|
| |
|
| |
|
| | __all__ = ["SafeDiffusionGuidance"]
|
| |
|