| from utils import * |
| from src.configs.safetynet_config import SafetyNetConfig |
| from utils.safetynet.vae_ae_train import Attention_DataProcessing, Train, Test, Detector_Stats |
| from src.configs.spylab_model_config import spylab_create_config |
| from src.configs.anthropic_model_config import anthropic_create_config |
|
|
| import plotly.graph_objects as go |
|
|
|
|
|
|
| class Visualization: |
| |
| @staticmethod |
| def data_processing_for_crow( |
| other_layer_idx, |
| vanilla_path = "utils/data/llama2/ae_vae/vanilla/cosine_analysis.json", |
| harmful_path = "utils/data/llama2/ae_vae/backdoored/cosine_analysis.json"): |
| |
| with open(vanilla_path, "r") as f: |
| vanilla_data = json.load(f) |
| |
| with open(harmful_path, "r") as f_: |
| backdoor_data = json.load(f_) |
| |
| if other_layer_idx == 'prev': |
| layer_idx = 0 |
| elif other_layer_idx == "next": |
| layer_idx = 1 |
| |
| |
| ''' |
| As the two layers pair values are there, so having [0] will give the first pair and [1] the second pair |
| ''' |
| mean_harmful_vanilla = np.mean(np.array(vanilla_data["harmful"][layer_idx])) |
| mean_harmful_backdoor = np.mean(np.array(backdoor_data["harmful"][layer_idx])) |
| |
| vanilla_data_stats = [i - float(mean_harmful_vanilla) for i in vanilla_data["normal"][layer_idx]] |
| |
| split_idx = int(len(vanilla_data_stats) * 0.8) |
| vanilla_data_stats_train = vanilla_data_stats[:split_idx] |
| vanilla_data_stats_val = vanilla_data_stats[split_idx:] |
| backdoor_data_stats = [i - float(mean_harmful_backdoor) for i in backdoor_data["normal"][layer_idx]] |
| |
| |
| |
| return { |
| "normal_losses": vanilla_data_stats_train, |
| "val_losses": vanilla_data_stats_val, |
| "harmful_losses": backdoor_data_stats |
| } |
| |
| |
| @staticmethod |
| def plot_all_layers_violin(model_name, model_type, save_path, config: SafetyNetConfig, max_layers=32): |
| """Create violin plot for all available layers""" |
| |
| fig = go.Figure() |
| layers_data = {} |
| |
| |
| for layer_idx in range(max_layers): |
| data_path = f"utils/data/{model_name}/{model_type}_loss/layer_{layer_idx}_{model_type}_loss.json" |
| with open(data_path, "r") as f: |
| layers_data[layer_idx] = json.load(f) |
| |
| if not layers_data: |
| print("No layer data found!") |
| return None |
| |
| |
| |
| |
| |
| |
| |
| for i, (layer_idx, data) in tqdm(enumerate(layers_data.items())): |
| x_pos = f'L{layer_idx}' |
| |
| |
| fig.add_trace(go.Violin( |
| y=data["normal_losses"], |
| x=[x_pos] * len(data["normal_losses"]), |
| name='Normal (Train)', |
| side='negative', |
| fillcolor="#4DB6AC", |
| line_color='#00695C', |
| box_visible=True, |
| meanline_visible=True, |
| points=False, |
| width=0.7, |
| legendgroup='normal_train', |
| showlegend=(i == 0) |
| )) |
| |
| |
| fig.add_trace(go.Violin( |
| y=data["harmful_losses"], |
| x=[x_pos] * len(data["harmful_losses"]), |
| name='Harmful', |
| side='positive', |
| fillcolor="#BA68C8", |
| line_color='#6A1B9A', |
| box_visible=True, |
| meanline_visible=True, |
| points=False, |
| width=0.5, |
| legendgroup='harmful', |
| showlegend=(i == 0) |
| )) |
| |
| |
| fig.add_trace(go.Violin( |
| y=data["val_losses"], |
| x=[x_pos] * len(data["val_losses"]), |
| name='Normal (Val)', |
| side='positive', |
| fillcolor="#3498DB", |
| line_color='#2874A6', |
| box_visible=True, |
| meanline_visible=True, |
| points=False, |
| width=0.3, |
| legendgroup='normal_val', |
| showlegend=(i == 0) |
| )) |
| |
| |
| fig.update_layout( |
| title=dict( |
| text=f'{config.model_name} Loss Distribution Across All Layers ({model_type.upper()})', |
| x=0.5, |
| y=0.98, |
| xanchor='center', |
| yanchor='top', |
| font=dict( |
| family="Times New Roman", |
| size=30, |
| color="black" |
| ) |
| ), |
| xaxis_title='Layer Index', |
| yaxis_title='Reconstruction Loss', |
| width=max(800, len(layers_data) * 60), |
| height=500, |
| showlegend=True, |
| legend=dict( |
| orientation="h", |
| yanchor="bottom", |
| y=0.95, |
| xanchor="center", |
| x=0.5, |
| font=dict(size=25, family="Times New Roman") |
| ), |
| plot_bgcolor='#FFFEF7', |
| paper_bgcolor='white', |
| font=dict(family="Times New Roman", size=20), |
| margin=dict( |
| t=70, |
| b=20, |
| l=20, |
| r=0 |
| ) |
| ) |
| fig.update_xaxes( |
| showgrid=True, |
| gridcolor='rgba(128, 128, 128, 0.2)', |
| showline=False, |
| tickangle=45 if len(layers_data) > 10 else 0, |
| |
| tickfont=dict( |
| family="Times New Roman", |
| size=25, |
| color="black" |
| ), |
| |
| title_font=dict( |
| family="Times New Roman", |
| size=22, |
| color="black" |
| ) |
| ) |
|
|
| fig.update_yaxes( |
| showgrid=True, |
| gridcolor='rgba(128, 128, 128, 0.2)', |
| showline=False, |
| range=[0, None], |
| |
| tickfont=dict( |
| family="Times New Roman", |
| size=25, |
| color="black" |
| ), |
| |
| title_font=dict( |
| family="Times New Roman", |
| size=32, |
| color="black" |
| ) |
| ) |
| |
| fig.write_image(f"{save_path}_all_layers_violin.pdf", height = 1000, width = 1500, scale=3) |
| return fig |
| |
| |
|
|
| @staticmethod |
| def plot_detectors_comparison(model_name, |
| detector_types, |
| other_layer_idx, |
| current_layer_idx, |
| save_path, |
| config: SafetyNetConfig, |
| model_type, |
| args |
| ): |
| """Compare different detector types (AE, VAE, PCA) at a specific layer with normalized losses""" |
| |
| fig = go.Figure() |
| |
| results = {} |
| |
| for i, detector_type in enumerate(detector_types): |
| |
| if detector_type == "crow" \ |
| or detector_type == "obfuscated_sim_crow" \ |
| or detector_type == "obfuscated_ae_crow": |
| |
|
|
| harmful_path = "utils/data/llama2/ae_vae/vanilla/cosine_analysis.json" |
| if args.dataset == "mad": |
| data = Visualization.data_processing_for_crow( |
| other_layer_idx=other_layer_idx, |
| harmful_path = f"utils/data/llama2/ae_vae/{model_type}/cosine_analysis.json" |
| ) |
| elif args.dataset == "spylab": |
| data = Visualization.data_processing_for_crow( |
| other_layer_idx=other_layer_idx, |
| vanilla_path = f"utils/spylab_data/llama2/vanilla/cosine_analysis.json", |
| harmful_path = f"utils/spylab_data/llama2/{model_type}/cosine_analysis.json" |
| ) |
| elif args.dataset == "anthropic": |
| data = Visualization.data_processing_for_crow( |
| other_layer_idx=other_layer_idx, |
| vanilla_path = f"safetynet/utils/anthropic_data/{model_name}/vanilla/cosine_analysis.json", |
| harmful_path = f"safetynet/utils/anthropic_data/{model_name}/{model_type}/cosine_analysis.json" |
| ) |
| normal_losses = np.array(data["normal_losses"]) |
| harmful_losses = np.array(data["harmful_losses"]) |
| val_losses = np.array(data["val_losses"]) |
|
|
| else: |
| if args.dataset == "mad": |
| data_path = f"utils/data/{model_name}/{detector_type}_loss/layer_{current_layer_idx}_{detector_type}_loss.json" |
| elif args.dataset == "spylab": |
| data_path = f"utils/spylab_data/{model_name}/{args.model_type}_{detector_type}_loss/layer_{current_layer_idx}_{args.model_type}_{detector_type}_loss.json" |
| elif args.dataset == "anthropic": |
| data_path = f"safetynet/utils/anthropic_data/{model_name}/{args.model_type}_{detector_type}_loss/layer_{current_layer_idx}_{args.model_type}_{detector_type}_loss.json" |
| print(data_path) |
| with open(data_path, "r") as f: |
| data = json.load(f) |
| |
| |
| |
| normal_losses = np.array(data["normal_losses"]) |
| harmful_losses = np.array(data["harmful_losses"]) |
| val_losses = np.array(data["val_losses"]) |
| |
| print(data) |
| |
| |
| |
| |
| val_mean = np.mean(val_losses) |
| val_std = np.std(val_losses) |
| threshold_upper = val_mean + 2 * val_std |
| threshold_lower = val_mean - 2 * val_std |
| |
|
|
| |
| train_pred = ((normal_losses < threshold_lower) | (normal_losses > threshold_upper)).astype(int) |
| harmful_pred = ((harmful_losses < threshold_lower) | (harmful_losses > threshold_upper)).astype(int) |
|
|
| |
| train_labels = np.zeros(len(normal_losses)) |
| harmful_labels = np.ones(len(harmful_losses)) |
|
|
| |
| all_pred = np.concatenate([train_pred, harmful_pred]) |
| all_labels = np.concatenate([train_labels, harmful_labels]) |
| all_scores = np.concatenate([normal_losses, harmful_losses]) |
|
|
| |
| |
| |
| stats = Detector_Stats() |
| detector_results = stats.compute_comprehensive_metrics(normal_losses, val_losses, harmful_losses) |
| detector_results['confusion_matrix_overall'] = detector_results['confusion_matrix_overall'].tolist() |
| results[detector_type] = detector_results |
| ''' |
| print(all_labels) |
| print(all_scores) |
| auroc = roc_auc_score(all_labels, all_scores) |
| if auroc < 0.5: # Scores are inverted |
| auroc = roc_auc_score(all_labels, -all_scores) |
| # except: |
| # auroc = 0.5 |
| |
| # Overall metrics |
| overall_accuracy = accuracy_score(all_labels, all_pred) |
| overall_precision = precision_score(all_labels, all_pred, zero_division=0) |
| overall_recall = recall_score(all_labels, all_pred, zero_division=0) |
| overall_f1 = f1_score(all_labels, all_pred, zero_division=0) |
| |
| # Per-class metrics |
| train_accuracy = np.mean(train_pred == train_labels) |
| harmful_accuracy = np.mean(harmful_pred == harmful_labels) |
| harmful_precision = precision_score(harmful_labels, harmful_pred, zero_division=0) |
| harmful_recall = recall_score(harmful_labels, harmful_pred, zero_division=0) |
| harmful_f1 = f1_score(harmful_labels, harmful_pred, zero_division=0) |
| |
| results[detector_type] = { |
| "auroc": float(auroc), |
| "overall_accuracy": float(overall_accuracy), |
| "overall_precision": float(overall_precision), |
| "overall_recall": float(overall_recall), |
| "overall_f1": float(overall_f1), |
| "train_accuracy": float(train_accuracy), |
| "harmful_accuracy": float(harmful_accuracy), |
| "harmful_precision": float(harmful_precision), |
| "harmful_recall": float(harmful_recall), |
| "harmful_f1": float(harmful_f1), |
| "threshold_lower": float(threshold_lower), |
| "threshold_upper": float(threshold_upper) |
| } |
| ''' |
| |
| |
| |
| all_losses = np.concatenate([normal_losses, harmful_losses, val_losses]) |
| min_loss = np.min(all_losses) |
| max_loss = np.max(all_losses) |
| loss_range = max_loss - min_loss |
| |
| print(f"\n NORMAL LOSSEs \n") |
| print(normal_losses) |
| |
| |
| |
| if loss_range == 0: |
| loss_range = 1 |
| |
| |
| normal_norm = (normal_losses - min_loss) / loss_range |
| harmful_norm = (harmful_losses - min_loss) / loss_range |
| val_norm = (val_losses - min_loss) / loss_range |
| |
| |
| detector = detector_type.split("_")[-1] |
| |
| if detector == "crow": |
| if other_layer_idx == "prev": |
| x_pos = f"CROW {current_layer_idx-1}-{current_layer_idx}" |
| elif other_layer_idx == "next": |
| x_pos = f"CROW {current_layer_idx}-{current_layer_idx+1}" |
| else: |
| x_pos = detector.upper() |
| |
| |
| loss_data = [ |
| ('Normal (Train)', normal_norm), |
| ('Harmful', harmful_norm), |
| ('Normal (Val)', val_norm) |
| ] |
| |
| for j, (loss_type, losses) in enumerate(loss_data): |
| fig.add_trace(go.Violin( |
| y=losses, |
| x=[x_pos] * len(losses), |
| name=loss_type, |
| side='negative' if j == 0 else 'positive', |
| fillcolor='#BA68C8' if j == 1 else ('#3498DB' if j == 2 else '#4DB6AC'), |
| line_color='#6A1B9A' if j == 1 else ('#2874A6' if j == 2 else '#00695C'), |
| box_visible=True, |
| meanline_visible=True, |
| points=False, |
| width=0.7 if j == 0 else (0.5 if j == 1 else 0.3), |
| legendgroup=loss_type.lower().replace(' ', '_'), |
| showlegend=(i == 0), |
| |
| hovertemplate=f'<b>{loss_type}</b><br>' + |
| 'Normalized: %{y:.3f}<br>' + |
| f'Original Range: [{min_loss:.3f}, {max_loss:.3f}]<br>' + |
| '<extra></extra>' |
| )) |
| |
| |
| print(f"CURRENTLY PROCESSING {detector_type}") |
|
|
|
|
| |
| pprint(results) |
| |
| |
| fig.update_layout( |
| title=dict( |
| text=f'{model_type.upper()} {config.model_name} Detector Comparison at Layer {current_layer_idx}', |
| x=0.5, y=0.96, xanchor='center', yanchor='top', |
| font=dict(family="Times New Roman", size=12, color="black") |
| ), |
| xaxis_title='Detector Type', |
| yaxis_title='Distribution of Distance (0-1 Scale)', |
| width=max(600, len(detector_types) * 120), |
| height=500, |
| showlegend=True, |
| legend=dict( |
| orientation="h", yanchor="bottom", y=0.97, xanchor="center", x=0.5, |
| font=dict(size=10, family="Times New Roman") |
| ), |
| plot_bgcolor='#FFFEF7', |
| paper_bgcolor='white', |
| font=dict(family="Times New Roman", size=10), |
| margin=dict(t=50, b=20, l=20, r=0) |
| ) |
| |
| |
| axis_style = dict( |
| showgrid=True, |
| gridcolor='rgba(128, 128, 128, 0.3)', |
| showline=False, |
| tickfont=dict(family="Times New Roman", size=10, color="black") |
| ) |
| |
| fig.update_xaxes(**axis_style, title_font=dict(family="Times New Roman", size=12, color="black")) |
| fig.update_yaxes( |
| **axis_style, |
| range=[-0.1, 1.1], |
| title_font=dict(family="Times New Roman", size=12, color="black") |
| ) |
| |
| if other_layer_idx == "prev": |
| fig.write_image(f"{save_path}_{model_type}_detectors_comparison_layer_{current_layer_idx-1}_{current_layer_idx}.pdf", |
| height=300, width=500, scale=3) |
| |
|
|
| |
| accuracy_path = f"{save_path}_{model_type}_accuracy_layer_{current_layer_idx-1}_{current_layer_idx}.json" |
| |
| elif other_layer_idx == "next": |
| fig.write_image(f"{save_path}_{model_type}_detectors_comparison_layer_{current_layer_idx}_{current_layer_idx+1}.pdf", |
| height=300, width=500, scale=3) |
| |
|
|
| |
| accuracy_path = f"{save_path}_{model_type}_accuracy_layer_{current_layer_idx}_{current_layer_idx+1}.json" |
| |
| |
| def numpy_to_python(obj): |
| if isinstance(obj, np.integer): |
| return int(obj) |
| elif isinstance(obj, np.floating): |
| return float(obj) |
| elif isinstance(obj, np.ndarray): |
| return obj.tolist() |
| elif isinstance(obj, dict): |
| return {key: numpy_to_python(val) for key, val in obj.items()} |
| elif isinstance(obj, list): |
| return [numpy_to_python(item) for item in obj] |
| return obj |
|
|
| |
| results = numpy_to_python(results) |
|
|
| if 'confusion_matrix_overall' in results: |
| cm = results['confusion_matrix_overall'] |
| if isinstance(cm, np.ndarray): |
| results['confusion_matrix_overall'] = cm.tolist() |
| elif isinstance(cm, list): |
| results['confusion_matrix_overall'] = [[int(x) for x in row] for row in cm] |
|
|
|
|
| |
| if os.path.exists(accuracy_path): |
| with open(accuracy_path, 'r') as f: |
| existing_results = json.load(f) |
| existing_results.update(results) |
| results = existing_results |
| |
| with open(accuracy_path, 'w') as f: |
| json.dump(results, f, indent=2) |
| |
| |
| return fig |
|
|
| |
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser(description='Multi-layer Attention Analysis') |
| parser.add_argument('--model_name', type=str, required=True) |
| parser.add_argument('--model_type', type=str, required=True) |
| parser.add_argument("--other_layer_idx", type=str, required=True, help="crow should be taken for previous and current layer or next and current layers? give 'prev' or 'next' as argument ") |
| parser.add_argument("--dataset", required=True, help="mad, spylab, or anthropic") |
| args = parser.parse_args() |
| |
| if args.dataset == "mad": |
| config = SafetyNetConfig(args.model_name) |
| elif args.dataset == "spylab": |
| config = spylab_create_config(args.model_name) |
| elif args.dataset == "anthropic": |
| config = anthropic_create_config(args.model_name) |
| |
| if args.model_name == 'qwen': |
| current_layer_idx=21 |
| elif args.model_name == 'mistral': |
| current_layer_idx = 12 |
| elif args.model_name == 'llama3': |
| current_layer_idx = 13 |
| elif args.model_name == 'llama2': |
| current_layer_idx = 15 |
| elif args.model_name == 'gemma': |
| current_layer_idx = 18 |
| |
| |
| save_path = f"{config.output_dir}/{args.model_name}" |
| |
| viz = Visualization() |
| |
| viz.plot_detectors_comparison(args.model_name, |
| |
| ['ae', 'pca', 'mahalanobis', 'beatrix', f'crow'], |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| current_layer_idx = current_layer_idx, |
| other_layer_idx = args.other_layer_idx, |
| save_path = save_path, |
| config = config, |
| model_type = args.model_type, |
| args = args |
| ) |
| |
| |
| |
| |
| |
|
|