Caleb Spradlin
Initial commit
4d01101
raw
history blame
5.81 kB
import streamlit as st
from pathlib import Path
# -----------------------------------------------------------------------------
# main
# -----------------------------------------------------------------------------
def main():
st.title("SatVision Few-Shot Comparison")
selected_option = st.select_slider(
"## Number of training samples",
options=[10, 100, 500, 1000, 5000])
st.markdown('Move slider to select how many training ' + \
'samples the models were trained on')
images = load_images(selected_option, Path('./images/images'))
labels = load_labels(selected_option, Path('./images/labels'))
preds = load_predictions(selected_option, Path('./images/predictions'))
zipped_st_images = zip(images, preds['svb'], preds['unet'], labels)
grid = make_grid(4, 4)
for i, (image_data, svb_data, unet_data, label_data) in \
enumerate(zipped_st_images):
if i == 0:
grid[0][0].markdown(f'## MOD09GA 3-2-1 Image Chip')
grid[0][1].markdown(f'## SatVision-B Prediction')
grid[0][2].markdown(f'## UNet (CNN) Prediction')
grid[0][3].markdown(f'## MCD12Q1 LandCover Target')
grid[i][0].image(image_data[0], image_data[1], use_column_width=True)
grid[i][1].image(svb_data[0], svb_data[1], use_column_width=True)
grid[i][2].image(unet_data[0], unet_data[1], use_column_width=True)
grid[i][3].image(label_data[0], label_data[1], use_column_width=True)
st.text("Additional Information:")
st.text("This is a placeholder for additional information about the images.")
# -----------------------------------------------------------------------------
# load_images
# -----------------------------------------------------------------------------
def load_images(selected_option: str, image_dir: Path):
"""
Given a selected option and image dir, return streamlit image objects.
"""
image_paths = find_images(selected_option, image_dir)
images = [(str(path), f"MOD09GA 3-2-1 H18v04 2019 Example {i}") for \
i, path in enumerate(image_paths, 1)]
return images
# -----------------------------------------------------------------------------
# find_images
# -----------------------------------------------------------------------------
def find_images(selected_option: str, image_dir: Path):
images_regex = f'ft_demo_{selected_option}_*_img.png'
images_matching_regex = sorted(image_dir.glob(images_regex))
assert len(images_matching_regex) == 3, "Should be 3 images matching regex"
assert '1071' in str(images_matching_regex[0]), 'Should be 1071'
return images_matching_regex
# -----------------------------------------------------------------------------
# load_labels
# -----------------------------------------------------------------------------
def load_labels(selected_option, label_dir: Path):
label_paths = find_labels(selected_option, label_dir)
labels = [(str(path), f"MCD12Q1 LandCover Target Example {i}") for \
i, path in enumerate(label_paths, 1)]
return labels
# -----------------------------------------------------------------------------
# find_labels
# -----------------------------------------------------------------------------
def find_labels(selected_option: str, label_dir: Path):
labels_regex = f'ft_demo_{selected_option}_*_label.png'
labels_matching_regex = sorted(label_dir.glob(labels_regex))
assert len(labels_matching_regex) == 3, \
"Should be 3 label images matching regex"
assert '1071' in str(labels_matching_regex[0]), 'Should be 1071'
return labels_matching_regex
# -----------------------------------------------------------------------------
# load_predictions
# -----------------------------------------------------------------------------
def load_predictions(selected_option: str, pred_dir: Path):
svb_pred_paths = find_preds(selected_option, pred_dir, 'svb')
unet_pred_paths = find_preds(selected_option, pred_dir, 'cnn')
svb_preds = [(str(path), f"SatVision-B Prediction Example {i}") for \
i, path in enumerate(svb_pred_paths, 1)]
unet_preds = [(str(path), f"Unet Prediction Example {i}") for \
i, path in enumerate(unet_pred_paths, 1)]
prediction_dict = {'svb': svb_preds, 'unet': unet_preds}
return prediction_dict
# -----------------------------------------------------------------------------
# find_preds
# -----------------------------------------------------------------------------
def find_preds(selected_option: int, pred_dir: Path, model: str):
if model == 'cnn':
pred_regex = f'ft_cnn_demo_{selected_option}_*_pred.png'
else:
pred_regex = f'ft_demo_{selected_option}_*_pred.png'
model_specific_dir = pred_dir / str(selected_option) / model
assert model_specific_dir.exists(), f'{model_specific_dir} does not exist'
preds_matching_regex = sorted(model_specific_dir.glob(pred_regex))
assert len(preds_matching_regex) == 3, \
"Should be 3 prediction images matching regex"
assert '1071' in str(preds_matching_regex[0]), 'Should be 1071'
return preds_matching_regex
# -----------------------------------------------------------------------------
# make_grid
# -----------------------------------------------------------------------------
def make_grid(cols,rows):
grid = [0]*cols
for i in range(cols):
with st.container():
grid[i] = st.columns(rows, gap='large')
return grid
# -----------------------------------------------------------------------------
# Main execution
# -----------------------------------------------------------------------------
if __name__ == "__main__":
main()