zseid
commited on
Commit
•
f44c507
1
Parent(s):
5c44a55
switch to stable diffusion 2 for generation, euler scheduler, add xformers
Browse files- app.py +10 -12
- requirements.txt +2 -1
app.py
CHANGED
@@ -18,19 +18,19 @@ from saac.prompt_generation.prompts import generate_prompts,generate_occupations
|
|
18 |
from saac.prompt_generation.prompt_utils import score_prompt
|
19 |
from saac.image_analysis.process import process_image_pil
|
20 |
from saac.evaluation.eval_utils import generate_countplot, lumia_violinplot, process_analysis, generate_histplot,rgb_intensity,EVAL_DATA_DIRECTORY
|
21 |
-
from saac.evaluation.evaluate import
|
22 |
from datasets import load_dataset
|
23 |
-
from diffusers import DiffusionPipeline, PNDMScheduler
|
24 |
|
25 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
26 |
STABLE_MODELS = ["Stable Diffusion v1.5", "Midjourney"]
|
27 |
results = dict()
|
28 |
-
results[STABLE_MODELS[0]] = process_analysis(os.path.join(EVAL_DATA_DIRECTORY,'raw',"stable_diffusion_raw_processed.csv"))
|
29 |
-
results[STABLE_MODELS[1]] = process_analysis(os.path.join(EVAL_DATA_DIRECTORY,'raw',"midjourney_deepface_calibrated_equalized_mode.csv"))
|
30 |
|
31 |
-
scheduler = PNDMScheduler.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="scheduler", prediction_type="v_prediction",revision="fp16",
|
32 |
-
|
33 |
-
pipe = DiffusionPipeline.from_pretrained("
|
34 |
pipe = pipe.to(device)
|
35 |
|
36 |
LOOKS = sorted(list(generate_traits()['tag']))#["beautiful", "stunning", "handsome", "ugly", "plain", "repulsive", "arrogant", "trustworthy"]
|
@@ -48,8 +48,7 @@ def fig2img(fig):
|
|
48 |
|
49 |
def trait_graph(model,hist=True):
|
50 |
tda_res,occ_res = results[model]
|
51 |
-
pass_gen =
|
52 |
-
pass_skin = evaluate_skin_by_adjectives(tda_res)
|
53 |
fig = None
|
54 |
if not hist:
|
55 |
fig = generate_countplot(tda_res, 'tda_sentiment_val', 'gender_detected_val',
|
@@ -78,10 +77,9 @@ def trait_graph(model,hist=True):
|
|
78 |
return pass_skin,pass_gen,fig2img(fig2),fig2img(fig)
|
79 |
def occ_graph(model):
|
80 |
tda_res,occ_result = results[model]
|
81 |
-
pass_skin =
|
82 |
-
pass_gen = evaluate_gender_by_occupation(occ_result)
|
83 |
fig = generate_histplot(occ_result, 'a_median', 'gender_detected_val',
|
84 |
-
title='Gender Distribution by Median
|
85 |
xlabel= 'Median Annual Salary',
|
86 |
ylabel= 'Count',)
|
87 |
fig2 = lumia_violinplot(df=occ_result, x_col='a_median',
|
|
|
18 |
from saac.prompt_generation.prompt_utils import score_prompt
|
19 |
from saac.image_analysis.process import process_image_pil
|
20 |
from saac.evaluation.eval_utils import generate_countplot, lumia_violinplot, process_analysis, generate_histplot,rgb_intensity,EVAL_DATA_DIRECTORY
|
21 |
+
from saac.evaluation.evaluate import evaluate_by_adjectives,evaluate_by_occupation
|
22 |
from datasets import load_dataset
|
23 |
+
from diffusers import DiffusionPipeline, PNDMScheduler,EulerDiscreteScheduler
|
24 |
|
25 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
26 |
STABLE_MODELS = ["Stable Diffusion v1.5", "Midjourney"]
|
27 |
results = dict()
|
28 |
+
results[STABLE_MODELS[0]] = process_analysis(os.path.join(EVAL_DATA_DIRECTORY,'raw',"stable_diffusion_raw_processed.csv"),filtered=True)
|
29 |
+
results[STABLE_MODELS[1]] = process_analysis(os.path.join(EVAL_DATA_DIRECTORY,'raw',"midjourney_deepface_calibrated_equalized_mode.csv"),filtered=True)
|
30 |
|
31 |
+
# scheduler = PNDMScheduler.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="scheduler", prediction_type="v_prediction",revision="fp16",torch_dtype=torch.float16)
|
32 |
+
esched = EulerDiscreteScheduler.from_pretrained("stabilityai/stable-diffusion-2-base",subfolder="scheduler")
|
33 |
+
pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-base", scheduler=esched,)
|
34 |
pipe = pipe.to(device)
|
35 |
|
36 |
LOOKS = sorted(list(generate_traits()['tag']))#["beautiful", "stunning", "handsome", "ugly", "plain", "repulsive", "arrogant", "trustworthy"]
|
|
|
48 |
|
49 |
def trait_graph(model,hist=True):
|
50 |
tda_res,occ_res = results[model]
|
51 |
+
pass_gen,pass_skin = evaluate_by_adjectives(adjective_df=tda_res)
|
|
|
52 |
fig = None
|
53 |
if not hist:
|
54 |
fig = generate_countplot(tda_res, 'tda_sentiment_val', 'gender_detected_val',
|
|
|
77 |
return pass_skin,pass_gen,fig2img(fig2),fig2img(fig)
|
78 |
def occ_graph(model):
|
79 |
tda_res,occ_result = results[model]
|
80 |
+
pass_gen,pass_skin = evaluate_by_occupation(occupation_df=occ_result)
|
|
|
81 |
fig = generate_histplot(occ_result, 'a_median', 'gender_detected_val',
|
82 |
+
title='Gender Distribution by Median Annualg Salary',
|
83 |
xlabel= 'Median Annual Salary',
|
84 |
ylabel= 'Count',)
|
85 |
fig2 = lumia_violinplot(df=occ_result, x_col='a_median',
|
requirements.txt
CHANGED
@@ -3,4 +3,5 @@ gradio
|
|
3 |
transformers
|
4 |
diffusers
|
5 |
torch
|
6 |
-
accelerate
|
|
|
|
3 |
transformers
|
4 |
diffusers
|
5 |
torch
|
6 |
+
accelerate
|
7 |
+
xformers
|