File size: 2,473 Bytes
01c612b
 
 
 
 
 
 
 
 
 
 
 
 
 
7674524
01c612b
14b962b
01c612b
 
4925660
f4ad3a9
01c612b
 
 
 
 
 
 
 
 
14b962b
01c612b
4a36a1b
01c612b
 
 
 
4925660
01c612b
 
 
 
 
 
14b962b
01c612b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import gradio as gr
from datasets import load_from_disk
import numpy as np


gender_labels = ['man', 'non-binary', 'woman', 'no_gender_specified', ]

ethnicity_labels = ['African-American', 'American_Indian', 'Black', 'Caucasian', 'East_Asian',
                    'First_Nations', 'Hispanic', 'Indigenous_American', 'Latino', 'Latinx',
                    'Multiracial', 'Native_American', 'Pacific_Islander', 'South_Asian',
                    'Southeast_Asian', 'White', 'no_ethnicity_specified']
models = ['DallE', 'SD_14', 'SD_2']
nos = [1,2,3,4,5,6,7,8,9,10]

ds = load_from_disk("color-sorted")

def get_nearest(gender, ethnicity, model, no):
    df = ds.remove_columns(["image","image_path"]).to_pandas()
    ix = df.loc[(df['ethnicity'] == ethnicity)  & (df['gender'] == gender) & (df['no'] == no) & (df['model'] == model)].index[0]
    image = ds.select([ix])["image"][0]
    neighbors =  ds.select(range(max(ix-10, 0), min(ix+10, len(ds)-1)))
    neighbor_images = neighbors["image"]
    neighbor_captions = [caption.split("/")[-1] for caption in neighbors["image_path"]]
    neighbor_captions = [' '.join(caption.split("_")[4:-3]) for caption in neighbor_captions]
    neighbor_models = neighbors["model"]
    neighbor_captions = [f"{a} {b}" for a,b in zip(neighbor_captions,neighbor_models)]
    return image, list(zip(neighbor_images, neighbor_captions))

with gr.Blocks() as demo:
    gr.Markdown("# Colorfulness Nearest Neighbors Explorer")
    gr.Markdown("### Colorfulness 1-D index of the _identities_ dataset of images generated by 3 models")
    gr.Markdown("#### Choose one of the generated identity images to see its nearest neighbors according to colorfulness")
    gr.HTML("""<span style="color:red">⚠️ <b>DISCLAIMER: the images displayed by this tool were generated by text-to-image models and may depict offensive stereotypes or contain explicit content.</b></span>""")
    with gr.Row():
        with gr.Column():
            model = gr.Radio(models, label="Model")
            gender = gr.Radio(gender_labels, label="Gender label")
            no = gr.Radio(nos, label="Image number")
        with gr.Column():
            ethnicity = gr.Radio(ethnicity_labels, label="Ethnicity label")
    button = gr.Button(value="Get nearest neighbors")
    with gr.Row():
        image = gr.Image()
        gallery = gr.Gallery().style(grid=4)
    button.click(get_nearest, inputs=[gender, ethnicity, model, no], outputs=[image, gallery])
demo.launch()