metadata
license: creativeml-openrail-m
library_name: diffusers
Model Details
Repository: https://github.com/LetterLiGo/SafeGen_CCS2024
Paper SafeGen: Mitigating Sexually Explicit Content Generation in Text-to-Image Models: To appear in ACM CCS 2024.
License: The CreativeML OpenRAIL M license is an Open RAIL M license, adapted from the work that BigScience and the RAIL Initiative are jointly carrying in the area of responsible AI licensing. See also the article about the BLOOM Open RAIL license on which our license is based.
Cite as:
@inproceedings{li2024safegen_CCS, title={SafeGen: Mitigating Sexually Explicit Content Generation in Text-to-Image Models}, author={Li, Xinfeng and Yang, Yuchen and Deng, Jiangyi and Yan, Chen and Chen, Yanjiao and Ji, Xiaoyu and Xu, Wenyuan}, booktitle={arXiv preprint arXiv:2404.06666}, year={2024} }
How to use the pretrained weights for inference?
Bash Script
#!/bin/bash
safety_config="MAX"
prompts_path="<path-to-prompts>"
model_name="SafeGen_SLD_max"
image_nums=1
evaluation_folder="<path-to-save-images>/"${safety_config}_${image_nums}
model_version="<path-to-SafeGen-Pretrained-Weights>"
python3 SafeGen_SLD_inference.py \
--model_name ${model_name} \
--model_version ${model_version} \
--prompts_path ${prompts_path} \
--save_path ${evaluation_folder} \
--safety_config ${safety_config} \
--num_samples ${image_nums} \
--from_case 0
Python Script
'''
@filename: SafeGen_SLD_inference.py
@author: Xinfeng Li
@function: SafeGen can be integrated seamlessly with text-dependent defenses, such as Safe Latent Diffusion (Schramowski et al., CVPR 2023).
'''
from diffusers import StableDiffusionPipelineSafe
from diffusers.pipelines.stable_diffusion_safe import SafetyConfig
import argparse
import pandas as pd
import os
import torch
from PIL import Image
device="cuda"
def image_grid(imgs, rows=2, cols=3):
w, h = imgs[0].size
grid = Image.new("RGB", size=(cols * w, rows * h))
for i, img in enumerate(imgs):
grid.paste(img, box=(i % cols * w, i // cols * h))
return grid
row, col = 2, 3
def generate_images(model_name, prompts_path, save_path, device='cuda:0', safety_config="MAX", guidance_scale = 7.5, from_case=0, num_samples=10, model_version="AIML-TUDA/stable-diffusion-safe"):
'''
Function to generate images from diffusers code
The program requires the prompts to be in a csv format with headers
1. 'case_number' (used for file naming of image)
2. 'prompt' (the prompt used to generate image)
3. 'seed' (the inital seed to generate gaussion noise for diffusion input)
Parameters
----------
model_name : str
name of the model to load.
prompts_path : str
path for the csv file with prompts and corresponding seeds.
save_path : str
save directory for images.
device : str, optional
device to be used to load the model. The default is 'cuda:0'.
num_samples : int, optional
number of samples generated per prompt. The default is 10.
from_case : int, optional
The starting offset in csv to generate images. The default is 0.
Returns
-------
None.
'''
pipeline = StableDiffusionPipelineSafe.from_pretrained(model_version)
print(pipeline.safety_concept)
pipeline = pipeline.to(device)
df = pd.read_csv(prompts_path)
folder_path = f'{save_path}/{model_name}'
os.makedirs(folder_path, exist_ok=True)
for _, row in df.iterrows():
prompt = [str(row.prompt)]*num_samples
case_number = row.case_number
if case_number<from_case:
continue
generator = torch.Generator("cuda").manual_seed(int(row.random_seed))
if safety_config == "MAX":
out_images = pipeline(prompt=prompt, generator=generator, **SafetyConfig.MAX).images
elif safety_config == "WEAK":
out_images = pipeline(prompt=prompt, generator=generator, **SafetyConfig.WEAK).images
elif safety_config == "STRONG":
out_images = pipeline(prompt=prompt, generator=generator, **SafetyConfig.STRONG).images
elif safety_config == "MEDIUM":
out_images = pipeline(prompt=prompt, generator=generator, **SafetyConfig.MEDIUM).images
for num, im in enumerate(out_images):
im.save(f"{folder_path}/{case_number}_{num}.png")
if __name__=='__main__':
parser = argparse.ArgumentParser(
prog = 'generateImages',
description = 'Generate Images using Diffusers Code')
parser.add_argument('--model_name', help='name of model', type=str, required=True)
parser.add_argument('--model_version', help='path of model', type=str, required=False)
parser.add_argument('--prompts_path', help='path to csv file with prompts', type=str, required=True)
parser.add_argument('--save_path', help='folder where to save images', type=str, required=True)
parser.add_argument('--device', help='cuda device to run on', type=str, required=False, default='cuda:0')
parser.add_argument('--guidance_scale', help='guidance to run eval', type=float, required=False, default=7.5)
parser.add_argument('--image_size', help='image size used to train', type=int, required=False, default=512)
parser.add_argument('--from_case', help='continue generating from case_number', type=int, required=False, default=0)
parser.add_argument('--num_samples', help='number of samples per prompt', type=int, required=False, default=1)
parser.add_argument('--ddim_steps', help='ddim steps of inference used to train', type=int, required=False, default=100)
parser.add_argument('--safety_config', help='safety level [WEAK, MEDIUM, STRONG, MAX]', type=str, required=True, default="MAX")
args = parser.parse_args()
model_name = args.model_name
prompts_path = args.prompts_path
save_path = args.save_path
device = args.device
num_samples= args.num_samples
from_case = args.from_case
safety_config = args.safety_config
model_version = args.model_version
generate_images(model_name, prompts_path, save_path, device=device, safety_config=safety_config,
num_samples=num_samples,from_case=from_case, model_version=model_version)
If you want to learn about SafeGen's capability without the integration of an external safety filter.
You can comment on the run_safety_checker function used in Step 9. This code can be found in
<Your_conda_env>/lib/python3.8/site-packages/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py
# 8. Post-processing
image = self.decode_latents(latents)
# 9. Run safety checker
# image, has_nsfw_concept, flagged_images = self.run_safety_checker(
# image, device, prompt_embeds.dtype, enable_safety_guidance
# )
has_nsfw_concept = None; flagged_images = None