cakiki commited on
Commit
8ea625c
1 Parent(s): a85d0c6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -3
app.py CHANGED
@@ -1,4 +1,5 @@
1
  import gradio as gr
 
2
  from datasets import load_dataset
3
  import numpy as np
4
 
@@ -14,8 +15,11 @@ index = np.load("indexes/knn_10752_65.npy")
14
  ds = load_dataset("SDBiaseval/identities", split="train")
15
 
16
  def get_index(gender, ethnicity, model, no):
17
- pass
18
 
19
- def get_nearest_64(ix):
20
- return ds.select(index[ix][1:])
 
21
 
 
 
 
1
  import gradio as gr
2
+ import pandas as pd
3
  from datasets import load_dataset
4
  import numpy as np
5
 
 
15
  ds = load_dataset("SDBiaseval/identities", split="train")
16
 
17
  def get_index(gender, ethnicity, model, no):
18
+ return df.loc[(df['ethnicity'] == ethnicity) & (df['gender'] == gender) & (df['no'] == no) & (df['model'] == model)].index[0]
19
 
20
+ def get_nearest_64(gender, ethnicity, model, no):
21
+ ix = get_index(gender, ethnicity, model, no)
22
+ return ds.select(index[ix][0])["image"], ds.select(index[ix][1:])["image"]
23
 
24
+ # demo = gr.Interface(fn=query_db, inputs="text", outputs=[gr.Image(), gr.Gallery().style(grid=[8])])
25
+ # demo.launch(debug=True)