File size: 1,816 Bytes
20b3e10
8ea625c
20b3e10
 
 
8609803
b68cd97
a85d0c6
 
 
 
 
 
b68cd97
5a6a374
20b3e10
 
8ea625c
1675e3b
 
0475e0a
 
 
769e78e
d052173
a85d0c6
b68cd97
8720503
 
 
 
 
95932fb
 
442d337
4868380
442d337
 
4868380
b68cd97
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import gradio as gr
import pandas as pd
from datasets import load_dataset
import numpy as np


gender_labels = ['man', 'non-binary', 'woman', 'no_gender_specified', ]

ethnicity_labels = ['African-American', 'American_Indian', 'Black', 'Caucasian', 'East_Asian',
                    'First_Nations', 'Hispanic', 'Indigenous_American', 'Latino', 'Latinx',
                    'Multiracial', 'Native_American', 'Pacific_Islander', 'South_Asian',
                    'Southeast_Asian', 'White', 'no_ethnicity_specified']
models = ['DallE', 'SD_14', 'SD_2']
nos = [1,2,3,4,5,6,7,8,9,10]
index = np.load("indexes/knn_10752_65.npy")
ds = load_dataset("SDBiaseval/identities", split="train")

def get_nearest_64(gender, ethnicity, model, no):
    df = ds.remove_columns(["image","image_path"]).to_pandas()
    ix = df.loc[(df['ethnicity'] == ethnicity)  & (df['gender'] == gender) & (df['no'] == no) & (df['model'] == model)].index[0]
    image = ds.select([index[ix][0]])["image"][0]
    neighbors =  ds.select(index[ix][1:])
    neighbor_images = neighbors["image"]
    neighbor_captions = [caption.split("/")[-1] for caption in neighbors["image_path"]]
    return image, list(zip(neighbor_images, neighbor_captions))

with gr.Blocks() as demo:
    with gr.Row():
        with gr.Column():
            gender = gr.Radio(gender_labels, label="Gender label")
            model = gr.Radio(models, label="Model")
            no = gr.Radio(nos, label="Image number")
        with gr.Column():
            ethnicity = gr.Radio(ethnicity_labels, label="Ethnicity label")
    button = gr.Button(value="Get nearest neighbors")
    with gr.Row():
        image = gr.Image()
        gallery = gr.Gallery().style(grid=8)
    button.click(get_nearest_64, inputs=[gender, ethnicity, model, no], outputs=[image, gallery])
demo.launch()