| | """Animation functionality for creating MP4 videos from multi-dimensional data.""" |
| |
|
| | import os |
| | import tempfile |
| | import subprocess |
| | from typing import Optional, Callable, List |
| | import numpy as np |
| | import matplotlib.pyplot as plt |
| | from matplotlib.animation import FuncAnimation |
| | import xarray as xr |
| |
|
| | from .plot import plot_1d, plot_2d, plot_map, setup_matplotlib |
| | from .utils import identify_coordinates, format_value |
| |
|
| |
|
| | def check_ffmpeg(): |
| | """Check if FFmpeg is available.""" |
| | try: |
| | subprocess.run(['ffmpeg', '-version'], capture_output=True, check=True) |
| | return True |
| | except (subprocess.CalledProcessError, FileNotFoundError): |
| | return False |
| |
|
| |
|
| | def animate_over_dim(da: xr.DataArray, dim: str, plot_func: Callable = None, |
| | fps: int = 10, out: str = "animation.mp4", |
| | figsize: tuple = (10, 8), **plot_kwargs) -> str: |
| | """ |
| | Create an animation over a specified dimension. |
| | |
| | Args: |
| | da: Input DataArray |
| | dim: Dimension to animate over |
| | plot_func: Plotting function to use (auto-detected if None) |
| | fps: Frames per second |
| | out: Output file path |
| | figsize: Figure size |
| | **plot_kwargs: Additional plotting parameters |
| | |
| | Returns: |
| | Path to the created animation file |
| | """ |
| | if not check_ffmpeg(): |
| | raise RuntimeError("FFmpeg is required for creating MP4 animations") |
| | |
| | if dim not in da.dims: |
| | raise ValueError(f"Dimension '{dim}' not found in DataArray") |
| | |
| | setup_matplotlib() |
| | |
| | |
| | coord_vals = da.coords[dim].values |
| | n_frames = len(coord_vals) |
| | |
| | if n_frames < 2: |
| | raise ValueError(f"Need at least 2 frames for animation, got {n_frames}") |
| | |
| | |
| | if plot_func is None: |
| | remaining_dims = [d for d in da.dims if d != dim] |
| | n_remaining = len(remaining_dims) |
| | |
| | |
| | coords = identify_coordinates(da) |
| | has_geo = 'X' in coords and 'Y' in coords |
| | |
| | if n_remaining == 1: |
| | plot_func = plot_1d |
| | elif n_remaining == 2 and has_geo: |
| | plot_func = plot_map |
| | elif n_remaining == 2: |
| | plot_func = plot_2d |
| | else: |
| | raise ValueError(f"Cannot auto-detect plot type for {n_remaining}D data") |
| | |
| | |
| | fig, ax = plt.subplots(figsize=figsize) |
| | |
| | |
| | initial_frame = da.isel({dim: 0}) |
| | |
| | |
| | if 'vmin' not in plot_kwargs: |
| | plot_kwargs['vmin'] = float(da.min().values) |
| | if 'vmax' not in plot_kwargs: |
| | plot_kwargs['vmax'] = float(da.max().values) |
| | |
| | |
| | if plot_func == plot_1d: |
| | line, = ax.plot([], []) |
| | ax.set_xlim(float(initial_frame.coords[initial_frame.dims[0]].min()), |
| | float(initial_frame.coords[initial_frame.dims[0]].max())) |
| | ax.set_ylim(plot_kwargs['vmin'], plot_kwargs['vmax']) |
| | |
| | |
| | x_dim = initial_frame.dims[0] |
| | ax.set_xlabel(f"{x_dim} ({initial_frame.coords[x_dim].attrs.get('units', '')})") |
| | ax.set_ylabel(f"{da.name or 'Value'} ({da.attrs.get('units', '')})") |
| | |
| | def animate(frame_idx): |
| | frame_data = da.isel({dim: frame_idx}) |
| | x_data = frame_data.coords[x_dim] |
| | line.set_data(x_data, frame_data) |
| | |
| | |
| | coord_val = coord_vals[frame_idx] |
| | coord_str = format_value(coord_val, dim) |
| | title = f"{da.attrs.get('long_name', da.name or 'Data')} - {dim}={coord_str}" |
| | ax.set_title(title) |
| | |
| | return line, |
| | |
| | elif plot_func in [plot_2d, plot_map]: |
| | |
| | def animate(frame_idx): |
| | ax.clear() |
| | frame_data = da.isel({dim: frame_idx}) |
| | |
| | |
| | if plot_func == plot_map: |
| | |
| | import cartopy.crs as ccrs |
| | import cartopy.feature as cfeature |
| | |
| | proj = plot_kwargs.get('proj', 'PlateCarree') |
| | proj_map = { |
| | 'PlateCarree': ccrs.PlateCarree(), |
| | 'Robinson': ccrs.Robinson(), |
| | 'Mollweide': ccrs.Mollweide() |
| | } |
| | projection = proj_map.get(proj, ccrs.PlateCarree()) |
| | |
| | coords = identify_coordinates(frame_data) |
| | lon_dim = coords['X'] |
| | lat_dim = coords['Y'] |
| | |
| | lons = frame_data.coords[lon_dim].values |
| | lats = frame_data.coords[lat_dim].values |
| | |
| | |
| | cmap = plot_kwargs.get('cmap', 'viridis') |
| | im = ax.pcolormesh(lons, lats, frame_data.transpose(lat_dim, lon_dim).values, |
| | cmap=cmap, vmin=plot_kwargs['vmin'], vmax=plot_kwargs['vmax'], |
| | transform=ccrs.PlateCarree(), shading='auto') |
| | |
| | |
| | if plot_kwargs.get('coastlines', True): |
| | ax.coastlines(resolution='50m', color='black', linewidth=0.5) |
| | if plot_kwargs.get('gridlines', True): |
| | ax.gridlines(alpha=0.5) |
| | |
| | ax.set_global() |
| | |
| | else: |
| | |
| | coords = identify_coordinates(frame_data) |
| | x_dim = coords.get('X', frame_data.dims[-1]) |
| | y_dim = coords.get('Y', frame_data.dims[-2]) |
| | |
| | frame_plot = frame_data.transpose(y_dim, x_dim) |
| | x_coord = frame_data.coords[x_dim] |
| | y_coord = frame_data.coords[y_dim] |
| | |
| | im = ax.imshow(frame_plot.values, |
| | extent=[float(x_coord.min()), float(x_coord.max()), |
| | float(y_coord.min()), float(y_coord.max())], |
| | aspect='auto', origin='lower', |
| | cmap=plot_kwargs.get('cmap', 'viridis'), |
| | vmin=plot_kwargs['vmin'], vmax=plot_kwargs['vmax']) |
| | |
| | ax.set_xlabel(f"{x_dim} ({x_coord.attrs.get('units', '')})") |
| | ax.set_ylabel(f"{y_dim} ({y_coord.attrs.get('units', '')})") |
| | |
| | |
| | coord_val = coord_vals[frame_idx] |
| | coord_str = format_value(coord_val, dim) |
| | title = f"{da.attrs.get('long_name', da.name or 'Data')} - {dim}={coord_str}" |
| | ax.set_title(title) |
| | |
| | return [im] if 'im' in locals() else [] |
| | |
| | |
| | anim = FuncAnimation(fig, animate, frames=n_frames, interval=1000//fps, blit=False) |
| | |
| | |
| | try: |
| | |
| | Writer = plt.matplotlib.animation.writers['ffmpeg'] |
| | writer = Writer(fps=fps, metadata=dict(artist='TensorView'), bitrate=1800) |
| | anim.save(out, writer=writer) |
| | |
| | plt.close(fig) |
| | return out |
| | |
| | except Exception as e: |
| | plt.close(fig) |
| | raise RuntimeError(f"Failed to create animation: {str(e)}") |
| |
|
| |
|
| | def create_frame_sequence(da: xr.DataArray, dim: str, plot_func: Callable = None, |
| | output_dir: str = "frames", **plot_kwargs) -> List[str]: |
| | """ |
| | Create a sequence of individual frame images. |
| | |
| | Args: |
| | da: Input DataArray |
| | dim: Dimension to animate over |
| | plot_func: Plotting function to use |
| | output_dir: Directory to save frames |
| | **plot_kwargs: Additional plotting parameters |
| | |
| | Returns: |
| | List of frame file paths |
| | """ |
| | if dim not in da.dims: |
| | raise ValueError(f"Dimension '{dim}' not found in DataArray") |
| | |
| | os.makedirs(output_dir, exist_ok=True) |
| | |
| | coord_vals = da.coords[dim].values |
| | frame_paths = [] |
| | |
| | |
| | if plot_func is None: |
| | remaining_dims = [d for d in da.dims if d != dim] |
| | n_remaining = len(remaining_dims) |
| | |
| | coords = identify_coordinates(da) |
| | has_geo = 'X' in coords and 'Y' in coords |
| | |
| | if n_remaining == 1: |
| | plot_func = plot_1d |
| | elif n_remaining == 2 and has_geo: |
| | plot_func = plot_map |
| | elif n_remaining == 2: |
| | plot_func = plot_2d |
| | else: |
| | raise ValueError(f"Cannot auto-detect plot type for {n_remaining}D data") |
| | |
| | |
| | if 'vmin' not in plot_kwargs: |
| | plot_kwargs['vmin'] = float(da.min().values) |
| | if 'vmax' not in plot_kwargs: |
| | plot_kwargs['vmax'] = float(da.max().values) |
| | |
| | |
| | for i, coord_val in enumerate(coord_vals): |
| | frame_data = da.isel({dim: i}) |
| | |
| | |
| | fig = plot_func(frame_data, **plot_kwargs) |
| | |
| | |
| | coord_str = format_value(coord_val, dim) |
| | fig.suptitle(f"{da.attrs.get('long_name', da.name or 'Data')} - {dim}={coord_str}") |
| | |
| | |
| | frame_path = os.path.join(output_dir, f"frame_{i:04d}.png") |
| | fig.savefig(frame_path, dpi=150, bbox_inches='tight') |
| | frame_paths.append(frame_path) |
| | |
| | plt.close(fig) |
| | |
| | return frame_paths |
| |
|
| |
|
| | def frames_to_mp4(frame_dir: str, output_path: str, fps: int = 10, cleanup: bool = True) -> str: |
| | """ |
| | Convert a directory of frame images to MP4 video. |
| | |
| | Args: |
| | frame_dir: Directory containing frame images |
| | output_path: Output MP4 file path |
| | fps: Frames per second |
| | cleanup: Whether to delete frame files after conversion |
| | |
| | Returns: |
| | Path to created MP4 file |
| | """ |
| | if not check_ffmpeg(): |
| | raise RuntimeError("FFmpeg is required for MP4 conversion") |
| | |
| | |
| | cmd = [ |
| | 'ffmpeg', '-y', |
| | '-framerate', str(fps), |
| | '-pattern_type', 'glob', |
| | '-i', os.path.join(frame_dir, 'frame_*.png'), |
| | '-c:v', 'libx264', |
| | '-pix_fmt', 'yuv420p', |
| | '-crf', '18', |
| | output_path |
| | ] |
| | |
| | try: |
| | subprocess.run(cmd, check=True, capture_output=True) |
| | |
| | |
| | if cleanup: |
| | import glob |
| | for frame_file in glob.glob(os.path.join(frame_dir, 'frame_*.png')): |
| | os.remove(frame_file) |
| | |
| | |
| | try: |
| | os.rmdir(frame_dir) |
| | except OSError: |
| | pass |
| | |
| | return output_path |
| | |
| | except subprocess.CalledProcessError as e: |
| | raise RuntimeError(f"FFmpeg failed: {e.stderr.decode()}") |
| |
|
| |
|
| | def create_gif(da: xr.DataArray, dim: str, output_path: str = "animation.gif", |
| | duration: int = 200, plot_func: Callable = None, **plot_kwargs) -> str: |
| | """ |
| | Create an animated GIF. |
| | |
| | Args: |
| | da: Input DataArray |
| | dim: Dimension to animate over |
| | output_path: Output GIF file path |
| | duration: Duration per frame in milliseconds |
| | plot_func: Plotting function to use |
| | **plot_kwargs: Additional plotting parameters |
| | |
| | Returns: |
| | Path to created GIF file |
| | """ |
| | try: |
| | from PIL import Image |
| | except ImportError: |
| | raise ImportError("Pillow is required for GIF creation") |
| | |
| | |
| | with tempfile.TemporaryDirectory() as temp_dir: |
| | frame_paths = create_frame_sequence(da, dim, plot_func, temp_dir, **plot_kwargs) |
| | |
| | |
| | images = [] |
| | for frame_path in frame_paths: |
| | img = Image.open(frame_path) |
| | images.append(img) |
| | |
| | |
| | images[0].save(output_path, save_all=True, append_images=images[1:], |
| | duration=duration, loop=0) |
| | |
| | return output_path |