| | |
| | import os |
| | import torch |
| | import numpy as np |
| | import matplotlib.pyplot as plt |
| |
|
| | torch.set_printoptions(profile="full") |
| |
|
| | FILTER_RATE = 0.95 |
| | TOP_RATE = 0.5 |
| | ACTIVATION_BAR_RATIO = 0.95 |
| |
|
| | langs = ["en", "eu"] |
| | base_path = "new_activations" |
| |
|
| | n, over_zero = [], [] |
| |
|
| | model_name = None |
| | checkpoint = None |
| |
|
| | for lang in langs: |
| | |
| | path = os.path.join(base_path, f"activation.{lang}.train.l2-7b-eu.pt") |
| | data = torch.load(path) |
| | n.append(data["n"]) |
| | over_zero.append(data["over_zero"]) |
| |
|
| | |
| | if model_name is None: |
| | model_name = os.path.basename(os.path.dirname(path)) |
| | filename = os.path.basename(path) |
| | parts = filename.split('.') |
| | checkpoint = parts[-1] |
| |
|
| | |
| | n = torch.Tensor(n) |
| | over_zero = torch.stack(over_zero, dim=-1) |
| |
|
| | num_layers, intermediate_size, lang_num = over_zero.size() |
| |
|
| | |
| | activation_probs = over_zero / n |
| |
|
| | |
| | normed_activation_probs = activation_probs / activation_probs.sum(dim=-1, keepdim=True) |
| |
|
| | |
| | log_prop = torch.where(normed_activation_probs > 0, |
| | normed_activation_probs.log(), |
| | torch.zeros_like(normed_activation_probs)) |
| | entropy = -(normed_activation_probs * log_prop).sum(dim=-1) |
| |
|
| | |
| | flat_probs = activation_probs.flatten() |
| | thresh = flat_probs.kthvalue(int(flat_probs.numel() * FILTER_RATE)).values |
| | valid_mask = (activation_probs > thresh).any(dim=-1) |
| | entropy[~valid_mask] = float("inf") |
| |
|
| | |
| | flat_entropy = entropy.flatten() |
| | topk = int(flat_entropy.numel() * TOP_RATE) |
| | _, idx = flat_entropy.topk(topk, largest=False) |
| |
|
| | layer_idx = idx // intermediate_size |
| | neuron_idx = idx % intermediate_size |
| |
|
| | |
| | selection_props = activation_probs[layer_idx, neuron_idx] |
| | bar = flat_probs.kthvalue(int(flat_probs.numel() * ACTIVATION_BAR_RATIO)).values |
| | lang_mask = (selection_props > bar).T |
| |
|
| | final_mask = {} |
| | for i, lang in enumerate(langs): |
| | neuron_ids = torch.where(lang_mask[i])[0] |
| | layer_lists = [[] for _ in range(num_layers)] |
| | for nid in neuron_ids.tolist(): |
| | l = layer_idx[nid].item() |
| | h = neuron_idx[nid].item() |
| | layer_lists[l].append(h) |
| | final_mask[lang] = [torch.tensor(lst, dtype=torch.long) for lst in layer_lists] |
| |
|
| | |
| | |
| | |
| | plt.figure(figsize=(12, 6)) |
| | x = np.arange(num_layers) |
| | width = 0.35 |
| |
|
| | bars_list = [] |
| | for i, lang_key in enumerate(langs): |
| | counts = [len(layer) for layer in final_mask[lang_key]] |
| | bars = plt.bar(x + i * width, counts, width=width, label=lang_key) |
| | bars_list.append(bars) |
| |
|
| | |
| | for bar in bars: |
| | height = bar.get_height() |
| | plt.text(bar.get_x() + bar.get_width()/2.0, height, f'{int(height)}', |
| | ha='center', va='bottom', fontsize=9) |
| |
|
| | plt.xlabel("Layer Index") |
| | plt.ylabel("Number of Neurons") |
| | plt.title(f"Number of Language-Specific Neurons per Layer\nModel: {model_name}, Checkpoint: {checkpoint}") |
| | plt.xticks(x + width / 2, x) |
| | plt.legend() |
| | plt.grid(alpha=0.3, axis='y') |
| | plt.tight_layout() |
| |
|
| | plt.savefig(f"{model_name}_{checkpoint}_neurons_bar.png", dpi=300) |
| | plt.close() |
| |
|
| |
|