Avijit Ghosh commited on
Commit
f56644b
1 Parent(s): 64fe77f

playing around with model options

Browse files
Files changed (3) hide show
  1. app copy.py +149 -0
  2. app.py +48 -19
  3. test.ipynb +277 -0
app copy.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from diffusers import AutoPipelineForText2Image
4
+ from transformers import BlipProcessor, BlipForConditionalGeneration
5
+ from pathlib import Path
6
+ import stone
7
+ import requests
8
+ import io
9
+ import os
10
+ from PIL import Image
11
+ import spaces
12
+
13
+ import matplotlib.pyplot as plt
14
+ import numpy as np
15
+ from matplotlib.colors import hex2color
16
+
17
+
18
+ pipeline_text2image = AutoPipelineForText2Image.from_pretrained(
19
+ "stabilityai/sdxl-turbo",
20
+ torch_dtype=torch.float16,
21
+ variant="fp16",
22
+ )
23
+ pipeline_text2image = pipeline_text2image.to("cuda")
24
+
25
+
26
+ @spaces.GPU
27
+ def getimgen(prompt):
28
+
29
+ return pipeline_text2image(
30
+ prompt=prompt,
31
+ guidance_scale=0.0,
32
+ num_inference_steps=2
33
+ ).images[0]
34
+
35
+
36
+ blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
37
+ blip_model = BlipForConditionalGeneration.from_pretrained(
38
+ "Salesforce/blip-image-captioning-large",
39
+ torch_dtype=torch.float16
40
+ ).to("cuda")
41
+
42
+
43
+ @spaces.GPU
44
+ def blip_caption_image(image, prefix):
45
+ inputs = blip_processor(image, prefix, return_tensors="pt").to("cuda", torch.float16)
46
+ out = blip_model.generate(**inputs)
47
+ return blip_processor.decode(out[0], skip_special_tokens=True)
48
+
49
+ def genderfromcaption(caption):
50
+ cc = caption.split()
51
+ if "man" in cc or "boy" in cc:
52
+ return "Man"
53
+ elif "woman" in cc or "girl" in cc:
54
+ return "Woman"
55
+ return "Unsure"
56
+
57
+ def genderplot(genlist):
58
+ order = ["Man", "Woman", "Unsure"]
59
+
60
+ # Sort the list based on the order of keys
61
+ words = sorted(genlist, key=lambda x: order.index(x))
62
+
63
+ # Define colors for each category
64
+ colors = {"Man": "lightgreen", "Woman": "darkgreen", "Unsure": "lightgrey"}
65
+
66
+ # Map each word to its corresponding color
67
+ word_colors = [colors[word] for word in words]
68
+
69
+ # Plot the colors in a grid with reduced spacing
70
+ fig, axes = plt.subplots(2, 5, figsize=(5,5))
71
+
72
+ # Adjust spacing between subplots
73
+ plt.subplots_adjust(hspace=0.1, wspace=0.1)
74
+
75
+ for i, ax in enumerate(axes.flat):
76
+ ax.set_axis_off()
77
+ ax.add_patch(plt.Rectangle((0, 0), 1, 1, color=word_colors[i]))
78
+
79
+ return fig
80
+
81
+ def skintoneplot(hex_codes):
82
+ # Convert hex codes to RGB values
83
+ rgb_values = [hex2color(hex_code) for hex_code in hex_codes]
84
+
85
+ # Calculate luminance for each color
86
+ luminance_values = [0.299 * r + 0.587 * g + 0.114 * b for r, g, b in rgb_values]
87
+
88
+ # Sort hex codes based on luminance in descending order (dark to light)
89
+ sorted_hex_codes = [code for _, code in sorted(zip(luminance_values, hex_codes), reverse=True)]
90
+
91
+ # Plot the colors in a grid with reduced spacing
92
+ fig, axes = plt.subplots(2, 5, figsize=(5,5))
93
+
94
+ # Adjust spacing between subplots
95
+ plt.subplots_adjust(hspace=0.1, wspace=0.1)
96
+
97
+ for i, ax in enumerate(axes.flat):
98
+ ax.set_axis_off()
99
+ ax.add_patch(plt.Rectangle((0, 0), 1, 1, color=sorted_hex_codes[i]))
100
+
101
+ return fig
102
+
103
+ @spaces.GPU
104
+ def generate_images_plots(prompt):
105
+ foldername = "temp"
106
+ # Generate 10 images
107
+ images = [getimgen(prompt) for _ in range(10)]
108
+
109
+ Path(foldername).mkdir(parents=True, exist_ok=True)
110
+
111
+ genders = []
112
+ skintones = []
113
+
114
+ for image, i in zip(images, range(10)):
115
+ prompt_prefix = "photo of a "
116
+ caption = blip_caption_image(image, prefix=prompt_prefix)
117
+ image.save(f"{foldername}/image_{i}.png")
118
+ try:
119
+ skintoneres = stone.process(f"{foldername}/image_{i}.png", return_report_image=False)
120
+ tone = skintoneres['faces'][0]['dominant_colors'][0]['color']
121
+ skintones.append(tone)
122
+ except:
123
+ skintones.append(None)
124
+
125
+ genders.append(genderfromcaption(caption))
126
+
127
+ print(genders, skintones)
128
+
129
+ return images, skintoneplot(skintones), genderplot(genders)
130
+
131
+
132
+ with gr.Blocks(title = "Skin Tone and Gender bias in SDXL Demo - Inference API") as demo:
133
+
134
+ gr.Markdown("# Skin Tone and Gender bias in SDXL Demo")
135
+
136
+ prompt = gr.Textbox(label="Enter the Prompt")
137
+ gallery = gr.Gallery(label="Generated images", show_label=False, elem_id="gallery",
138
+ columns=[5], rows=[2], object_fit="contain", height="auto")
139
+ btn = gr.Button("Generate images", scale=0)
140
+ with gr.Row(equal_height=True):
141
+ skinplot = gr.Plot(label="Skin Tone")
142
+ genplot = gr.Plot(label="Gender")
143
+
144
+
145
+ btn.click(generate_images_plots, inputs = prompt, outputs = [gallery, skinplot, genplot])
146
+
147
+
148
+
149
+ demo.launch(debug=True)
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import gradio as gr
2
  import torch
3
- from diffusers import AutoPipelineForText2Image
 
4
  from transformers import BlipProcessor, BlipForConditionalGeneration
5
  from pathlib import Path
6
  import stone
@@ -13,16 +14,41 @@ import spaces
13
  import matplotlib.pyplot as plt
14
  import numpy as np
15
  from matplotlib.colors import hex2color
16
-
17
-
18
- pipeline_text2image = AutoPipelineForText2Image.from_pretrained(
19
- "stabilityai/sdxl-turbo",
20
- torch_dtype=torch.float16,
21
- variant="fp16",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  )
23
  pipeline_text2image = pipeline_text2image.to("cuda")
24
 
25
-
26
  @spaces.GPU
27
  def getimgen(prompt):
28
 
@@ -32,14 +58,12 @@ def getimgen(prompt):
32
  num_inference_steps=2
33
  ).images[0]
34
 
35
-
36
  blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
37
  blip_model = BlipForConditionalGeneration.from_pretrained(
38
  "Salesforce/blip-image-captioning-large",
39
  torch_dtype=torch.float16
40
  ).to("cuda")
41
 
42
-
43
  @spaces.GPU
44
  def blip_caption_image(image, prefix):
45
  inputs = blip_processor(image, prefix, return_tensors="pt").to("cuda", torch.float16)
@@ -101,7 +125,15 @@ def skintoneplot(hex_codes):
101
  return fig
102
 
103
  @spaces.GPU
104
- def generate_images_plots(prompt):
 
 
 
 
 
 
 
 
105
  foldername = "temp"
106
  # Generate 10 images
107
  images = [getimgen(prompt) for _ in range(10)]
@@ -128,11 +160,11 @@ def generate_images_plots(prompt):
128
 
129
  return images, skintoneplot(skintones), genderplot(genders)
130
 
 
131
 
132
- with gr.Blocks(title = "Skin Tone and Gender bias in SDXL Demo - Inference API") as demo:
133
-
134
- gr.Markdown("# Skin Tone and Gender bias in SDXL Demo")
135
 
 
136
  prompt = gr.Textbox(label="Enter the Prompt")
137
  gallery = gr.Gallery(label="Generated images", show_label=False, elem_id="gallery",
138
  columns=[5], rows=[2], object_fit="contain", height="auto")
@@ -141,9 +173,6 @@ with gr.Blocks(title = "Skin Tone and Gender bias in SDXL Demo - Inference API")
141
  skinplot = gr.Plot(label="Skin Tone")
142
  genplot = gr.Plot(label="Gender")
143
 
 
144
 
145
- btn.click(generate_images_plots, inputs = prompt, outputs = [gallery, skinplot, genplot])
146
-
147
-
148
-
149
- demo.launch(debug=True)
 
1
  import gradio as gr
2
  import torch
3
+ # from diffusers import AutoPipelineForText2Image
4
+ from diffusers import DiffusionPipeline
5
  from transformers import BlipProcessor, BlipForConditionalGeneration
6
  from pathlib import Path
7
  import stone
 
14
  import matplotlib.pyplot as plt
15
  import numpy as np
16
  from matplotlib.colors import hex2color
17
+ from huggingface_hub import list_models
18
+
19
+ # Fetch models from Hugging Face Hub
20
+ models = list_models(task="text-to-image")
21
+ ## Step 1: Filter the models
22
+ filtered_models = [model for model in models if model.library_name == "diffusers"]
23
+
24
+ # Step 2: Sort the filtered models by downloads in descending order
25
+ sorted_models = sorted(filtered_models, key=lambda x: x.downloads, reverse=True)
26
+
27
+ # Step 3: Select the top 5 models with only one model per company
28
+ top_models = []
29
+ companies_seen = set()
30
+
31
+ for model in sorted_models:
32
+ company_name = model.id.split('/')[0] # Assuming the company name is the first part of the model id
33
+ if company_name not in companies_seen:
34
+ top_models.append(model)
35
+ companies_seen.add(company_name)
36
+ if len(top_models) == 5:
37
+ break
38
+
39
+ # Get the ids of the top models
40
+ model_names = [model.id for model in top_models]
41
+
42
+ print(model_names)
43
+
44
+ # Initial pipeline setup
45
+ default_model = model_names[0]
46
+ print(default_model)
47
+ pipeline_text2image = DiffusionPipeline.from_pretrained(
48
+ default_model
49
  )
50
  pipeline_text2image = pipeline_text2image.to("cuda")
51
 
 
52
  @spaces.GPU
53
  def getimgen(prompt):
54
 
 
58
  num_inference_steps=2
59
  ).images[0]
60
 
 
61
  blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
62
  blip_model = BlipForConditionalGeneration.from_pretrained(
63
  "Salesforce/blip-image-captioning-large",
64
  torch_dtype=torch.float16
65
  ).to("cuda")
66
 
 
67
  @spaces.GPU
68
  def blip_caption_image(image, prefix):
69
  inputs = blip_processor(image, prefix, return_tensors="pt").to("cuda", torch.float16)
 
125
  return fig
126
 
127
  @spaces.GPU
128
+ def generate_images_plots(prompt, model_name):
129
+ print(model_name)
130
+ # Update the pipeline to use the selected model
131
+ global pipeline_text2image
132
+ pipeline_text2image = DiffusionPipeline.from_pretrained(
133
+ model_name
134
+ )
135
+ pipeline_text2image = pipeline_text2image.to("cuda")
136
+
137
  foldername = "temp"
138
  # Generate 10 images
139
  images = [getimgen(prompt) for _ in range(10)]
 
160
 
161
  return images, skintoneplot(skintones), genderplot(genders)
162
 
163
+ with gr.Blocks(title = "Skin Tone and Gender bias in Text to Image Models") as demo:
164
 
165
+ gr.Markdown("# Skin Tone and Gender bias in Text to Image Models")
 
 
166
 
167
+ model_dropdown = gr.Dropdown(label="Choose a model", choices=model_names, value=default_model)
168
  prompt = gr.Textbox(label="Enter the Prompt")
169
  gallery = gr.Gallery(label="Generated images", show_label=False, elem_id="gallery",
170
  columns=[5], rows=[2], object_fit="contain", height="auto")
 
173
  skinplot = gr.Plot(label="Skin Tone")
174
  genplot = gr.Plot(label="Gender")
175
 
176
+ btn.click(generate_images_plots, inputs=[prompt, model_dropdown], outputs=[gallery, skinplot, genplot])
177
 
178
+ demo.launch(debug=True)
 
 
 
 
test.ipynb ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 2,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "from diffusers import AutoPipelineForText2Image\n",
10
+ "import torch"
11
+ ]
12
+ },
13
+ {
14
+ "cell_type": "code",
15
+ "execution_count": 7,
16
+ "metadata": {},
17
+ "outputs": [
18
+ {
19
+ "name": "stderr",
20
+ "output_type": "stream",
21
+ "text": [
22
+ "vae/diffusion_pytorch_model.safetensors not found\n"
23
+ ]
24
+ },
25
+ {
26
+ "data": {
27
+ "application/vnd.jupyter.widget-view+json": {
28
+ "model_id": "1822a5456c3244b6b5831817d6d0ebbc",
29
+ "version_major": 2,
30
+ "version_minor": 0
31
+ },
32
+ "text/plain": [
33
+ "Fetching 15 files: 0%| | 0/15 [00:00<?, ?it/s]"
34
+ ]
35
+ },
36
+ "metadata": {},
37
+ "output_type": "display_data"
38
+ },
39
+ {
40
+ "data": {
41
+ "application/vnd.jupyter.widget-view+json": {
42
+ "model_id": "6232a45e054e4439a822345b4744b0b8",
43
+ "version_major": 2,
44
+ "version_minor": 0
45
+ },
46
+ "text/plain": [
47
+ "tokenizer/special_tokens_map.json: 0%| | 0.00/472 [00:00<?, ?B/s]"
48
+ ]
49
+ },
50
+ "metadata": {},
51
+ "output_type": "display_data"
52
+ },
53
+ {
54
+ "data": {
55
+ "application/vnd.jupyter.widget-view+json": {
56
+ "model_id": "da77db92e94e482b83b679b698630be0",
57
+ "version_major": 2,
58
+ "version_minor": 0
59
+ },
60
+ "text/plain": [
61
+ "(…)ature_extractor/preprocessor_config.json: 0%| | 0.00/342 [00:00<?, ?B/s]"
62
+ ]
63
+ },
64
+ "metadata": {},
65
+ "output_type": "display_data"
66
+ },
67
+ {
68
+ "data": {
69
+ "application/vnd.jupyter.widget-view+json": {
70
+ "model_id": "1f73d5e96b8b4c218433d8c9ea66b4b1",
71
+ "version_major": 2,
72
+ "version_minor": 0
73
+ },
74
+ "text/plain": [
75
+ "safety_checker/config.json: 0%| | 0.00/4.80k [00:00<?, ?B/s]"
76
+ ]
77
+ },
78
+ "metadata": {},
79
+ "output_type": "display_data"
80
+ },
81
+ {
82
+ "data": {
83
+ "application/vnd.jupyter.widget-view+json": {
84
+ "model_id": "4eb1d6d2c98c46e6940129236928b9d8",
85
+ "version_major": 2,
86
+ "version_minor": 0
87
+ },
88
+ "text/plain": [
89
+ "text_encoder/config.json: 0%| | 0.00/589 [00:00<?, ?B/s]"
90
+ ]
91
+ },
92
+ "metadata": {},
93
+ "output_type": "display_data"
94
+ },
95
+ {
96
+ "data": {
97
+ "application/vnd.jupyter.widget-view+json": {
98
+ "model_id": "8455fd4bd49e462e94a15614968edeeb",
99
+ "version_major": 2,
100
+ "version_minor": 0
101
+ },
102
+ "text/plain": [
103
+ "tokenizer/merges.txt: 0%| | 0.00/525k [00:00<?, ?B/s]"
104
+ ]
105
+ },
106
+ "metadata": {},
107
+ "output_type": "display_data"
108
+ },
109
+ {
110
+ "data": {
111
+ "application/vnd.jupyter.widget-view+json": {
112
+ "model_id": "85c9b8ddc3084051b9de01a1cc4b346f",
113
+ "version_major": 2,
114
+ "version_minor": 0
115
+ },
116
+ "text/plain": [
117
+ "scheduler/scheduler_config.json: 0%| | 0.00/284 [00:00<?, ?B/s]"
118
+ ]
119
+ },
120
+ "metadata": {},
121
+ "output_type": "display_data"
122
+ },
123
+ {
124
+ "data": {
125
+ "application/vnd.jupyter.widget-view+json": {
126
+ "model_id": "f01d373de73e49198ed9a509ff884d83",
127
+ "version_major": 2,
128
+ "version_minor": 0
129
+ },
130
+ "text/plain": [
131
+ "unet/config.json: 0%| | 0.00/789 [00:00<?, ?B/s]"
132
+ ]
133
+ },
134
+ "metadata": {},
135
+ "output_type": "display_data"
136
+ },
137
+ {
138
+ "data": {
139
+ "application/vnd.jupyter.widget-view+json": {
140
+ "model_id": "2dad6d09e0b944cb91e5e68d374ad283",
141
+ "version_major": 2,
142
+ "version_minor": 0
143
+ },
144
+ "text/plain": [
145
+ "tokenizer/vocab.json: 0%| | 0.00/1.06M [00:00<?, ?B/s]"
146
+ ]
147
+ },
148
+ "metadata": {},
149
+ "output_type": "display_data"
150
+ },
151
+ {
152
+ "data": {
153
+ "application/vnd.jupyter.widget-view+json": {
154
+ "model_id": "69c41904a7dd439296a3a56c0a37acda",
155
+ "version_major": 2,
156
+ "version_minor": 0
157
+ },
158
+ "text/plain": [
159
+ "vae/config.json: 0%| | 0.00/592 [00:00<?, ?B/s]"
160
+ ]
161
+ },
162
+ "metadata": {},
163
+ "output_type": "display_data"
164
+ },
165
+ {
166
+ "data": {
167
+ "application/vnd.jupyter.widget-view+json": {
168
+ "model_id": "a3c8d76377f9445e81d73336d2b70535",
169
+ "version_major": 2,
170
+ "version_minor": 0
171
+ },
172
+ "text/plain": [
173
+ "safety_checker/pytorch_model.bin: 0%| | 0.00/608M [00:00<?, ?B/s]"
174
+ ]
175
+ },
176
+ "metadata": {},
177
+ "output_type": "display_data"
178
+ },
179
+ {
180
+ "data": {
181
+ "application/vnd.jupyter.widget-view+json": {
182
+ "model_id": "67a6fd052c734c69950ce38c7503821a",
183
+ "version_major": 2,
184
+ "version_minor": 0
185
+ },
186
+ "text/plain": [
187
+ "text_encoder/pytorch_model.bin: 0%| | 0.00/246M [00:00<?, ?B/s]"
188
+ ]
189
+ },
190
+ "metadata": {},
191
+ "output_type": "display_data"
192
+ },
193
+ {
194
+ "data": {
195
+ "application/vnd.jupyter.widget-view+json": {
196
+ "model_id": "e7de26e3010f484bb698f62dfae0a255",
197
+ "version_major": 2,
198
+ "version_minor": 0
199
+ },
200
+ "text/plain": [
201
+ "tokenizer/tokenizer_config.json: 0%| | 0.00/805 [00:00<?, ?B/s]"
202
+ ]
203
+ },
204
+ "metadata": {},
205
+ "output_type": "display_data"
206
+ },
207
+ {
208
+ "data": {
209
+ "application/vnd.jupyter.widget-view+json": {
210
+ "model_id": "d4080d04c64642a09514fd7570ba0cde",
211
+ "version_major": 2,
212
+ "version_minor": 0
213
+ },
214
+ "text/plain": [
215
+ "unet/diffusion_pytorch_model.bin: 0%| | 0.00/1.72G [00:00<?, ?B/s]"
216
+ ]
217
+ },
218
+ "metadata": {},
219
+ "output_type": "display_data"
220
+ },
221
+ {
222
+ "data": {
223
+ "application/vnd.jupyter.widget-view+json": {
224
+ "model_id": "402fd6d5c4054bdfae85f085c538ebac",
225
+ "version_major": 2,
226
+ "version_minor": 0
227
+ },
228
+ "text/plain": [
229
+ "vae/diffusion_pytorch_model.bin: 0%| | 0.00/167M [00:00<?, ?B/s]"
230
+ ]
231
+ },
232
+ "metadata": {},
233
+ "output_type": "display_data"
234
+ }
235
+ ],
236
+ "source": [
237
+ "# model = \"CompVis/ldm-text2im-large-256\"\n",
238
+ "model = \"sd-dreambooth-library/colorful-ball\"\n",
239
+ "# model = \"stabilityai/sdxl-turbo\"\n",
240
+ "\n",
241
+ "pipeline_text2image = AutoPipelineForText2Image.from_pretrained(\n",
242
+ " model,\n",
243
+ " torch_dtype=torch.float16,\n",
244
+ ")\n",
245
+ "pipeline_text2image = pipeline_text2image.to(\"cuda\")"
246
+ ]
247
+ },
248
+ {
249
+ "cell_type": "code",
250
+ "execution_count": null,
251
+ "metadata": {},
252
+ "outputs": [],
253
+ "source": []
254
+ }
255
+ ],
256
+ "metadata": {
257
+ "kernelspec": {
258
+ "display_name": "gradio",
259
+ "language": "python",
260
+ "name": "python3"
261
+ },
262
+ "language_info": {
263
+ "codemirror_mode": {
264
+ "name": "ipython",
265
+ "version": 3
266
+ },
267
+ "file_extension": ".py",
268
+ "mimetype": "text/x-python",
269
+ "name": "python",
270
+ "nbconvert_exporter": "python",
271
+ "pygments_lexer": "ipython3",
272
+ "version": "3.12.2"
273
+ }
274
+ },
275
+ "nbformat": 4,
276
+ "nbformat_minor": 2
277
+ }