|
|
import gradio as gr |
|
|
import torch |
|
|
import os |
|
|
import sys |
|
|
import json |
|
|
import gc |
|
|
from pathlib import Path |
|
|
import matplotlib |
|
|
matplotlib.use('Agg') |
|
|
import matplotlib.pyplot as plt |
|
|
from typing import Dict, List, Optional, Tuple |
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
import io |
|
|
|
|
|
|
|
|
sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'ffg_experiment_suite')) |
|
|
|
|
|
|
|
|
from src.models import load_assets |
|
|
from src.grafting import fast_fisher_graft, magnitude_graft, fish_mask_graft |
|
|
from src.analysis import (_create_sparsity_distribution_plot, |
|
|
_set_publication_fonts, |
|
|
load_masks_from_run) |
|
|
|
|
|
|
|
|
AVAILABLE_MODELS = [ |
|
|
{ |
|
|
"name": "Math Reasoning", |
|
|
"base": "meta-llama/Meta-Llama-3.1-8B", |
|
|
"finetuned": "pmahdavi/Llama-3.1-8B-math-reasoning", |
|
|
"optimizer_states": "pmahdavi/Llama-3.1-8B-math-reasoning:export/exp_avg_sq.safetensors" |
|
|
}, |
|
|
{ |
|
|
"name": "Coding", |
|
|
"base": "meta-llama/Meta-Llama-3.1-8B", |
|
|
"finetuned": "pmahdavi/Llama-3.1-8B-coding-tulu3-ebs128-lr5e6-wsdcr0p4", |
|
|
"optimizer_states": "pmahdavi/Llama-3.1-8B-coding-tulu3-ebs128-lr5e6-wsdcr0p4:export_full_state_checkpoint-1100/exp_avg_sq.safetensors" |
|
|
}, |
|
|
{ |
|
|
"name": "Instruction Following", |
|
|
"base": "meta-llama/Meta-Llama-3.1-8B", |
|
|
"finetuned": "pmahdavi/Llama-3.1-8B-precise-if", |
|
|
"optimizer_states": "pmahdavi/Llama-3.1-8B-precise-if:export/exp_avg_sq.safetensors" |
|
|
}, |
|
|
{ |
|
|
"name": "General", |
|
|
"base": "meta-llama/Meta-Llama-3.1-8B", |
|
|
"finetuned": "pmahdavi/Llama-3.1-8B-general", |
|
|
"optimizer_states": "pmahdavi/Llama-3.1-8B-general:export/exp_avg_sq.safetensors" |
|
|
}, |
|
|
{ |
|
|
"name": "Knowledge Recall", |
|
|
"base": "meta-llama/Meta-Llama-3.1-8B", |
|
|
"finetuned": "pmahdavi/Llama-3.1-8B-knowledge-recall", |
|
|
"optimizer_states": "pmahdavi/Llama-3.1-8B-knowledge-recall:export/exp_avg_sq.safetensors" |
|
|
} |
|
|
] |
|
|
|
|
|
class FFGMaskExplorer: |
|
|
def __init__(self): |
|
|
self.current_masks = None |
|
|
self.current_stats = None |
|
|
|
|
|
def generate_masks(self, model_selection: str, sparsity_ratio: float, |
|
|
grafting_method: str, device_type: str, progress=gr.Progress()): |
|
|
"""Generate masks for a single model configuration.""" |
|
|
|
|
|
|
|
|
model_config = None |
|
|
for model in AVAILABLE_MODELS: |
|
|
if model["name"] == model_selection: |
|
|
model_config = model |
|
|
break |
|
|
|
|
|
if not model_config: |
|
|
return None, None, "Model not found!" |
|
|
|
|
|
progress(0.1, desc="Loading models...") |
|
|
|
|
|
try: |
|
|
|
|
|
config = { |
|
|
"base_model_id": model_config["base"], |
|
|
"finetuned_model_id": model_config["finetuned"], |
|
|
"optimizer_states_file": model_config["optimizer_states"], |
|
|
"device": device_type.lower(), |
|
|
"dtype": "bfloat16" |
|
|
} |
|
|
|
|
|
|
|
|
pretrained_model, finetuned_model, optimizer_v_state, tokenizer = load_assets(config) |
|
|
|
|
|
progress(0.5, desc="Generating masks...") |
|
|
|
|
|
|
|
|
if grafting_method == "Fast Fisher (FFG)": |
|
|
grafted_model, stats_dict, masks_dict = fast_fisher_graft( |
|
|
pretrained_model, finetuned_model, optimizer_v_state, sparsity_ratio |
|
|
) |
|
|
elif grafting_method == "Magnitude": |
|
|
grafted_model, stats_dict, masks_dict = magnitude_graft( |
|
|
pretrained_model, finetuned_model, sparsity_ratio |
|
|
) |
|
|
elif grafting_method == "Fish-Mask": |
|
|
grafted_model, stats_dict, masks_dict = fish_mask_graft( |
|
|
pretrained_model, finetuned_model, optimizer_v_state, sparsity_ratio |
|
|
) |
|
|
else: |
|
|
raise ValueError(f"Unknown grafting method: {grafting_method}") |
|
|
|
|
|
|
|
|
self.current_masks = masks_dict |
|
|
self.current_stats = stats_dict |
|
|
|
|
|
progress(0.8, desc="Creating visualizations...") |
|
|
|
|
|
|
|
|
viz_images = self.create_basic_visualizations(masks_dict, stats_dict) |
|
|
|
|
|
|
|
|
del pretrained_model, finetuned_model, grafted_model |
|
|
if optimizer_v_state is not None: |
|
|
del optimizer_v_state |
|
|
gc.collect() |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
progress(1.0, desc="Complete!") |
|
|
|
|
|
|
|
|
stats_text = self.format_statistics(stats_dict) |
|
|
|
|
|
return viz_images, stats_text, "Success!" |
|
|
|
|
|
except Exception as e: |
|
|
gc.collect() |
|
|
torch.cuda.empty_cache() |
|
|
return None, None, f"Error: {str(e)}" |
|
|
|
|
|
def create_basic_visualizations(self, masks_dict: Dict, stats_dict: Dict) -> List[Image.Image]: |
|
|
"""Create basic visualizations from the masks.""" |
|
|
images = [] |
|
|
|
|
|
|
|
|
_set_publication_fonts(scale_factor=1.0) |
|
|
|
|
|
|
|
|
fig, ax = plt.subplots(figsize=(10, 6)) |
|
|
stats_data = { |
|
|
'Total Parameters': stats_dict['total_params'], |
|
|
'Kept Parameters': stats_dict['kept_params'], |
|
|
'Pruned Parameters': stats_dict['total_params'] - stats_dict['kept_params'] |
|
|
} |
|
|
|
|
|
ax.bar(stats_data.keys(), stats_data.values(), color=['blue', 'green', 'red']) |
|
|
ax.set_ylabel('Number of Parameters') |
|
|
ax.set_title(f'Grafting Statistics (Sparsity: {stats_dict["final_sparsity"]:.2%})') |
|
|
|
|
|
|
|
|
for i, (key, value) in enumerate(stats_data.items()): |
|
|
ax.text(i, value, f'{value:,}', ha='center', va='bottom') |
|
|
|
|
|
plt.tight_layout() |
|
|
images.append(self.fig_to_image(fig)) |
|
|
plt.close(fig) |
|
|
|
|
|
|
|
|
fig, ax = plt.subplots(figsize=(14, 8)) |
|
|
|
|
|
|
|
|
layer_sparsities = [] |
|
|
layer_names = [] |
|
|
|
|
|
for name, mask in masks_dict.items(): |
|
|
if mask is not None and mask.numel() > 0: |
|
|
sparsity = 1.0 - mask.float().mean().item() |
|
|
layer_sparsities.append(sparsity) |
|
|
|
|
|
|
|
|
short_name = name.replace('model.layers.', 'L').replace('.weight', '') |
|
|
if len(short_name) > 20: |
|
|
short_name = short_name[:17] + '...' |
|
|
layer_names.append(short_name) |
|
|
|
|
|
|
|
|
if len(layer_names) > 50: |
|
|
layer_names = layer_names[:50] |
|
|
layer_sparsities = layer_sparsities[:50] |
|
|
|
|
|
ax.barh(range(len(layer_names)), layer_sparsities, color='skyblue') |
|
|
ax.set_yticks(range(len(layer_names))) |
|
|
ax.set_yticklabels(layer_names, fontsize=8) |
|
|
ax.set_xlabel('Sparsity Ratio') |
|
|
ax.set_title('Layer-wise Sparsity Distribution') |
|
|
ax.set_xlim(0, 1) |
|
|
|
|
|
plt.tight_layout() |
|
|
images.append(self.fig_to_image(fig)) |
|
|
plt.close(fig) |
|
|
|
|
|
|
|
|
num_samples = min(4, len(masks_dict)) |
|
|
fig, axes = plt.subplots(2, 2, figsize=(12, 10)) |
|
|
axes = axes.flatten() |
|
|
|
|
|
for idx, (name, mask) in enumerate(list(masks_dict.items())[:num_samples]): |
|
|
if mask is None or mask.numel() == 0: |
|
|
continue |
|
|
|
|
|
ax = axes[idx] |
|
|
|
|
|
|
|
|
mask_np = mask.cpu().float().numpy() |
|
|
|
|
|
|
|
|
if mask.ndim == 2: |
|
|
im = ax.imshow(mask_np, cmap='RdBu_r', aspect='auto', vmin=0, vmax=1) |
|
|
else: |
|
|
|
|
|
if mask.ndim == 1: |
|
|
|
|
|
size = int(np.sqrt(mask.numel())) |
|
|
if size * size == mask.numel(): |
|
|
mask_np = mask_np.reshape(size, size) |
|
|
else: |
|
|
|
|
|
target_size = size + 1 |
|
|
padded = np.zeros(target_size * target_size) |
|
|
padded[:mask.numel()] = mask_np.flatten() |
|
|
mask_np = padded.reshape(target_size, target_size) |
|
|
im = ax.imshow(mask_np, cmap='RdBu_r', aspect='auto', vmin=0, vmax=1) |
|
|
else: |
|
|
|
|
|
mask_np = mask_np.reshape(mask_np.shape[0], -1)[:min(mask_np.shape[0], 512), :min(mask_np.shape[1], 512)] |
|
|
im = ax.imshow(mask_np, cmap='RdBu_r', aspect='auto', vmin=0, vmax=1) |
|
|
|
|
|
|
|
|
plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04) |
|
|
|
|
|
|
|
|
short_name = name.replace('model.layers.', 'L').replace('.weight', '') |
|
|
if len(short_name) > 30: |
|
|
short_name = short_name[:27] + '...' |
|
|
ax.set_title(short_name, fontsize=10) |
|
|
ax.set_xlabel('Dimension 1') |
|
|
ax.set_ylabel('Dimension 0') |
|
|
|
|
|
|
|
|
for idx in range(num_samples, len(axes)): |
|
|
axes[idx].axis('off') |
|
|
|
|
|
plt.suptitle('Sample Mask Visualizations (1=kept, 0=pruned)', fontsize=14) |
|
|
plt.tight_layout() |
|
|
images.append(self.fig_to_image(fig)) |
|
|
plt.close(fig) |
|
|
|
|
|
return images |
|
|
|
|
|
def fig_to_image(self, fig) -> Image.Image: |
|
|
"""Convert matplotlib figure to PIL Image.""" |
|
|
buf = io.BytesIO() |
|
|
fig.savefig(buf, format='png', dpi=150, bbox_inches='tight') |
|
|
buf.seek(0) |
|
|
return Image.open(buf) |
|
|
|
|
|
def format_statistics(self, stats_dict: Dict) -> str: |
|
|
"""Format statistics dictionary as readable text.""" |
|
|
lines = [ |
|
|
"### Grafting Statistics", |
|
|
f"- **Total Parameters**: {stats_dict['total_params']:,}", |
|
|
f"- **Kept Parameters**: {stats_dict['kept_params']:,}", |
|
|
f"- **Pruned Parameters**: {stats_dict['total_params'] - stats_dict['kept_params']:,}", |
|
|
f"- **Final Sparsity**: {stats_dict['final_sparsity']:.4f}", |
|
|
f"- **Threshold**: {stats_dict.get('threshold', 'N/A')}" |
|
|
] |
|
|
return "\n".join(lines) |
|
|
|
|
|
|
|
|
explorer = FFGMaskExplorer() |
|
|
|
|
|
|
|
|
with gr.Blocks(title="FFG Mask Explorer", theme=gr.themes.Base()) as app: |
|
|
gr.Markdown(""" |
|
|
# π¬ FFG Mask Explorer |
|
|
|
|
|
Interactive tool for generating and visualizing Fast Fisher Grafting (FFG) masks on fine-tuned language models. |
|
|
Based on the paper: [Harnessing Optimization Dynamics for Curvature-Informed Model Merging](https://arxiv.org/abs/2509.11167) |
|
|
|
|
|
### How to use: |
|
|
1. Select a pre-configured model or enter custom model IDs |
|
|
2. Choose sparsity ratio (fraction of parameters to KEEP) |
|
|
3. Select grafting method |
|
|
4. Click Generate to create masks and visualizations |
|
|
""") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
gr.Markdown("### Model Configuration") |
|
|
|
|
|
model_dropdown = gr.Dropdown( |
|
|
choices=[m["name"] for m in AVAILABLE_MODELS], |
|
|
value=AVAILABLE_MODELS[0]["name"], |
|
|
label="Select Pre-configured Model", |
|
|
interactive=True |
|
|
) |
|
|
|
|
|
with gr.Accordion("Custom Model (Advanced)", open=False): |
|
|
custom_base = gr.Textbox( |
|
|
label="Base Model ID", |
|
|
placeholder="e.g., meta-llama/Meta-Llama-3.1-8B" |
|
|
) |
|
|
custom_finetuned = gr.Textbox( |
|
|
label="Fine-tuned Model ID", |
|
|
placeholder="e.g., username/model-name" |
|
|
) |
|
|
custom_optimizer = gr.Textbox( |
|
|
label="Optimizer States Path", |
|
|
placeholder="e.g., username/model:export/exp_avg_sq.safetensors" |
|
|
) |
|
|
|
|
|
sparsity_slider = gr.Slider( |
|
|
minimum=0.01, |
|
|
maximum=0.9, |
|
|
value=0.4, |
|
|
step=0.01, |
|
|
label="Sparsity Ratio (fraction to KEEP)", |
|
|
info="0.4 means keeping 40% of parameters" |
|
|
) |
|
|
|
|
|
method_radio = gr.Radio( |
|
|
choices=["Fast Fisher (FFG)", "Magnitude", "Fish-Mask"], |
|
|
value="Fast Fisher (FFG)", |
|
|
label="Grafting Method", |
|
|
info="FFG uses optimizer second moments for importance" |
|
|
) |
|
|
|
|
|
device_radio = gr.Radio( |
|
|
choices=["CUDA", "CPU"], |
|
|
value="CUDA", |
|
|
label="Device", |
|
|
info="CUDA recommended for faster processing" |
|
|
) |
|
|
|
|
|
generate_btn = gr.Button("π Generate Masks", variant="primary", size="lg") |
|
|
|
|
|
with gr.Column(scale=2): |
|
|
gr.Markdown("### Results") |
|
|
|
|
|
status_text = gr.Textbox(label="Status", interactive=False, value="Ready") |
|
|
|
|
|
with gr.Tabs(): |
|
|
with gr.TabItem("Visualizations"): |
|
|
gallery = gr.Gallery( |
|
|
label="Generated Visualizations", |
|
|
show_label=False, |
|
|
elem_id="gallery", |
|
|
columns=2, |
|
|
rows=2, |
|
|
object_fit="contain", |
|
|
height="auto" |
|
|
) |
|
|
|
|
|
with gr.TabItem("Statistics"): |
|
|
stats_markdown = gr.Markdown("*Generate masks to see statistics*") |
|
|
|
|
|
with gr.Row(): |
|
|
download_masks_btn = gr.Button("πΎ Download Masks", size="sm", interactive=False) |
|
|
download_viz_btn = gr.Button("π Download Visualizations", size="sm", interactive=False) |
|
|
|
|
|
|
|
|
def on_generate(model_selection, sparsity, method, device, progress=gr.Progress()): |
|
|
images, stats, status = explorer.generate_masks( |
|
|
model_selection, sparsity, method, device, progress |
|
|
) |
|
|
|
|
|
|
|
|
if images: |
|
|
return ( |
|
|
images, |
|
|
stats, |
|
|
status, |
|
|
gr.Button(interactive=True), |
|
|
gr.Button(interactive=True) |
|
|
) |
|
|
else: |
|
|
return ( |
|
|
None, |
|
|
"*Generation failed*", |
|
|
status, |
|
|
gr.Button(interactive=False), |
|
|
gr.Button(interactive=False) |
|
|
) |
|
|
|
|
|
generate_btn.click( |
|
|
fn=on_generate, |
|
|
inputs=[model_dropdown, sparsity_slider, method_radio, device_radio], |
|
|
outputs=[gallery, stats_markdown, status_text, download_masks_btn, download_viz_btn] |
|
|
) |
|
|
|
|
|
gr.Markdown(""" |
|
|
--- |
|
|
### About FFG |
|
|
Fast Fisher Grafting (FFG) uses the second moments from Adam optimizer to identify important parameters |
|
|
in fine-tuned models. This allows for more informed pruning compared to magnitude-based methods. |
|
|
|
|
|
### Citation |
|
|
```bibtex |
|
|
@misc{mahdavinia2025harnessingoptimizationdynamicscurvatureinformed, |
|
|
title={Harnessing Optimization Dynamics for Curvature-Informed Model Merging}, |
|
|
author={Pouria Mahdavinia and Hamed Mahdavi and Niloofar Mireshghallah and Mehrdad Mahdavi}, |
|
|
year={2025}, |
|
|
eprint={2509.11167}, |
|
|
archivePrefix={arXiv}, |
|
|
primaryClass={cs.LG}, |
|
|
url={https://arxiv.org/abs/2509.11167}, |
|
|
} |
|
|
``` |
|
|
""") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
app.launch() |