Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
from PIL import Image | |
import base64 | |
from io import BytesIO | |
import pandas as pd | |
import numpy as np | |
import random as rd | |
import math | |
from diffusers import StableDiffusionPipeline | |
from transformers import CLIPProcessor, CLIPModel, Pix2StructProcessor, Pix2StructForConditionalGeneration, ViltProcessor, ViltForQuestionAnswering, BlipProcessor, BlipForQuestionAnswering, AutoProcessor, AutoModelForCausalLM | |
import openai | |
clip_model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14") | |
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14") | |
vilt_model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa") | |
vilt_processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa") | |
import ds_manager as ds_mgr | |
MISSING_C = None | |
C1_B64s = [] | |
C2_B64s = [] | |
C1_PILs = [] | |
C2_PILs = [] | |
def updateErrorMsg(isError, text): | |
return gr.Markdown.update(visible=isError, value=text) | |
def moveStep1(): | |
variants = ["primary","secondary","secondary"] | |
#inter = [True, False, False] | |
tabs = [True, False, False] | |
return (gr.update(variant=variants[0]), | |
gr.update(variant=variants[1]), | |
gr.update(variant=variants[2]), | |
gr.update(visible=tabs[0]), | |
gr.update(visible=tabs[1]), | |
gr.update(visible=tabs[2])) | |
# Interaction with top tabs | |
def moveStep1_clear(): | |
variants = ["primary","secondary","secondary"] | |
#inter = [True, False, False] | |
tabs = [True, False, False] | |
return (gr.update(variant=variants[0]), | |
gr.update(variant=variants[1]), | |
gr.update(variant=variants[2]), | |
gr.update(visible=tabs[0]), | |
gr.update(visible=tabs[1]), | |
gr.update(visible=tabs[2]), | |
gr.Textbox.update(value=""), | |
gr.Textbox.update(value=""), | |
gr.Textbox.update(value=""), | |
gr.Textbox.update(value="")) | |
def moveStep2(): | |
variants = ["secondary","primary","secondary"] | |
#inter = [True, True, False] | |
tabs = [False, True, False] | |
return (gr.update(variant=variants[0]), | |
gr.update(variant=variants[1]), | |
gr.update(variant=variants[2]), | |
gr.update(visible=tabs[0]), | |
gr.update(visible=tabs[1]), | |
gr.update(visible=tabs[2])) | |
def moveStep3(): | |
variants = ["secondary","secondary","primary"] | |
#inter = [True, True, False] | |
tabs = [False, False, True] | |
return (gr.update(variant=variants[0]), | |
gr.update(variant=variants[1]), | |
gr.update(variant=variants[2]), | |
gr.update(visible=tabs[0]), | |
gr.update(visible=tabs[1]), | |
gr.update(visible=tabs[2])) | |
def decode_b64(b64s): | |
decoded = [] | |
for b64 in b64s: | |
decoded.append(Image.open(BytesIO(base64.b64decode(b64)))) | |
return decoded | |
def generate(prompt, openai_key): | |
prompt = prompt.lower().strip() | |
_, retrieved, _ = ds_mgr.getSavedSentences(prompt) | |
print(f"retrieved: {retrieved}") | |
if len(retrieved.index) > 0: | |
update_value = decode_b64(list(retrieved['b64'])) | |
print(f"update_value: {update_value}") | |
return update_value, list(retrieved['b64']) | |
openai.api_key = openai_key | |
response = openai.Image.create( | |
prompt=prompt, | |
n=4, | |
size="256x256", | |
response_format='b64_json' | |
) | |
image_b64s = [] | |
save_b64s = [] | |
for image in response['data']: | |
image_b64s.append(image['b64_json']) | |
save_b64s.append([prompt, image['b64_json']]) | |
save_df = pd.DataFrame(save_b64s, columns=["prompt", "b64"]) | |
print(f"save_df: {save_b64s}") | |
# save (save_df) | |
ds_mgr.saveSentences(save_df) | |
images = decode_b64(image_b64s) | |
# images = pipe(prompt, height=256, width=256, num_images_per_prompt=2).images | |
#print(images) | |
# return ( | |
# gr.update(value=images) | |
# ) | |
return images, image_b64s | |
def clip(imgs1, imgs2, g1, g2): | |
""" | |
imgs1: list of PIL Images | |
imgs1: list of PIL Images | |
g1: list of str (test-concepts 1) | |
g2: list of str (test-concepts 2) | |
returns avg_probs_imgs1, avg_probs_imgs2 - dicts for imgs1, imgs2 | |
({img index: {'g1': probability, 'g2': probability}}) | |
""" | |
# One call of CLIP processor + model - may need to batch later | |
inputs = clip_processor(text = g1 + g2, images = imgs1 + imgs2, | |
return_tensors="pt", padding=True) | |
outputs = clip_model(**inputs) | |
logits_imgs1 = outputs.logits_per_image[:len(imgs1)] | |
logits_imgs2 = outputs.logits_per_image[len(imgs1):] | |
probs_imgs1 = torch.softmax(logits_imgs1, dim=1) | |
probs_imgs2 = torch.softmax(logits_imgs2, dim=1) | |
avg_probs_imgs1 = {} | |
avg_probs_imgs2 = {} | |
# Calculate the probabilities of prompts in g1 and g2 for each image in imgs1 | |
for idx, img_probs in enumerate(probs_imgs1): | |
prob_g1 = img_probs[:len(g1)].sum().item() | |
prob_g2 = img_probs[len(g1):].sum().item() | |
avg_probs_imgs1[idx] = {'g1': prob_g1, 'g2': prob_g2} | |
# Calculate the probabilities of prompts in g1 and g2 for each image in imgs2 | |
for idx, img_probs in enumerate(probs_imgs2): | |
prob_g1 = img_probs[:len(g1)].sum().item() | |
prob_g2 = img_probs[len(g1):].sum().item() | |
avg_probs_imgs2[idx] = {'g1': prob_g1, 'g2': prob_g2} | |
print(f"avg_probs_imgs1:\n{avg_probs_imgs1}") | |
print(f"avg_probs_imgs2:\n{avg_probs_imgs2}") | |
# Can do an average probability over all images - need to decide how we are using this | |
return avg_probs_imgs1, avg_probs_imgs2 | |
def vilt_test(imgs1, imgs2, g1, g2, model, processor): | |
avg_probs_imgs1 = {} | |
avg_probs_imgs2 = {} | |
for i, img in enumerate(imgs1): | |
g1c = rd.choice(g1) | |
g2c = rd.choice(g2) | |
encoding = processor(img, f'Is the image of a {g1c}?', return_tensors="pt") | |
outputs = model(**encoding) | |
logits = outputs.logits | |
idx = logits.argmax(-1).item() | |
ans = model.config.id2label[idx] | |
print("Predicted answer:", model.config.id2label[idx]) | |
logitsList = torch.softmax(logits, dim=1).flatten().tolist() | |
m = max(logitsList) | |
s = -math.inf | |
for logit in logitsList: | |
if s <= logit < m: | |
s = logit | |
t = sum(logitsList) | |
pm, ps = m/t, s/t | |
if 'yes' in ans: | |
avg_probs_imgs1[i] = {'g1': pm, 'g2': ps} | |
else: | |
avg_probs_imgs1[i] = {'g1': ps, 'g2': pm} | |
for i, img in enumerate(imgs2): | |
g2c = rd.choice(g2) | |
g1c = rd.choice(g1) | |
encoding = processor(img, f'Is the image of a {g2c}?', return_tensors="pt") | |
outputs = model(**encoding) | |
logits = outputs.logits | |
idx = logits.argmax(-1).item() | |
ans = model.config.id2label[idx] | |
print("Predicted answer:", model.config.id2label[idx]) | |
logitsList = torch.softmax(logits, dim=1).flatten().tolist() | |
m = max(logitsList) | |
s = -math.inf | |
for logit in logitsList: | |
if s <= logit < m: | |
s = logit | |
t = sum(logitsList) | |
pm, ps = m/t, s/t | |
if 'yes' in ans: | |
avg_probs_imgs2[i] = {'g1': ps, 'g2': pm} | |
else: | |
avg_probs_imgs2[i] = {'g1': pm, 'g2': ps} | |
print(f"avg_probs_imgs1:\n{avg_probs_imgs1}") | |
print(f"avg_probs_imgs2:\n{avg_probs_imgs2}") | |
return avg_probs_imgs1, avg_probs_imgs2 | |
def bloombergViz(att, numblocks, score, concept_images, concept_b64s, onRight=False): | |
leftColor = "#065b41" #"#555" | |
rightColor = "#35d4ac" #"#999" | |
# if flip: | |
# leftColor = "#35d4ac" #"#999" | |
# rightColor = "#065b41" #"#555" | |
spanClass = "tooltiptext_left" | |
if onRight: | |
spanClass = "tooltiptext_right" | |
# g1p is indices of score where g1 >= g2 | |
# g2p is indices of score where g2 < g1 | |
g1p = [] | |
g2p = [] | |
print(f"score: {score}") | |
for i in score: | |
if score[i]['g1'] >= score[i]['g2']: | |
g1p.append(i) | |
else: | |
g2p.append(i) | |
res = "" | |
for i in g1p: | |
disp = concept_b64s[i] | |
res += f"<div style='height:20px;width:20px;background-color:{leftColor};display:inline-block;position:relative' id='filled'><span class='{spanClass}' style='color:#FFF'><center><img src='data:image/jpeg;base64,{disp}'></center><br>This image was identified as more likely to depict a group 1 term.</span></div> " | |
for i in g2p: | |
disp = concept_b64s[i] | |
res += f"<div style='height:20px;width:20px;background-color:{rightColor};display:inline-block;position:relative' id='empty'><span class='{spanClass}' style='color:#FFF'><center><img src='data:image/jpeg;base64,{disp}'></center><br>This image was identified as more likely to depict a group 2 term.</span></div> " | |
return res | |
def att_bloombergViz(att, numblocks, scores, concept_images, concept_b64s, onRight=False): | |
viz = bloombergViz(att, numblocks, scores, concept_images, concept_b64s, onRight) | |
attHTML = f"<div style='border-style:solid;border-color:#999;border-radius:12px'>{att}: %<br>{viz}</div><br>" | |
return attHTML | |
def retrieveImgs(concept1, concept2, group1, group2, progress=gr.Progress()): | |
global MISSING_C, C1_B64s, C2_B64s, C1_PILs, C2_PILs | |
print(f"concept1: {concept1}. concept2: {concept2}. group1: {group1}. group2: {group2}") | |
print("RETRIEVE IMAGES CLICKED!") | |
G_MISSING_SPEC = [] | |
variants = ["secondary","primary","secondary"] | |
inter = [True, True, False] | |
tabs = [True, False] | |
bias_gen_states = [True, False] | |
bias_gen_label = "Generate New Images" | |
bias_test_label = "Test Model for Social Bias" | |
num2gen_update = gr.update(visible=True) #update the number of new sentences to generate | |
prog_vis = [True] | |
err_update = updateErrorMsg(False, "") | |
info_msg_update = gr.Markdown.update(visible=False, value="") | |
openai_gen_row_update = gr.Row.update(visible=True) | |
tested_model_dropdown_update = gr.Dropdown.update(visible=False) | |
tested_model_row_update = gr.Row.update(visible=False) | |
c1s = concept1.split(',') | |
c2s = concept2.split(',') | |
c1s = [c1.strip() for c1 in c1s] | |
c2s = [c2.strip() for c2 in c2s] | |
C1_PILs = [] | |
C2_PILs = [] | |
C1_B64s = [] | |
C2_B64s = [] | |
if not c1s or not c2s: | |
print("No terms entered!") | |
err_update = updateErrorMsg(True, "Please enter terms!") | |
variants = ["primary","secondary","secondary"] | |
inter = [True, False, False] | |
tabs = [True, False] | |
prog_vis = [False] | |
else: | |
tabs = [False, True] | |
progress(0, desc="Fetching saved images...") | |
for c1 in c1s: | |
_, retrieved, _ = ds_mgr.getSavedSentences(c1) | |
print(f"retrieved: {retrieved}") | |
if len(retrieved.index) > 0: | |
C1_B64s += list(retrieved['b64']) | |
C1_PILs += decode_b64(list(retrieved['b64'])) | |
print(f"c1_retrieved: {C1_B64s}") | |
for c2 in c2s: | |
_, retrieved, _ = ds_mgr.getSavedSentences(c2) | |
print(f"retrieved: {retrieved}") | |
if len(retrieved.index) > 0: | |
C2_B64s += list(retrieved['b64']) | |
C2_PILs += decode_b64(list(retrieved['b64'])) | |
print(f"c2_retrieved: {C2_B64s}") | |
if not C1_PILs or not C2_PILs: | |
err_update = updateErrorMsg(True, "No images were found for one or both concepts. Please enter OpenAI key and use Dall-E to generate new test images or change bias specification!") | |
if not C1_PILs and not C2_PILs: | |
MISSING_C = 0 | |
elif not C1_PILs: | |
MISSING_C = 1 | |
elif not C2_PILs: | |
MISSING_C = 2 | |
else: | |
print('there exist images for both!') | |
bias_gen_states = [False, True] | |
openai_gen_row_update = gr.Row.update(visible=False) | |
tested_model_dropdown_update = gr.Dropdown.update(visible=True) | |
tested_model_row_update = gr.Row.update(visible=True) | |
print(len(C1_PILs), len(C2_PILs), len(C1_B64s), len(C2_B64s)) | |
print(f"Will these show up?: {concept1}, {concept2}, {group1}, {group2}") | |
print(f"C1_B64s, C1_PILs: {C1_B64s} || {C1_PILs}") | |
print(f"C2_B64s, C2_PILs: {C2_B64s} || {C2_PILs}") | |
return ( | |
err_update, # error message | |
openai_gen_row_update, # OpenAI generation | |
num2gen_update, # Number of images to genrate | |
tested_model_row_update, #Tested Model Row | |
tested_model_dropdown_update, # Tested Model Dropdown | |
info_msg_update, # sentences retrieved info update | |
gr.update(visible=prog_vis), # progress bar top | |
gr.update(variant=variants[0], interactive=inter[0]), # breadcrumb btn1 | |
gr.update(variant=variants[1], interactive=inter[1]), # breadcrumb btn2 | |
gr.update(variant=variants[2], interactive=inter[2]), # breadcrumb btn3 | |
gr.update(visible=tabs[0]), # tab 1 | |
gr.update(visible=tabs[1]), # tab 2 | |
gr.Accordion.update(visible=bias_gen_states[1], label=f"Test images ({len(C1_PILs) + len(C2_PILs)})"), # accordion | |
gr.update(visible=True), # Row images | |
gr.update(value=C1_PILs+C2_PILs), #test images | |
gr.Button.update(visible=bias_gen_states[0], value=bias_gen_label), # gen btn | |
gr.Button.update(visible=bias_gen_states[1], value=bias_test_label), # bias test btn | |
gr.update(value=concept1), # concept1_fixed | |
gr.update(value=concept2), # concept2_fixed | |
gr.update(value=group1), # group1_fixed | |
gr.update(value=group2) # group2_fixed | |
) | |
def generateImgs(concept1, concept2, openai_key, num_imgs2gen, progress=gr.Progress()): | |
global MISSING_C, C1_B64s, C2_B64s, C1_PILs, C2_PILs | |
err_update = updateErrorMsg(False, "") | |
bias_test_label = "Test Model Using Imbalanced Images" | |
if MISSING_C == 0: | |
bias_gen_states = [True, False] | |
online_gen_visible = True | |
test_model_visible = False | |
elif MISSING_C == 1 or MISSING_C == 2: | |
bias_gen_states = [True, True] | |
online_gen_visible = True | |
test_model_visible = True | |
info_msg_update = gr.Markdown.update(visible=False, value="") | |
c1s = concept1.split(',') | |
c2s = concept2.split(',') | |
C1_PILs = [] | |
C2_PILs = [] | |
if not c1s or not c2s: | |
print("No terms entered!") | |
err_update = updateErrorMsg(True, "Please enter terms!") | |
variants = ["primary","secondary","secondary"] | |
inter = [True, False, False] | |
tabs = [True, False] | |
prog_vis = [False] | |
else: | |
if len(openai_key) == 0: | |
print("Empty OpenAI key!!!") | |
err_update = updateErrorMsg(True, "Please enter an OpenAI key!") | |
elif len(openai_key) < 10: | |
print("Wrong length OpenAI key!!!") | |
err_update = updateErrorMsg(True, "Please enter a correct OpenAI key!") | |
else: | |
progress(0, desc="Dall-E generation...") | |
C1_PILs = [] | |
C1_B64s = [] | |
for c1 in c1s: | |
prompt = c1 | |
PILs, c1_b64s = generate(prompt, openai_key) | |
C1_PILs += PILs | |
C1_B64s += c1_b64s | |
C2_PILs = [] | |
C2_B64s = [] | |
for c2 in c2s: | |
prompt = c2 | |
PILs, c2_b64s = generate(prompt, openai_key) | |
C2_PILs += PILs | |
C2_B64s += c2_b64s | |
bias_gen_states = [False, True] | |
online_gen_visible = False | |
test_model_visible = True | |
bias_test_label = "Test Model for Social Bias" | |
return (err_update, # err message if any | |
info_msg_update, # infor message about the number of imgs and coverage | |
gr.Row.update(visible=online_gen_visible), # online gen row | |
gr.Row.update(visible=test_model_visible), # tested model row | |
gr.Dropdown.update(visible=test_model_visible), # tested model selection dropdown | |
gr.Accordion.update(visible=test_model_visible, label=f"Test images ({len(C1_PILs)+len(C2_PILs)})"), # accordion | |
gr.update(visible=True), # Row images | |
gr.update(value=C1_PILs+C2_PILs), # test images | |
gr.update(visible=bias_gen_states[0]), # gen btn | |
gr.update(visible=bias_gen_states[1], value=bias_test_label) # bias btn | |
) | |
def startBiasTest(test_imgs, concept1, concept2, group1, group2, model_name, progress=gr.Progress()): | |
global C1_B64s, C2_B64s, C1_PILs, C2_PILs | |
variants = ["secondary","secondary","primary"] | |
inter = [True, True, True] | |
tabs = [False, False, True] | |
err_update = updateErrorMsg(False, "") | |
if len(test_imgs) == 0: | |
err_update = updateErrorMsg(True, "There are no images! (How'd you get here?)") | |
progress(0, desc="Starting social bias testing...") | |
g1 = group1.split(', ') | |
g2 = group2.split(', ') | |
avg_probs_imgs1, avg_probs_imgs2 = None, None | |
if model_name.lower() == 'clip': | |
avg_probs_imgs1, avg_probs_imgs2 = clip(C1_PILs, C2_PILs, g1, g2) | |
elif 'vilt' in model_name.lower(): | |
avg_probs_imgs1, avg_probs_imgs2 = vilt_test(C1_PILs, C2_PILs, g1, g2, vilt_model, vilt_processor) | |
else: | |
print("that's not right") | |
c1_html = att_bloombergViz(concept1, len(avg_probs_imgs1), avg_probs_imgs1, C1_PILs, C1_B64s, False) | |
c2_html = att_bloombergViz(concept2, len(avg_probs_imgs2), avg_probs_imgs2, C2_PILs, C2_B64s, True) | |
model_bias_dict_n = 0.0 | |
for key in avg_probs_imgs1: | |
model_bias_dict_n += avg_probs_imgs1[key]['g1'] | |
for key in avg_probs_imgs2: | |
model_bias_dict_n += avg_probs_imgs2[key]['g2'] | |
model_bias_dict_d = len(avg_probs_imgs1) + len(avg_probs_imgs2) | |
model_bias_dict = {f'bias score for {model_name} on {len(C1_PILs) + len(C2_PILs)} images': round(model_bias_dict_n/model_bias_dict_d, 2)} | |
group_labels_html_update = gr.HTML.update( | |
value=f"<div style='height:20px;width:20px;background-color:#065b41;display:inline-block;vertical-align:top'></div><div style='display:inline-block;vertical-align:top'> Image more likely classified as a Group 1 ({group1}) term </div> <div style='height:20px;width:20px;background-color:#35d4ac;display:inline-block;vertical-align:top'></div><div style='display:inline-block;vertical-align:top'> Image more likely classified as a Group 2 ({group2}) term </div>") | |
return (err_update, # error message | |
gr.Markdown.update(visible=True), # bar progress | |
gr.Button.update(variant=variants[0], interactive=inter[0]), # top breadcrumb button 1 | |
gr.Button.update(variant=variants[1], interactive=inter[1]), # top breadcrumb button 2 | |
gr.Button.update(variant=variants[2], interactive=inter[2]), # top breadcrumb button 3 | |
gr.update(visible=tabs[0]), # content tab/column 1 | |
gr.update(visible=tabs[1]), # content tab/column 2 | |
gr.update(visible=tabs[2]), # content tab/column 3 | |
model_bias_dict, # per model bias score | |
gr.update(value=c1_html), # c1 bloomberg viz | |
gr.update(value=c2_html), # c2 bloomberg viz | |
gr.update(value=concept1), # c1_fixed | |
gr.update(value=concept2), # c2_fixed | |
gr.update(value=group1), # g1_fixed | |
gr.update(value=group2), # g2_fixed | |
group_labels_html_update# group_labels_html | |
) | |
theme = gr.themes.Soft().set( | |
button_small_radius='*radius_xxs', | |
background_fill_primary='*neutral_50', | |
border_color_primary='*primary_50' | |
) | |
soft = gr.themes.Soft( | |
primary_hue="slate", | |
spacing_size="sm", | |
radius_size="md" | |
).set( | |
# body_background_fill="white", | |
button_primary_background_fill='*primary_400' | |
) | |
css_adds = "#group_row {background: white; border-color: white;} \ | |
#attribute_row {background: white; border-color: white;} \ | |
#tested_model_row {background: white; border-color: white;} \ | |
#button_row {background: white; border-color: white} \ | |
#examples_elem .label {display: none}\ | |
#con1_words {border-color: #E5E7EB;} \ | |
#con2_words {border-color: #E5E7EB;} \ | |
#grp1_words {border-color: #E5E7EB;} \ | |
#grp2_words {border-color: #E5E7EB;} \ | |
#con1_words_fixed {border-color: #E5E7EB;} \ | |
#con2_words_fixed {border-color: #E5E7EB;} \ | |
#grp1_words_fixed {border-color: #E5E7EB;} \ | |
#grp2_words_fixed {border-color: #E5E7EB;} \ | |
#con1_words_fixed input {box-shadow:None; border-width:0} \ | |
#con1_words_fixed .scroll-hide {box-shadow:None; border-width:0} \ | |
#con2_words_fixed input {box-shadow:None; border-width:0} \ | |
#con2_words_fixed .scroll-hide {box-shadow:None; border-width:0} \ | |
#grp1_words_fixed input {box-shadow:None; border-width:0} \ | |
#grp1_words_fixed .scroll-hide {box-shadow:None; border-width:0} \ | |
#grp2_words_fixed input {box-shadow:None; border-width:0} \ | |
#grp2_words_fixed .scroll-hide {box-shadow:None; border-width:0} \ | |
#tested_model_drop {border-color: #E5E7EB;} \ | |
#gen_model_check {border-color: white;} \ | |
#gen_model_check .wrap {border-color: white;} \ | |
#gen_model_check .form {border-color: white;} \ | |
#open_ai_key_box {border-color: #E5E7EB;} \ | |
#gen_col {border-color: white;} \ | |
#gen_col .form {border-color: white;} \ | |
#res_label {background-color: #F8FAFC;} \ | |
#per_attrib_label_elem {background-color: #F8FAFC;} \ | |
#accordion {border-color: #E5E7EB} \ | |
#err_msg_elem p {color: #FF0000; cursor: pointer} \ | |
#res_label .bar {background-color: #35d4ac; } \ | |
#bloomberg_legend {background: white; border-color: white} \ | |
#bloomberg_att1 {background: white; border-color: white} \ | |
#bloomberg_att2 {background: white; border-color: white} \ | |
.tooltiptext_left {visibility: hidden;max-width:50ch;min-width:25ch;top: 100%;left: 0%;background-color: #222;text-align: center;border-radius: 6px;padding: 5px 0;position: absolute;z-index: 1;} \ | |
.tooltiptext_right {visibility: hidden;max-width:50ch;min-width:25ch;top: 100%;right: 0%;background-color: #222;text-align: center;border-radius: 6px;padding: 5px 0;position: absolute;z-index: 1;} \ | |
#filled:hover .tooltiptext_left {visibility: visible;} \ | |
#empty:hover .tooltiptext_left {visibility: visible;} \ | |
#filled:hover .tooltiptext_right {visibility: visible;} \ | |
#empty:hover .tooltiptext_right {visibility: visible;}" | |
with gr.Blocks(theme=soft, title="Social Bias Testing in Image-To-Text Models", | |
css=css_adds) as iface: | |
with gr.Row(): | |
s1_btn = gr.Button(value="Step 1: Bias Specification", variant="primary", visible=True, interactive=True, size='sm')#.style(size='sm') | |
s2_btn = gr.Button(value="Step 2: Test Images", variant="secondary", visible=True, interactive=False, size='sm')#.style(size='sm') | |
s3_btn = gr.Button(value="Step 3: Bias Testing", variant="secondary", visible=True, interactive=False, size='sm')#.style(size='sm') | |
err_message = gr.Markdown("", visible=False, elem_id="err_msg_elem") | |
bar_progress = gr.Markdown(" ") | |
# Page 1 | |
with gr.Column(visible=True) as tab1: | |
with gr.Column(): | |
gr.Markdown("#### Enter concepts to generate") # #group_row | |
with gr.Row(elem_id ="generation_row"): | |
concept1 = gr.Textbox(label="Image Generation Concept 1", max_lines=1, elem_id="con1_words", elem_classes="input_words", placeholder="ceo, executive") | |
concept2 = gr.Textbox(label="Image Generation Concept 2", max_lines=1, elem_id="con2_words", elem_classes="input_words", placeholder="nurse, janitor") | |
gr.Markdown("#### Enter concepts to test") # #attribute_row | |
with gr.Row(elem_id="group_row"): | |
group1 = gr.Textbox(label="Text Caption Concept 1", max_lines=1, elem_id="grp1_words", elem_classes="input_words", placeholder="brother, father") | |
group2 = gr.Textbox(label="Text Caption Concept 2", max_lines=1, elem_id="grp2_words", elem_classes="input_words", placeholder="sister, mother") | |
with gr.Row(): | |
gr.Markdown(" ") | |
get_sent_btn = gr.Button(value="Get Images", variant="primary", visible=True) | |
gr.Markdown(" ") | |
# Page 2 | |
with gr.Column(visible=False) as tab2: | |
info_imgs_found = gr.Markdown(value="", visible=False) # info_sentences_found | |
gr.Markdown("### Tested Social Bias Specification", visible=True) | |
with gr.Row(): | |
concept1_fixed = gr.Textbox(label="Image Generation Concept 1", max_lines=1, elem_id="con1_words_fixed", elem_classes="input_words", interactive=False, visible=True) # group1_words_fixed | |
concept2_fixed = gr.Textbox(label='Image Generation Concept 2', max_lines=1, elem_id="con2_words_fixed", elem_classes="input_words", interactive=False, visible=True) # group2_fixed | |
with gr.Row(): | |
group1_fixed = gr.Textbox(label='Text Caption Concept 1', max_lines=1, elem_id="grp1_words_fixed", elem_classes="input_words", interactive=False, visible=True) # att1_words_fixed | |
group2_fixed = gr.Textbox(label='Text Caption Concept 2', max_lines=1, elem_id="grp2_words_fixed", elem_classes="input_words", interactive=False, visible=True) # att2_fixed | |
with gr.Row(): | |
with gr.Column(): | |
with gr.Row(visible=False) as online_gen_row: | |
with gr.Column(): | |
gen_title = gr.Markdown("### Generate Additional Images", visible=True) | |
# OpenAI Key for generator | |
openai_key = gr.Textbox(lines=1, label="OpenAI API Key", value=None, | |
placeholder="starts with sk-", | |
info="Please provide the key for an Open AI account to generate new test images", | |
visible=True, | |
interactive=True, | |
elem_id="open_ai_key_box") | |
num_imgs2gen = gr.Slider(2, 20, value=2, step=1, | |
interactive=True, | |
visible=True, | |
container=True) | |
with gr.Row(visible=False) as tested_model_row: | |
with gr.Column(): | |
gen_title = gr.Markdown("### Select Tested Model", visible=True) | |
tested_model_name = gr.Dropdown(["CLIP", "ViLT"], value="CLIP", | |
multiselect=None, | |
interactive=True, | |
label="Tested model", | |
elem_id="tested_model_drop", | |
visible=True | |
) | |
with gr.Row(): | |
gr.Markdown(" ") | |
gen_btn = gr.Button(value="Generate New Images", variant="primary", visible=True) | |
bias_btn = gr.Button(value="Test Model for Social Bias", variant="primary", visible=False) | |
gr.Markdown(" ") | |
with gr.Row(visible=False) as row_imgs: # row_sentences | |
with gr.Accordion(label="Test Images", open=False, visible=False) as acc_test_imgs: # acc_test_sentences | |
test_imgs = gr.Gallery(show_label=False) # test_sentences, output | |
# Page 3 | |
with gr.Column(visible=False) as tab3: | |
gr.Markdown("### Tested Social Bias Specification", visible=True) | |
with gr.Row(): | |
concept1_fixed2 = gr.Textbox(label="Image Generation Concept 1", max_lines=1, elem_id="con1_words_fixed", elem_classes="input_words", interactive=False) # group1_words_fixed | |
concept2_fixed2 = gr.Textbox(label='Image Generation Concept 2', max_lines=1, elem_id="con2_words_fixed", elem_classes="input_words", interactive=False) # group2_fixed | |
with gr.Row(): | |
group1_fixed2 = gr.Textbox(label='Text Caption Concept 1', max_lines=1, elem_id="grp1_words_fixed", elem_classes="input_words", interactive=False) # att1_words_fixed | |
group2_fixed2 = gr.Textbox(label='Text Caption Concept 2', max_lines=1, elem_id="grp2_words_fixed", elem_classes="input_words", interactive=False) # att2_fixed | |
with gr.Row(): | |
with gr.Column(scale=2): | |
gr.Markdown("### Bias Test Results") | |
with gr.Row(): | |
with gr.Column(scale=2): | |
lbl_model_bias = gr.Markdown("**Model Bias** - % stereotyped choices (β more bias)") | |
model_bias_label = gr.Label(num_top_classes=1, label="% stereotyped choices (β more bias)", | |
elem_id="res_label", | |
show_label=False) | |
with gr.Row(): | |
with gr.Column(variant="compact", elem_id="bloomberg_legend"): | |
group_labels_html = gr.HTML(value="<div style='height:20px;width:20px;background-color:#065b41;display:inline-block;vertical-align:top'></div><div style='display:inline-block;vertical-align:top'> Social group 1 more probable in the image </div> <div style='height:20px;width:20px;background-color:#35d4ac;display:inline-block;vertical-align:top'></div><div style='display:inline-block;vertical-align:top'> Social group 2 more probable in the image </div>") | |
with gr.Row(): | |
with gr.Column(variant="compact", elem_id="bloomberg_att1"): | |
gr.Markdown("#### Text Caption Concept Probability for Image Generation Concept 1") | |
c1_results = gr.HTML() | |
with gr.Column(variant="compact", elem_id="bloomberg_att2"): | |
gr.Markdown("#### Text Caption Concept Probability for Image Generation Concept 2") | |
c2_results = gr.HTML() | |
gr.HTML(value="Visualization inspired by <a href='https://www.bloomberg.com/graphics/2023-generative-ai-bias/' target='_blank'>Bloomberg article on bias in text-to-image models</a>.") | |
save_msg = gr.HTML(value="<span style=\"color:black\">Bias test result saved! </span>", visible=False) | |
with gr.Row(): | |
with gr.Column(): | |
with gr.Row(): | |
gr.Markdown(" ") | |
with gr.Column(): | |
new_bias_button = gr.Button("Try New Bias Test", variant="primary") | |
gr.Markdown(" ") | |
# Get sentences | |
get_sent_btn.click(fn=retrieveImgs, #retrieveSentences | |
inputs=[concept1, concept2, group1, group2], | |
outputs=[err_message, online_gen_row, num_imgs2gen, tested_model_row, tested_model_name, info_imgs_found, bar_progress, s1_btn, s2_btn, s3_btn, tab1, tab2, acc_test_imgs, row_imgs, test_imgs, gen_btn, bias_btn, | |
concept1_fixed, concept2_fixed, group1_fixed, group2_fixed ] | |
) | |
# request getting sentences | |
gen_btn.click(fn=generateImgs, #generateSentences | |
inputs=[concept1, concept2, openai_key, num_imgs2gen], | |
outputs=[err_message, info_imgs_found, online_gen_row, | |
tested_model_row, tested_model_name, acc_test_imgs, row_imgs, test_imgs, gen_btn, bias_btn ] | |
) | |
# Test bias | |
bias_btn.click(fn=startBiasTest, | |
inputs=[test_imgs, concept1, concept2, group1, group2, tested_model_name], | |
outputs=[err_message, bar_progress, s1_btn, s2_btn, s3_btn, tab1, tab2, tab3, model_bias_label, | |
c1_results, c2_results, concept1_fixed2, concept2_fixed2, group1_fixed2, group2_fixed2, | |
group_labels_html] | |
) | |
# top breadcrumbs | |
s1_btn.click(fn=moveStep1, | |
inputs=[], | |
outputs=[s1_btn, s2_btn, s3_btn, tab1, tab2, tab3]) | |
# top breadcrumbs | |
s2_btn.click(fn=moveStep2, | |
inputs=[], | |
outputs=[s1_btn, s2_btn, s3_btn, tab1, tab2, tab3]) | |
# top breadcrumbs | |
s3_btn.click(fn=moveStep3, | |
inputs=[], | |
outputs=[s1_btn, s2_btn, s3_btn, tab1, tab2, tab3]) | |
new_bias_button.click(fn=moveStep1_clear, | |
inputs=[], | |
outputs=[s1_btn, s2_btn, s3_btn, tab1, tab2, tab3, concept1, concept2, group1, group2]) | |
iface.queue(concurrency_count=2).launch() |