Mahmudm's picture
Upload 183 files
82c899e verified
# import os
# import numpy as np
# import matplotlib.pyplot as plt
# import matplotlib.animation as animation
# import matplotlib.ticker as mticker
# import cartopy.crs as ccrs
# import cartopy.feature as cfeature
# from adjustText import adjust_text
# import cartopy.io.shapereader as shpreader
# from .interpolation import interpolate_grid
# from .basemaps import draw_etopo_basemap
# def animate_all_z_levels(animator, output_folder: str, fps: int = 2, threshold: float = 0.1):
# os.makedirs(output_folder, exist_ok=True)
# countries_shp = shpreader.natural_earth(resolution='110m', category='cultural', name='admin_0_countries')
# reader = shpreader.Reader(countries_shp)
# country_geoms = list(reader.records())
# for z_index, z_val in enumerate(animator.levels):
# fig = plt.figure(figsize=(16, 7))
# proj = ccrs.PlateCarree()
# ax1 = fig.add_subplot(1, 2, 1, projection=proj)
# ax2 = fig.add_subplot(1, 2, 2, projection=proj)
# valid_mask = np.stack([
# ds['ash_concentration'].values[z_index] for ds in animator.datasets
# ]).max(axis=0) > 0
# y_idx, x_idx = np.where(valid_mask)
# if y_idx.size == 0 or x_idx.size == 0:
# print(f"Z level {z_val} km has no valid data. Skipping...")
# plt.close()
# continue
# y_min, y_max = y_idx.min(), y_idx.max()
# x_min, x_max = x_idx.min(), x_idx.max()
# buffer_y = int((y_max - y_min) * 0.5)
# buffer_x = int((x_max - x_min) * 0.5)
# y_start = max(0, y_min - buffer_y)
# y_end = min(animator.lat_grid.shape[0], y_max + buffer_y + 1)
# x_start = max(0, x_min - buffer_x)
# x_end = min(animator.lon_grid.shape[1], x_max + buffer_x + 1)
# lat_zoom = animator.lats[y_start:y_end]
# lon_zoom = animator.lons[x_start:x_end]
# lon_zoom_grid, lat_zoom_grid = np.meshgrid(lon_zoom, lat_zoom)
# valid_frames = []
# for t in range(len(animator.datasets)):
# data = animator.datasets[t]['ash_concentration'].values[z_index]
# interp = interpolate_grid(data, animator.lon_grid, animator.lat_grid)
# interp = np.where(interp < 0, np.nan, interp)
# if np.isfinite(interp).sum() > 0:
# valid_frames.append(t)
# if not valid_frames:
# print(f"No valid frames for Z={z_val} km. Skipping animation.")
# plt.close()
# continue
# def update(t):
# ax1.clear()
# ax2.clear()
# data = animator.datasets[t]['ash_concentration'].values[z_index]
# interp = interpolate_grid(data, animator.lon_grid, animator.lat_grid)
# interp = np.where(interp < 0, np.nan, interp)
# zoom_plot = interp[y_start:y_end, x_start:x_end]
# valid_vals = interp[np.isfinite(interp)]
# if valid_vals.size == 0:
# return []
# min_val = np.nanmin(valid_vals)
# max_val = np.nanmax(valid_vals)
# log_cutoff = 1e-3
# log_ratio = max_val / (min_val + 1e-6)
# use_log = min_val > log_cutoff and log_ratio > 100
# if use_log:
# data_for_plot = np.where(interp > log_cutoff, interp, np.nan)
# levels = np.logspace(np.log10(log_cutoff), np.log10(max_val), 20)
# scale_label = "Hybrid Log"
# else:
# data_for_plot = interp
# levels = np.linspace(0, max_val, 20)
# scale_label = "Linear"
# draw_etopo_basemap(ax1, mode='stock')
# draw_etopo_basemap(ax2, mode='stock')
# c1 = ax1.contourf(animator.lons, animator.lats, data_for_plot, levels=levels,
# cmap="rainbow", alpha=0.6, transform=proj)
# ax1.contour(animator.lons, animator.lats, data_for_plot, levels=levels,
# colors='black', linewidths=0.5, transform=proj)
# ax1.set_title(f"T{t+1} | Alt: {z_val} km (Full - {scale_label})")
# ax1.set_extent([animator.lons.min(), animator.lons.max(), animator.lats.min(), animator.lats.max()])
# ax1.coastlines()
# ax1.add_feature(cfeature.BORDERS, linestyle=':')
# ax1.add_feature(cfeature.LAND)
# ax1.add_feature(cfeature.OCEAN)
# c2 = ax2.contourf(lon_zoom_grid, lat_zoom_grid, zoom_plot, levels=levels,
# cmap="rainbow", alpha=0.4, transform=proj)
# ax2.contour(lon_zoom_grid, lat_zoom_grid, zoom_plot, levels=levels,
# colors='black', linewidths=0.5, transform=proj)
# ax2.set_title(f"T{t+1} | Alt: {z_val} km (Zoom - {scale_label})")
# ax2.set_extent([lon_zoom.min(), lon_zoom.max(), lat_zoom.min(), lat_zoom.max()])
# ax2.coastlines()
# ax2.add_feature(cfeature.BORDERS, linestyle=':')
# ax2.add_feature(cfeature.LAND)
# ax2.add_feature(cfeature.OCEAN)
# ax2.text(animator.lons[0], animator.lats[0], animator.country_label, fontsize=9, color='white',
# transform=proj, bbox=dict(facecolor='black', alpha=0.5))
# texts_ax1, texts_ax2 = [], []
# for country in country_geoms:
# name = country.attributes['NAME_LONG']
# geom = country.geometry
# try:
# lon, lat = geom.centroid.x, geom.centroid.y
# if (lon_zoom.min() <= lon <= lon_zoom.max()) and (lat_zoom.min() <= lat <= lat_zoom.max()):
# text = ax2.text(lon, lat, name, fontsize=6, transform=proj,
# ha='center', va='center', color='white',
# bbox=dict(facecolor='black', alpha=0.5, linewidth=0))
# texts_ax2.append(text)
# if (animator.lons.min() <= lon <= animator.lons.max()) and (animator.lats.min() <= lat <= animator.lats.max()):
# text = ax1.text(lon, lat, name, fontsize=6, transform=proj,
# ha='center', va='center', color='white',
# bbox=dict(facecolor='black', alpha=0.5, linewidth=0))
# texts_ax1.append(text)
# except:
# continue
# adjust_text(texts_ax1, ax=ax1, only_move={'points': 'y', 'text': 'y'},
# arrowprops=dict(arrowstyle="->", color='white', lw=0.5))
# adjust_text(texts_ax2, ax=ax2, only_move={'points': 'y', 'text': 'y'},
# arrowprops=dict(arrowstyle="->", color='white', lw=0.5))
# if np.nanmax(valid_vals) > threshold:
# alert_text = f"⚠ Exceeds {threshold} g/m³!"
# for ax in [ax1, ax2]:
# ax.text(0.99, 0.01, alert_text, transform=ax.transAxes,
# ha='right', va='bottom', fontsize=10, color='red',
# bbox=dict(facecolor='white', alpha=0.8, edgecolor='red'))
# ax1.contour(animator.lons, animator.lats, interp, levels=[threshold], colors='red', linewidths=2, transform=proj)
# ax2.contour(lon_zoom_grid, lat_zoom_grid, zoom_plot, levels=[threshold], colors='red', linewidths=2, transform=proj)
# if not hasattr(update, "colorbar"):
# update.colorbar = fig.colorbar(c1, ax=[ax1, ax2], orientation='vertical',
# label="Ash concentration (g/m³)")
# formatter = mticker.FuncFormatter(lambda x, _: f'{x:.2g}')
# update.colorbar.ax.yaxis.set_major_formatter(formatter)
# if use_log:
# update.colorbar.ax.text(1.05, 1.02, "log scale", transform=update.colorbar.ax.transAxes,
# fontsize=9, color='gray', rotation=90, ha='left', va='bottom')
# return []
# ani = animation.FuncAnimation(fig, update, frames=valid_frames, blit=False)
# gif_path = os.path.join(output_folder, f"ash_T1-Tn_Z{z_index+1}.gif")
# ani.save(gif_path, writer='pillow', fps=fps)
# plt.close()
# print(f"✅ Saved animation for Z={z_val} km to {gif_path}")
###################################################################################################################
# import os
# import numpy as np
# import matplotlib.pyplot as plt
# import matplotlib.animation as animation
# import matplotlib.ticker as mticker
# import cartopy.crs as ccrs
# import cartopy.feature as cfeature
# from adjustText import adjust_text
# import cartopy.io.shapereader as shpreader
# from .interpolation import interpolate_grid
# from .basemaps import draw_etopo_basemap
# def animate_all_z_levels(animator, output_folder: str, fps: int = 2, threshold: float = 0.1):
# os.makedirs(output_folder, exist_ok=True)
# countries_shp = shpreader.natural_earth(resolution='110m', category='cultural', name='admin_0_countries')
# reader = shpreader.Reader(countries_shp)
# country_geoms = list(reader.records())
# # Compute consistent zoom window across all z-levels and time frames
# valid_mask_all = np.zeros_like(animator.datasets[0]['ash_concentration'].values[0], dtype=bool)
# for ds in animator.datasets:
# for z in range(len(animator.levels)):
# valid_mask_all |= ds['ash_concentration'].values[z] > 0
# y_idx_all, x_idx_all = np.where(valid_mask_all)
# if y_idx_all.size == 0 or x_idx_all.size == 0:
# raise ValueError("No valid data found across any Z level or frame.")
# y_min, y_max = y_idx_all.min(), y_idx_all.max()
# x_min, x_max = x_idx_all.min(), x_idx_all.max()
# buffer_y = int((y_max - y_min) * 0.5)
# buffer_x = int((x_max - x_min) * 0.5)
# y_start = max(0, y_min - buffer_y)
# y_end = min(animator.lat_grid.shape[0], y_max + buffer_y + 1)
# x_start = max(0, x_min - buffer_x)
# x_end = min(animator.lon_grid.shape[1], x_max + buffer_x + 1)
# lat_zoom = animator.lats[y_start:y_end]
# lon_zoom = animator.lons[x_start:x_end]
# lon_zoom_grid, lat_zoom_grid = np.meshgrid(lon_zoom, lat_zoom)
# for z_index, z_val in enumerate(animator.levels):
# fig = plt.figure(figsize=(16, 7))
# proj = ccrs.PlateCarree()
# ax1 = fig.add_subplot(1, 2, 1, projection=proj)
# ax2 = fig.add_subplot(1, 2, 2, projection=proj)
# valid_frames = []
# for t in range(len(animator.datasets)):
# data = animator.datasets[t]['ash_concentration'].values[z_index]
# interp = interpolate_grid(data, animator.lon_grid, animator.lat_grid)
# interp = np.where(interp < 0, np.nan, interp)
# if np.isfinite(interp).sum() > 0:
# valid_frames.append(t)
# if not valid_frames:
# print(f"No valid frames for Z={z_val} km. Skipping animation.")
# plt.close()
# continue
# def update(t):
# ax1.clear()
# ax2.clear()
# data = animator.datasets[t]['ash_concentration'].values[z_index]
# interp = interpolate_grid(data, animator.lon_grid, animator.lat_grid)
# interp = np.where(interp < 0, np.nan, interp)
# zoom_plot = interp[y_start:y_end, x_start:x_end]
# valid_vals = interp[np.isfinite(interp)]
# if valid_vals.size == 0:
# return []
# min_val = np.nanmin(valid_vals)
# max_val = np.nanmax(valid_vals)
# log_cutoff = 1e-3
# log_ratio = max_val / (min_val + 1e-6)
# use_log = min_val > log_cutoff and log_ratio > 100
# if use_log:
# data_for_plot = np.where(interp > log_cutoff, interp, np.nan)
# levels = np.logspace(np.log10(log_cutoff), np.log10(max_val), 20)
# scale_label = "Hybrid Log"
# else:
# data_for_plot = interp
# levels = np.linspace(0, max_val, 20)
# scale_label = "Linear"
# draw_etopo_basemap(ax1, mode='stock')
# draw_etopo_basemap(ax2, mode='stock')
# c1 = ax1.contourf(animator.lons, animator.lats, data_for_plot, levels=levels,
# cmap="rainbow", alpha=0.6, transform=proj)
# ax1.contour(animator.lons, animator.lats, data_for_plot, levels=levels,
# colors='black', linewidths=0.5, transform=proj)
# ax1.set_title(f"T{t+1} | Alt: {z_val} km (Full - {scale_label})")
# ax1.set_extent([animator.lons.min(), animator.lons.max(), animator.lats.min(), animator.lats.max()])
# ax1.coastlines()
# ax1.add_feature(cfeature.BORDERS, linestyle=':')
# ax1.add_feature(cfeature.LAND)
# ax1.add_feature(cfeature.OCEAN)
# c2 = ax2.contourf(lon_zoom_grid, lat_zoom_grid, zoom_plot, levels=levels,
# cmap="rainbow", alpha=0.4, transform=proj)
# ax2.contour(lon_zoom_grid, lat_zoom_grid, zoom_plot, levels=levels,
# colors='black', linewidths=0.5, transform=proj)
# ax2.set_title(f"T{t+1} | Alt: {z_val} km (Zoom - {scale_label})")
# ax2.set_extent([lon_zoom.min(), lon_zoom.max(), lat_zoom.min(), lat_zoom.max()])
# ax2.coastlines()
# ax2.add_feature(cfeature.BORDERS, linestyle=':')
# ax2.add_feature(cfeature.LAND)
# ax2.add_feature(cfeature.OCEAN)
# ax2.text(animator.lons[0], animator.lats[0], animator.country_label, fontsize=9, color='white',
# transform=proj, bbox=dict(facecolor='black', alpha=0.5))
# texts_ax1, texts_ax2 = [], []
# for country in country_geoms:
# name = country.attributes['NAME_LONG']
# geom = country.geometry
# try:
# lon, lat = geom.centroid.x, geom.centroid.y
# if (lon_zoom.min() <= lon <= lon_zoom.max()) and (lat_zoom.min() <= lat <= lat_zoom.max()):
# text = ax2.text(lon, lat, name, fontsize=6, transform=proj,
# ha='center', va='center', color='white',
# bbox=dict(facecolor='black', alpha=0.5, linewidth=0))
# texts_ax2.append(text)
# if (animator.lons.min() <= lon <= animator.lons.max()) and (animator.lats.min() <= lat <= animator.lats.max()):
# text = ax1.text(lon, lat, name, fontsize=6, transform=proj,
# ha='center', va='center', color='white',
# bbox=dict(facecolor='black', alpha=0.5, linewidth=0))
# texts_ax1.append(text)
# except:
# continue
# adjust_text(texts_ax1, ax=ax1, only_move={'points': 'y', 'text': 'y'},
# arrowprops=dict(arrowstyle="->", color='white', lw=0.5))
# adjust_text(texts_ax2, ax=ax2, only_move={'points': 'y', 'text': 'y'},
# arrowprops=dict(arrowstyle="->", color='white', lw=0.5))
# if np.nanmax(valid_vals) > threshold:
# alert_text = f"⚠ Exceeds {threshold} g/m³!"
# for ax in [ax1, ax2]:
# ax.text(0.99, 0.01, alert_text, transform=ax.transAxes,
# ha='right', va='bottom', fontsize=10, color='red',
# bbox=dict(facecolor='white', alpha=0.8, edgecolor='red'))
# ax1.contour(animator.lons, animator.lats, interp, levels=[threshold], colors='red', linewidths=2, transform=proj)
# ax2.contour(lon_zoom_grid, lat_zoom_grid, zoom_plot, levels=[threshold], colors='red', linewidths=2, transform=proj)
# if not hasattr(update, "colorbar"):
# update.colorbar = fig.colorbar(c1, ax=[ax1, ax2], orientation='vertical',
# label="Ash concentration (g/m³)")
# formatter = mticker.FuncFormatter(lambda x, _: f'{x:.2g}')
# update.colorbar.ax.yaxis.set_major_formatter(formatter)
# if use_log:
# update.colorbar.ax.text(1.05, 1.02, "log scale", transform=update.colorbar.ax.transAxes,
# fontsize=9, color='gray', rotation=90, ha='left', va='bottom')
# return []
# ani = animation.FuncAnimation(fig, update, frames=valid_frames, blit=False)
# gif_path = os.path.join(output_folder, f"ash_T1-Tn_Z{z_index+1}.gif")
# ani.save(gif_path, writer='pillow', fps=fps)
# plt.close()
# print(f"✅ Saved animation for Z={z_val} km to {gif_path}")
import os
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import matplotlib.ticker as mticker
import cartopy.crs as ccrs
import cartopy.feature as cfeature
from adjustText import adjust_text
import cartopy.io.shapereader as shpreader
from .interpolation import interpolate_grid
from .basemaps import draw_etopo_basemap
def animate_all_z_levels(animator, output_folder: str, fps: int = 2, threshold: float = 0.1,
zoom_width_deg: float = 6.0, zoom_height_deg: float = 6.0):
os.makedirs(output_folder, exist_ok=True)
countries_shp = shpreader.natural_earth(resolution='110m', category='cultural', name='admin_0_countries')
reader = shpreader.Reader(countries_shp)
country_geoms = list(reader.records())
# Find the most active region (max concentration point)
max_conc = -np.inf
center_lat = center_lon = None
for ds in animator.datasets:
for z in range(len(animator.levels)):
data = ds['ash_concentration'].values[z]
if np.max(data) > max_conc:
max_conc = np.max(data)
max_idx = np.unravel_index(np.argmax(data), data.shape)
center_lat = animator.lat_grid[max_idx]
center_lon = animator.lon_grid[max_idx]
if center_lat is None or center_lon is None:
raise ValueError("No valid concentration found to determine zoom center.")
# Compute fixed zoom extents in lat/lon degrees
lon_zoom_min = center_lon - zoom_width_deg / 2
lon_zoom_max = center_lon + zoom_width_deg / 2
lat_zoom_min = center_lat - zoom_height_deg / 2
lat_zoom_max = center_lat + zoom_height_deg / 2
# Create zoom grids for plotting
lat_zoom = animator.lats[(animator.lats >= lat_zoom_min) & (animator.lats <= lat_zoom_max)]
lon_zoom = animator.lons[(animator.lons >= lon_zoom_min) & (animator.lons <= lon_zoom_max)]
lon_zoom_grid, lat_zoom_grid = np.meshgrid(lon_zoom, lat_zoom)
for z_index, z_val in enumerate(animator.levels):
fig = plt.figure(figsize=(16, 7))
proj = ccrs.PlateCarree()
ax1 = fig.add_subplot(1, 2, 1, projection=proj)
ax2 = fig.add_subplot(1, 2, 2, projection=proj)
valid_frames = []
for t in range(len(animator.datasets)):
data = animator.datasets[t]['ash_concentration'].values[z_index]
interp = interpolate_grid(data, animator.lon_grid, animator.lat_grid)
interp = np.where(interp < 0, np.nan, interp)
if np.isfinite(interp).sum() > 0:
valid_frames.append(t)
if not valid_frames:
print(f"No valid frames for Z={z_val} km. Skipping animation.")
plt.close()
continue
def update(t):
ax1.clear()
ax2.clear()
data = animator.datasets[t]['ash_concentration'].values[z_index]
interp = interpolate_grid(data, animator.lon_grid, animator.lat_grid)
interp = np.where(interp < 0, np.nan, interp)
# Extract zoom window from interpolated data
lat_idx = np.where((animator.lats >= lat_zoom_min) & (animator.lats <= lat_zoom_max))[0]
lon_idx = np.where((animator.lons >= lon_zoom_min) & (animator.lons <= lon_zoom_max))[0]
zoom_plot = interp[np.ix_(lat_idx, lon_idx)]
valid_vals = interp[np.isfinite(interp)]
if valid_vals.size == 0:
return []
min_val = np.nanmin(valid_vals)
max_val = np.nanmax(valid_vals)
log_cutoff = 1e-3
log_ratio = max_val / (min_val + 1e-6)
use_log = min_val > log_cutoff and log_ratio > 100
if use_log:
data_for_plot = np.where(interp > log_cutoff, interp, np.nan)
levels = np.logspace(np.log10(log_cutoff), np.log10(max_val), 20)
scale_label = "Hybrid Log"
else:
data_for_plot = interp
levels = np.linspace(0, max_val, 20)
scale_label = "Linear"
draw_etopo_basemap(ax1, mode='stock')
draw_etopo_basemap(ax2, mode='stock')
c1 = ax1.contourf(animator.lons, animator.lats, data_for_plot, levels=levels,
cmap="rainbow", alpha=0.6, transform=proj)
ax1.contour(animator.lons, animator.lats, data_for_plot, levels=levels,
colors='black', linewidths=0.5, transform=proj)
ax1.set_title(f"T{t+1} | Alt: {z_val} km (Full - {scale_label})")
ax1.set_extent([animator.lons.min(), animator.lons.max(), animator.lats.min(), animator.lats.max()])
ax1.coastlines()
ax1.add_feature(cfeature.BORDERS, linestyle=':')
ax1.add_feature(cfeature.LAND)
ax1.add_feature(cfeature.OCEAN)
c2 = ax2.contourf(lon_zoom_grid, lat_zoom_grid, zoom_plot, levels=levels,
cmap="rainbow", alpha=0.4, transform=proj)
ax2.contour(lon_zoom_grid, lat_zoom_grid, zoom_plot, levels=levels,
colors='black', linewidths=0.5, transform=proj)
ax2.set_title(f"T{t+1} | Alt: {z_val} km (Zoom - {scale_label})")
ax2.set_extent([lon_zoom_min, lon_zoom_max, lat_zoom_min, lat_zoom_max])
ax2.coastlines()
ax2.add_feature(cfeature.BORDERS, linestyle=':')
ax2.add_feature(cfeature.LAND)
ax2.add_feature(cfeature.OCEAN)
ax2.text(animator.lons[0], animator.lats[0], animator.country_label, fontsize=9, color='white',
transform=proj, bbox=dict(facecolor='black', alpha=0.5))
texts_ax1, texts_ax2 = [], []
for country in country_geoms:
name = country.attributes['NAME_LONG']
geom = country.geometry
try:
lon, lat = geom.centroid.x, geom.centroid.y
if (lon_zoom_min <= lon <= lon_zoom_max) and (lat_zoom_min <= lat <= lat_zoom_max):
text = ax2.text(lon, lat, name, fontsize=6, transform=proj,
ha='center', va='center', color='white',
bbox=dict(facecolor='black', alpha=0.5, linewidth=0))
texts_ax2.append(text)
if (animator.lons.min() <= lon <= animator.lons.max()) and (animator.lats.min() <= lat <= animator.lats.max()):
text = ax1.text(lon, lat, name, fontsize=6, transform=proj,
ha='center', va='center', color='white',
bbox=dict(facecolor='black', alpha=0.5, linewidth=0))
texts_ax1.append(text)
except:
continue
adjust_text(texts_ax1, ax=ax1, only_move={'points': 'y', 'text': 'y'},
arrowprops=dict(arrowstyle="->", color='white', lw=0.5))
adjust_text(texts_ax2, ax=ax2, only_move={'points': 'y', 'text': 'y'},
arrowprops=dict(arrowstyle="->", color='white', lw=0.5))
if np.nanmax(valid_vals) > threshold:
alert_text = f"⚠ Exceeds {threshold} g/m³!"
for ax in [ax1, ax2]:
ax.text(0.99, 0.01, alert_text, transform=ax.transAxes,
ha='right', va='bottom', fontsize=10, color='red',
bbox=dict(facecolor='white', alpha=0.8, edgecolor='red'))
ax1.contour(animator.lons, animator.lats, interp, levels=[threshold], colors='red', linewidths=2, transform=proj)
ax2.contour(lon_zoom_grid, lat_zoom_grid, zoom_plot, levels=[threshold], colors='red', linewidths=2, transform=proj)
if not hasattr(update, "colorbar"):
update.colorbar = fig.colorbar(c1, ax=[ax1, ax2], orientation='vertical',
label="Ash concentration (g/m³)")
formatter = mticker.FuncFormatter(lambda x, _: f'{x:.2g}')
update.colorbar.ax.yaxis.set_major_formatter(formatter)
if use_log:
update.colorbar.ax.text(1.05, 1.02, "log scale", transform=update.colorbar.ax.transAxes,
fontsize=9, color='gray', rotation=90, ha='left', va='bottom')
return []
ani = animation.FuncAnimation(fig, update, frames=valid_frames, blit=False)
gif_path = os.path.join(output_folder, f"ash_T1-Tn_Z{z_index+1}.gif")
ani.save(gif_path, writer='pillow', fps=fps)
plt.close()
print(f"✅ Saved animation for Z={z_val} km to {gif_path}")