Spaces:
Sleeping
Sleeping
import gradio as gr | |
import pandas as pd | |
import plotly.graph_objects as go | |
import numpy as np | |
import torch | |
from typing import Dict, List, Tuple | |
import re | |
from typing import Callable, Union, Dict | |
class TimeSeriesEditor: | |
def __init__(self, seq_length: int, feature_dim: int, trainer): | |
# Existing initialization | |
self.seq_length = seq_length | |
self.feature_dim = feature_dim | |
self.trainer = trainer | |
self.coef = None | |
self.stepsize = None | |
self.sampling_steps = None | |
self.feature_names = ["revenue", "download", "daily active user"]# * 20 | |
# self.feature_names = [f"Feature {i}" for i in range(self.feature_dim)] | |
# Store the latest model output | |
self.latest_sample = None | |
self.latest_observed_points = None | |
self.latest_observed_mask = None | |
self.latest_gradient_control_signal = None | |
self.latest_model_control_signal = None | |
# self.latest_metrics | |
# Define scales for each feature | |
self.feature_scales = { | |
0: 1000000, # Revenue: $1M per 0.1 | |
1: 100000, # Download: 100K downloads per 0.1 | |
2: 10000 # AU: 10K active users per 0.1 | |
} | |
self.feature_units = { | |
0: "$", # Revenue | |
1: "downloads", # Download | |
2: "users" # AU | |
} | |
self.show_normalized = True | |
# Add frequency band multipliers | |
self.freq_bands = np.ones(5) # 5 frequency bands, initially all set to 1.0 | |
self.function_parser = FunctionParser() | |
self.trending_controls = [ | |
# (200, 250, 0, self.function_parser.string_to_function("sin(2*pi*x)"), 0.05) | |
# 200,250,0,sin(2*pi*x),0.05 | |
] | |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
def format_value(self, value: float, feature_idx: int) -> str: | |
"""Format value with appropriate units and notation""" | |
if self.show_normalized: | |
return f"{value:.4f}" | |
else: | |
if feature_idx == 0: # Revenue | |
return f"{self.feature_units[feature_idx]}{value:,.2f}" | |
else: # Downloads and AU | |
return f"{value:,.0f} {self.feature_units[feature_idx]}" | |
def create_plot(self, sample: np.ndarray, observed_points: torch.Tensor, | |
observed_mask: torch.Tensor, | |
gradient_control_signal: Dict, metrics: Dict) -> List[go.Figure]: | |
figures = [] | |
# Get weights from model_control_signal (will be all 1s if not provided) | |
weights = observed_mask | |
for feat_idx in range(self.feature_dim): | |
fig = go.Figure() | |
# Scale values if needed | |
scale_factor = self.feature_scales[feat_idx] * 10 if not self.show_normalized else 1 | |
# Plot predicted line | |
predicted_values = sample[:, feat_idx] * scale_factor | |
fig.add_trace(go.Scatter( | |
x=np.arange(self.seq_length), | |
y=predicted_values, | |
mode='lines', | |
name='Predicted', | |
line=dict(color='green', width=2), | |
showlegend=True | |
)) | |
# Calculate and plot confidence bands based on weights | |
# Lower weights = larger uncertainty bands | |
mask = observed_points[:, feat_idx] > 0 | |
ox = np.arange(0, self.seq_length)[mask] | |
oy = observed_points[mask, feat_idx].numpy() * scale_factor | |
weights_masked = 1 - weights[mask, feat_idx].numpy() | |
# Calculate error bars - inverse relationship with weight | |
# Weight of 1.0 gives minimal uncertainty (0.02) | |
# Weight of 0.1 gives larger uncertainty (0.2) | |
# error_y = 0.02 / weights_masked | |
error_y = weights_masked / 5 | |
# Plot observed points with error bars - changed symbol to 'cross' | |
fig.add_trace(go.Scatter( | |
x=ox, | |
y=oy, | |
mode='markers', | |
name='Observed', | |
marker=dict( | |
# special red | |
color='rgba(255, 0, 0, 0.5)', | |
# size=10, | |
symbol='x', # Changed from 'circle' to 'x' for cross symbol | |
), | |
error_y=dict( | |
type='data', | |
array=error_y * scale_factor, | |
visible=True, | |
thickness=0.5, | |
width=2, | |
color='blue' | |
), | |
showlegend=True | |
)) | |
# Add shaded confidence bands around the predicted line | |
# This shows the general uncertainty in the prediction | |
uncertainty = 0.05 # Base uncertainty level | |
upper_bound = predicted_values + uncertainty * scale_factor | |
lower_bound = predicted_values - uncertainty * scale_factor | |
fig.add_trace(go.Scatter( | |
x=np.concatenate([np.arange(self.seq_length), np.arange(self.seq_length)[::-1]]), | |
y=np.concatenate([upper_bound, lower_bound[::-1]]), | |
# fill='toself', | |
# fillcolor='rgba(0,100,0,0.1)', | |
line=dict(color='rgba(255,255,255,0)'), | |
name='Prediction Interval', | |
showlegend=True | |
)) | |
# Add vertical lines for peak points | |
if gradient_control_signal.get("peak_points"): | |
for peak_point in gradient_control_signal["peak_points"]: | |
fig.add_vline(x=peak_point, line_dash="dash", line_color="red") | |
# Add metrics annotations | |
total_value = np.sum(sample[:, feat_idx]) * (self.feature_scales[feat_idx] * 10 if not self.show_normalized else 1) | |
annotations = [dict( | |
x=0.02, | |
y=1.1, | |
xref="paper", | |
yref="paper", | |
text=f"Total {self.feature_names[feat_idx]}: {self.format_value(total_value, feat_idx)}", | |
showarrow=False | |
)] | |
# Update y-axis title based on feature and scaling | |
if self.show_normalized: | |
y_title = f'{self.feature_names[feat_idx]} (Normalized)' | |
else: | |
unit = self.feature_units[feat_idx] | |
y_title = f'{self.feature_names[feat_idx]} ({unit})' | |
# Create a more informative legend for uncertainty | |
legend_text = ( | |
"Prediction with Confidence Bands<br>" | |
"• Blue points: Observed values with uncertainty<br>" | |
"• Green line: Predicted values<br>" | |
# "• Shaded area: Prediction uncertainty<br>" | |
"• Error bars: Observation uncertainty (larger = lower weight)" | |
) | |
fig.update_layout( | |
title=dict( | |
text=f'Feature: {self.feature_names[feat_idx]}', | |
x=0.5, | |
y=0.95 | |
), | |
xaxis_title='Time', | |
yaxis_title=y_title, | |
height=400, | |
showlegend=True, | |
dragmode='select', | |
annotations=[ | |
*annotations, | |
# dict( | |
# x=1.15, | |
# y=0.5, | |
# xref="paper", | |
# yref="paper", | |
# text=legend_text, | |
# showarrow=False, | |
# align="left", | |
# bordercolor="black", | |
# borderwidth=1, | |
# borderpad=4, | |
# bgcolor="white", | |
# ) | |
], | |
margin=dict(r=200) # Add right margin for legend | |
) | |
figures.append(fig) | |
return figures | |
def update_scaling(self, | |
revenue_scale: float, | |
download_scale: float, | |
au_scale: float, | |
show_normalized: bool) -> Tuple[List[go.Figure], Dict]: | |
"""Update the scaling parameters and redraw plots""" | |
if self.latest_sample is None: | |
return [], {} | |
# Update scales | |
self.feature_scales = { | |
0: revenue_scale, | |
1: download_scale, | |
2: au_scale | |
} | |
self.show_normalized = show_normalized | |
# Calculate metrics | |
metrics = { | |
'show_normalized': self.show_normalized | |
} | |
for feat_idx in range(self.feature_dim): | |
total = np.sum(self.latest_sample[:, feat_idx]) * (self.feature_scales[feat_idx] * 10 if not self.show_normalized else 1) | |
metrics[f'total_{self.feature_names[feat_idx]}'] = self.format_value(total, feat_idx) | |
# Update plots | |
figures = self.create_plot( | |
self.latest_sample, | |
self.latest_observed_points, | |
self.latest_observed_mask, | |
self.latest_gradient_control_signal, | |
metrics | |
) | |
return figures, metrics | |
def parse_data_points(self, df) -> Dict: | |
"""Parse data points from DataFrame with columns: time,feature,value""" | |
data_dict = {} | |
if df is None or df.empty: | |
return data_dict | |
for _, row in df.iterrows(): | |
# Skip if any required value is NaN | |
if pd.isna(row['time']) or pd.isna(row['feature']) or pd.isna(row['value']): | |
continue | |
try: | |
time_idx = int(row['time']) | |
feature_idx = int(row['feature']) | |
value = float(row['value']) | |
if time_idx not in data_dict: | |
data_dict[time_idx] = {} | |
data_dict[time_idx][feature_idx] = (value, 1.0) | |
except (ValueError, TypeError): | |
continue | |
return data_dict | |
def parse_point_groups(self, df) -> Dict: | |
"""Parse point groups from DataFrame with columns: start,end,interval,feature,value,weight""" | |
data_dict = {} | |
if df is None or df.empty: | |
return data_dict | |
for _, row in df.iterrows(): | |
# Skip if any required value is NaN | |
if pd.isna(row['start']) or pd.isna(row['end']) or pd.isna(row['interval']) or \ | |
pd.isna(row['feature']) or pd.isna(row['value']): | |
continue | |
try: | |
start = int(row['start']) | |
end = int(row['end']) | |
interval = int(row['interval']) | |
feature = int(row['feature']) | |
value = float(row['value']) | |
weight = float(row.get('weight', 1.0)) if not pd.isna(row.get('weight')) else 1.0 | |
for t in range(start, end + 1, interval): | |
if 0 <= t < self.seq_length: | |
if t not in data_dict: | |
data_dict[t] = {} | |
data_dict[t][feature] = (value, weight) | |
except (ValueError, TypeError): | |
continue | |
return data_dict | |
def to_tensor(self, observed_points_dict, seq_length, feature_dim): | |
observed_points = torch.zeros((seq_length, feature_dim)) | |
observed_weights = torch.zeros((seq_length, feature_dim)) | |
for seq, feature_dict in observed_points_dict.items(): | |
for feature, (value, weight) in feature_dict.items(): | |
observed_points[seq, feature] = value | |
observed_weights[seq, feature] = weight | |
return observed_points, observed_weights | |
def apply_direct_edits(self, sample: np.ndarray, edit_params: Dict) -> np.ndarray: | |
"""Apply direct edits to the sample array""" | |
edited_sample = sample.copy() | |
if edit_params.get("enable_direct_area"): | |
areas = self.parse_area_selections(edit_params["direct_areas"]) | |
for area in areas: | |
start, end, feat_idx, target = area | |
edited_sample[start:end, feat_idx] += target | |
edited_sample = np.clip(edited_sample, 0, 1) | |
return edited_sample | |
def parse_area_selections(self, area_text: str) -> List[Tuple]: | |
"""Parse area selection text into (start, end, feature, target) tuples""" | |
areas = [] | |
if not area_text.strip(): | |
return areas | |
area_text = area_text.replace('\n', ';') | |
for line in area_text.strip().split(';'): | |
if not line.strip(): | |
continue | |
try: | |
start, end, feat, target = map(float, line.strip().split(',')) | |
areas.append((int(start), int(end), int(feat), target)) | |
except (ValueError, IndexError): | |
continue | |
return areas | |
def apply_trending_mask(self, points: torch.Tensor, mask: torch.Tensor, consider_last_generated=False) -> Tuple[torch.Tensor, torch.Tensor]: | |
"""Apply trending functions as soft constraints through masks""" | |
if not self.trending_controls or self.latest_sample is None: | |
return points, mask | |
for start, end, feat_idx, func, confidence in self.trending_controls: | |
if start < 0 or end > self.seq_length or start >= end: | |
continue | |
# Generate x values normalized between 0 and 1 for the segment | |
x = np.linspace(0, 1, end - start) | |
try: | |
# Calculate the function values | |
y = func(x) | |
# Scale the function output to 0-1 range | |
y = (y - np.min(y)) / (np.max(y) - np.min(y)) | |
# points[start:end, feat_idx] = torch.tensor(y, dtype=points.dtype) | |
# mask[start:end, feat_idx] = max(mask[start:end, feat_idx], min(1.0, confidence * abs( | |
# self.latest_sample[start:end, feat_idx] - y | |
# ))) # Use lower weight for trending constraints | |
except Exception as e: | |
print(f"Error applying function: {e}") | |
continue | |
# Apply the trend as soft constraints | |
mask_zero = (mask[start:end, feat_idx] == 0) | |
points[start:end, feat_idx][mask_zero] = torch.tensor(y, dtype=points.dtype)[mask_zero] | |
mask[start:end, feat_idx][mask_zero] = torch.tensor(confidence * np.ones_like(y), dtype=mask.dtype)[mask_zero] | |
# mask[start:end, feat_idx][mask_zero] = torch.tensor((confidence * np.abs(self.latest_sample[start:end, feat_idx] - y)), dtype=mask.dtype)[mask_zero] | |
mask = mask.clamp(0, 1) | |
return points, mask | |
def update_model(self, | |
figures: List[go.Figure], | |
data_points: str, | |
point_groups: str, | |
enable_area_control: bool, | |
area_selections: str, | |
enable_auc: bool, | |
auc_value: float, | |
enable_peaks: bool, | |
peak_points: str, | |
peak_alpha: float, | |
auc_weight: float, | |
peak_weight: float, | |
enable_trending: bool = True, | |
enable_trending_with_diff: bool = False, | |
trending_params: str = "" | |
) -> Tuple[List[go.Figure], str, str, Dict]: | |
# Parse both point groups and individual data points | |
individual_points_dict = self.parse_data_points(data_points) | |
group_points_dict = self.parse_point_groups(point_groups) | |
# Merge dictionaries, giving preference to individual points | |
combined_points_dict = group_points_dict.copy() | |
for t, feat_dict in individual_points_dict.items(): | |
if t not in combined_points_dict: | |
combined_points_dict[t] = {} | |
for f, v in feat_dict.items(): | |
combined_points_dict[t][f] = v | |
# Convert to tensor | |
observed_points, observed_weights = self.to_tensor( | |
combined_points_dict, | |
self.seq_length, | |
self.feature_dim | |
) | |
observed_mask = observed_weights | |
# Parse peak points | |
peak_points_list = [] | |
if enable_peaks and peak_points: | |
try: | |
peak_points_list = [int(x.strip()) for x in peak_points.split(',') if x.strip()] | |
except ValueError: | |
peak_points_list = [] | |
# Apply trending control if enabled | |
if enable_trending and trending_params: | |
self.parse_trending_parameters(trending_params) | |
observed_points, observed_mask = self.apply_trending_mask(observed_points, observed_mask, consider_last_generated=enable_trending_with_diff) | |
# Build gradient control signal | |
# IMPORTANT | |
gradient_control_signal = {} | |
if enable_auc: | |
gradient_control_signal["auc"] = auc_value | |
gradient_control_signal["auc_weight"] = auc_weight | |
if enable_peaks: | |
gradient_control_signal.update({ | |
"peak_points": peak_points_list, | |
"peak_alpha": peak_alpha, | |
"peak_weight": peak_weight | |
}) | |
# Build model control signal | |
model_control_signal = {} | |
# if enable_area_control and area_selections: | |
# areas = self.parse_area_selections(area_selections) | |
# if areas: | |
# model_control_signal["selected_areas"] = areas | |
# Run prediction | |
with torch.no_grad(): | |
# to cuda | |
observed_points = observed_points.to(self.device) | |
observed_mask = observed_mask.to(self.device) | |
sample = self.trainer.predict_weighted_points( | |
observed_points, # (seq_length, feature_dim) | |
observed_mask, # (seq_length, feature_dim) | |
self.coef, # fixed | |
self.stepsize, # fixed | |
self.sampling_steps, # fixed | |
# model_control_signal=model_control_signal, | |
gradient_control_signal=gradient_control_signal | |
) | |
observed_points = observed_points.cpu() | |
observed_mask = observed_mask.cpu() | |
# Store latest results | |
self.latest_sample = sample | |
self.latest_observed_points = observed_points | |
self.latest_observed_mask = observed_mask | |
self.latest_gradient_control_signal = gradient_control_signal | |
self.latest_model_control_signal = model_control_signal | |
# Calculate metrics | |
metrics = { | |
'show_normalized': self.show_normalized | |
} | |
for feat_idx in range(self.feature_dim): | |
total = np.sum(sample[:, feat_idx]) * (self.feature_scales[feat_idx] * 10 if not self.show_normalized else 1) | |
metrics[f'total_{self.feature_names[feat_idx]}'] = self.format_value(total, feat_idx) | |
# Update plots | |
figures = self.create_plot(sample, observed_points, observed_mask, gradient_control_signal, metrics) | |
return figures, data_points, point_groups, metrics | |
def update_additional_edit( | |
self, | |
enable_direct_area: bool, | |
direct_areas: str): | |
# Apply direct edits if enabled | |
if enable_direct_area: | |
sample = self.apply_direct_edits(self.latest_sample, { | |
"enable_direct_area": enable_direct_area, | |
"direct_areas": direct_areas | |
}) | |
else: | |
sample = self.latest_sample | |
# Calculate metrics | |
metrics = { | |
'show_normalized': self.show_normalized | |
} | |
for feat_idx in range(self.feature_dim): | |
total = np.sum(sample[:, feat_idx]) * (self.feature_scales[feat_idx] * 10 if not self.show_normalized else 1) | |
metrics[f'total_{self.feature_names[feat_idx]}'] = self.format_value(total, feat_idx) | |
# Update plots | |
figures = self.create_plot( | |
sample, | |
self.latest_observed_points, | |
self.latest_observed_mask, | |
self.latest_gradient_control_signal, | |
metrics | |
) | |
return figures, metrics | |
def apply_frequency_filter(self, signal: np.ndarray) -> np.ndarray: | |
"""Apply FFT-based frequency filtering using the current band multipliers""" | |
# Get FFT of the signal | |
fft = np.fft.fft(signal) | |
freqs = np.fft.fftfreq(len(signal)) | |
# Split frequencies into 5 bands | |
# Exclude DC component (0 frequency) from bands | |
pos_freqs = freqs[1:len(freqs)//2] | |
freq_ranges = np.array_split(pos_freqs, 5) | |
# Apply band multipliers | |
filtered_fft = fft.copy() | |
# Handle DC component separately (lowest frequency) | |
filtered_fft[0] *= self.freq_bands[4] # Apply very low freq multiplier to DC | |
# Apply multipliers to each frequency band | |
for i, freq_range in enumerate(freq_ranges): | |
# Get indices for this frequency band | |
band_mask = np.logical_and( | |
freqs >= freq_range[0], | |
freqs <= freq_range[-1] | |
) | |
# Apply multiplier to positive and negative frequencies symmetrically | |
filtered_fft[band_mask] *= self.freq_bands[4-i] | |
filtered_fft[np.flip(band_mask)] *= self.freq_bands[4-i] | |
# Convert back to time domain | |
filtered_signal = np.real(np.fft.ifft(filtered_fft)) | |
return filtered_signal | |
def update_frequency_bands(self, band_idx: int, value: float) -> Tuple[List[go.Figure], Dict]: | |
"""Update a frequency band multiplier and recompute the filtered signal""" | |
if self.latest_sample is None: | |
return [], {} | |
# Update the specified band multiplier | |
self.freq_bands[band_idx] = value | |
# Apply frequency filtering to each feature | |
filtered_sample = self.latest_sample.copy() | |
for feat_idx in range(self.feature_dim): | |
filtered_sample[:, feat_idx] = self.apply_frequency_filter( | |
self.latest_sample[:, feat_idx] | |
) | |
# Ensure values remain in valid range | |
filtered_sample = np.clip(filtered_sample, 0, 1) | |
# Calculate metrics | |
metrics = { | |
'show_normalized': self.show_normalized, | |
'frequency_bands': self.freq_bands.tolist() | |
} | |
for feat_idx in range(self.feature_dim): | |
total = np.sum(filtered_sample[:, feat_idx]) * (self.feature_scales[feat_idx] * 10 if not self.show_normalized else 1) | |
metrics[f'total_{self.feature_names[feat_idx]}'] = self.format_value(total, feat_idx) | |
# Update plots | |
figures = self.create_plot( | |
filtered_sample, | |
self.latest_observed_points, | |
self.latest_observed_mask, | |
self.latest_gradient_control_signal, | |
metrics | |
) | |
return figures, metrics | |
def parse_trending_parameters(self, trending_text: str) -> List[Tuple]: | |
"""Parse trending control parameters into (start, end, feature, function) tuples""" | |
trending_params = [] | |
if not trending_text.strip(): | |
return trending_params | |
trending_text = trending_text.replace('\n', ';') | |
for line in trending_text.strip().split(';'): | |
if not line.strip(): | |
continue | |
try: | |
# Split by comma and handle the function part separately | |
parts = line.strip().split(',', 4) | |
if len(parts) != 5: | |
continue | |
start, end, feat = map(int, parts[:3]) | |
function_str = parts[3].strip() | |
confidence = float(parts[4]) | |
# Convert the function string to a callable | |
try: | |
func = self.function_parser.string_to_function(function_str) | |
trending_params.append((start, end, feat, func, confidence)) | |
except ValueError as e: | |
print(f"Error parsing function '{function_str}': {e}") | |
continue | |
except (ValueError, IndexError): | |
continue | |
self.trending_controls = trending_params # Store the parsed parameters | |
return trending_params | |
def create_gradio_interface(editor: TimeSeriesEditor): | |
with gr.Blocks() as app: | |
gr.Markdown("# Time Series Editor") | |
gr.Markdown("## Instruction: Scroll Down + Click [Update Figure] [~10s]") | |
metrics_display = gr.JSON(label="Metrics", value={}) | |
with gr.Row(): | |
with gr.Column(scale=1): | |
# with Tab(): | |
# Scaling Parameters Section | |
# with gr.Group(): | |
gr.Markdown("## Scaling Parameters") | |
with gr.Accordion("Open for More Detail", open=False): | |
revenue_scale = gr.Number( | |
label="Revenue Scale ($ per 0.1 in model)", | |
value=1000000 | |
) | |
download_scale = gr.Number( | |
label="Download Scale (downloads per 0.1 in model)", | |
value=100000 | |
) | |
au_scale = gr.Number( | |
label="Active Users Scale (users per 0.1 in model)", | |
value=10000 | |
) | |
show_normalized = gr.Checkbox( | |
label="Show Normalized Values (0-1 scale)", | |
value=True | |
) | |
update_scaling_btn = gr.Button("Update Scaling") | |
# TS Section | |
gr.Markdown("## Time Series Control Panel") | |
# with gr.Accordion("Open for More Detail"): | |
with gr.Group(): | |
gr.Markdown("### Fixed Point Control") | |
data_points_df = gr.Dataframe( | |
headers=["time", "feature", "value"], | |
datatype=["number", "number", "number"], | |
# label="Anchor Point Control", | |
value=[[0, 0, 0.04], [2, 0, 0.58], [6, 0, 0.27], [58, 0, 0.8], [60, 0, 0.5]], | |
col_count=(3, "fixed"), # Fix number of columns | |
interactive=True | |
) | |
add_data_point_btn = gr.Button("Add Data Point") | |
def add_data_point(df): | |
new_row = pd.DataFrame([[None, 0, None]], | |
columns=["time", "feature", "value"]) | |
return pd.concat([df, new_row], ignore_index=True) | |
add_data_point_btn.click( | |
fn=add_data_point, | |
inputs=[data_points_df], | |
outputs=[data_points_df] | |
) | |
with gr.Group(): | |
gr.Markdown("### Group of Anchor Point Control with Confidence") | |
point_groups_df = gr.Dataframe( | |
headers=["start", "end", "interval", "feature", "value", "weight"], | |
datatype=["number", "number", "number", "number", "number", "number"], | |
# label="Group of Anchor Point Control", | |
value=[[0, 50, 10, 0, 0.5, 0.1], [100, 150, 50, 0, 0.1, 0.5]], | |
col_count=(6, "fixed"), # Fix number of columns | |
interactive=True | |
) | |
add_point_group_btn = gr.Button("Add Point Group") | |
def add_point_group(df): | |
new_row = pd.DataFrame([[None, None, None, 0, None, None]], | |
columns=["start", "end", "interval", "feature", "value", "weight"]) | |
return pd.concat([df, new_row], ignore_index=True) | |
add_point_group_btn.click( | |
fn=add_point_group, | |
inputs=[point_groups_df], | |
outputs=[point_groups_df] | |
) | |
with gr.Group(): | |
# with gr.Tab("Trending Control"): | |
gr.Markdown("### Trending Control") | |
gr.Markdown(""" | |
Enter trending control parameters in the format: | |
``` | |
start_time,end_time,feature,function,confidence | |
``` | |
Examples: | |
- Linear trend: `0,100,0,x` | |
- Sine wave: `0,100,0,sin(2*pi*x)` | |
- Exponential: `0,100,0,exp(-x)` | |
Separate multiple trends with semicolons. | |
""") | |
enable_trending_control = gr.Checkbox(label="Enable Trending Control", value=False) | |
enable_trending_control_with_diff = gr.Checkbox(label="Consider Last Generated", value=False) | |
trending_control = gr.Textbox( | |
label="Trending Control Parameters", | |
lines=2, | |
placeholder="Enter parameters: start_time,end_time,feature,function,condifdence; separated by semicolons", | |
value="200,250,0,sin(2*pi*x),0.05" | |
) | |
# Area Control Parameters | |
with gr.Group(visible=False): | |
gr.Markdown("### Area Control") | |
enable_area_control = gr.Checkbox(label="Enable Area Control", value=False) | |
area_selections = gr.Textbox( | |
label="Area Selections (format: start_time,end_time,feature,target_value)", | |
lines=2, | |
placeholder="Enter areas: start,end,feature,target; separated by semicolons", | |
) | |
# AUC Parameters | |
gr.Markdown("### Statistics Control") | |
enable_auc = gr.Checkbox(label="Enable Total Sum Control", value=True) | |
auc_input = gr.Number(label="Target Sum Value", value=-150) | |
auc_weight_input = gr.Number(label="Sum Weight", value=10.0) | |
# Peak Parameters | |
with gr.Group(visible=False): | |
gr.Markdown("### Peak Control") | |
enable_peaks = gr.Checkbox(label="Enable Peak Control", value=False) | |
peak_points_input = gr.Textbox(label="Peak Points (comma-separated)", value="100,200") | |
peak_alpha_input = gr.Number(label="Peak Alpha", value=10) | |
peak_weight_input = gr.Number(label="Peak Weight", value=1.0) | |
update_model_btn = gr.Button("Update Figure") | |
gr.Markdown("## Extend Edit", visible=False) | |
with gr.Tab("Range Shift", visible=False): | |
# gr.Markdown("### Direct Edit Control") | |
enable_direct_area = gr.Checkbox(label="Enable Direct Edits", value=False) # range shift | |
direct_areas = gr.Textbox( | |
label="Direct Edit Areas (format: start_time,end_time,feature,delta)", | |
lines=2, | |
placeholder="Enter areas: start,end,feature,delta; separated by semicolons", | |
value="150,200,0,-0.1" | |
) | |
update_additional_btn = gr.Button("Update Additional Edit") | |
# with gr.Tab("Trending Control"): | |
# gr.Markdown("### Trending Control") | |
# gr.Markdown(""" | |
# Enter trending control parameters in the format: | |
# ``` | |
# start_time,end_time,feature,function | |
# ``` | |
# Examples: | |
# - Linear trend: `0,100,0,x` | |
# - Sine wave: `0,100,0,sin(2*pi*x)` | |
# - Exponential: `0,100,0,exp(-x)` | |
# Separate multiple trends with semicolons. | |
# """) | |
# enable_trending_control = gr.Checkbox(label="Enable Trending Control", value=False) | |
# enable_trending_control_with_diff = gr.Checkbox(label="Consider Last Generated", value=False) | |
# trending_control = gr.Textbox( | |
# label="Trending Control Parameters", | |
# lines=2, | |
# placeholder="Enter parameters: start_time,end_time,feature,function,condifdence; separated by semicolons", | |
# value="0,100,0,sin(2*pi*x),0.3" | |
# ) | |
# with gr.Tab("Frequency Controls", visible=False): | |
with gr.Group(visible=False): | |
gr.Markdown("Adjust multipliers for different frequency bands (0-2)") | |
freq_bands = [ | |
gr.Slider( | |
minimum=0, maximum=2, step=0.1, value=1.0, | |
label=f"Band {i+1}: {'Very High' if i==0 else 'High' if i==1 else 'Mid' if i==2 else 'Low' if i==3 else 'Very Low'} Freq", | |
) for i in range(5) | |
] | |
gr.Markdown("### Feature Index Reference:") | |
for idx, name in enumerate(editor.feature_names): | |
gr.Markdown(f"- {idx}: {name}") | |
with gr.Column(scale=1.2): | |
gr.Markdown(""" | |
### Plot Legend | |
- **Points with Error Bars**: Observed values where: | |
- Point position = observed value | |
- Error bar size = uncertainty (inversely proportional to weight) | |
- **Green Line**: Model prediction | |
- **Vertical Red Lines**: Peak points (if enabled) | |
""") | |
plots = [gr.Plot() for _ in range(editor.feature_dim)] | |
# - **Shaded Area**: General prediction uncertainty | |
def update_scaling_callback(revenue_scale, download_scale, au_scale, show_normalized): | |
figs, metrics = editor.update_scaling( | |
revenue_scale, | |
download_scale, | |
au_scale, | |
show_normalized | |
) | |
return [*figs, metrics] | |
def update_model_callback( | |
data_points_df, | |
point_groups_df, | |
enable_area_control, | |
area_selections, | |
enable_auc, | |
auc, | |
auc_weight, | |
enable_peaks, | |
peak_points, | |
peak_alpha, | |
peak_weight, | |
enable_trending, | |
enable_trending_with_diff, | |
trending_params | |
): | |
figs, _, _, metrics = editor.update_model( | |
plots, | |
data_points_df, | |
point_groups_df, | |
enable_area_control, | |
area_selections, | |
enable_auc, | |
auc, | |
enable_peaks, | |
peak_points, | |
peak_alpha, | |
auc_weight, | |
peak_weight, | |
enable_trending, | |
enable_trending_with_diff, | |
trending_params | |
) | |
return [*figs, metrics] | |
# Update the click handler | |
update_model_btn.click( | |
fn=update_model_callback, | |
inputs=[ | |
data_points_df, | |
point_groups_df, | |
enable_area_control, | |
area_selections, | |
enable_auc, | |
auc_input, | |
auc_weight_input, | |
enable_peaks, | |
peak_points_input, | |
peak_alpha_input, | |
peak_weight_input, | |
enable_trending_control, | |
enable_trending_control_with_diff, | |
trending_control | |
], | |
outputs=[*plots, metrics_display] | |
) | |
def update_additional_callback(enable_direct_area, direct_areas): | |
figs, metrics = editor.update_additional_edit( | |
enable_direct_area, | |
direct_areas | |
) | |
return [*figs, metrics] | |
def update_freq_band(band_idx, value): | |
figs, metrics = editor.update_frequency_bands(band_idx, value) | |
return [*figs, metrics] | |
update_scaling_btn.click( | |
fn=update_scaling_callback, | |
inputs=[ | |
revenue_scale, | |
download_scale, | |
au_scale, | |
show_normalized | |
], | |
outputs=[*plots, metrics_display] | |
) | |
update_additional_btn.click( | |
fn=update_additional_callback, | |
inputs=[enable_direct_area, direct_areas], | |
outputs=[*plots, metrics_display] | |
) | |
# Add event handlers for frequency band sliders | |
for i, slider in enumerate(freq_bands): | |
slider.change( | |
fn=update_freq_band, | |
inputs=[gr.Number(value=i, visible=False), slider], | |
outputs=[*plots, metrics_display] | |
) | |
app.load( | |
fn=update_model_callback, | |
inputs=[ | |
data_points_df, | |
point_groups_df, | |
enable_area_control, | |
area_selections, | |
enable_auc, | |
auc_input, | |
auc_weight_input, | |
enable_peaks, | |
peak_points_input, | |
peak_alpha_input, | |
peak_weight_input, | |
enable_trending_control, | |
enable_trending_control_with_diff, | |
trending_control | |
], | |
outputs=[*plots, metrics_display] | |
) | |
return app | |
class FunctionParser: | |
def __init__(self): | |
# Define available mathematical functions and constants | |
self.math_functions = { | |
'sin': np.sin, | |
'cos': np.cos, | |
'tan': np.tan, | |
'exp': np.exp, | |
'log': np.log, | |
'sqrt': np.sqrt, | |
'abs': np.abs, | |
'pow': np.power, | |
'pi': np.pi, | |
'e': np.e, | |
'asin': np.arcsin, | |
'acos': np.arccos, | |
'atan': np.arctan, | |
'sinh': np.sinh, | |
'cosh': np.cosh, | |
'tanh': np.tanh | |
} | |
def validate_expression(self, expression: str) -> bool: | |
""" | |
Validate the mathematical expression for basic syntax errors. | |
""" | |
# Check for balanced parentheses | |
if expression.count('(') != expression.count(')'): | |
raise ValueError("Unbalanced parentheses in expression") | |
# Check for invalid characters | |
valid_chars = set('0123456789.+-*/()^ xXepi,') | |
valid_chars.update(''.join(self.math_functions.keys())) | |
if not all(c in valid_chars or c.isspace() for c in expression.lower()): | |
raise ValueError("Expression contains invalid characters") | |
return True | |
def preprocess_expression(self, expression: str) -> str: | |
""" | |
Preprocess the expression to handle various input formats. | |
""" | |
# Remove whitespace | |
expression = expression.replace(' ', '') | |
# Convert ^ to ** for exponentiation | |
expression = expression.replace('^', '**') | |
# Ensure multiplication is explicit | |
expression = re.sub(r'(\d+)([a-zA-Z])', r'\1*\2', expression) | |
expression = re.sub(r'(\))([\w])', r'\1*\2', expression) | |
# Replace X with x for consistency | |
expression = expression.lower() | |
return expression | |
def string_to_function(self, expression: str) -> Callable[[Union[float, np.ndarray]], Union[float, np.ndarray]]: | |
""" | |
Convert a string mathematical expression to a callable function. | |
Args: | |
expression (str): Mathematical expression (e.g., "sin(x) + x^2") | |
Returns: | |
Callable: A function that takes x as input and returns the evaluated result | |
Example: | |
>>> f = string_to_function("sin(x) + x^2") | |
>>> f(0.5) | |
0.729321... | |
""" | |
# Validate and preprocess the expression | |
self.validate_expression(expression) | |
processed_expr = self.preprocess_expression(expression) | |
# Create the function namespace | |
namespace = self.math_functions.copy() | |
try: | |
# Create the lambda function | |
func = eval(f"lambda x: {processed_expr}", namespace) | |
# Test the function with a simple input | |
test_value = 1.0 | |
try: | |
func(test_value) | |
except Exception as e: | |
raise ValueError(f"Invalid function: {str(e)}") | |
return func | |
except SyntaxError as e: | |
raise ValueError(f"Invalid expression syntax: {str(e)}") | |
except Exception as e: | |
raise ValueError(f"Error creating function: {str(e)}") | |
def demonstrate_usage(): | |
""" | |
Demonstrate various uses of the function parser. | |
""" | |
parser = FunctionParser() | |
# Test cases | |
test_expressions = [ | |
"x^2 + 2*x + 1", | |
"sin(x) + cos(x)", | |
"exp(-x^2)", | |
"log(x + 1)", | |
"sqrt(1 - x^2)", | |
] | |
print("Testing various mathematical expressions:") | |
x_test = 0.5 | |
for expr in test_expressions: | |
try: | |
print(f"\nExpression: {expr}") | |
func = parser.string_to_function(expr) | |
result = func(x_test) | |
print(f"f({x_test}) = {result}") | |
# Test with numpy array | |
x_array = np.linspace(0, 1, 5) | |
results = func(x_array) | |
print(f"f(array) = {results}") | |
except Exception as e: | |
print(f"Error: {str(e)}") | |
# Example usage: | |
if __name__ == "__main__": | |
import os | |
import torch | |
import numpy as np | |
# assert torch.cuda.is_available(), "CUDA must be available" | |
os.environ["WANDB_ENABLED"] = "false" | |
print(os.getcwd()) | |
device = torch.device(f"cuda:0") if torch.cuda.is_available() else "cpu" | |
print(f"Device: {device}") | |
print(f"Using device: {device}") | |
from models.Tiffusion import tiffusion | |
model = tiffusion.Tiffusion( | |
seq_length=365, | |
feature_size=3, | |
n_layer_enc=6, | |
n_layer_dec=4, | |
d_model=128, | |
timesteps=500, | |
sampling_timesteps=200, | |
loss_type='l1', | |
beta_schedule='cosine', | |
n_heads=8, | |
mlp_hidden_times=4, | |
attn_pd=0.0, | |
resid_pd=0.0, | |
kernel_size=1, | |
padding_size=0, | |
control_signal=[] | |
).to(device) | |
model.load_state_dict(torch.load("./weight/checkpoint-10.pt", map_location=device, weights_only=True)["model"]) | |
coef = 1.0e-2 | |
stepsize = 5.0e-2 | |
sampling_steps = 100 # Adjustable between 100-500 for speed/accuracy tradeoff | |
seq_length = 365 | |
feature_dim = 3 | |
print(f"seq_length: {seq_length}, feature_dim: {feature_dim}") | |
editor = TimeSeriesEditor(seq_length, feature_dim, model) | |
editor.coef = coef | |
editor.stepsize = stepsize | |
editor.sampling_steps = sampling_steps | |
app = create_gradio_interface(editor) | |
app.launch(show_api=False) | |