zseid commited on
Commit
f44c507
1 Parent(s): 5c44a55

switch to stable diffusion 2 for generation, euler scheduler, add xformers

Browse files
Files changed (2) hide show
  1. app.py +10 -12
  2. requirements.txt +2 -1
app.py CHANGED
@@ -18,19 +18,19 @@ from saac.prompt_generation.prompts import generate_prompts,generate_occupations
18
  from saac.prompt_generation.prompt_utils import score_prompt
19
  from saac.image_analysis.process import process_image_pil
20
  from saac.evaluation.eval_utils import generate_countplot, lumia_violinplot, process_analysis, generate_histplot,rgb_intensity,EVAL_DATA_DIRECTORY
21
- from saac.evaluation.evaluate import evaluate_gender_by_adjectives,evaluate_gender_by_occupation,evaluate_skin_by_adjectives,evaluate_skin_by_occupation
22
  from datasets import load_dataset
23
- from diffusers import DiffusionPipeline, PNDMScheduler
24
 
25
  device = "cuda" if torch.cuda.is_available() else "cpu"
26
  STABLE_MODELS = ["Stable Diffusion v1.5", "Midjourney"]
27
  results = dict()
28
- results[STABLE_MODELS[0]] = process_analysis(os.path.join(EVAL_DATA_DIRECTORY,'raw',"stable_diffusion_raw_processed.csv"))
29
- results[STABLE_MODELS[1]] = process_analysis(os.path.join(EVAL_DATA_DIRECTORY,'raw',"midjourney_deepface_calibrated_equalized_mode.csv"))
30
 
31
- scheduler = PNDMScheduler.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="scheduler", prediction_type="v_prediction",revision="fp16",
32
- torch_dtype=torch.float16)
33
- pipe = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", scheduler=scheduler)
34
  pipe = pipe.to(device)
35
 
36
  LOOKS = sorted(list(generate_traits()['tag']))#["beautiful", "stunning", "handsome", "ugly", "plain", "repulsive", "arrogant", "trustworthy"]
@@ -48,8 +48,7 @@ def fig2img(fig):
48
 
49
  def trait_graph(model,hist=True):
50
  tda_res,occ_res = results[model]
51
- pass_gen = evaluate_gender_by_adjectives(tda_res)
52
- pass_skin = evaluate_skin_by_adjectives(tda_res)
53
  fig = None
54
  if not hist:
55
  fig = generate_countplot(tda_res, 'tda_sentiment_val', 'gender_detected_val',
@@ -78,10 +77,9 @@ def trait_graph(model,hist=True):
78
  return pass_skin,pass_gen,fig2img(fig2),fig2img(fig)
79
  def occ_graph(model):
80
  tda_res,occ_result = results[model]
81
- pass_skin = evaluate_skin_by_occupation(occ_result)
82
- pass_gen = evaluate_gender_by_occupation(occ_result)
83
  fig = generate_histplot(occ_result, 'a_median', 'gender_detected_val',
84
- title='Gender Distribution by Median Annual Salary',
85
  xlabel= 'Median Annual Salary',
86
  ylabel= 'Count',)
87
  fig2 = lumia_violinplot(df=occ_result, x_col='a_median',
 
18
  from saac.prompt_generation.prompt_utils import score_prompt
19
  from saac.image_analysis.process import process_image_pil
20
  from saac.evaluation.eval_utils import generate_countplot, lumia_violinplot, process_analysis, generate_histplot,rgb_intensity,EVAL_DATA_DIRECTORY
21
+ from saac.evaluation.evaluate import evaluate_by_adjectives,evaluate_by_occupation
22
  from datasets import load_dataset
23
+ from diffusers import DiffusionPipeline, PNDMScheduler,EulerDiscreteScheduler
24
 
25
  device = "cuda" if torch.cuda.is_available() else "cpu"
26
  STABLE_MODELS = ["Stable Diffusion v1.5", "Midjourney"]
27
  results = dict()
28
+ results[STABLE_MODELS[0]] = process_analysis(os.path.join(EVAL_DATA_DIRECTORY,'raw',"stable_diffusion_raw_processed.csv"),filtered=True)
29
+ results[STABLE_MODELS[1]] = process_analysis(os.path.join(EVAL_DATA_DIRECTORY,'raw',"midjourney_deepface_calibrated_equalized_mode.csv"),filtered=True)
30
 
31
+ # scheduler = PNDMScheduler.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="scheduler", prediction_type="v_prediction",revision="fp16",torch_dtype=torch.float16)
32
+ esched = EulerDiscreteScheduler.from_pretrained("stabilityai/stable-diffusion-2-base",subfolder="scheduler")
33
+ pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-base", scheduler=esched,)
34
  pipe = pipe.to(device)
35
 
36
  LOOKS = sorted(list(generate_traits()['tag']))#["beautiful", "stunning", "handsome", "ugly", "plain", "repulsive", "arrogant", "trustworthy"]
 
48
 
49
  def trait_graph(model,hist=True):
50
  tda_res,occ_res = results[model]
51
+ pass_gen,pass_skin = evaluate_by_adjectives(adjective_df=tda_res)
 
52
  fig = None
53
  if not hist:
54
  fig = generate_countplot(tda_res, 'tda_sentiment_val', 'gender_detected_val',
 
77
  return pass_skin,pass_gen,fig2img(fig2),fig2img(fig)
78
  def occ_graph(model):
79
  tda_res,occ_result = results[model]
80
+ pass_gen,pass_skin = evaluate_by_occupation(occupation_df=occ_result)
 
81
  fig = generate_histplot(occ_result, 'a_median', 'gender_detected_val',
82
+ title='Gender Distribution by Median Annualg Salary',
83
  xlabel= 'Median Annual Salary',
84
  ylabel= 'Count',)
85
  fig2 = lumia_violinplot(df=occ_result, x_col='a_median',
requirements.txt CHANGED
@@ -3,4 +3,5 @@ gradio
3
  transformers
4
  diffusers
5
  torch
6
- accelerate
 
 
3
  transformers
4
  diffusers
5
  torch
6
+ accelerate
7
+ xformers