|
|
"""Plotting functions for 1D, 2D, and map visualizations.""" |
|
|
|
|
|
import io |
|
|
import os |
|
|
from typing import Optional, Dict, Any, Tuple, Literal |
|
|
import numpy as np |
|
|
import matplotlib.pyplot as plt |
|
|
import matplotlib.colors as mcolors |
|
|
from matplotlib.figure import Figure |
|
|
from matplotlib.axes import Axes |
|
|
import xarray as xr |
|
|
|
|
|
try: |
|
|
import cartopy.crs as ccrs |
|
|
import cartopy.feature as cfeature |
|
|
HAS_CARTOPY = True |
|
|
except ImportError: |
|
|
HAS_CARTOPY = False |
|
|
|
|
|
from .utils import identify_coordinates, get_crs, is_geographic, format_value |
|
|
|
|
|
|
|
|
def setup_matplotlib(): |
|
|
"""Setup matplotlib with non-interactive backend.""" |
|
|
plt.switch_backend('Agg') |
|
|
plt.style.use('default') |
|
|
|
|
|
|
|
|
def plot_1d(da: xr.DataArray, x_dim: Optional[str] = None, **style) -> Figure: |
|
|
""" |
|
|
Create a 1D line plot. |
|
|
|
|
|
Args: |
|
|
da: Input DataArray (should be 1D or have only one varying dimension) |
|
|
x_dim: Dimension to use as x-axis (auto-detected if None) |
|
|
**style: Style parameters (color, linewidth, etc.) |
|
|
|
|
|
Returns: |
|
|
matplotlib Figure |
|
|
""" |
|
|
setup_matplotlib() |
|
|
|
|
|
|
|
|
if x_dim is None: |
|
|
|
|
|
for dim in da.dims: |
|
|
if da.sizes[dim] > 1: |
|
|
x_dim = dim |
|
|
break |
|
|
|
|
|
if x_dim is None: |
|
|
raise ValueError("No suitable dimension found for 1D plot") |
|
|
|
|
|
if x_dim not in da.dims: |
|
|
raise ValueError(f"Dimension '{x_dim}' not found in DataArray") |
|
|
|
|
|
|
|
|
fig, ax = plt.subplots(figsize=(10, 6)) |
|
|
|
|
|
|
|
|
x_data = da.coords[x_dim] |
|
|
y_data = da |
|
|
|
|
|
|
|
|
line_style = { |
|
|
'color': style.get('color', 'blue'), |
|
|
'linewidth': style.get('linewidth', 1.5), |
|
|
'linestyle': style.get('linestyle', '-'), |
|
|
'marker': style.get('marker', ''), |
|
|
'markersize': style.get('markersize', 4), |
|
|
'alpha': style.get('alpha', 1.0) |
|
|
} |
|
|
|
|
|
ax.plot(x_data, y_data, **line_style) |
|
|
|
|
|
|
|
|
ax.set_xlabel(f"{x_dim} ({x_data.attrs.get('units', '')})") |
|
|
ax.set_ylabel(f"{da.name or 'Value'} ({da.attrs.get('units', '')})") |
|
|
|
|
|
|
|
|
title = da.attrs.get('long_name', da.name or 'Data') |
|
|
ax.set_title(title) |
|
|
|
|
|
|
|
|
if style.get('grid', True): |
|
|
ax.grid(True, alpha=0.3) |
|
|
|
|
|
|
|
|
if 'time' in x_dim.lower() or x_data.dtype.kind == 'M': |
|
|
fig.autofmt_xdate() |
|
|
|
|
|
plt.tight_layout() |
|
|
return fig |
|
|
|
|
|
|
|
|
def plot_2d(da: xr.DataArray, kind: Literal["image", "contour"] = "image", |
|
|
x_dim: Optional[str] = None, y_dim: Optional[str] = None, **style) -> Figure: |
|
|
""" |
|
|
Create a 2D plot (image or contour). |
|
|
|
|
|
Args: |
|
|
da: Input DataArray (should be 2D) |
|
|
kind: Plot type ('image' or 'contour') |
|
|
x_dim, y_dim: Dimensions to use for axes |
|
|
**style: Style parameters |
|
|
|
|
|
Returns: |
|
|
matplotlib Figure |
|
|
""" |
|
|
setup_matplotlib() |
|
|
|
|
|
|
|
|
if x_dim is None or y_dim is None: |
|
|
coords = identify_coordinates(da) |
|
|
if x_dim is None: |
|
|
x_dim = coords.get('X', da.dims[-1]) |
|
|
if y_dim is None: |
|
|
y_dim = coords.get('Y', da.dims[-2]) |
|
|
|
|
|
if x_dim not in da.dims or y_dim not in da.dims: |
|
|
raise ValueError(f"Dimensions {x_dim}, {y_dim} not found in DataArray") |
|
|
|
|
|
|
|
|
da_plot = da.transpose(y_dim, x_dim) |
|
|
|
|
|
|
|
|
fig, ax = plt.subplots(figsize=(10, 8)) |
|
|
|
|
|
|
|
|
x_coord = da.coords[x_dim] |
|
|
y_coord = da.coords[y_dim] |
|
|
|
|
|
|
|
|
cmap = style.get('cmap', 'viridis') |
|
|
if isinstance(cmap, str): |
|
|
cmap = plt.get_cmap(cmap) |
|
|
|
|
|
|
|
|
vmin = style.get('vmin', float(da.min().values)) |
|
|
vmax = style.get('vmax', float(da.max().values)) |
|
|
norm = mcolors.Normalize(vmin=vmin, vmax=vmax) |
|
|
|
|
|
if kind == "image": |
|
|
|
|
|
im = ax.imshow(da_plot.values, |
|
|
extent=[float(x_coord.min()), float(x_coord.max()), |
|
|
float(y_coord.min()), float(y_coord.max())], |
|
|
aspect='auto', origin='lower', cmap=cmap, norm=norm) |
|
|
|
|
|
elif kind == "contour": |
|
|
|
|
|
levels = style.get('levels', 20) |
|
|
if isinstance(levels, int): |
|
|
levels = np.linspace(vmin, vmax, levels) |
|
|
|
|
|
X, Y = np.meshgrid(x_coord, y_coord) |
|
|
im = ax.contourf(X, Y, da_plot.values, levels=levels, cmap=cmap, norm=norm) |
|
|
|
|
|
|
|
|
if style.get('contour_lines', False): |
|
|
cs = ax.contour(X, Y, da_plot.values, levels=levels, colors='k', linewidths=0.5) |
|
|
ax.clabel(cs, inline=True, fontsize=8) |
|
|
|
|
|
|
|
|
if style.get('colorbar', True): |
|
|
cbar = plt.colorbar(im, ax=ax) |
|
|
cbar.set_label(f"{da.name or 'Value'} ({da.attrs.get('units', '')})") |
|
|
|
|
|
|
|
|
ax.set_xlabel(f"{x_dim} ({x_coord.attrs.get('units', '')})") |
|
|
ax.set_ylabel(f"{y_dim} ({y_coord.attrs.get('units', '')})") |
|
|
|
|
|
|
|
|
title = da.attrs.get('long_name', da.name or 'Data') |
|
|
ax.set_title(title) |
|
|
|
|
|
plt.tight_layout() |
|
|
return fig |
|
|
|
|
|
|
|
|
def plot_map(da: xr.DataArray, proj: str = "PlateCarree", **style) -> Figure: |
|
|
""" |
|
|
Create a map plot with cartopy. |
|
|
|
|
|
Args: |
|
|
da: Input DataArray with geographic coordinates |
|
|
proj: Map projection name |
|
|
**style: Style parameters |
|
|
|
|
|
Returns: |
|
|
matplotlib Figure |
|
|
""" |
|
|
if not HAS_CARTOPY: |
|
|
raise ImportError("Cartopy is required for map plotting") |
|
|
|
|
|
setup_matplotlib() |
|
|
|
|
|
|
|
|
if not is_geographic(da): |
|
|
raise ValueError("DataArray does not appear to have geographic coordinates") |
|
|
|
|
|
|
|
|
coords = identify_coordinates(da) |
|
|
if 'X' not in coords or 'Y' not in coords: |
|
|
raise ValueError("Could not identify longitude/latitude coordinates") |
|
|
|
|
|
lon_dim = coords['X'] |
|
|
lat_dim = coords['Y'] |
|
|
|
|
|
|
|
|
proj_map = { |
|
|
'PlateCarree': ccrs.PlateCarree(), |
|
|
'Robinson': ccrs.Robinson(), |
|
|
'Mollweide': ccrs.Mollweide(), |
|
|
'Orthographic': ccrs.Orthographic(), |
|
|
'NorthPolarStereo': ccrs.NorthPolarStereo(), |
|
|
'SouthPolarStereo': ccrs.SouthPolarStereo(), |
|
|
'Miller': ccrs.Miller(), |
|
|
'InterruptedGoodeHomolosine': ccrs.InterruptedGoodeHomolosine() |
|
|
} |
|
|
|
|
|
if proj not in proj_map: |
|
|
proj = 'PlateCarree' |
|
|
|
|
|
projection = proj_map[proj] |
|
|
|
|
|
|
|
|
fig, ax = plt.subplots(figsize=(12, 8), |
|
|
subplot_kw={'projection': projection}) |
|
|
|
|
|
|
|
|
da_plot = da.transpose(lat_dim, lon_dim) |
|
|
|
|
|
|
|
|
lons = da.coords[lon_dim].values |
|
|
lats = da.coords[lat_dim].values |
|
|
|
|
|
|
|
|
cmap = style.get('cmap', 'viridis') |
|
|
if isinstance(cmap, str): |
|
|
cmap = plt.get_cmap(cmap) |
|
|
|
|
|
vmin = style.get('vmin', float(da.min().values)) |
|
|
vmax = style.get('vmax', float(da.max().values)) |
|
|
|
|
|
|
|
|
plot_type = style.get('plot_type', 'pcolormesh') |
|
|
|
|
|
if plot_type == 'contourf': |
|
|
levels = style.get('levels', 20) |
|
|
if isinstance(levels, int): |
|
|
levels = np.linspace(vmin, vmax, levels) |
|
|
im = ax.contourf(lons, lats, da_plot.values, levels=levels, |
|
|
cmap=cmap, transform=ccrs.PlateCarree()) |
|
|
else: |
|
|
im = ax.pcolormesh(lons, lats, da_plot.values, cmap=cmap, |
|
|
transform=ccrs.PlateCarree(), |
|
|
vmin=vmin, vmax=vmax, shading='auto') |
|
|
|
|
|
|
|
|
if style.get('coastlines', True): |
|
|
ax.coastlines(resolution='50m', color='black', linewidth=0.5) |
|
|
|
|
|
if style.get('borders', False): |
|
|
ax.add_feature(cfeature.BORDERS, linewidth=0.5) |
|
|
|
|
|
if style.get('ocean', False): |
|
|
ax.add_feature(cfeature.OCEAN, color='lightblue', alpha=0.5) |
|
|
|
|
|
if style.get('land', False): |
|
|
ax.add_feature(cfeature.LAND, color='lightgray', alpha=0.5) |
|
|
|
|
|
|
|
|
if style.get('gridlines', True): |
|
|
gl = ax.gridlines(draw_labels=True, alpha=0.5) |
|
|
gl.top_labels = False |
|
|
gl.right_labels = False |
|
|
|
|
|
|
|
|
if 'extent' in style: |
|
|
ax.set_extent(style['extent'], crs=ccrs.PlateCarree()) |
|
|
else: |
|
|
ax.set_global() |
|
|
|
|
|
|
|
|
if style.get('colorbar', True): |
|
|
cbar = plt.colorbar(im, ax=ax, orientation='horizontal', |
|
|
pad=0.05, shrink=0.8) |
|
|
cbar.set_label(f"{da.name or 'Value'} ({da.attrs.get('units', '')})") |
|
|
|
|
|
|
|
|
title = da.attrs.get('long_name', da.name or 'Data') |
|
|
ax.set_title(title, pad=20) |
|
|
|
|
|
plt.tight_layout() |
|
|
return fig |
|
|
|
|
|
|
|
|
def export_fig(fig: Figure, fmt: Literal["png", "svg", "pdf"] = "png", |
|
|
dpi: int = 150, out_path: Optional[str] = None) -> str: |
|
|
""" |
|
|
Export a figure to file or return as bytes. |
|
|
|
|
|
Args: |
|
|
fig: matplotlib Figure |
|
|
fmt: Output format |
|
|
dpi: Resolution for raster formats |
|
|
out_path: Output file path (if None, returns bytes) |
|
|
|
|
|
Returns: |
|
|
File path or bytes |
|
|
""" |
|
|
if out_path is None: |
|
|
|
|
|
buf = io.BytesIO() |
|
|
fig.savefig(buf, format=fmt, dpi=dpi, bbox_inches='tight') |
|
|
buf.seek(0) |
|
|
return buf.getvalue() |
|
|
else: |
|
|
|
|
|
fig.savefig(out_path, format=fmt, dpi=dpi, bbox_inches='tight') |
|
|
return out_path |
|
|
|
|
|
|
|
|
def create_subplot_figure(n_plots: int, ncols: int = 2) -> Tuple[Figure, np.ndarray]: |
|
|
"""Create a figure with multiple subplots.""" |
|
|
nrows = (n_plots + ncols - 1) // ncols |
|
|
fig, axes = plt.subplots(nrows, ncols, figsize=(6*ncols, 4*nrows)) |
|
|
|
|
|
if n_plots == 1: |
|
|
axes = np.array([axes]) |
|
|
elif nrows == 1: |
|
|
axes = axes.reshape(1, -1) |
|
|
|
|
|
|
|
|
for i in range(n_plots, nrows * ncols): |
|
|
axes.flat[i].set_visible(False) |
|
|
|
|
|
return fig, axes |
|
|
|
|
|
|
|
|
def add_statistics_text(ax: Axes, da: xr.DataArray, x: float = 0.02, y: float = 0.98): |
|
|
"""Add statistics text to a plot.""" |
|
|
stats = [ |
|
|
f"Min: {float(da.min().values):.3g}", |
|
|
f"Max: {float(da.max().values):.3g}", |
|
|
f"Mean: {float(da.mean().values):.3g}", |
|
|
f"Std: {float(da.std().values):.3g}" |
|
|
] |
|
|
|
|
|
text = '\n'.join(stats) |
|
|
ax.text(x, y, text, transform=ax.transAxes, |
|
|
bbox=dict(boxstyle='round', facecolor='white', alpha=0.8), |
|
|
verticalalignment='top', fontsize=8) |