evijit HF staff commited on
Commit
b0be682
1 Parent(s): a5d42f0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +77 -30
app.py CHANGED
@@ -6,9 +6,10 @@ from diffusers import (
6
  StableDiffusionXLPipeline,
7
  EulerDiscreteScheduler,
8
  UNet2DConditionModel,
9
- StableDiffusion3Pipeline
 
10
  )
11
- from transformers import BlipProcessor, BlipForConditionalGeneration
12
  from pathlib import Path
13
  from safetensors.torch import load_file
14
  from huggingface_hub import hf_hub_download
@@ -21,11 +22,9 @@ import spaces
21
 
22
  access_token = os.getenv("AccessTokenSD3")
23
 
24
-
25
  from huggingface_hub import login
26
  login(token = access_token)
27
 
28
-
29
  # Define model initialization functions
30
  def load_model(model_name):
31
  if model_name == "stabilityai/sdxl-turbo":
@@ -34,11 +33,6 @@ def load_model(model_name):
34
  torch_dtype=torch.float16,
35
  variant="fp16"
36
  ).to("cuda")
37
- elif model_name == "runwayml/stable-diffusion-v1-5":
38
- pipeline = StableDiffusionPipeline.from_pretrained(
39
- model_name,
40
- torch_dtype=torch.float16
41
- ).to("cuda")
42
  elif model_name == "ByteDance/SDXL-Lightning":
43
  base = "stabilityai/stable-diffusion-xl-base-1.0"
44
  ckpt = "sdxl_lightning_4step_unet.safetensors"
@@ -70,29 +64,40 @@ def load_model(model_name):
70
  scheduler=scheduler,
71
  torch_dtype=torch.float16
72
  ).to("cuda")
 
 
 
73
  else:
74
  raise ValueError("Unknown model name")
75
  return pipeline
76
 
77
  # Initialize the default model
78
- default_model = "stabilityai/stable-diffusion-3-medium-diffusers"
79
  pipeline_text2image = load_model(default_model)
80
 
81
  @spaces.GPU
82
  def getimgen(prompt, model_name):
83
  if model_name == "stabilityai/sdxl-turbo":
84
- return pipeline_text2image(prompt=prompt, guidance_scale=0.0, num_inference_steps=2).images[0]
85
- elif model_name == "runwayml/stable-diffusion-v1-5":
86
- return pipeline_text2image(prompt).images[0]
87
  elif model_name == "ByteDance/SDXL-Lightning":
88
- return pipeline_text2image(prompt, num_inference_steps=4, guidance_scale=0).images[0]
89
  elif model_name == "segmind/SSD-1B":
90
  neg_prompt = "ugly, blurry, poor quality"
91
- return pipeline_text2image(prompt=prompt, negative_prompt=neg_prompt).images[0]
92
  elif model_name == "stabilityai/stable-diffusion-3-medium-diffusers":
93
- return pipeline_text2image(prompt=prompt, negative_prompt="", num_inference_steps=28, guidance_scale=7.0).images[0]
94
  elif model_name == "stabilityai/stable-diffusion-2":
95
- return pipeline_text2image(prompt=prompt).images[0]
 
 
 
 
 
 
 
 
 
 
96
 
97
  blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
98
  blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large", torch_dtype=torch.float16).to("cuda")
@@ -136,6 +141,42 @@ def skintoneplot(hex_codes):
136
  ax.add_patch(plt.Rectangle((0, 0), 1, 1, color=sorted_hex_codes[i]))
137
  return fig
138
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
  @spaces.GPU(duration=200)
140
  def generate_images_plots(prompt, model_name):
141
  global pipeline_text2image
@@ -145,6 +186,8 @@ def generate_images_plots(prompt, model_name):
145
  images = [getimgen(prompt, model_name) for _ in range(10)]
146
  genders = []
147
  skintones = []
 
 
148
  for image, i in zip(images, range(10)):
149
  prompt_prefix = "photo of a "
150
  caption = blip_caption_image(image, prefix=prompt_prefix)
@@ -156,37 +199,38 @@ def generate_images_plots(prompt, model_name):
156
  except:
157
  skintones.append(None)
158
  genders.append(genderfromcaption(caption))
159
- return images, skintoneplot(skintones), genderplot(genders)
 
 
160
 
161
- with gr.Blocks(title="Skin Tone and Gender bias in Text-to-Image Generation Models") as demo:
162
- gr.Markdown("# Skin Tone and Gender bias in Text to Image Models")
163
  gr.Markdown('''
164
- In this demo, we explore the potential biases in text-to-image models by generating multiple images based on user prompts and analyzing the gender and skin tone of the generated subjects. Here's how the analysis works:
165
-
166
  1. **Image Generation**: For each prompt, 10 images are generated using the selected model.
167
  2. **Gender Detection**: The [BLIP caption generator](https://huggingface.co/Salesforce/blip-image-captioning-large) is used to elicit gender markers by identifying words like "man," "boy," "woman," and "girl" in the captions.
168
  3. **Skin Tone Classification**: The [skin-tone-classifier library](https://github.com/ChenglongMa/SkinToneClassifier) is used to extract the skin tones of the generated subjects.
169
-
170
-
171
  #### Visualization
172
-
173
  We create visual grids to represent the data:
174
-
175
  - **Skin Tone Grids**: Skin tones are plotted as exact hex codes rather than using the Fitzpatrick scale, which can be [problematic and limiting for darker skin tones](https://arxiv.org/pdf/2309.05148).
176
  - **Gender Grids**: Light green denotes men, dark green denotes women, and grey denotes cases where the BLIP caption did not specify a binary gender.
177
-
 
 
178
  This demo provides an insightful look into how current text-to-image models handle sensitive attributes, shedding light on areas for improvement and further study.
179
  [Here is an article](https://medium.com/@evijit/analysis-of-ai-generated-images-of-indian-people-for-colorism-and-sexism-b80ff946759f) showing how this space can be used to perform such analyses, using colorism and sexism in India as an example.
180
  ''')
181
  model_dropdown = gr.Dropdown(
182
  label="Choose a model",
183
  choices=[
 
184
  "stabilityai/stable-diffusion-3-medium-diffusers",
185
  "stabilityai/sdxl-turbo",
186
  "ByteDance/SDXL-Lightning",
187
  "stabilityai/stable-diffusion-2",
188
- "runwayml/stable-diffusion-v1-5",
189
- "segmind/SSD-1B"
190
  ],
191
  value=default_model
192
  )
@@ -204,6 +248,9 @@ This demo provides an insightful look into how current text-to-image models hand
204
  with gr.Row(equal_height=True):
205
  skinplot = gr.Plot(label="Skin Tone")
206
  genplot = gr.Plot(label="Gender")
207
- btn.click(generate_images_plots, inputs=[prompt, model_dropdown], outputs=[gallery, skinplot, genplot])
 
 
 
208
 
209
  demo.launch(debug=True)
 
6
  StableDiffusionXLPipeline,
7
  EulerDiscreteScheduler,
8
  UNet2DConditionModel,
9
+ StableDiffusion3Pipeline,
10
+ FluxPipeline
11
  )
12
+ from transformers import BlipProcessor, BlipForConditionalGeneration, pipeline
13
  from pathlib import Path
14
  from safetensors.torch import load_file
15
  from huggingface_hub import hf_hub_download
 
22
 
23
  access_token = os.getenv("AccessTokenSD3")
24
 
 
25
  from huggingface_hub import login
26
  login(token = access_token)
27
 
 
28
  # Define model initialization functions
29
  def load_model(model_name):
30
  if model_name == "stabilityai/sdxl-turbo":
 
33
  torch_dtype=torch.float16,
34
  variant="fp16"
35
  ).to("cuda")
 
 
 
 
 
36
  elif model_name == "ByteDance/SDXL-Lightning":
37
  base = "stabilityai/stable-diffusion-xl-base-1.0"
38
  ckpt = "sdxl_lightning_4step_unet.safetensors"
 
64
  scheduler=scheduler,
65
  torch_dtype=torch.float16
66
  ).to("cuda")
67
+ elif model_name == "black-forest-labs/FLUX.1-dev":
68
+ pipeline = FluxPipeline.from_pretrained(model_name, torch_dtype=torch.bfloat16)
69
+ pipeline.enable_model_cpu_offload()
70
  else:
71
  raise ValueError("Unknown model name")
72
  return pipeline
73
 
74
  # Initialize the default model
75
+ default_model = "black-forest-labs/FLUX.1-dev"
76
  pipeline_text2image = load_model(default_model)
77
 
78
  @spaces.GPU
79
  def getimgen(prompt, model_name):
80
  if model_name == "stabilityai/sdxl-turbo":
81
+ return pipeline_text2image(prompt=prompt, guidance_scale=0.0, num_inference_steps=2, height=512, width=512).images[0]
 
 
82
  elif model_name == "ByteDance/SDXL-Lightning":
83
+ return pipeline_text2image(prompt, num_inference_steps=4, guidance_scale=0, height=512, width=512).images[0]
84
  elif model_name == "segmind/SSD-1B":
85
  neg_prompt = "ugly, blurry, poor quality"
86
+ return pipeline_text2image(prompt=prompt, negative_prompt=neg_prompt, height=512, width=512).images[0]
87
  elif model_name == "stabilityai/stable-diffusion-3-medium-diffusers":
88
+ return pipeline_text2image(prompt=prompt, negative_prompt="", num_inference_steps=28, guidance_scale=7.0, height=512, width=512).images[0]
89
  elif model_name == "stabilityai/stable-diffusion-2":
90
+ return pipeline_text2image(prompt=prompt, height=512, width=512).images[0]
91
+ elif model_name == "black-forest-labs/FLUX.1-dev":
92
+ return pipeline_text2image(
93
+ prompt,
94
+ height=512,
95
+ width=512,
96
+ guidance_scale=3.5,
97
+ num_inference_steps=50,
98
+ max_sequence_length=512,
99
+ generator=torch.Generator("cpu").manual_seed(0)
100
+ ).images[0]
101
 
102
  blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
103
  blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large", torch_dtype=torch.float16).to("cuda")
 
141
  ax.add_patch(plt.Rectangle((0, 0), 1, 1, color=sorted_hex_codes[i]))
142
  return fig
143
 
144
+ def age_detector(image):
145
+ pipe = pipeline('image-classification', model="dima806/faces_age_detection", device=0)
146
+ result = pipe(image)
147
+ max_score_item = max(result, key=lambda item: item['score'])
148
+ return max_score_item['label']
149
+
150
+ def ageplot(agelist):
151
+ order = ["YOUNG", "MIDDLE", "OLD"]
152
+ words = sorted(agelist, key=lambda x: order.index(x))
153
+ colors = {"YOUNG": "skyblue", "MIDDLE": "royalblue", "OLD": "darkblue"}
154
+ word_colors = [colors[word] for word in words]
155
+ fig, axes = plt.subplots(2, 5, figsize=(5,5))
156
+ plt.subplots_adjust(hspace=0.1, wspace=0.1)
157
+ for i, ax in enumerate(axes.flat):
158
+ ax.set_axis_off()
159
+ ax.add_patch(plt.Rectangle((0, 0), 1, 1, color=word_colors[i]))
160
+ return fig
161
+
162
+ def is_nsfw(image):
163
+ classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection")
164
+ result = classifier(image)
165
+ max_score_item = max(result, key=lambda item: item['score'])
166
+ return max_score_item['label']
167
+
168
+ def nsfwplot(nsfwlist):
169
+ order = ["normal", "nsfw"]
170
+ words = sorted(nsfwlist, key=lambda x: order.index(x))
171
+ colors = {"normal": "mistyrose", "nsfw": "red"}
172
+ word_colors = [colors[word] for word in words]
173
+ fig, axes = plt.subplots(2, 5, figsize=(5,5))
174
+ plt.subplots_adjust(hspace=0.1, wspace=0.1)
175
+ for i, ax in enumerate(axes.flat):
176
+ ax.set_axis_off()
177
+ ax.add_patch(plt.Rectangle((0, 0), 1, 1, color=word_colors[i]))
178
+ return fig
179
+
180
  @spaces.GPU(duration=200)
181
  def generate_images_plots(prompt, model_name):
182
  global pipeline_text2image
 
186
  images = [getimgen(prompt, model_name) for _ in range(10)]
187
  genders = []
188
  skintones = []
189
+ ages = []
190
+ nsfws = []
191
  for image, i in zip(images, range(10)):
192
  prompt_prefix = "photo of a "
193
  caption = blip_caption_image(image, prefix=prompt_prefix)
 
199
  except:
200
  skintones.append(None)
201
  genders.append(genderfromcaption(caption))
202
+ ages.append(age_detector(image))
203
+ nsfws.append(is_nsfw(image))
204
+ return images, skintoneplot(skintones), genderplot(genders), ageplot(ages), nsfwplot(nsfws)
205
 
206
+ with gr.Blocks(title="Demographic bias in Text-to-Image Generation Models") as demo:
207
+ gr.Markdown("# Demographic bias in Text to Image Models")
208
  gr.Markdown('''
209
+ In this demo, we explore the potential biases in text-to-image models by generating multiple images based on user prompts and analyzing the gender, skin tone, age, and potential sexual nature of the generated subjects. Here's how the analysis works:
 
210
  1. **Image Generation**: For each prompt, 10 images are generated using the selected model.
211
  2. **Gender Detection**: The [BLIP caption generator](https://huggingface.co/Salesforce/blip-image-captioning-large) is used to elicit gender markers by identifying words like "man," "boy," "woman," and "girl" in the captions.
212
  3. **Skin Tone Classification**: The [skin-tone-classifier library](https://github.com/ChenglongMa/SkinToneClassifier) is used to extract the skin tones of the generated subjects.
213
+ 4. **Age Detection**: The [Faces Age Detection model](https://huggingface.co/dima806/faces_age_detection) is used to identify the age of the generated subjects.
214
+ 5. **NFAA Detection**: The [Falconsai/nsfw_image_detection](https://huggingface.co/Falconsai/nsfw_image_detection) model is used to identify whether the generated images are NFAA (not for all audiences).
215
  #### Visualization
 
216
  We create visual grids to represent the data:
 
217
  - **Skin Tone Grids**: Skin tones are plotted as exact hex codes rather than using the Fitzpatrick scale, which can be [problematic and limiting for darker skin tones](https://arxiv.org/pdf/2309.05148).
218
  - **Gender Grids**: Light green denotes men, dark green denotes women, and grey denotes cases where the BLIP caption did not specify a binary gender.
219
+ - **Age Grids**: Light blue denotes people between 18 and 30, blue denotes people between 30 and 50, and dark blue denotes people older than 50.
220
+ - **NFAA Grids**: Light red denotes FAA images, and dark red denotes NFAA images.
221
+
222
  This demo provides an insightful look into how current text-to-image models handle sensitive attributes, shedding light on areas for improvement and further study.
223
  [Here is an article](https://medium.com/@evijit/analysis-of-ai-generated-images-of-indian-people-for-colorism-and-sexism-b80ff946759f) showing how this space can be used to perform such analyses, using colorism and sexism in India as an example.
224
  ''')
225
  model_dropdown = gr.Dropdown(
226
  label="Choose a model",
227
  choices=[
228
+ "black-forest-labs/FLUX.1-dev",
229
  "stabilityai/stable-diffusion-3-medium-diffusers",
230
  "stabilityai/sdxl-turbo",
231
  "ByteDance/SDXL-Lightning",
232
  "stabilityai/stable-diffusion-2",
233
+ "segmind/SSD-1B",
 
234
  ],
235
  value=default_model
236
  )
 
248
  with gr.Row(equal_height=True):
249
  skinplot = gr.Plot(label="Skin Tone")
250
  genplot = gr.Plot(label="Gender")
251
+ with gr.Row(equal_height=True):
252
+ agesplot = gr.Plot(label="Age")
253
+ nsfwsplot = gr.Plot(label="NFAA")
254
+ btn.click(generate_images_plots, inputs=[prompt, model_dropdown], outputs=[gallery, skinplot, genplot, agesplot, nsfwsplot])
255
 
256
  demo.launch(debug=True)