zseid commited on
Commit
5d822cb
1 Parent(s): 73b9323

call facia graphs, use clip prompt graph gui base

Browse files
Files changed (2) hide show
  1. app.py +236 -0
  2. requirements.txt +2 -0
app.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #initial fork from https://huggingface.co/spaces/Manjushri/SD-2.1-CPU/raw/main/app.py, https://huggingface.co/spaces/bhautikj/sd_clip_bias
2
+
3
+ import gradio as gr
4
+ import torch
5
+ import numpy as np
6
+
7
+ import matplotlib
8
+ matplotlib.use("Agg")
9
+ import matplotlib.pyplot as plt
10
+
11
+ from PIL import Image
12
+ import collections
13
+ import numpy as np
14
+ import pandas as pd
15
+ import io
16
+ from saac.prompt_generation.prompts import generate_prompts,generate_occupations,generate_traits
17
+ from saac.prompt_generation.prompt_utils import score_prompt
18
+ from saac.evaluation.eval_utils import generate_countplot, lumia_violinplot, process_analysis, generate_histplot
19
+ from datasets import load_dataset
20
+ from diffusers import DiffusionPipeline, PNDMScheduler
21
+
22
+ device = "cuda" if torch.cuda.is_available() else "cpu"
23
+ STABLE_MODELS = ["runwayml/stable-diffusion-v1-5", "Midjourney"]
24
+
25
+ # TODO Make model selectable input
26
+ scheduler = PNDMScheduler.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="scheduler", prediction_type="v_prediction")
27
+ pipe = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", scheduler=scheduler)
28
+ pipe = pipe.to(device)
29
+
30
+ tokenizer = pipe.tokenizer
31
+ text_encoder = pipe.text_encoder
32
+
33
+ #
34
+ # AUTHORS NOTE:
35
+ # Being absolutely explicit: the genders and ethnicities listed here
36
+ # are reductive. This is done not to erase but make analysis here
37
+ # simpler, tractable and easier to understand
38
+ #
39
+ GENDERS = ["male", "female"]
40
+ ETHNICITIES = ["black", "white", "asian"]
41
+ LOOKS = list(generate_traits()['tag'])#["beautiful", "stunning", "handsome", "ugly", "plain", "repulsive", "arrogant", "trustworthy"]
42
+ JOBS = list(generate_occupations()['tag'])#["CEO", "doctor", "nurse", "cashier", "janitor", "engineer", "pilot", "dentist", "leader"]
43
+ RENDERPREFIX = "a high quality photo of a"
44
+
45
+ def echoToken(token):
46
+ res = getMostSimilar(tokenizer, text_encoder, token)
47
+ return ",".join(res)
48
+
49
+ def getEmbeddingForToken(tokenizer, token):
50
+ token_ids = tokenizer.encode(token)[1:-1]
51
+ if len(token_ids) != 1:
52
+ print(len(token_ids))
53
+ raise
54
+ token_id = token_ids[0]
55
+ return token_id, text_encoder.get_input_embeddings().weight.data[token_id].unsqueeze(0)
56
+
57
+ def getMostSimilar(tokenizer, text_encoder, token, numResults=50):
58
+ internal_embs = text_encoder.text_model.embeddings.token_embedding.weight
59
+ tID, tok = getEmbeddingForToken(tokenizer, token)
60
+
61
+ cos = torch.nn.CosineSimilarity(dim=1, eps=1e-6)
62
+ scores = cos(internal_embs.to("cpu").to(torch.float32), tok.to("cpu").to(torch.float32))
63
+ sorted_scores, sorted_ids = torch.sort(scores, descending=True)
64
+ best_ids = sorted_ids[0:numResults].detach().numpy()
65
+ best_scores = sorted_scores[0:numResults].detach().numpy()
66
+
67
+ res = []
68
+ for best_id, best_score in zip(best_ids, best_scores):
69
+ #res.append((tokenizer.decode(best_id), best_score))
70
+ res.append("[" + tokenizer.decode(best_id) + "," + str(best_score) + "]")
71
+ return res[1:]
72
+
73
+ def computeTermSimilarity(tokenizer, text_encoder, termA, termB):
74
+ inputs = tokenizer([termA, termB], padding=True, return_tensors="pt").to("cpu")
75
+ outputs = text_encoder(**inputs)
76
+ cos = torch.nn.CosineSimilarity(dim=-1, eps=1e-6)
77
+ val = cos(outputs.pooler_output[0], outputs.pooler_output[1]).item()
78
+ return float(val)
79
+
80
+ def computeJob(tokenizer, text_encoder, job):
81
+ res = {}
82
+ neutralPrompt = " ".join([RENDERPREFIX, job])
83
+ titleText = neutralPrompt
84
+ for gender in GENDERS:
85
+ for ethnicity in ETHNICITIES:
86
+ prompt = " ".join([RENDERPREFIX, ethnicity, gender, job])
87
+ val = computeTermSimilarity(tokenizer, text_encoder, prompt, neutralPrompt)
88
+ res[prompt] = val
89
+
90
+ return titleText, sorted(res.items(), reverse=True)
91
+
92
+ def computeLook(tokenizer, text_encoder, look):
93
+ res = {}
94
+ titleText = " ".join([RENDERPREFIX,
95
+ look,
96
+ "[",
97
+ "|".join(GENDERS),
98
+ "]"])
99
+
100
+ for gender in GENDERS:
101
+ neutralPromptGender = " ".join([RENDERPREFIX, look, gender])
102
+ for ethnicity in ETHNICITIES:
103
+ prompt = " ".join([RENDERPREFIX, look, ethnicity, gender])
104
+ val = computeTermSimilarity(tokenizer, text_encoder, prompt, neutralPromptGender)
105
+ res[prompt] = val
106
+
107
+ return titleText, sorted(res.items(), reverse=True)
108
+
109
+ # via https://stackoverflow.com/questions/57316491/how-to-convert-matplotlib-figure-to-pil-image-object-without-saving-image
110
+ def fig2img(fig):
111
+ """Convert a Matplotlib figure to a PIL Image and return it"""
112
+ buf = io.BytesIO()
113
+ fig.savefig(buf)
114
+ buf.seek(0)
115
+ img = Image.open(buf)
116
+ return img
117
+
118
+ def computePlot(title, results, scaleXAxis=True):
119
+ x = list(map(lambda x:x[0], results))
120
+ y = list(map(lambda x:x[1], results))
121
+
122
+ fig, ax = plt.subplots(1, 1, figsize=(10, 5))
123
+ y_pos = np.arange(len(x))
124
+
125
+ hbars = ax.barh(y_pos, y, left=0, align='center')
126
+ ax.set_yticks(y_pos, labels=x)
127
+ ax.invert_yaxis() # labels read top-to-bottom
128
+ ax.set_xlabel('Cosine similarity - take care to note compressed X-axis')
129
+ ax.set_title('Similarity to "' + title + '"')
130
+
131
+ # Label with specially formatted floats
132
+ ax.bar_label(hbars, fmt='%.3f')
133
+ minR = np.min(y)
134
+ maxR = np.max(y)
135
+ diffR = maxR-minR
136
+
137
+ if scaleXAxis:
138
+ ax.set_xlim(left=minR-0.1*diffR, right=maxR+0.1*diffR)
139
+ else:
140
+ ax.set_xlim(left=0.0, right=1.0)
141
+ plt.tight_layout()
142
+ plt.close()
143
+ return fig2img(fig)
144
+
145
+ def computeJobBias(job):
146
+ title, results = computeJob(tokenizer, text_encoder, job)
147
+ return computePlot(title, results)
148
+
149
+ def computeLookBias(look):
150
+ title, results = computeLook(tokenizer, text_encoder, look)
151
+ return computePlot(title, results)
152
+ def trait_graph(trait,hist=True):
153
+ tda_res, occ_result = process_analysis()
154
+ fig = None
155
+ if not hist:
156
+ fig = generate_countplot(tda_res, 'tda_sentiment_val', 'gender_detected_val',
157
+ title='Gender Count by Trait Sentiment',
158
+ xlabel='Trait Sentiment',
159
+ ylabel='Count',
160
+ legend_title='Gender')
161
+ else:
162
+ df = tda_res
163
+ df['tda_sentiment_val'] = pd.Categorical(df['tda_sentiment_val'],
164
+ ['very negative', 'negative', 'neutral', 'positive', 'very positive'])
165
+ fig = generate_histplot(tda_res, 'tda_sentiment_val', 'gender_detected_val',
166
+ title='Gender Distribution by Trait Sentiment',
167
+ xlabel='Trait Sentiment',
168
+ ylabel='Count', )
169
+
170
+ fig2 = lumia_violinplot(df = tda_res,
171
+ x_col = 'tda_compound',
172
+ rgb_col = 'skincolor',
173
+ n_bins = 21,
174
+ widths_val = 0.05,
175
+ points_val = 100,
176
+ x_label = 'TDA Sentiment',
177
+ y_label = 'Skincolor Intensity',
178
+ title = 'Skin Color Intensity, Binned by TDA Sentiment',)
179
+ return fig2img(fig),fig2img(fig2)
180
+ def occ_graph(occ):
181
+ tda_res, occ_result = process_analysis()
182
+ fig = generate_histplot(occ_result, 'a_median', 'gender_detected_val',
183
+ title='Gender Distribution by Median Annual Salary',
184
+ xlabel= 'Median Annual Salary',
185
+ ylabel= 'Count',)
186
+ fig2 = lumia_violinplot(df=occ_result, x_col='a_median',
187
+ rgb_col='skincolor',
188
+ n_bins=21,
189
+ widths_val=7500.0,
190
+ points_val=100,
191
+ x_label='Median Salary',
192
+ y_label='Skincolor Intensity',
193
+ title='Skin Color Intensity, Binned by Median Salary')
194
+ return fig2img(fig),fig2img(fig2)
195
+
196
+ if __name__=='__main__':
197
+ disclaimerString = ""
198
+
199
+ jobInterface = gr.Interface(fn=occ_graph,
200
+ inputs=[gr.Dropdown(JOBS, label="occupation")],
201
+ outputs=['image','image'],
202
+ description="Referencing a specific profession comes loaded with associations of gender and ethnicity."
203
+ " Text to image models provide an opportunity to explicitly specify an underrepresented group, but first we must understand our default behavior.",
204
+ title="How occupation affects txt2img gender and skin color representation",
205
+ article = "To view how mentioning a particular occupation affects the gender and skin colors in faces of text to image generators, select a job."
206
+ " Promotional materials, advertising, and even criminal sketches which do not explicitly specify a gender or ethnicity term will tend towards the displayed distributions.")
207
+
208
+ affectInterface = gr.Interface(fn=trait_graph,
209
+ inputs=[gr.Dropdown(LOOKS, label="trait")],
210
+ outputs=['image','image'],
211
+ description="Certain adjectives can reinforce harmful stereotypes associated with gender roles and ethnic backgrounds."
212
+ "Text to image models provide an opportunity to understand how prompting a particular human expression could be triggering,"
213
+ " or why an uncommon combination might provide important examples to minorities without default representation.",
214
+ title="How word sentiment affects txt2img gender and skin color representation",
215
+ article = "To view how characterizing a person with a positive, negative, or neutral term influences the gender and skin color composition of AI-generated faces, select a direction.")
216
+
217
+ jobInterfaceManual = gr.Interface(fn=score_prompt,
218
+ inputs=[gr.inputs.Textbox()],
219
+ outputs='text',
220
+ description="Analyze prompt",
221
+ title="Understand which prompts require further engineering to represent equally genders and skin colors",
222
+ article = "Try modifying a trait or occupational prompt to produce a result in the minority representation!")
223
+
224
+
225
+ toolInterface = gr.Interface(fn=lambda t: trait_graph(t,hist=False),inputs=[gr.Dropdown(STABLE_MODELS,label="text-to-image model")],outputs='image',
226
+ title="How different models fare in gender and skin color representation across a variety of prompts",
227
+ description="The training set, vocabulary, pre and post processing of generative AI tools doesn't treat everyone equally. "
228
+ "Within a 95% margin of statistical error, the following tests expose bias in gender and skin color.",
229
+ article="To learn more about this process, <a href=\"http://github.com/TRSS-Research/SAAC.git\"/> Visit the repo</a>"
230
+ )
231
+
232
+ gr.TabbedInterface(
233
+ [jobInterface, affectInterface, jobInterfaceManual,toolInterface],
234
+ ["Occupational Bias", "Adjectival Bias", "Prompt analysis",'FACIA model auditing'],
235
+ title = "Text-to-Image Bias Explorer"
236
+ ).launch(share=True)
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ git+https://github.com/TRSS-Research/SAAC.git
2
+ gradio