sddemo / app.py
提提
createapp
381d853
import os
import gradio as gr
import torch
from torch import autocast
from diffusers import StableDiffusionPipeline
prompts = open('prompts/' + 'all-prompts.txt').read().splitlines()
model_id = "CompVis/stable-diffusion-v1-4"
device = "cuda"
pipe = StableDiffusionPipeline.from_pretrained(model_id, use_auth_token=True, torch_dtype=torch.float16, low_cpu_mem_usage=True)
pipe = pipe.to(device)
HF_TOKEN = os.getenv('HUGGING_FACE_HUB_TOKEN')
hf_writer = gr.HuggingFaceDatasetSaver(HF_TOKEN, "stable-diffusion-flagging")
def inference(diffusion_prompt):
samples = 4
generator = torch.Generator(device=device).manual_seed(42)
with autocast("cuda"):
images_list = pipe(
[diffusion_prompt] * samples,
height=512, width=512,
num_inference_steps=50,
generator=generator
)
images = []
for i, image in enumerate(images_list["sample"]):
images.append(image)
return images
images = gr.Gallery(label="Images").style(grid=[2], height="auto")
with gr.Blocks() as demo:
gr.Markdown("# Stable Diffusion Bias Prompting Demo")
gr.Markdown("## The purpose of this project is to help discover latent biases of text-to-image models (in this demo, Stable Diffusion).")
gr.Markdown("Type in your prompt below, look at the images generated by the model and use the flagging mechanism to flag generations that you find problematic.")
gr.Markdown("If you want help thinking of prompts, please consider: **identity characteristics** (race, gender, ethnicity, religion, etc.), **occupations** (nurse, judge, CEO, secretary), **activities** (teaching, running, playing soccer), **feelings and abstract concepts** (hate, love, kindness)")
gr.Markdown("The images and data you submit via the Flag button will be collected for research purposes.")
with gr.Row():
with gr.Column():
with gr.Row():
prompt = gr.Textbox(
lines=1,
placeholder="Enter your prompt..",
interactive=True,
label="Prompt"
)
submit = gr.Button("Run")
gr.Markdown("We've provided some example prompts to guide you in creating your own.")
prompt_samples = gr.Examples(
prompts, prompt, images, inference,
cache_examples=True
)
with gr.Column():
gr.Markdown("*Image generation can take several minutes -- please be patient!*")
images.render()
gr.Markdown("Is there a particular image that you find problematic? If so, why?")
flagged_images = gr.CheckboxGroup(
choices=["Image 1 (Top Left)", "Image 2 (Top Right)", "Image 3 (Bottom Left)", "Image 4 (Bottom Right)"],
type="index",
label="Flagged Images"
)
flagged_categories = gr.CheckboxGroup(
choices=["Unsolicited violence", "Unsolicited sexualization"],
label="Flagged Categories"
)
gr.Markdown("Are there biases that are visible across all images? If so, which ones?")
flagged_output = gr.CheckboxGroup(
choices=["Reinforcing gender biases", "Reinforcing ethnicity/religion biases", "Reinforcing biases - other"],
label="Flagged Output"
)
btn = gr.Button("Flag")
hf_writer.setup([prompt, images, flagged_images, flagged_categories,flagged_output], "flagged_data_points")
submit.click(fn=inference, inputs=[prompt], outputs=[images])
btn.click(
fn=lambda *args: hf_writer.flag(args),
inputs=[prompt, images, flagged_images, flagged_categories, flagged_output],
outputs=None
)
demo.launch()