|
import os |
|
import gradio as gr |
|
from datasets import load_dataset |
|
|
|
auth_token = os.environ.get("auth_token") |
|
visit_bench_all = load_dataset("mlfoundations/VisIT-Bench", use_auth_token=auth_token) |
|
print('visit_bench_all') |
|
print(visit_bench_all) |
|
print('dataset keys:') |
|
print(visit_bench_all.keys()) |
|
dataset_keys = list(visit_bench_all.keys()) |
|
assert len(dataset_keys) == 1 |
|
dataset_key = dataset_keys[0] |
|
visit_bench = visit_bench_all[dataset_key] |
|
print('first item:') |
|
print(visit_bench[0]) |
|
|
|
df = visit_bench.to_pandas() |
|
print(f"Got {len(df)} items in dataframe") |
|
df = df.sample(frac=1) |
|
|
|
df['image'] = df['image'].apply(lambda x: f'<a href="{x["path"]}" target="_blank"><img src="{x["path"]}" style="width:100%; max-width:800px; height:auto;"></a>') |
|
|
|
|
|
cols = list(df.columns) |
|
cols.insert(0, cols.pop(cols.index('image'))) |
|
df = df.reindex(columns=cols) |
|
LINES_NUMBER = 20 |
|
df.drop(columns=['visual'],inplace=True) |
|
|
|
def display_df(): |
|
df_images = df.head(LINES_NUMBER) |
|
return df_images |
|
|
|
def display_next(dataframe, end): |
|
start = int(end or len(dataframe)) |
|
end = int(start) + int(LINES_NUMBER) |
|
global df |
|
if end >= len(df) - 1: |
|
start = 0 |
|
end = LINES_NUMBER |
|
df = df.sample(frac=1) |
|
print(f"Shuffle") |
|
df_images = df.iloc[start:end] |
|
assert len(df_images) == LINES_NUMBER |
|
return df_images, end |
|
|
|
initial_dataframe = display_df() |
|
|
|
|
|
with gr.Blocks() as demo: |
|
style = """ |
|
<style> |
|
.gradio-container table tr td, .gradio-container table tr th { |
|
padding: 10px 20px; /* Increase padding, adjust as needed */ |
|
white-space: normal; /* Ensure text wraps in both data cells and headers */ |
|
word-wrap: break-word; /* Break the word to prevent overflow */ |
|
vertical-align: top; /* Align text to the top of the cell */ |
|
} |
|
.gradio-container table tr td:nth-child(1), .gradio-container table tr th:nth-child(1) { |
|
white-space: nowrap; /* Prevent wrapping in image column and its header */ |
|
vertical-align: middle; /* Align images to the middle of the cell */ |
|
width: 600px; /* Set a specific width for the image column and its header */ |
|
} |
|
</style> |
|
""" |
|
gr.HTML(style) |
|
gr.Markdown("<h1><center>VisIT-Bench Dataset Viewer</center></h1>") |
|
|
|
with gr.Row(): |
|
num_end = gr.Number(visible=False) |
|
b1 = gr.Button("Get Initial dataframe") |
|
b2 = gr.Button("Next Rows") |
|
|
|
with gr.Row(): |
|
out_dataframe = gr.Dataframe( |
|
value=initial_dataframe, |
|
row_count=LINES_NUMBER, |
|
interactive=False, |
|
datatype=["markdown", "str", "str", "bool", "bool", "bool", "str", "str", "str"], |
|
|
|
column_widths=[300, 120, 120, 120, 120, 120, 120, 120, 120, 120] |
|
) |
|
|
|
b1.click(fn=display_df, outputs=out_dataframe, api_name="initial_dataframe") |
|
b2.click(fn=display_next, inputs=[out_dataframe, num_end], outputs=[out_dataframe, num_end], |
|
api_name="next_rows") |
|
|
|
demo.launch(debug=True, show_error=True) |
|
|