yvokeller commited on
Commit
37d5e3e
·
1 Parent(s): 72b9dc9

load geotifs from AWS S3

Browse files
.gitignore CHANGED
@@ -1,4 +1,6 @@
1
  hf_cache
2
  __pycache__
3
  .DS_Store
4
- data/
 
 
 
1
  hf_cache
2
  __pycache__
3
  .DS_Store
4
+ data/*
5
+ !data/dataset_info.json
6
+ !data/chips_stats.yaml
data/chips_stats.yaml ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ fold_0:
2
+ mean:
3
+ - 570.7305297851562
4
+ - 691.6322021484375
5
+ - 436.3498229980469
6
+ - 3222.44775390625
7
+ - 1939.10009765625
8
+ - 1180.9752197265625
9
+ n_chips: 72
10
+ std:
11
+ - 466.30303955078125
12
+ - 355.8695373535156
13
+ - 305.64422607421875
14
+ - 1125.038330078125
15
+ - 681.1266479492188
16
+ - 632.9395751953125
17
+ fold_1:
18
+ mean:
19
+ - 578.840087890625
20
+ - 700.3932495117188
21
+ - 447.7603759765625
22
+ - 3214.770751953125
23
+ - 1935.7547607421875
24
+ - 1180.948974609375
25
+ n_chips: 72
26
+ std:
27
+ - 488.18951416015625
28
+ - 380.0794677734375
29
+ - 339.2040100097656
30
+ - 1149.1806640625
31
+ - 684.062255859375
32
+ - 638.1996459960938
33
+ fold_2:
34
+ mean:
35
+ - 576.6995239257812
36
+ - 696.1321411132812
37
+ - 441.8569030761719
38
+ - 3244.72998046875
39
+ - 1951.02734375
40
+ - 1187.2139892578125
41
+ n_chips: 72
42
+ std:
43
+ - 479.5519104003906
44
+ - 371.16168212890625
45
+ - 329.15521240234375
46
+ - 1062.248779296875
47
+ - 642.8277587890625
48
+ - 624.2898559570312
49
+ fold_3:
50
+ mean:
51
+ - 556.559814453125
52
+ - 678.816162109375
53
+ - 427.0155944824219
54
+ - 3229.951904296875
55
+ - 1929.81103515625
56
+ - 1163.3214111328125
57
+ n_chips: 71
58
+ std:
59
+ - 461.0555725097656
60
+ - 352.8609619140625
61
+ - 301.42449951171875
62
+ - 1098.100830078125
63
+ - 670.2280883789062
64
+ - 628.5005493164062
65
+ fold_4:
66
+ mean:
67
+ - 565.2382202148438
68
+ - 683.7120361328125
69
+ - 431.3344421386719
70
+ - 3241.815185546875
71
+ - 1936.545654296875
72
+ - 1176.8934326171875
73
+ n_chips: 71
74
+ std:
75
+ - 470.9222412109375
76
+ - 360.7673645019531
77
+ - 309.9328918457031
78
+ - 1090.716796875
79
+ - 671.9968872070312
80
+ - 639.7266845703125
81
+ fold_5:
82
+ mean:
83
+ - 568.527587890625
84
+ - 691.322021484375
85
+ - 444.148193359375
86
+ - 3116.252197265625
87
+ - 1885.5509033203125
88
+ - 1155.75927734375
89
+ n_chips: 71
90
+ std:
91
+ - 473.07781982421875
92
+ - 364.54290771484375
93
+ - 317.82159423828125
94
+ - 1228.0555419921875
95
+ - 732.2979736328125
96
+ - 648.6840209960938
97
+ num_classes_tier1: 6
98
+ num_classes_tier2: 17
99
+ num_classes_tier3: 49
data/dataset_info.json ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "tier1": [
3
+ "Background",
4
+ "Grassland",
5
+ "Field crops",
6
+ "Orchards",
7
+ "Special crops",
8
+ "Forest"
9
+ ],
10
+ "tier2": [
11
+ "Background",
12
+ "Meadow",
13
+ "SmallGrainCereal",
14
+ "LargeGrainCereal",
15
+ "Pasture",
16
+ "BroadLeafRowCrop",
17
+ "VegetableCrop",
18
+ "OrchardCrop",
19
+ "Fallow",
20
+ "Hedge",
21
+ "Berries",
22
+ "TreeCrop",
23
+ "CropMix",
24
+ "BiodiversityArea",
25
+ "Forest",
26
+ "Multiple",
27
+ "Gardens"
28
+ ],
29
+ "tier3": [
30
+ "Background",
31
+ "Meadow",
32
+ "WinterWheat",
33
+ "Maize",
34
+ "Pasture",
35
+ "Sugar_beets",
36
+ "WinterBarley",
37
+ "WinterRapeseed",
38
+ "Vegetables",
39
+ "Potatoes",
40
+ "Wheat",
41
+ "Sunflowers",
42
+ "Vines",
43
+ "Spelt",
44
+ "Fallow",
45
+ "Hedge",
46
+ "Apples",
47
+ "Soy",
48
+ "Peas",
49
+ "Oat",
50
+ "Berries",
51
+ "EinkornWheat",
52
+ "Field bean",
53
+ "SummerWheat",
54
+ "Rye",
55
+ "TreeCrop",
56
+ "StoneFruit",
57
+ "MixedCrop",
58
+ "Sorghum",
59
+ "Grain",
60
+ "Chicory",
61
+ "Pears",
62
+ "SummerBarley",
63
+ "Biodiversity encouragement area",
64
+ "Forest",
65
+ "Linen",
66
+ "Legumes",
67
+ "Pumpkin",
68
+ "Tobacco",
69
+ "Buckwheat",
70
+ "Hemp",
71
+ "SummerRapeseed",
72
+ "Hops",
73
+ "Beets",
74
+ "Multiple",
75
+ "Lupine",
76
+ "Mustard",
77
+ "Gardens",
78
+ "Chestnut"
79
+ ],
80
+ "tier3_to_tier1": [
81
+ 0,
82
+ 1,
83
+ 2,
84
+ 2,
85
+ 1,
86
+ 2,
87
+ 2,
88
+ 2,
89
+ 2,
90
+ 2,
91
+ 2,
92
+ 2,
93
+ 3,
94
+ 2,
95
+ 4,
96
+ 4,
97
+ 3,
98
+ 2,
99
+ 2,
100
+ 2,
101
+ 4,
102
+ 2,
103
+ 2,
104
+ 2,
105
+ 2,
106
+ 3,
107
+ 3,
108
+ 2,
109
+ 2,
110
+ 2,
111
+ 2,
112
+ 3,
113
+ 2,
114
+ 1,
115
+ 5,
116
+ 2,
117
+ 2,
118
+ 2,
119
+ 2,
120
+ 2,
121
+ 2,
122
+ 2,
123
+ 3,
124
+ 2,
125
+ 4,
126
+ 2,
127
+ 2,
128
+ 4,
129
+ 3
130
+ ],
131
+ "tier3_to_tier2": [
132
+ 0,
133
+ 1,
134
+ 2,
135
+ 3,
136
+ 4,
137
+ 5,
138
+ 2,
139
+ 5,
140
+ 6,
141
+ 5,
142
+ 2,
143
+ 5,
144
+ 7,
145
+ 2,
146
+ 8,
147
+ 9,
148
+ 7,
149
+ 5,
150
+ 5,
151
+ 2,
152
+ 10,
153
+ 2,
154
+ 5,
155
+ 2,
156
+ 2,
157
+ 11,
158
+ 7,
159
+ 12,
160
+ 3,
161
+ 2,
162
+ 6,
163
+ 7,
164
+ 2,
165
+ 13,
166
+ 14,
167
+ 5,
168
+ 5,
169
+ 6,
170
+ 5,
171
+ 2,
172
+ 5,
173
+ 5,
174
+ 7,
175
+ 5,
176
+ 15,
177
+ 5,
178
+ 5,
179
+ 16,
180
+ 7
181
+ ]
182
+ }
inference.py CHANGED
@@ -14,6 +14,9 @@ import geopandas as gpd
14
 
15
  from messis.messis import LogConfusionMatrix
16
 
 
 
 
17
  class InferenceDataLoader:
18
  def __init__(self, features_path, labels_path, field_ids_path, stats_path, window_size=224, n_timesteps=3, fold_indices=None, debug=False):
19
  self.features_path = features_path
@@ -24,7 +27,7 @@ class InferenceDataLoader:
24
  self.n_timesteps = n_timesteps
25
  self.fold_indices = fold_indices if fold_indices is not None else []
26
  self.debug = debug
27
-
28
  # Load normalization stats
29
  self.means, self.stds = self.load_stats()
30
 
@@ -69,20 +72,20 @@ class InferenceDataLoader:
69
  if self.debug:
70
  print("Source Transform", src.transform)
71
  print(f"UTM X: {utm_x}, UTM Y: {utm_y}")
72
-
73
  try:
74
  px, py = rowcol(src.transform, utm_x, utm_y)
75
  except ValueError:
76
  raise ValueError("Coordinates out of bounds for this raster.")
77
-
78
  if self.debug:
79
  print(f"Row: {py}, Column: {px}")
80
-
81
  half_window_size = self.window_size // 2
82
-
83
  row_off = px - half_window_size
84
  col_off = py - half_window_size
85
-
86
  if row_off < 0:
87
  row_off = 0
88
  if col_off < 0:
@@ -91,7 +94,7 @@ class InferenceDataLoader:
91
  row_off = src.width - self.window_size
92
  if col_off + self.window_size > src.height:
93
  col_off = src.height - self.window_size
94
-
95
  window = Window(col_off, row_off, self.window_size, self.window_size)
96
  window_transform = src.window_transform(window)
97
  if self.debug:
@@ -109,7 +112,7 @@ class InferenceDataLoader:
109
  if self.debug:
110
  print(f"Extracted window data from {path}")
111
  print(f"Min: {window_data.min()}, Max: {window_data.max()}")
112
-
113
  return window_data
114
 
115
  def prepare_data_for_model(self, features_data):
@@ -199,10 +202,12 @@ def crop_predictions_to_gdf(field_ids, targets, predictions, transform, crs, cla
199
  return gdf
200
 
201
  def perform_inference(lon, lat, model, config, debug=False):
202
- features_path = "./data/stacked_features.tif"
203
- labels_path = "./data/labels.tif"
204
- field_ids_path = "./data/field_ids.tif"
 
205
  stats_path = "./data/chips_stats.yaml"
 
206
 
207
  loader = InferenceDataLoader(features_path, labels_path, field_ids_path, stats_path, n_timesteps=9, fold_indices=[0], debug=True)
208
 
@@ -215,8 +220,9 @@ def perform_inference(lon, lat, model, config, debug=False):
215
  print(label_data.shape)
216
  print(field_ids_data.shape)
217
 
218
- with open('./data/dataset_info.json', 'r') as file:
219
  dataset_info = json.load(file)
 
220
  class_names = dataset_info['tier3']
221
 
222
  tiers_dict = {k: v for k, v in config.hparams.get('heads_spec').items() if v.get('is_metrics_tier', False)}
@@ -236,6 +242,5 @@ def perform_inference(lon, lat, model, config, debug=False):
236
  # Simple GeoDataFrame with only the necessary columns
237
  gdf = gdf[['prediction_class', 'target_class', 'correct', 'geometry']]
238
  gdf.columns = ['Prediction', 'Target', 'Correct', 'geometry']
239
- # gdf = gdf[gdf['Target'] != 'Background']
240
 
241
  return gdf
 
14
 
15
  from messis.messis import LogConfusionMatrix
16
 
17
+ import rasterio
18
+ import rasterio.env
19
+
20
  class InferenceDataLoader:
21
  def __init__(self, features_path, labels_path, field_ids_path, stats_path, window_size=224, n_timesteps=3, fold_indices=None, debug=False):
22
  self.features_path = features_path
 
27
  self.n_timesteps = n_timesteps
28
  self.fold_indices = fold_indices if fold_indices is not None else []
29
  self.debug = debug
30
+
31
  # Load normalization stats
32
  self.means, self.stds = self.load_stats()
33
 
 
72
  if self.debug:
73
  print("Source Transform", src.transform)
74
  print(f"UTM X: {utm_x}, UTM Y: {utm_y}")
75
+
76
  try:
77
  px, py = rowcol(src.transform, utm_x, utm_y)
78
  except ValueError:
79
  raise ValueError("Coordinates out of bounds for this raster.")
80
+
81
  if self.debug:
82
  print(f"Row: {py}, Column: {px}")
83
+
84
  half_window_size = self.window_size // 2
85
+
86
  row_off = px - half_window_size
87
  col_off = py - half_window_size
88
+
89
  if row_off < 0:
90
  row_off = 0
91
  if col_off < 0:
 
94
  row_off = src.width - self.window_size
95
  if col_off + self.window_size > src.height:
96
  col_off = src.height - self.window_size
97
+
98
  window = Window(col_off, row_off, self.window_size, self.window_size)
99
  window_transform = src.window_transform(window)
100
  if self.debug:
 
112
  if self.debug:
113
  print(f"Extracted window data from {path}")
114
  print(f"Min: {window_data.min()}, Max: {window_data.max()}")
115
+
116
  return window_data
117
 
118
  def prepare_data_for_model(self, features_data):
 
202
  return gdf
203
 
204
  def perform_inference(lon, lat, model, config, debug=False):
205
+ features_path = "s3://messis-demo/stacked_features_cog.tif"
206
+ labels_path = "s3://messis-demo/labels_cog.tif"
207
+ field_ids_path = "s3://messis-demo/field_ids_cog.tif"
208
+
209
  stats_path = "./data/chips_stats.yaml"
210
+ dataset_info_path = "./data/dataset_info.json"
211
 
212
  loader = InferenceDataLoader(features_path, labels_path, field_ids_path, stats_path, n_timesteps=9, fold_indices=[0], debug=True)
213
 
 
220
  print(label_data.shape)
221
  print(field_ids_data.shape)
222
 
223
+ with open(dataset_info_path, 'r') as file:
224
  dataset_info = json.load(file)
225
+
226
  class_names = dataset_info['tier3']
227
 
228
  tiers_dict = {k: v for k, v in config.hparams.get('heads_spec').items() if v.get('is_metrics_tier', False)}
 
242
  # Simple GeoDataFrame with only the necessary columns
243
  gdf = gdf[['prediction_class', 'target_class', 'correct', 'geometry']]
244
  gdf.columns = ['Prediction', 'Target', 'Correct', 'geometry']
 
245
 
246
  return gdf
pages/2_Perform_Crop_Classification.py CHANGED
@@ -91,7 +91,7 @@ def perform_inference_step():
91
  # Add COG
92
  m.add_cog_layer(
93
  url="https://messis-demo.s3.amazonaws.com/stacked_features_cog.tif",
94
- name="AWS COG",
95
  bands=selected_bands,
96
  rescale=f"{vmin_vmax[selected_band][0]},{vmin_vmax[selected_band][1]}",
97
  zoom_to_layer=True
 
91
  # Add COG
92
  m.add_cog_layer(
93
  url="https://messis-demo.s3.amazonaws.com/stacked_features_cog.tif",
94
+ name="Sentinel-2 Satellite Imagery",
95
  bands=selected_bands,
96
  rescale=f"{vmin_vmax[selected_band][0]},{vmin_vmax[selected_band][1]}",
97
  zoom_to_layer=True