"""
# Copyright (c) 2022, salesforce.com, inc.
# All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
import random
from collections import OrderedDict
from functools import reduce
from tkinter import N
import streamlit as st
from lavis.common.registry import registry
from lavis.datasets.builders import dataset_zoo, load_dataset
from lavis.datasets.builders.base_dataset_builder import load_dataset_config
from PIL import Image
IMAGE_LAYOUT = 3, 4
VIDEO_LAYOUT = 1, 2
PREV_STR = "Prev"
NEXT_STR = "Next"
def sample_dataset(dataset, indices):
samples = [dataset.displ_item(idx) for idx in indices]
return samples
def get_concat_v(im1, im2):
margin = 5
canvas_size = (im1.width + im2.width + margin, max(im1.height, im2.height))
canvas = Image.new("RGB", canvas_size, "White")
canvas.paste(im1, (0, 0))
canvas.paste(im2, (im1.width + margin, 0))
return canvas
def resize_img_w(raw_img, new_w=224):
if isinstance(raw_img, list):
resized_imgs = [resize_img_w(img, 196) for img in raw_img]
# concatenate images
resized_image = reduce(get_concat_v, resized_imgs)
else:
w, h = raw_img.size
scaling_factor = new_w / w
resized_image = raw_img.resize(
(int(w * scaling_factor), int(h * scaling_factor))
)
return resized_image
def get_visual_key(dataset):
if "image" in dataset[0]:
return "image"
elif "image0" in dataset[0]: # NLVR2 dataset
return "image"
elif "video" in dataset[0]:
return "video"
else:
raise ValueError("Visual key not found.")
def gather_items(samples, exclude=[]):
gathered = []
for s in samples:
ns = OrderedDict()
for k in s.keys():
if k not in exclude:
ns[k] = s[k]
gathered.append(ns)
return gathered
@st.cache(allow_output_mutation=True)
def load_dataset_cache(name):
return load_dataset(name)
def format_text(text):
md = "\n\n".join([f"**{k}**: {v}" for k, v in text.items()])
return md
def show_samples(dataset, offset=0, is_next=False):
visual_key = get_visual_key(dataset)
num_rows, num_cols = IMAGE_LAYOUT if visual_key == "image" else VIDEO_LAYOUT
n_samples = num_rows * num_cols
if not shuffle:
if is_next:
start = min(int(start_idx) + offset + n_samples, len(dataset) - n_samples)
else:
start = max(0, int(start_idx) + offset - n_samples)
st.session_state.last_start = start
end = min(start + n_samples, len(dataset))
indices = list(range(start, end))
else:
indices = random.sample(range(len(dataset)), n_samples)
samples = sample_dataset(dataset, indices)
visual_info = (
iter([resize_img_w(s[visual_key]) for s in samples])
if visual_key == "image"
# else iter([s[visual_key] for s in samples])
else iter([s["file"] for s in samples])
)
text_info = gather_items(samples, exclude=["image", "video"])
text_info = iter([format_text(s) for s in text_info])
st.markdown(
"""
""",
unsafe_allow_html=True,
)
for _ in range(num_rows):
with st.container():
for col in st.columns(num_cols):
# col.text(next(text_info))
# col.caption(next(text_info))
try:
col.markdown(next(text_info))
if visual_key == "image":
col.image(next(visual_info), use_column_width=True, clamp=True)
elif visual_key == "video":
col.markdown(
"![Alt Text](https://media.giphy.com/media/vFKqnCdLPNOKc/giphy.gif)"
)
except StopIteration:
break
st.markdown(
"""
""",
unsafe_allow_html=True,
)
st.session_state.n_display = n_samples
if __name__ == "__main__":
st.set_page_config(
page_title="LAVIS Dataset Explorer",
# layout="wide",
initial_sidebar_state="expanded",
)
dataset_name = st.sidebar.selectbox("Dataset:", dataset_zoo.get_names())
function = st.sidebar.selectbox("Function:", ["Browser"], index=0)
if function == "Browser":
shuffle = st.sidebar.selectbox("Shuffled:", [True, False], index=0)
dataset = load_dataset_cache(dataset_name)
split = st.sidebar.selectbox("Split:", dataset.keys())
dataset_len = len(dataset[split])
st.success(
f"Loaded {dataset_name}/{split} with **{dataset_len}** records. **Image/video directory**: {dataset[split].vis_root}"
)
if "last_dataset" not in st.session_state:
st.session_state.last_dataset = dataset_name
st.session_state.last_split = split
if "last_start" not in st.session_state:
st.session_state.last_start = 0
if "start_idx" not in st.session_state:
st.session_state.start_idx = 0
if "shuffle" not in st.session_state:
st.session_state.shuffle = shuffle
if "first_run" not in st.session_state:
st.session_state.first_run = True
elif (
st.session_state.last_dataset != dataset_name
or st.session_state.last_split != split
):
st.session_state.first_run = True
st.session_state.last_dataset = dataset_name
st.session_state.last_split = split
elif st.session_state.shuffle != shuffle:
st.session_state.shuffle = shuffle
st.session_state.first_run = True
if not shuffle:
n_col, p_col = st.columns([0.05, 1])
prev_button = n_col.button(PREV_STR)
next_button = p_col.button(NEXT_STR)
else:
next_button = st.button(NEXT_STR)
if not shuffle:
start_idx = st.sidebar.text_input(f"Begin from (total {dataset_len})", 0)
if not start_idx.isdigit():
st.error(f"Input to 'Begin from' must be digits, found {start_idx}.")
else:
if int(start_idx) != st.session_state.start_idx:
st.session_state.start_idx = int(start_idx)
st.session_state.last_start = int(start_idx)
if prev_button:
show_samples(
dataset[split],
offset=st.session_state.last_start - st.session_state.start_idx,
is_next=False,
)
if next_button:
show_samples(
dataset[split],
offset=st.session_state.last_start - st.session_state.start_idx,
is_next=True,
)
if st.session_state.first_run:
st.session_state.first_run = False
show_samples(
dataset[split],
offset=st.session_state.last_start - st.session_state.start_idx,
is_next=True,
)