MultiApp / app.py
Avatarr05's picture
Update app.py
de1b7f1 verified
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from scipy.signal import savgol_filter
import rasterio
import multiprocessing
import time
import torch
from pickle import load
import warnings
import gradio as gr
import os
from matplotlib.pyplot import figure
from mpl_toolkits.axes_grid1 import make_axes_locatable
import matplotlib.ticker as ticker
from matplotlib.animation import FuncAnimation
from matplotlib import rc
from rasterio.plot import show
from huggingface_hub import hf_hub_download
warnings.filterwarnings("ignore")
rc('animation', html='jshtml')
# ---------------------------
# Trait list (unchanged)
# ---------------------------
Traits = ["cab", "cw", "cm", "LAI", "cp", "cbc", "car", "anth"]
# ---------------------------
# Spectral preprocessing
# ---------------------------
def filter_segment(features_noWtab, order=1, der=False):
part1 = features_noWtab.copy()
if der:
fr1 = savgol_filter(part1, 65, 1, deriv=1)
else:
fr1 = savgol_filter(part1, 65, order)
return pd.DataFrame(data=fr1, columns=part1.columns)
def feature_preparation(features, inval=[1351,1431,1801,2051], frmax=2451, order=1, der=False):
other = features.copy()
other.columns = other.columns.astype('int')
other[other < 0] = np.nan
other[other > 1] = np.nan
other = (other.ffill() + other.bfill())/2
other = other.interpolate(method='linear', axis=1, limit_direction='both')
wt_ab = [i for i in range(inval[0],inval[1])] + [i for i in range(inval[2],inval[3])] + [i for i in range(2451,2501)]
features_noWtab = other.drop(wt_ab, axis=1)
fr1 = filter_segment(features_noWtab.loc[:,:inval[0]-1], order=order, der=der)
fr2 = filter_segment(features_noWtab.loc[:,inval[1]:inval[2]-1], order=order, der=der)
fr3 = filter_segment(features_noWtab.loc[:,inval[3]:frmax], order=order, der=der)
inter = pd.concat([fr1,fr2,fr3], axis=1, join='inner')
inter[inter<0]=0
return inter
def plot_fig(features, save=False, file=None, figsize=(15,10)):
plt.figure(figsize=figsize)
plt.plot(features.T)
plt.ylim(0, features.max().max())
if save:
plt.savefig(file + '.pdf', bbox_inches='tight', dpi=1000)
plt.savefig(file + '.svg', bbox_inches='tight', dpi=1000)
plt.show()
# ---------------------------
# Image handling
# ---------------------------
def image_processing(enmap_im_path, bands_path):
bands = pd.read_csv(bands_path)['bands'].astype(float)
src = rasterio.open(enmap_im_path)
array = src.read()
sp_px = np.stack([array[i].reshape(-1,1) for i in range(array.shape[0])], axis=0)
sp_px = np.swapaxes(sp_px.mean(axis=2),0,1)
assert (sp_px.shape[1] == bands.shape[0]), "Mismatch between image bands and CSV bands!"
df = pd.DataFrame(sp_px, columns=bands.to_list())
df[df < df.quantile(0.01).min() + 10] = np.nan
idx_null = df[df.T.isna().all()].index
return src, df, idx_null
def process_dataframe(veg_spec):
veg_reindex = veg_spec.reindex(columns=sorted(veg_spec.columns.tolist() +
[i for i in range(400,2501) if i not in veg_spec.columns.tolist()]))
veg_reindex = veg_reindex/10000
veg_reindex.columns = veg_reindex.columns.astype(int)
inter = veg_reindex.loc[:,~veg_reindex.columns.duplicated()]
inter = feature_preparation(veg_reindex, order=1)
inter = inter.loc[:,~inter.columns.duplicated()]
return inter.loc[:,400:]
def transform_data(df):
num_cpus = multiprocessing.cpu_count()
df_chunks = [chunk for chunk in np.array_split(df, num_cpus)]
print("Starting data transformation ...")
with multiprocessing.Pool(num_cpus) as pool:
results = pool.map(process_dataframe, df_chunks)
pool.close(); pool.join()
df_transformed = pd.concat(results).reset_index(drop=True)
print("Transformation complete.")
return df_transformed
# ---------------------------
# Model loading (PyTorch)
# ---------------------------
def load_model(dir_data, gp=None):
"""
Loads a PyTorch model and its associated scaler from a directory.
Replaces the original TensorFlow-based loading logic.
"""
model_path = os.path.join(dir_data, "model.pt")
scaler_path = os.path.join(dir_data, "scaler_global.pkl")
if not os.path.exists(model_path):
raise FileNotFoundError(f"Model weights not found in {dir_data}")
model = torch.load(model_path, map_location="cpu")
model.eval()
if os.path.exists(scaler_path):
scaler_list = load(open(scaler_path, "rb"))
else:
scaler_list = None
return model, scaler_list
# ---------------------------
# Visualization utilities
# ---------------------------
def animation_preds(src, preds_tr, Traits=Traits):
from matplotlib.animation import FuncAnimation
import matplotlib.ticker as ticker
def update(frame):
tr = frame
preds_tr_ = pd.DataFrame(np.array(preds_tr.loc[:, tr]))
preds_vis = preds_tr_.copy()[preds_tr_ < preds_tr_.quantile(0.99)]
flag = np.array(preds_vis)
maxv = pd.DataFrame(flag).max().max()
minv = pd.DataFrame(flag).min().min()
pred_im.set_array(preds_tr_.values.reshape(src.shape[0], src.shape[1]))
pred_im.set_clim(vmin=minv, vmax=maxv)
ax2.set_title(f"{Traits[tr]} map")
return pred_im
plt.rc('font', size=3)
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(3, 2), dpi=300,
sharex=True, sharey=True,
gridspec_kw={'width_ratios': [1, 1.09]})
nir = src.read(72)/10000
red = src.read(47)/10000
green = src.read(28)/10000
blue = src.read(6)/10000
nrg = np.dstack((nir, red, green))
ax1.imshow(nrg)
tr = 0
preds_tr_ = pd.DataFrame(np.array(preds_tr.loc[:, tr]))
preds_vis = preds_tr_.copy()[preds_tr_ < preds_tr_.quantile(0.99)]
flag = np.array(preds_vis)
maxv = pd.DataFrame(flag).max().max()
minv = pd.DataFrame(flag).min().min()
pred_im = ax2.imshow(preds_tr_.values.reshape(src.shape[0], src.shape[1]), vmin=minv, vmax=maxv)
plt.colorbar(pred_im, ax=ax2, fraction=0.04, pad=0.04)
ax1.set(title="Original scene (False Color)")
ax2.set(title=f"{Traits[tr]} map")
for ax in (ax1, ax2):
ax.set_aspect("equal")
ax.axis("off")
ax.xaxis.set_major_locator(ticker.NullLocator())
ax.yaxis.set_major_locator(ticker.NullLocator())
animation = FuncAnimation(fig, update, frames=range(1, 20), interval=1000)
animation.save("Traits_predictions.gif")
return "Traits_predictions.gif"
def geo_tiff_save(src, preds):
size = (src.height, src.width, preds.shape[1])
new_image_path = "./twentyTraitPredictions.tif"
with rasterio.open(
new_image_path, "w",
driver="GTiff",
width=size[1], height=size[0],
count=size[2], dtype="float32",
crs=src.crs, transform=src.transform
) as new_image:
for i in range(1, size[2] + 1):
array_data = np.array(preds.loc[:, i-1]).reshape((src.height, src.width))
new_image.write(array_data, i)
return new_image_path
# -------------------------------
# Model configuration
# -------------------------------
repo_id = "Avatarr05/Multi-trait_SSL"
# Map of available pretrained weights in your repo
model_file_map = {
("MAE", "Full Range"): "mae/MAE_FR_400-2449_FT_155.pt",
("MAE", "Half Range"): "mae/MAE_HR_VNIR_400-899_FT_155.pt",
("GAN", "Full Range"): "Gans_models/checkpoints_GanFR_seed140/best_model.pt",
("GAN", "Half Range"): "Gans_models/checkpoints_GanHR_seed140/best_model.pt",
}
_model_cache = {}
def load_pretrained_model(model_name, range_type):
"""Downloads and loads pretrained weights and associated scaler."""
key = (model_name, range_type)
if key in _model_cache:
return _model_cache[key]
if key not in model_file_map:
raise ValueError(f"No pretrained weights found for {model_name} ({range_type})")
model_path = model_file_map[key]
# Download from your Hugging Face repo
file_path = hf_hub_download(repo_id=repo_id, filename=model_path)
# Load PyTorch model and scaler
best_model, scaler_list = load_model(os.path.dirname(file_path))
_model_cache[key] = (best_model, scaler_list)
return best_model, scaler_list
# -------------------------------
# Core function: regression + visualization
# -------------------------------
def apply_regression(input_image, input_csv, model_choice, range_choice):
"""
Applies the pretrained model to the uploaded hyperspectral scene (.tif)
and associated band CSV, using your original preprocessing + transformations.
"""
# 1️⃣ Load model + scaler
best_model, scaler_list = load_pretrained_model(model_choice, range_choice)
best_model.eval()
# 2️⃣ Preprocess input data (your unchanged pipeline)
src, df, idx_null = image_processing(input_image, input_csv)
df_transformed = transform_data(df)
# 3️⃣ Run inference (PyTorch forward pass)
with torch.no_grad():
x = torch.tensor(df_transformed.values, dtype=torch.float32)
tf_preds = best_model(x).numpy()
# 4️⃣ Reverse scaling
if scaler_list is not None:
tf_preds = scaler_list.inverse_transform(tf_preds)
# 5️⃣ Build prediction DataFrame
preds = pd.DataFrame(tf_preds)
preds.loc[idx_null] = np.nan
# 6️⃣ Generate visualization and save GeoTIFF
fig = animation_preds(src, preds)
raster_path = geo_tiff_save(src, preds)
return fig, raster_path
# -------------------------------
# Gradio interface
# -------------------------------
iface = gr.Interface(
fn=apply_regression,
inputs=[
gr.File(type="filepath", label="Upload Hyperspectral Scene (.tif)"),
gr.File(type="filepath", label="Upload Band Information (.csv)"),
gr.Dropdown(["MAE", "GAN"], label="Select Model Type"),
gr.Radio(["Full Range", "Half Range"], label="Scene Range"),
],
outputs=[
gr.Image(label="Predicted Trait Maps (Animation)", show_download_button=False),
gr.File(label="Download Predicted GeoTIFF"),
],
title="🛰️ Multi-Trait Prediction from Hyperspectral Scenes (PyTorch)",
description=(
"Upload your hyperspectral scene (.tif) and its corresponding CSV file. "
"The selected pretrained model will process the data, predict multiple traits, "
"and generate both an animated visualization and a downloadable GeoTIFF."
),
# article=copyright_html,
theme="soft",
)
# Launch the Gradio app
iface.launch() #share=False