|
import streamlit as st |
|
from pathlib import Path |
|
|
|
|
|
|
|
|
|
|
|
def main(): |
|
st.title("SatVision Few-Shot Comparison") |
|
|
|
selected_option = st.select_slider( |
|
"## Number of training samples", |
|
options=[10, 100, 500, 1000, 5000]) |
|
|
|
st.markdown('Move slider to select how many training ' + \ |
|
'samples the models were trained on') |
|
|
|
images = load_images(selected_option, Path('./images/images')) |
|
|
|
labels = load_labels(selected_option, Path('./images/labels')) |
|
|
|
preds = load_predictions(selected_option, Path('./images/predictions')) |
|
|
|
zipped_st_images = zip(images, preds['svb'], preds['unet'], labels) |
|
|
|
grid = make_grid(4, 4) |
|
|
|
for i, (image_data, svb_data, unet_data, label_data) in \ |
|
enumerate(zipped_st_images): |
|
|
|
if i == 0: |
|
|
|
grid[0][0].markdown(f'## MOD09GA 3-2-1 Image Chip') |
|
grid[0][1].markdown(f'## SatVision-B Prediction') |
|
grid[0][2].markdown(f'## UNet (CNN) Prediction') |
|
grid[0][3].markdown(f'## MCD12Q1 LandCover Target') |
|
|
|
grid[i][0].image(image_data[0], image_data[1], use_column_width=True) |
|
grid[i][1].image(svb_data[0], svb_data[1], use_column_width=True) |
|
grid[i][2].image(unet_data[0], unet_data[1], use_column_width=True) |
|
grid[i][3].image(label_data[0], label_data[1], use_column_width=True) |
|
|
|
st.text("Additional Information:") |
|
st.text("This is a placeholder for additional information about the images.") |
|
|
|
|
|
|
|
|
|
def load_images(selected_option: str, image_dir: Path): |
|
""" |
|
Given a selected option and image dir, return streamlit image objects. |
|
""" |
|
|
|
image_paths = find_images(selected_option, image_dir) |
|
|
|
images = [(str(path), f"MOD09GA 3-2-1 H18v04 2019 Example {i}") for \ |
|
i, path in enumerate(image_paths, 1)] |
|
|
|
return images |
|
|
|
|
|
|
|
|
|
def find_images(selected_option: str, image_dir: Path): |
|
|
|
images_regex = f'ft_demo_{selected_option}_*_img.png' |
|
|
|
images_matching_regex = sorted(image_dir.glob(images_regex)) |
|
|
|
assert len(images_matching_regex) == 3, "Should be 3 images matching regex" |
|
|
|
assert '1071' in str(images_matching_regex[0]), 'Should be 1071' |
|
|
|
return images_matching_regex |
|
|
|
|
|
|
|
|
|
|
|
def load_labels(selected_option, label_dir: Path): |
|
label_paths = find_labels(selected_option, label_dir) |
|
|
|
labels = [(str(path), f"MCD12Q1 LandCover Target Example {i}") for \ |
|
i, path in enumerate(label_paths, 1)] |
|
|
|
return labels |
|
|
|
|
|
|
|
|
|
|
|
def find_labels(selected_option: str, label_dir: Path): |
|
|
|
labels_regex = f'ft_demo_{selected_option}_*_label.png' |
|
|
|
labels_matching_regex = sorted(label_dir.glob(labels_regex)) |
|
|
|
assert len(labels_matching_regex) == 3, \ |
|
"Should be 3 label images matching regex" |
|
|
|
assert '1071' in str(labels_matching_regex[0]), 'Should be 1071' |
|
|
|
return labels_matching_regex |
|
|
|
|
|
|
|
|
|
|
|
def load_predictions(selected_option: str, pred_dir: Path): |
|
svb_pred_paths = find_preds(selected_option, pred_dir, 'svb') |
|
|
|
unet_pred_paths = find_preds(selected_option, pred_dir, 'cnn') |
|
|
|
svb_preds = [(str(path), f"SatVision-B Prediction Example {i}") for \ |
|
i, path in enumerate(svb_pred_paths, 1)] |
|
|
|
unet_preds = [(str(path), f"Unet Prediction Example {i}") for \ |
|
i, path in enumerate(unet_pred_paths, 1)] |
|
|
|
prediction_dict = {'svb': svb_preds, 'unet': unet_preds} |
|
|
|
return prediction_dict |
|
|
|
|
|
|
|
|
|
|
|
def find_preds(selected_option: int, pred_dir: Path, model: str): |
|
|
|
if model == 'cnn': |
|
|
|
pred_regex = f'ft_cnn_demo_{selected_option}_*_pred.png' |
|
|
|
else: |
|
pred_regex = f'ft_demo_{selected_option}_*_pred.png' |
|
|
|
model_specific_dir = pred_dir / str(selected_option) / model |
|
|
|
assert model_specific_dir.exists(), f'{model_specific_dir} does not exist' |
|
|
|
preds_matching_regex = sorted(model_specific_dir.glob(pred_regex)) |
|
|
|
assert len(preds_matching_regex) == 3, \ |
|
"Should be 3 prediction images matching regex" |
|
|
|
assert '1071' in str(preds_matching_regex[0]), 'Should be 1071' |
|
|
|
return preds_matching_regex |
|
|
|
|
|
|
|
|
|
|
|
def make_grid(cols,rows): |
|
|
|
grid = [0]*cols |
|
|
|
for i in range(cols): |
|
|
|
with st.container(): |
|
|
|
grid[i] = st.columns(rows, gap='large') |
|
|
|
return grid |
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
main() |