import streamlit as st from pathlib import Path # ----------------------------------------------------------------------------- # main # ----------------------------------------------------------------------------- def main(): st.title("SatVision Few-Shot Comparison") st.write("") selected_option = st.selectbox( "Number of training samples", [10, 1000, 5000] ) st.markdown( "Move slider to select how many training " + "samples the models were trained on" ) images = load_images(selected_option, Path("./images/images")) labels = load_labels(selected_option, Path("./images/labels")) preds = load_predictions(selected_option, Path("./images/predictions")) zipped_st_images = zip(images, preds["svb"], preds["unet"], preds["unet-ls"], labels) st.write("") titleCol0, titleCol1, titleCol2, titleCol3, titleCol4 = st.columns(5) titleCol0.markdown(f"### MOD09GA [3-2-1] Image Chip") titleCol1.markdown(f"### SatVision-B Prediction") titleCol2.markdown(f"### UNet (CNN) Prediction") titleCol3.markdown(f'### UNet (CNN) LS Pretrained Prediction') titleCol4.markdown(f"### MCD12Q1 LandCover Target") st.write("") grid = make_grid(5, 5) for i, (image_data, svb_data, unet_data, unet_ls_data, label_data) in enumerate(zipped_st_images): # if i == 0: # grid[0][0].markdown(f'## MOD09GA 3-2-1 Image Chip') # grid[0][1].markdown(f'## SatVision-B Prediction') # grid[0][2].markdown(f'## UNet (CNN) Prediction') # grid[0][3].markdown(f'## MCD12Q1 LandCover Target') grid[i][0].image(image_data[0], image_data[1], use_column_width=True) grid[i][1].image(svb_data[0], svb_data[1], use_column_width=True) grid[i][2].image(unet_data[0], unet_data[1], use_column_width=True) grid[i][3].image(unet_ls_data[0], unet_ls_data[1], use_column_width=True) grid[i][4].image(label_data[0], label_data[1], use_column_width=True) st.markdown("### Few-Shot Learning with SatVision-Base") description = ( "Pre-trained vision transformers (we use SwinV2) offers a " + "good advantage when looking to apply a model to a task with very little" + " labeled training data. We pre-trained SatVision-Base on 26 million " + " MODIS Surface Reflectance image patches. This allows the " + " SatVision-Base models to learn relevant features and representations" + " from a diverse range of scenes. This knowledge can be transferred to a" + " few-shot learning task, enabling the model to leverage its" + " understanding of spatial patterns, textures, and contextual information" ) st.markdown(description) # ----------------------------------------------------------------------------- # load_images # ----------------------------------------------------------------------------- def load_images(selected_option: str, image_dir: Path): """ Given a selected option and image dir, return streamlit image objects. """ image_paths = find_images(selected_option, image_dir) images = [ (str(path), f"MOD09GA 3-2-1 H18v04 2019 Example {i}") for i, path in enumerate(image_paths, 1) ] return images # ----------------------------------------------------------------------------- # find_images # ----------------------------------------------------------------------------- def find_images(selected_option: str, image_dir: Path): images_regex = f"ft_demo_{selected_option}_*_img.png" images_matching_regex = sorted(image_dir.glob(images_regex)) assert len(images_matching_regex) == 3, "Should be 3 images matching regex" assert "1071" in str(images_matching_regex[0]), "Should be 1071" return images_matching_regex # ----------------------------------------------------------------------------- # load_labels # ----------------------------------------------------------------------------- def load_labels(selected_option, label_dir: Path): label_paths = find_labels(selected_option, label_dir) labels = [ (str(path), f"MCD12Q1 LandCover Target Example {i}") for i, path in enumerate(label_paths, 1) ] return labels # ----------------------------------------------------------------------------- # find_labels # ----------------------------------------------------------------------------- def find_labels(selected_option: str, label_dir: Path): labels_regex = f"ft_demo_{selected_option}_*_label.png" labels_matching_regex = sorted(label_dir.glob(labels_regex)) assert len(labels_matching_regex) == 3, "Should be 3 label images matching regex" assert "1071" in str(labels_matching_regex[0]), "Should be 1071" return labels_matching_regex # ----------------------------------------------------------------------------- # load_predictions # ----------------------------------------------------------------------------- def load_predictions(selected_option: str, pred_dir: Path): svb_pred_paths = find_preds(selected_option, pred_dir, "svb") unet_pred_paths = find_preds(selected_option, pred_dir, "cnn") unet_ls_pred_paths = find_preds(selected_option, pred_dir, "cnn-ls") svb_preds = [ (str(path), f"SatVision-B Prediction Example {i}") for i, path in enumerate(svb_pred_paths, 1) ] unet_preds = [ (str(path), f"Unet Prediction Example {i}") for i, path in enumerate(unet_pred_paths, 1) ] unet_ls_preds = [ (str(path), f"Unet LS Pre-trained Prediction Example {i}") for i, path in enumerate(unet_ls_pred_paths, 1) ] prediction_dict = {"svb": svb_preds, "unet": unet_preds, "unet-ls": unet_ls_preds} return prediction_dict # ----------------------------------------------------------------------------- # find_preds # ----------------------------------------------------------------------------- def find_preds(selected_option: int, pred_dir: Path, model: str): if model == "cnn": pred_regex = f"ft_cnn_demo_{selected_option}_*cnn-plain_pred.png" elif model == "cnn-ls": pred_regex = f"ft_cnn_demo_{selected_option}_*cnn-ls_pred.png" else: pred_regex = f"ft_demo_{selected_option}_*_pred.png" model_specific_dir = pred_dir / str(selected_option) / model assert model_specific_dir.exists(), f"{model_specific_dir} does not exist" preds_matching_regex = sorted(model_specific_dir.glob(pred_regex)) assert ( len(preds_matching_regex) == 3 ), "Should be 3 prediction images matching regex" assert "1071" in str(preds_matching_regex[0]), "Should be 1071" return preds_matching_regex # ----------------------------------------------------------------------------- # make_grid # ----------------------------------------------------------------------------- def make_grid(cols, rows): grid = [0] * cols for i in range(cols): with st.container(): grid[i] = st.columns(rows, gap="large") return grid # ----------------------------------------------------------------------------- # Main execution # ----------------------------------------------------------------------------- if __name__ == "__main__": main()