Caleb Spradlin
commited on
Commit
·
18903a3
1
Parent(s):
4d01101
changed images
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- app.py +82 -56
- images/images/ft_demo_1000_1071_img.png +0 -0
- images/images/ft_demo_1000_1076_img.png +0 -0
- images/images/ft_demo_1000_1541_img.png +0 -0
- images/images/ft_demo_100_1071_img.png +0 -0
- images/images/ft_demo_100_1076_img.png +0 -0
- images/images/ft_demo_100_1541_img.png +0 -0
- images/images/ft_demo_10_1071_img.png +0 -0
- images/images/ft_demo_10_1076_img.png +0 -0
- images/images/ft_demo_10_1541_img.png +0 -0
- images/images/ft_demo_5000_1071_img.png +0 -0
- images/images/ft_demo_5000_1076_img.png +0 -0
- images/images/ft_demo_5000_1541_img.png +0 -0
- images/images/ft_demo_500_1071_img.png +0 -0
- images/images/ft_demo_500_1076_img.png +0 -0
- images/images/ft_demo_500_1541_img.png +0 -0
- images/labels/ft_demo_1000_1071_label.png +0 -0
- images/labels/ft_demo_1000_1076_label.png +0 -0
- images/labels/ft_demo_1000_1541_label.png +0 -0
- images/labels/ft_demo_100_1071_label.png +0 -0
- images/labels/ft_demo_100_1076_label.png +0 -0
- images/labels/ft_demo_100_1541_label.png +0 -0
- images/labels/ft_demo_10_1071_label.png +0 -0
- images/labels/ft_demo_10_1076_label.png +0 -0
- images/labels/ft_demo_10_1541_label.png +0 -0
- images/labels/ft_demo_5000_1071_label.png +0 -0
- images/labels/ft_demo_5000_1076_label.png +0 -0
- images/labels/ft_demo_5000_1541_label.png +0 -0
- images/labels/ft_demo_500_1071_label.png +0 -0
- images/labels/ft_demo_500_1076_label.png +0 -0
- images/labels/ft_demo_500_1541_label.png +0 -0
- images/predictions/10/cnn/ft_cnn_demo_10_1071_pred.png +0 -0
- images/predictions/10/cnn/ft_cnn_demo_10_1076_pred.png +0 -0
- images/predictions/10/cnn/ft_cnn_demo_10_1541_pred.png +0 -0
- images/predictions/10/svb/ft_demo_10_1071_pred.png +0 -0
- images/predictions/10/svb/ft_demo_10_1076_pred.png +0 -0
- images/predictions/10/svb/ft_demo_10_1541_pred.png +0 -0
- images/predictions/100/cnn/ft_cnn_demo_100_1071_pred.png +0 -0
- images/predictions/100/cnn/ft_cnn_demo_100_1076_pred.png +0 -0
- images/predictions/100/cnn/ft_cnn_demo_100_1541_pred.png +0 -0
- images/predictions/100/svb/ft_demo_100_1071_pred.png +0 -0
- images/predictions/100/svb/ft_demo_100_1076_pred.png +0 -0
- images/predictions/100/svb/ft_demo_100_1541_pred.png +0 -0
- images/predictions/1000/cnn/ft_cnn_demo_1000_1071_pred.png +0 -0
- images/predictions/1000/cnn/ft_cnn_demo_1000_1076_pred.png +0 -0
- images/predictions/1000/cnn/ft_cnn_demo_1000_1541_pred.png +0 -0
- images/predictions/1000/svb/ft_demo_1000_1071_pred.png +0 -0
- images/predictions/1000/svb/ft_demo_1000_1076_pred.png +0 -0
- images/predictions/1000/svb/ft_demo_1000_1541_pred.png +0 -0
- images/predictions/500/cnn/ft_cnn_demo_500_1071_pred.png +0 -0
app.py
CHANGED
@@ -7,41 +7,64 @@ from pathlib import Path
|
|
7 |
# -----------------------------------------------------------------------------
|
8 |
def main():
|
9 |
st.title("SatVision Few-Shot Comparison")
|
10 |
-
|
11 |
-
selected_option = st.select_slider(
|
12 |
-
"## Number of training samples",
|
13 |
-
options=[10, 100, 500, 1000, 5000])
|
14 |
|
15 |
-
st.
|
16 |
-
'samples the models were trained on')
|
17 |
|
18 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
|
20 |
-
|
21 |
|
22 |
-
|
23 |
|
24 |
-
|
25 |
|
26 |
-
|
|
|
|
|
|
|
|
|
27 |
|
28 |
-
|
29 |
-
|
|
|
|
|
30 |
|
31 |
-
|
|
|
|
|
32 |
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
|
|
|
|
|
|
37 |
|
38 |
grid[i][0].image(image_data[0], image_data[1], use_column_width=True)
|
39 |
grid[i][1].image(svb_data[0], svb_data[1], use_column_width=True)
|
40 |
grid[i][2].image(unet_data[0], unet_data[1], use_column_width=True)
|
41 |
grid[i][3].image(label_data[0], label_data[1], use_column_width=True)
|
42 |
|
43 |
-
st.
|
44 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
|
46 |
# -----------------------------------------------------------------------------
|
47 |
# load_images
|
@@ -53,23 +76,25 @@ def load_images(selected_option: str, image_dir: Path):
|
|
53 |
|
54 |
image_paths = find_images(selected_option, image_dir)
|
55 |
|
56 |
-
images = [
|
57 |
-
|
|
|
|
|
58 |
|
59 |
return images
|
60 |
|
|
|
61 |
# -----------------------------------------------------------------------------
|
62 |
# find_images
|
63 |
# -----------------------------------------------------------------------------
|
64 |
def find_images(selected_option: str, image_dir: Path):
|
65 |
-
|
66 |
-
images_regex = f'ft_demo_{selected_option}_*_img.png'
|
67 |
|
68 |
images_matching_regex = sorted(image_dir.glob(images_regex))
|
69 |
|
70 |
assert len(images_matching_regex) == 3, "Should be 3 images matching regex"
|
71 |
|
72 |
-
assert
|
73 |
|
74 |
return images_matching_regex
|
75 |
|
@@ -80,8 +105,10 @@ def find_images(selected_option: str, image_dir: Path):
|
|
80 |
def load_labels(selected_option, label_dir: Path):
|
81 |
label_paths = find_labels(selected_option, label_dir)
|
82 |
|
83 |
-
labels = [
|
84 |
-
|
|
|
|
|
85 |
|
86 |
return labels
|
87 |
|
@@ -90,15 +117,13 @@ def load_labels(selected_option, label_dir: Path):
|
|
90 |
# find_labels
|
91 |
# -----------------------------------------------------------------------------
|
92 |
def find_labels(selected_option: str, label_dir: Path):
|
93 |
-
|
94 |
-
labels_regex = f'ft_demo_{selected_option}_*_label.png'
|
95 |
|
96 |
labels_matching_regex = sorted(label_dir.glob(labels_regex))
|
97 |
|
98 |
-
assert len(labels_matching_regex) == 3,
|
99 |
-
"Should be 3 label images matching regex"
|
100 |
|
101 |
-
assert
|
102 |
|
103 |
return labels_matching_regex
|
104 |
|
@@ -107,17 +132,21 @@ def find_labels(selected_option: str, label_dir: Path):
|
|
107 |
# load_predictions
|
108 |
# -----------------------------------------------------------------------------
|
109 |
def load_predictions(selected_option: str, pred_dir: Path):
|
110 |
-
svb_pred_paths = find_preds(selected_option, pred_dir,
|
|
|
|
|
111 |
|
112 |
-
|
|
|
|
|
|
|
113 |
|
114 |
-
|
115 |
-
|
|
|
|
|
116 |
|
117 |
-
|
118 |
-
i, path in enumerate(unet_pred_paths, 1)]
|
119 |
-
|
120 |
-
prediction_dict = {'svb': svb_preds, 'unet': unet_preds}
|
121 |
|
122 |
return prediction_dict
|
123 |
|
@@ -126,24 +155,23 @@ def load_predictions(selected_option: str, pred_dir: Path):
|
|
126 |
# find_preds
|
127 |
# -----------------------------------------------------------------------------
|
128 |
def find_preds(selected_option: int, pred_dir: Path, model: str):
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
pred_regex = f'ft_cnn_demo_{selected_option}_*_pred.png'
|
133 |
|
134 |
else:
|
135 |
-
pred_regex = f
|
136 |
|
137 |
model_specific_dir = pred_dir / str(selected_option) / model
|
138 |
|
139 |
-
assert model_specific_dir.exists(), f
|
140 |
|
141 |
preds_matching_regex = sorted(model_specific_dir.glob(pred_regex))
|
142 |
|
143 |
-
assert
|
144 |
-
|
|
|
145 |
|
146 |
-
assert
|
147 |
|
148 |
return preds_matching_regex
|
149 |
|
@@ -151,20 +179,18 @@ def find_preds(selected_option: int, pred_dir: Path, model: str):
|
|
151 |
# -----------------------------------------------------------------------------
|
152 |
# make_grid
|
153 |
# -----------------------------------------------------------------------------
|
154 |
-
def make_grid(cols,rows):
|
155 |
-
|
156 |
-
grid = [0]*cols
|
157 |
|
158 |
for i in range(cols):
|
159 |
-
|
160 |
with st.container():
|
161 |
-
|
162 |
-
grid[i] = st.columns(rows, gap='large')
|
163 |
|
164 |
return grid
|
165 |
|
|
|
166 |
# -----------------------------------------------------------------------------
|
167 |
# Main execution
|
168 |
# -----------------------------------------------------------------------------
|
169 |
if __name__ == "__main__":
|
170 |
-
main()
|
|
|
7 |
# -----------------------------------------------------------------------------
|
8 |
def main():
|
9 |
st.title("SatVision Few-Shot Comparison")
|
|
|
|
|
|
|
|
|
10 |
|
11 |
+
st.write("")
|
|
|
12 |
|
13 |
+
selected_option = st.selectbox(
|
14 |
+
"Number of training samples", [10, 100, 500, 1000, 5000]
|
15 |
+
)
|
16 |
+
st.markdown(
|
17 |
+
"Move slider to select how many training "
|
18 |
+
+ "samples the models were trained on"
|
19 |
+
)
|
20 |
|
21 |
+
images = load_images(selected_option, Path("./images/images"))
|
22 |
|
23 |
+
labels = load_labels(selected_option, Path("./images/labels"))
|
24 |
|
25 |
+
preds = load_predictions(selected_option, Path("./images/predictions"))
|
26 |
|
27 |
+
zipped_st_images = zip(images, preds["svb"], preds["unet"], labels)
|
28 |
+
|
29 |
+
st.write("")
|
30 |
+
|
31 |
+
titleCol0, titleCol1, titleCol2, titleCol3 = st.columns(4)
|
32 |
|
33 |
+
titleCol0.markdown(f"### MOD09GA [3-2-1] Image Chip")
|
34 |
+
titleCol1.markdown(f"### SatVision-B Prediction")
|
35 |
+
titleCol2.markdown(f"### UNet (CNN) Prediction")
|
36 |
+
titleCol3.markdown(f"### MCD12Q1 LandCover Target")
|
37 |
|
38 |
+
st.write("")
|
39 |
+
|
40 |
+
grid = make_grid(4, 4)
|
41 |
|
42 |
+
for i, (image_data, svb_data, unet_data, label_data) in enumerate(zipped_st_images):
|
43 |
+
# if i == 0:
|
44 |
+
|
45 |
+
# grid[0][0].markdown(f'## MOD09GA 3-2-1 Image Chip')
|
46 |
+
# grid[0][1].markdown(f'## SatVision-B Prediction')
|
47 |
+
# grid[0][2].markdown(f'## UNet (CNN) Prediction')
|
48 |
+
# grid[0][3].markdown(f'## MCD12Q1 LandCover Target')
|
49 |
|
50 |
grid[i][0].image(image_data[0], image_data[1], use_column_width=True)
|
51 |
grid[i][1].image(svb_data[0], svb_data[1], use_column_width=True)
|
52 |
grid[i][2].image(unet_data[0], unet_data[1], use_column_width=True)
|
53 |
grid[i][3].image(label_data[0], label_data[1], use_column_width=True)
|
54 |
|
55 |
+
st.markdown("### Few-Shot Learning with SatVision-Base")
|
56 |
+
description = (
|
57 |
+
"Pre-trained vision transformers (we use SwinV2) offers a "
|
58 |
+
+ "good advantage when looking to apply a model to a task with very little"
|
59 |
+
+ " labeled training data. We pre-trained SatVision-Base on 26 million "
|
60 |
+
+ " MODIS Surface Reflectance image patches. This allows the "
|
61 |
+
+ " SatVision-Base models to learn relevant features and representations"
|
62 |
+
+ " from a diverse range of scenes. This knowledge can be transferred to a"
|
63 |
+
+ " few-shot learning task, enabling the model to leverage its"
|
64 |
+
+ " understanding of spatial patterns, textures, and contextual information"
|
65 |
+
)
|
66 |
+
st.markdown(description)
|
67 |
+
|
68 |
|
69 |
# -----------------------------------------------------------------------------
|
70 |
# load_images
|
|
|
76 |
|
77 |
image_paths = find_images(selected_option, image_dir)
|
78 |
|
79 |
+
images = [
|
80 |
+
(str(path), f"MOD09GA 3-2-1 H18v04 2019 Example {i}")
|
81 |
+
for i, path in enumerate(image_paths, 1)
|
82 |
+
]
|
83 |
|
84 |
return images
|
85 |
|
86 |
+
|
87 |
# -----------------------------------------------------------------------------
|
88 |
# find_images
|
89 |
# -----------------------------------------------------------------------------
|
90 |
def find_images(selected_option: str, image_dir: Path):
|
91 |
+
images_regex = f"ft_demo_{selected_option}_*_img.png"
|
|
|
92 |
|
93 |
images_matching_regex = sorted(image_dir.glob(images_regex))
|
94 |
|
95 |
assert len(images_matching_regex) == 3, "Should be 3 images matching regex"
|
96 |
|
97 |
+
assert "1071" in str(images_matching_regex[0]), "Should be 1071"
|
98 |
|
99 |
return images_matching_regex
|
100 |
|
|
|
105 |
def load_labels(selected_option, label_dir: Path):
|
106 |
label_paths = find_labels(selected_option, label_dir)
|
107 |
|
108 |
+
labels = [
|
109 |
+
(str(path), f"MCD12Q1 LandCover Target Example {i}")
|
110 |
+
for i, path in enumerate(label_paths, 1)
|
111 |
+
]
|
112 |
|
113 |
return labels
|
114 |
|
|
|
117 |
# find_labels
|
118 |
# -----------------------------------------------------------------------------
|
119 |
def find_labels(selected_option: str, label_dir: Path):
|
120 |
+
labels_regex = f"ft_demo_{selected_option}_*_label.png"
|
|
|
121 |
|
122 |
labels_matching_regex = sorted(label_dir.glob(labels_regex))
|
123 |
|
124 |
+
assert len(labels_matching_regex) == 3, "Should be 3 label images matching regex"
|
|
|
125 |
|
126 |
+
assert "1071" in str(labels_matching_regex[0]), "Should be 1071"
|
127 |
|
128 |
return labels_matching_regex
|
129 |
|
|
|
132 |
# load_predictions
|
133 |
# -----------------------------------------------------------------------------
|
134 |
def load_predictions(selected_option: str, pred_dir: Path):
|
135 |
+
svb_pred_paths = find_preds(selected_option, pred_dir, "svb")
|
136 |
+
|
137 |
+
unet_pred_paths = find_preds(selected_option, pred_dir, "cnn")
|
138 |
|
139 |
+
svb_preds = [
|
140 |
+
(str(path), f"SatVision-B Prediction Example {i}")
|
141 |
+
for i, path in enumerate(svb_pred_paths, 1)
|
142 |
+
]
|
143 |
|
144 |
+
unet_preds = [
|
145 |
+
(str(path), f"Unet Prediction Example {i}")
|
146 |
+
for i, path in enumerate(unet_pred_paths, 1)
|
147 |
+
]
|
148 |
|
149 |
+
prediction_dict = {"svb": svb_preds, "unet": unet_preds}
|
|
|
|
|
|
|
150 |
|
151 |
return prediction_dict
|
152 |
|
|
|
155 |
# find_preds
|
156 |
# -----------------------------------------------------------------------------
|
157 |
def find_preds(selected_option: int, pred_dir: Path, model: str):
|
158 |
+
if model == "cnn":
|
159 |
+
pred_regex = f"ft_cnn_demo_{selected_option}_*_pred.png"
|
|
|
|
|
160 |
|
161 |
else:
|
162 |
+
pred_regex = f"ft_demo_{selected_option}_*_pred.png"
|
163 |
|
164 |
model_specific_dir = pred_dir / str(selected_option) / model
|
165 |
|
166 |
+
assert model_specific_dir.exists(), f"{model_specific_dir} does not exist"
|
167 |
|
168 |
preds_matching_regex = sorted(model_specific_dir.glob(pred_regex))
|
169 |
|
170 |
+
assert (
|
171 |
+
len(preds_matching_regex) == 3
|
172 |
+
), "Should be 3 prediction images matching regex"
|
173 |
|
174 |
+
assert "1071" in str(preds_matching_regex[0]), "Should be 1071"
|
175 |
|
176 |
return preds_matching_regex
|
177 |
|
|
|
179 |
# -----------------------------------------------------------------------------
|
180 |
# make_grid
|
181 |
# -----------------------------------------------------------------------------
|
182 |
+
def make_grid(cols, rows):
|
183 |
+
grid = [0] * cols
|
|
|
184 |
|
185 |
for i in range(cols):
|
|
|
186 |
with st.container():
|
187 |
+
grid[i] = st.columns(rows, gap="large")
|
|
|
188 |
|
189 |
return grid
|
190 |
|
191 |
+
|
192 |
# -----------------------------------------------------------------------------
|
193 |
# Main execution
|
194 |
# -----------------------------------------------------------------------------
|
195 |
if __name__ == "__main__":
|
196 |
+
main()
|
images/images/ft_demo_1000_1071_img.png
CHANGED
images/images/ft_demo_1000_1076_img.png
CHANGED
images/images/ft_demo_1000_1541_img.png
CHANGED
images/images/ft_demo_100_1071_img.png
CHANGED
images/images/ft_demo_100_1076_img.png
CHANGED
images/images/ft_demo_100_1541_img.png
CHANGED
images/images/ft_demo_10_1071_img.png
CHANGED
images/images/ft_demo_10_1076_img.png
CHANGED
images/images/ft_demo_10_1541_img.png
CHANGED
images/images/ft_demo_5000_1071_img.png
CHANGED
images/images/ft_demo_5000_1076_img.png
CHANGED
images/images/ft_demo_5000_1541_img.png
CHANGED
images/images/ft_demo_500_1071_img.png
CHANGED
images/images/ft_demo_500_1076_img.png
CHANGED
images/images/ft_demo_500_1541_img.png
CHANGED
images/labels/ft_demo_1000_1071_label.png
CHANGED
images/labels/ft_demo_1000_1076_label.png
CHANGED
images/labels/ft_demo_1000_1541_label.png
CHANGED
images/labels/ft_demo_100_1071_label.png
CHANGED
images/labels/ft_demo_100_1076_label.png
CHANGED
images/labels/ft_demo_100_1541_label.png
CHANGED
images/labels/ft_demo_10_1071_label.png
CHANGED
images/labels/ft_demo_10_1076_label.png
CHANGED
images/labels/ft_demo_10_1541_label.png
CHANGED
images/labels/ft_demo_5000_1071_label.png
CHANGED
images/labels/ft_demo_5000_1076_label.png
CHANGED
images/labels/ft_demo_5000_1541_label.png
CHANGED
images/labels/ft_demo_500_1071_label.png
CHANGED
images/labels/ft_demo_500_1076_label.png
CHANGED
images/labels/ft_demo_500_1541_label.png
CHANGED
images/predictions/10/cnn/ft_cnn_demo_10_1071_pred.png
CHANGED
images/predictions/10/cnn/ft_cnn_demo_10_1076_pred.png
CHANGED
images/predictions/10/cnn/ft_cnn_demo_10_1541_pred.png
CHANGED
images/predictions/10/svb/ft_demo_10_1071_pred.png
CHANGED
images/predictions/10/svb/ft_demo_10_1076_pred.png
CHANGED
images/predictions/10/svb/ft_demo_10_1541_pred.png
CHANGED
images/predictions/100/cnn/ft_cnn_demo_100_1071_pred.png
CHANGED
images/predictions/100/cnn/ft_cnn_demo_100_1076_pred.png
CHANGED
images/predictions/100/cnn/ft_cnn_demo_100_1541_pred.png
CHANGED
images/predictions/100/svb/ft_demo_100_1071_pred.png
CHANGED
images/predictions/100/svb/ft_demo_100_1076_pred.png
CHANGED
images/predictions/100/svb/ft_demo_100_1541_pred.png
CHANGED
images/predictions/1000/cnn/ft_cnn_demo_1000_1071_pred.png
CHANGED
images/predictions/1000/cnn/ft_cnn_demo_1000_1076_pred.png
CHANGED
images/predictions/1000/cnn/ft_cnn_demo_1000_1541_pred.png
CHANGED
images/predictions/1000/svb/ft_demo_1000_1071_pred.png
CHANGED
images/predictions/1000/svb/ft_demo_1000_1076_pred.png
CHANGED
images/predictions/1000/svb/ft_demo_1000_1541_pred.png
CHANGED
images/predictions/500/cnn/ft_cnn_demo_500_1071_pred.png
CHANGED