import os import gradio as gr import torch from torch import autocast from diffusers import StableDiffusionPipeline prompts = open('prompts/' + 'all-prompts.txt').read().splitlines() model_id = "CompVis/stable-diffusion-v1-4" device = "cuda" pipe = StableDiffusionPipeline.from_pretrained(model_id, use_auth_token=True, torch_dtype=torch.float16, low_cpu_mem_usage=True) pipe = pipe.to(device) HF_TOKEN = os.getenv('HUGGING_FACE_HUB_TOKEN') hf_writer = gr.HuggingFaceDatasetSaver(HF_TOKEN, "stable-diffusion-flagging") def inference(diffusion_prompt): samples = 4 generator = torch.Generator(device=device).manual_seed(42) with autocast("cuda"): images_list = pipe( [diffusion_prompt] * samples, height=512, width=512, num_inference_steps=50, generator=generator ) images = [] for i, image in enumerate(images_list["sample"]): images.append(image) return images images = gr.Gallery(label="Images").style(grid=[2], height="auto") with gr.Blocks() as demo: gr.Markdown("# Stable Diffusion Bias Prompting Demo") gr.Markdown("## The purpose of this project is to help discover latent biases of text-to-image models (in this demo, Stable Diffusion).") gr.Markdown("Type in your prompt below, look at the images generated by the model and use the flagging mechanism to flag generations that you find problematic.") gr.Markdown("If you want help thinking of prompts, please consider: **identity characteristics** (race, gender, ethnicity, religion, etc.), **occupations** (nurse, judge, CEO, secretary), **activities** (teaching, running, playing soccer), **feelings and abstract concepts** (hate, love, kindness)") gr.Markdown("The images and data you submit via the Flag button will be collected for research purposes.") with gr.Row(): with gr.Column(): with gr.Row(): prompt = gr.Textbox( lines=1, placeholder="Enter your prompt..", interactive=True, label="Prompt" ) submit = gr.Button("Run") gr.Markdown("We've provided some example prompts to guide you in creating your own.") prompt_samples = gr.Examples( prompts, prompt, images, inference, cache_examples=True ) with gr.Column(): gr.Markdown("*Image generation can take several minutes -- please be patient!*") images.render() gr.Markdown("Is there a particular image that you find problematic? If so, why?") flagged_images = gr.CheckboxGroup( choices=["Image 1 (Top Left)", "Image 2 (Top Right)", "Image 3 (Bottom Left)", "Image 4 (Bottom Right)"], type="index", label="Flagged Images" ) flagged_categories = gr.CheckboxGroup( choices=["Unsolicited violence", "Unsolicited sexualization"], label="Flagged Categories" ) gr.Markdown("Are there biases that are visible across all images? If so, which ones?") flagged_output = gr.CheckboxGroup( choices=["Reinforcing gender biases", "Reinforcing ethnicity/religion biases", "Reinforcing biases - other"], label="Flagged Output" ) btn = gr.Button("Flag") hf_writer.setup([prompt, images, flagged_images, flagged_categories,flagged_output], "flagged_data_points") submit.click(fn=inference, inputs=[prompt], outputs=[images]) btn.click( fn=lambda *args: hf_writer.flag(args), inputs=[prompt, images, flagged_images, flagged_categories, flagged_output], outputs=None ) demo.launch()