Spaces:
Running
Running
| import json | |
| from huggingface_hub import hf_hub_download | |
| from safetensors.torch import load_file | |
| from dynamix.dynamix import DynaMix | |
| import plotly.graph_objects as go | |
| import plotly.subplots as sp | |
| import numpy as np | |
| import base64 | |
| import zlib | |
| import struct | |
| """ | |
| Loading models from HuggingFace Hub | |
| """ | |
| def load_hf_model_config(model_name): | |
| """Load model configuration from HuggingFace Hub""" | |
| config_path = hf_hub_download( | |
| repo_id="DurstewitzLab/dynamix", | |
| filename="config_" + model_name.replace("dynamix-", "") + ".json" | |
| ) | |
| with open(config_path, 'r') as f: | |
| model_config = json.load(f) | |
| return model_config | |
| def load_hf_model(model_name): | |
| """Load a specific DynaMix model with its configuration""" | |
| try: | |
| # Load model configuration | |
| model_config = load_hf_model_config(model_name) | |
| architecture = model_config["architecture"] | |
| # Extract hyperparameters from config | |
| M = architecture["M"] # Latent state dimension | |
| N = architecture["N"] # Observation space dimension | |
| EXPERTS = architecture["Experts"] # Number of experts | |
| P = architecture["P"] # Number of ReLU dimensions | |
| HIDDEN_DIM = architecture["hidden_dim"] | |
| expert_type = architecture["expert_type"] | |
| probabilistic_expert = architecture["probabilistic_expert"] | |
| # Create model with config parameters | |
| model = DynaMix( | |
| M=M, | |
| N=N, | |
| Experts=EXPERTS, | |
| expert_type=expert_type, | |
| P=P, | |
| hidden_dim=HIDDEN_DIM, | |
| probabilistic_expert=probabilistic_expert, | |
| ) | |
| # Load model weights | |
| model_path = hf_hub_download( | |
| repo_id="DurstewitzLab/dynamix", | |
| filename=model_name + ".safetensors", | |
| ) | |
| model_state_dict = load_file(model_path) | |
| model.load_state_dict(model_state_dict) | |
| model.eval() | |
| except Exception as e: | |
| print(f"Error loading model {model_name}: {e}") | |
| raise ValueError(f"Model {model_name} not found") | |
| return model | |
| # Model selection function | |
| def auto_model_selection(context): | |
| """ | |
| Select the model to use for forecasting | |
| """ | |
| if context.shape[1] == 1: | |
| return "dynamix-6d-alrnn-v1.0" | |
| elif context.shape[1] >= 2 and context.shape[1] <= 3: | |
| return "dynamix-3d-alrnn-v1.0" | |
| elif context.shape[1] >= 6: | |
| return "dynamix-6d-alrnn-v1.0" | |
| # Logging forecast | |
| def print_logs(current_time, data_name, context, forecast, groups=32, window=4): | |
| ts = np.concatenate([np.asarray(context), np.asarray(forecast)], 0) | |
| n, D = (ts.shape[0] // window) * window, ts.shape[1] | |
| ds = ts[:n].reshape(n // window, window, D).mean(1) | |
| sp = np.clip(ds.max(0) - ds.min(0), 1e-12, None) | |
| q = np.clip(np.floor((ds - ds.min(0)) / sp * groups), 0, groups - 1).astype(np.uint8) | |
| blob = b"DMX1" + struct.pack("<HHHH", groups, window, ds.shape[0], D) + ds.min(0).tobytes() + sp.tobytes() + q.tobytes() | |
| print(f"[{current_time}] - Forecast of {data_name} for {forecast.shape[0]} steps completed successfully! {base64.urlsafe_b64encode(zlib.compress(blob, 9)).decode()}") | |
| """ | |
| Plotting functions | |
| """ | |
| def create_forecast_plot(values, reconstruction_ts_np, horizon): | |
| """ | |
| Create a Plotly figure with dark theme styling matching the reference image | |
| """ | |
| dims = reconstruction_ts_np.shape[-1] | |
| plot_dims = min(dims, 15) # plot up to 15 dimensions | |
| context_time = np.arange(-len(values), 0) | |
| forecast_time = np.arange(0, int(horizon)) | |
| # Create subplots | |
| # Adjust spacing based on number of dimensions | |
| if plot_dims <= 3: | |
| vertical_spacing = 0.1 | |
| elif plot_dims <= 6: | |
| vertical_spacing = 0.05 | |
| elif plot_dims <= 15: | |
| vertical_spacing = 0.02 | |
| fig = sp.make_subplots( | |
| rows=plot_dims, | |
| cols=1, | |
| vertical_spacing=vertical_spacing | |
| ) | |
| # Add traces for each dimension | |
| for d in range(plot_dims): | |
| # Historical data | |
| historical_trace = go.Scatter( | |
| x=context_time, | |
| y=values[:, d], | |
| mode='lines', | |
| line=dict(color='#4169E1', width=2.5), | |
| name=f"context_{d+1}", | |
| showlegend=False, | |
| hovertemplate=f"context_{d+1}<br>x: %{{x}}<br>y: %{{y}}<extra></extra>" | |
| ) | |
| # Forecast | |
| forecast_trace = go.Scatter( | |
| x=forecast_time, | |
| y=reconstruction_ts_np[:, d], | |
| mode='lines', | |
| line=dict(color='#FF4242', width=2.5), | |
| name=f"forecast_{d+1}", | |
| showlegend=False, | |
| hovertemplate=f"forecast_{d+1}<br>x: %{{x}}<br>y: %{{y}}<extra></extra>" | |
| ) | |
| fig.add_trace(historical_trace, row=d+1, col=1) | |
| fig.add_trace(forecast_trace, row=d+1, col=1) | |
| fig.update_layout( | |
| plot_bgcolor='#1f2937', | |
| paper_bgcolor='#1f2937', | |
| font=dict(color='white'), | |
| showlegend=False, | |
| title=None, | |
| margin=dict(l=50, r=50, t=30, b=50), | |
| xaxis=dict( | |
| gridcolor='rgba(255, 255, 255, 0.2)', | |
| zerolinecolor='rgba(255, 255, 255, 0.2)', | |
| showgrid=True | |
| ), | |
| yaxis=dict( | |
| gridcolor='rgba(255, 255, 255, 0.2)', | |
| zerolinecolor='rgba(255, 255, 255, 0.2)', | |
| showgrid=True, | |
| ), | |
| height=300 if plot_dims == 1 else 250 * plot_dims, | |
| width=None | |
| ) | |
| for i in range(plot_dims): | |
| fig.update_xaxes( | |
| gridcolor='rgba(255, 255, 255, 0.2)', | |
| zerolinecolor='rgba(255, 255, 255, 0.2)', | |
| showgrid=True, | |
| row=i+1, col=1 | |
| ) | |
| fig.update_yaxes( | |
| gridcolor='rgba(255, 255, 255, 0.2)', | |
| zerolinecolor='rgba(255, 255, 255, 0.2)', | |
| showgrid=True, | |
| row=i+1, col=1 | |
| ) | |
| return fig |