cakiki commited on
Commit
1675e3b
1 Parent(s): 9f2f0a9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -9
app.py CHANGED
@@ -14,13 +14,12 @@ nos = [1,2,3,4,5,6,7,8,9,10]
14
  index = np.load("indexes/knn_10752_65.npy")
15
  ds = load_dataset("SDBiaseval/identities", split="train")
16
 
17
- def get_index(gender, ethnicity, model, no):
18
- df = ds.remove_columns(["image","image_path"]).to_pandas()
19
- return df.loc[(df['ethnicity'] == ethnicity) & (df['gender'] == gender) & (df['no'] == no) & (df['model'] == model)].index[0]
20
-
21
  def get_nearest_64(gender, ethnicity, model, no):
22
- ix = get_index(gender, ethnicity, model, no)
23
- return ds.select([index[ix][0]])["image"], ds.select(index[ix][1:5])["image"]
 
 
 
24
 
25
  with gr.Blocks() as demo:
26
  gender = gr.Radio(gender_labels, label="Gender label")
@@ -29,7 +28,5 @@ with gr.Blocks() as demo:
29
  no = gr.Radio(nos, label="Image number")
30
 
31
  button = gr.Button(value="Get nearest neighbors")
32
- button.click(get_nearest_64, inputs=[gender, ethnicity, model, no], outputs=[gr.Image(), gr.Gallery()])
33
  demo.launch()
34
- # demo = gr.Interface(fn=query_db, inputs="text", outputs=[gr.Image(), gr.Gallery().style(grid=[8])])
35
- # demo.launch(debug=True)
 
14
  index = np.load("indexes/knn_10752_65.npy")
15
  ds = load_dataset("SDBiaseval/identities", split="train")
16
 
 
 
 
 
17
  def get_nearest_64(gender, ethnicity, model, no):
18
+ df = ds.remove_columns(["image","image_path"]).to_pandas()
19
+ ix = df.loc[(df['ethnicity'] == ethnicity) & (df['gender'] == gender) & (df['no'] == no) & (df['model'] == model)].index[0]
20
+ image = ds.select([index[ix][0]])["image"]
21
+ neighbor_list = ds.select(index[ix][1:])["image"]
22
+ return image, neighbor_list
23
 
24
  with gr.Blocks() as demo:
25
  gender = gr.Radio(gender_labels, label="Gender label")
 
28
  no = gr.Radio(nos, label="Image number")
29
 
30
  button = gr.Button(value="Get nearest neighbors")
31
+ button.click(get_nearest_64, inputs=[gender, ethnicity, model, no], outputs=[gr.Image(), gr.Gallery().style(grid=[8], height="auto")])
32
  demo.launch()