TSEditor / app.py
PeterYu's picture
update
aac28f7
import gradio as gr
import pandas as pd
import plotly.graph_objects as go
import numpy as np
import torch
from typing import Dict, List, Tuple
import re
from typing import Callable, Union, Dict
class TimeSeriesEditor:
def __init__(self, seq_length: int, feature_dim: int, trainer):
# Existing initialization
self.seq_length = seq_length
self.feature_dim = feature_dim
self.trainer = trainer
self.coef = None
self.stepsize = None
self.sampling_steps = None
self.feature_names = ["revenue", "download", "daily active user"]# * 20
# self.feature_names = [f"Feature {i}" for i in range(self.feature_dim)]
# Store the latest model output
self.latest_sample = None
self.latest_observed_points = None
self.latest_observed_mask = None
self.latest_gradient_control_signal = None
self.latest_model_control_signal = None
# self.latest_metrics
# Define scales for each feature
self.feature_scales = {
0: 1000000, # Revenue: $1M per 0.1
1: 100000, # Download: 100K downloads per 0.1
2: 10000 # AU: 10K active users per 0.1
}
self.feature_units = {
0: "$", # Revenue
1: "downloads", # Download
2: "users" # AU
}
self.show_normalized = True
# Add frequency band multipliers
self.freq_bands = np.ones(5) # 5 frequency bands, initially all set to 1.0
self.function_parser = FunctionParser()
self.trending_controls = [
# (200, 250, 0, self.function_parser.string_to_function("sin(2*pi*x)"), 0.05)
# 200,250,0,sin(2*pi*x),0.05
]
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def format_value(self, value: float, feature_idx: int) -> str:
"""Format value with appropriate units and notation"""
if self.show_normalized:
return f"{value:.4f}"
else:
if feature_idx == 0: # Revenue
return f"{self.feature_units[feature_idx]}{value:,.2f}"
else: # Downloads and AU
return f"{value:,.0f} {self.feature_units[feature_idx]}"
def create_plot(self, sample: np.ndarray, observed_points: torch.Tensor,
observed_mask: torch.Tensor,
gradient_control_signal: Dict, metrics: Dict) -> List[go.Figure]:
figures = []
# Get weights from model_control_signal (will be all 1s if not provided)
weights = observed_mask
for feat_idx in range(self.feature_dim):
fig = go.Figure()
# Scale values if needed
scale_factor = self.feature_scales[feat_idx] * 10 if not self.show_normalized else 1
# Plot predicted line
predicted_values = sample[:, feat_idx] * scale_factor
fig.add_trace(go.Scatter(
x=np.arange(self.seq_length),
y=predicted_values,
mode='lines',
name='Predicted',
line=dict(color='green', width=2),
showlegend=True
))
# Calculate and plot confidence bands based on weights
# Lower weights = larger uncertainty bands
mask = observed_points[:, feat_idx] > 0
ox = np.arange(0, self.seq_length)[mask]
oy = observed_points[mask, feat_idx].numpy() * scale_factor
weights_masked = 1 - weights[mask, feat_idx].numpy()
# Calculate error bars - inverse relationship with weight
# Weight of 1.0 gives minimal uncertainty (0.02)
# Weight of 0.1 gives larger uncertainty (0.2)
# error_y = 0.02 / weights_masked
error_y = weights_masked / 5
# Plot observed points with error bars - changed symbol to 'cross'
fig.add_trace(go.Scatter(
x=ox,
y=oy,
mode='markers',
name='Observed',
marker=dict(
# special red
color='rgba(255, 0, 0, 0.5)',
# size=10,
symbol='x', # Changed from 'circle' to 'x' for cross symbol
),
error_y=dict(
type='data',
array=error_y * scale_factor,
visible=True,
thickness=0.5,
width=2,
color='blue'
),
showlegend=True
))
# Add shaded confidence bands around the predicted line
# This shows the general uncertainty in the prediction
uncertainty = 0.05 # Base uncertainty level
upper_bound = predicted_values + uncertainty * scale_factor
lower_bound = predicted_values - uncertainty * scale_factor
fig.add_trace(go.Scatter(
x=np.concatenate([np.arange(self.seq_length), np.arange(self.seq_length)[::-1]]),
y=np.concatenate([upper_bound, lower_bound[::-1]]),
# fill='toself',
# fillcolor='rgba(0,100,0,0.1)',
line=dict(color='rgba(255,255,255,0)'),
name='Prediction Interval',
showlegend=True
))
# Add vertical lines for peak points
if gradient_control_signal.get("peak_points"):
for peak_point in gradient_control_signal["peak_points"]:
fig.add_vline(x=peak_point, line_dash="dash", line_color="red")
# Add metrics annotations
total_value = np.sum(sample[:, feat_idx]) * (self.feature_scales[feat_idx] * 10 if not self.show_normalized else 1)
annotations = [dict(
x=0.02,
y=1.1,
xref="paper",
yref="paper",
text=f"Total {self.feature_names[feat_idx]}: {self.format_value(total_value, feat_idx)}",
showarrow=False
)]
# Update y-axis title based on feature and scaling
if self.show_normalized:
y_title = f'{self.feature_names[feat_idx]} (Normalized)'
else:
unit = self.feature_units[feat_idx]
y_title = f'{self.feature_names[feat_idx]} ({unit})'
# Create a more informative legend for uncertainty
legend_text = (
"Prediction with Confidence Bands<br>"
"• Blue points: Observed values with uncertainty<br>"
"• Green line: Predicted values<br>"
# "• Shaded area: Prediction uncertainty<br>"
"• Error bars: Observation uncertainty (larger = lower weight)"
)
fig.update_layout(
title=dict(
text=f'Feature: {self.feature_names[feat_idx]}',
x=0.5,
y=0.95
),
xaxis_title='Time',
yaxis_title=y_title,
height=400,
showlegend=True,
dragmode='select',
annotations=[
*annotations,
# dict(
# x=1.15,
# y=0.5,
# xref="paper",
# yref="paper",
# text=legend_text,
# showarrow=False,
# align="left",
# bordercolor="black",
# borderwidth=1,
# borderpad=4,
# bgcolor="white",
# )
],
margin=dict(r=200) # Add right margin for legend
)
figures.append(fig)
return figures
def update_scaling(self,
revenue_scale: float,
download_scale: float,
au_scale: float,
show_normalized: bool) -> Tuple[List[go.Figure], Dict]:
"""Update the scaling parameters and redraw plots"""
if self.latest_sample is None:
return [], {}
# Update scales
self.feature_scales = {
0: revenue_scale,
1: download_scale,
2: au_scale
}
self.show_normalized = show_normalized
# Calculate metrics
metrics = {
'show_normalized': self.show_normalized
}
for feat_idx in range(self.feature_dim):
total = np.sum(self.latest_sample[:, feat_idx]) * (self.feature_scales[feat_idx] * 10 if not self.show_normalized else 1)
metrics[f'total_{self.feature_names[feat_idx]}'] = self.format_value(total, feat_idx)
# Update plots
figures = self.create_plot(
self.latest_sample,
self.latest_observed_points,
self.latest_observed_mask,
self.latest_gradient_control_signal,
metrics
)
return figures, metrics
def parse_data_points(self, df) -> Dict:
"""Parse data points from DataFrame with columns: time,feature,value"""
data_dict = {}
if df is None or df.empty:
return data_dict
for _, row in df.iterrows():
# Skip if any required value is NaN
if pd.isna(row['time']) or pd.isna(row['feature']) or pd.isna(row['value']):
continue
try:
time_idx = int(row['time'])
feature_idx = int(row['feature'])
value = float(row['value'])
if time_idx not in data_dict:
data_dict[time_idx] = {}
data_dict[time_idx][feature_idx] = (value, 1.0)
except (ValueError, TypeError):
continue
return data_dict
def parse_point_groups(self, df) -> Dict:
"""Parse point groups from DataFrame with columns: start,end,interval,feature,value,weight"""
data_dict = {}
if df is None or df.empty:
return data_dict
for _, row in df.iterrows():
# Skip if any required value is NaN
if pd.isna(row['start']) or pd.isna(row['end']) or pd.isna(row['interval']) or \
pd.isna(row['feature']) or pd.isna(row['value']):
continue
try:
start = int(row['start'])
end = int(row['end'])
interval = int(row['interval'])
feature = int(row['feature'])
value = float(row['value'])
weight = float(row.get('weight', 1.0)) if not pd.isna(row.get('weight')) else 1.0
for t in range(start, end + 1, interval):
if 0 <= t < self.seq_length:
if t not in data_dict:
data_dict[t] = {}
data_dict[t][feature] = (value, weight)
except (ValueError, TypeError):
continue
return data_dict
def to_tensor(self, observed_points_dict, seq_length, feature_dim):
observed_points = torch.zeros((seq_length, feature_dim))
observed_weights = torch.zeros((seq_length, feature_dim))
for seq, feature_dict in observed_points_dict.items():
for feature, (value, weight) in feature_dict.items():
observed_points[seq, feature] = value
observed_weights[seq, feature] = weight
return observed_points, observed_weights
def apply_direct_edits(self, sample: np.ndarray, edit_params: Dict) -> np.ndarray:
"""Apply direct edits to the sample array"""
edited_sample = sample.copy()
if edit_params.get("enable_direct_area"):
areas = self.parse_area_selections(edit_params["direct_areas"])
for area in areas:
start, end, feat_idx, target = area
edited_sample[start:end, feat_idx] += target
edited_sample = np.clip(edited_sample, 0, 1)
return edited_sample
def parse_area_selections(self, area_text: str) -> List[Tuple]:
"""Parse area selection text into (start, end, feature, target) tuples"""
areas = []
if not area_text.strip():
return areas
area_text = area_text.replace('\n', ';')
for line in area_text.strip().split(';'):
if not line.strip():
continue
try:
start, end, feat, target = map(float, line.strip().split(','))
areas.append((int(start), int(end), int(feat), target))
except (ValueError, IndexError):
continue
return areas
def apply_trending_mask(self, points: torch.Tensor, mask: torch.Tensor, consider_last_generated=False) -> Tuple[torch.Tensor, torch.Tensor]:
"""Apply trending functions as soft constraints through masks"""
if not self.trending_controls or self.latest_sample is None:
return points, mask
for start, end, feat_idx, func, confidence in self.trending_controls:
if start < 0 or end > self.seq_length or start >= end:
continue
# Generate x values normalized between 0 and 1 for the segment
x = np.linspace(0, 1, end - start)
try:
# Calculate the function values
y = func(x)
# Scale the function output to 0-1 range
y = (y - np.min(y)) / (np.max(y) - np.min(y))
# points[start:end, feat_idx] = torch.tensor(y, dtype=points.dtype)
# mask[start:end, feat_idx] = max(mask[start:end, feat_idx], min(1.0, confidence * abs(
# self.latest_sample[start:end, feat_idx] - y
# ))) # Use lower weight for trending constraints
except Exception as e:
print(f"Error applying function: {e}")
continue
# Apply the trend as soft constraints
mask_zero = (mask[start:end, feat_idx] == 0)
points[start:end, feat_idx][mask_zero] = torch.tensor(y, dtype=points.dtype)[mask_zero]
mask[start:end, feat_idx][mask_zero] = torch.tensor(confidence * np.ones_like(y), dtype=mask.dtype)[mask_zero]
# mask[start:end, feat_idx][mask_zero] = torch.tensor((confidence * np.abs(self.latest_sample[start:end, feat_idx] - y)), dtype=mask.dtype)[mask_zero]
mask = mask.clamp(0, 1)
return points, mask
def update_model(self,
figures: List[go.Figure],
data_points: str,
point_groups: str,
enable_area_control: bool,
area_selections: str,
enable_auc: bool,
auc_value: float,
enable_peaks: bool,
peak_points: str,
peak_alpha: float,
auc_weight: float,
peak_weight: float,
enable_trending: bool = True,
enable_trending_with_diff: bool = False,
trending_params: str = ""
) -> Tuple[List[go.Figure], str, str, Dict]:
# Parse both point groups and individual data points
individual_points_dict = self.parse_data_points(data_points)
group_points_dict = self.parse_point_groups(point_groups)
# Merge dictionaries, giving preference to individual points
combined_points_dict = group_points_dict.copy()
for t, feat_dict in individual_points_dict.items():
if t not in combined_points_dict:
combined_points_dict[t] = {}
for f, v in feat_dict.items():
combined_points_dict[t][f] = v
# Convert to tensor
observed_points, observed_weights = self.to_tensor(
combined_points_dict,
self.seq_length,
self.feature_dim
)
observed_mask = observed_weights
# Parse peak points
peak_points_list = []
if enable_peaks and peak_points:
try:
peak_points_list = [int(x.strip()) for x in peak_points.split(',') if x.strip()]
except ValueError:
peak_points_list = []
# Apply trending control if enabled
if enable_trending and trending_params:
self.parse_trending_parameters(trending_params)
observed_points, observed_mask = self.apply_trending_mask(observed_points, observed_mask, consider_last_generated=enable_trending_with_diff)
# Build gradient control signal
# IMPORTANT
gradient_control_signal = {}
if enable_auc:
gradient_control_signal["auc"] = auc_value
gradient_control_signal["auc_weight"] = auc_weight
if enable_peaks:
gradient_control_signal.update({
"peak_points": peak_points_list,
"peak_alpha": peak_alpha,
"peak_weight": peak_weight
})
# Build model control signal
model_control_signal = {}
# if enable_area_control and area_selections:
# areas = self.parse_area_selections(area_selections)
# if areas:
# model_control_signal["selected_areas"] = areas
# Run prediction
with torch.no_grad():
# to cuda
observed_points = observed_points.to(self.device)
observed_mask = observed_mask.to(self.device)
sample = self.trainer.predict_weighted_points(
observed_points, # (seq_length, feature_dim)
observed_mask, # (seq_length, feature_dim)
self.coef, # fixed
self.stepsize, # fixed
self.sampling_steps, # fixed
# model_control_signal=model_control_signal,
gradient_control_signal=gradient_control_signal
)
observed_points = observed_points.cpu()
observed_mask = observed_mask.cpu()
# Store latest results
self.latest_sample = sample
self.latest_observed_points = observed_points
self.latest_observed_mask = observed_mask
self.latest_gradient_control_signal = gradient_control_signal
self.latest_model_control_signal = model_control_signal
# Calculate metrics
metrics = {
'show_normalized': self.show_normalized
}
for feat_idx in range(self.feature_dim):
total = np.sum(sample[:, feat_idx]) * (self.feature_scales[feat_idx] * 10 if not self.show_normalized else 1)
metrics[f'total_{self.feature_names[feat_idx]}'] = self.format_value(total, feat_idx)
# Update plots
figures = self.create_plot(sample, observed_points, observed_mask, gradient_control_signal, metrics)
return figures, data_points, point_groups, metrics
def update_additional_edit(
self,
enable_direct_area: bool,
direct_areas: str):
# Apply direct edits if enabled
if enable_direct_area:
sample = self.apply_direct_edits(self.latest_sample, {
"enable_direct_area": enable_direct_area,
"direct_areas": direct_areas
})
else:
sample = self.latest_sample
# Calculate metrics
metrics = {
'show_normalized': self.show_normalized
}
for feat_idx in range(self.feature_dim):
total = np.sum(sample[:, feat_idx]) * (self.feature_scales[feat_idx] * 10 if not self.show_normalized else 1)
metrics[f'total_{self.feature_names[feat_idx]}'] = self.format_value(total, feat_idx)
# Update plots
figures = self.create_plot(
sample,
self.latest_observed_points,
self.latest_observed_mask,
self.latest_gradient_control_signal,
metrics
)
return figures, metrics
def apply_frequency_filter(self, signal: np.ndarray) -> np.ndarray:
"""Apply FFT-based frequency filtering using the current band multipliers"""
# Get FFT of the signal
fft = np.fft.fft(signal)
freqs = np.fft.fftfreq(len(signal))
# Split frequencies into 5 bands
# Exclude DC component (0 frequency) from bands
pos_freqs = freqs[1:len(freqs)//2]
freq_ranges = np.array_split(pos_freqs, 5)
# Apply band multipliers
filtered_fft = fft.copy()
# Handle DC component separately (lowest frequency)
filtered_fft[0] *= self.freq_bands[4] # Apply very low freq multiplier to DC
# Apply multipliers to each frequency band
for i, freq_range in enumerate(freq_ranges):
# Get indices for this frequency band
band_mask = np.logical_and(
freqs >= freq_range[0],
freqs <= freq_range[-1]
)
# Apply multiplier to positive and negative frequencies symmetrically
filtered_fft[band_mask] *= self.freq_bands[4-i]
filtered_fft[np.flip(band_mask)] *= self.freq_bands[4-i]
# Convert back to time domain
filtered_signal = np.real(np.fft.ifft(filtered_fft))
return filtered_signal
def update_frequency_bands(self, band_idx: int, value: float) -> Tuple[List[go.Figure], Dict]:
"""Update a frequency band multiplier and recompute the filtered signal"""
if self.latest_sample is None:
return [], {}
# Update the specified band multiplier
self.freq_bands[band_idx] = value
# Apply frequency filtering to each feature
filtered_sample = self.latest_sample.copy()
for feat_idx in range(self.feature_dim):
filtered_sample[:, feat_idx] = self.apply_frequency_filter(
self.latest_sample[:, feat_idx]
)
# Ensure values remain in valid range
filtered_sample = np.clip(filtered_sample, 0, 1)
# Calculate metrics
metrics = {
'show_normalized': self.show_normalized,
'frequency_bands': self.freq_bands.tolist()
}
for feat_idx in range(self.feature_dim):
total = np.sum(filtered_sample[:, feat_idx]) * (self.feature_scales[feat_idx] * 10 if not self.show_normalized else 1)
metrics[f'total_{self.feature_names[feat_idx]}'] = self.format_value(total, feat_idx)
# Update plots
figures = self.create_plot(
filtered_sample,
self.latest_observed_points,
self.latest_observed_mask,
self.latest_gradient_control_signal,
metrics
)
return figures, metrics
def parse_trending_parameters(self, trending_text: str) -> List[Tuple]:
"""Parse trending control parameters into (start, end, feature, function) tuples"""
trending_params = []
if not trending_text.strip():
return trending_params
trending_text = trending_text.replace('\n', ';')
for line in trending_text.strip().split(';'):
if not line.strip():
continue
try:
# Split by comma and handle the function part separately
parts = line.strip().split(',', 4)
if len(parts) != 5:
continue
start, end, feat = map(int, parts[:3])
function_str = parts[3].strip()
confidence = float(parts[4])
# Convert the function string to a callable
try:
func = self.function_parser.string_to_function(function_str)
trending_params.append((start, end, feat, func, confidence))
except ValueError as e:
print(f"Error parsing function '{function_str}': {e}")
continue
except (ValueError, IndexError):
continue
self.trending_controls = trending_params # Store the parsed parameters
return trending_params
def create_gradio_interface(editor: TimeSeriesEditor):
with gr.Blocks() as app:
gr.Markdown("# Time Series Editor")
gr.Markdown("## Instruction: Scroll Down + Click [Update Figure] [~10s]")
metrics_display = gr.JSON(label="Metrics", value={})
with gr.Row():
with gr.Column(scale=1):
# with Tab():
# Scaling Parameters Section
# with gr.Group():
gr.Markdown("## Scaling Parameters")
with gr.Accordion("Open for More Detail", open=False):
revenue_scale = gr.Number(
label="Revenue Scale ($ per 0.1 in model)",
value=1000000
)
download_scale = gr.Number(
label="Download Scale (downloads per 0.1 in model)",
value=100000
)
au_scale = gr.Number(
label="Active Users Scale (users per 0.1 in model)",
value=10000
)
show_normalized = gr.Checkbox(
label="Show Normalized Values (0-1 scale)",
value=True
)
update_scaling_btn = gr.Button("Update Scaling")
# TS Section
gr.Markdown("## Time Series Control Panel")
# with gr.Accordion("Open for More Detail"):
with gr.Group():
gr.Markdown("### Fixed Point Control")
data_points_df = gr.Dataframe(
headers=["time", "feature", "value"],
datatype=["number", "number", "number"],
# label="Anchor Point Control",
value=[[0, 0, 0.04], [2, 0, 0.58], [6, 0, 0.27], [58, 0, 0.8], [60, 0, 0.5]],
col_count=(3, "fixed"), # Fix number of columns
interactive=True
)
add_data_point_btn = gr.Button("Add Data Point")
def add_data_point(df):
new_row = pd.DataFrame([[None, 0, None]],
columns=["time", "feature", "value"])
return pd.concat([df, new_row], ignore_index=True)
add_data_point_btn.click(
fn=add_data_point,
inputs=[data_points_df],
outputs=[data_points_df]
)
with gr.Group():
gr.Markdown("### Group of Anchor Point Control with Confidence")
point_groups_df = gr.Dataframe(
headers=["start", "end", "interval", "feature", "value", "weight"],
datatype=["number", "number", "number", "number", "number", "number"],
# label="Group of Anchor Point Control",
value=[[0, 50, 10, 0, 0.5, 0.1], [100, 150, 50, 0, 0.1, 0.5]],
col_count=(6, "fixed"), # Fix number of columns
interactive=True
)
add_point_group_btn = gr.Button("Add Point Group")
def add_point_group(df):
new_row = pd.DataFrame([[None, None, None, 0, None, None]],
columns=["start", "end", "interval", "feature", "value", "weight"])
return pd.concat([df, new_row], ignore_index=True)
add_point_group_btn.click(
fn=add_point_group,
inputs=[point_groups_df],
outputs=[point_groups_df]
)
with gr.Group():
# with gr.Tab("Trending Control"):
gr.Markdown("### Trending Control")
gr.Markdown("""
Enter trending control parameters in the format:
```
start_time,end_time,feature,function,confidence
```
Examples:
- Linear trend: `0,100,0,x`
- Sine wave: `0,100,0,sin(2*pi*x)`
- Exponential: `0,100,0,exp(-x)`
Separate multiple trends with semicolons.
""")
enable_trending_control = gr.Checkbox(label="Enable Trending Control", value=False)
enable_trending_control_with_diff = gr.Checkbox(label="Consider Last Generated", value=False)
trending_control = gr.Textbox(
label="Trending Control Parameters",
lines=2,
placeholder="Enter parameters: start_time,end_time,feature,function,condifdence; separated by semicolons",
value="200,250,0,sin(2*pi*x),0.05"
)
# Area Control Parameters
with gr.Group(visible=False):
gr.Markdown("### Area Control")
enable_area_control = gr.Checkbox(label="Enable Area Control", value=False)
area_selections = gr.Textbox(
label="Area Selections (format: start_time,end_time,feature,target_value)",
lines=2,
placeholder="Enter areas: start,end,feature,target; separated by semicolons",
)
# AUC Parameters
gr.Markdown("### Statistics Control")
enable_auc = gr.Checkbox(label="Enable Total Sum Control", value=True)
auc_input = gr.Number(label="Target Sum Value", value=-150)
auc_weight_input = gr.Number(label="Sum Weight", value=10.0)
# Peak Parameters
with gr.Group(visible=False):
gr.Markdown("### Peak Control")
enable_peaks = gr.Checkbox(label="Enable Peak Control", value=False)
peak_points_input = gr.Textbox(label="Peak Points (comma-separated)", value="100,200")
peak_alpha_input = gr.Number(label="Peak Alpha", value=10)
peak_weight_input = gr.Number(label="Peak Weight", value=1.0)
update_model_btn = gr.Button("Update Figure")
gr.Markdown("## Extend Edit", visible=False)
with gr.Tab("Range Shift", visible=False):
# gr.Markdown("### Direct Edit Control")
enable_direct_area = gr.Checkbox(label="Enable Direct Edits", value=False) # range shift
direct_areas = gr.Textbox(
label="Direct Edit Areas (format: start_time,end_time,feature,delta)",
lines=2,
placeholder="Enter areas: start,end,feature,delta; separated by semicolons",
value="150,200,0,-0.1"
)
update_additional_btn = gr.Button("Update Additional Edit")
# with gr.Tab("Trending Control"):
# gr.Markdown("### Trending Control")
# gr.Markdown("""
# Enter trending control parameters in the format:
# ```
# start_time,end_time,feature,function
# ```
# Examples:
# - Linear trend: `0,100,0,x`
# - Sine wave: `0,100,0,sin(2*pi*x)`
# - Exponential: `0,100,0,exp(-x)`
# Separate multiple trends with semicolons.
# """)
# enable_trending_control = gr.Checkbox(label="Enable Trending Control", value=False)
# enable_trending_control_with_diff = gr.Checkbox(label="Consider Last Generated", value=False)
# trending_control = gr.Textbox(
# label="Trending Control Parameters",
# lines=2,
# placeholder="Enter parameters: start_time,end_time,feature,function,condifdence; separated by semicolons",
# value="0,100,0,sin(2*pi*x),0.3"
# )
# with gr.Tab("Frequency Controls", visible=False):
with gr.Group(visible=False):
gr.Markdown("Adjust multipliers for different frequency bands (0-2)")
freq_bands = [
gr.Slider(
minimum=0, maximum=2, step=0.1, value=1.0,
label=f"Band {i+1}: {'Very High' if i==0 else 'High' if i==1 else 'Mid' if i==2 else 'Low' if i==3 else 'Very Low'} Freq",
) for i in range(5)
]
gr.Markdown("### Feature Index Reference:")
for idx, name in enumerate(editor.feature_names):
gr.Markdown(f"- {idx}: {name}")
with gr.Column(scale=1.2):
gr.Markdown("""
### Plot Legend
- **Points with Error Bars**: Observed values where:
- Point position = observed value
- Error bar size = uncertainty (inversely proportional to weight)
- **Green Line**: Model prediction
- **Vertical Red Lines**: Peak points (if enabled)
""")
plots = [gr.Plot() for _ in range(editor.feature_dim)]
# - **Shaded Area**: General prediction uncertainty
def update_scaling_callback(revenue_scale, download_scale, au_scale, show_normalized):
figs, metrics = editor.update_scaling(
revenue_scale,
download_scale,
au_scale,
show_normalized
)
return [*figs, metrics]
def update_model_callback(
data_points_df,
point_groups_df,
enable_area_control,
area_selections,
enable_auc,
auc,
auc_weight,
enable_peaks,
peak_points,
peak_alpha,
peak_weight,
enable_trending,
enable_trending_with_diff,
trending_params
):
figs, _, _, metrics = editor.update_model(
plots,
data_points_df,
point_groups_df,
enable_area_control,
area_selections,
enable_auc,
auc,
enable_peaks,
peak_points,
peak_alpha,
auc_weight,
peak_weight,
enable_trending,
enable_trending_with_diff,
trending_params
)
return [*figs, metrics]
# Update the click handler
update_model_btn.click(
fn=update_model_callback,
inputs=[
data_points_df,
point_groups_df,
enable_area_control,
area_selections,
enable_auc,
auc_input,
auc_weight_input,
enable_peaks,
peak_points_input,
peak_alpha_input,
peak_weight_input,
enable_trending_control,
enable_trending_control_with_diff,
trending_control
],
outputs=[*plots, metrics_display]
)
def update_additional_callback(enable_direct_area, direct_areas):
figs, metrics = editor.update_additional_edit(
enable_direct_area,
direct_areas
)
return [*figs, metrics]
def update_freq_band(band_idx, value):
figs, metrics = editor.update_frequency_bands(band_idx, value)
return [*figs, metrics]
update_scaling_btn.click(
fn=update_scaling_callback,
inputs=[
revenue_scale,
download_scale,
au_scale,
show_normalized
],
outputs=[*plots, metrics_display]
)
update_additional_btn.click(
fn=update_additional_callback,
inputs=[enable_direct_area, direct_areas],
outputs=[*plots, metrics_display]
)
# Add event handlers for frequency band sliders
for i, slider in enumerate(freq_bands):
slider.change(
fn=update_freq_band,
inputs=[gr.Number(value=i, visible=False), slider],
outputs=[*plots, metrics_display]
)
app.load(
fn=update_model_callback,
inputs=[
data_points_df,
point_groups_df,
enable_area_control,
area_selections,
enable_auc,
auc_input,
auc_weight_input,
enable_peaks,
peak_points_input,
peak_alpha_input,
peak_weight_input,
enable_trending_control,
enable_trending_control_with_diff,
trending_control
],
outputs=[*plots, metrics_display]
)
return app
class FunctionParser:
def __init__(self):
# Define available mathematical functions and constants
self.math_functions = {
'sin': np.sin,
'cos': np.cos,
'tan': np.tan,
'exp': np.exp,
'log': np.log,
'sqrt': np.sqrt,
'abs': np.abs,
'pow': np.power,
'pi': np.pi,
'e': np.e,
'asin': np.arcsin,
'acos': np.arccos,
'atan': np.arctan,
'sinh': np.sinh,
'cosh': np.cosh,
'tanh': np.tanh
}
def validate_expression(self, expression: str) -> bool:
"""
Validate the mathematical expression for basic syntax errors.
"""
# Check for balanced parentheses
if expression.count('(') != expression.count(')'):
raise ValueError("Unbalanced parentheses in expression")
# Check for invalid characters
valid_chars = set('0123456789.+-*/()^ xXepi,')
valid_chars.update(''.join(self.math_functions.keys()))
if not all(c in valid_chars or c.isspace() for c in expression.lower()):
raise ValueError("Expression contains invalid characters")
return True
def preprocess_expression(self, expression: str) -> str:
"""
Preprocess the expression to handle various input formats.
"""
# Remove whitespace
expression = expression.replace(' ', '')
# Convert ^ to ** for exponentiation
expression = expression.replace('^', '**')
# Ensure multiplication is explicit
expression = re.sub(r'(\d+)([a-zA-Z])', r'\1*\2', expression)
expression = re.sub(r'(\))([\w])', r'\1*\2', expression)
# Replace X with x for consistency
expression = expression.lower()
return expression
def string_to_function(self, expression: str) -> Callable[[Union[float, np.ndarray]], Union[float, np.ndarray]]:
"""
Convert a string mathematical expression to a callable function.
Args:
expression (str): Mathematical expression (e.g., "sin(x) + x^2")
Returns:
Callable: A function that takes x as input and returns the evaluated result
Example:
>>> f = string_to_function("sin(x) + x^2")
>>> f(0.5)
0.729321...
"""
# Validate and preprocess the expression
self.validate_expression(expression)
processed_expr = self.preprocess_expression(expression)
# Create the function namespace
namespace = self.math_functions.copy()
try:
# Create the lambda function
func = eval(f"lambda x: {processed_expr}", namespace)
# Test the function with a simple input
test_value = 1.0
try:
func(test_value)
except Exception as e:
raise ValueError(f"Invalid function: {str(e)}")
return func
except SyntaxError as e:
raise ValueError(f"Invalid expression syntax: {str(e)}")
except Exception as e:
raise ValueError(f"Error creating function: {str(e)}")
@staticmethod
def demonstrate_usage():
"""
Demonstrate various uses of the function parser.
"""
parser = FunctionParser()
# Test cases
test_expressions = [
"x^2 + 2*x + 1",
"sin(x) + cos(x)",
"exp(-x^2)",
"log(x + 1)",
"sqrt(1 - x^2)",
]
print("Testing various mathematical expressions:")
x_test = 0.5
for expr in test_expressions:
try:
print(f"\nExpression: {expr}")
func = parser.string_to_function(expr)
result = func(x_test)
print(f"f({x_test}) = {result}")
# Test with numpy array
x_array = np.linspace(0, 1, 5)
results = func(x_array)
print(f"f(array) = {results}")
except Exception as e:
print(f"Error: {str(e)}")
# Example usage:
if __name__ == "__main__":
import os
import torch
import numpy as np
# assert torch.cuda.is_available(), "CUDA must be available"
os.environ["WANDB_ENABLED"] = "false"
print(os.getcwd())
device = torch.device(f"cuda:0") if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")
print(f"Using device: {device}")
from models.Tiffusion import tiffusion
model = tiffusion.Tiffusion(
seq_length=365,
feature_size=3,
n_layer_enc=6,
n_layer_dec=4,
d_model=128,
timesteps=500,
sampling_timesteps=200,
loss_type='l1',
beta_schedule='cosine',
n_heads=8,
mlp_hidden_times=4,
attn_pd=0.0,
resid_pd=0.0,
kernel_size=1,
padding_size=0,
control_signal=[]
).to(device)
model.load_state_dict(torch.load("./weight/checkpoint-10.pt", map_location=device, weights_only=True)["model"])
coef = 1.0e-2
stepsize = 5.0e-2
sampling_steps = 100 # Adjustable between 100-500 for speed/accuracy tradeoff
seq_length = 365
feature_dim = 3
print(f"seq_length: {seq_length}, feature_dim: {feature_dim}")
editor = TimeSeriesEditor(seq_length, feature_dim, model)
editor.coef = coef
editor.stepsize = stepsize
editor.sampling_steps = sampling_steps
app = create_gradio_interface(editor)
app.launch(show_api=False)