Avijit Ghosh commited on
Commit
17497eb
1 Parent(s): e36c39f
Files changed (1) hide show
  1. app.py +18 -18
app.py CHANGED
@@ -4,19 +4,15 @@ from diffusers import DiffusionPipeline, StableDiffusionPipeline, StableDiffusio
4
  from transformers import BlipProcessor, BlipForConditionalGeneration
5
  from pathlib import Path
6
  from safetensors.torch import load_file
7
- from huggingface_hub import hf_hub_download
8
  from PIL import Image
9
  import matplotlib.pyplot as plt
10
  from matplotlib.colors import hex2color
11
  import stone
12
  import os
13
- import spaces
14
-
15
- from huggingface_hub import login
16
- login()
17
 
18
  # Define model initialization functions
19
- def load_model(model_name):
20
  if model_name == "stabilityai/sdxl-turbo":
21
  pipeline = DiffusionPipeline.from_pretrained(
22
  model_name,
@@ -48,9 +44,12 @@ def load_model(model_name):
48
  variant="fp16"
49
  ).to("cuda")
50
  elif model_name == "stabilityai/stable-diffusion-3-medium-diffusers":
 
 
51
  pipeline = StableDiffusion3Pipeline.from_pretrained(
52
  model_name,
53
- torch_dtype=torch.float16
 
54
  ).to("cuda")
55
  else:
56
  raise ValueError("Unknown model name")
@@ -60,10 +59,10 @@ def load_model(model_name):
60
  default_model = "stabilityai/sdxl-turbo"
61
  pipeline_text2image = load_model(default_model)
62
 
63
- @spaces.GPU
64
- def getimgen(prompt, model_name):
65
  global pipeline_text2image
66
- pipeline_text2image = load_model(model_name)
67
  if model_name == "stabilityai/sdxl-turbo":
68
  return pipeline_text2image(prompt=prompt, guidance_scale=0.0, num_inference_steps=2).images[0]
69
  elif model_name == "runwayml/stable-diffusion-v1-5":
@@ -79,7 +78,7 @@ def getimgen(prompt, model_name):
79
  blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
80
  blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large", torch_dtype=torch.float16).to("cuda")
81
 
82
- @spaces.GPU
83
  def blip_caption_image(image, prefix):
84
  inputs = blip_processor(image, prefix, return_tensors="pt").to("cuda", torch.float16)
85
  out = blip_model.generate(**inputs)
@@ -118,13 +117,11 @@ def skintoneplot(hex_codes):
118
  ax.add_patch(plt.Rectangle((0, 0), 1, 1, color=sorted_hex_codes[i]))
119
  return fig
120
 
121
- @spaces.GPU
122
- def generate_images_plots(prompt, model_name):
123
- global pipeline_text2image
124
- pipeline_text2image = load_model(model_name)
125
  foldername = "temp"
126
  Path(foldername).mkdir(parents=True, exist_ok=True)
127
- images = [getimgen(prompt, model_name) for _ in range(10)]
128
  genders = []
129
  skintones = []
130
  for image, i in zip(images, range(10)):
@@ -142,6 +139,9 @@ def generate_images_plots(prompt, model_name):
142
 
143
  with gr.Blocks(title="Skin Tone and Gender bias in Text to Image Models") as demo:
144
  gr.Markdown("# Skin Tone and Gender bias in Text to Image Models")
 
 
 
145
  model_dropdown = gr.Dropdown(
146
  label="Choose a model",
147
  choices=[
@@ -167,6 +167,6 @@ with gr.Blocks(title="Skin Tone and Gender bias in Text to Image Models") as dem
167
  with gr.Row(equal_height=True):
168
  skinplot = gr.Plot(label="Skin Tone")
169
  genplot = gr.Plot(label="Gender")
170
- btn.click(generate_images_plots, inputs=[prompt, model_dropdown], outputs=[gallery, skinplot, genplot])
171
 
172
- demo.launch(debug=True)
 
4
  from transformers import BlipProcessor, BlipForConditionalGeneration
5
  from pathlib import Path
6
  from safetensors.torch import load_file
7
+ from huggingface_hub import hf_hub_download, list_models
8
  from PIL import Image
9
  import matplotlib.pyplot as plt
10
  from matplotlib.colors import hex2color
11
  import stone
12
  import os
 
 
 
 
13
 
14
  # Define model initialization functions
15
+ def load_model(model_name, token=None):
16
  if model_name == "stabilityai/sdxl-turbo":
17
  pipeline = DiffusionPipeline.from_pretrained(
18
  model_name,
 
44
  variant="fp16"
45
  ).to("cuda")
46
  elif model_name == "stabilityai/stable-diffusion-3-medium-diffusers":
47
+ if token is None:
48
+ raise ValueError("Hugging Face token is required to access this model")
49
  pipeline = StableDiffusion3Pipeline.from_pretrained(
50
  model_name,
51
+ torch_dtype=torch.float16,
52
+ use_auth_token=token
53
  ).to("cuda")
54
  else:
55
  raise ValueError("Unknown model name")
 
59
  default_model = "stabilityai/sdxl-turbo"
60
  pipeline_text2image = load_model(default_model)
61
 
62
+ @gr.outputs.Image(type="pil")
63
+ def getimgen(prompt, model_name, token):
64
  global pipeline_text2image
65
+ pipeline_text2image = load_model(model_name, token)
66
  if model_name == "stabilityai/sdxl-turbo":
67
  return pipeline_text2image(prompt=prompt, guidance_scale=0.0, num_inference_steps=2).images[0]
68
  elif model_name == "runwayml/stable-diffusion-v1-5":
 
78
  blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
79
  blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large", torch_dtype=torch.float16).to("cuda")
80
 
81
+ @gr.outputs.Textbox(type="str")
82
  def blip_caption_image(image, prefix):
83
  inputs = blip_processor(image, prefix, return_tensors="pt").to("cuda", torch.float16)
84
  out = blip_model.generate(**inputs)
 
117
  ax.add_patch(plt.Rectangle((0, 0), 1, 1, color=sorted_hex_codes[i]))
118
  return fig
119
 
120
+ @gr.outputs.Image(type="pil")
121
+ def generate_images_plots(prompt, model_name, token):
 
 
122
  foldername = "temp"
123
  Path(foldername).mkdir(parents=True, exist_ok=True)
124
+ images = [getimgen(prompt, model_name, token) for _ in range(10)]
125
  genders = []
126
  skintones = []
127
  for image, i in zip(images, range(10)):
 
139
 
140
  with gr.Blocks(title="Skin Tone and Gender bias in Text to Image Models") as demo:
141
  gr.Markdown("# Skin Tone and Gender bias in Text to Image Models")
142
+ gr.LoginButton() # Add a login button for Hugging Face
143
+ profile = gr.State()
144
+ token = gr.State()
145
  model_dropdown = gr.Dropdown(
146
  label="Choose a model",
147
  choices=[
 
167
  with gr.Row(equal_height=True):
168
  skinplot = gr.Plot(label="Skin Tone")
169
  genplot = gr.Plot(label="Gender")
170
+ btn.click(generate_images_plots, inputs=[prompt, model_dropdown, token], outputs=[gallery, skinplot, genplot])
171
 
172
+ demo.launch()