Avijit Ghosh commited on
Commit
e7204ee
1 Parent(s): 17497eb
Files changed (1) hide show
  1. app.py +18 -13
app.py CHANGED
@@ -4,12 +4,16 @@ from diffusers import DiffusionPipeline, StableDiffusionPipeline, StableDiffusio
4
  from transformers import BlipProcessor, BlipForConditionalGeneration
5
  from pathlib import Path
6
  from safetensors.torch import load_file
7
- from huggingface_hub import hf_hub_download, list_models
8
  from PIL import Image
9
  import matplotlib.pyplot as plt
10
  from matplotlib.colors import hex2color
11
  import stone
12
  import os
 
 
 
 
13
 
14
  # Define model initialization functions
15
  def load_model(model_name, token=None):
@@ -59,10 +63,10 @@ def load_model(model_name, token=None):
59
  default_model = "stabilityai/sdxl-turbo"
60
  pipeline_text2image = load_model(default_model)
61
 
62
- @gr.outputs.Image(type="pil")
63
- def getimgen(prompt, model_name, token):
64
  global pipeline_text2image
65
- pipeline_text2image = load_model(model_name, token)
66
  if model_name == "stabilityai/sdxl-turbo":
67
  return pipeline_text2image(prompt=prompt, guidance_scale=0.0, num_inference_steps=2).images[0]
68
  elif model_name == "runwayml/stable-diffusion-v1-5":
@@ -78,7 +82,7 @@ def getimgen(prompt, model_name, token):
78
  blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
79
  blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large", torch_dtype=torch.float16).to("cuda")
80
 
81
- @gr.outputs.Textbox(type="str")
82
  def blip_caption_image(image, prefix):
83
  inputs = blip_processor(image, prefix, return_tensors="pt").to("cuda", torch.float16)
84
  out = blip_model.generate(**inputs)
@@ -117,11 +121,13 @@ def skintoneplot(hex_codes):
117
  ax.add_patch(plt.Rectangle((0, 0), 1, 1, color=sorted_hex_codes[i]))
118
  return fig
119
 
120
- @gr.outputs.Image(type="pil")
121
- def generate_images_plots(prompt, model_name, token):
 
 
122
  foldername = "temp"
123
  Path(foldername).mkdir(parents=True, exist_ok=True)
124
- images = [getimgen(prompt, model_name, token) for _ in range(10)]
125
  genders = []
126
  skintones = []
127
  for image, i in zip(images, range(10)):
@@ -139,9 +145,6 @@ def generate_images_plots(prompt, model_name, token):
139
 
140
  with gr.Blocks(title="Skin Tone and Gender bias in Text to Image Models") as demo:
141
  gr.Markdown("# Skin Tone and Gender bias in Text to Image Models")
142
- gr.LoginButton() # Add a login button for Hugging Face
143
- profile = gr.State()
144
- token = gr.State()
145
  model_dropdown = gr.Dropdown(
146
  label="Choose a model",
147
  choices=[
@@ -163,10 +166,12 @@ with gr.Blocks(title="Skin Tone and Gender bias in Text to Image Models") as dem
163
  object_fit="contain",
164
  height="auto"
165
  )
 
 
166
  btn = gr.Button("Generate images", scale=0)
167
  with gr.Row(equal_height=True):
168
  skinplot = gr.Plot(label="Skin Tone")
169
  genplot = gr.Plot(label="Gender")
170
- btn.click(generate_images_plots, inputs=[prompt, model_dropdown, token], outputs=[gallery, skinplot, genplot])
171
 
172
- demo.launch()
 
4
  from transformers import BlipProcessor, BlipForConditionalGeneration
5
  from pathlib import Path
6
  from safetensors.torch import load_file
7
+ from huggingface_hub import hf_hub_download
8
  from PIL import Image
9
  import matplotlib.pyplot as plt
10
  from matplotlib.colors import hex2color
11
  import stone
12
  import os
13
+ import spaces
14
+
15
+ from huggingface_hub import login
16
+ login()
17
 
18
  # Define model initialization functions
19
  def load_model(model_name, token=None):
 
63
  default_model = "stabilityai/sdxl-turbo"
64
  pipeline_text2image = load_model(default_model)
65
 
66
+ @spaces.GPU
67
+ def getimgen(prompt, model_name):
68
  global pipeline_text2image
69
+ pipeline_text2image = load_model(model_name)
70
  if model_name == "stabilityai/sdxl-turbo":
71
  return pipeline_text2image(prompt=prompt, guidance_scale=0.0, num_inference_steps=2).images[0]
72
  elif model_name == "runwayml/stable-diffusion-v1-5":
 
82
  blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
83
  blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large", torch_dtype=torch.float16).to("cuda")
84
 
85
+ @spaces.GPU
86
  def blip_caption_image(image, prefix):
87
  inputs = blip_processor(image, prefix, return_tensors="pt").to("cuda", torch.float16)
88
  out = blip_model.generate(**inputs)
 
121
  ax.add_patch(plt.Rectangle((0, 0), 1, 1, color=sorted_hex_codes[i]))
122
  return fig
123
 
124
+ @spaces.GPU
125
+ def generate_images_plots(prompt, model_name):
126
+ global pipeline_text2image
127
+ pipeline_text2image = load_model(model_name)
128
  foldername = "temp"
129
  Path(foldername).mkdir(parents=True, exist_ok=True)
130
+ images = [getimgen(prompt, model_name) for _ in range(10)]
131
  genders = []
132
  skintones = []
133
  for image, i in zip(images, range(10)):
 
145
 
146
  with gr.Blocks(title="Skin Tone and Gender bias in Text to Image Models") as demo:
147
  gr.Markdown("# Skin Tone and Gender bias in Text to Image Models")
 
 
 
148
  model_dropdown = gr.Dropdown(
149
  label="Choose a model",
150
  choices=[
 
166
  object_fit="contain",
167
  height="auto"
168
  )
169
+ gr.LoginButton()
170
+ gr.Markdown('### You need to log in to your Hugging Face account to run Stable Diffusion 3')
171
  btn = gr.Button("Generate images", scale=0)
172
  with gr.Row(equal_height=True):
173
  skinplot = gr.Plot(label="Skin Tone")
174
  genplot = gr.Plot(label="Gender")
175
+ btn.click(generate_images_plots, inputs=[prompt, model_dropdown], outputs=[gallery, skinplot, genplot])
176
 
177
+ demo.launch(debug=True)