Spaces:
Configuration error
Configuration error
import os | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import matplotlib.animation as animation | |
import matplotlib.ticker as mticker | |
import cartopy.crs as ccrs | |
import cartopy.feature as cfeature | |
import cartopy.io.shapereader as shpreader | |
from adjustText import adjust_text | |
from .interpolation import interpolate_grid | |
from .basemaps import draw_etopo_basemap | |
import imageio.v2 as imageio | |
import shutil | |
import tempfile | |
class Plot_3DField_Data: | |
""" | |
A class for visualizing 3D spatiotemporal field data (e.g., ash concentration) across time and altitude levels. | |
This class uses matplotlib and cartopy to create: | |
- Animated GIFs of spatial fields at given altitudes | |
- Vertical profile animations over time | |
- Exported static frames with metadata annotations and zoomed views | |
Parameters | |
---------- | |
animator : object | |
A container holding the dataset, including: | |
- datasets: list of xarray-like DataArrays with 'ash_concentration' | |
- lons, lats: 1D longitude and latitude arrays | |
- lat_grid, lon_grid: 2D grid arrays for spatial mapping | |
- levels: 1D array of vertical altitude levels (e.g., in km) | |
output_dir : str | |
Base directory for saving all outputs. Defaults to "plots". | |
cmap : str | |
Matplotlib colormap name. Defaults to "rainbow". | |
fps : int | |
Frames per second for GIFs. Defaults to 2. | |
include_metadata : bool | |
Whether to annotate each figure with simulation metadata. Defaults to True. | |
threshold : float | |
Value threshold (e.g., in g/m³) to highlight exceedances. Defaults to 0.1. | |
zoom_width_deg : float | |
Width of the zoomed-in region in degrees. Defaults to 6.0. | |
zoom_height_deg : float | |
Height of the zoomed-in region in degrees. Defaults to 6.0. | |
zoom_level : int | |
Zoom level passed to basemap drawing. Defaults to 7. | |
basemap_type : str | |
Type of basemap to draw (passed to draw_etopo_basemap). Defaults to "basemap". | |
Methods | |
------- | |
plot_single_z_level(z_km, filename) | |
Generate animation over time at a specific altitude level. | |
plot_vertical_profile_at_time(t_index, filename=None) | |
Generate vertical profile GIF for a single timestep. | |
animate_altitude(t_index, output_path) | |
Animate altitude slices for one timestep. | |
animate_all_altitude_profiles(output_folder='altitude_profiles') | |
Generate vertical animations for all time steps. | |
export_frames_as_jpgs(include_metadata=True) | |
Export individual frames as static `.jpg` images with annotations. | |
""" | |
def __init__(self, animator, output_dir="plots", cmap="rainbow", fps=2, | |
include_metadata=True, threshold=0.1, | |
zoom_width_deg=6.0, zoom_height_deg=6.0, zoom_level=7, basemap_type="basemap"): | |
self.animator = animator | |
self.output_dir = os.path.abspath( | |
os.path.join( | |
os.environ.get("NAME_OUTPUT_DIR", tempfile.gettempdir()), | |
output_dir | |
) | |
) | |
os.makedirs(self.output_dir, exist_ok=True) | |
self.cmap = cmap | |
self.fps = fps | |
self.include_metadata = include_metadata | |
self.threshold = threshold | |
self.zoom_width = zoom_width_deg | |
self.zoom_height = zoom_height_deg | |
shp = shpreader.natural_earth(resolution='110m', category='cultural', name='admin_0_countries') | |
self.country_geoms = list(shpreader.Reader(shp).records()) | |
self.zoom_level=zoom_level | |
self.basemap_type=basemap_type | |
#############3 | |
# Load shapefile once | |
countries_shp = shpreader.natural_earth( | |
resolution='110m', | |
category='cultural', | |
name='admin_0_countries' | |
) | |
self.country_geoms = list(shpreader.Reader(countries_shp).records()) | |
# Cache extent bounds | |
self.lon_min = np.min(self.animator.lons) | |
self.lon_max = np.max(self.animator.lons) | |
self.lat_min = np.min(self.animator.lats) | |
self.lat_max = np.max(self.animator.lats) | |
#####################3 | |
def _make_dirs(self, path): | |
path = os.path.abspath(os.path.join(os.getcwd(), os.path.dirname(path))) | |
os.makedirs(path, exist_ok=True) | |
def _get_zoom_indices(self, center_lat, center_lon): | |
lon_min = center_lon - self.zoom_width / 2 | |
lon_max = center_lon + self.zoom_width / 2 | |
lat_min = center_lat - self.zoom_height / 2 | |
lat_max = center_lat + self.zoom_height / 2 | |
lat_idx = np.where((self.animator.lats >= lat_min) & (self.animator.lats <= lat_max))[0] | |
lon_idx = np.where((self.animator.lons >= lon_min) & (self.animator.lons <= lon_max))[0] | |
return lat_idx, lon_idx, lon_min, lon_max, lat_min, lat_max | |
def _get_max_concentration_location(self): | |
max_conc = -np.inf | |
center_lat = center_lon = None | |
for ds in self.animator.datasets: | |
for z in range(len(self.animator.levels)): | |
data = ds['ash_concentration'].values[z] | |
if np.max(data) > max_conc: | |
max_conc = np.max(data) | |
max_idx = np.unravel_index(np.argmax(data), data.shape) | |
center_lat = self.animator.lat_grid[max_idx] | |
center_lon = self.animator.lon_grid[max_idx] | |
return center_lat, center_lon | |
def _add_country_labels(self, ax, extent): | |
proj = ccrs.PlateCarree() | |
texts = [] | |
for country in self.country_geoms: | |
name = country.attributes['NAME_LONG'] | |
geom = country.geometry | |
try: | |
lon, lat = geom.centroid.x, geom.centroid.y | |
if extent[0] <= lon <= extent[1] and extent[2] <= lat <= extent[3]: | |
text = ax.text(lon, lat, name, fontsize=6, transform=proj, | |
ha='center', va='center', color='white', | |
bbox=dict(facecolor='black', alpha=0.5, linewidth=0)) | |
texts.append(text) | |
except: | |
continue | |
adjust_text(texts, ax=ax, only_move={'points': 'y', 'text': 'y'}, | |
arrowprops=dict(arrowstyle="->", color='white', lw=0.5)) | |
def _plot_frame(self, ax, data, lons, lats, title, levels, scale_label, proj): | |
draw_etopo_basemap(ax, mode=self.basemap_type, zoom=self.zoom_level) | |
c = ax.contourf(lons, lats, data, levels=levels, cmap=self.cmap, alpha=0.6, transform=proj) | |
ax.contour(lons, lats, data, levels=levels, colors='black', linewidths=0.5, transform=proj) | |
ax.set_title(title) | |
ax.set_extent([lons.min(), lons.max(), lats.min(), lats.max()]) | |
ax.coastlines() | |
ax.add_feature(cfeature.BORDERS, linestyle=':') | |
ax.add_feature(cfeature.LAND) | |
ax.add_feature(cfeature.OCEAN) | |
return c | |
# metadata placement function and usage | |
def _draw_metadata_sidebar(self, fig, meta_dict): | |
lines = [ | |
f"Run name: {meta_dict.get('run_name', 'N/A')}", | |
f"Run time: {meta_dict.get('run_time', 'N/A')}", | |
f"Met data: {meta_dict.get('met_data', 'N/A')}", | |
f"Start release: {meta_dict.get('start_of_release', 'N/A')}", | |
f"End release: {meta_dict.get('end_of_release', 'N/A')}", | |
f"Source strength: {meta_dict.get('source_strength', 'N/A')} g/s", | |
f"Release loc: {meta_dict.get('release_location', 'N/A')}", | |
f"Release height: {meta_dict.get('release_height', 'N/A')} m asl", | |
f"Run duration: {meta_dict.get('run_duration', 'N/A')}" | |
] | |
full_text = "\n".join(lines) # ✅ actual newlines | |
fig.text(0.1, 0.095, full_text, va='center', ha='left', | |
fontsize=9, family='monospace', color='black', | |
bbox=dict(facecolor='white', alpha=0.8, edgecolor='gray')) | |
def plot_single_z_level(self, z_km, filename="z_level.gif"): | |
if z_km not in self.animator.levels: | |
print(f"Z level {z_km} km not found.") | |
return | |
z_index = np.where(self.animator.levels == z_km)[0][0] | |
output_path = os.path.join(self.output_dir, "z_levels", filename) | |
fig = plt.figure(figsize=(16, 8)) | |
proj = ccrs.PlateCarree() | |
ax1 = fig.add_subplot(1, 2, 1, projection=proj) | |
ax2 = fig.add_subplot(1, 2, 2, projection=proj) | |
center_lat, center_lon = self._get_max_concentration_location() | |
lat_idx, lon_idx, lon_min, lon_max, lat_min, lat_max = self._get_zoom_indices(center_lat, center_lon) | |
lat_zoom = self.animator.lats[lat_idx] | |
lon_zoom = self.animator.lons[lon_idx] | |
lon_zoom_grid, lat_zoom_grid = np.meshgrid(lon_zoom, lat_zoom) | |
meta = self.animator.datasets[0].attrs | |
valid_frames = [] | |
for t in range(len(self.animator.datasets)): | |
interp = interpolate_grid(self.animator.datasets[t]['ash_concentration'].values[z_index], | |
self.animator.lon_grid, self.animator.lat_grid) | |
if np.isfinite(interp).sum() > 0: | |
valid_frames.append(t) | |
if not valid_frames: | |
print(f"No valid frames for Z={z_km} km.") | |
plt.close() | |
return | |
def update(t): | |
ax1.clear() | |
ax2.clear() | |
data = self.animator.datasets[t]['ash_concentration'].values[z_index] | |
interp = interpolate_grid(data, self.animator.lon_grid, self.animator.lat_grid) | |
interp = np.where(interp < 0, np.nan, interp) | |
zoom_plot = interp[np.ix_(lat_idx, lon_idx)] | |
valid_vals = interp[np.isfinite(interp)] | |
if valid_vals.size == 0: | |
return [] | |
min_val, max_val = np.nanmin(valid_vals), np.nanmax(valid_vals) | |
log_cutoff = 1e-3 | |
use_log = min_val > log_cutoff and (max_val / (min_val + 1e-6)) > 100 | |
levels = ( | |
np.logspace(np.log10(log_cutoff), np.log10(max_val), 20) | |
if use_log else | |
np.linspace(0, max_val, 20) | |
) | |
data_for_plot = np.where(interp > log_cutoff, interp, np.nan) if use_log else interp | |
scale_label = "Log" if use_log else "Linear" | |
c = self._plot_frame(ax1, data_for_plot, self.animator.lons, self.animator.lats, | |
f"T{t+1} | Alt: {z_km} km (Full - {scale_label})", levels, scale_label, proj) | |
self._plot_frame(ax2, zoom_plot, lon_zoom, lat_zoom, | |
f"T{t} | Alt: {z_km} km (Zoom - {scale_label})", levels, scale_label, proj) | |
self._add_country_labels(ax1, [self.animator.lons.min(), self.animator.lons.max(), | |
self.animator.lats.min(), self.animator.lats.max()]) | |
self._add_country_labels(ax2, [lon_min, lon_max, lat_min, lat_max]) | |
if not hasattr(update, "colorbar"): | |
update.colorbar = fig.colorbar(c, ax=[ax1, ax2], orientation='vertical', | |
label="Ash concentration (g/m³)") | |
formatter = mticker.FuncFormatter(lambda x, _: f'{x:.2g}') | |
update.colorbar.ax.yaxis.set_major_formatter(formatter) | |
# ✅ Draw threshold outline and label only if exceeded | |
if np.nanmax(valid_vals) > self.threshold: | |
ax1.contour(self.animator.lons, self.animator.lats, interp, levels=[self.threshold], | |
colors='red', linewidths=2, transform=proj) | |
ax2.contour(lon_zoom, lat_zoom, zoom_plot, levels=[self.threshold], | |
colors='red', linewidths=2, transform=proj) | |
ax2.text(0.99, 0.01, f"⚠ Max Thresold Exceed: {np.nanmax(valid_vals):.2f} > {self.threshold} g/m³", | |
transform=ax2.transAxes, ha='right', va='bottom', | |
fontsize=9, color='red', | |
bbox=dict(facecolor='white', alpha=0.8, edgecolor='red')) | |
return [] | |
self._draw_metadata_sidebar(fig, meta) | |
self._make_dirs(output_path) | |
fig.tight_layout() | |
ani = animation.FuncAnimation(fig, update, frames=valid_frames, blit=False, cache_frame_data =False) | |
ani.save(output_path, writer='pillow', fps=self.fps, dpi=300) | |
plt.close() | |
print(f"✅ Saved Z-level animation to {output_path}") | |
def plot_vertical_profile_at_time(self, t_index, filename=None): | |
time_label = f"T{t_index+1}" | |
for z_index, z_val in enumerate(self.animator.levels): | |
filename = f"TimeSlices_Z{z_val:.1f}km.gif" | |
self.plot_single_z_level(z_val, filename=os.path.join("vertical_profiles_timeSlice", filename)) | |
################################################ | |
def animate_altitude(self, t_index: int, output_path: str): | |
if not (0 <= t_index < len(self.animator.datasets)): | |
print(f"Invalid time index {t_index}. Must be between 0 and {len(self.animator.datasets) - 1}.") | |
ds = self.animator.datasets[t_index] | |
fig = plt.figure(figsize=(18, 7)) | |
proj = ccrs.PlateCarree() | |
ax1 = fig.add_subplot(1, 2, 1, projection=proj) | |
ax2 = fig.add_subplot(1, 2, 2, projection=proj) | |
meta = ds.attrs | |
center_lat, center_lon = self._get_max_concentration_location() | |
if center_lat is None or center_lon is None: | |
print(f"No valid data found for time T{t_index + 1}. Skipping...") | |
plt.close() | |
return | |
lat_idx, lon_idx, lon_min, lon_max, lat_min, lat_max = self._get_zoom_indices(center_lat, center_lon) | |
lat_zoom = self.animator.lats[lat_idx] | |
lon_zoom = self.animator.lons[lon_idx] | |
lon_zoom_grid, lat_zoom_grid = np.meshgrid(lon_zoom, lat_zoom) | |
z_indices_with_data = [] | |
for z_index in range(len(self.animator.levels)): | |
data = ds['ash_concentration'].values[z_index] | |
interp = interpolate_grid(data, self.animator.lon_grid, self.animator.lat_grid) | |
if np.isfinite(interp).sum() > 0: | |
z_indices_with_data.append(z_index) | |
if not z_indices_with_data: | |
print(f"No valid Z-levels at time T{t_index + 1}.") | |
plt.close() | |
return | |
def update(z_index): | |
ax1.clear() | |
ax2.clear() | |
data = ds['ash_concentration'].values[z_index] | |
interp = interpolate_grid(data, self.animator.lon_grid, self.animator.lat_grid) | |
interp = np.where(interp < 0, np.nan, interp) | |
zoom_plot = interp[np.ix_(lat_idx, lon_idx)] | |
valid_vals = interp[np.isfinite(interp)] | |
if valid_vals.size == 0: | |
return [] | |
min_val, max_val = np.nanmin(valid_vals), np.nanmax(valid_vals) | |
log_cutoff = 1e-3 | |
use_log = min_val > log_cutoff and (max_val / (min_val + 1e-6)) > 100 | |
levels = np.logspace(np.log10(log_cutoff), np.log10(max_val), 20) if use_log else np.linspace(0, max_val, 20) | |
data_for_plot = np.where(interp > log_cutoff, interp, np.nan) if use_log else interp | |
scale_label = "Log" if use_log else "Linear" | |
title1 = f"T{t_index + 1} | Alt: {self.animator.levels[z_index]} km (Full - {scale_label})" | |
title2 = f"T{t_index + 1} | Alt: {self.animator.levels[z_index]} km (Zoom - {scale_label})" | |
c1 = self._plot_frame(ax1, data_for_plot, self.animator.lons, self.animator.lats, title1, levels, scale_label, proj) | |
self._plot_frame(ax2, zoom_plot, lon_zoom, lat_zoom, title2, levels, scale_label, proj) | |
self._add_country_labels(ax1, [self.lon_min, self.lon_max, self.lat_min, self.lat_max]) | |
self._add_country_labels(ax2, [lon_min, lon_max, lat_min, lat_max]) | |
if self.include_metadata: | |
self._draw_metadata_sidebar(fig, meta) | |
if not hasattr(update, "colorbar"): | |
update.colorbar = fig.colorbar(c1, ax=[ax1, ax2], orientation='vertical', | |
label="Ash concentration (g/m³)", shrink=0.75) | |
formatter = mticker.FuncFormatter(lambda x, _: f'{x:.2g}') | |
update.colorbar.ax.yaxis.set_major_formatter(formatter) | |
if np.nanmax(valid_vals) > self.threshold: | |
ax1.contour(self.animator.lons, self.animator.lats, interp, levels=[self.threshold], | |
colors='red', linewidths=2, transform=proj) | |
ax2.contour(lon_zoom, lat_zoom, zoom_plot, levels=[self.threshold], | |
colors='red', linewidths=2, transform=proj) | |
ax2.text(0.99, 0.01, f"⚠ Max Thresold Exceed: {np.nanmax(valid_vals):.2f} > {self.threshold} g/m³", | |
transform=ax2.transAxes, ha='right', va='bottom', | |
fontsize=9, color='red', | |
bbox=dict(facecolor='white', alpha=0.8, edgecolor='red')) | |
return [] | |
os.makedirs(os.path.dirname(output_path), exist_ok=True) | |
#fig.set_size_inches(18, 7) | |
fig.tight_layout(rect=[0.02, 0.02, 0.98, 0.98]) | |
ani = animation.FuncAnimation(fig, update, frames=z_indices_with_data, blit=False, cache_frame_data =False) | |
ani.save(output_path, writer='pillow', fps=self.fps, dpi=300) | |
plt.close() | |
print(f"✅ Saved vertical profile animation for T{t_index + 1} to {output_path}") | |
def animate_all_altitude_profiles(self, output_folder='altitude_profiles'): | |
output_folder = os.path.join(self.output_dir, "altitude_profiles") | |
os.makedirs(output_folder, exist_ok=True) | |
for t_index in range(len(self.animator.datasets)): | |
output_path = os.path.join(output_folder, f"vertical_T{t_index + 1:02d}.gif") | |
print(f"🔄 Generating vertical profile animation for T{t_index + 1}...") | |
self.animate_altitude(t_index, output_path) | |
def export_frames_as_jpgs(self, include_metadata: bool = True): | |
output_folder = os.path.join(self.output_dir, "frames") | |
os.makedirs(output_folder, exist_ok=True) | |
meta = self.animator.datasets[0].attrs | |
legend_text = "\\n".join([ | |
f"Run name: {meta.get('run_name', 'N/A')}", | |
f"Run time: {meta.get('run_time', 'N/A')}", | |
f"Met data: {meta.get('met_data', 'N/A')}", | |
f"Start release: {meta.get('start_of_release', 'N/A')}", | |
f"End release: {meta.get('end_of_release', 'N/A')}", | |
f"Strength: {meta.get('source_strength', 'N/A')} g/s", | |
f"Location: {meta.get('release_location', 'N/A')}", | |
f"Height: {meta.get('release_height', 'N/A')} m asl", | |
f"Duration: {meta.get('run_duration', 'N/A')}" | |
]) | |
for z_index, z_val in enumerate(self.animator.levels): | |
z_dir = os.path.join(output_folder, f"Z{z_val:.1f}km") | |
os.makedirs(z_dir, exist_ok=True) | |
for t in range(len(self.animator.datasets)): | |
data = self.animator.datasets[t]['ash_concentration'].values[z_index] | |
interp = interpolate_grid(data, self.animator.lon_grid, self.animator.lat_grid) | |
if not np.isfinite(interp).any(): | |
continue | |
fig = plt.figure(figsize=(16, 8)) | |
proj = ccrs.PlateCarree() | |
ax1 = fig.add_subplot(1, 2, 1, projection=proj) | |
ax2 = fig.add_subplot(1, 2, 2, projection=proj) | |
valid_vals = interp[np.isfinite(interp)] | |
min_val, max_val = np.nanmin(valid_vals), np.nanmax(valid_vals) | |
log_cutoff = 1e-3 | |
use_log = min_val > log_cutoff and (max_val / (min_val + 1e-6)) > 100 | |
levels = np.logspace(np.log10(log_cutoff), np.log10(max_val), 20) if use_log else np.linspace(0, max_val, 20) | |
data_for_plot = np.where(interp > log_cutoff, interp, np.nan) if use_log else interp | |
scale_label = "Log" if use_log else "Linear" | |
center_lat, center_lon = self._get_max_concentration_location() | |
lat_idx, lon_idx, lon_min, lon_max, lat_min, lat_max = self._get_zoom_indices(center_lat, center_lon) | |
zoom_plot = interp[np.ix_(lat_idx, lon_idx)] | |
lon_zoom = self.animator.lons[lon_idx] | |
lat_zoom = self.animator.lats[lat_idx] | |
c1 = self._plot_frame(ax1, data_for_plot, self.animator.lons, self.animator.lats, | |
f"T{t+1} | Alt: {z_val} km (Full - {scale_label})", levels, scale_label, proj) | |
self._plot_frame(ax2, zoom_plot, lon_zoom, lat_zoom, | |
f"T{t+1} | Alt: {z_val} km (Zoom - {scale_label})", levels, scale_label, proj) | |
self._add_country_labels(ax1, [self.animator.lons.min(), self.animator.lons.max(), | |
self.animator.lats.min(), self.animator.lats.max()]) | |
self._add_country_labels(ax2, [lon_min, lon_max, lat_min, lat_max]) | |
if np.nanmax(valid_vals) > self.threshold: | |
ax1.contour(self.animator.lons, self.animator.lats, interp, levels=[self.threshold], | |
colors='red', linewidths=2, transform=proj) | |
ax2.contour(lon_zoom, lat_zoom, zoom_plot, levels=[self.threshold], | |
colors='red', linewidths=2, transform=proj) | |
ax2.text(0.99, 0.01, f"⚠ Max: {np.nanmax(valid_vals):.2f} > {self.threshold} g/m³", | |
transform=ax2.transAxes, ha='right', va='bottom', | |
fontsize=9, color='red', | |
bbox=dict(facecolor='white', alpha=0.8, edgecolor='red')) | |
if include_metadata: | |
self._draw_metadata_sidebar(fig, meta) | |
cbar = fig.colorbar(c1, ax=[ax1, ax2], orientation='vertical', shrink=0.75, pad=0.03) | |
cbar.set_label("Ash concentration (g/m³)") | |
formatter = mticker.FuncFormatter(lambda x, _: f'{x:.2g}') | |
cbar.ax.yaxis.set_major_formatter(formatter) | |
frame_path = os.path.join(z_dir, f"frame_{t+1:04d}.jpg") | |
plt.savefig(frame_path, bbox_inches='tight') | |
plt.close(fig) | |
print(f"📸 Saved {frame_path}") |