| | """Grid alignment and combination operations.""" |
| |
|
| | from typing import Literal, Dict, Any, Tuple |
| | import numpy as np |
| | import xarray as xr |
| | from .utils import identify_coordinates, get_crs, is_geographic |
| |
|
| |
|
| | def align_for_combine(a: xr.DataArray, b: xr.DataArray, method: str = "reindex") -> Tuple[xr.DataArray, xr.DataArray]: |
| | """ |
| | Align two DataArrays for combination operations. |
| | |
| | Args: |
| | a, b: Input DataArrays |
| | method: Alignment method ('reindex', 'interp') |
| | |
| | Returns: |
| | Tuple of aligned DataArrays |
| | """ |
| | |
| | crs_a = get_crs(a) |
| | crs_b = get_crs(b) |
| | |
| | if crs_a and crs_b and not crs_a.equals(crs_b): |
| | raise ValueError(f"CRS mismatch: {crs_a} vs {crs_b}") |
| | |
| | |
| | coords_a = identify_coordinates(a) |
| | coords_b = identify_coordinates(b) |
| | |
| | |
| | common_dims = set(a.dims) & set(b.dims) |
| | |
| | if not common_dims: |
| | raise ValueError("No common dimensions found for alignment") |
| | |
| | |
| | if method == "reindex": |
| | |
| | a_aligned = a |
| | b_aligned = b |
| | |
| | for dim in common_dims: |
| | if dim in a.dims and dim in b.dims: |
| | |
| | coord_a = a.coords[dim] |
| | coord_b = b.coords[dim] |
| | |
| | |
| | if len(coord_a) >= len(coord_b): |
| | target_coord = coord_a |
| | else: |
| | target_coord = coord_b |
| | |
| | |
| | a_aligned = a_aligned.reindex({dim: target_coord}, method='nearest') |
| | b_aligned = b_aligned.reindex({dim: target_coord}, method='nearest') |
| | |
| | elif method == "interp": |
| | |
| | |
| | common_coords = {} |
| | for dim in common_dims: |
| | if dim in a.dims and dim in b.dims: |
| | coord_a = a.coords[dim] |
| | coord_b = b.coords[dim] |
| | |
| | |
| | min_val = max(float(coord_a.min()), float(coord_b.min())) |
| | max_val = min(float(coord_a.max()), float(coord_b.max())) |
| | |
| | |
| | res_a = float(coord_a[1] - coord_a[0]) if len(coord_a) > 1 else 1.0 |
| | res_b = float(coord_b[1] - coord_b[0]) if len(coord_b) > 1 else 1.0 |
| | res = min(abs(res_a), abs(res_b)) |
| | |
| | common_coords[dim] = np.arange(min_val, max_val + res, res) |
| | |
| | a_aligned = a.interp(common_coords) |
| | b_aligned = b.interp(common_coords) |
| | |
| | else: |
| | raise ValueError(f"Unknown alignment method: {method}") |
| | |
| | return a_aligned, b_aligned |
| |
|
| |
|
| | def combine(a: xr.DataArray, b: xr.DataArray, op: Literal["sum", "avg", "diff"] = "sum") -> xr.DataArray: |
| | """ |
| | Combine two DataArrays with the specified operation. |
| | |
| | Args: |
| | a, b: Input DataArrays |
| | op: Operation ('sum', 'avg', 'diff') |
| | |
| | Returns: |
| | Combined DataArray |
| | """ |
| | |
| | a_aligned, b_aligned = align_for_combine(a, b) |
| | |
| | |
| | if op == "sum": |
| | result = a_aligned + b_aligned |
| | elif op == "avg": |
| | result = (a_aligned + b_aligned) / 2 |
| | elif op == "diff": |
| | result = a_aligned - b_aligned |
| | else: |
| | raise ValueError(f"Unknown operation: {op}") |
| | |
| | |
| | result.name = f"{a.name}_{op}_{b.name}" |
| | |
| | if op == "sum": |
| | result.attrs['long_name'] = f"{a.attrs.get('long_name', a.name)} + {b.attrs.get('long_name', b.name)}" |
| | elif op == "avg": |
| | result.attrs['long_name'] = f"Average of {a.attrs.get('long_name', a.name)} and {b.attrs.get('long_name', b.name)}" |
| | elif op == "diff": |
| | result.attrs['long_name'] = f"{a.attrs.get('long_name', a.name)} - {b.attrs.get('long_name', b.name)}" |
| | |
| | |
| | if a.attrs.get('units') == b.attrs.get('units'): |
| | result.attrs['units'] = a.attrs.get('units', '') |
| | |
| | return result |
| |
|
| |
|
| | def section(da: xr.DataArray, along: str, fixed: Dict[str, Any]) -> xr.DataArray: |
| | """ |
| | Create a cross-section of the DataArray. |
| | |
| | Args: |
| | da: Input DataArray |
| | along: Dimension to keep for the section (e.g., 'time', 'lat') |
| | fixed: Dictionary of {dim: value} for dimensions to fix |
| | |
| | Returns: |
| | Cross-section DataArray |
| | """ |
| | if along not in da.dims: |
| | raise ValueError(f"Dimension '{along}' not found in DataArray") |
| | |
| | |
| | result = da |
| | |
| | |
| | selection = {} |
| | for dim, value in fixed.items(): |
| | if dim not in da.dims: |
| | continue |
| | |
| | coord = da.coords[dim] |
| | |
| | if isinstance(value, (int, float)): |
| | |
| | selection[dim] = coord.sel({dim: value}, method='nearest') |
| | elif isinstance(value, str) and 'time' in dim.lower(): |
| | |
| | selection[dim] = value |
| | else: |
| | selection[dim] = value |
| | |
| | if selection: |
| | result = result.sel(selection, method='nearest') |
| | |
| | |
| | if along not in result.dims: |
| | raise ValueError(f"Section operation removed the '{along}' dimension") |
| | |
| | |
| | result.attrs = da.attrs.copy() |
| | |
| | |
| | section_info = [] |
| | for dim, value in fixed.items(): |
| | if dim in da.dims: |
| | if isinstance(value, (int, float)): |
| | section_info.append(f"{dim}={value:.3f}") |
| | else: |
| | section_info.append(f"{dim}={value}") |
| | |
| | if section_info: |
| | long_name = result.attrs.get('long_name', result.name) |
| | result.attrs['long_name'] = f"{long_name} ({', '.join(section_info)})" |
| | |
| | return result |
| |
|
| |
|
| | def aggregate_spatial(da: xr.DataArray, method: str = "mean") -> xr.DataArray: |
| | """ |
| | Aggregate spatially (e.g., zonal mean). |
| | |
| | Args: |
| | da: Input DataArray |
| | method: Aggregation method ('mean', 'sum', 'std') |
| | |
| | Returns: |
| | Spatially aggregated DataArray |
| | """ |
| | coords = identify_coordinates(da) |
| | |
| | spatial_dims = [] |
| | if 'X' in coords: |
| | spatial_dims.append(coords['X']) |
| | if 'Y' in coords: |
| | spatial_dims.append(coords['Y']) |
| | |
| | if not spatial_dims: |
| | raise ValueError("No spatial dimensions found for aggregation") |
| | |
| | |
| | if method == "mean": |
| | result = da.mean(dim=spatial_dims) |
| | elif method == "sum": |
| | result = da.sum(dim=spatial_dims) |
| | elif method == "std": |
| | result = da.std(dim=spatial_dims) |
| | else: |
| | raise ValueError(f"Unknown aggregation method: {method}") |
| | |
| | |
| | result.attrs = da.attrs.copy() |
| | long_name = result.attrs.get('long_name', result.name) |
| | result.attrs['long_name'] = f"{method.capitalize()} of {long_name}" |
| | |
| | return result |