Caleb Spradlin commited on
Commit
18903a3
·
1 Parent(s): 4d01101

changed images

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. app.py +82 -56
  2. images/images/ft_demo_1000_1071_img.png +0 -0
  3. images/images/ft_demo_1000_1076_img.png +0 -0
  4. images/images/ft_demo_1000_1541_img.png +0 -0
  5. images/images/ft_demo_100_1071_img.png +0 -0
  6. images/images/ft_demo_100_1076_img.png +0 -0
  7. images/images/ft_demo_100_1541_img.png +0 -0
  8. images/images/ft_demo_10_1071_img.png +0 -0
  9. images/images/ft_demo_10_1076_img.png +0 -0
  10. images/images/ft_demo_10_1541_img.png +0 -0
  11. images/images/ft_demo_5000_1071_img.png +0 -0
  12. images/images/ft_demo_5000_1076_img.png +0 -0
  13. images/images/ft_demo_5000_1541_img.png +0 -0
  14. images/images/ft_demo_500_1071_img.png +0 -0
  15. images/images/ft_demo_500_1076_img.png +0 -0
  16. images/images/ft_demo_500_1541_img.png +0 -0
  17. images/labels/ft_demo_1000_1071_label.png +0 -0
  18. images/labels/ft_demo_1000_1076_label.png +0 -0
  19. images/labels/ft_demo_1000_1541_label.png +0 -0
  20. images/labels/ft_demo_100_1071_label.png +0 -0
  21. images/labels/ft_demo_100_1076_label.png +0 -0
  22. images/labels/ft_demo_100_1541_label.png +0 -0
  23. images/labels/ft_demo_10_1071_label.png +0 -0
  24. images/labels/ft_demo_10_1076_label.png +0 -0
  25. images/labels/ft_demo_10_1541_label.png +0 -0
  26. images/labels/ft_demo_5000_1071_label.png +0 -0
  27. images/labels/ft_demo_5000_1076_label.png +0 -0
  28. images/labels/ft_demo_5000_1541_label.png +0 -0
  29. images/labels/ft_demo_500_1071_label.png +0 -0
  30. images/labels/ft_demo_500_1076_label.png +0 -0
  31. images/labels/ft_demo_500_1541_label.png +0 -0
  32. images/predictions/10/cnn/ft_cnn_demo_10_1071_pred.png +0 -0
  33. images/predictions/10/cnn/ft_cnn_demo_10_1076_pred.png +0 -0
  34. images/predictions/10/cnn/ft_cnn_demo_10_1541_pred.png +0 -0
  35. images/predictions/10/svb/ft_demo_10_1071_pred.png +0 -0
  36. images/predictions/10/svb/ft_demo_10_1076_pred.png +0 -0
  37. images/predictions/10/svb/ft_demo_10_1541_pred.png +0 -0
  38. images/predictions/100/cnn/ft_cnn_demo_100_1071_pred.png +0 -0
  39. images/predictions/100/cnn/ft_cnn_demo_100_1076_pred.png +0 -0
  40. images/predictions/100/cnn/ft_cnn_demo_100_1541_pred.png +0 -0
  41. images/predictions/100/svb/ft_demo_100_1071_pred.png +0 -0
  42. images/predictions/100/svb/ft_demo_100_1076_pred.png +0 -0
  43. images/predictions/100/svb/ft_demo_100_1541_pred.png +0 -0
  44. images/predictions/1000/cnn/ft_cnn_demo_1000_1071_pred.png +0 -0
  45. images/predictions/1000/cnn/ft_cnn_demo_1000_1076_pred.png +0 -0
  46. images/predictions/1000/cnn/ft_cnn_demo_1000_1541_pred.png +0 -0
  47. images/predictions/1000/svb/ft_demo_1000_1071_pred.png +0 -0
  48. images/predictions/1000/svb/ft_demo_1000_1076_pred.png +0 -0
  49. images/predictions/1000/svb/ft_demo_1000_1541_pred.png +0 -0
  50. images/predictions/500/cnn/ft_cnn_demo_500_1071_pred.png +0 -0
app.py CHANGED
@@ -7,41 +7,64 @@ from pathlib import Path
7
  # -----------------------------------------------------------------------------
8
  def main():
9
  st.title("SatVision Few-Shot Comparison")
10
-
11
- selected_option = st.select_slider(
12
- "## Number of training samples",
13
- options=[10, 100, 500, 1000, 5000])
14
 
15
- st.markdown('Move slider to select how many training ' + \
16
- 'samples the models were trained on')
17
 
18
- images = load_images(selected_option, Path('./images/images'))
 
 
 
 
 
 
19
 
20
- labels = load_labels(selected_option, Path('./images/labels'))
21
 
22
- preds = load_predictions(selected_option, Path('./images/predictions'))
23
 
24
- zipped_st_images = zip(images, preds['svb'], preds['unet'], labels)
25
 
26
- grid = make_grid(4, 4)
 
 
 
 
27
 
28
- for i, (image_data, svb_data, unet_data, label_data) in \
29
- enumerate(zipped_st_images):
 
 
30
 
31
- if i == 0:
 
 
32
 
33
- grid[0][0].markdown(f'## MOD09GA 3-2-1 Image Chip')
34
- grid[0][1].markdown(f'## SatVision-B Prediction')
35
- grid[0][2].markdown(f'## UNet (CNN) Prediction')
36
- grid[0][3].markdown(f'## MCD12Q1 LandCover Target')
 
 
 
37
 
38
  grid[i][0].image(image_data[0], image_data[1], use_column_width=True)
39
  grid[i][1].image(svb_data[0], svb_data[1], use_column_width=True)
40
  grid[i][2].image(unet_data[0], unet_data[1], use_column_width=True)
41
  grid[i][3].image(label_data[0], label_data[1], use_column_width=True)
42
 
43
- st.text("Additional Information:")
44
- st.text("This is a placeholder for additional information about the images.")
 
 
 
 
 
 
 
 
 
 
 
45
 
46
  # -----------------------------------------------------------------------------
47
  # load_images
@@ -53,23 +76,25 @@ def load_images(selected_option: str, image_dir: Path):
53
 
54
  image_paths = find_images(selected_option, image_dir)
55
 
56
- images = [(str(path), f"MOD09GA 3-2-1 H18v04 2019 Example {i}") for \
57
- i, path in enumerate(image_paths, 1)]
 
 
58
 
59
  return images
60
 
 
61
  # -----------------------------------------------------------------------------
62
  # find_images
63
  # -----------------------------------------------------------------------------
64
  def find_images(selected_option: str, image_dir: Path):
65
-
66
- images_regex = f'ft_demo_{selected_option}_*_img.png'
67
 
68
  images_matching_regex = sorted(image_dir.glob(images_regex))
69
 
70
  assert len(images_matching_regex) == 3, "Should be 3 images matching regex"
71
 
72
- assert '1071' in str(images_matching_regex[0]), 'Should be 1071'
73
 
74
  return images_matching_regex
75
 
@@ -80,8 +105,10 @@ def find_images(selected_option: str, image_dir: Path):
80
  def load_labels(selected_option, label_dir: Path):
81
  label_paths = find_labels(selected_option, label_dir)
82
 
83
- labels = [(str(path), f"MCD12Q1 LandCover Target Example {i}") for \
84
- i, path in enumerate(label_paths, 1)]
 
 
85
 
86
  return labels
87
 
@@ -90,15 +117,13 @@ def load_labels(selected_option, label_dir: Path):
90
  # find_labels
91
  # -----------------------------------------------------------------------------
92
  def find_labels(selected_option: str, label_dir: Path):
93
-
94
- labels_regex = f'ft_demo_{selected_option}_*_label.png'
95
 
96
  labels_matching_regex = sorted(label_dir.glob(labels_regex))
97
 
98
- assert len(labels_matching_regex) == 3, \
99
- "Should be 3 label images matching regex"
100
 
101
- assert '1071' in str(labels_matching_regex[0]), 'Should be 1071'
102
 
103
  return labels_matching_regex
104
 
@@ -107,17 +132,21 @@ def find_labels(selected_option: str, label_dir: Path):
107
  # load_predictions
108
  # -----------------------------------------------------------------------------
109
  def load_predictions(selected_option: str, pred_dir: Path):
110
- svb_pred_paths = find_preds(selected_option, pred_dir, 'svb')
 
 
111
 
112
- unet_pred_paths = find_preds(selected_option, pred_dir, 'cnn')
 
 
 
113
 
114
- svb_preds = [(str(path), f"SatVision-B Prediction Example {i}") for \
115
- i, path in enumerate(svb_pred_paths, 1)]
 
 
116
 
117
- unet_preds = [(str(path), f"Unet Prediction Example {i}") for \
118
- i, path in enumerate(unet_pred_paths, 1)]
119
-
120
- prediction_dict = {'svb': svb_preds, 'unet': unet_preds}
121
 
122
  return prediction_dict
123
 
@@ -126,24 +155,23 @@ def load_predictions(selected_option: str, pred_dir: Path):
126
  # find_preds
127
  # -----------------------------------------------------------------------------
128
  def find_preds(selected_option: int, pred_dir: Path, model: str):
129
-
130
- if model == 'cnn':
131
-
132
- pred_regex = f'ft_cnn_demo_{selected_option}_*_pred.png'
133
 
134
  else:
135
- pred_regex = f'ft_demo_{selected_option}_*_pred.png'
136
 
137
  model_specific_dir = pred_dir / str(selected_option) / model
138
 
139
- assert model_specific_dir.exists(), f'{model_specific_dir} does not exist'
140
 
141
  preds_matching_regex = sorted(model_specific_dir.glob(pred_regex))
142
 
143
- assert len(preds_matching_regex) == 3, \
144
- "Should be 3 prediction images matching regex"
 
145
 
146
- assert '1071' in str(preds_matching_regex[0]), 'Should be 1071'
147
 
148
  return preds_matching_regex
149
 
@@ -151,20 +179,18 @@ def find_preds(selected_option: int, pred_dir: Path, model: str):
151
  # -----------------------------------------------------------------------------
152
  # make_grid
153
  # -----------------------------------------------------------------------------
154
- def make_grid(cols,rows):
155
-
156
- grid = [0]*cols
157
 
158
  for i in range(cols):
159
-
160
  with st.container():
161
-
162
- grid[i] = st.columns(rows, gap='large')
163
 
164
  return grid
165
 
 
166
  # -----------------------------------------------------------------------------
167
  # Main execution
168
  # -----------------------------------------------------------------------------
169
  if __name__ == "__main__":
170
- main()
 
7
  # -----------------------------------------------------------------------------
8
  def main():
9
  st.title("SatVision Few-Shot Comparison")
 
 
 
 
10
 
11
+ st.write("")
 
12
 
13
+ selected_option = st.selectbox(
14
+ "Number of training samples", [10, 100, 500, 1000, 5000]
15
+ )
16
+ st.markdown(
17
+ "Move slider to select how many training "
18
+ + "samples the models were trained on"
19
+ )
20
 
21
+ images = load_images(selected_option, Path("./images/images"))
22
 
23
+ labels = load_labels(selected_option, Path("./images/labels"))
24
 
25
+ preds = load_predictions(selected_option, Path("./images/predictions"))
26
 
27
+ zipped_st_images = zip(images, preds["svb"], preds["unet"], labels)
28
+
29
+ st.write("")
30
+
31
+ titleCol0, titleCol1, titleCol2, titleCol3 = st.columns(4)
32
 
33
+ titleCol0.markdown(f"### MOD09GA [3-2-1] Image Chip")
34
+ titleCol1.markdown(f"### SatVision-B Prediction")
35
+ titleCol2.markdown(f"### UNet (CNN) Prediction")
36
+ titleCol3.markdown(f"### MCD12Q1 LandCover Target")
37
 
38
+ st.write("")
39
+
40
+ grid = make_grid(4, 4)
41
 
42
+ for i, (image_data, svb_data, unet_data, label_data) in enumerate(zipped_st_images):
43
+ # if i == 0:
44
+
45
+ # grid[0][0].markdown(f'## MOD09GA 3-2-1 Image Chip')
46
+ # grid[0][1].markdown(f'## SatVision-B Prediction')
47
+ # grid[0][2].markdown(f'## UNet (CNN) Prediction')
48
+ # grid[0][3].markdown(f'## MCD12Q1 LandCover Target')
49
 
50
  grid[i][0].image(image_data[0], image_data[1], use_column_width=True)
51
  grid[i][1].image(svb_data[0], svb_data[1], use_column_width=True)
52
  grid[i][2].image(unet_data[0], unet_data[1], use_column_width=True)
53
  grid[i][3].image(label_data[0], label_data[1], use_column_width=True)
54
 
55
+ st.markdown("### Few-Shot Learning with SatVision-Base")
56
+ description = (
57
+ "Pre-trained vision transformers (we use SwinV2) offers a "
58
+ + "good advantage when looking to apply a model to a task with very little"
59
+ + " labeled training data. We pre-trained SatVision-Base on 26 million "
60
+ + " MODIS Surface Reflectance image patches. This allows the "
61
+ + " SatVision-Base models to learn relevant features and representations"
62
+ + " from a diverse range of scenes. This knowledge can be transferred to a"
63
+ + " few-shot learning task, enabling the model to leverage its"
64
+ + " understanding of spatial patterns, textures, and contextual information"
65
+ )
66
+ st.markdown(description)
67
+
68
 
69
  # -----------------------------------------------------------------------------
70
  # load_images
 
76
 
77
  image_paths = find_images(selected_option, image_dir)
78
 
79
+ images = [
80
+ (str(path), f"MOD09GA 3-2-1 H18v04 2019 Example {i}")
81
+ for i, path in enumerate(image_paths, 1)
82
+ ]
83
 
84
  return images
85
 
86
+
87
  # -----------------------------------------------------------------------------
88
  # find_images
89
  # -----------------------------------------------------------------------------
90
  def find_images(selected_option: str, image_dir: Path):
91
+ images_regex = f"ft_demo_{selected_option}_*_img.png"
 
92
 
93
  images_matching_regex = sorted(image_dir.glob(images_regex))
94
 
95
  assert len(images_matching_regex) == 3, "Should be 3 images matching regex"
96
 
97
+ assert "1071" in str(images_matching_regex[0]), "Should be 1071"
98
 
99
  return images_matching_regex
100
 
 
105
  def load_labels(selected_option, label_dir: Path):
106
  label_paths = find_labels(selected_option, label_dir)
107
 
108
+ labels = [
109
+ (str(path), f"MCD12Q1 LandCover Target Example {i}")
110
+ for i, path in enumerate(label_paths, 1)
111
+ ]
112
 
113
  return labels
114
 
 
117
  # find_labels
118
  # -----------------------------------------------------------------------------
119
  def find_labels(selected_option: str, label_dir: Path):
120
+ labels_regex = f"ft_demo_{selected_option}_*_label.png"
 
121
 
122
  labels_matching_regex = sorted(label_dir.glob(labels_regex))
123
 
124
+ assert len(labels_matching_regex) == 3, "Should be 3 label images matching regex"
 
125
 
126
+ assert "1071" in str(labels_matching_regex[0]), "Should be 1071"
127
 
128
  return labels_matching_regex
129
 
 
132
  # load_predictions
133
  # -----------------------------------------------------------------------------
134
  def load_predictions(selected_option: str, pred_dir: Path):
135
+ svb_pred_paths = find_preds(selected_option, pred_dir, "svb")
136
+
137
+ unet_pred_paths = find_preds(selected_option, pred_dir, "cnn")
138
 
139
+ svb_preds = [
140
+ (str(path), f"SatVision-B Prediction Example {i}")
141
+ for i, path in enumerate(svb_pred_paths, 1)
142
+ ]
143
 
144
+ unet_preds = [
145
+ (str(path), f"Unet Prediction Example {i}")
146
+ for i, path in enumerate(unet_pred_paths, 1)
147
+ ]
148
 
149
+ prediction_dict = {"svb": svb_preds, "unet": unet_preds}
 
 
 
150
 
151
  return prediction_dict
152
 
 
155
  # find_preds
156
  # -----------------------------------------------------------------------------
157
  def find_preds(selected_option: int, pred_dir: Path, model: str):
158
+ if model == "cnn":
159
+ pred_regex = f"ft_cnn_demo_{selected_option}_*_pred.png"
 
 
160
 
161
  else:
162
+ pred_regex = f"ft_demo_{selected_option}_*_pred.png"
163
 
164
  model_specific_dir = pred_dir / str(selected_option) / model
165
 
166
+ assert model_specific_dir.exists(), f"{model_specific_dir} does not exist"
167
 
168
  preds_matching_regex = sorted(model_specific_dir.glob(pred_regex))
169
 
170
+ assert (
171
+ len(preds_matching_regex) == 3
172
+ ), "Should be 3 prediction images matching regex"
173
 
174
+ assert "1071" in str(preds_matching_regex[0]), "Should be 1071"
175
 
176
  return preds_matching_regex
177
 
 
179
  # -----------------------------------------------------------------------------
180
  # make_grid
181
  # -----------------------------------------------------------------------------
182
+ def make_grid(cols, rows):
183
+ grid = [0] * cols
 
184
 
185
  for i in range(cols):
 
186
  with st.container():
187
+ grid[i] = st.columns(rows, gap="large")
 
188
 
189
  return grid
190
 
191
+
192
  # -----------------------------------------------------------------------------
193
  # Main execution
194
  # -----------------------------------------------------------------------------
195
  if __name__ == "__main__":
196
+ main()
images/images/ft_demo_1000_1071_img.png CHANGED
images/images/ft_demo_1000_1076_img.png CHANGED
images/images/ft_demo_1000_1541_img.png CHANGED
images/images/ft_demo_100_1071_img.png CHANGED
images/images/ft_demo_100_1076_img.png CHANGED
images/images/ft_demo_100_1541_img.png CHANGED
images/images/ft_demo_10_1071_img.png CHANGED
images/images/ft_demo_10_1076_img.png CHANGED
images/images/ft_demo_10_1541_img.png CHANGED
images/images/ft_demo_5000_1071_img.png CHANGED
images/images/ft_demo_5000_1076_img.png CHANGED
images/images/ft_demo_5000_1541_img.png CHANGED
images/images/ft_demo_500_1071_img.png CHANGED
images/images/ft_demo_500_1076_img.png CHANGED
images/images/ft_demo_500_1541_img.png CHANGED
images/labels/ft_demo_1000_1071_label.png CHANGED
images/labels/ft_demo_1000_1076_label.png CHANGED
images/labels/ft_demo_1000_1541_label.png CHANGED
images/labels/ft_demo_100_1071_label.png CHANGED
images/labels/ft_demo_100_1076_label.png CHANGED
images/labels/ft_demo_100_1541_label.png CHANGED
images/labels/ft_demo_10_1071_label.png CHANGED
images/labels/ft_demo_10_1076_label.png CHANGED
images/labels/ft_demo_10_1541_label.png CHANGED
images/labels/ft_demo_5000_1071_label.png CHANGED
images/labels/ft_demo_5000_1076_label.png CHANGED
images/labels/ft_demo_5000_1541_label.png CHANGED
images/labels/ft_demo_500_1071_label.png CHANGED
images/labels/ft_demo_500_1076_label.png CHANGED
images/labels/ft_demo_500_1541_label.png CHANGED
images/predictions/10/cnn/ft_cnn_demo_10_1071_pred.png CHANGED
images/predictions/10/cnn/ft_cnn_demo_10_1076_pred.png CHANGED
images/predictions/10/cnn/ft_cnn_demo_10_1541_pred.png CHANGED
images/predictions/10/svb/ft_demo_10_1071_pred.png CHANGED
images/predictions/10/svb/ft_demo_10_1076_pred.png CHANGED
images/predictions/10/svb/ft_demo_10_1541_pred.png CHANGED
images/predictions/100/cnn/ft_cnn_demo_100_1071_pred.png CHANGED
images/predictions/100/cnn/ft_cnn_demo_100_1076_pred.png CHANGED
images/predictions/100/cnn/ft_cnn_demo_100_1541_pred.png CHANGED
images/predictions/100/svb/ft_demo_100_1071_pred.png CHANGED
images/predictions/100/svb/ft_demo_100_1076_pred.png CHANGED
images/predictions/100/svb/ft_demo_100_1541_pred.png CHANGED
images/predictions/1000/cnn/ft_cnn_demo_1000_1071_pred.png CHANGED
images/predictions/1000/cnn/ft_cnn_demo_1000_1076_pred.png CHANGED
images/predictions/1000/cnn/ft_cnn_demo_1000_1541_pred.png CHANGED
images/predictions/1000/svb/ft_demo_1000_1071_pred.png CHANGED
images/predictions/1000/svb/ft_demo_1000_1076_pred.png CHANGED
images/predictions/1000/svb/ft_demo_1000_1541_pred.png CHANGED
images/predictions/500/cnn/ft_cnn_demo_500_1071_pred.png CHANGED