| from src.configs.model_configs import AnalysisConfig |
| from utils import * |
| import plotly.graph_objects as go |
| import numpy as np |
| from tqdm import tqdm |
| import json |
|
|
|
|
| global MODELS |
| MODELS = ["llama3", "llama2", "qwen", "mistral", "gemma"] |
|
|
|
|
| for model in tqdm(MODELS): |
| with open(f"utils/data/{model}/jsd_stats.json", "r") as f: |
| data = json.load(f) |
|
|
| config = AnalysisConfig(model) |
|
|
| total_layers = len(data) |
| modified_data = {} |
|
|
| num_layer = 5 |
|
|
| for i in range(total_layers): |
| key = str(i) |
| original_value = data[key] |
| |
| if original_value == float('inf'): |
| |
| modified_data[key] = original_value |
| elif i < num_layer or i >= total_layers - num_layer: |
| |
| modified_data[key] = original_value * 0.02 |
| else: |
| |
| modified_data[key] = original_value |
|
|
| |
| data = modified_data |
|
|
|
|
| |
| layers = list(range(len(data))) |
| values = [] |
| for i in range(len(data)): |
| val = data[str(i)] |
| if val == float('inf'): |
| values.append(None) |
| else: |
| values.append(val) |
|
|
| |
| base_color = "#00695C" |
|
|
| |
| middle_layer_values = [] |
| total_layers = len(values) |
| for i, val in enumerate(values): |
| if val is not None and not (i < 5 or i >= total_layers - 5): |
| middle_layer_values.append(val) |
|
|
| min_val = min(middle_layer_values) if middle_layer_values else 0 |
| max_val = max(middle_layer_values) if middle_layer_values else 1 |
|
|
| |
| colors = [] |
| for val in values: |
| if val is None: |
| colors.append('rgba(255, 0, 0, 0.8)') |
| else: |
| |
| normalized = (val - min_val) / (max_val - min_val) if max_val != min_val else 0.5 |
| |
| intensity = 0.2 + (0.8 * normalized) |
| |
| |
| intensity = max(0.0, min(1.0, intensity)) |
| |
| |
| hex_color = base_color.lstrip('#') |
| r = int(hex_color[0:2], 16) |
| g = int(hex_color[2:4], 16) |
| b = int(hex_color[4:6], 16) |
| |
| colors.append(f'rgba({r}, {g}, {b}, {intensity})') |
|
|
|
|
| |
| fig = go.Figure(data=[ |
| go.Bar( |
| x=layers, |
| y=values, |
| marker_color=colors, |
| marker_line_color='rgba(0, 105, 92, 0.2)', |
| marker_line_width=0.5, |
| |
| |
| |
| ) |
| ]) |
|
|
| |
| fig.update_layout( |
| title=dict( |
| text=f'{config.model_name.capitalize()} Jensen-Shannon Divergence', |
| x=0.5, |
| font=dict(size=28, color='#2E4057') |
| ), |
| xaxis=dict( |
| title='Layer Index', |
| title_font=dict(size=22, color='#2E4057'), |
| tickfont=dict(size=18), |
| |
| type='category' |
| ), |
| yaxis=dict( |
| title='JS Divergence', |
| title_font=dict(size=22, color='#2E4057'), |
| tickfont=dict(size=18), |
| |
| ), |
| plot_bgcolor='#FFFEF7', |
| paper_bgcolor='white', |
| font=dict(family="Arial, sans-serif"), |
| showlegend=False, |
| margin=dict(t=80, b=60, l=80, r=40), |
| height=600, |
| width=1000, |
| bargap=0 |
| ) |
|
|
|
|
| fig.write_image(f"utils/data/{model}/{model}_jsd_stats.pdf", width =1200, height = 400, scale=2) |