File size: 6,460 Bytes
956fa05
 
31a0f6f
956fa05
 
31a0f6f
 
956fa05
 
 
31a0f6f
 
ab041ea
f56644b
31a0f6f
de81f33
31a0f6f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
956fa05
ab041ea
de81f33
 
 
 
 
 
 
 
 
 
956fa05
64fe77f
31a0f6f
956fa05
ab041ea
956fa05
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31a0f6f
956fa05
 
 
 
 
 
 
31a0f6f
 
956fa05
 
ab041ea
f56644b
 
31a0f6f
956fa05
 
de81f33
956fa05
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31a0f6f
f56644b
31a0f6f
 
 
 
 
 
 
 
 
 
956fa05
31a0f6f
 
 
 
 
 
 
 
 
956fa05
 
 
 
f56644b
956fa05
de81f33
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
import gradio as gr
import torch
from diffusers import DiffusionPipeline, StableDiffusionPipeline, StableDiffusionXLPipeline, EulerDiscreteScheduler, UNet2DConditionModel
from transformers import BlipProcessor, BlipForConditionalGeneration
from pathlib import Path
from safetensors.torch import load_file
from huggingface_hub import hf_hub_download
from PIL import Image
import matplotlib.pyplot as plt
from matplotlib.colors import hex2color
import stone
import os
import spaces

# Define model initialization functions
@spaces.GPU
def load_model(model_name):
    if model_name == "stabilityai/sdxl-turbo":
        pipeline = DiffusionPipeline.from_pretrained(
            model_name, 
            torch_dtype=torch.float16, 
            variant="fp16"
        ).to("cuda")
    elif model_name == "runwayml/stable-diffusion-v1-5":
        pipeline = StableDiffusionPipeline.from_pretrained(
            model_name, 
            torch_dtype=torch.float16
        ).to("cuda")
    elif model_name == "ByteDance/SDXL-Lightning":
        base = "stabilityai/stable-diffusion-xl-base-1.0"
        ckpt = "sdxl_lightning_4step_unet.safetensors"
        unet = UNet2DConditionModel.from_config(base, subfolder="unet").to("cuda", torch.float16)
        unet.load_state_dict(load_file(hf_hub_download(model_name, ckpt), device="cuda"))
        pipeline = StableDiffusionXLPipeline.from_pretrained(
            base, 
            unet=unet, 
            torch_dtype=torch.float16, 
            variant="fp16"
        ).to("cuda")
        pipeline.scheduler = EulerDiscreteScheduler.from_config(pipeline.scheduler.config, timestep_spacing="trailing")
    elif model_name == "segmind/SSD-1B":
        pipeline = StableDiffusionXLPipeline.from_pretrained(
            model_name, 
            torch_dtype=torch.float16, 
            use_safetensors=True, 
            variant="fp16"
        ).to("cuda")
    else:
        raise ValueError("Unknown model name")
    return pipeline

# Initialize the default model
default_model = "stabilityai/sdxl-turbo"
pipeline_text2image = load_model(default_model)

@spaces.GPU
def getimgen(prompt, model_name):
    if model_name == "stabilityai/sdxl-turbo":
        return pipeline_text2image(prompt=prompt, guidance_scale=0.0, num_inference_steps=2).images[0]
    elif model_name == "runwayml/stable-diffusion-v1-5":
        return pipeline_text2image(prompt).images[0]
    elif model_name == "ByteDance/SDXL-Lightning":
        return pipeline_text2image(prompt, num_inference_steps=4, guidance_scale=0).images[0]
    elif model_name == "segmind/SSD-1B":
        neg_prompt = "ugly, blurry, poor quality"
        return pipeline_text2image(prompt=prompt, negative_prompt=neg_prompt).images[0]

blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large", torch_dtype=torch.float16).to("cuda")

@spaces.GPU
def blip_caption_image(image, prefix):
    inputs = blip_processor(image, prefix, return_tensors="pt").to("cuda", torch.float16)
    out = blip_model.generate(**inputs)
    return blip_processor.decode(out[0], skip_special_tokens=True)

def genderfromcaption(caption):
    cc = caption.split()
    if "man" in cc or "boy" in cc:
        return "Man"
    elif "woman" in cc or "girl" in cc:
        return "Woman"
    return "Unsure"

def genderplot(genlist):    
    order = ["Man", "Woman", "Unsure"]
    words = sorted(genlist, key=lambda x: order.index(x))
    colors = {"Man": "lightgreen", "Woman": "darkgreen", "Unsure": "lightgrey"}
    word_colors = [colors[word] for word in words]
    fig, axes = plt.subplots(2, 5, figsize=(5,5))
    plt.subplots_adjust(hspace=0.1, wspace=0.1)
    for i, ax in enumerate(axes.flat):
        ax.set_axis_off()
        ax.add_patch(plt.Rectangle((0, 0), 1, 1, color=word_colors[i]))
    return fig

def skintoneplot(hex_codes):
    hex_codes = [code for code in hex_codes if code is not None]
    rgb_values = [hex2color(hex_code) for hex_code in hex_codes]
    luminance_values = [0.299 * r + 0.587 * g + 0.114 * b for r, g, b in rgb_values]
    sorted_hex_codes = [code for _, code in sorted(zip(luminance_values, hex_codes), reverse=True)]
    fig, axes = plt.subplots(2, 5, figsize=(5,5))
    plt.subplots_adjust(hspace=0.1, wspace=0.1)
    for i, ax in enumerate(axes.flat):
        ax.set_axis_off()
        if i < len(sorted_hex_codes):
            ax.add_patch(plt.Rectangle((0, 0), 1, 1, color=sorted_hex_codes[i]))
    return fig

@spaces.GPU
def generate_images_plots(prompt, model_name):
    global pipeline_text2image
    pipeline_text2image = load_model(model_name)
    foldername = "temp"
    Path(foldername).mkdir(parents=True, exist_ok=True)
    images = [getimgen(prompt, model_name) for _ in range(10)]
    genders = []
    skintones = []
    for image, i in zip(images, range(10)):
        prompt_prefix = "photo of a "
        caption = blip_caption_image(image, prefix=prompt_prefix)
        image.save(f"{foldername}/image_{i}.png")
        try:
            skintoneres = stone.process(f"{foldername}/image_{i}.png", return_report_image=False)
            tone = skintoneres['faces'][0]['dominant_colors'][0]['color']
            skintones.append(tone)
        except:
            skintones.append(None)
        genders.append(genderfromcaption(caption))
    return images, skintoneplot(skintones), genderplot(genders)

with gr.Blocks(title="Skin Tone and Gender bias in Text to Image Models") as demo:
    gr.Markdown("# Skin Tone and Gender bias in Text to Image Models")
    model_dropdown = gr.Dropdown(
        label="Choose a model", 
        choices=[
            "stabilityai/sdxl-turbo", 
            "runwayml/stable-diffusion-v1-5", 
            "ByteDance/SDXL-Lightning", 
            "segmind/SSD-1B"
        ], 
        value=default_model
    )
    prompt = gr.Textbox(label="Enter the Prompt")
    gallery = gr.Gallery(
        label="Generated images", 
        show_label=False, 
        elem_id="gallery", 
        columns=[5], 
        rows=[2], 
        object_fit="contain", 
        height="auto"
    )
    btn = gr.Button("Generate images", scale=0)
    with gr.Row(equal_height=True):
        skinplot = gr.Plot(label="Skin Tone")
        genplot = gr.Plot(label="Gender")
    btn.click(generate_images_plots, inputs=[prompt, model_dropdown], outputs=[gallery, skinplot, genplot])

demo.launch(debug=True)