Spaces:
Sleeping
Sleeping
first messis demo app version
Browse files- .gitignore +2 -0
- inference.py +235 -0
- main.py +15 -0
- messis/README.md +7 -0
- messis/__init__.py +0 -0
- messis/dataloader.py +287 -0
- messis/messis.py +919 -0
- messis/prithvi.py +555 -0
- pages/1_Select_Location.py +78 -0
- pages/2_Perform_Crop_Classification.py +99 -0
- requirements.txt +23 -0
.gitignore
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
hf_cache
|
2 |
+
__pycache__
|
inference.py
ADDED
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import yaml
|
4 |
+
import json
|
5 |
+
import rasterio
|
6 |
+
from rasterio.windows import Window
|
7 |
+
from rasterio.transform import rowcol
|
8 |
+
from pyproj import Transformer
|
9 |
+
from torchvision import transforms
|
10 |
+
import numpy as np
|
11 |
+
from rasterio.features import shapes
|
12 |
+
from shapely.geometry import shape
|
13 |
+
import geopandas as gpd
|
14 |
+
|
15 |
+
from messis.messis import LogConfusionMatrix
|
16 |
+
|
17 |
+
class InferenceDataLoader:
|
18 |
+
def __init__(self, features_path, labels_path, field_ids_path, stats_path, window_size=224, n_timesteps=3, fold_indices=None, debug=False):
|
19 |
+
self.features_path = features_path
|
20 |
+
self.labels_path = labels_path
|
21 |
+
self.field_ids_path = field_ids_path
|
22 |
+
self.stats_path = stats_path
|
23 |
+
self.window_size = window_size
|
24 |
+
self.n_timesteps = n_timesteps
|
25 |
+
self.fold_indices = fold_indices if fold_indices is not None else []
|
26 |
+
self.debug = debug
|
27 |
+
|
28 |
+
# Load normalization stats
|
29 |
+
self.means, self.stds = self.load_stats()
|
30 |
+
|
31 |
+
# Set up the transformer for coordinate conversion
|
32 |
+
self.transformer = Transformer.from_crs("EPSG:4326", "EPSG:32632", always_xy=True)
|
33 |
+
|
34 |
+
def load_stats(self):
|
35 |
+
"""Load normalization statistics for dataset from YAML file."""
|
36 |
+
if self.debug:
|
37 |
+
print(f"Loading mean/std stats from {self.stats_path}")
|
38 |
+
assert os.path.exists(self.stats_path), f"Mean/std stats file not found at {self.stats_path}"
|
39 |
+
|
40 |
+
with open(self.stats_path, 'r') as file:
|
41 |
+
stats = yaml.safe_load(file)
|
42 |
+
|
43 |
+
mean_list, std_list, n_list = [], [], []
|
44 |
+
for fold in self.fold_indices:
|
45 |
+
key = f'fold_{fold}'
|
46 |
+
if key not in stats:
|
47 |
+
raise ValueError(f"Mean/std stats for fold {fold} not found in {self.stats_path}")
|
48 |
+
if self.debug:
|
49 |
+
print(f"Stats with selected test fold {fold}: {stats[key]} over {self.n_timesteps} timesteps.")
|
50 |
+
mean_list.append(torch.tensor(stats[key]['mean'])) # list of 6 means
|
51 |
+
std_list.append(torch.tensor(stats[key]['std'])) # list of 6 stds
|
52 |
+
n_list.append(stats[key]['n_chips']) # list of 6 ns
|
53 |
+
|
54 |
+
means, stds = [], []
|
55 |
+
for channel in range(mean_list[0].shape[0]):
|
56 |
+
means.append(torch.stack([mean_list[i][channel] for i in range(len(mean_list))]).mean())
|
57 |
+
variances = torch.stack([std_list[i][channel] ** 2 for i in range(len(std_list))])
|
58 |
+
n = torch.tensor([n_list[i] for i in range(len(n_list))], dtype=torch.float32)
|
59 |
+
combined_variance = torch.sum(variances * (n - 1)) / (torch.sum(n) - len(n_list))
|
60 |
+
stds.append(torch.sqrt(combined_variance))
|
61 |
+
|
62 |
+
return means * self.n_timesteps, stds * self.n_timesteps
|
63 |
+
|
64 |
+
def identify_window(self, path, lon, lat):
|
65 |
+
"""Identify the 224x224 window centered on the clicked coordinates (lon, lat) from the specified GeoTIFF."""
|
66 |
+
with rasterio.open(path) as src:
|
67 |
+
# Transform the coordinates from WGS84 to UTM (EPSG:32632)
|
68 |
+
utm_x, utm_y = self.transformer.transform(lon, lat)
|
69 |
+
|
70 |
+
try:
|
71 |
+
px, py = rowcol(src.transform, utm_x, utm_y)
|
72 |
+
except ValueError:
|
73 |
+
raise ValueError("Coordinates out of bounds for this raster.")
|
74 |
+
|
75 |
+
if self.debug:
|
76 |
+
print(f"Row: {py}, Column: {px}")
|
77 |
+
|
78 |
+
half_window_size = self.window_size // 2
|
79 |
+
|
80 |
+
col_off = px - half_window_size
|
81 |
+
row_off = py - half_window_size
|
82 |
+
|
83 |
+
if col_off < 0:
|
84 |
+
col_off = 0
|
85 |
+
if row_off < 0:
|
86 |
+
row_off = 0
|
87 |
+
if col_off + self.window_size > src.width:
|
88 |
+
col_off = src.width - self.window_size
|
89 |
+
if row_off + self.window_size > src.height:
|
90 |
+
row_off = src.height - self.window_size
|
91 |
+
|
92 |
+
window = Window(col_off, row_off, self.window_size, self.window_size)
|
93 |
+
window_transform = src.window_transform(window)
|
94 |
+
crs = src.crs
|
95 |
+
|
96 |
+
return window, window_transform, crs
|
97 |
+
|
98 |
+
def extract_window(self, path, window):
|
99 |
+
"""Extract data from the specified window from the GeoTIFF."""
|
100 |
+
with rasterio.open(path) as src:
|
101 |
+
window_data = src.read(window=window)
|
102 |
+
|
103 |
+
if self.debug:
|
104 |
+
print(f"Extracted window data from {path}")
|
105 |
+
print(f"Min: {window_data.min()}, Max: {window_data.max()}")
|
106 |
+
|
107 |
+
return window_data
|
108 |
+
|
109 |
+
def prepare_data_for_model(self, features_data):
|
110 |
+
"""Prepare the window data for model inference."""
|
111 |
+
# Convert to tensor
|
112 |
+
features_data = torch.tensor(features_data, dtype=torch.float32)
|
113 |
+
|
114 |
+
# Normalize
|
115 |
+
normalize = transforms.Normalize(mean=self.means, std=self.stds)
|
116 |
+
features_data = normalize(features_data)
|
117 |
+
|
118 |
+
# Permute the dimensions if needed
|
119 |
+
height, width = features_data.shape[-2:]
|
120 |
+
features_data = features_data.view(self.n_timesteps, 6, height, width).permute(1, 0, 2, 3)
|
121 |
+
|
122 |
+
# Add batch dimension
|
123 |
+
features_data = features_data.unsqueeze(0)
|
124 |
+
|
125 |
+
return features_data
|
126 |
+
|
127 |
+
def get_data(self, lon, lat):
|
128 |
+
"""Extract, normalize, and prepare data for inference, including labels and field IDs."""
|
129 |
+
# Identify the window and get the georeferencing information
|
130 |
+
window, features_transform, features_crs = self.identify_window(self.features_path, lon, lat)
|
131 |
+
|
132 |
+
# Extract data from the GeoTIFF, labels, and field IDs
|
133 |
+
features_data = self.extract_window(self.features_path, window)
|
134 |
+
label_data = self.extract_window(self.labels_path, window)
|
135 |
+
field_ids_data = self.extract_window(self.field_ids_path, window)
|
136 |
+
|
137 |
+
# Prepare the window data for the model
|
138 |
+
prepared_features_data = self.prepare_data_for_model(features_data)
|
139 |
+
|
140 |
+
# Convert labels and field_ids to tensors (without normalization)
|
141 |
+
label_data = torch.tensor(label_data, dtype=torch.long)
|
142 |
+
field_ids_data = torch.tensor(field_ids_data, dtype=torch.long)
|
143 |
+
|
144 |
+
# Return the prepared data along with transform and CRS
|
145 |
+
return prepared_features_data, label_data, field_ids_data, features_transform, features_crs
|
146 |
+
|
147 |
+
def crop_predictions_to_gdf(field_ids, targets, predictions, transform, crs, class_names):
|
148 |
+
"""
|
149 |
+
Convert field_ids, targets, and predictions tensors to field polygons with corresponding class reference.
|
150 |
+
|
151 |
+
:param field_ids: PyTorch tensor of shape (1, 224, 224) representing individual fields
|
152 |
+
:param targets: PyTorch tensor of shape (1, 224, 224) for targets
|
153 |
+
:param predictions: PyTorch tensor of shape (1, 224, 224) for predictions
|
154 |
+
:param transform: Affine transform for georeferencing
|
155 |
+
:param crs: Coordinate reference system (CRS) of the data
|
156 |
+
:param class_names: Dictionary mapping class indices to class names
|
157 |
+
:return: GeoPandas DataFrame with polygons, prediction class labels, and target class labels
|
158 |
+
"""
|
159 |
+
field_array = field_ids.squeeze().cpu().numpy().astype(np.int32)
|
160 |
+
target_array = targets.squeeze().cpu().numpy().astype(np.int8)
|
161 |
+
pred_array = predictions.squeeze().cpu().numpy().astype(np.int8)
|
162 |
+
|
163 |
+
polygons = []
|
164 |
+
field_values = []
|
165 |
+
target_values = []
|
166 |
+
pred_values = []
|
167 |
+
|
168 |
+
for geom, field_value in shapes(field_array, transform=transform):
|
169 |
+
polygons.append(shape(geom))
|
170 |
+
field_values.append(field_value)
|
171 |
+
|
172 |
+
# Get a single value from the field area for targets and predictions
|
173 |
+
target_value = target_array[field_array == field_value][0]
|
174 |
+
pred_value = pred_array[field_array == field_value][0]
|
175 |
+
|
176 |
+
target_values.append(target_value)
|
177 |
+
pred_values.append(pred_value)
|
178 |
+
|
179 |
+
gdf = gpd.GeoDataFrame({
|
180 |
+
'geometry': polygons,
|
181 |
+
'field_id': field_values,
|
182 |
+
'target': target_values,
|
183 |
+
'prediction': pred_values
|
184 |
+
}, crs=crs)
|
185 |
+
|
186 |
+
gdf['prediction_class'] = gdf['prediction'].apply(lambda x: class_names[x])
|
187 |
+
gdf['target_class'] = gdf['target'].apply(lambda x: class_names[x])
|
188 |
+
|
189 |
+
gdf['correct'] = gdf['target'] == gdf['prediction']
|
190 |
+
|
191 |
+
gdf = gdf[gdf.geometry.area > 250] # Threshold for small polygons
|
192 |
+
|
193 |
+
return gdf
|
194 |
+
|
195 |
+
def perform_inference(lon, lat, model, config, debug=False):
|
196 |
+
features_path = "./data/stacked_features.tif"
|
197 |
+
labels_path = "./data/labels.tif"
|
198 |
+
field_ids_path = "./data/field_ids.tif"
|
199 |
+
stats_path = "./data/chips_stats.yaml"
|
200 |
+
|
201 |
+
loader = InferenceDataLoader(features_path, labels_path, field_ids_path, stats_path, n_timesteps=9, fold_indices=[0], debug=True)
|
202 |
+
|
203 |
+
# Coordinates must be in EPSG:4326 and lon lat order - are converted to the CRS of the raster
|
204 |
+
satellite_data, label_data, field_ids_data, features_transform, features_crs = loader.get_data(lon, lat)
|
205 |
+
|
206 |
+
if debug:
|
207 |
+
# Print the shape of the extracted data
|
208 |
+
print(satellite_data.shape)
|
209 |
+
print(label_data.shape)
|
210 |
+
print(field_ids_data.shape)
|
211 |
+
|
212 |
+
with open('./data/dataset_info.json', 'r') as file:
|
213 |
+
dataset_info = json.load(file)
|
214 |
+
class_names = dataset_info['tier3']
|
215 |
+
|
216 |
+
tiers_dict = {k: v for k, v in config.hparams.get('heads_spec').items() if v.get('is_metrics_tier', False)}
|
217 |
+
tiers = list(tiers_dict.keys())
|
218 |
+
|
219 |
+
# Perform inference
|
220 |
+
model.eval()
|
221 |
+
with torch.no_grad():
|
222 |
+
output = model(satellite_data)['tier3_refinement_head']
|
223 |
+
|
224 |
+
pixelwise_outputs_stacked, majority_outputs_stacked = LogConfusionMatrix.get_pixelwise_and_majority_outputs(output, tiers, field_ids=field_ids_data, dataset_info=dataset_info)
|
225 |
+
majority_tier3_predictions = majority_outputs_stacked[2] # Tier 3 predictions
|
226 |
+
|
227 |
+
# Convert the predictions to a GeoDataFrame
|
228 |
+
gdf = crop_predictions_to_gdf(field_ids_data, label_data, majority_tier3_predictions, features_transform, features_crs, class_names)
|
229 |
+
|
230 |
+
# Simple GeoDataFrame with only the necessary columns
|
231 |
+
gdf = gdf[['prediction_class', 'target_class', 'correct', 'geometry']]
|
232 |
+
gdf.columns = ['Prediction', 'Target', 'Correct', 'geometry']
|
233 |
+
# gdf = gdf[gdf['Target'] != 'Background']
|
234 |
+
|
235 |
+
return gdf
|
main.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
|
3 |
+
def main():
|
4 |
+
st.set_page_config(layout="wide", page_title="Messis 🌾 - Crop Classification 🌎")
|
5 |
+
|
6 |
+
st.title("Messis 🌾 - Crop Classification 🌎")
|
7 |
+
|
8 |
+
st.write("Welcome to the Messis Crop Classification app. Use the sidebar to navigate between selecting coordinates and performing inference.")
|
9 |
+
|
10 |
+
st.page_link("main.py", label="Home", icon="🏠")
|
11 |
+
st.page_link("pages/1_Select_Location.py", label="Select Location", icon="📍")
|
12 |
+
st.page_link("pages/2_Perform_Crop_Classification.py", label="Perform Crop Classification", icon="🔍")
|
13 |
+
|
14 |
+
if __name__ == "__main__":
|
15 |
+
main()
|
messis/README.md
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# About
|
2 |
+
|
3 |
+
This package contains the code for the crop classification model Messis.
|
4 |
+
|
5 |
+
It can be found on Hugging Face at [this link](https://huggingface.co/crop-classification/messis).
|
6 |
+
|
7 |
+
TODO: Add more information about the model.
|
messis/__init__.py
ADDED
File without changes
|
messis/dataloader.py
ADDED
@@ -0,0 +1,287 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
import torch
|
3 |
+
from torch.utils.data import Dataset, DataLoader
|
4 |
+
from torchvision import transforms
|
5 |
+
from pytorch_lightning import LightningDataModule
|
6 |
+
import os
|
7 |
+
import re
|
8 |
+
import yaml
|
9 |
+
import rasterio
|
10 |
+
import dvc.api
|
11 |
+
|
12 |
+
|
13 |
+
params = dvc.api.params_show()
|
14 |
+
N_TIMESTEPS = params['number_of_timesteps']
|
15 |
+
|
16 |
+
class ToTensorTransform(object):
|
17 |
+
def __init__(self, dtype):
|
18 |
+
self.dtype = dtype
|
19 |
+
|
20 |
+
def __call__(self, data):
|
21 |
+
return torch.tensor(data, dtype=self.dtype)
|
22 |
+
|
23 |
+
class NormalizeTransform(object):
|
24 |
+
def __init__(self, means, stds):
|
25 |
+
self.mean = means
|
26 |
+
self.std = stds
|
27 |
+
|
28 |
+
def __call__(self, data):
|
29 |
+
return transforms.Normalize(self.mean, self.std)(data)
|
30 |
+
|
31 |
+
class PermuteTransform:
|
32 |
+
def __call__(self, data):
|
33 |
+
height, width = data.shape[-2:]
|
34 |
+
|
35 |
+
# Ensure the channel dimension is as expected
|
36 |
+
if data.shape[0] != N_TIMESTEPS * 6:
|
37 |
+
raise ValueError(f"Expected {N_TIMESTEPS*6} channels, got {data.shape[1]}")
|
38 |
+
|
39 |
+
# Step 1: Reshape the data to group the N_TIMESTEPS*6 bands into N_TIMESTEPS groups of 6 bands
|
40 |
+
data = data.view(N_TIMESTEPS, 6, height, width)
|
41 |
+
|
42 |
+
# Step 2: Permute to bring the bands to the front
|
43 |
+
data = data.permute(1, 0, 2, 3) # NOTE: Prithvi wants it bands first # after this, shape is (6, N_TIMESTEPS, height, width)
|
44 |
+
return data
|
45 |
+
|
46 |
+
class RandomFlipAndJitterTransform:
|
47 |
+
"""
|
48 |
+
Apply random horizontal and vertical flips, and channel jitter to the input image and corresponding mask.
|
49 |
+
|
50 |
+
Parameters:
|
51 |
+
-----------
|
52 |
+
flip_prob : float, optional (default=0.5)
|
53 |
+
Probability of applying horizontal and vertical flips to the image and mask.
|
54 |
+
Each flip (horizontal and vertical) is applied independently based on this probability.
|
55 |
+
|
56 |
+
jitter_std : float, optional (default=0.02)
|
57 |
+
Standard deviation of the Gaussian noise added to the image channels for jitter.
|
58 |
+
This value controls the intensity of the random noise applied to the image channels.
|
59 |
+
|
60 |
+
Effects of Parameters:
|
61 |
+
----------------------
|
62 |
+
flip_prob:
|
63 |
+
- Higher flip_prob increases the likelihood of the image and mask being flipped.
|
64 |
+
- A value of 0 means no flipping, while a value of 1 means always flip.
|
65 |
+
|
66 |
+
jitter_std:
|
67 |
+
- Higher jitter_std increases the intensity of the noise added to the image channels.
|
68 |
+
- A value of 0 means no noise, while larger values add more significant noise.
|
69 |
+
"""
|
70 |
+
def __init__(self, flip_prob=0.5, jitter_std=0.02):
|
71 |
+
self.flip_prob = flip_prob
|
72 |
+
self.jitter_std = jitter_std
|
73 |
+
|
74 |
+
def __call__(self, img, mask, field_ids):
|
75 |
+
# Shapes (..., H, W)| img: torch.Size([6, N_TIMESTEPS, 224, 224]), mask: torch.Size([N_TIMESTEPS, 224, 224]), field_ids: torch.Size([1, 224, 224])
|
76 |
+
|
77 |
+
# Temporarily convert field_ids to int32 for flipping (flip not implemented for uint16)
|
78 |
+
field_ids = field_ids.to(torch.int32)
|
79 |
+
|
80 |
+
# Random horizontal flip
|
81 |
+
if random.random() < self.flip_prob:
|
82 |
+
img = torch.flip(img, [2])
|
83 |
+
mask = torch.flip(mask, [1])
|
84 |
+
field_ids = torch.flip(field_ids, [1])
|
85 |
+
|
86 |
+
# Random vertical flip
|
87 |
+
if random.random() < self.flip_prob:
|
88 |
+
img = torch.flip(img, [3])
|
89 |
+
mask = torch.flip(mask, [2])
|
90 |
+
field_ids = torch.flip(field_ids, [2])
|
91 |
+
|
92 |
+
# Convert field_ids back to uint16
|
93 |
+
field_ids = field_ids.to(torch.uint16)
|
94 |
+
|
95 |
+
# Channel jitter
|
96 |
+
noise = torch.randn(img.size()) * self.jitter_std
|
97 |
+
img += noise
|
98 |
+
|
99 |
+
return img, mask, field_ids
|
100 |
+
|
101 |
+
def get_img_transforms():
|
102 |
+
return transforms.Compose([])
|
103 |
+
|
104 |
+
def get_mask_transforms():
|
105 |
+
return transforms.Compose([])
|
106 |
+
|
107 |
+
class GeospatialDataset(Dataset):
|
108 |
+
def __init__(self, data_dir, fold_indicies, transform_img=None, transform_mask=None, transform_field_ids=None, debug=False, subset_size=None, data_augmentation=None):
|
109 |
+
self.data_dir = data_dir
|
110 |
+
self.chips_dir = os.path.join(data_dir, 'chips')
|
111 |
+
self.transform_img = transform_img
|
112 |
+
self.transform_mask = transform_mask
|
113 |
+
self.transform_field_ids = transform_field_ids
|
114 |
+
self.debug = debug
|
115 |
+
self.images = []
|
116 |
+
self.masks = []
|
117 |
+
self.field_ids = []
|
118 |
+
self.data_augmentation = data_augmentation
|
119 |
+
|
120 |
+
self.means, self.stds = self.load_stats(fold_indicies, N_TIMESTEPS)
|
121 |
+
self.transform_img_load = self.get_img_load_transforms(self.means, self.stds)
|
122 |
+
self.transform_mask_load = self.get_mask_load_transforms()
|
123 |
+
self.transform_field_ids_load = self.get_field_ids_load_transforms()
|
124 |
+
|
125 |
+
# Adjust file selection based on fold
|
126 |
+
for file in os.listdir(self.chips_dir):
|
127 |
+
if re.match(f".*_fold_[{''.join([str(f) for f in fold_indicies])}]_merged.tif", file):
|
128 |
+
self.images.append(file)
|
129 |
+
mask_file = file.replace("_merged.tif", "_mask.tif")
|
130 |
+
self.masks.append(mask_file)
|
131 |
+
field_ids_file = file.replace("_merged.tif", "_field_ids.tif")
|
132 |
+
self.field_ids.append(field_ids_file)
|
133 |
+
|
134 |
+
assert len(self.images) == len(self.masks), "Number of images and masks do not match"
|
135 |
+
|
136 |
+
# If subset_size is specified, randomly select a subset of the data
|
137 |
+
if subset_size is not None and len(self.images) > subset_size:
|
138 |
+
print(f"Randomly selecting {subset_size} samples from {len(self.images)} samples.")
|
139 |
+
selected_indices = random.sample(range(len(self.images)), subset_size)
|
140 |
+
self.images = [self.images[i] for i in selected_indices]
|
141 |
+
self.masks = [self.masks[i] for i in selected_indices]
|
142 |
+
self.field_ids = [self.field_ids[i] for i in selected_indices]
|
143 |
+
|
144 |
+
def load_stats(self, fold_indicies, n_timesteps=3):
|
145 |
+
"""Load normalization statistics for dataset from YAML file."""
|
146 |
+
stats_path = os.path.join(self.data_dir, 'chips_stats.yaml')
|
147 |
+
if self.debug:
|
148 |
+
print(f"Loading mean/std stats from {stats_path}")
|
149 |
+
assert os.path.exists(stats_path), f"mean/std stats file for dataset not found at {stats_path}"
|
150 |
+
with open(stats_path, 'r') as file:
|
151 |
+
stats = yaml.safe_load(file)
|
152 |
+
mean_list, std_list, n_list = [], [], []
|
153 |
+
for fold in fold_indicies:
|
154 |
+
key = f'fold_{fold}'
|
155 |
+
if key not in stats:
|
156 |
+
raise ValueError(f"mean/std stats for fold {fold} not found in {stats_path}")
|
157 |
+
if self.debug:
|
158 |
+
print(f"Stats with selected test fold {fold}: {stats[key]} over {n_timesteps} timesteps.")
|
159 |
+
mean_list.append(torch.Tensor(stats[key]['mean'])) # list of 6 means
|
160 |
+
std_list.append(torch.Tensor(stats[key]['std'])) # list of 6 stds
|
161 |
+
n_list.append(stats[key]['n_chips']) # list of 6 ns
|
162 |
+
# aggregate means and stds over all folds
|
163 |
+
means, stds = [], []
|
164 |
+
for channel in range(mean_list[0].shape[0]):
|
165 |
+
means.append(torch.stack([mean_list[i][channel] for i in range(len(mean_list))]).mean())
|
166 |
+
# stds are waaaay more complex to aggregate
|
167 |
+
# \sqrt{\frac{\sum_{i=1}^{n} (\sigma_i * (n_i - 1))}{\sum_{i=1}^{n} (n_i) - n}}
|
168 |
+
variances = torch.stack([std_list[i][channel] ** 2 for i in range(len(std_list))])
|
169 |
+
n = torch.tensor([n_list[i] for i in range(len(n_list))], dtype=torch.float32)
|
170 |
+
combined_variance = torch.sum(variances * (n - 1)) / (torch.sum(n) - len(n_list))
|
171 |
+
stds.append(torch.sqrt(combined_variance))
|
172 |
+
|
173 |
+
# make means and stds into 2d arrays, as torchvision would otherwise convert it into a 3d tensor which is incompatible with our 4d temporal images
|
174 |
+
# https://github.com/pytorch/vision/blob/6e18cea3485066b7277785415bf2e0422dbdb9da/torchvision/transforms/_functional_tensor.py#L923
|
175 |
+
return means * n_timesteps, stds * n_timesteps
|
176 |
+
|
177 |
+
def get_img_load_transforms(self, means, stds):
|
178 |
+
return transforms.Compose([
|
179 |
+
ToTensorTransform(torch.float32),
|
180 |
+
NormalizeTransform(means, stds),
|
181 |
+
PermuteTransform()
|
182 |
+
])
|
183 |
+
|
184 |
+
def get_mask_load_transforms(self):
|
185 |
+
return transforms.Compose([
|
186 |
+
ToTensorTransform(torch.uint8)
|
187 |
+
])
|
188 |
+
|
189 |
+
def get_field_ids_load_transforms(self):
|
190 |
+
return transforms.Compose([
|
191 |
+
ToTensorTransform(torch.uint16)
|
192 |
+
])
|
193 |
+
|
194 |
+
def __len__(self):
|
195 |
+
return len(self.images)
|
196 |
+
|
197 |
+
def __getitem__(self, idx):
|
198 |
+
img_path = os.path.join(self.chips_dir, self.images[idx])
|
199 |
+
mask_path = os.path.join(self.chips_dir, self.masks[idx])
|
200 |
+
field_ids_path = os.path.join(self.chips_dir, self.field_ids[idx])
|
201 |
+
|
202 |
+
img = rasterio.open(img_path).read().astype('uint16')
|
203 |
+
mask = rasterio.open(mask_path).read().astype('uint8')
|
204 |
+
field_ids = rasterio.open(field_ids_path).read().astype('uint16')
|
205 |
+
|
206 |
+
# Apply our base transforms
|
207 |
+
img = self.transform_img_load(img)
|
208 |
+
mask = self.transform_mask_load(mask)
|
209 |
+
field_ids = self.transform_field_ids_load(field_ids)
|
210 |
+
|
211 |
+
# Apply additional transforms passed from GeospatialDataModule if applicable
|
212 |
+
if self.transform_img is not None:
|
213 |
+
img = self.transform_img(img)
|
214 |
+
if self.transform_mask is not None:
|
215 |
+
mask = self.transform_mask(mask)
|
216 |
+
if self.transform_field_ids is not None:
|
217 |
+
field_ids = self.transform_field_ids(field_ids)
|
218 |
+
|
219 |
+
# Apply data augmentation if enabled
|
220 |
+
if self.data_augmentation is not None and self.data_augmentation.get('enabled', True):
|
221 |
+
img, mask, field_ids = RandomFlipAndJitterTransform(
|
222 |
+
flip_prob=self.data_augmentation.get('flip_prob', 0.5),
|
223 |
+
jitter_std=self.data_augmentation.get('jitter_std', 0.02)
|
224 |
+
)(img, mask, field_ids)
|
225 |
+
|
226 |
+
# Load targets for given tiers
|
227 |
+
num_tiers = mask.shape[0]
|
228 |
+
targets = ()
|
229 |
+
for i in range(num_tiers):
|
230 |
+
targets += (mask[i, :, :].type(torch.long),)
|
231 |
+
|
232 |
+
return img, (targets, field_ids)
|
233 |
+
|
234 |
+
class GeospatialDataModule(LightningDataModule):
|
235 |
+
def __init__(self, data_dir, train_folds, val_folds, test_folds, batch_size=8, num_workers=4, debug=False, subsets=None, data_augmentation=None):
|
236 |
+
super().__init__()
|
237 |
+
self.data_dir = data_dir
|
238 |
+
self.batch_size = batch_size
|
239 |
+
self.num_workers = num_workers
|
240 |
+
self.debug = debug
|
241 |
+
self.subsets = subsets if subsets is not None else {}
|
242 |
+
self.data_augmentation = data_augmentation if data_augmentation is not None else {}
|
243 |
+
|
244 |
+
GeospatialDataModule.validate_folds(train_folds, val_folds, test_folds)
|
245 |
+
self.train_folds = train_folds
|
246 |
+
self.val_folds = val_folds
|
247 |
+
self.test_folds = test_folds
|
248 |
+
|
249 |
+
# NOTE: Transforms on this level not used for now
|
250 |
+
self.transform_img = get_img_transforms()
|
251 |
+
self.transform_mask = get_mask_transforms()
|
252 |
+
|
253 |
+
@staticmethod
|
254 |
+
def validate_folds(train, val, test):
|
255 |
+
if train is None or val is None or test is None:
|
256 |
+
raise ValueError("All fold sets must be specified")
|
257 |
+
|
258 |
+
if len(set(train) & set(val)) > 0 or len(set(train) & set(test)) > 0 or len(set(val) & set(test)) > 0:
|
259 |
+
raise ValueError("Folds must be mutually exclusive")
|
260 |
+
|
261 |
+
def setup(self, stage=None):
|
262 |
+
print(f"Setting up GeospatialDataModule for stage: {stage}. Data augmentation config: {self.data_augmentation}")
|
263 |
+
common_params = {
|
264 |
+
'data_dir': self.data_dir,
|
265 |
+
'debug': self.debug,
|
266 |
+
'data_augmentation': self.data_augmentation
|
267 |
+
}
|
268 |
+
common_params_val_test = { # Never augment validation or test data
|
269 |
+
**common_params,
|
270 |
+
'data_augmentation': {
|
271 |
+
'enabled': False
|
272 |
+
}
|
273 |
+
}
|
274 |
+
if stage in ('fit', None):
|
275 |
+
self.train_dataset = GeospatialDataset(fold_indicies=self.train_folds, subset_size=self.subsets.get('train', None), **common_params)
|
276 |
+
self.val_dataset = GeospatialDataset(fold_indicies=self.val_folds, subset_size=self.subsets.get('val', None), **common_params_val_test)
|
277 |
+
if stage in ('test', None):
|
278 |
+
self.test_dataset = GeospatialDataset(fold_indicies=self.test_folds, subset_size=self.subsets.get('test', None), **common_params_val_test)
|
279 |
+
|
280 |
+
def train_dataloader(self):
|
281 |
+
return DataLoader(self.train_dataset, batch_size=self.batch_size, num_workers=self.num_workers, persistent_workers=True, shuffle=True)
|
282 |
+
|
283 |
+
def val_dataloader(self):
|
284 |
+
return DataLoader(self.val_dataset, batch_size=self.batch_size, num_workers=self.num_workers, persistent_workers=True)
|
285 |
+
|
286 |
+
def test_dataloader(self):
|
287 |
+
return DataLoader(self.test_dataset, batch_size=self.batch_size, num_workers=self.num_workers, persistent_workers=True)
|
messis/messis.py
ADDED
@@ -0,0 +1,919 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import pytorch_lightning as pl
|
4 |
+
from torchmetrics import classification
|
5 |
+
import wandb
|
6 |
+
from matplotlib import pyplot as plt
|
7 |
+
import numpy as np
|
8 |
+
import matplotlib.ticker as ticker
|
9 |
+
from matplotlib.colors import ListedColormap
|
10 |
+
from huggingface_hub import PyTorchModelHubMixin
|
11 |
+
from lion_pytorch import Lion
|
12 |
+
|
13 |
+
import json
|
14 |
+
|
15 |
+
from messis.prithvi import TemporalViTEncoder, ConvTransformerTokensToEmbeddingNeck, ConvTransformerTokensToEmbeddingBottleneckNeck
|
16 |
+
|
17 |
+
|
18 |
+
def safe_shape(x):
|
19 |
+
if isinstance(x, tuple):
|
20 |
+
# loop through tuple
|
21 |
+
shape_info = '(tuple) : '
|
22 |
+
for i in x:
|
23 |
+
shape_info += str(i.shape) + ', '
|
24 |
+
return shape_info
|
25 |
+
if isinstance(x, list):
|
26 |
+
# loop through list
|
27 |
+
shape_info = '(list) : '
|
28 |
+
for i in x:
|
29 |
+
shape_info += str(i.shape) + ', '
|
30 |
+
return shape_info
|
31 |
+
return x.shape
|
32 |
+
|
33 |
+
class ConvModule(nn.Module):
|
34 |
+
"""
|
35 |
+
A simple convolutional module including Conv, BatchNorm, and ReLU layers.
|
36 |
+
"""
|
37 |
+
def __init__(self, in_channels, out_channels, kernel_size, padding, dilation):
|
38 |
+
super(ConvModule, self).__init__()
|
39 |
+
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=1, padding=padding, dilation=dilation, bias=False)
|
40 |
+
self.bn = nn.BatchNorm2d(out_channels)
|
41 |
+
self.relu = nn.ReLU(inplace=True)
|
42 |
+
|
43 |
+
def forward(self, x):
|
44 |
+
x = self.conv(x)
|
45 |
+
x = self.bn(x)
|
46 |
+
return self.relu(x)
|
47 |
+
|
48 |
+
class HierarchicalFCNHead(nn.Module):
|
49 |
+
"""
|
50 |
+
Hierarchical FCN Head for semantic segmentation.
|
51 |
+
"""
|
52 |
+
def __init__(self, in_channels, out_channels, num_classes, num_convs=2, kernel_size=3, dilation=1, dropout_p=0.1, debug=False):
|
53 |
+
super(HierarchicalFCNHead, self).__init__()
|
54 |
+
|
55 |
+
self.debug = debug
|
56 |
+
|
57 |
+
self.convs = nn.Sequential(*[
|
58 |
+
ConvModule(
|
59 |
+
in_channels if i == 0 else out_channels,
|
60 |
+
out_channels,
|
61 |
+
kernel_size,
|
62 |
+
padding=dilation * (kernel_size // 2),
|
63 |
+
dilation=dilation
|
64 |
+
) for i in range(num_convs)
|
65 |
+
])
|
66 |
+
|
67 |
+
self.conv_seg = nn.Conv2d(out_channels, num_classes, kernel_size=1)
|
68 |
+
self.dropout = nn.Dropout2d(p=dropout_p)
|
69 |
+
|
70 |
+
def forward(self, x):
|
71 |
+
if self.debug:
|
72 |
+
print('HierarchicalFCNHead forward INP: ', safe_shape(x))
|
73 |
+
x = self.convs(x)
|
74 |
+
features = self.dropout(x)
|
75 |
+
output = self.conv_seg(features)
|
76 |
+
if self.debug:
|
77 |
+
print('HierarchicalFCNHead forward features OUT: ', safe_shape(features))
|
78 |
+
print('HierarchicalFCNHead forward output OUT: ', safe_shape(output))
|
79 |
+
return output, features
|
80 |
+
|
81 |
+
class LabelRefinementHead(nn.Module):
|
82 |
+
"""
|
83 |
+
Similar to the label refinement module introduced in the ZueriCrop paper, this module refines the predictions for tier 3.
|
84 |
+
It takes the raw predictions from head 1, head 2 and head 3 and refines them to produce the final prediction for tier 3.
|
85 |
+
According to ZueriCrop, this helps with making the predictions more consistent across the different tiers.
|
86 |
+
"""
|
87 |
+
def __init__(self, input_channels, num_classes):
|
88 |
+
super(LabelRefinementHead, self).__init__()
|
89 |
+
|
90 |
+
self.cnn_layers = nn.Sequential(
|
91 |
+
# 1x1 Convolutional layer
|
92 |
+
nn.Conv2d(in_channels=input_channels, out_channels=128, kernel_size=1, stride=1, padding=0),
|
93 |
+
nn.BatchNorm2d(128),
|
94 |
+
nn.ReLU(inplace=True),
|
95 |
+
|
96 |
+
# 3x3 Convolutional layer
|
97 |
+
nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),
|
98 |
+
nn.BatchNorm2d(128),
|
99 |
+
nn.ReLU(inplace=True),
|
100 |
+
nn.Dropout(p=0.5),
|
101 |
+
|
102 |
+
# Skip connection (implemented in forward method)
|
103 |
+
|
104 |
+
# Another 3x3 Convolutional layer
|
105 |
+
nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),
|
106 |
+
nn.BatchNorm2d(128),
|
107 |
+
nn.ReLU(inplace=True),
|
108 |
+
|
109 |
+
# 1x1 Convolutional layer to adjust the number of output channels to num_classes
|
110 |
+
nn.Conv2d(in_channels=128, out_channels=num_classes, kernel_size=1, stride=1, padding=0),
|
111 |
+
nn.Dropout(p=0.5)
|
112 |
+
)
|
113 |
+
|
114 |
+
def forward(self, x):
|
115 |
+
# Apply initial conv layer
|
116 |
+
y = self.cnn_layers[0:3](x)
|
117 |
+
|
118 |
+
# Save for skip connection
|
119 |
+
y_skip = y
|
120 |
+
|
121 |
+
# Apply the next two conv layers
|
122 |
+
y = self.cnn_layers[3:9](y)
|
123 |
+
|
124 |
+
# Skip connection (element-wise addition)
|
125 |
+
y = y + y_skip
|
126 |
+
|
127 |
+
# Apply the last conv layer
|
128 |
+
y = self.cnn_layers[9:](y)
|
129 |
+
return y
|
130 |
+
|
131 |
+
class HierarchicalClassifier(nn.Module):
|
132 |
+
def __init__(
|
133 |
+
self,
|
134 |
+
heads_spec,
|
135 |
+
dropout_p=0.1,
|
136 |
+
img_size=256,
|
137 |
+
patch_size=16,
|
138 |
+
num_frames=3,
|
139 |
+
bands=[0, 1, 2, 3, 4, 5],
|
140 |
+
backbone_weights_path=None,
|
141 |
+
freeze_backbone=True,
|
142 |
+
use_bottleneck_neck=False,
|
143 |
+
bottleneck_reduction_factor=4,
|
144 |
+
loss_ignore_background=False,
|
145 |
+
debug=False
|
146 |
+
):
|
147 |
+
super(HierarchicalClassifier, self).__init__()
|
148 |
+
|
149 |
+
self.embed_dim = 768
|
150 |
+
if num_frames % 3 != 0:
|
151 |
+
raise ValueError("The number of frames must be a multiple of 3, it is currently: ", num_frames)
|
152 |
+
self.num_frames = num_frames
|
153 |
+
self.hp, self.wp = img_size // patch_size, img_size // patch_size
|
154 |
+
self.heads_spec = heads_spec
|
155 |
+
self.dropout_p = dropout_p
|
156 |
+
self.loss_ignore_background = loss_ignore_background
|
157 |
+
self.debug = debug
|
158 |
+
|
159 |
+
if self.debug:
|
160 |
+
print('hp and wp: ', self.hp, self.wp)
|
161 |
+
|
162 |
+
self.prithvi = TemporalViTEncoder(
|
163 |
+
img_size=img_size,
|
164 |
+
patch_size=patch_size,
|
165 |
+
num_frames=3,
|
166 |
+
tubelet_size=1,
|
167 |
+
in_chans=len(bands),
|
168 |
+
embed_dim=self.embed_dim,
|
169 |
+
depth=12,
|
170 |
+
num_heads=8,
|
171 |
+
mlp_ratio=4.0,
|
172 |
+
norm_pix_loss=False,
|
173 |
+
pretrained=backbone_weights_path,
|
174 |
+
debug=self.debug
|
175 |
+
)
|
176 |
+
|
177 |
+
# (Un)freeze the backbone
|
178 |
+
for param in self.prithvi.parameters():
|
179 |
+
param.requires_grad = not freeze_backbone
|
180 |
+
|
181 |
+
# Neck to transform the token-based output of the transformer into a spatial feature map
|
182 |
+
number_of_necks = self.num_frames // 3
|
183 |
+
if use_bottleneck_neck:
|
184 |
+
self.necks = nn.ModuleList([ConvTransformerTokensToEmbeddingBottleneckNeck(
|
185 |
+
embed_dim=self.embed_dim * 3,
|
186 |
+
output_embed_dim=self.embed_dim * 3,
|
187 |
+
drop_cls_token=True,
|
188 |
+
Hp=self.hp,
|
189 |
+
Wp=self.wp,
|
190 |
+
bottleneck_reduction_factor=bottleneck_reduction_factor
|
191 |
+
) for _ in range(number_of_necks)])
|
192 |
+
else:
|
193 |
+
self.necks = nn.ModuleList([ConvTransformerTokensToEmbeddingNeck(
|
194 |
+
embed_dim=self.embed_dim * 3,
|
195 |
+
output_embed_dim=self.embed_dim * 3,
|
196 |
+
drop_cls_token=True,
|
197 |
+
Hp=self.hp,
|
198 |
+
Wp=self.wp,
|
199 |
+
) for _ in range(number_of_necks)])
|
200 |
+
|
201 |
+
# Initialize heads and loss weights based on tiers
|
202 |
+
self.heads = nn.ModuleDict()
|
203 |
+
self.loss_weights = {}
|
204 |
+
self.total_classes = 0
|
205 |
+
|
206 |
+
# Build HierarchicalFCNHeads
|
207 |
+
head_count = 0
|
208 |
+
for head_name, head_info in self.heads_spec.items():
|
209 |
+
head_type = head_info['type']
|
210 |
+
num_classes = head_info['num_classes_to_predict']
|
211 |
+
loss_weight = head_info['loss_weight']
|
212 |
+
|
213 |
+
if head_type == 'HierarchicalFCNHead':
|
214 |
+
num_classes = head_info['num_classes_to_predict']
|
215 |
+
loss_weight = head_info['loss_weight']
|
216 |
+
kernel_size = head_info.get('kernel_size', 3)
|
217 |
+
num_convs = head_info.get('num_convs', 1)
|
218 |
+
num_channels = head_info.get('num_channels', 256)
|
219 |
+
self.total_classes += num_classes
|
220 |
+
|
221 |
+
self.heads[head_name] = HierarchicalFCNHead(
|
222 |
+
in_channels=(self.embed_dim * self.num_frames) if head_count == 0 else num_channels,
|
223 |
+
out_channels=num_channels,
|
224 |
+
num_classes=num_classes,
|
225 |
+
num_convs=num_convs,
|
226 |
+
kernel_size=kernel_size,
|
227 |
+
dropout_p=self.dropout_p,
|
228 |
+
debug=self.debug
|
229 |
+
)
|
230 |
+
self.loss_weights[head_name] = loss_weight
|
231 |
+
|
232 |
+
# NOTE: LabelRefinementHead must be the last in the dict, otherwise the total_classes will be incorrect
|
233 |
+
if head_type == 'LabelRefinementHead':
|
234 |
+
self.refinement_head = LabelRefinementHead(input_channels=self.total_classes, num_classes=num_classes)
|
235 |
+
self.refinement_head_name = head_name
|
236 |
+
self.loss_weights[head_name] = loss_weight
|
237 |
+
|
238 |
+
head_count += 1
|
239 |
+
|
240 |
+
self.loss_func = nn.CrossEntropyLoss(ignore_index=-1)
|
241 |
+
|
242 |
+
def forward(self, x):
|
243 |
+
if self.debug:
|
244 |
+
print(f"Input shape: {safe_shape(x)}") # torch.Size([4, 6, 9, 224, 224])
|
245 |
+
|
246 |
+
# Extract features from the base model
|
247 |
+
if len(self.necks) == 1:
|
248 |
+
features = [x]
|
249 |
+
else:
|
250 |
+
features = torch.chunk(x, len(self.necks), dim=2)
|
251 |
+
features = [self.prithvi(x) for x in features]
|
252 |
+
|
253 |
+
if self.debug:
|
254 |
+
print(f"Features shape after base model: {', '.join([safe_shape(f) for f in features])}") # (tuple) : torch.Size([4, 589, 768]), , (tuple) : torch.Size
|
255 |
+
|
256 |
+
# Process through the neck
|
257 |
+
features = [neck(feat_) for feat_, neck in zip(features, self.necks)]
|
258 |
+
|
259 |
+
if self.debug:
|
260 |
+
print(f"Features shape after neck: {', '.join([safe_shape(f) for f in features])}") # (tuple) : torch.Size([4, 2304, 224, 224]), , (tuple) : torch.Size
|
261 |
+
|
262 |
+
# Remove from tuple
|
263 |
+
features = [feat[0] for feat in features]
|
264 |
+
# stack the features to create a tensor of torch.Size([4, 6912, 224, 224])
|
265 |
+
features = torch.concatenate(features, dim=1)
|
266 |
+
if self.debug:
|
267 |
+
print(f"Features shape after removing tuple: {safe_shape(features)}") # torch.Size([4, 6912, 224, 224])
|
268 |
+
|
269 |
+
# Process through the heads
|
270 |
+
outputs = {}
|
271 |
+
for tier_name, head in self.heads.items():
|
272 |
+
output, features = head(features)
|
273 |
+
outputs[tier_name] = output
|
274 |
+
|
275 |
+
if self.debug:
|
276 |
+
print(f"Features shape after {tier_name} head: {safe_shape(features)}")
|
277 |
+
print(f"Output shape after {tier_name} head: {safe_shape(output)}")
|
278 |
+
|
279 |
+
# Process through the classification refinement head
|
280 |
+
output_concatenated = torch.cat(list(outputs.values()), dim=1)
|
281 |
+
output_refinement_head = self.refinement_head(output_concatenated)
|
282 |
+
outputs[self.refinement_head_name] = output_refinement_head
|
283 |
+
|
284 |
+
return outputs
|
285 |
+
|
286 |
+
def calculate_loss(self, outputs, targets):
|
287 |
+
total_loss = 0
|
288 |
+
loss_per_head = {}
|
289 |
+
for head_name, output in outputs.items():
|
290 |
+
if self.debug:
|
291 |
+
print(f"Target index for {head_name}: {self.heads_spec[head_name]['target_idx']}")
|
292 |
+
target = targets[self.heads_spec[head_name]['target_idx']]
|
293 |
+
loss_target = target
|
294 |
+
if self.loss_ignore_background:
|
295 |
+
loss_target = target.clone() # Clone as original target needed in backward pass
|
296 |
+
loss_target[loss_target == 0] = -1 # Set background class to ignore_index -1 for loss calculation
|
297 |
+
loss = self.loss_func(output, loss_target)
|
298 |
+
loss_per_head[f'{head_name}'] = loss
|
299 |
+
total_loss += loss * self.loss_weights[head_name]
|
300 |
+
|
301 |
+
return total_loss, loss_per_head
|
302 |
+
|
303 |
+
class Messis(pl.LightningModule, PyTorchModelHubMixin):
|
304 |
+
def __init__(self, hparams):
|
305 |
+
super().__init__()
|
306 |
+
self.save_hyperparameters(hparams)
|
307 |
+
|
308 |
+
self.model = HierarchicalClassifier(
|
309 |
+
heads_spec=hparams['heads_spec'],
|
310 |
+
dropout_p=hparams.get('dropout_p'),
|
311 |
+
img_size=hparams.get('img_size'),
|
312 |
+
patch_size=hparams.get('patch_size'),
|
313 |
+
num_frames=hparams.get('num_frames'),
|
314 |
+
bands=hparams.get('bands'),
|
315 |
+
backbone_weights_path=hparams.get('backbone_weights_path'),
|
316 |
+
freeze_backbone=hparams['freeze_backbone'],
|
317 |
+
use_bottleneck_neck=hparams.get('use_bottleneck_neck'),
|
318 |
+
bottleneck_reduction_factor=hparams.get('bottleneck_reduction_factor'),
|
319 |
+
loss_ignore_background=hparams.get('loss_ignore_background'),
|
320 |
+
debug=hparams.get('debug')
|
321 |
+
)
|
322 |
+
|
323 |
+
def forward(self, x):
|
324 |
+
return self.model(x)
|
325 |
+
|
326 |
+
def training_step(self, batch, batch_idx):
|
327 |
+
return self.__step(batch, batch_idx, "train")
|
328 |
+
|
329 |
+
def validation_step(self, batch, batch_idx):
|
330 |
+
return self.__step(batch, batch_idx, "val")
|
331 |
+
|
332 |
+
def test_step(self, batch, batch_idx):
|
333 |
+
return self.__step(batch, batch_idx, "test")
|
334 |
+
|
335 |
+
def configure_optimizers(self):
|
336 |
+
# select case on optimizer
|
337 |
+
match self.hparams.get('optimizer', 'Adam'):
|
338 |
+
case 'Adam':
|
339 |
+
optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.get('lr', 1e-3))
|
340 |
+
case 'AdamW':
|
341 |
+
optimizer = torch.optim.AdamW(self.parameters(), lr=self.hparams.get('lr', 1e-3), weight_decay=self.hparams.get('optimizer_weight_decay', 0.01))
|
342 |
+
case 'SGD':
|
343 |
+
optimizer = torch.optim.SGD(self.parameters(), lr=self.hparams.get('lr', 1e-3), momentum=self.hparams.get('optimizer_momentum', 0.9))
|
344 |
+
case 'Lion':
|
345 |
+
# https://github.com/lucidrains/lion-pytorch | Typically lr 3-10 times lower than Adam and weight_decay 3-10 times higher
|
346 |
+
optimizer = Lion(self.parameters(), lr=self.hparams.get('lr', 1e-4), weight_decay=self.hparams.get('optimizer_weight_decay', 0.1))
|
347 |
+
case _:
|
348 |
+
raise ValueError(f"Optimizer {self.hparams.get('optimizer')} not supported")
|
349 |
+
return optimizer
|
350 |
+
|
351 |
+
def __step(self, batch, batch_idx, stage):
|
352 |
+
inputs, targets = batch
|
353 |
+
targets = torch.stack(targets[0])
|
354 |
+
outputs = self(inputs)
|
355 |
+
loss, loss_per_head = self.model.calculate_loss(outputs, targets)
|
356 |
+
loss_per_head_named = {f'{stage}_loss_{head}': loss_per_head[head] for head in loss_per_head}
|
357 |
+
loss_proportions = { f'{stage}_loss_{head}_proportion': round(loss_per_head[head].item() / loss.item(), 2) for head in loss_per_head}
|
358 |
+
loss_detail_dict = {**loss_per_head_named, **loss_proportions}
|
359 |
+
|
360 |
+
if self.hparams.get('debug'):
|
361 |
+
print(f"Step Inputs shape: {safe_shape(inputs)}")
|
362 |
+
print(f"Step Targets shape: {safe_shape(targets)}")
|
363 |
+
print(f"Step Outputs dict keys: {outputs.keys()}")
|
364 |
+
|
365 |
+
# NOTE: All metrics other than loss are tracked by callbacks (LogMessisMetrics)
|
366 |
+
self.log_dict({f'{stage}_loss': loss, **loss_detail_dict}, on_step=True, on_epoch=True, prog_bar=True, logger=True)
|
367 |
+
return {'loss': loss, 'outputs': outputs}
|
368 |
+
|
369 |
+
class LogConfusionMatrix(pl.Callback):
|
370 |
+
def __init__(self, hparams, dataset_info_file, debug=False):
|
371 |
+
super().__init__()
|
372 |
+
|
373 |
+
assert hparams.get('heads_spec') is not None, "heads_spec must be defined in the hparams"
|
374 |
+
self.tiers_dict = {k: v for k, v in hparams.get('heads_spec').items() if v.get('is_metrics_tier', False)}
|
375 |
+
self.last_tier_name = next((k for k, v in hparams.get('heads_spec').items() if v.get('is_last_tier', False)), None)
|
376 |
+
self.final_head_name = next((k for k, v in hparams.get('heads_spec').items() if v.get('is_final_head', False)), None)
|
377 |
+
|
378 |
+
assert self.last_tier_name is not None, "No tier found with 'is_last_tier' set to True"
|
379 |
+
assert self.final_head_name is not None, "No head found with 'is_final_head' set to True"
|
380 |
+
|
381 |
+
self.tiers = list(self.tiers_dict.keys())
|
382 |
+
self.phases = ['train', 'val', 'test']
|
383 |
+
self.modes = ['pixelwise', 'majority']
|
384 |
+
self.debug = debug
|
385 |
+
|
386 |
+
if debug:
|
387 |
+
print(f"Final head identified as: {self.final_head_name}")
|
388 |
+
print(f"LogConfusionMatrix Metrics over | Phases: {self.phases}, Tiers: {self.tiers}, Modes: {self.modes}")
|
389 |
+
|
390 |
+
with open(dataset_info_file, 'r') as f:
|
391 |
+
self.dataset_info = json.load(f)
|
392 |
+
|
393 |
+
# Initialize confusion matrices
|
394 |
+
self.metrics_to_compute = ['confusion_matrix']
|
395 |
+
self.metrics = {phase: {tier: {mode: self.__init_metrics(tier, phase) for mode in self.modes} for tier in self.tiers} for phase in self.phases}
|
396 |
+
|
397 |
+
def __init_metrics(self, tier, phase):
|
398 |
+
num_classes = self.tiers_dict[tier]['num_classes_to_predict']
|
399 |
+
confusion_matrix = classification.MulticlassConfusionMatrix(num_classes=num_classes)
|
400 |
+
|
401 |
+
return {
|
402 |
+
'confusion_matrix': confusion_matrix
|
403 |
+
}
|
404 |
+
|
405 |
+
def setup(self, trainer, pl_module, stage=None):
|
406 |
+
# Move all metrics to the correct device at the start of the training/validation
|
407 |
+
device = pl_module.device
|
408 |
+
for phase_metrics in self.metrics.values():
|
409 |
+
for tier_metrics in phase_metrics.values():
|
410 |
+
for mode_metrics in tier_metrics.values():
|
411 |
+
for metric in self.metrics_to_compute:
|
412 |
+
mode_metrics[metric].to(device)
|
413 |
+
|
414 |
+
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
|
415 |
+
self.__update_confusion_matrices(trainer, pl_module, outputs, batch, batch_idx, 'train')
|
416 |
+
|
417 |
+
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
|
418 |
+
self.__update_confusion_matrices(trainer, pl_module, outputs, batch, batch_idx, 'val')
|
419 |
+
|
420 |
+
def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
|
421 |
+
self.__update_confusion_matrices(trainer, pl_module, outputs, batch, batch_idx, 'test')
|
422 |
+
|
423 |
+
def __update_confusion_matrices(self, trainer, pl_module, outputs, batch, batch_idx, phase):
|
424 |
+
if trainer.sanity_checking:
|
425 |
+
return
|
426 |
+
|
427 |
+
targets = torch.stack(batch[1][0]) # (tiers, batch, H, W)
|
428 |
+
outputs = outputs['outputs'][self.final_head_name] # (batch, C, H, W)
|
429 |
+
field_ids = batch[1][1].permute(1, 0, 2, 3)[0]
|
430 |
+
|
431 |
+
pixelwise_outputs, majority_outputs = LogConfusionMatrix.get_pixelwise_and_majority_outputs(outputs, self.tiers, field_ids, self.dataset_info)
|
432 |
+
|
433 |
+
for preds, mode in zip([pixelwise_outputs, majority_outputs], self.modes):
|
434 |
+
# Update all metrics
|
435 |
+
assert len(preds) == len(targets), f"Number of predictions and targets do not match: {len(preds)} vs {len(targets)}"
|
436 |
+
assert len(preds) == len(self.tiers), f"Number of predictions and tiers do not match: {len(preds)} vs {len(self.tiers)}"
|
437 |
+
|
438 |
+
for pred, target, tier in zip(preds, targets, self.tiers):
|
439 |
+
if self.debug:
|
440 |
+
print(f"Updating confusion matrix for {phase} {tier} {mode}")
|
441 |
+
metrics = self.metrics[phase][tier][mode]
|
442 |
+
# flatten and remove background class if the mode is majority (such that the background class is not included in the confusion matrix)
|
443 |
+
if mode == 'majority':
|
444 |
+
pred = pred[target != 0]
|
445 |
+
target = target[target != 0]
|
446 |
+
metrics['confusion_matrix'].update(pred, target)
|
447 |
+
|
448 |
+
|
449 |
+
@staticmethod
|
450 |
+
def get_pixelwise_and_majority_outputs(refinement_head_outputs, tiers, field_ids, dataset_info):
|
451 |
+
"""
|
452 |
+
Get the pixelwise and majority predictions from the model outputs.
|
453 |
+
The pixelwise tier predictions are derived from the refinement_head_outputs predictions.
|
454 |
+
The majority last tier predictions are derived from the refinement_head_outputs. And then the majority lower-tier predictions are derived from the majority highest-tier predictions.
|
455 |
+
|
456 |
+
Also sets the background to 0 for all field majority predictions (regardless of what the model predicts for the background class).
|
457 |
+
As this is a classification task and not a segmentation task and the field boundaries are known beforehand and not of any interest.
|
458 |
+
|
459 |
+
Args:
|
460 |
+
refinement_head_outputs (torch.Tensor(batch, C, H, W)): The probability outputs from the model for the refined tier.
|
461 |
+
tiers (list of str): List of tiers e.g. ['tier1', 'tier2', 'tier3'].
|
462 |
+
field_ids (torch.Tensor(batch, H, W)): The field IDs for each prediction.
|
463 |
+
dataset_info (dict): The dataset information.
|
464 |
+
|
465 |
+
Returns:
|
466 |
+
torch.Tensor(tiers, batch, H, W): The pixelwise predictions.
|
467 |
+
torch.Tensor(tiers, batch, H, W): The majority predictions.
|
468 |
+
"""
|
469 |
+
|
470 |
+
# Assuming the highest tier is the last one in the list
|
471 |
+
highest_tier = tiers[-1]
|
472 |
+
|
473 |
+
pixelwise_highest_tier = torch.softmax(refinement_head_outputs, dim=1).argmax(dim=1) # (batch, H, W)
|
474 |
+
majority_highest_tier = LogConfusionMatrix.get_field_majority_preds(refinement_head_outputs, field_ids)
|
475 |
+
|
476 |
+
tier_mapping = {tier: dataset_info[f'{highest_tier}_to_{tier}'] for tier in tiers if tier != highest_tier}
|
477 |
+
|
478 |
+
pixelwise_outputs = {highest_tier: pixelwise_highest_tier}
|
479 |
+
majority_outputs = {highest_tier: majority_highest_tier}
|
480 |
+
|
481 |
+
# Initialize pixelwise and majority outputs for each tier
|
482 |
+
for tier in tiers:
|
483 |
+
if tier != highest_tier:
|
484 |
+
pixelwise_outputs[tier] = torch.zeros_like(pixelwise_highest_tier)
|
485 |
+
majority_outputs[tier] = torch.zeros_like(majority_highest_tier)
|
486 |
+
|
487 |
+
# Map the highest tier to lower tiers
|
488 |
+
for i, mappings in enumerate(zip(*tier_mapping.values())):
|
489 |
+
for j, tier in enumerate(tier_mapping.keys()):
|
490 |
+
pixelwise_outputs[tier][pixelwise_highest_tier == i] = mappings[j]
|
491 |
+
majority_outputs[tier][majority_highest_tier == i] = mappings[j]
|
492 |
+
|
493 |
+
pixelwise_outputs_stacked = torch.stack([pixelwise_outputs[tier] for tier in tiers])
|
494 |
+
majority_outputs_stacked = torch.stack([majority_outputs[tier] for tier in tiers])
|
495 |
+
|
496 |
+
# Ensure these are tensors
|
497 |
+
assert isinstance(pixelwise_outputs_stacked, torch.Tensor), "pixelwise_outputs_stacked is not a tensor"
|
498 |
+
assert isinstance(majority_outputs_stacked, torch.Tensor), "majority_outputs_stacked is not a tensor"
|
499 |
+
|
500 |
+
return pixelwise_outputs_stacked, majority_outputs_stacked
|
501 |
+
|
502 |
+
|
503 |
+
@staticmethod
|
504 |
+
def get_field_majority_preds(output, field_ids):
|
505 |
+
"""
|
506 |
+
Get the majority prediction for each field in the batch. The majority excludes the background class.
|
507 |
+
|
508 |
+
Args:
|
509 |
+
output (torch.Tensor(batch, C, H, W)): The probability outputs from the model (tier3_refined)
|
510 |
+
field_ids (torch.Tensor(batch, H, W)): The field IDs for each prediction.
|
511 |
+
|
512 |
+
Returns:
|
513 |
+
torch.Tensor(batch, H, W): The majority predictions.
|
514 |
+
"""
|
515 |
+
# remove the background class
|
516 |
+
pixelwise = torch.softmax(output[:, 1:, :, :], dim=1).argmax(dim=1) + 1 # (batch, H, W)
|
517 |
+
majority_preds = torch.zeros_like(pixelwise)
|
518 |
+
for batch in range(len(pixelwise)):
|
519 |
+
field_ids_batch = field_ids[batch]
|
520 |
+
for field_id in np.unique(field_ids_batch.cpu().numpy()):
|
521 |
+
if field_id == 0:
|
522 |
+
continue
|
523 |
+
field_mask = field_ids_batch == field_id
|
524 |
+
flattened_pred = pixelwise[batch][field_mask].view(-1) # Flatten the prediction
|
525 |
+
flattened_pred = flattened_pred[flattened_pred != 0] # Exclude background class
|
526 |
+
if len(flattened_pred) == 0:
|
527 |
+
continue
|
528 |
+
mode_pred, _ = torch.mode(flattened_pred) # Compute mode prediction
|
529 |
+
majority_preds[batch][field_mask] = mode_pred.item()
|
530 |
+
return majority_preds
|
531 |
+
|
532 |
+
def on_train_epoch_end(self, trainer, pl_module):
|
533 |
+
# Log and then reset the confusion matrices after training epoch
|
534 |
+
self.__log_and_reset_confusion_matrices(trainer, pl_module, 'train')
|
535 |
+
|
536 |
+
def on_validation_epoch_end(self, trainer, pl_module):
|
537 |
+
# Log and then reset the confusion matrices after validation epoch
|
538 |
+
self.__log_and_reset_confusion_matrices(trainer, pl_module, 'val')
|
539 |
+
|
540 |
+
def on_test_epoch_end(self, trainer, pl_module):
|
541 |
+
# Log and then reset the confusion matrices after test epoch
|
542 |
+
self.__log_and_reset_confusion_matrices(trainer, pl_module, 'test')
|
543 |
+
|
544 |
+
def __log_and_reset_confusion_matrices(self, trainer, pl_module, phase):
|
545 |
+
if trainer.sanity_checking:
|
546 |
+
return
|
547 |
+
|
548 |
+
for tier in self.tiers:
|
549 |
+
for mode in self.modes:
|
550 |
+
metrics = self.metrics[phase][tier][mode]
|
551 |
+
confusion_matrix = metrics['confusion_matrix']
|
552 |
+
if self.debug:
|
553 |
+
print(f"Logging and resetting confusion matrix for {phase} {tier} Update count: {confusion_matrix._update_count}")
|
554 |
+
matrix = confusion_matrix.compute() # columns are predictions and rows are targets
|
555 |
+
|
556 |
+
# Calculate percentages
|
557 |
+
matrix = matrix.float()
|
558 |
+
row_sums = matrix.sum(dim=1, keepdim=True)
|
559 |
+
matrix_percent = matrix / row_sums
|
560 |
+
|
561 |
+
# Ensure percentages sum to 1 for each row or handle NaNs
|
562 |
+
row_sum_check = matrix_percent.sum(dim=1)
|
563 |
+
valid_rows = ~torch.isnan(row_sum_check)
|
564 |
+
if valid_rows.any():
|
565 |
+
assert torch.allclose(row_sum_check[valid_rows], torch.ones_like(row_sum_check[valid_rows]), atol=1e-2), "Percentages do not sum to 1 for some valid rows"
|
566 |
+
|
567 |
+
# Sort the matrix and labels by the total number of instances
|
568 |
+
sorted_indices = row_sums.squeeze().argsort(descending=True)
|
569 |
+
matrix_percent = matrix_percent[sorted_indices, :] # sort rows
|
570 |
+
matrix_percent = matrix_percent[:, sorted_indices] # sort columns
|
571 |
+
class_labels = [self.dataset_info[tier][i] for i in sorted_indices]
|
572 |
+
row_sums_sorted = row_sums[sorted_indices]
|
573 |
+
|
574 |
+
# Check for zero rows after sorting
|
575 |
+
zero_rows = (row_sums_sorted == 0).squeeze()
|
576 |
+
|
577 |
+
fig, ax = plt.subplots(figsize=(matrix.size(0), matrix.size(0)), dpi=140)
|
578 |
+
|
579 |
+
ax.matshow(matrix_percent.cpu().numpy(), cmap='viridis')
|
580 |
+
|
581 |
+
ax.xaxis.set_major_locator(ticker.FixedLocator(range(matrix.size(1) + 1)))
|
582 |
+
ax.yaxis.set_major_locator(ticker.FixedLocator(range(matrix.size(0) + 1)))
|
583 |
+
|
584 |
+
ax.set_xticklabels(class_labels + [''], rotation=45)
|
585 |
+
ax.set_yticklabels(class_labels + [''])
|
586 |
+
|
587 |
+
# Add total number of instances to the y-axis labels
|
588 |
+
y_labels = [f'{class_labels[i]} [n={int(row_sums_sorted[i].item()):,.0f}]'.replace(',', "'") for i in range(matrix.size(0))]
|
589 |
+
ax.set_yticklabels(y_labels + [''])
|
590 |
+
|
591 |
+
ax.set_xlabel('Predictions')
|
592 |
+
ax.set_ylabel('Targets')
|
593 |
+
|
594 |
+
# Move x-axis label and ticks to the top
|
595 |
+
ax.xaxis.set_label_position('top')
|
596 |
+
ax.xaxis.set_ticks_position('top')
|
597 |
+
|
598 |
+
fig.tight_layout()
|
599 |
+
|
600 |
+
for i in range(matrix.size(0)):
|
601 |
+
for j in range(matrix.size(1)):
|
602 |
+
if zero_rows[i]:
|
603 |
+
ax.text(j, i, 'N/A', ha='center', va='center', color='black')
|
604 |
+
else:
|
605 |
+
ax.text(j, i, f'{matrix_percent[i, j]:.2f}', ha='center', va='center', color='#F88379', weight='bold') # coral red
|
606 |
+
trainer.logger.experiment.log({f"{phase}_{tier}_confusion_matrix_{mode}": wandb.Image(fig)})
|
607 |
+
plt.close()
|
608 |
+
confusion_matrix.reset()
|
609 |
+
|
610 |
+
class LogMessisMetrics(pl.Callback):
|
611 |
+
def __init__(self, hparams, dataset_info_file, debug=False):
|
612 |
+
super().__init__()
|
613 |
+
|
614 |
+
assert hparams.get('heads_spec') is not None, "heads_spec must be defined in the hparams"
|
615 |
+
self.tiers_dict = {k: v for k, v in hparams.get('heads_spec').items() if v.get('is_metrics_tier', False)}
|
616 |
+
self.last_tier_name = next((k for k, v in hparams.get('heads_spec').items() if v.get('is_last_tier', False)), None)
|
617 |
+
self.final_head_name = next((k for k, v in hparams.get('heads_spec').items() if v.get('is_final_head', False)), None)
|
618 |
+
|
619 |
+
assert self.last_tier_name is not None, "No tier found with 'is_last_tier' set to True"
|
620 |
+
assert self.final_head_name is not None, "No head found with 'is_final_head' set to True"
|
621 |
+
|
622 |
+
self.tiers = list(self.tiers_dict.keys())
|
623 |
+
self.phases = ['train', 'val', 'test']
|
624 |
+
self.modes = ['pixelwise', 'majority']
|
625 |
+
self.debug = debug
|
626 |
+
|
627 |
+
if debug:
|
628 |
+
print(f"Last tier identified as: {self.last_tier_name}")
|
629 |
+
print(f"Final head identified as: {self.final_head_name}")
|
630 |
+
print(f"LogMessisMetrics Metrics over | Phases: {self.phases}, Tiers: {self.tiers}, Modes: {self.modes}")
|
631 |
+
|
632 |
+
with open(dataset_info_file, 'r') as f:
|
633 |
+
self.dataset_info = json.load(f)
|
634 |
+
|
635 |
+
# Initialize metrics
|
636 |
+
self.metrics_to_compute = ['accuracy', 'weighted_accuracy', 'precision', 'weighted_precision', 'recall', 'weighted_recall' ,'f1', 'weighted_f1', 'cohen_kappa']
|
637 |
+
self.metrics = {phase: {tier: {mode: self.__init_metrics(tier, phase) for mode in self.modes} for tier in self.tiers} for phase in self.phases}
|
638 |
+
self.images_to_log = {phase: {mode: None for mode in self.modes} for phase in self.phases}
|
639 |
+
self.images_to_log_targets = {phase: None for phase in self.phases}
|
640 |
+
self.field_ids_to_log_targets = {phase: None for phase in self.phases}
|
641 |
+
self.inputs_to_log = {phase: None for phase in self.phases}
|
642 |
+
|
643 |
+
def __init_metrics(self, tier, phase):
|
644 |
+
num_classes = self.tiers_dict[tier]['num_classes_to_predict']
|
645 |
+
|
646 |
+
accuracy = classification.MulticlassAccuracy(num_classes=num_classes, average='macro')
|
647 |
+
weighted_accuracy = classification.MulticlassAccuracy(num_classes=num_classes, average='weighted')
|
648 |
+
per_class_accuracies = {
|
649 |
+
class_index: classification.BinaryAccuracy() for class_index in range(num_classes)
|
650 |
+
}
|
651 |
+
precision = classification.MulticlassPrecision(num_classes=num_classes, average='macro')
|
652 |
+
weighted_precision = classification.MulticlassPrecision(num_classes=num_classes, average='weighted')
|
653 |
+
recall = classification.MulticlassRecall(num_classes=num_classes, average='macro')
|
654 |
+
weighted_recall = classification.MulticlassRecall(num_classes=num_classes, average='weighted')
|
655 |
+
f1 = classification.MulticlassF1Score(num_classes=num_classes, average='macro')
|
656 |
+
weighted_f1 = classification.MulticlassF1Score(num_classes=num_classes, average='weighted')
|
657 |
+
cohen_kappa = classification.MulticlassCohenKappa(num_classes=num_classes)
|
658 |
+
|
659 |
+
return {
|
660 |
+
'accuracy': accuracy,
|
661 |
+
'weighted_accuracy': weighted_accuracy,
|
662 |
+
'per_class_accuracies': per_class_accuracies,
|
663 |
+
'precision': precision,
|
664 |
+
'weighted_precision': weighted_precision,
|
665 |
+
'recall': recall,
|
666 |
+
'weighted_recall': weighted_recall,
|
667 |
+
'f1': f1,
|
668 |
+
'weighted_f1': weighted_f1,
|
669 |
+
'cohen_kappa': cohen_kappa
|
670 |
+
}
|
671 |
+
|
672 |
+
def setup(self, trainer, pl_module, stage=None):
|
673 |
+
# Move all metrics to the correct device at the start of the training/validation
|
674 |
+
device = pl_module.device
|
675 |
+
for phase_metrics in self.metrics.values():
|
676 |
+
for tier_metrics in phase_metrics.values():
|
677 |
+
for mode_metrics in tier_metrics.values():
|
678 |
+
for metric in self.metrics_to_compute:
|
679 |
+
mode_metrics[metric].to(device)
|
680 |
+
for class_accuracy in mode_metrics['per_class_accuracies'].values():
|
681 |
+
class_accuracy.to(device)
|
682 |
+
|
683 |
+
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
|
684 |
+
self.__on_batch_end(trainer, pl_module, outputs, batch, batch_idx, 'train')
|
685 |
+
|
686 |
+
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
|
687 |
+
self.__on_batch_end(trainer, pl_module, outputs, batch, batch_idx, 'val')
|
688 |
+
|
689 |
+
def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
|
690 |
+
self.__on_batch_end(trainer, pl_module, outputs, batch, batch_idx, 'test')
|
691 |
+
|
692 |
+
def __on_batch_end(self, trainer: pl.Trainer, pl_module, outputs, batch, batch_idx, phase):
|
693 |
+
if trainer.sanity_checking:
|
694 |
+
return
|
695 |
+
if self.debug:
|
696 |
+
print(f"{phase} batch ended. Updating metrics...")
|
697 |
+
|
698 |
+
targets = torch.stack(batch[1][0]) # (tiers, batch, H, W)
|
699 |
+
outputs = outputs['outputs'][self.final_head_name] # (batch, C, H, W)
|
700 |
+
field_ids = batch[1][1].permute(1, 0, 2, 3)[0]
|
701 |
+
|
702 |
+
pixelwise_outputs, majority_outputs = LogConfusionMatrix.get_pixelwise_and_majority_outputs(outputs, self.tiers, field_ids, self.dataset_info)
|
703 |
+
|
704 |
+
for preds, mode in zip([pixelwise_outputs, majority_outputs], self.modes):
|
705 |
+
|
706 |
+
# Update all metrics
|
707 |
+
assert preds.shape == targets.shape, f"Shapes of predictions and targets do not match: {preds.shape} vs {targets.shape}"
|
708 |
+
assert preds.shape[0] == len(self.tiers), f"Number of tiers in predictions and tiers do not match: {preds.shape[0]} vs {len(self.tiers)}"
|
709 |
+
|
710 |
+
self.images_to_log[phase][mode] = preds[-1]
|
711 |
+
|
712 |
+
for pred, target, tier in zip(preds, targets, self.tiers):
|
713 |
+
# flatten and remove background class if the mode is majority (such that the background class is not considered in the metrics)
|
714 |
+
if mode == 'majority':
|
715 |
+
pred = pred[target != 0]
|
716 |
+
target = target[target != 0]
|
717 |
+
metrics = self.metrics[phase][tier][mode]
|
718 |
+
for metric in self.metrics_to_compute:
|
719 |
+
metrics[metric].update(pred, target)
|
720 |
+
if self.debug:
|
721 |
+
print(f"{phase} {tier} {mode} {metric} updated. Update count: {metrics[metric]._update_count}")
|
722 |
+
self.__update_per_class_metrics(pred, target, metrics['per_class_accuracies'])
|
723 |
+
|
724 |
+
self.images_to_log_targets[phase] = targets[-1]
|
725 |
+
self.field_ids_to_log_targets[phase] = field_ids
|
726 |
+
self.inputs_to_log[phase] = batch[0]
|
727 |
+
|
728 |
+
def __update_per_class_metrics(self, preds, targets, per_class_accuracies):
|
729 |
+
for class_index, class_accuracy in per_class_accuracies.items():
|
730 |
+
if not (targets == class_index).any():
|
731 |
+
continue
|
732 |
+
|
733 |
+
if class_index == 0:
|
734 |
+
# Mask out non-background elements for background class (0)
|
735 |
+
class_mask = targets != 0
|
736 |
+
else:
|
737 |
+
# Mask out background elements for other classes
|
738 |
+
class_mask = targets == 0
|
739 |
+
|
740 |
+
preds_fields = preds[~class_mask]
|
741 |
+
targets_fields = targets[~class_mask]
|
742 |
+
|
743 |
+
# Prepare for binary classification (needs to be float)
|
744 |
+
preds_class = (preds_fields == class_index).float()
|
745 |
+
targets_class = (targets_fields == class_index).float()
|
746 |
+
|
747 |
+
class_accuracy.update(preds_class, targets_class)
|
748 |
+
|
749 |
+
if self.debug:
|
750 |
+
print(f"Shape of preds_fields: {preds_fields.shape}")
|
751 |
+
print(f"Shape of targets_fields: {targets_fields.shape}")
|
752 |
+
print(f"Unique values in preds_fields: {torch.unique(preds_fields)}")
|
753 |
+
print(f"Unique values in targets_fields: {torch.unique(targets_fields)}")
|
754 |
+
print(f"Per-class metrics for class {class_index} updated. Update count: {per_class_accuracies[class_index]._update_count}")
|
755 |
+
|
756 |
+
def on_train_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
|
757 |
+
self.__on_epoch_end(trainer, pl_module, 'train')
|
758 |
+
|
759 |
+
def on_validation_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
|
760 |
+
self.__on_epoch_end(trainer, pl_module, 'val')
|
761 |
+
|
762 |
+
def on_test_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
|
763 |
+
self.__on_epoch_end(trainer, pl_module, 'test')
|
764 |
+
|
765 |
+
def __on_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule, phase):
|
766 |
+
if trainer.sanity_checking:
|
767 |
+
return # Skip during sanity check (avoid warning about metric compute being called before update)
|
768 |
+
for tier in self.tiers:
|
769 |
+
for mode in self.modes:
|
770 |
+
metrics = self.metrics[phase][tier][mode]
|
771 |
+
|
772 |
+
# Calculate and reset in tier: Accuracy, WeightedAccuracy, Precision, Recall, F1, Cohen's Kappa
|
773 |
+
metrics_dict = {metric: metrics[metric].compute() for metric in self.metrics_to_compute}
|
774 |
+
pl_module.log_dict({f"{phase}_{metric}_{tier}_{mode}": v for metric, v in metrics_dict.items()}, on_step=False, on_epoch=True)
|
775 |
+
for metric in self.metrics_to_compute:
|
776 |
+
metrics[metric].reset()
|
777 |
+
|
778 |
+
# Per-class metrics
|
779 |
+
# NOTE: Some literature reports "per class accuracy" but what they actually mean is "per class recall".
|
780 |
+
# Using the accuracy formula per class has no value in our imbalanced multi-class setting (TN's inflate scores!)
|
781 |
+
# We calculate all 4 metrics. This allows us to calculate any macro/micro score later if needed.
|
782 |
+
class_metrics = []
|
783 |
+
class_names_mapping = self.dataset_info[tier.split('_')[0] if '_refined' in tier else tier]
|
784 |
+
for class_index, class_accuracy in metrics['per_class_accuracies'].items():
|
785 |
+
if class_accuracy._update_count == 0:
|
786 |
+
continue # Skip if no updates have been made
|
787 |
+
tp, tn, fp, fn = class_accuracy.tp, class_accuracy.tn, class_accuracy.fp, class_accuracy.fn
|
788 |
+
recall = (tp / (tp + fn)).item() if tp + fn > 0 else 0
|
789 |
+
precision = (tp / (tp + fp)).item() if tp + fp > 0 else 0
|
790 |
+
f1 = (2 * (precision * recall) / (precision + recall)) if precision + recall > 0 else 0
|
791 |
+
n_of_class = (tp + fn).item()
|
792 |
+
class_metrics.append([class_index, class_names_mapping[class_index], precision, recall, f1, class_accuracy.compute().item(), n_of_class])
|
793 |
+
class_accuracy.reset()
|
794 |
+
wandb_table = wandb.Table(data=class_metrics, columns=["Class Index", "Class Name", "Precision", "Recall", "F1", "Accuracy", "N"])
|
795 |
+
trainer.logger.experiment.log({f"{phase}_per_class_metrics_{tier}_{mode}": wandb_table})
|
796 |
+
|
797 |
+
# use the same n_classes for all images, such that they are comparable
|
798 |
+
n_classes = max([
|
799 |
+
torch.max(self.images_to_log_targets[phase]),
|
800 |
+
torch.max(self.images_to_log[phase]["majority"]),
|
801 |
+
torch.max(self.images_to_log[phase]["pixelwise"])
|
802 |
+
])
|
803 |
+
images = [LogMessisMetrics.process_images(self.images_to_log[phase][mode], n_classes) for mode in self.modes]
|
804 |
+
images.append(LogMessisMetrics.create_positive_negative_image(self.images_to_log[phase]["majority"], self.images_to_log_targets[phase]))
|
805 |
+
images.append(LogMessisMetrics.process_images(self.images_to_log_targets[phase], n_classes))
|
806 |
+
images.append(LogMessisMetrics.process_images(self.field_ids_to_log_targets[phase].cpu()))
|
807 |
+
|
808 |
+
examples = []
|
809 |
+
for i in range(len(images[0])):
|
810 |
+
example = np.concatenate([img[i] for img in images], axis=0)
|
811 |
+
examples.append(wandb.Image(example, caption=f"From Top to Bottom: {self.modes[0]}, {self.modes[1]}, right/wrong classifications, target, fields"))
|
812 |
+
|
813 |
+
trainer.logger.experiment.log({f"{phase}_examples": examples})
|
814 |
+
|
815 |
+
# Log segmentation masks
|
816 |
+
batch_input_data = self.inputs_to_log[phase].cpu() # shape [BS, 6, N_TIMESTEPS, 224, 224]
|
817 |
+
ground_truth_masks = self.images_to_log_targets[phase].cpu().numpy()
|
818 |
+
pixel_wise_masks = self.images_to_log[phase]["pixelwise"].cpu().numpy()
|
819 |
+
field_majority_masks = self.images_to_log[phase]["majority"].cpu().numpy()
|
820 |
+
correctness_masks = self.create_positive_negative_segmentation_mask(field_majority_masks, ground_truth_masks)
|
821 |
+
class_labels = {idx: name for idx, name in enumerate(self.dataset_info[self.last_tier_name])}
|
822 |
+
|
823 |
+
segmentation_masks = []
|
824 |
+
for input_data, ground_truth_mask, pixel_wise_mask, field_majority_mask, correctness_mask in zip(batch_input_data, ground_truth_masks, pixel_wise_masks, field_majority_masks, correctness_masks):
|
825 |
+
middle_timestep_index = input_data.shape[1] // 2 # Get the middle timestamp index
|
826 |
+
gamma = 2.5 # Gamma for brightness adjustment
|
827 |
+
rgb_image = input_data[:3, middle_timestep_index, :, :].permute(1, 2, 0).numpy() # Shape [224, 224, 3]
|
828 |
+
rgb_image = (rgb_image - rgb_image.min()) / (rgb_image.max() - rgb_image.min())
|
829 |
+
rgb_image = np.power(rgb_image, 1.0 / gamma)
|
830 |
+
rgb_image = (rgb_image * 255).astype(np.uint8)
|
831 |
+
|
832 |
+
mask_img = wandb.Image(
|
833 |
+
rgb_image,
|
834 |
+
masks={
|
835 |
+
"predictions_pixel_wise": {"mask_data": pixel_wise_mask, "class_labels": class_labels},
|
836 |
+
"predictions_field_majority": {"mask_data": field_majority_mask, "class_labels": class_labels},
|
837 |
+
"ground_truth": {"mask_data": ground_truth_mask, "class_labels": class_labels},
|
838 |
+
"correctness": {"mask_data": correctness_mask, "class_labels": { 0: "Background", 1: "Wrong", 2: "Right" }},
|
839 |
+
},
|
840 |
+
)
|
841 |
+
segmentation_masks.append(mask_img)
|
842 |
+
|
843 |
+
trainer.logger.experiment.log({f"{phase}_segmentation_mask": segmentation_masks})
|
844 |
+
|
845 |
+
if self.debug:
|
846 |
+
print(f"{phase} epoch ended. Logging & resetting metrics...", trainer.sanity_checking)
|
847 |
+
|
848 |
+
@staticmethod
|
849 |
+
def create_positive_negative_segmentation_mask(field_majority_masks, ground_truth_masks):
|
850 |
+
"""
|
851 |
+
Create a tensor that shows the positive and negative classifications of the model.
|
852 |
+
|
853 |
+
Args:
|
854 |
+
field_majority_masks (np.ndarray): The field majority masks generated by the model.
|
855 |
+
ground_truth_masks (np.ndarray): The ground truth masks.
|
856 |
+
|
857 |
+
Returns:
|
858 |
+
np.ndarray: An array with values:
|
859 |
+
- 0 where the target is 0,
|
860 |
+
- 2 where the prediction matches the target,
|
861 |
+
- 1 where the prediction does not match the target.
|
862 |
+
"""
|
863 |
+
correctness_mask = np.zeros_like(ground_truth_masks, dtype=int)
|
864 |
+
|
865 |
+
matches = (field_majority_masks == ground_truth_masks) & (ground_truth_masks != 0)
|
866 |
+
correctness_mask[matches] = 2
|
867 |
+
|
868 |
+
mismatches = (field_majority_masks != ground_truth_masks) & (ground_truth_masks != 0)
|
869 |
+
correctness_mask[mismatches] = 1
|
870 |
+
|
871 |
+
return correctness_mask
|
872 |
+
|
873 |
+
@staticmethod
|
874 |
+
def create_positive_negative_image(generated_images, target_images):
|
875 |
+
"""
|
876 |
+
Create an image that shows the positive and negative classifications of the model.
|
877 |
+
|
878 |
+
Args:
|
879 |
+
generated_images (torch.Tensor): The images generated by the model.
|
880 |
+
target_images (torch.Tensor): The target images.
|
881 |
+
|
882 |
+
Returns:
|
883 |
+
list: A list of processed images.
|
884 |
+
"""
|
885 |
+
classification_masks = generated_images == target_images
|
886 |
+
processed_imgs = []
|
887 |
+
for mask, target in zip(classification_masks, target_images):
|
888 |
+
# color the background white, right classifications green, wrong classifications red
|
889 |
+
colored_img = torch.zeros((mask.shape[0], mask.shape[1], 3), dtype=torch.uint8)
|
890 |
+
mask = mask.bool() # Convert to boolean tensor
|
891 |
+
colored_img[mask] = torch.tensor([0, 255, 0], dtype=torch.uint8)
|
892 |
+
colored_img[~mask] = torch.tensor([255, 0, 0], dtype=torch.uint8)
|
893 |
+
colored_img[target == 0] = torch.tensor([0, 0, 0], dtype=torch.uint8)
|
894 |
+
processed_imgs.append(colored_img.cpu())
|
895 |
+
return processed_imgs
|
896 |
+
|
897 |
+
@staticmethod
|
898 |
+
def process_images(imgs, max=None):
|
899 |
+
"""
|
900 |
+
Process a batch of images to be logged on wandb.
|
901 |
+
|
902 |
+
Args:
|
903 |
+
imgs (torch.Tensor): A batch of images with shape (B, H, W) to be processed.
|
904 |
+
max (float, optional): The maximum value to normalize the images. Defaults to None. If None, the maximum value in the batch is used.
|
905 |
+
"""
|
906 |
+
if max is None:
|
907 |
+
max = np.max(imgs.cpu().numpy())
|
908 |
+
normalized_img = imgs / max
|
909 |
+
processed_imgs = []
|
910 |
+
for img in normalized_img.cpu().numpy():
|
911 |
+
if max < 60:
|
912 |
+
cmap = ListedColormap(plt.get_cmap('tab20').colors + plt.get_cmap('tab20b').colors + plt.get_cmap('tab20c').colors)
|
913 |
+
else:
|
914 |
+
cmap = plt.get_cmap('viridis')
|
915 |
+
colored_img = cmap(img)
|
916 |
+
colored_img[img == 0] = [0, 0, 0, 1]
|
917 |
+
colored_img_uint8 = (colored_img[:, :, :3] * 255).astype(np.uint8)
|
918 |
+
processed_imgs.append(colored_img_uint8)
|
919 |
+
return processed_imgs
|
messis/prithvi.py
ADDED
@@ -0,0 +1,555 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from safetensors import safe_open
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
from timm.models.layers import to_2tuple
|
7 |
+
from timm.models.vision_transformer import Block
|
8 |
+
|
9 |
+
# Taken and adapted from Pritvhi `geospatial_fm.py`, for the purpose of avoiding MMCV/MMSegmentation dependencies
|
10 |
+
|
11 |
+
def _convTranspose2dOutput(
|
12 |
+
input_size: int,
|
13 |
+
stride: int,
|
14 |
+
padding: int,
|
15 |
+
dilation: int,
|
16 |
+
kernel_size: int,
|
17 |
+
output_padding: int,
|
18 |
+
):
|
19 |
+
"""
|
20 |
+
Calculate the output size of a ConvTranspose2d.
|
21 |
+
Taken from: https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html
|
22 |
+
"""
|
23 |
+
return (
|
24 |
+
(input_size - 1) * stride
|
25 |
+
- 2 * padding
|
26 |
+
+ dilation * (kernel_size - 1)
|
27 |
+
+ output_padding
|
28 |
+
+ 1
|
29 |
+
)
|
30 |
+
|
31 |
+
|
32 |
+
def get_1d_sincos_pos_embed_from_grid(embed_dim: int, pos: torch.Tensor):
|
33 |
+
"""
|
34 |
+
embed_dim: output dimension for each position
|
35 |
+
pos: a list of positions to be encoded: size (M,)
|
36 |
+
out: (M, D)
|
37 |
+
"""
|
38 |
+
assert embed_dim % 2 == 0
|
39 |
+
omega = np.arange(embed_dim // 2, dtype=np.float32)
|
40 |
+
omega /= embed_dim / 2.0
|
41 |
+
omega = 1.0 / 10000**omega # (D/2,)
|
42 |
+
|
43 |
+
pos = pos.reshape(-1) # (M,)
|
44 |
+
out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
|
45 |
+
|
46 |
+
emb_sin = np.sin(out) # (M, D/2)
|
47 |
+
emb_cos = np.cos(out) # (M, D/2)
|
48 |
+
|
49 |
+
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
50 |
+
return emb
|
51 |
+
|
52 |
+
|
53 |
+
def get_3d_sincos_pos_embed(embed_dim: int, grid_size: tuple, cls_token: bool = False):
|
54 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
55 |
+
# All rights reserved.
|
56 |
+
|
57 |
+
# This source code is licensed under the license found in the
|
58 |
+
# LICENSE file in the root directory of this source tree.
|
59 |
+
# --------------------------------------------------------
|
60 |
+
# Position embedding utils
|
61 |
+
# --------------------------------------------------------
|
62 |
+
"""
|
63 |
+
grid_size: 3d tuple of grid size: t, h, w
|
64 |
+
return:
|
65 |
+
pos_embed: L, D
|
66 |
+
"""
|
67 |
+
|
68 |
+
assert embed_dim % 16 == 0
|
69 |
+
|
70 |
+
t_size, h_size, w_size = grid_size
|
71 |
+
|
72 |
+
w_embed_dim = embed_dim // 16 * 6
|
73 |
+
h_embed_dim = embed_dim // 16 * 6
|
74 |
+
t_embed_dim = embed_dim // 16 * 4
|
75 |
+
|
76 |
+
w_pos_embed = get_1d_sincos_pos_embed_from_grid(w_embed_dim, np.arange(w_size))
|
77 |
+
h_pos_embed = get_1d_sincos_pos_embed_from_grid(h_embed_dim, np.arange(h_size))
|
78 |
+
t_pos_embed = get_1d_sincos_pos_embed_from_grid(t_embed_dim, np.arange(t_size))
|
79 |
+
|
80 |
+
w_pos_embed = np.tile(w_pos_embed, (t_size * h_size, 1))
|
81 |
+
h_pos_embed = np.tile(np.repeat(h_pos_embed, w_size, axis=0), (t_size, 1))
|
82 |
+
t_pos_embed = np.repeat(t_pos_embed, h_size * w_size, axis=0)
|
83 |
+
|
84 |
+
pos_embed = np.concatenate((w_pos_embed, h_pos_embed, t_pos_embed), axis=1)
|
85 |
+
|
86 |
+
if cls_token:
|
87 |
+
pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
|
88 |
+
return pos_embed
|
89 |
+
|
90 |
+
|
91 |
+
class Norm2d(nn.Module):
|
92 |
+
def __init__(self, embed_dim: int):
|
93 |
+
super().__init__()
|
94 |
+
self.ln = nn.LayerNorm(embed_dim, eps=1e-6)
|
95 |
+
|
96 |
+
def forward(self, x):
|
97 |
+
x = x.permute(0, 2, 3, 1)
|
98 |
+
x = self.ln(x)
|
99 |
+
x = x.permute(0, 3, 1, 2).contiguous()
|
100 |
+
return x
|
101 |
+
|
102 |
+
|
103 |
+
class PatchEmbed(nn.Module):
|
104 |
+
"""Frames of 2D Images to Patch Embedding
|
105 |
+
The 3D version of timm.models.vision_transformer.PatchEmbed
|
106 |
+
"""
|
107 |
+
|
108 |
+
def __init__(
|
109 |
+
self,
|
110 |
+
img_size: int = 224,
|
111 |
+
patch_size: int = 16,
|
112 |
+
num_frames: int = 3,
|
113 |
+
tubelet_size: int = 1,
|
114 |
+
in_chans: int = 3,
|
115 |
+
embed_dim: int = 768,
|
116 |
+
norm_layer: nn.Module = None,
|
117 |
+
flatten: bool = True,
|
118 |
+
bias: bool = True,
|
119 |
+
):
|
120 |
+
super().__init__()
|
121 |
+
img_size = to_2tuple(img_size)
|
122 |
+
patch_size = to_2tuple(patch_size)
|
123 |
+
self.img_size = img_size
|
124 |
+
self.patch_size = patch_size
|
125 |
+
self.num_frames = num_frames
|
126 |
+
self.tubelet_size = tubelet_size
|
127 |
+
self.grid_size = (
|
128 |
+
num_frames // tubelet_size,
|
129 |
+
img_size[0] // patch_size[0],
|
130 |
+
img_size[1] // patch_size[1],
|
131 |
+
)
|
132 |
+
self.num_patches = self.grid_size[0] * self.grid_size[1] * self.grid_size[2]
|
133 |
+
self.flatten = flatten
|
134 |
+
|
135 |
+
self.proj = nn.Conv3d(
|
136 |
+
in_chans,
|
137 |
+
embed_dim,
|
138 |
+
kernel_size=(tubelet_size, patch_size[0], patch_size[1]),
|
139 |
+
stride=(tubelet_size, patch_size[0], patch_size[1]),
|
140 |
+
bias=bias,
|
141 |
+
)
|
142 |
+
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
143 |
+
|
144 |
+
def forward(self, x):
|
145 |
+
B, C, T, H, W = x.shape
|
146 |
+
assert (
|
147 |
+
H == self.img_size[0]
|
148 |
+
), f"Input image height ({H}) doesn't match model ({self.img_size[0]})."
|
149 |
+
assert (
|
150 |
+
W == self.img_size[1]
|
151 |
+
), f"Input image width ({W}) doesn't match model ({self.img_size[1]})."
|
152 |
+
x = self.proj(x)
|
153 |
+
Hp, Wp = x.shape[3], x.shape[4]
|
154 |
+
if self.flatten:
|
155 |
+
x = x.flatten(2).transpose(1, 2) # B,C,T,H,W -> B,C,L -> B,L,C
|
156 |
+
x = self.norm(x)
|
157 |
+
return x, Hp, Wp
|
158 |
+
|
159 |
+
|
160 |
+
class ConvTransformerTokensToEmbeddingNeck(nn.Module):
|
161 |
+
"""
|
162 |
+
Neck that transforms the token-based output of transformer into a single embedding suitable for processing with standard layers.
|
163 |
+
Performs 4 ConvTranspose2d operations on the rearranged input with kernel_size=2 and stride=2
|
164 |
+
"""
|
165 |
+
|
166 |
+
def __init__(
|
167 |
+
self,
|
168 |
+
embed_dim: int,
|
169 |
+
output_embed_dim: int,
|
170 |
+
# num_frames: int = 1,
|
171 |
+
Hp: int = 14,
|
172 |
+
Wp: int = 14,
|
173 |
+
drop_cls_token: bool = True,
|
174 |
+
):
|
175 |
+
"""
|
176 |
+
|
177 |
+
Args:
|
178 |
+
embed_dim (int): Input embedding dimension
|
179 |
+
output_embed_dim (int): Output embedding dimension
|
180 |
+
Hp (int, optional): Height (in patches) of embedding to be upscaled. Defaults to 14.
|
181 |
+
Wp (int, optional): Width (in patches) of embedding to be upscaled. Defaults to 14.
|
182 |
+
drop_cls_token (bool, optional): Whether there is a cls_token, which should be dropped. This assumes the cls token is the first token. Defaults to True.
|
183 |
+
"""
|
184 |
+
super().__init__()
|
185 |
+
self.drop_cls_token = drop_cls_token
|
186 |
+
self.Hp = Hp
|
187 |
+
self.Wp = Wp
|
188 |
+
self.H_out = Hp
|
189 |
+
self.W_out = Wp
|
190 |
+
# self.num_frames = num_frames
|
191 |
+
|
192 |
+
kernel_size = 2
|
193 |
+
stride = 2
|
194 |
+
dilation = 1
|
195 |
+
padding = 0
|
196 |
+
output_padding = 0
|
197 |
+
for _ in range(4):
|
198 |
+
self.H_out = _convTranspose2dOutput(
|
199 |
+
self.H_out, stride, padding, dilation, kernel_size, output_padding
|
200 |
+
)
|
201 |
+
self.W_out = _convTranspose2dOutput(
|
202 |
+
self.W_out, stride, padding, dilation, kernel_size, output_padding
|
203 |
+
)
|
204 |
+
|
205 |
+
self.embed_dim = embed_dim
|
206 |
+
self.output_embed_dim = output_embed_dim
|
207 |
+
self.fpn1 = nn.Sequential(
|
208 |
+
nn.ConvTranspose2d(
|
209 |
+
self.embed_dim,
|
210 |
+
self.output_embed_dim,
|
211 |
+
kernel_size=kernel_size,
|
212 |
+
stride=stride,
|
213 |
+
dilation=dilation,
|
214 |
+
padding=padding,
|
215 |
+
output_padding=output_padding,
|
216 |
+
),
|
217 |
+
Norm2d(self.output_embed_dim),
|
218 |
+
nn.GELU(),
|
219 |
+
nn.ConvTranspose2d(
|
220 |
+
self.output_embed_dim,
|
221 |
+
self.output_embed_dim,
|
222 |
+
kernel_size=kernel_size,
|
223 |
+
stride=stride,
|
224 |
+
dilation=dilation,
|
225 |
+
padding=padding,
|
226 |
+
output_padding=output_padding,
|
227 |
+
),
|
228 |
+
)
|
229 |
+
self.fpn2 = nn.Sequential(
|
230 |
+
nn.ConvTranspose2d(
|
231 |
+
self.output_embed_dim,
|
232 |
+
self.output_embed_dim,
|
233 |
+
kernel_size=kernel_size,
|
234 |
+
stride=stride,
|
235 |
+
dilation=dilation,
|
236 |
+
padding=padding,
|
237 |
+
output_padding=output_padding,
|
238 |
+
),
|
239 |
+
Norm2d(self.output_embed_dim),
|
240 |
+
nn.GELU(),
|
241 |
+
nn.ConvTranspose2d(
|
242 |
+
self.output_embed_dim,
|
243 |
+
self.output_embed_dim,
|
244 |
+
kernel_size=kernel_size,
|
245 |
+
stride=stride,
|
246 |
+
dilation=dilation,
|
247 |
+
padding=padding,
|
248 |
+
output_padding=output_padding,
|
249 |
+
),
|
250 |
+
)
|
251 |
+
|
252 |
+
def forward(self, x):
|
253 |
+
x = x[0]
|
254 |
+
if self.drop_cls_token:
|
255 |
+
x = x[:, 1:, :]
|
256 |
+
x = x.permute(0, 2, 1).reshape(x.shape[0], -1, self.Hp, self.Wp)
|
257 |
+
|
258 |
+
x = self.fpn1(x)
|
259 |
+
x = self.fpn2(x)
|
260 |
+
|
261 |
+
x = x.reshape((-1, self.output_embed_dim, self.H_out, self.W_out))
|
262 |
+
out = tuple([x])
|
263 |
+
return out
|
264 |
+
|
265 |
+
class ConvTransformerTokensToEmbeddingBottleneckNeck(nn.Module):
|
266 |
+
"""
|
267 |
+
Neck that transforms the token-based output of transformer into a single embedding suitable for processing with standard layers.
|
268 |
+
Performs ConvTranspose2d operations with bottleneck layers to reduce channels.
|
269 |
+
"""
|
270 |
+
|
271 |
+
def __init__(
|
272 |
+
self,
|
273 |
+
embed_dim: int,
|
274 |
+
output_embed_dim: int,
|
275 |
+
Hp: int = 14,
|
276 |
+
Wp: int = 14,
|
277 |
+
drop_cls_token: bool = True,
|
278 |
+
bottleneck_reduction_factor: int = 4,
|
279 |
+
):
|
280 |
+
"""
|
281 |
+
Args:
|
282 |
+
embed_dim (int): Input embedding dimension
|
283 |
+
output_embed_dim (int): Output embedding dimension
|
284 |
+
Hp (int, optional): Height (in patches) of embedding to be upscaled. Defaults to 14.
|
285 |
+
Wp (int, optional): Width (in patches) of embedding to be upscaled. Defaults to 14.
|
286 |
+
drop_cls_token (bool, optional): Whether there is a cls_token, which should be dropped. Defaults to True.
|
287 |
+
bottleneck_ratio (int, optional): Ratio to reduce channels in bottleneck layers. Defaults to 4.
|
288 |
+
"""
|
289 |
+
super().__init__()
|
290 |
+
self.drop_cls_token = drop_cls_token
|
291 |
+
self.Hp = Hp
|
292 |
+
self.Wp = Wp
|
293 |
+
self.H_out = Hp
|
294 |
+
self.W_out = Wp
|
295 |
+
|
296 |
+
kernel_size = 2
|
297 |
+
stride = 2
|
298 |
+
dilation = 1
|
299 |
+
padding = 0
|
300 |
+
output_padding = 0
|
301 |
+
for _ in range(4):
|
302 |
+
self.H_out = _convTranspose2dOutput(
|
303 |
+
self.H_out, stride, padding, dilation, kernel_size, output_padding
|
304 |
+
)
|
305 |
+
self.W_out = _convTranspose2dOutput(
|
306 |
+
self.W_out, stride, padding, dilation, kernel_size, output_padding
|
307 |
+
)
|
308 |
+
|
309 |
+
self.embed_dim = embed_dim
|
310 |
+
self.output_embed_dim = output_embed_dim
|
311 |
+
bottleneck_dim = self.embed_dim // bottleneck_reduction_factor
|
312 |
+
|
313 |
+
self.fpn1 = nn.Sequential(
|
314 |
+
nn.Conv2d(
|
315 |
+
self.embed_dim,
|
316 |
+
bottleneck_dim,
|
317 |
+
kernel_size=1
|
318 |
+
),
|
319 |
+
Norm2d(bottleneck_dim),
|
320 |
+
nn.GELU(),
|
321 |
+
nn.ConvTranspose2d(
|
322 |
+
bottleneck_dim,
|
323 |
+
bottleneck_dim,
|
324 |
+
kernel_size=kernel_size,
|
325 |
+
stride=stride,
|
326 |
+
padding=padding,
|
327 |
+
output_padding=output_padding
|
328 |
+
),
|
329 |
+
Norm2d(bottleneck_dim),
|
330 |
+
nn.GELU(),
|
331 |
+
nn.ConvTranspose2d(
|
332 |
+
bottleneck_dim,
|
333 |
+
bottleneck_dim,
|
334 |
+
kernel_size=kernel_size,
|
335 |
+
stride=stride,
|
336 |
+
padding=padding,
|
337 |
+
output_padding=output_padding
|
338 |
+
),
|
339 |
+
Norm2d(bottleneck_dim),
|
340 |
+
nn.GELU(),
|
341 |
+
nn.Conv2d(
|
342 |
+
bottleneck_dim,
|
343 |
+
self.output_embed_dim,
|
344 |
+
kernel_size=1
|
345 |
+
),
|
346 |
+
Norm2d(self.output_embed_dim),
|
347 |
+
nn.GELU(),
|
348 |
+
)
|
349 |
+
|
350 |
+
self.fpn2 = nn.Sequential(
|
351 |
+
nn.Conv2d(
|
352 |
+
self.output_embed_dim,
|
353 |
+
bottleneck_dim,
|
354 |
+
kernel_size=1
|
355 |
+
),
|
356 |
+
Norm2d(bottleneck_dim),
|
357 |
+
nn.GELU(),
|
358 |
+
nn.ConvTranspose2d(
|
359 |
+
bottleneck_dim,
|
360 |
+
bottleneck_dim,
|
361 |
+
kernel_size=kernel_size,
|
362 |
+
stride=stride,
|
363 |
+
padding=padding,
|
364 |
+
output_padding=output_padding
|
365 |
+
),
|
366 |
+
Norm2d(bottleneck_dim),
|
367 |
+
nn.GELU(),
|
368 |
+
nn.ConvTranspose2d(
|
369 |
+
bottleneck_dim,
|
370 |
+
bottleneck_dim,
|
371 |
+
kernel_size=kernel_size,
|
372 |
+
stride=stride,
|
373 |
+
padding=padding,
|
374 |
+
output_padding=output_padding
|
375 |
+
),
|
376 |
+
Norm2d(bottleneck_dim),
|
377 |
+
nn.GELU(),
|
378 |
+
nn.Conv2d(
|
379 |
+
bottleneck_dim,
|
380 |
+
self.output_embed_dim,
|
381 |
+
kernel_size=1
|
382 |
+
),
|
383 |
+
Norm2d(self.output_embed_dim),
|
384 |
+
nn.GELU(),
|
385 |
+
)
|
386 |
+
|
387 |
+
def forward(self, x):
|
388 |
+
x = x[0]
|
389 |
+
if self.drop_cls_token:
|
390 |
+
x = x[:, 1:, :]
|
391 |
+
x = x.permute(0, 2, 1).reshape(x.shape[0], -1, self.Hp, self.Wp)
|
392 |
+
|
393 |
+
x = self.fpn1(x)
|
394 |
+
x = self.fpn2(x)
|
395 |
+
|
396 |
+
x = x.reshape((-1, self.output_embed_dim, self.H_out, self.W_out))
|
397 |
+
out = tuple([x])
|
398 |
+
return out
|
399 |
+
|
400 |
+
class TemporalViTEncoder(nn.Module):
|
401 |
+
"""Encoder from an ViT with capability to take in temporal input.
|
402 |
+
|
403 |
+
This class defines an encoder taken from a ViT architecture.
|
404 |
+
"""
|
405 |
+
|
406 |
+
def __init__(
|
407 |
+
self,
|
408 |
+
img_size: int = 224,
|
409 |
+
patch_size: int = 16,
|
410 |
+
num_frames: int = 1,
|
411 |
+
tubelet_size: int = 1,
|
412 |
+
in_chans: int = 3,
|
413 |
+
embed_dim: int = 1024,
|
414 |
+
depth: int = 24,
|
415 |
+
num_heads: int = 16,
|
416 |
+
mlp_ratio: float = 4.0,
|
417 |
+
norm_layer: nn.Module = nn.LayerNorm,
|
418 |
+
norm_pix_loss: bool = False,
|
419 |
+
pretrained: str = None,
|
420 |
+
debug=False
|
421 |
+
):
|
422 |
+
"""
|
423 |
+
|
424 |
+
Args:
|
425 |
+
img_size (int, optional): Input image size. Defaults to 224.
|
426 |
+
patch_size (int, optional): Patch size to be used by the transformer. Defaults to 16.
|
427 |
+
num_frames (int, optional): Number of frames (temporal dimension) to be input to the encoder. Defaults to 1.
|
428 |
+
tubelet_size (int, optional): Tubelet size used in patch embedding. Defaults to 1.
|
429 |
+
in_chans (int, optional): Number of input channels. Defaults to 3.
|
430 |
+
embed_dim (int, optional): Embedding dimension. Defaults to 1024.
|
431 |
+
depth (int, optional): Encoder depth. Defaults to 24.
|
432 |
+
num_heads (int, optional): Number of heads used in the encoder blocks. Defaults to 16.
|
433 |
+
mlp_ratio (float, optional): Ratio to be used for the size of the MLP in encoder blocks. Defaults to 4.0.
|
434 |
+
norm_layer (nn.Module, optional): Norm layer to be used. Defaults to nn.LayerNorm.
|
435 |
+
norm_pix_loss (bool, optional): Whether to use Norm Pix Loss. Defaults to False.
|
436 |
+
pretrained (str, optional): Path to pretrained encoder weights. Defaults to None.
|
437 |
+
"""
|
438 |
+
super().__init__()
|
439 |
+
|
440 |
+
# --------------------------------------------------------------------------
|
441 |
+
# MAE encoder specifics
|
442 |
+
self.embed_dim = embed_dim
|
443 |
+
self.patch_embed = PatchEmbed(
|
444 |
+
img_size, patch_size, num_frames, tubelet_size, in_chans, embed_dim
|
445 |
+
)
|
446 |
+
num_patches = self.patch_embed.num_patches
|
447 |
+
self.num_frames = num_frames
|
448 |
+
|
449 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
450 |
+
self.pos_embed = nn.Parameter(
|
451 |
+
torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False
|
452 |
+
) # fixed sin-cos embedding
|
453 |
+
|
454 |
+
self.blocks = nn.ModuleList(
|
455 |
+
[
|
456 |
+
Block(
|
457 |
+
embed_dim,
|
458 |
+
num_heads,
|
459 |
+
mlp_ratio,
|
460 |
+
qkv_bias=True,
|
461 |
+
norm_layer=norm_layer,
|
462 |
+
)
|
463 |
+
for _ in range(depth)
|
464 |
+
]
|
465 |
+
)
|
466 |
+
self.norm = norm_layer(embed_dim)
|
467 |
+
|
468 |
+
self.norm_pix_loss = norm_pix_loss
|
469 |
+
self.pretrained = pretrained
|
470 |
+
self.debug = debug
|
471 |
+
|
472 |
+
self.initialize_weights()
|
473 |
+
|
474 |
+
def initialize_weights(self):
|
475 |
+
# initialize (and freeze) pos_embed by sin-cos embedding
|
476 |
+
pos_embed = get_3d_sincos_pos_embed(
|
477 |
+
self.pos_embed.shape[-1], self.patch_embed.grid_size, cls_token=True
|
478 |
+
)
|
479 |
+
self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
|
480 |
+
|
481 |
+
# initialize patch_embed like nn.Linear (instead of nn.Conv2d)
|
482 |
+
w = self.patch_embed.proj.weight.data
|
483 |
+
torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
484 |
+
|
485 |
+
# TODO: FIX huggingface config
|
486 |
+
# load pretrained weights
|
487 |
+
# if self.pretrained:
|
488 |
+
# if self.pretrained.endswith('.safetensors'):
|
489 |
+
# self._load_safetensors_weights()
|
490 |
+
# elif self.pretrained == 'huggingface':
|
491 |
+
# print("TemporalViTEncoder | Using HuggingFace pretrained weights.")
|
492 |
+
# else:
|
493 |
+
# self._load_pt_weights()
|
494 |
+
# else:
|
495 |
+
# self.apply(self._init_weights)
|
496 |
+
|
497 |
+
def _load_safetensors_weights(self):
|
498 |
+
with safe_open(self.pretrained, framework='pt', device='cpu') as f:
|
499 |
+
checkpoint_state_dict = {k: torch.tensor(v) for k, v in f.items()}
|
500 |
+
missing_keys, unexpected_keys = self.load_state_dict(checkpoint_state_dict, strict=False)
|
501 |
+
if missing_keys:
|
502 |
+
print("TemporalViTEncoder | Warning: Missing keys in the state dict:", missing_keys)
|
503 |
+
if unexpected_keys:
|
504 |
+
print("TemporalViTEncoder | Warning: Unexpected keys in the state dict:", unexpected_keys)
|
505 |
+
print(f"TemporalViTEncoder | Loaded pretrained weights from '{self.pretrained}' (safetensors).")
|
506 |
+
|
507 |
+
def _load_pt_weights(self):
|
508 |
+
checkpoint = torch.load(self.pretrained, map_location='cpu')
|
509 |
+
checkpoint_state_dict = checkpoint.get('state_dict', checkpoint)
|
510 |
+
missing_keys, unexpected_keys = self.load_state_dict(checkpoint_state_dict, strict=False)
|
511 |
+
if missing_keys:
|
512 |
+
print("TemporalViTEncoder | Warning: Missing keys in the state dict:", missing_keys)
|
513 |
+
if unexpected_keys:
|
514 |
+
print("TemporalViTEncoder | Warning: Unexpected keys in the state dict:", unexpected_keys)
|
515 |
+
print(f"TemporalViTEncoder | Loaded pretrained weights from '{self.pretrained}' (pt file).")
|
516 |
+
|
517 |
+
def _init_weights(self, m):
|
518 |
+
print("TemporalViTEncoder | Newly Initializing weights...")
|
519 |
+
if isinstance(m, nn.Linear):
|
520 |
+
# we use xavier_uniform following official JAX ViT:
|
521 |
+
torch.nn.init.xavier_uniform_(m.weight)
|
522 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
523 |
+
nn.init.constant_(m.bias, 0)
|
524 |
+
elif isinstance(m, nn.LayerNorm):
|
525 |
+
nn.init.constant_(m.bias, 0)
|
526 |
+
nn.init.constant_(m.weight, 1.0)
|
527 |
+
|
528 |
+
def forward(self, x):
|
529 |
+
if self.debug:
|
530 |
+
print('TemporalViTEncoder IN:', x.shape)
|
531 |
+
|
532 |
+
# embed patches
|
533 |
+
x, _, _ = self.patch_embed(x)
|
534 |
+
|
535 |
+
if self.debug:
|
536 |
+
print('TemporalViTEncoder EMBED:', x.shape)
|
537 |
+
|
538 |
+
# add pos embed w/o cls token
|
539 |
+
x = x + self.pos_embed[:, 1:, :]
|
540 |
+
|
541 |
+
# append cls token
|
542 |
+
cls_token = self.cls_token + self.pos_embed[:, :1, :]
|
543 |
+
cls_tokens = cls_token.expand(x.shape[0], -1, -1)
|
544 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
545 |
+
|
546 |
+
# apply Transformer blocks
|
547 |
+
for blk in self.blocks:
|
548 |
+
x = blk(x)
|
549 |
+
|
550 |
+
x = self.norm(x)
|
551 |
+
|
552 |
+
if self.debug:
|
553 |
+
print('TemporalViTEncoder OUT:', x.shape)
|
554 |
+
|
555 |
+
return tuple([x])
|
pages/1_Select_Location.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from streamlit_folium import st_folium
|
3 |
+
import folium
|
4 |
+
from geopy.geocoders import Nominatim
|
5 |
+
|
6 |
+
# Define the bounding box
|
7 |
+
ZUERICH_BBOX = [8.364, 47.240, 9.0405, 47.69894]
|
8 |
+
|
9 |
+
def within_bbox(lat, lon, bbox):
|
10 |
+
"""Check if a point is within the given bounding box."""
|
11 |
+
return bbox[1] <= lat <= bbox[3] and bbox[0] <= lon <= bbox[2]
|
12 |
+
|
13 |
+
def select_coordinates():
|
14 |
+
st.title("Step 1: Select Location")
|
15 |
+
|
16 |
+
instructions = """
|
17 |
+
1. Choose a crop classification location. Search for a location or click on the map.
|
18 |
+
2. Proceed to the "Perform Crop Classification" step.
|
19 |
+
|
20 |
+
_Note:_ The location must be within the green ZüriCrop area.
|
21 |
+
"""
|
22 |
+
st.sidebar.header("Instructions")
|
23 |
+
st.sidebar.markdown(instructions)
|
24 |
+
|
25 |
+
# Initialize a map centered around the midpoint of the bounding box
|
26 |
+
midpoint_lat = (ZUERICH_BBOX[1] + ZUERICH_BBOX[3]) / 2
|
27 |
+
midpoint_lon = (ZUERICH_BBOX[0] + ZUERICH_BBOX[2]) / 2
|
28 |
+
m = folium.Map(location=[midpoint_lat, midpoint_lon], zoom_start=9)
|
29 |
+
|
30 |
+
# Add the bounding box to the map as a rectangle
|
31 |
+
folium.Rectangle(
|
32 |
+
bounds=[[ZUERICH_BBOX[1], ZUERICH_BBOX[0]], [ZUERICH_BBOX[3], ZUERICH_BBOX[2]]],
|
33 |
+
color="green",
|
34 |
+
fill=True,
|
35 |
+
fill_opacity=0.1
|
36 |
+
).add_to(m)
|
37 |
+
|
38 |
+
# Search for a location
|
39 |
+
geolocator = Nominatim(user_agent="streamlit-app")
|
40 |
+
location_query = st.text_input("Search for a location:")
|
41 |
+
|
42 |
+
if location_query:
|
43 |
+
location = geolocator.geocode(location_query)
|
44 |
+
if location:
|
45 |
+
lat, lon = location.latitude, location.longitude
|
46 |
+
folium.Marker([lat, lon], tooltip=location.address).add_to(m)
|
47 |
+
m.location = [lat, lon]
|
48 |
+
m.zoom_start = 12
|
49 |
+
|
50 |
+
if within_bbox(lat, lon, ZUERICH_BBOX):
|
51 |
+
st.success(f"Location found: {location.address}. It is within the bounding box.")
|
52 |
+
st.session_state["selected_location"] = (lat, lon)
|
53 |
+
else:
|
54 |
+
st.error(f"Location found: {location.address}. It is outside the bounding box.")
|
55 |
+
else:
|
56 |
+
st.error("Location not found. Please try again.")
|
57 |
+
|
58 |
+
# Add a click event listener to capture coordinates
|
59 |
+
m.add_child(folium.LatLngPopup())
|
60 |
+
|
61 |
+
# Display the map using streamlit-folium
|
62 |
+
st_data = st_folium(m, height=500, width=800)
|
63 |
+
|
64 |
+
# Check if the user clicked within the bounding box
|
65 |
+
if st_data["last_clicked"]:
|
66 |
+
lat, lon = st_data["last_clicked"]["lat"], st_data["last_clicked"]["lng"]
|
67 |
+
if within_bbox(lat, lon, ZUERICH_BBOX):
|
68 |
+
st.success(f"Selected Location: Latitude {lat}, Longitude {lon}")
|
69 |
+
st.session_state["selected_location"] = (lat, lon)
|
70 |
+
else:
|
71 |
+
st.error(f"Selected Location is outside the allowed area. Please select a location within the bounding box.")
|
72 |
+
|
73 |
+
# Proceed to the next step
|
74 |
+
link_disabled = "selected_location" not in st.session_state
|
75 |
+
st.sidebar.page_link("pages/2_Perform_Crop_Classification.py", label="Proceed to Crop Classification", icon="🌾", disabled=link_disabled)
|
76 |
+
|
77 |
+
if __name__ == "__main__":
|
78 |
+
select_coordinates()
|
pages/2_Perform_Crop_Classification.py
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import leafmap.foliumap as leafmap
|
3 |
+
from transformers import PretrainedConfig
|
4 |
+
from folium import Icon
|
5 |
+
|
6 |
+
from messis.messis import Messis
|
7 |
+
from inference import perform_inference
|
8 |
+
|
9 |
+
st.set_page_config(layout="wide")
|
10 |
+
|
11 |
+
GEOTIFF_PATH = "./data/stacked_features.tif"
|
12 |
+
|
13 |
+
# Load the model
|
14 |
+
@st.cache_resource
|
15 |
+
def load_model():
|
16 |
+
config = PretrainedConfig.from_pretrained('crop-classification/messis', revision='47d9ca4')
|
17 |
+
model = Messis.from_pretrained('crop-classification/messis', cache_dir='./hf_cache/', revision='47d9ca4')
|
18 |
+
return model, config
|
19 |
+
model, config = load_model()
|
20 |
+
|
21 |
+
def perform_inference_step():
|
22 |
+
st.title("Step 2: Perform Crop Classification")
|
23 |
+
|
24 |
+
if "selected_location" not in st.session_state:
|
25 |
+
st.error("No location selected. Please select a location first.")
|
26 |
+
st.page_link("pages/1_Select_Location.py", label="Select Location", icon="📍")
|
27 |
+
return
|
28 |
+
|
29 |
+
lat, lon = st.session_state["selected_location"]
|
30 |
+
|
31 |
+
# Sidebar
|
32 |
+
st.sidebar.header("Settings")
|
33 |
+
|
34 |
+
# Timestep Slider
|
35 |
+
timestep = st.sidebar.slider("Select Timestep", 1, 9, 5)
|
36 |
+
|
37 |
+
# Band Dropdown
|
38 |
+
band_options = {
|
39 |
+
"RGB": [1, 2, 3], # Adjust indices based on the actual bands in your GeoTIFF
|
40 |
+
"NIR": [4],
|
41 |
+
"SWIR1": [5],
|
42 |
+
"SWIR2": [6]
|
43 |
+
}
|
44 |
+
vmin_vmax = {
|
45 |
+
"RGB": (89, 1878),
|
46 |
+
"NIR": (165, 5468),
|
47 |
+
"SWIR1": (120, 3361),
|
48 |
+
"SWIR2": (94, 2700)
|
49 |
+
}
|
50 |
+
selected_band = st.sidebar.selectbox("Select Satellite Band to Display", options=list(band_options.keys()), index=0)
|
51 |
+
|
52 |
+
# Calculate the band indices based on the selected timestep
|
53 |
+
selected_bands = [band + (timestep - 1) * 6 for band in band_options[selected_band]]
|
54 |
+
|
55 |
+
instructions = """
|
56 |
+
Click the button "Perform Crop Classification".
|
57 |
+
|
58 |
+
_Note:_
|
59 |
+
- Messis will classify the crop types for the fields in your selected location.
|
60 |
+
- Hover over the fields to see the predicted and true crop type.
|
61 |
+
- The satellite images might take a few seconds to load.
|
62 |
+
"""
|
63 |
+
st.sidebar.header("Instructions")
|
64 |
+
st.sidebar.markdown(instructions)
|
65 |
+
|
66 |
+
# Initialize the map
|
67 |
+
m = leafmap.Map(center=(lat, lon), zoom=10, draw_control=False)
|
68 |
+
|
69 |
+
# Perform inference
|
70 |
+
if st.sidebar.button("Perform Crop Classification", type="primary"):
|
71 |
+
predictions = perform_inference(lon, lat, model, config, debug=True)
|
72 |
+
|
73 |
+
m.add_data(predictions,
|
74 |
+
layer_name = "Predictions",
|
75 |
+
column="Correct",
|
76 |
+
add_legend=False,
|
77 |
+
style_function=lambda x: {"fillColor": "green" if x["properties"]["Correct"] else "red", "color": "black", "weight": 0, "fillOpacity": 0.25},
|
78 |
+
)
|
79 |
+
st.success("Inference completed!")
|
80 |
+
|
81 |
+
# GeoTIFF Satellite Imagery with selected timestep and band
|
82 |
+
m.add_raster(
|
83 |
+
GEOTIFF_PATH,
|
84 |
+
layer_name="Satellite Image",
|
85 |
+
bands=selected_bands,
|
86 |
+
fit_bounds=True,
|
87 |
+
vmin=vmin_vmax[selected_band][0],
|
88 |
+
vmax=vmin_vmax[selected_band][1],
|
89 |
+
)
|
90 |
+
|
91 |
+
# Show the POI on the map
|
92 |
+
poi_icon = Icon(color="green", prefix="fa", icon="crosshairs")
|
93 |
+
m.add_marker(location=[lat, lon], popup="Selected Location", layer_name="POI", icon=poi_icon)
|
94 |
+
|
95 |
+
# Display the map in the Streamlit app
|
96 |
+
m.to_streamlit()
|
97 |
+
|
98 |
+
if __name__ == "__main__":
|
99 |
+
perform_inference_step()
|
requirements.txt
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch==2.3.0
|
2 |
+
PyYAML==6.0.1
|
3 |
+
rasterio==1.3.10
|
4 |
+
torchvision==0.18.0
|
5 |
+
shapely==2.0.4
|
6 |
+
geopandas==0.14.4
|
7 |
+
pytorch-lightning==2.2.3
|
8 |
+
dvc==3.50.1
|
9 |
+
streamlit==1.37.0
|
10 |
+
leafmap==0.36.6
|
11 |
+
transformers==4.41.2
|
12 |
+
folium==0.17.0
|
13 |
+
streamlit-folium==0.22.0
|
14 |
+
geopy==2.4.1
|
15 |
+
localtileserver==0.10.3
|
16 |
+
xarray==2024.7.0
|
17 |
+
scipy==1.14.0
|
18 |
+
mapclassify==2.8.0
|
19 |
+
wandb==0.16.6
|
20 |
+
numpy==1.26.4
|
21 |
+
lion-pytorch==0.2.2
|
22 |
+
timm==0.9.16
|
23 |
+
pyproj
|