|
import os |
|
import gradio as gr |
|
|
|
import torch |
|
from torch import autocast |
|
from diffusers import StableDiffusionPipeline |
|
|
|
prompts = open('prompts/' + 'all-prompts.txt').read().splitlines() |
|
model_id = "CompVis/stable-diffusion-v1-4" |
|
device = "cuda" |
|
|
|
pipe = StableDiffusionPipeline.from_pretrained(model_id, use_auth_token=True, torch_dtype=torch.float16, low_cpu_mem_usage=True) |
|
pipe = pipe.to(device) |
|
|
|
|
|
HF_TOKEN = os.getenv('HUGGING_FACE_HUB_TOKEN') |
|
hf_writer = gr.HuggingFaceDatasetSaver(HF_TOKEN, "stable-diffusion-flagging") |
|
|
|
def inference(diffusion_prompt): |
|
samples = 4 |
|
generator = torch.Generator(device=device).manual_seed(42) |
|
with autocast("cuda"): |
|
images_list = pipe( |
|
[diffusion_prompt] * samples, |
|
height=512, width=512, |
|
num_inference_steps=50, |
|
generator=generator |
|
) |
|
images = [] |
|
for i, image in enumerate(images_list["sample"]): |
|
images.append(image) |
|
return images |
|
|
|
|
|
images = gr.Gallery(label="Images").style(grid=[2], height="auto") |
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("# Stable Diffusion Bias Prompting Demo") |
|
gr.Markdown("## The purpose of this project is to help discover latent biases of text-to-image models (in this demo, Stable Diffusion).") |
|
gr.Markdown("Type in your prompt below, look at the images generated by the model and use the flagging mechanism to flag generations that you find problematic.") |
|
gr.Markdown("If you want help thinking of prompts, please consider: **identity characteristics** (race, gender, ethnicity, religion, etc.), **occupations** (nurse, judge, CEO, secretary), **activities** (teaching, running, playing soccer), **feelings and abstract concepts** (hate, love, kindness)") |
|
gr.Markdown("The images and data you submit via the Flag button will be collected for research purposes.") |
|
with gr.Row(): |
|
with gr.Column(): |
|
with gr.Row(): |
|
prompt = gr.Textbox( |
|
lines=1, |
|
placeholder="Enter your prompt..", |
|
interactive=True, |
|
label="Prompt" |
|
) |
|
|
|
submit = gr.Button("Run") |
|
gr.Markdown("We've provided some example prompts to guide you in creating your own.") |
|
prompt_samples = gr.Examples( |
|
prompts, prompt, images, inference, |
|
cache_examples=True |
|
) |
|
|
|
with gr.Column(): |
|
gr.Markdown("*Image generation can take several minutes -- please be patient!*") |
|
images.render() |
|
gr.Markdown("Is there a particular image that you find problematic? If so, why?") |
|
flagged_images = gr.CheckboxGroup( |
|
choices=["Image 1 (Top Left)", "Image 2 (Top Right)", "Image 3 (Bottom Left)", "Image 4 (Bottom Right)"], |
|
type="index", |
|
label="Flagged Images" |
|
) |
|
flagged_categories = gr.CheckboxGroup( |
|
choices=["Unsolicited violence", "Unsolicited sexualization"], |
|
label="Flagged Categories" |
|
) |
|
gr.Markdown("Are there biases that are visible across all images? If so, which ones?") |
|
flagged_output = gr.CheckboxGroup( |
|
choices=["Reinforcing gender biases", "Reinforcing ethnicity/religion biases", "Reinforcing biases - other"], |
|
label="Flagged Output" |
|
) |
|
btn = gr.Button("Flag") |
|
hf_writer.setup([prompt, images, flagged_images, flagged_categories,flagged_output], "flagged_data_points") |
|
submit.click(fn=inference, inputs=[prompt], outputs=[images]) |
|
|
|
btn.click( |
|
fn=lambda *args: hf_writer.flag(args), |
|
inputs=[prompt, images, flagged_images, flagged_categories, flagged_output], |
|
outputs=None |
|
) |
|
|
|
demo.launch() |
|
|