Anonymous Authors commited on
Commit
cd8be09
1 Parent(s): ad7b8e6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -3
app.py CHANGED
@@ -11,11 +11,12 @@ ethnicity_labels = ['African-American', 'American_Indian', 'Black', 'Caucasian',
11
  'Southeast_Asian', 'White', 'no_ethnicity_specified']
12
  models = ['DallE', 'SD_14', 'SD_2']
13
  nos = [1,2,3,4,5,6,7,8,9,10]
14
- index = np.load("indexes/knn_10752_65.npy")
15
  ds = load_dataset("tti-bias/identities", split="train")
16
 
17
- def get_nearest_64(gender="man", ethnicity="Hispanic", model="SD_14", no=1):
18
  df = ds.remove_columns(["image","image_path"]).to_pandas()
 
19
  ix = df.loc[(df['ethnicity'] == ethnicity) & (df['gender'] == gender) & (df['no'] == no) & (df['model'] == model)].index[0]
20
  image = ds.select([index[ix][0]])["image"][0]
21
  neighbors = ds.select(index[ix][1:25])
@@ -33,6 +34,7 @@ with gr.Blocks() as demo:
33
  with gr.Row():
34
  with gr.Column():
35
  model = gr.Radio(models, label="Model")
 
36
  gender = gr.Radio(gender_labels, label="Gender label")
37
  no = gr.Radio(nos, label="Image number")
38
  with gr.Column():
@@ -41,5 +43,5 @@ with gr.Blocks() as demo:
41
  with gr.Row():
42
  image = gr.Image()
43
  gallery = gr.Gallery().style(grid=4)
44
- button.click(get_nearest_64, inputs=[gender, ethnicity, model, no], outputs=[image, gallery])
45
  demo.launch()
 
11
  'Southeast_Asian', 'White', 'no_ethnicity_specified']
12
  models = ['DallE', 'SD_14', 'SD_2']
13
  nos = [1,2,3,4,5,6,7,8,9,10]
14
+ indexes = [768, 1536, 10752]
15
  ds = load_dataset("tti-bias/identities", split="train")
16
 
17
+ def get_nearest_64(gender="man", ethnicity="Hispanic", model="SD_14", no=1, index):
18
  df = ds.remove_columns(["image","image_path"]).to_pandas()
19
+ index = np.load(f"indexes/knn_{index}_65.npy")
20
  ix = df.loc[(df['ethnicity'] == ethnicity) & (df['gender'] == gender) & (df['no'] == no) & (df['model'] == model)].index[0]
21
  image = ds.select([index[ix][0]])["image"][0]
22
  neighbors = ds.select(index[ix][1:25])
 
34
  with gr.Row():
35
  with gr.Column():
36
  model = gr.Radio(models, label="Model")
37
+ index = gr.Radio(indexes, label="Visual vocabulary size")
38
  gender = gr.Radio(gender_labels, label="Gender label")
39
  no = gr.Radio(nos, label="Image number")
40
  with gr.Column():
 
43
  with gr.Row():
44
  image = gr.Image()
45
  gallery = gr.Gallery().style(grid=4)
46
+ button.click(get_nearest_64, inputs=[gender, ethnicity, model, no, index], outputs=[image, gallery])
47
  demo.launch()