Avijit Ghosh commited on
Commit
31a0f6f
1 Parent(s): f56644b

add dropdown

Browse files
Files changed (1) hide show
  1. app.py +83 -103
app.py CHANGED
@@ -1,70 +1,73 @@
1
  import gradio as gr
2
  import torch
3
- # from diffusers import AutoPipelineForText2Image
4
- from diffusers import DiffusionPipeline
5
  from transformers import BlipProcessor, BlipForConditionalGeneration
6
  from pathlib import Path
7
- import stone
8
- import requests
9
- import io
10
- import os
11
  from PIL import Image
12
- import spaces
13
-
14
  import matplotlib.pyplot as plt
15
- import numpy as np
16
  from matplotlib.colors import hex2color
17
- from huggingface_hub import list_models
18
-
19
- # Fetch models from Hugging Face Hub
20
- models = list_models(task="text-to-image")
21
- ## Step 1: Filter the models
22
- filtered_models = [model for model in models if model.library_name == "diffusers"]
23
-
24
- # Step 2: Sort the filtered models by downloads in descending order
25
- sorted_models = sorted(filtered_models, key=lambda x: x.downloads, reverse=True)
26
-
27
- # Step 3: Select the top 5 models with only one model per company
28
- top_models = []
29
- companies_seen = set()
30
-
31
- for model in sorted_models:
32
- company_name = model.id.split('/')[0] # Assuming the company name is the first part of the model id
33
- if company_name not in companies_seen:
34
- top_models.append(model)
35
- companies_seen.add(company_name)
36
- if len(top_models) == 5:
37
- break
38
-
39
- # Get the ids of the top models
40
- model_names = [model.id for model in top_models]
41
-
42
- print(model_names)
43
 
44
- # Initial pipeline setup
45
- default_model = model_names[0]
46
- print(default_model)
47
- pipeline_text2image = DiffusionPipeline.from_pretrained(
48
- default_model
49
- )
50
- pipeline_text2image = pipeline_text2image.to("cuda")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
- @spaces.GPU
53
  def getimgen(prompt):
54
-
55
- return pipeline_text2image(
56
- prompt=prompt,
57
- guidance_scale=0.0,
58
- num_inference_steps=2
59
- ).images[0]
60
 
61
  blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
62
- blip_model = BlipForConditionalGeneration.from_pretrained(
63
- "Salesforce/blip-image-captioning-large",
64
- torch_dtype=torch.float16
65
- ).to("cuda")
66
 
67
- @spaces.GPU
68
  def blip_caption_image(image, prefix):
69
  inputs = blip_processor(image, prefix, return_tensors="pt").to("cuda", torch.float16)
70
  out = blip_model.generate(**inputs)
@@ -80,69 +83,37 @@ def genderfromcaption(caption):
80
 
81
  def genderplot(genlist):
82
  order = ["Man", "Woman", "Unsure"]
83
-
84
- # Sort the list based on the order of keys
85
  words = sorted(genlist, key=lambda x: order.index(x))
86
-
87
- # Define colors for each category
88
  colors = {"Man": "lightgreen", "Woman": "darkgreen", "Unsure": "lightgrey"}
89
-
90
- # Map each word to its corresponding color
91
  word_colors = [colors[word] for word in words]
92
-
93
- # Plot the colors in a grid with reduced spacing
94
  fig, axes = plt.subplots(2, 5, figsize=(5,5))
95
-
96
- # Adjust spacing between subplots
97
  plt.subplots_adjust(hspace=0.1, wspace=0.1)
98
-
99
  for i, ax in enumerate(axes.flat):
100
  ax.set_axis_off()
101
  ax.add_patch(plt.Rectangle((0, 0), 1, 1, color=word_colors[i]))
102
-
103
  return fig
104
 
105
  def skintoneplot(hex_codes):
106
- # Convert hex codes to RGB values
107
  rgb_values = [hex2color(hex_code) for hex_code in hex_codes]
108
-
109
- # Calculate luminance for each color
110
  luminance_values = [0.299 * r + 0.587 * g + 0.114 * b for r, g, b in rgb_values]
111
-
112
- # Sort hex codes based on luminance in descending order (dark to light)
113
  sorted_hex_codes = [code for _, code in sorted(zip(luminance_values, hex_codes), reverse=True)]
114
-
115
- # Plot the colors in a grid with reduced spacing
116
  fig, axes = plt.subplots(2, 5, figsize=(5,5))
117
-
118
- # Adjust spacing between subplots
119
  plt.subplots_adjust(hspace=0.1, wspace=0.1)
120
-
121
  for i, ax in enumerate(axes.flat):
122
  ax.set_axis_off()
123
- ax.add_patch(plt.Rectangle((0, 0), 1, 1, color=sorted_hex_codes[i]))
124
-
125
  return fig
126
 
127
- @spaces.GPU
128
  def generate_images_plots(prompt, model_name):
129
- print(model_name)
130
- # Update the pipeline to use the selected model
131
  global pipeline_text2image
132
- pipeline_text2image = DiffusionPipeline.from_pretrained(
133
- model_name
134
- )
135
- pipeline_text2image = pipeline_text2image.to("cuda")
136
-
137
  foldername = "temp"
138
- # Generate 10 images
139
- images = [getimgen(prompt) for _ in range(10)]
140
-
141
  Path(foldername).mkdir(parents=True, exist_ok=True)
142
-
143
  genders = []
144
  skintones = []
145
-
146
  for image, i in zip(images, range(10)):
147
  prompt_prefix = "photo of a "
148
  caption = blip_caption_image(image, prefix=prompt_prefix)
@@ -153,26 +124,35 @@ def generate_images_plots(prompt, model_name):
153
  skintones.append(tone)
154
  except:
155
  skintones.append(None)
156
-
157
  genders.append(genderfromcaption(caption))
158
-
159
- print(genders, skintones)
160
-
161
  return images, skintoneplot(skintones), genderplot(genders)
162
 
163
- with gr.Blocks(title = "Skin Tone and Gender bias in Text to Image Models") as demo:
164
-
165
  gr.Markdown("# Skin Tone and Gender bias in Text to Image Models")
166
-
167
- model_dropdown = gr.Dropdown(label="Choose a model", choices=model_names, value=default_model)
 
 
 
 
 
 
 
 
168
  prompt = gr.Textbox(label="Enter the Prompt")
169
- gallery = gr.Gallery(label="Generated images", show_label=False, elem_id="gallery",
170
- columns=[5], rows=[2], object_fit="contain", height="auto")
 
 
 
 
 
 
 
171
  btn = gr.Button("Generate images", scale=0)
172
  with gr.Row(equal_height=True):
173
  skinplot = gr.Plot(label="Skin Tone")
174
  genplot = gr.Plot(label="Gender")
175
-
176
  btn.click(generate_images_plots, inputs=[prompt, model_dropdown], outputs=[gallery, skinplot, genplot])
177
 
178
- demo.launch(debug=True)
 
1
  import gradio as gr
2
  import torch
3
+ from diffusers import DiffusionPipeline, StableDiffusionPipeline, StableDiffusionXLPipeline, EulerDiscreteScheduler, UNet2DConditionModel
 
4
  from transformers import BlipProcessor, BlipForConditionalGeneration
5
  from pathlib import Path
6
+ from safetensors.torch import load_file
7
+ from huggingface_hub import hf_hub_download
 
 
8
  from PIL import Image
 
 
9
  import matplotlib.pyplot as plt
 
10
  from matplotlib.colors import hex2color
11
+ import stone
12
+ import os
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
+ # Define model initialization functions
15
+ def load_model(model_name):
16
+ if model_name == "stabilityai/sdxl-turbo":
17
+ pipeline = DiffusionPipeline.from_pretrained(
18
+ model_name,
19
+ torch_dtype=torch.float16,
20
+ variant="fp16"
21
+ ).to("cuda")
22
+ elif model_name == "runwayml/stable-diffusion-v1-5":
23
+ pipeline = StableDiffusionPipeline.from_pretrained(
24
+ model_name,
25
+ torch_dtype=torch.float16
26
+ ).to("cuda")
27
+ elif model_name == "ByteDance/SDXL-Lightning":
28
+ base = "stabilityai/stable-diffusion-xl-base-1.0"
29
+ ckpt = "sdxl_lightning_4step_unet.safetensors"
30
+ unet = UNet2DConditionModel.from_config(base, subfolder="unet").to("cuda", torch.float16)
31
+ unet.load_state_dict(load_file(hf_hub_download(model_name, ckpt), device="cuda"))
32
+ pipeline = StableDiffusionXLPipeline.from_pretrained(
33
+ base,
34
+ unet=unet,
35
+ torch_dtype=torch.float16,
36
+ variant="fp16"
37
+ ).to("cuda")
38
+ pipeline.scheduler = EulerDiscreteScheduler.from_config(pipeline.scheduler.config, timestep_spacing="trailing")
39
+ elif model_name == "segmind/SSD-1B":
40
+ pipeline = StableDiffusionXLPipeline.from_pretrained(
41
+ model_name,
42
+ torch_dtype=torch.float16,
43
+ use_safetensors=True,
44
+ variant="fp16"
45
+ ).to("cuda")
46
+ else:
47
+ raise ValueError("Unknown model name")
48
+ return pipeline
49
+
50
+ choices=[
51
+ "stabilityai/sdxl-turbo",
52
+ "runwayml/stable-diffusion-v1-5",
53
+ "ByteDance/SDXL-Lightning",
54
+ "segmind/SSD-1B"
55
+ ]
56
+
57
+ for model_name in choices:
58
+ load_model(model_name)
59
+
60
+ # Initialize the default model
61
+ default_model = "stabilityai/sdxl-turbo"
62
+
63
+ pipeline_text2image = load_model(default_model)
64
 
 
65
  def getimgen(prompt):
66
+ return pipeline_text2image(prompt=prompt, guidance_scale=0.0, num_inference_steps=2).images[0]
 
 
 
 
 
67
 
68
  blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
69
+ blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large", torch_dtype=torch.float16).to("cuda")
 
 
 
70
 
 
71
  def blip_caption_image(image, prefix):
72
  inputs = blip_processor(image, prefix, return_tensors="pt").to("cuda", torch.float16)
73
  out = blip_model.generate(**inputs)
 
83
 
84
  def genderplot(genlist):
85
  order = ["Man", "Woman", "Unsure"]
 
 
86
  words = sorted(genlist, key=lambda x: order.index(x))
 
 
87
  colors = {"Man": "lightgreen", "Woman": "darkgreen", "Unsure": "lightgrey"}
 
 
88
  word_colors = [colors[word] for word in words]
 
 
89
  fig, axes = plt.subplots(2, 5, figsize=(5,5))
 
 
90
  plt.subplots_adjust(hspace=0.1, wspace=0.1)
 
91
  for i, ax in enumerate(axes.flat):
92
  ax.set_axis_off()
93
  ax.add_patch(plt.Rectangle((0, 0), 1, 1, color=word_colors[i]))
 
94
  return fig
95
 
96
  def skintoneplot(hex_codes):
97
+ hex_codes = [code for code in hex_codes if code is not None]
98
  rgb_values = [hex2color(hex_code) for hex_code in hex_codes]
 
 
99
  luminance_values = [0.299 * r + 0.587 * g + 0.114 * b for r, g, b in rgb_values]
 
 
100
  sorted_hex_codes = [code for _, code in sorted(zip(luminance_values, hex_codes), reverse=True)]
 
 
101
  fig, axes = plt.subplots(2, 5, figsize=(5,5))
 
 
102
  plt.subplots_adjust(hspace=0.1, wspace=0.1)
 
103
  for i, ax in enumerate(axes.flat):
104
  ax.set_axis_off()
105
+ if i < len(sorted_hex_codes):
106
+ ax.add_patch(plt.Rectangle((0, 0), 1, 1, color=sorted_hex_codes[i]))
107
  return fig
108
 
 
109
  def generate_images_plots(prompt, model_name):
 
 
110
  global pipeline_text2image
111
+ pipeline_text2image = load_model(model_name)
 
 
 
 
112
  foldername = "temp"
 
 
 
113
  Path(foldername).mkdir(parents=True, exist_ok=True)
114
+ images = [getimgen(prompt) for _ in range(10)]
115
  genders = []
116
  skintones = []
 
117
  for image, i in zip(images, range(10)):
118
  prompt_prefix = "photo of a "
119
  caption = blip_caption_image(image, prefix=prompt_prefix)
 
124
  skintones.append(tone)
125
  except:
126
  skintones.append(None)
 
127
  genders.append(genderfromcaption(caption))
 
 
 
128
  return images, skintoneplot(skintones), genderplot(genders)
129
 
130
+ with gr.Blocks(title="Skin Tone and Gender bias in Text to Image Models") as demo:
 
131
  gr.Markdown("# Skin Tone and Gender bias in Text to Image Models")
132
+ model_dropdown = gr.Dropdown(
133
+ label="Choose a model",
134
+ choices=[
135
+ "stabilityai/sdxl-turbo",
136
+ "runwayml/stable-diffusion-v1-5",
137
+ "ByteDance/SDXL-Lightning",
138
+ "segmind/SSD-1B"
139
+ ],
140
+ value=default_model
141
+ )
142
  prompt = gr.Textbox(label="Enter the Prompt")
143
+ gallery = gr.Gallery(
144
+ label="Generated images",
145
+ show_label=False,
146
+ elem_id="gallery",
147
+ columns=[5],
148
+ rows=[2],
149
+ object_fit="contain",
150
+ height="auto"
151
+ )
152
  btn = gr.Button("Generate images", scale=0)
153
  with gr.Row(equal_height=True):
154
  skinplot = gr.Plot(label="Skin Tone")
155
  genplot = gr.Plot(label="Gender")
 
156
  btn.click(generate_images_plots, inputs=[prompt, model_dropdown], outputs=[gallery, skinplot, genplot])
157
 
158
+ demo.launch(debug=True)