File size: 10,686 Bytes
de1b7f1 38f1e1e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 |
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from scipy.signal import savgol_filter
import rasterio
import multiprocessing
import time
import torch
from pickle import load
import warnings
import gradio as gr
import os
from matplotlib.pyplot import figure
from mpl_toolkits.axes_grid1 import make_axes_locatable
import matplotlib.ticker as ticker
from matplotlib.animation import FuncAnimation
from matplotlib import rc
from rasterio.plot import show
from huggingface_hub import hf_hub_download
warnings.filterwarnings("ignore")
rc('animation', html='jshtml')
# ---------------------------
# Trait list (unchanged)
# ---------------------------
Traits = ["cab", "cw", "cm", "LAI", "cp", "cbc", "car", "anth"]
# ---------------------------
# Spectral preprocessing
# ---------------------------
def filter_segment(features_noWtab, order=1, der=False):
part1 = features_noWtab.copy()
if der:
fr1 = savgol_filter(part1, 65, 1, deriv=1)
else:
fr1 = savgol_filter(part1, 65, order)
return pd.DataFrame(data=fr1, columns=part1.columns)
def feature_preparation(features, inval=[1351,1431,1801,2051], frmax=2451, order=1, der=False):
other = features.copy()
other.columns = other.columns.astype('int')
other[other < 0] = np.nan
other[other > 1] = np.nan
other = (other.ffill() + other.bfill())/2
other = other.interpolate(method='linear', axis=1, limit_direction='both')
wt_ab = [i for i in range(inval[0],inval[1])] + [i for i in range(inval[2],inval[3])] + [i for i in range(2451,2501)]
features_noWtab = other.drop(wt_ab, axis=1)
fr1 = filter_segment(features_noWtab.loc[:,:inval[0]-1], order=order, der=der)
fr2 = filter_segment(features_noWtab.loc[:,inval[1]:inval[2]-1], order=order, der=der)
fr3 = filter_segment(features_noWtab.loc[:,inval[3]:frmax], order=order, der=der)
inter = pd.concat([fr1,fr2,fr3], axis=1, join='inner')
inter[inter<0]=0
return inter
def plot_fig(features, save=False, file=None, figsize=(15,10)):
plt.figure(figsize=figsize)
plt.plot(features.T)
plt.ylim(0, features.max().max())
if save:
plt.savefig(file + '.pdf', bbox_inches='tight', dpi=1000)
plt.savefig(file + '.svg', bbox_inches='tight', dpi=1000)
plt.show()
# ---------------------------
# Image handling
# ---------------------------
def image_processing(enmap_im_path, bands_path):
bands = pd.read_csv(bands_path)['bands'].astype(float)
src = rasterio.open(enmap_im_path)
array = src.read()
sp_px = np.stack([array[i].reshape(-1,1) for i in range(array.shape[0])], axis=0)
sp_px = np.swapaxes(sp_px.mean(axis=2),0,1)
assert (sp_px.shape[1] == bands.shape[0]), "Mismatch between image bands and CSV bands!"
df = pd.DataFrame(sp_px, columns=bands.to_list())
df[df < df.quantile(0.01).min() + 10] = np.nan
idx_null = df[df.T.isna().all()].index
return src, df, idx_null
def process_dataframe(veg_spec):
veg_reindex = veg_spec.reindex(columns=sorted(veg_spec.columns.tolist() +
[i for i in range(400,2501) if i not in veg_spec.columns.tolist()]))
veg_reindex = veg_reindex/10000
veg_reindex.columns = veg_reindex.columns.astype(int)
inter = veg_reindex.loc[:,~veg_reindex.columns.duplicated()]
inter = feature_preparation(veg_reindex, order=1)
inter = inter.loc[:,~inter.columns.duplicated()]
return inter.loc[:,400:]
def transform_data(df):
num_cpus = multiprocessing.cpu_count()
df_chunks = [chunk for chunk in np.array_split(df, num_cpus)]
print("Starting data transformation ...")
with multiprocessing.Pool(num_cpus) as pool:
results = pool.map(process_dataframe, df_chunks)
pool.close(); pool.join()
df_transformed = pd.concat(results).reset_index(drop=True)
print("Transformation complete.")
return df_transformed
# ---------------------------
# Model loading (PyTorch)
# ---------------------------
def load_model(dir_data, gp=None):
"""
Loads a PyTorch model and its associated scaler from a directory.
Replaces the original TensorFlow-based loading logic.
"""
model_path = os.path.join(dir_data, "model.pt")
scaler_path = os.path.join(dir_data, "scaler_global.pkl")
if not os.path.exists(model_path):
raise FileNotFoundError(f"Model weights not found in {dir_data}")
model = torch.load(model_path, map_location="cpu")
model.eval()
if os.path.exists(scaler_path):
scaler_list = load(open(scaler_path, "rb"))
else:
scaler_list = None
return model, scaler_list
# ---------------------------
# Visualization utilities
# ---------------------------
def animation_preds(src, preds_tr, Traits=Traits):
from matplotlib.animation import FuncAnimation
import matplotlib.ticker as ticker
def update(frame):
tr = frame
preds_tr_ = pd.DataFrame(np.array(preds_tr.loc[:, tr]))
preds_vis = preds_tr_.copy()[preds_tr_ < preds_tr_.quantile(0.99)]
flag = np.array(preds_vis)
maxv = pd.DataFrame(flag).max().max()
minv = pd.DataFrame(flag).min().min()
pred_im.set_array(preds_tr_.values.reshape(src.shape[0], src.shape[1]))
pred_im.set_clim(vmin=minv, vmax=maxv)
ax2.set_title(f"{Traits[tr]} map")
return pred_im
plt.rc('font', size=3)
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(3, 2), dpi=300,
sharex=True, sharey=True,
gridspec_kw={'width_ratios': [1, 1.09]})
nir = src.read(72)/10000
red = src.read(47)/10000
green = src.read(28)/10000
blue = src.read(6)/10000
nrg = np.dstack((nir, red, green))
ax1.imshow(nrg)
tr = 0
preds_tr_ = pd.DataFrame(np.array(preds_tr.loc[:, tr]))
preds_vis = preds_tr_.copy()[preds_tr_ < preds_tr_.quantile(0.99)]
flag = np.array(preds_vis)
maxv = pd.DataFrame(flag).max().max()
minv = pd.DataFrame(flag).min().min()
pred_im = ax2.imshow(preds_tr_.values.reshape(src.shape[0], src.shape[1]), vmin=minv, vmax=maxv)
plt.colorbar(pred_im, ax=ax2, fraction=0.04, pad=0.04)
ax1.set(title="Original scene (False Color)")
ax2.set(title=f"{Traits[tr]} map")
for ax in (ax1, ax2):
ax.set_aspect("equal")
ax.axis("off")
ax.xaxis.set_major_locator(ticker.NullLocator())
ax.yaxis.set_major_locator(ticker.NullLocator())
animation = FuncAnimation(fig, update, frames=range(1, 20), interval=1000)
animation.save("Traits_predictions.gif")
return "Traits_predictions.gif"
def geo_tiff_save(src, preds):
size = (src.height, src.width, preds.shape[1])
new_image_path = "./twentyTraitPredictions.tif"
with rasterio.open(
new_image_path, "w",
driver="GTiff",
width=size[1], height=size[0],
count=size[2], dtype="float32",
crs=src.crs, transform=src.transform
) as new_image:
for i in range(1, size[2] + 1):
array_data = np.array(preds.loc[:, i-1]).reshape((src.height, src.width))
new_image.write(array_data, i)
return new_image_path
# -------------------------------
# Model configuration
# -------------------------------
repo_id = "Avatarr05/Multi-trait_SSL"
# Map of available pretrained weights in your repo
model_file_map = {
("MAE", "Full Range"): "mae/MAE_FR_400-2449_FT_155.pt",
("MAE", "Half Range"): "mae/MAE_HR_VNIR_400-899_FT_155.pt",
("GAN", "Full Range"): "Gans_models/checkpoints_GanFR_seed140/best_model.pt",
("GAN", "Half Range"): "Gans_models/checkpoints_GanHR_seed140/best_model.pt",
}
_model_cache = {}
def load_pretrained_model(model_name, range_type):
"""Downloads and loads pretrained weights and associated scaler."""
key = (model_name, range_type)
if key in _model_cache:
return _model_cache[key]
if key not in model_file_map:
raise ValueError(f"No pretrained weights found for {model_name} ({range_type})")
model_path = model_file_map[key]
# Download from your Hugging Face repo
file_path = hf_hub_download(repo_id=repo_id, filename=model_path)
# Load PyTorch model and scaler
best_model, scaler_list = load_model(os.path.dirname(file_path))
_model_cache[key] = (best_model, scaler_list)
return best_model, scaler_list
# -------------------------------
# Core function: regression + visualization
# -------------------------------
def apply_regression(input_image, input_csv, model_choice, range_choice):
"""
Applies the pretrained model to the uploaded hyperspectral scene (.tif)
and associated band CSV, using your original preprocessing + transformations.
"""
# 1️⃣ Load model + scaler
best_model, scaler_list = load_pretrained_model(model_choice, range_choice)
best_model.eval()
# 2️⃣ Preprocess input data (your unchanged pipeline)
src, df, idx_null = image_processing(input_image, input_csv)
df_transformed = transform_data(df)
# 3️⃣ Run inference (PyTorch forward pass)
with torch.no_grad():
x = torch.tensor(df_transformed.values, dtype=torch.float32)
tf_preds = best_model(x).numpy()
# 4️⃣ Reverse scaling
if scaler_list is not None:
tf_preds = scaler_list.inverse_transform(tf_preds)
# 5️⃣ Build prediction DataFrame
preds = pd.DataFrame(tf_preds)
preds.loc[idx_null] = np.nan
# 6️⃣ Generate visualization and save GeoTIFF
fig = animation_preds(src, preds)
raster_path = geo_tiff_save(src, preds)
return fig, raster_path
# -------------------------------
# Gradio interface
# -------------------------------
iface = gr.Interface(
fn=apply_regression,
inputs=[
gr.File(type="filepath", label="Upload Hyperspectral Scene (.tif)"),
gr.File(type="filepath", label="Upload Band Information (.csv)"),
gr.Dropdown(["MAE", "GAN"], label="Select Model Type"),
gr.Radio(["Full Range", "Half Range"], label="Scene Range"),
],
outputs=[
gr.Image(label="Predicted Trait Maps (Animation)", show_download_button=False),
gr.File(label="Download Predicted GeoTIFF"),
],
title="🛰️ Multi-Trait Prediction from Hyperspectral Scenes (PyTorch)",
description=(
"Upload your hyperspectral scene (.tif) and its corresponding CSV file. "
"The selected pretrained model will process the data, predict multiple traits, "
"and generate both an animated visualization and a downloadable GeoTIFF."
),
# article=copyright_html,
theme="soft",
)
# Launch the Gradio app
iface.launch() #share=False |