LetterJohn's picture
Update README.md
556797f verified
|
raw
history blame
7.73 kB
metadata
license: creativeml-openrail-m
library_name: diffusers

Model Details

How to use the pretrained weights for inference?

Bash Script

#!/bin/bash
safety_config="MAX"
prompts_path="<path-to-prompts>"
model_name="SafeGen_SLD_max"
image_nums=1
evaluation_folder="<path-to-save-images>/"${safety_config}_${image_nums}
model_version="<path-to-SafeGen-Pretrained-Weights>"

python3 SafeGen_SLD_inference.py \
       --model_name ${model_name} \
       --model_version ${model_version} \
       --prompts_path ${prompts_path} \
       --save_path ${evaluation_folder} \
       --safety_config ${safety_config} \
       --num_samples ${image_nums} \
       --from_case 0

Python Script

'''
@filename: SafeGen_SLD_inference.py
@author: Xinfeng Li
@function: SafeGen can be integrated seamlessly with text-dependent defenses, such as Safe Latent Diffusion (Schramowski et al., CVPR 2023).
'''
from diffusers import StableDiffusionPipelineSafe
from diffusers.pipelines.stable_diffusion_safe import SafetyConfig
import argparse
import pandas as pd
import os
import torch
from PIL import Image

device="cuda"

def image_grid(imgs, rows=2, cols=3):
    w, h = imgs[0].size
    grid = Image.new("RGB", size=(cols * w, rows * h))

    for i, img in enumerate(imgs):
        grid.paste(img, box=(i % cols * w, i // cols * h))
    return grid

row, col = 2, 3

def generate_images(model_name, prompts_path, save_path, device='cuda:0', safety_config="MAX", guidance_scale = 7.5, from_case=0, num_samples=10, model_version="AIML-TUDA/stable-diffusion-safe"):
    '''
    Function to generate images from diffusers code
    
    The program requires the prompts to be in a csv format with headers 
        1. 'case_number' (used for file naming of image)
        2. 'prompt' (the prompt used to generate image)
        3. 'seed' (the inital seed to generate gaussion noise for diffusion input)
    
    Parameters
    ----------
    model_name : str
        name of the model to load.
    prompts_path : str
        path for the csv file with prompts and corresponding seeds.
    save_path : str
        save directory for images.
    device : str, optional
        device to be used to load the model. The default is 'cuda:0'.
    num_samples : int, optional
        number of samples generated per prompt. The default is 10.
    from_case : int, optional
        The starting offset in csv to generate images. The default is 0.

    Returns
    -------
    None.
    '''
    pipeline = StableDiffusionPipelineSafe.from_pretrained(model_version)
    print(pipeline.safety_concept)
    pipeline = pipeline.to(device)

    df = pd.read_csv(prompts_path)

    folder_path = f'{save_path}/{model_name}'
    os.makedirs(folder_path, exist_ok=True)

    for _, row in df.iterrows():
        prompt = [str(row.prompt)]*num_samples
        case_number = row.case_number
        if case_number<from_case:
            continue
        
        generator = torch.Generator("cuda").manual_seed(int(row.random_seed))
        if safety_config == "MAX":
            out_images = pipeline(prompt=prompt, generator=generator, **SafetyConfig.MAX).images
        elif safety_config == "WEAK":
            out_images = pipeline(prompt=prompt, generator=generator, **SafetyConfig.WEAK).images
        elif safety_config == "STRONG":
            out_images = pipeline(prompt=prompt, generator=generator, **SafetyConfig.STRONG).images
        elif safety_config == "MEDIUM":
            out_images = pipeline(prompt=prompt, generator=generator, **SafetyConfig.MEDIUM).images
    
        for num, im in enumerate(out_images):
            im.save(f"{folder_path}/{case_number}_{num}.png")

if __name__=='__main__':
    parser = argparse.ArgumentParser(
                    prog = 'generateImages',
                    description = 'Generate Images using Diffusers Code')
    parser.add_argument('--model_name', help='name of model', type=str, required=True)
    parser.add_argument('--model_version', help='path of model', type=str, required=False)
    parser.add_argument('--prompts_path', help='path to csv file with prompts', type=str, required=True)
    parser.add_argument('--save_path', help='folder where to save images', type=str, required=True)
    parser.add_argument('--device', help='cuda device to run on', type=str, required=False, default='cuda:0')
    parser.add_argument('--guidance_scale', help='guidance to run eval', type=float, required=False, default=7.5)
    parser.add_argument('--image_size', help='image size used to train', type=int, required=False, default=512)
    parser.add_argument('--from_case', help='continue generating from case_number', type=int, required=False, default=0)
    parser.add_argument('--num_samples', help='number of samples per prompt', type=int, required=False, default=1)
    parser.add_argument('--ddim_steps', help='ddim steps of inference used to train', type=int, required=False, default=100)
    parser.add_argument('--safety_config', help='safety level [WEAK, MEDIUM, STRONG, MAX]', type=str, required=True, default="MAX")
    args = parser.parse_args()
    
    model_name = args.model_name
    prompts_path = args.prompts_path
    save_path = args.save_path
    device = args.device
    num_samples= args.num_samples
    from_case = args.from_case
    safety_config = args.safety_config
    model_version = args.model_version
    
    generate_images(model_name, prompts_path, save_path, device=device, safety_config=safety_config,
                    num_samples=num_samples,from_case=from_case, model_version=model_version)

If you want to learn about SafeGen's capability without the integration of an external safety filter.

You can comment on the run_safety_checker function used in Step 9. This code can be found in

<Your_conda_env>/lib/python3.8/site-packages/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py
# 8. Post-processing
  image = self.decode_latents(latents)

# 9. Run safety checker
  # image, has_nsfw_concept, flagged_images = self.run_safety_checker(
  #     image, device, prompt_embeds.dtype, enable_safety_guidance
  # )
has_nsfw_concept = None; flagged_images = None