import gradio as gr import pandas as pd import plotly.graph_objects as go import numpy as np import torch from typing import Dict, List, Tuple import re from typing import Callable, Union, Dict class TimeSeriesEditor: def __init__(self, seq_length: int, feature_dim: int, trainer): # Existing initialization self.seq_length = seq_length self.feature_dim = feature_dim self.trainer = trainer self.coef = None self.stepsize = None self.sampling_steps = None self.feature_names = ["revenue", "download", "daily active user"]# * 20 # self.feature_names = [f"Feature {i}" for i in range(self.feature_dim)] # Store the latest model output self.latest_sample = None self.latest_observed_points = None self.latest_observed_mask = None self.latest_gradient_control_signal = None self.latest_model_control_signal = None # self.latest_metrics # Define scales for each feature self.feature_scales = { 0: 1000000, # Revenue: $1M per 0.1 1: 100000, # Download: 100K downloads per 0.1 2: 10000 # AU: 10K active users per 0.1 } self.feature_units = { 0: "$", # Revenue 1: "downloads", # Download 2: "users" # AU } self.show_normalized = True # Add frequency band multipliers self.freq_bands = np.ones(5) # 5 frequency bands, initially all set to 1.0 self.function_parser = FunctionParser() self.trending_controls = [ # (200, 250, 0, self.function_parser.string_to_function("sin(2*pi*x)"), 0.05) # 200,250,0,sin(2*pi*x),0.05 ] self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def format_value(self, value: float, feature_idx: int) -> str: """Format value with appropriate units and notation""" if self.show_normalized: return f"{value:.4f}" else: if feature_idx == 0: # Revenue return f"{self.feature_units[feature_idx]}{value:,.2f}" else: # Downloads and AU return f"{value:,.0f} {self.feature_units[feature_idx]}" def create_plot(self, sample: np.ndarray, observed_points: torch.Tensor, observed_mask: torch.Tensor, gradient_control_signal: Dict, metrics: Dict) -> List[go.Figure]: figures = [] # Get weights from model_control_signal (will be all 1s if not provided) weights = observed_mask for feat_idx in range(self.feature_dim): fig = go.Figure() # Scale values if needed scale_factor = self.feature_scales[feat_idx] * 10 if not self.show_normalized else 1 # Plot predicted line predicted_values = sample[:, feat_idx] * scale_factor fig.add_trace(go.Scatter( x=np.arange(self.seq_length), y=predicted_values, mode='lines', name='Predicted', line=dict(color='green', width=2), showlegend=True )) # Calculate and plot confidence bands based on weights # Lower weights = larger uncertainty bands mask = observed_points[:, feat_idx] > 0 ox = np.arange(0, self.seq_length)[mask] oy = observed_points[mask, feat_idx].numpy() * scale_factor weights_masked = 1 - weights[mask, feat_idx].numpy() # Calculate error bars - inverse relationship with weight # Weight of 1.0 gives minimal uncertainty (0.02) # Weight of 0.1 gives larger uncertainty (0.2) # error_y = 0.02 / weights_masked error_y = weights_masked / 5 # Plot observed points with error bars - changed symbol to 'cross' fig.add_trace(go.Scatter( x=ox, y=oy, mode='markers', name='Observed', marker=dict( # special red color='rgba(255, 0, 0, 0.5)', # size=10, symbol='x', # Changed from 'circle' to 'x' for cross symbol ), error_y=dict( type='data', array=error_y * scale_factor, visible=True, thickness=0.5, width=2, color='blue' ), showlegend=True )) # Add shaded confidence bands around the predicted line # This shows the general uncertainty in the prediction uncertainty = 0.05 # Base uncertainty level upper_bound = predicted_values + uncertainty * scale_factor lower_bound = predicted_values - uncertainty * scale_factor fig.add_trace(go.Scatter( x=np.concatenate([np.arange(self.seq_length), np.arange(self.seq_length)[::-1]]), y=np.concatenate([upper_bound, lower_bound[::-1]]), # fill='toself', # fillcolor='rgba(0,100,0,0.1)', line=dict(color='rgba(255,255,255,0)'), name='Prediction Interval', showlegend=True )) # Add vertical lines for peak points if gradient_control_signal.get("peak_points"): for peak_point in gradient_control_signal["peak_points"]: fig.add_vline(x=peak_point, line_dash="dash", line_color="red") # Add metrics annotations total_value = np.sum(sample[:, feat_idx]) * (self.feature_scales[feat_idx] * 10 if not self.show_normalized else 1) annotations = [dict( x=0.02, y=1.1, xref="paper", yref="paper", text=f"Total {self.feature_names[feat_idx]}: {self.format_value(total_value, feat_idx)}", showarrow=False )] # Update y-axis title based on feature and scaling if self.show_normalized: y_title = f'{self.feature_names[feat_idx]} (Normalized)' else: unit = self.feature_units[feat_idx] y_title = f'{self.feature_names[feat_idx]} ({unit})' # Create a more informative legend for uncertainty legend_text = ( "Prediction with Confidence Bands
" "• Blue points: Observed values with uncertainty
" "• Green line: Predicted values
" # "• Shaded area: Prediction uncertainty
" "• Error bars: Observation uncertainty (larger = lower weight)" ) fig.update_layout( title=dict( text=f'Feature: {self.feature_names[feat_idx]}', x=0.5, y=0.95 ), xaxis_title='Time', yaxis_title=y_title, height=400, showlegend=True, dragmode='select', annotations=[ *annotations, # dict( # x=1.15, # y=0.5, # xref="paper", # yref="paper", # text=legend_text, # showarrow=False, # align="left", # bordercolor="black", # borderwidth=1, # borderpad=4, # bgcolor="white", # ) ], margin=dict(r=200) # Add right margin for legend ) figures.append(fig) return figures def update_scaling(self, revenue_scale: float, download_scale: float, au_scale: float, show_normalized: bool) -> Tuple[List[go.Figure], Dict]: """Update the scaling parameters and redraw plots""" if self.latest_sample is None: return [], {} # Update scales self.feature_scales = { 0: revenue_scale, 1: download_scale, 2: au_scale } self.show_normalized = show_normalized # Calculate metrics metrics = { 'show_normalized': self.show_normalized } for feat_idx in range(self.feature_dim): total = np.sum(self.latest_sample[:, feat_idx]) * (self.feature_scales[feat_idx] * 10 if not self.show_normalized else 1) metrics[f'total_{self.feature_names[feat_idx]}'] = self.format_value(total, feat_idx) # Update plots figures = self.create_plot( self.latest_sample, self.latest_observed_points, self.latest_observed_mask, self.latest_gradient_control_signal, metrics ) return figures, metrics def parse_data_points(self, df) -> Dict: """Parse data points from DataFrame with columns: time,feature,value""" data_dict = {} if df is None or df.empty: return data_dict for _, row in df.iterrows(): # Skip if any required value is NaN if pd.isna(row['time']) or pd.isna(row['feature']) or pd.isna(row['value']): continue try: time_idx = int(row['time']) feature_idx = int(row['feature']) value = float(row['value']) if time_idx not in data_dict: data_dict[time_idx] = {} data_dict[time_idx][feature_idx] = (value, 1.0) except (ValueError, TypeError): continue return data_dict def parse_point_groups(self, df) -> Dict: """Parse point groups from DataFrame with columns: start,end,interval,feature,value,weight""" data_dict = {} if df is None or df.empty: return data_dict for _, row in df.iterrows(): # Skip if any required value is NaN if pd.isna(row['start']) or pd.isna(row['end']) or pd.isna(row['interval']) or \ pd.isna(row['feature']) or pd.isna(row['value']): continue try: start = int(row['start']) end = int(row['end']) interval = int(row['interval']) feature = int(row['feature']) value = float(row['value']) weight = float(row.get('weight', 1.0)) if not pd.isna(row.get('weight')) else 1.0 for t in range(start, end + 1, interval): if 0 <= t < self.seq_length: if t not in data_dict: data_dict[t] = {} data_dict[t][feature] = (value, weight) except (ValueError, TypeError): continue return data_dict def to_tensor(self, observed_points_dict, seq_length, feature_dim): observed_points = torch.zeros((seq_length, feature_dim)) observed_weights = torch.zeros((seq_length, feature_dim)) for seq, feature_dict in observed_points_dict.items(): for feature, (value, weight) in feature_dict.items(): observed_points[seq, feature] = value observed_weights[seq, feature] = weight return observed_points, observed_weights def apply_direct_edits(self, sample: np.ndarray, edit_params: Dict) -> np.ndarray: """Apply direct edits to the sample array""" edited_sample = sample.copy() if edit_params.get("enable_direct_area"): areas = self.parse_area_selections(edit_params["direct_areas"]) for area in areas: start, end, feat_idx, target = area edited_sample[start:end, feat_idx] += target edited_sample = np.clip(edited_sample, 0, 1) return edited_sample def parse_area_selections(self, area_text: str) -> List[Tuple]: """Parse area selection text into (start, end, feature, target) tuples""" areas = [] if not area_text.strip(): return areas area_text = area_text.replace('\n', ';') for line in area_text.strip().split(';'): if not line.strip(): continue try: start, end, feat, target = map(float, line.strip().split(',')) areas.append((int(start), int(end), int(feat), target)) except (ValueError, IndexError): continue return areas def apply_trending_mask(self, points: torch.Tensor, mask: torch.Tensor, consider_last_generated=False) -> Tuple[torch.Tensor, torch.Tensor]: """Apply trending functions as soft constraints through masks""" if not self.trending_controls or self.latest_sample is None: return points, mask for start, end, feat_idx, func, confidence in self.trending_controls: if start < 0 or end > self.seq_length or start >= end: continue # Generate x values normalized between 0 and 1 for the segment x = np.linspace(0, 1, end - start) try: # Calculate the function values y = func(x) # Scale the function output to 0-1 range y = (y - np.min(y)) / (np.max(y) - np.min(y)) # points[start:end, feat_idx] = torch.tensor(y, dtype=points.dtype) # mask[start:end, feat_idx] = max(mask[start:end, feat_idx], min(1.0, confidence * abs( # self.latest_sample[start:end, feat_idx] - y # ))) # Use lower weight for trending constraints except Exception as e: print(f"Error applying function: {e}") continue # Apply the trend as soft constraints mask_zero = (mask[start:end, feat_idx] == 0) points[start:end, feat_idx][mask_zero] = torch.tensor(y, dtype=points.dtype)[mask_zero] mask[start:end, feat_idx][mask_zero] = torch.tensor(confidence * np.ones_like(y), dtype=mask.dtype)[mask_zero] # mask[start:end, feat_idx][mask_zero] = torch.tensor((confidence * np.abs(self.latest_sample[start:end, feat_idx] - y)), dtype=mask.dtype)[mask_zero] mask = mask.clamp(0, 1) return points, mask def update_model(self, figures: List[go.Figure], data_points: str, point_groups: str, enable_area_control: bool, area_selections: str, enable_auc: bool, auc_value: float, enable_peaks: bool, peak_points: str, peak_alpha: float, auc_weight: float, peak_weight: float, enable_trending: bool = True, enable_trending_with_diff: bool = False, trending_params: str = "" ) -> Tuple[List[go.Figure], str, str, Dict]: # Parse both point groups and individual data points individual_points_dict = self.parse_data_points(data_points) group_points_dict = self.parse_point_groups(point_groups) # Merge dictionaries, giving preference to individual points combined_points_dict = group_points_dict.copy() for t, feat_dict in individual_points_dict.items(): if t not in combined_points_dict: combined_points_dict[t] = {} for f, v in feat_dict.items(): combined_points_dict[t][f] = v # Convert to tensor observed_points, observed_weights = self.to_tensor( combined_points_dict, self.seq_length, self.feature_dim ) observed_mask = observed_weights # Parse peak points peak_points_list = [] if enable_peaks and peak_points: try: peak_points_list = [int(x.strip()) for x in peak_points.split(',') if x.strip()] except ValueError: peak_points_list = [] # Apply trending control if enabled if enable_trending and trending_params: self.parse_trending_parameters(trending_params) observed_points, observed_mask = self.apply_trending_mask(observed_points, observed_mask, consider_last_generated=enable_trending_with_diff) # Build gradient control signal # IMPORTANT gradient_control_signal = {} if enable_auc: gradient_control_signal["auc"] = auc_value gradient_control_signal["auc_weight"] = auc_weight if enable_peaks: gradient_control_signal.update({ "peak_points": peak_points_list, "peak_alpha": peak_alpha, "peak_weight": peak_weight }) # Build model control signal model_control_signal = {} # if enable_area_control and area_selections: # areas = self.parse_area_selections(area_selections) # if areas: # model_control_signal["selected_areas"] = areas # Run prediction with torch.no_grad(): # to cuda observed_points = observed_points.to(self.device) observed_mask = observed_mask.to(self.device) sample = self.trainer.predict_weighted_points( observed_points, # (seq_length, feature_dim) observed_mask, # (seq_length, feature_dim) self.coef, # fixed self.stepsize, # fixed self.sampling_steps, # fixed # model_control_signal=model_control_signal, gradient_control_signal=gradient_control_signal ) observed_points = observed_points.cpu() observed_mask = observed_mask.cpu() # Store latest results self.latest_sample = sample self.latest_observed_points = observed_points self.latest_observed_mask = observed_mask self.latest_gradient_control_signal = gradient_control_signal self.latest_model_control_signal = model_control_signal # Calculate metrics metrics = { 'show_normalized': self.show_normalized } for feat_idx in range(self.feature_dim): total = np.sum(sample[:, feat_idx]) * (self.feature_scales[feat_idx] * 10 if not self.show_normalized else 1) metrics[f'total_{self.feature_names[feat_idx]}'] = self.format_value(total, feat_idx) # Update plots figures = self.create_plot(sample, observed_points, observed_mask, gradient_control_signal, metrics) return figures, data_points, point_groups, metrics def update_additional_edit( self, enable_direct_area: bool, direct_areas: str): # Apply direct edits if enabled if enable_direct_area: sample = self.apply_direct_edits(self.latest_sample, { "enable_direct_area": enable_direct_area, "direct_areas": direct_areas }) else: sample = self.latest_sample # Calculate metrics metrics = { 'show_normalized': self.show_normalized } for feat_idx in range(self.feature_dim): total = np.sum(sample[:, feat_idx]) * (self.feature_scales[feat_idx] * 10 if not self.show_normalized else 1) metrics[f'total_{self.feature_names[feat_idx]}'] = self.format_value(total, feat_idx) # Update plots figures = self.create_plot( sample, self.latest_observed_points, self.latest_observed_mask, self.latest_gradient_control_signal, metrics ) return figures, metrics def apply_frequency_filter(self, signal: np.ndarray) -> np.ndarray: """Apply FFT-based frequency filtering using the current band multipliers""" # Get FFT of the signal fft = np.fft.fft(signal) freqs = np.fft.fftfreq(len(signal)) # Split frequencies into 5 bands # Exclude DC component (0 frequency) from bands pos_freqs = freqs[1:len(freqs)//2] freq_ranges = np.array_split(pos_freqs, 5) # Apply band multipliers filtered_fft = fft.copy() # Handle DC component separately (lowest frequency) filtered_fft[0] *= self.freq_bands[4] # Apply very low freq multiplier to DC # Apply multipliers to each frequency band for i, freq_range in enumerate(freq_ranges): # Get indices for this frequency band band_mask = np.logical_and( freqs >= freq_range[0], freqs <= freq_range[-1] ) # Apply multiplier to positive and negative frequencies symmetrically filtered_fft[band_mask] *= self.freq_bands[4-i] filtered_fft[np.flip(band_mask)] *= self.freq_bands[4-i] # Convert back to time domain filtered_signal = np.real(np.fft.ifft(filtered_fft)) return filtered_signal def update_frequency_bands(self, band_idx: int, value: float) -> Tuple[List[go.Figure], Dict]: """Update a frequency band multiplier and recompute the filtered signal""" if self.latest_sample is None: return [], {} # Update the specified band multiplier self.freq_bands[band_idx] = value # Apply frequency filtering to each feature filtered_sample = self.latest_sample.copy() for feat_idx in range(self.feature_dim): filtered_sample[:, feat_idx] = self.apply_frequency_filter( self.latest_sample[:, feat_idx] ) # Ensure values remain in valid range filtered_sample = np.clip(filtered_sample, 0, 1) # Calculate metrics metrics = { 'show_normalized': self.show_normalized, 'frequency_bands': self.freq_bands.tolist() } for feat_idx in range(self.feature_dim): total = np.sum(filtered_sample[:, feat_idx]) * (self.feature_scales[feat_idx] * 10 if not self.show_normalized else 1) metrics[f'total_{self.feature_names[feat_idx]}'] = self.format_value(total, feat_idx) # Update plots figures = self.create_plot( filtered_sample, self.latest_observed_points, self.latest_observed_mask, self.latest_gradient_control_signal, metrics ) return figures, metrics def parse_trending_parameters(self, trending_text: str) -> List[Tuple]: """Parse trending control parameters into (start, end, feature, function) tuples""" trending_params = [] if not trending_text.strip(): return trending_params trending_text = trending_text.replace('\n', ';') for line in trending_text.strip().split(';'): if not line.strip(): continue try: # Split by comma and handle the function part separately parts = line.strip().split(',', 4) if len(parts) != 5: continue start, end, feat = map(int, parts[:3]) function_str = parts[3].strip() confidence = float(parts[4]) # Convert the function string to a callable try: func = self.function_parser.string_to_function(function_str) trending_params.append((start, end, feat, func, confidence)) except ValueError as e: print(f"Error parsing function '{function_str}': {e}") continue except (ValueError, IndexError): continue self.trending_controls = trending_params # Store the parsed parameters return trending_params def create_gradio_interface(editor: TimeSeriesEditor): with gr.Blocks() as app: gr.Markdown("# Time Series Editor") gr.Markdown("## Instruction: Scroll Down + Click [Update Figure] [~10s]") metrics_display = gr.JSON(label="Metrics", value={}) with gr.Row(): with gr.Column(scale=1): # with Tab(): # Scaling Parameters Section # with gr.Group(): gr.Markdown("## Scaling Parameters") with gr.Accordion("Open for More Detail", open=False): revenue_scale = gr.Number( label="Revenue Scale ($ per 0.1 in model)", value=1000000 ) download_scale = gr.Number( label="Download Scale (downloads per 0.1 in model)", value=100000 ) au_scale = gr.Number( label="Active Users Scale (users per 0.1 in model)", value=10000 ) show_normalized = gr.Checkbox( label="Show Normalized Values (0-1 scale)", value=True ) update_scaling_btn = gr.Button("Update Scaling") # TS Section gr.Markdown("## Time Series Control Panel") # with gr.Accordion("Open for More Detail"): with gr.Group(): gr.Markdown("### Fixed Point Control") data_points_df = gr.Dataframe( headers=["time", "feature", "value"], datatype=["number", "number", "number"], # label="Anchor Point Control", value=[[0, 0, 0.04], [2, 0, 0.58], [6, 0, 0.27], [58, 0, 0.8], [60, 0, 0.5]], col_count=(3, "fixed"), # Fix number of columns interactive=True ) add_data_point_btn = gr.Button("Add Data Point") def add_data_point(df): new_row = pd.DataFrame([[None, 0, None]], columns=["time", "feature", "value"]) return pd.concat([df, new_row], ignore_index=True) add_data_point_btn.click( fn=add_data_point, inputs=[data_points_df], outputs=[data_points_df] ) with gr.Group(): gr.Markdown("### Group of Anchor Point Control with Confidence") point_groups_df = gr.Dataframe( headers=["start", "end", "interval", "feature", "value", "weight"], datatype=["number", "number", "number", "number", "number", "number"], # label="Group of Anchor Point Control", value=[[0, 50, 10, 0, 0.5, 0.1], [100, 150, 50, 0, 0.1, 0.5]], col_count=(6, "fixed"), # Fix number of columns interactive=True ) add_point_group_btn = gr.Button("Add Point Group") def add_point_group(df): new_row = pd.DataFrame([[None, None, None, 0, None, None]], columns=["start", "end", "interval", "feature", "value", "weight"]) return pd.concat([df, new_row], ignore_index=True) add_point_group_btn.click( fn=add_point_group, inputs=[point_groups_df], outputs=[point_groups_df] ) with gr.Group(): # with gr.Tab("Trending Control"): gr.Markdown("### Trending Control") gr.Markdown(""" Enter trending control parameters in the format: ``` start_time,end_time,feature,function,confidence ``` Examples: - Linear trend: `0,100,0,x` - Sine wave: `0,100,0,sin(2*pi*x)` - Exponential: `0,100,0,exp(-x)` Separate multiple trends with semicolons. """) enable_trending_control = gr.Checkbox(label="Enable Trending Control", value=False) enable_trending_control_with_diff = gr.Checkbox(label="Consider Last Generated", value=False) trending_control = gr.Textbox( label="Trending Control Parameters", lines=2, placeholder="Enter parameters: start_time,end_time,feature,function,condifdence; separated by semicolons", value="200,250,0,sin(2*pi*x),0.05" ) # Area Control Parameters with gr.Group(visible=False): gr.Markdown("### Area Control") enable_area_control = gr.Checkbox(label="Enable Area Control", value=False) area_selections = gr.Textbox( label="Area Selections (format: start_time,end_time,feature,target_value)", lines=2, placeholder="Enter areas: start,end,feature,target; separated by semicolons", ) # AUC Parameters gr.Markdown("### Statistics Control") enable_auc = gr.Checkbox(label="Enable Total Sum Control", value=True) auc_input = gr.Number(label="Target Sum Value", value=-150) auc_weight_input = gr.Number(label="Sum Weight", value=10.0) # Peak Parameters with gr.Group(visible=False): gr.Markdown("### Peak Control") enable_peaks = gr.Checkbox(label="Enable Peak Control", value=False) peak_points_input = gr.Textbox(label="Peak Points (comma-separated)", value="100,200") peak_alpha_input = gr.Number(label="Peak Alpha", value=10) peak_weight_input = gr.Number(label="Peak Weight", value=1.0) update_model_btn = gr.Button("Update Figure") gr.Markdown("## Extend Edit", visible=False) with gr.Tab("Range Shift", visible=False): # gr.Markdown("### Direct Edit Control") enable_direct_area = gr.Checkbox(label="Enable Direct Edits", value=False) # range shift direct_areas = gr.Textbox( label="Direct Edit Areas (format: start_time,end_time,feature,delta)", lines=2, placeholder="Enter areas: start,end,feature,delta; separated by semicolons", value="150,200,0,-0.1" ) update_additional_btn = gr.Button("Update Additional Edit") # with gr.Tab("Trending Control"): # gr.Markdown("### Trending Control") # gr.Markdown(""" # Enter trending control parameters in the format: # ``` # start_time,end_time,feature,function # ``` # Examples: # - Linear trend: `0,100,0,x` # - Sine wave: `0,100,0,sin(2*pi*x)` # - Exponential: `0,100,0,exp(-x)` # Separate multiple trends with semicolons. # """) # enable_trending_control = gr.Checkbox(label="Enable Trending Control", value=False) # enable_trending_control_with_diff = gr.Checkbox(label="Consider Last Generated", value=False) # trending_control = gr.Textbox( # label="Trending Control Parameters", # lines=2, # placeholder="Enter parameters: start_time,end_time,feature,function,condifdence; separated by semicolons", # value="0,100,0,sin(2*pi*x),0.3" # ) # with gr.Tab("Frequency Controls", visible=False): with gr.Group(visible=False): gr.Markdown("Adjust multipliers for different frequency bands (0-2)") freq_bands = [ gr.Slider( minimum=0, maximum=2, step=0.1, value=1.0, label=f"Band {i+1}: {'Very High' if i==0 else 'High' if i==1 else 'Mid' if i==2 else 'Low' if i==3 else 'Very Low'} Freq", ) for i in range(5) ] gr.Markdown("### Feature Index Reference:") for idx, name in enumerate(editor.feature_names): gr.Markdown(f"- {idx}: {name}") with gr.Column(scale=1.2): gr.Markdown(""" ### Plot Legend - **Points with Error Bars**: Observed values where: - Point position = observed value - Error bar size = uncertainty (inversely proportional to weight) - **Green Line**: Model prediction - **Vertical Red Lines**: Peak points (if enabled) """) plots = [gr.Plot() for _ in range(editor.feature_dim)] # - **Shaded Area**: General prediction uncertainty def update_scaling_callback(revenue_scale, download_scale, au_scale, show_normalized): figs, metrics = editor.update_scaling( revenue_scale, download_scale, au_scale, show_normalized ) return [*figs, metrics] def update_model_callback( data_points_df, point_groups_df, enable_area_control, area_selections, enable_auc, auc, auc_weight, enable_peaks, peak_points, peak_alpha, peak_weight, enable_trending, enable_trending_with_diff, trending_params ): figs, _, _, metrics = editor.update_model( plots, data_points_df, point_groups_df, enable_area_control, area_selections, enable_auc, auc, enable_peaks, peak_points, peak_alpha, auc_weight, peak_weight, enable_trending, enable_trending_with_diff, trending_params ) return [*figs, metrics] # Update the click handler update_model_btn.click( fn=update_model_callback, inputs=[ data_points_df, point_groups_df, enable_area_control, area_selections, enable_auc, auc_input, auc_weight_input, enable_peaks, peak_points_input, peak_alpha_input, peak_weight_input, enable_trending_control, enable_trending_control_with_diff, trending_control ], outputs=[*plots, metrics_display] ) def update_additional_callback(enable_direct_area, direct_areas): figs, metrics = editor.update_additional_edit( enable_direct_area, direct_areas ) return [*figs, metrics] def update_freq_band(band_idx, value): figs, metrics = editor.update_frequency_bands(band_idx, value) return [*figs, metrics] update_scaling_btn.click( fn=update_scaling_callback, inputs=[ revenue_scale, download_scale, au_scale, show_normalized ], outputs=[*plots, metrics_display] ) update_additional_btn.click( fn=update_additional_callback, inputs=[enable_direct_area, direct_areas], outputs=[*plots, metrics_display] ) # Add event handlers for frequency band sliders for i, slider in enumerate(freq_bands): slider.change( fn=update_freq_band, inputs=[gr.Number(value=i, visible=False), slider], outputs=[*plots, metrics_display] ) app.load( fn=update_model_callback, inputs=[ data_points_df, point_groups_df, enable_area_control, area_selections, enable_auc, auc_input, auc_weight_input, enable_peaks, peak_points_input, peak_alpha_input, peak_weight_input, enable_trending_control, enable_trending_control_with_diff, trending_control ], outputs=[*plots, metrics_display] ) return app class FunctionParser: def __init__(self): # Define available mathematical functions and constants self.math_functions = { 'sin': np.sin, 'cos': np.cos, 'tan': np.tan, 'exp': np.exp, 'log': np.log, 'sqrt': np.sqrt, 'abs': np.abs, 'pow': np.power, 'pi': np.pi, 'e': np.e, 'asin': np.arcsin, 'acos': np.arccos, 'atan': np.arctan, 'sinh': np.sinh, 'cosh': np.cosh, 'tanh': np.tanh } def validate_expression(self, expression: str) -> bool: """ Validate the mathematical expression for basic syntax errors. """ # Check for balanced parentheses if expression.count('(') != expression.count(')'): raise ValueError("Unbalanced parentheses in expression") # Check for invalid characters valid_chars = set('0123456789.+-*/()^ xXepi,') valid_chars.update(''.join(self.math_functions.keys())) if not all(c in valid_chars or c.isspace() for c in expression.lower()): raise ValueError("Expression contains invalid characters") return True def preprocess_expression(self, expression: str) -> str: """ Preprocess the expression to handle various input formats. """ # Remove whitespace expression = expression.replace(' ', '') # Convert ^ to ** for exponentiation expression = expression.replace('^', '**') # Ensure multiplication is explicit expression = re.sub(r'(\d+)([a-zA-Z])', r'\1*\2', expression) expression = re.sub(r'(\))([\w])', r'\1*\2', expression) # Replace X with x for consistency expression = expression.lower() return expression def string_to_function(self, expression: str) -> Callable[[Union[float, np.ndarray]], Union[float, np.ndarray]]: """ Convert a string mathematical expression to a callable function. Args: expression (str): Mathematical expression (e.g., "sin(x) + x^2") Returns: Callable: A function that takes x as input and returns the evaluated result Example: >>> f = string_to_function("sin(x) + x^2") >>> f(0.5) 0.729321... """ # Validate and preprocess the expression self.validate_expression(expression) processed_expr = self.preprocess_expression(expression) # Create the function namespace namespace = self.math_functions.copy() try: # Create the lambda function func = eval(f"lambda x: {processed_expr}", namespace) # Test the function with a simple input test_value = 1.0 try: func(test_value) except Exception as e: raise ValueError(f"Invalid function: {str(e)}") return func except SyntaxError as e: raise ValueError(f"Invalid expression syntax: {str(e)}") except Exception as e: raise ValueError(f"Error creating function: {str(e)}") @staticmethod def demonstrate_usage(): """ Demonstrate various uses of the function parser. """ parser = FunctionParser() # Test cases test_expressions = [ "x^2 + 2*x + 1", "sin(x) + cos(x)", "exp(-x^2)", "log(x + 1)", "sqrt(1 - x^2)", ] print("Testing various mathematical expressions:") x_test = 0.5 for expr in test_expressions: try: print(f"\nExpression: {expr}") func = parser.string_to_function(expr) result = func(x_test) print(f"f({x_test}) = {result}") # Test with numpy array x_array = np.linspace(0, 1, 5) results = func(x_array) print(f"f(array) = {results}") except Exception as e: print(f"Error: {str(e)}") # Example usage: if __name__ == "__main__": import os import torch import numpy as np # assert torch.cuda.is_available(), "CUDA must be available" os.environ["WANDB_ENABLED"] = "false" print(os.getcwd()) device = torch.device(f"cuda:0") if torch.cuda.is_available() else "cpu" print(f"Device: {device}") print(f"Using device: {device}") from models.Tiffusion import tiffusion model = tiffusion.Tiffusion( seq_length=365, feature_size=3, n_layer_enc=6, n_layer_dec=4, d_model=128, timesteps=500, sampling_timesteps=200, loss_type='l1', beta_schedule='cosine', n_heads=8, mlp_hidden_times=4, attn_pd=0.0, resid_pd=0.0, kernel_size=1, padding_size=0, control_signal=[] ).to(device) model.load_state_dict(torch.load("./weight/checkpoint-10.pt", map_location=device, weights_only=True)["model"]) coef = 1.0e-2 stepsize = 5.0e-2 sampling_steps = 100 # Adjustable between 100-500 for speed/accuracy tradeoff seq_length = 365 feature_dim = 3 print(f"seq_length: {seq_length}, feature_dim: {feature_dim}") editor = TimeSeriesEditor(seq_length, feature_dim, model) editor.coef = coef editor.stepsize = stepsize editor.sampling_steps = sampling_steps app = create_gradio_interface(editor) app.launch(show_api=False)