import os import torch os.environ["WANDB_ENABLED"] = "false" from engine.solver import Trainer from data.build_dataloader import build_dataloader from utils.metric_utils import visualization, save_pdf from data.build_dataloader import build_dataloader_cond # from utils.metric_utils import visualization from utils.io_utils import load_yaml_config, instantiate_from_config from models.model_utils import unnormalize_to_zero_to_one from scipy.signal import find_peaks, peak_prominences # disable user warnings import warnings warnings.simplefilter("ignore", UserWarning) import numpy as np import seaborn as sns import matplotlib.pyplot as plt import matplotlib as mpl import pickle from pathlib import Path def load_cached_results(cache_dir): results = {"unconditional": None, "sum_controlled": {}, "anchor_controlled": {}} for cache_file in cache_dir.glob("*.pkl"): with open(cache_file, "rb") as f: key = cache_file.stem # if key=="unconditional": # continue if key == "unconditional": results["unconditional"] = pickle.load(f) elif key.startswith("sum_"): param = key[4:] # Remove 'sum_' prefix results["sum_controlled"][param] = pickle.load(f) elif key.startswith("anchor_"): param = key[7:] # Remove 'anchor_' prefix results["anchor_controlled"][param] = pickle.load(f) return results def save_result(cache_dir, key, subkey, data): if subkey: filename = f"{key}_{subkey}.pkl" else: filename = f"{key}.pkl" with open(cache_dir / filename, "wb") as f: pickle.dump(data, f) class Arguments: def __init__(self, config_path, gpu=0) -> None: self.config_path = config_path # self.config_path = "./config/control/revenue-baseline-sine.yaml" self.save_dir = ( "../../../data/" + os.path.basename(self.config_path).split(".")[0] ) self.gpu = gpu os.makedirs(self.save_dir, exist_ok=True) self.mode = "infill" self.missing_ratio = 0.95 self.milestone = 10 def beautiful_text(key, highlight): # print(key) if "auc" in key: auc = key.split("_")[1] weight = key.split("_")[3] if highlight is None: return f"AUC: $\\mathbf{{{auc}}}$ Weight: {weight}" else: return f"AUC: {auc} Weight: $\\mathbf{{{weight}}}$" if "anchor" in key: anchor = key.split("_")[1] weight = key.split("_")[3] return f"anchor: {anchor} Weight: {weight}" return key def get_alpha(idx, n_plots): """Generate alpha value between 0.3-0.8 based on plot index""" return 0.5 + (0.4 * idx / (n_plots - 1)) if n_plots > 1 else 0.8 def create_color_gradient( sorting_value=None, start_color="#FFFF00", end_color="#00008B" ): """Create color gradient using matplotlib color interpolation.""" def color_fader(c1, c2, mix=0): """Fade from color c1 to c2 with mix ratio.""" c1 = np.array(mpl.colors.to_rgb(c1)) c2 = np.array(mpl.colors.to_rgb(c2)) return mpl.colors.to_hex((1 - mix) * c1 + mix * c2) if sorting_value is not None: # Normalize values between 0-1 values = np.array(list(sorting_value.values())) normalized = (values - values.min()) / (values.max() - values.min()) # Create color mapping return { key: color_fader(start_color, end_color, mix=norm_val) for key, norm_val in zip(sorting_value.keys(), normalized) } else: # Return middle point color return color_fader(start_color, end_color, mix=0.5) def create_color_gradient( sorting_value=None, start_color="#FFFF00", middle_color="#00FF00", end_color="#00008B", ): """Create color gradient using matplotlib interpolation with middle color.""" def color_fader(c1, c2, mix=0): """Fade from color c1 to c2 with mix ratio.""" c1 = np.array(mpl.colors.to_rgb(c1)) c2 = np.array(mpl.colors.to_rgb(c2)) return mpl.colors.to_hex((1 - mix) * c1 + mix * c2) if sorting_value is not None: values = np.array(list(sorting_value.values())) normalized = (values - values.min()) / (values.max() - values.min()) colors = {} for key, norm_val in zip(sorting_value.keys(), normalized): if norm_val <= 0.5: # Interpolate between start and middle mix = norm_val * 2 # Scale 0-0.5 to 0-1 colors[key] = color_fader(start_color, middle_color, mix) else: # Interpolate between middle and end mix = (norm_val - 0.5) * 2 # Scale 0.5-1 to 0-1 colors[key] = color_fader(middle_color, end_color, mix) return colors else: return middle_color # Return middle color directly # for config_path in [ # # "./config/modified/sines.yaml", # # "./config/modified/revenue-baseline-365.yaml", # "./config/modified/energy.yaml", # "./config/modified/fmri.yaml", # ]: import argparse def parse_args(): parser = argparse.ArgumentParser(description="Controlled Sampling") parser.add_argument( "--config_path", type=str, default="./config/modified/energy.yaml" ) parser.add_argument("--gpu", type=int, default=0) return parser.parse_args() def run(run_args): args = Arguments(run_args.config_path, run_args.gpu) configs = load_yaml_config(args.config_path) device = torch.device(f"cuda:{args.gpu}" if torch.cuda.is_available() else "cpu") torch.cuda.set_device(args.gpu) dl_info = build_dataloader(configs, args) model = instantiate_from_config(configs["model"]).to(device) trainer = Trainer(config=configs, args=args, model=model, dataloader=dl_info) # args.milestone trainer.load("10") dataset = dl_info["dataset"] test_dl_info = build_dataloader_cond(configs, args) test_dataloader, test_dataset = test_dl_info["dataloader"], test_dl_info["dataset"] coef = configs["dataloader"]["test_dataset"]["coefficient"] stepsize = configs["dataloader"]["test_dataset"]["step_size"] sampling_steps = configs["dataloader"]["test_dataset"]["sampling_steps"] seq_length, feature_dim = test_dataset.window, test_dataset.var_num dataset_name = os.path.basename(args.config_path).split(".")[0].split("-")[0] mapper = { "sines": "sines", "revenue": "revenue", "energy": "energy", "fmri": "fMRI", } gap = seq_length // 5 if seq_length in [96, 192, 384]: ori_data = np.load( os.path.join( "../../../data/train/",str(seq_length), dataset_name, "samples", f'{mapper[dataset_name].replace("sines", "sine")}_norm_truth_{seq_length}_train.npy', ) ) masks = np.load( os.path.join( "../../../data/train/",str(seq_length), dataset_name, "samples", f'{mapper[dataset_name].replace("sines", "sine")}_masking_{seq_length}.npy', ) ) else: ori_data = np.load( os.path.join( "../../../data/train/", dataset_name, "samples", f"{mapper[dataset_name]}_norm_truth_{seq_length}_train.npy", ) ) masks = np.load( os.path.join( "../../../data/train/", dataset_name, "samples", f"{mapper[dataset_name]}_masking_{seq_length}.npy", ) ) sample_num, _, _ = masks.shape # observed = ori_data[:sample_num] * masks ori_data = ori_data[:sample_num] sampling_size = min(1000, len(test_dataset), sample_num) batch_size = 500 print(f"Sampling size: {sampling_size}, Batch size: {batch_size}") ### Cache file path cache_dir = Path(f"../../../data/cache/{dataset_name}_{seq_length}") if "csdi" in args.config_path: cache_dir = Path(f"../../../data/cache/csdi/{dataset_name}_{seq_length}") cache_dir.mkdir(exist_ok=True) results = load_cached_results(cache_dir) ### Unconditional sampling if results["unconditional"] is None: print("Generating unconditional data...") results["unconditional"] = trainer.control_sample( num=sampling_size, size_every=batch_size, shape=[seq_length, feature_dim], model_kwargs={ "gradient_control_signal": {}, "coef": coef, "learning_rate": stepsize, }, ) save_result(cache_dir, "unconditional", "", results["unconditional"]) ### Different AUC values auc_weights = [10] auc_values = [-100, 20, 50, 150] # -200, -150, -100, -50, 0, 20, 30, 50, 100, 150 for auc in auc_values: for weight in auc_weights: key = f"auc_{auc}_weight_{weight}" if key not in results["sum_controlled"]: print(f"Generating sum controlled data - AUC: {auc}, Weight: {weight}") results["sum_controlled"][key] = trainer.control_sample( num=sampling_size, size_every=batch_size, shape=[seq_length, feature_dim], model_kwargs={ "gradient_control_signal": {"auc": auc, "auc_weight": weight}, "coef": coef, "learning_rate": stepsize, }, ) save_result(cache_dir, "sum", key, results["sum_controlled"][key]) ### Different AUC weights auc_weights = [1, 10, 50, 100] auc_values = [-100] for auc in auc_values: for weight in auc_weights: key = f"auc_{auc}_weight_{weight}" if key not in results["sum_controlled"]: print(f"Generating sum controlled data - AUC: {auc}, Weight: {weight}") results["sum_controlled"][key] = trainer.control_sample( num=sampling_size // 2, size_every=batch_size, shape=[seq_length, feature_dim], model_kwargs={ "gradient_control_signal": {"auc": auc, "auc_weight": weight}, "coef": coef, "learning_rate": stepsize, }, ) save_result(cache_dir, "sum", key, results["sum_controlled"][key]) ### Different AUC segments auc_weights = [10] auc_values = [150] auc_average = 10 auc_segments = ((gap, 2 * gap), (2 * gap, 3 * gap), (3 * gap, 4 * gap)) # for auc in auc_values: # for weight in auc_weights: # for segment in auc_segments: auc = auc_values[0] weight = auc_weights[0] # segment = auc_segments[0] for segment in auc_segments: key = f"auc_{auc}_weight_{weight}_segment_{segment[0]}_{segment[1]}" if key not in results["sum_controlled"]: print( f"Generating sum controlled data - AUC: {auc}, Weight: {weight}, Segment: {segment}" ) results["sum_controlled"][key] = trainer.control_sample( num=sampling_size, size_every=batch_size, shape=[seq_length, feature_dim], model_kwargs={ "gradient_control_signal": { "auc": auc_average * (segment[1] - segment[0]), # / seq_length, "auc_weight": weight, "segment": [segment], }, "coef": coef, "learning_rate": stepsize, }, ) save_result(cache_dir, "sum", key, results["sum_controlled"][key]) # Different anchors anchor_values = [-0.8, 0.6, 1.0] anchor_weights = [0.01, 0.01, 0.5, 1.0] for peak in anchor_values: for weight in anchor_weights: key = f"peak_{peak}_weight_{weight}" if key not in results["anchor_controlled"]: mask = np.zeros((seq_length, feature_dim), dtype=np.float32) mask[gap // 2 :: gap, 0] = weight target = np.zeros((seq_length, feature_dim), dtype=np.float32) target[gap // 2 :: gap, 0] = peak print(f"Anchor controlled data - Peak: {peak}, Weight: {weight}") results["anchor_controlled"][key] = trainer.control_sample( num=sampling_size, size_every=batch_size, shape=[seq_length, feature_dim], model_kwargs={ "gradient_control_signal": {}, # "auc": -50, "auc_weight": 10.0}, "coef": coef, "learning_rate": stepsize, }, target=target, partial_mask=mask, ) save_result(cache_dir, "anchor", key, results["anchor_controlled"][key]) # plot mask, target, and generated sequence # plt.figure(figsize=(6, 3)) # plt.scatter( # range(gap // 2, seq_length, gap), [weight] * 5, label="Mask" # ) # plt.scatter( # range(gap // 2, seq_length, gap), [peak] * 5, label="Target" # ) # plt.plot( # results["anchor_controlled"][key][0, :, 0], # label="Generated Sequence", # ) # plt.title(f"Anchor Controlled Data - Peak: {peak}, Weight: {weight}") # plt.legend() # plt.show() if dataset.auto_norm: for key, data in results.items(): if isinstance(data, dict): for subkey, subdata in data.items(): results[key][subkey] = unnormalize_to_zero_to_one(subdata) else: results[key] = unnormalize_to_zero_to_one(data) results["ori_data"] = ori_data # results tructure to sampling_size for key, data in results.items(): if isinstance(data, dict): for subkey, subdata in data.items(): results[key][subkey] = subdata[:sampling_size] else: results[key] = data[:sampling_size] return results, dataset_name, seq_length def ploting(results, dataset_name, seq_length): gap = seq_length // 5 ds_name_display = { "sines": "Synthetic Sine Waves", "revenue": "Revenue", "energy": "ETTh", "fmri": "fMRI", } # Unnormalize results if needed ori_data = results["ori_data"] # Store the results in variables for compatibility with existing code unconditional_data = results["unconditional"] sum_controled_data = results["sum_controlled"] # ['auc_0_weight_10.0'] # default values anchor_controled_data = results["anchor_controlled"] # ['anchor_0.8_weight_0.1'] # default values ### Visualization def kernel_subplots( data, output_label="", highlight=None ): # from scipy import integrate # Calculate area under curve for each distribution def get_auc(data_array): return data_array.sum(-1).mean() # Get AUC values auc_orig = get_auc(data["ori_data"]) auc_uncond = get_auc(data["Unconditional"]) # Setup subplots keys = [k for k in data.keys() if k not in ["ori_data", "Unconditional"]] l = len(keys) n_cols = min(4, len(keys)) n_rows = (len(keys) + n_cols - 1) // n_cols fig, axes = plt.subplots(n_rows, n_cols, figsize=(6 * n_cols, 4 * n_rows)) fig.set_dpi(300) if n_rows == 1: axes = axes.reshape(1, -1) for idx, key in enumerate(keys): row, col = idx // n_cols, idx % n_cols ax = axes[row, col] # Plot distributions sns.distplot( data["ori_data"], hist=False, kde=True, kde_kws={"linewidth": 2, "alpha": 0.9 - get_alpha(idx, l) * 0.5}, color="red", ax=ax, label=f"Original\n$\overline{{Area}}={auc_orig:.3f}$", ) sns.distplot( data["Unconditional"], hist=False, kde=True, kde_kws={ "linewidth": 2, "linestyle": "--", "alpha": 0.9 - get_alpha(idx, l) * 0.5, }, color="#15B01A", ax=ax, # FF4500 GREEN:15B01A label=f"Unconditional\n$\overline{{Area}}= {auc_uncond:.3f}$", ) auc_control = get_auc(data[key]) sns.distplot( data[key], hist=False, kde=True, kde_kws={"linewidth": 2, "alpha": get_alpha(idx, l), "linestyle": "--"}, color="#9A0EEA", ax=ax, label=f"{beautiful_text(key, highlight)}\n$\overline{{Area}}= {auc_control:.3f})$", ) # ax.set_title(f'{beautiful_text(key)}') ax.legend() # Set labels only for first column and last row if col == 0: ax.set_ylabel("Density") else: ax.set_ylabel("") if row == n_rows - 1: ax.set_xlabel("Value") else: ax.set_xlabel("") # fig.suptitle(f"Kernel Density Estimation of {output_label}", fontsize=16)#, fontweight='bold') plt.tight_layout() plt.show() # save pdf # plt.savefig(f"./figures/{output_label}_kde.pdf", bbox_inches='tight') save_pdf(fig, f"./figures/{output_label}_kde.pdf") plt.close() # Sum control samples = 1000 data = { "ori_data": ori_data[:samples, :, :1], "Unconditional": unconditional_data[:samples, :, :1], } for key in [ # "auc_-200_weight_10", "auc_-100_weight_10", # "auc_0_weight_10", "auc_20_weight_10", # "auc_30_weight_10", "auc_50_weight_10", # "auc_100_weight_10", "auc_150_weight_10", ]: data[key] = sum_controled_data[key][:samples, :, :1] print( key, " ==> ", sum_controled_data[key][:samples, :, :1].sum() / sum_controled_data[key][:samples, :, :1].shape[0], ) # visualization_control( # data=data, # analysis="kernel", # compare=ori_data.shape[0], # output_label="revenue" # ) # Updated # kernel_subplots( # data=data, # output_label=f"{ds_name_display[dataset_name]} Dataset with Summation Control" # ) data = { "ori_data": ori_data[:samples, :, :1], "Unconditional": unconditional_data[:samples, :, :1], } for key in [ "auc_-100_weight_1", "auc_-100_weight_10", "auc_-100_weight_50", "auc_-100_weight_100", ]: data[key] = sum_controled_data[key][:samples, :, :1] # print sum print( key, " ==> ", sum_controled_data[key][:samples, :, :1].sum() / sum_controled_data[key][:samples, :, :1].shape[0], ) kernel_subplots( data=data, analysis="kernel", compare=ori_data.shape[0], output_label=f"{ds_name_display[dataset_name]} Dataset with Summation Control", highlight="weight", ) # anchor control data = { "ori_data": ori_data[:samples, :, :1], "Unconditional": unconditional_data[:samples, :, :1], } # anchor_values = [-0.8, 0.6, 1.0] # anchor_weights = [0.01, 0.01, 0.5, 1.0] for key in [ "anchor_-0.8_weight_0.01", "anchor_-0.8_weight_0.1", "anchor_-0.8_weight_0.5", "anchor_-0.8_weight_1.0", "anchor_0.6_weight_0.01", "anchor_0.6_weight_0.1", "anchor_0.6_weight_0.5", "anchor_0.6_weight_1.0", "anchor_1.0_weight_0.01", "anchor_1.0_weight_0.1", "anchor_1.0_weight_0.5", "anchor_1.0_weight_1.0", ]: data[key] = anchor_controled_data[key][:samples, :, :1] # print anchor # print(key, " ==> ", anchor_controled_data[key][:samples, :, :1].max()) def visualization_control_anchor_subplots( data, seq_length, analysis="anchor", compare=100, output_label="" ): # Extract unique anchors and weights anchors = sorted( set([float(k.split("_")[1]) for k in data.keys() if "anchor" in k]) ) weights = sorted( set([float(k.split("_")[3]) for k in data.keys() if "weight" in k]) ) # Create subplot grid n_rows = len(anchors) n_cols = len(weights) fig, axes = plt.subplots(n_rows, n_cols, figsize=(6 * n_cols, 4 * n_rows)) fig.set_dpi(300) gap = seq_length // 5 for i, anchor in enumerate(anchors): for j, weight in enumerate(weights): ax = axes[i][j] key = f"anchor_{anchor}_weight_{weight}" # Plot distributions sns.distplot( data["ori_data"], hist=False, kde=True, kde_kws={"linewidth": 2}, color="red", ax=ax, label="Original", ) sns.distplot( data["Unconditional"], hist=False, kde=True, kde_kws={"linewidth": 2, "linestyle": "--"}, color="#15B01A", ax=ax, label="Unconditional", ) if key in data: sns.distplot( data[key], hist=False, kde=True, kde_kws={"linewidth": 2, "linestyle": "--"}, color="#9A0EEA", ax=ax, label=f"Controlled\n$Target={anchor}, Conf={weight}$", ) # anchor_point = int(anchor * seq_length) anchor_points = np.arange(gap // 2, seq_length, gap) for anchor_point in anchor_points: ax.axvline( x=anchor_point / seq_length, color="black", linestyle="--", alpha=0.5, ) # Labels and titles if i == n_rows - 1: ax.set_xlabel("Value") if j == 0: ax.set_ylabel("Density") ax.set_title(f"anchor={anchor}, Weight={weight}") ax.legend() plt.tight_layout() plt.show() # save_pdf(fig, f"./figures/{output_label}_anchor_kde.pdf") plt.close() # Anchor Control Distribution visualization_control_anchor_subplots( data=data, seq_length=seq_length, analysis="anchor", compare=ori_data.shape[0], output_label=f"{ds_name_display[dataset_name]} Dataset with Anchor Control", ) def evaluate_anchor_detection( data, target_anchors, window_size=7, min_distance=5, prominence_threshold=0.1 ): """ Evaluate anchor detection accuracy by comparing detected anchors with target anchors. Parameters: data: numpy array of shape (batch_size, seq_length, features) The generated sequences to analyze The indices where anchors should occur (e.g., every 7 steps for weekly anchors) target_anchor: list List of indices where anchors should occur window_size: int Size of window to consider a anchor match """ batch_size, seq_length, features = data.shape detected_anchors = [] accuracy_metrics = {} # Create figure for visualization fig, axes = plt.subplots(2, 2, figsize=(10, 5)) axes = axes.flatten() # Analyze first 8 batches and first feature (revenue) overall_matched = 0 overall_targets = 0 for i in range(4): sequence = data[i, :, 0] # batch i, all timepoints, revenue feature # Find anchors using scipy anchors, properties = find_peaks( sequence, distance=min_distance, prominence=prominence_threshold ) # Plot original sequence and detected anchors axes[i].plot(sequence, label="Generated") # Plot target anchor positions target_positions = ( target_anchors # np.arange(0, seq_length, 7) # Weekly anchors ) axes[i].plot( target_positions, sequence[target_positions], "o", label="Target" if i == 1 else "", ) axes[i].plot( anchors, sequence[anchors], "x", label="Detected" if i == 1 else "" ) axes[i].set_title(f"Sequence {i+1}") if i == 1: axes[i].legend(bbox_to_anchor=(1.05, 1), loc="upper left") axes[i].grid(True) # Count matches within window for this sequence matched_anchors = 0 for target in target_positions: # Check if any detected anchor is within the window of the target matches = np.any( (anchors >= target - window_size // 2) & (anchors <= target + window_size // 2) ) if matches: matched_anchors += 1 overall_matched += matched_anchors overall_targets += len(target_positions) for i in range(4, batch_size): anchors, properties = find_peaks( data[i, :, 0], distance=min_distance, prominence=prominence_threshold ) matched_anchors = 0 for target in target_anchors: matches = np.any( (anchors >= target - window_size // 2) & (anchors <= target + window_size // 2) ) if matches: matched_anchors += 1 overall_matched += matched_anchors overall_targets += len(target_anchors) # Calculate overall metrics accuracy = overall_matched / overall_targets precision = overall_matched / (len(anchors) * 8) if len(anchors) > 0 else 0 accuracy_metrics = { "accuracy": accuracy, "precision": precision, "total_targets": overall_targets, "detected_anchors": len(anchors) * 8, "matched_anchors": overall_matched, } plt.tight_layout() plt.show() return accuracy_metrics, anchors # Evaluate anchor detection for different control settings anchor_accuracies = {} for key, data in anchor_controled_data.items(): print(f"\nEvaluating {key}") metrics, anchors = evaluate_anchor_detection( data, target_anchors=range(0, seq_length, gap), window_size=max(1, gap // 2), min_distance=max(1, gap - 1), ) anchor_accuracies[key] = metrics print(f"Accuracy: {metrics['accuracy']:.3f}") print(f"Precision: {metrics['precision']:.3f}") print( f"Matched anchors: {metrics['matched_anchors']} / {metrics['total_targets']}" ) print("=" * 50) if __name__ == "__main__": args = parse_args() run(args)