Model Details

How to use the pretrained weights for inference?

Bash Script

#!/bin/bash
safety_config="MAX"
prompts_path="<path-to-prompts>"
model_name="SafeGen_SLD_max"
image_nums=1
evaluation_folder="<path-to-save-images>/"${safety_config}_${image_nums}
model_version="<path-to-SafeGen-Pretrained-Weights>"

python3 SafeGen_SLD_inference.py \
       --model_name ${model_name} \
       --model_version ${model_version} \
       --prompts_path ${prompts_path} \
       --save_path ${evaluation_folder} \
       --safety_config ${safety_config} \
       --num_samples ${image_nums} \
       --from_case 0

Python Script

'''
@filename: SafeGen_SLD_inference.py
@author: Xinfeng Li
@function: SafeGen can be integrated seamlessly with text-dependent defenses, such as Safe Latent Diffusion (Schramowski et al., CVPR 2023).
'''
from diffusers import StableDiffusionPipelineSafe
from diffusers.pipelines.stable_diffusion_safe import SafetyConfig
import argparse
import pandas as pd
import os
import torch
from PIL import Image

device="cuda"

def image_grid(imgs, rows=2, cols=3):
    w, h = imgs[0].size
    grid = Image.new("RGB", size=(cols * w, rows * h))

    for i, img in enumerate(imgs):
        grid.paste(img, box=(i % cols * w, i // cols * h))
    return grid

row, col = 2, 3

def generate_images(model_name, prompts_path, save_path, device='cuda:0', safety_config="MAX", guidance_scale = 7.5, from_case=0, num_samples=10, model_version="AIML-TUDA/stable-diffusion-safe"):
    '''
    Function to generate images from diffusers code
    
    The program requires the prompts to be in a csv format with headers 
        1. 'case_number' (used for file naming of image)
        2. 'prompt' (the prompt used to generate image)
        3. 'seed' (the inital seed to generate gaussion noise for diffusion input)
    
    Parameters
    ----------
    model_name : str
        name of the model to load.
    prompts_path : str
        path for the csv file with prompts and corresponding seeds.
    save_path : str
        save directory for images.
    device : str, optional
        device to be used to load the model. The default is 'cuda:0'.
    num_samples : int, optional
        number of samples generated per prompt. The default is 10.
    from_case : int, optional
        The starting offset in csv to generate images. The default is 0.

    Returns
    -------
    None.
    '''
    pipeline = StableDiffusionPipelineSafe.from_pretrained(model_version)
    print(pipeline.safety_concept)
    pipeline = pipeline.to(device)

    df = pd.read_csv(prompts_path)

    folder_path = f'{save_path}/{model_name}'
    os.makedirs(folder_path, exist_ok=True)

    for _, row in df.iterrows():
        prompt = [str(row.prompt)]*num_samples
        case_number = row.case_number
        if case_number<from_case:
            continue
        
        generator = torch.Generator("cuda").manual_seed(int(row.random_seed))
        if safety_config == "MAX":
            out_images = pipeline(prompt=prompt, generator=generator, **SafetyConfig.MAX).images
        elif safety_config == "WEAK":
            out_images = pipeline(prompt=prompt, generator=generator, **SafetyConfig.WEAK).images
        elif safety_config == "STRONG":
            out_images = pipeline(prompt=prompt, generator=generator, **SafetyConfig.STRONG).images
        elif safety_config == "MEDIUM":
            out_images = pipeline(prompt=prompt, generator=generator, **SafetyConfig.MEDIUM).images
    
        for num, im in enumerate(out_images):
            im.save(f"{folder_path}/{case_number}_{num}.png")

if __name__=='__main__':
    parser = argparse.ArgumentParser(
                    prog = 'generateImages',
                    description = 'Generate Images using Diffusers Code')
    parser.add_argument('--model_name', help='name of model', type=str, required=True)
    parser.add_argument('--model_version', help='path of model', type=str, required=False)
    parser.add_argument('--prompts_path', help='path to csv file with prompts', type=str, required=True)
    parser.add_argument('--save_path', help='folder where to save images', type=str, required=True)
    parser.add_argument('--device', help='cuda device to run on', type=str, required=False, default='cuda:0')
    parser.add_argument('--guidance_scale', help='guidance to run eval', type=float, required=False, default=7.5)
    parser.add_argument('--image_size', help='image size used to train', type=int, required=False, default=512)
    parser.add_argument('--from_case', help='continue generating from case_number', type=int, required=False, default=0)
    parser.add_argument('--num_samples', help='number of samples per prompt', type=int, required=False, default=1)
    parser.add_argument('--ddim_steps', help='ddim steps of inference used to train', type=int, required=False, default=100)
    parser.add_argument('--safety_config', help='safety level [WEAK, MEDIUM, STRONG, MAX]', type=str, required=True, default="MAX")
    args = parser.parse_args()
    
    model_name = args.model_name
    prompts_path = args.prompts_path
    save_path = args.save_path
    device = args.device
    num_samples= args.num_samples
    from_case = args.from_case
    safety_config = args.safety_config
    model_version = args.model_version
    
    generate_images(model_name, prompts_path, save_path, device=device, safety_config=safety_config,
                    num_samples=num_samples,from_case=from_case, model_version=model_version)

If you want to learn about SafeGen's capability without the integration of an external safety filter.

You can comment on the run_safety_checker function used in Step 9. This code can be found in

<Your_conda_env>/lib/python3.8/site-packages/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py
# 8. Post-processing
  image = self.decode_latents(latents)

# 9. Run safety checker
  # image, has_nsfw_concept, flagged_images = self.run_safety_checker(
  #     image, device, prompt_embeds.dtype, enable_safety_guidance
  # )
has_nsfw_concept = None; flagged_images = None
Downloads last month
50
Inference Examples
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social visibility and check back later, or deploy to Inference Endpoints (dedicated) instead.