Diff_LoRA / examples /run_glue_experiment.py
nozomuteruyo14's picture
Create examples/run_glue_experiment.py
5b16bbc verified
import os
import re
import math
import time
import sys
import io
import torch
import torch.nn as nn
import torch.nn.functional as F
from huggingface_hub import login
login("Your_API_Key")
# Modified Logger class to write output both to terminal and file
class Logger(io.TextIOBase):
def __init__(self, filename="experiment_log_GLUE.txt", stream=sys.stdout):
self.terminal = stream
self.log = open(filename, "w", encoding="utf8")
def write(self, message):
# Write to both terminal and file
self.terminal.write(message)
self.log.write(message)
self.log.flush() # Flush after each write
def flush(self):
self.terminal.flush()
self.log.flush()
@property
def encoding(self):
return self.log.encoding
# Redirect standard output to Logger
sys.stdout = Logger("experiment_log_GLUE.txt")
from transformers import (
AutoTokenizer,
AutoModelForSequenceClassification,
TrainingArguments,
Trainer,
DataCollatorWithPadding,
)
from datasets import load_dataset, DownloadConfig
import evaluate
from sklearn.metrics import f1_score
# Import DiffLoRA module from the diff_lora package
from diff_lora.model import replace_linear_with_diff_lora
###############################################
# Mappings for GLUE Tasks
###############################################
# Mapping of text columns for each GLUE task.
text_column_mapping = {
"mnli": ("premise", "hypothesis"),
"sst2": "sentence",
"cola": "sentence",
"qqp": ("question1", "question2"),
"qnli": ("question", "sentence"),
"rte": ("sentence1", "sentence2"),
"mrpc": ("sentence1", "sentence2"),
"stsb": ("sentence1", "sentence2")
}
# Number of labels per task (stsb is a regression task)
num_labels_mapping = {
"mnli": 3,
"sst2": 2,
"cola": 2,
"qqp": 2,
"qnli": 2,
"rte": 2,
"mrpc": 2,
"stsb": 1,
}
###############################################
# Experiment Function for a Single GLUE Task
###############################################
def run_glue_experiment(method: str, model_name: str, task: str,
num_train_epochs: int = 3, batch_size: int = 32,
lr: float = 2e-5, seed: int = 42, diff_r_ratio: float = 1.0):
print("\n==============================")
print(f"Task: {task} | Model: {model_name} | Method: {method}")
print("==============================\n")
torch.manual_seed(seed)
# Load dataset. For MNLI, use the "validation_matched" split.
download_config = DownloadConfig(max_retries=10)
dataset = load_dataset("glue", task, download_config=download_config)
if task == "mnli":
eval_split = "validation_matched"
else:
eval_split = "validation"
# Load evaluation metric.
metric = evaluate.load("glue", task)
# Load tokenizer.
tokenizer = AutoTokenizer.from_pretrained(model_name)
text_cols = text_column_mapping[task]
# Preprocessing: if there are multiple text columns, concatenate them with a space.
def preprocess_function(examples):
if isinstance(text_cols, tuple):
texts = [ex1 + " " + ex2 for ex1, ex2 in zip(examples[text_cols[0]], examples[text_cols[1]])]
else:
texts = examples[text_cols]
return tokenizer(texts, truncation=True)
encoded_dataset = dataset.map(preprocess_function, batched=True)
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
# Determine if this is a regression task (stsb) or a classification task.
is_regression = (task == "stsb")
num_labels = num_labels_mapping[task]
# Load model.
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=num_labels)
# For full fine-tuning, do not freeze parameters.
if method == "full_finetuning":
print("Performing full fine-tuning: All parameters are trainable.")
else:
# Freeze base model parameters.
for param in model.parameters():
param.requires_grad = False
baseline_r = 8
adapter_r = max(1, int(baseline_r * diff_r_ratio))
# Inject adapters based on the chosen method.
if method == "lora":
from peft import LoraConfig, get_peft_model
lora_config = LoraConfig(
r=baseline_r,
lora_alpha=16,
target_modules=["query", "value", "dense"],
lora_dropout=0.1,
bias="none",
task_type="SEQ_CLS",
)
model = get_peft_model(model, lora_config)
print("Injected standard LoRA adapters via PEFT.")
elif method == "diff_lora":
target_pattern = r"(query|value|dense)"
replace_linear_with_diff_lora(model, target_pattern, adapter_r)
print(f"Injected fused DiffLoRA adapters with rank {adapter_r} (ratio={diff_r_ratio}).")
elif method == "adalora":
from peft import AdaLoraConfig, get_peft_model
adalora_config = AdaLoraConfig(
peft_type="ADALORA",
r=baseline_r,
lora_alpha=16,
target_modules=["query", "value", "dense"],
lora_dropout=0.1,
bias="none",
task_type="SEQ_CLS",
)
model = get_peft_model(model, adalora_config)
print("Injected AdaLoRA adapters via PEFT.")
elif method == "vb_lora":
from peft import VBLoRAConfig, get_peft_model
vb_lora_config = VBLoRAConfig(
r=baseline_r,
task_type="SEQ_CLS",
target_modules=["query", "value", "dense"],
num_vectors=256,
vector_length=256,
topk=2,
vblora_dropout=0.1,
bias="none",
)
model = get_peft_model(model, vb_lora_config)
print("Injected VB-LoRA adapters via PEFT.")
elif method == "olora":
from peft import LoraConfig, get_peft_model
olora_config = LoraConfig(
r=baseline_r,
lora_alpha=16,
target_modules=["query", "value", "dense"],
lora_dropout=0.1,
bias="none",
task_type="SEQ_CLS",
init_lora_weights="olora",
)
model = get_peft_model(model, olora_config)
print("Injected OLoRA adapters via PEFT.")
elif method == "full_finetuning":
print("Proceeding with full fine-tuning (no adapter injection).")
else:
raise ValueError("Unknown method. Choose from 'lora', 'diff_lora', 'adalora', 'vb_lora', 'olora', or 'full_finetuning'.")
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
total_params = sum(p.numel() for p in model.parameters())
print(f"Trainable params: {trainable_params} / {total_params} ({100 * trainable_params / total_params:.2f}%)")
# Set training arguments.
training_args = TrainingArguments(
output_dir=f"./outputs/results_{model_name}_{task}_{method}",
evaluation_strategy="epoch",
save_strategy="epoch",
learning_rate=lr,
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
num_train_epochs=num_train_epochs,
weight_decay=0.01,
logging_steps=10000,
load_best_model_at_end=True,
report_to="none",
disable_tqdm=True
)
# Define compute_metrics based on the task.
def compute_metrics(eval_pred):
logits, labels = eval_pred
if task == "stsb":
predictions = logits.squeeze()
result = metric.compute(predictions=predictions, references=labels)
result["combined_score"] = (result["pearson"] + result["spearmanr"]) / 2
return result
elif task == "cola":
predictions = logits.argmax(axis=-1)
return metric.compute(predictions=predictions, references=labels)
elif task == "qqp":
predictions = logits.argmax(axis=-1)
acc = (predictions == labels).mean()
f1 = f1_score(labels, predictions)
return {"eval_accuracy": acc, "eval_f1": f1}
else:
predictions = logits.argmax(axis=-1)
return metric.compute(predictions=predictions, references=labels)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=encoded_dataset["train"],
eval_dataset=encoded_dataset[eval_split],
tokenizer=tokenizer,
data_collator=data_collator,
compute_metrics=compute_metrics,
)
print("Starting training...")
start_time = time.time()
trainer.train()
training_time = time.time() - start_time
print(f"Training completed in {training_time:.2f} seconds.")
# Evaluate and extract the final metric.
if task == "mnli":
eval_result_matched = trainer.evaluate(eval_dataset=encoded_dataset["validation_matched"])
eval_result_mismatched = trainer.evaluate(eval_dataset=encoded_dataset["validation_mismatched"])
acc_matched = eval_result_matched.get("eval_accuracy", 0.0)
acc_mismatched = eval_result_mismatched.get("eval_accuracy", 0.0)
final_metric_str = f"{acc_matched:.4f}/{acc_mismatched:.4f}"
final_metric_num = (acc_matched + acc_mismatched) / 2
elif task == "qqp":
eval_result = trainer.evaluate()
acc = eval_result.get("eval_accuracy", 0.0)
f1 = eval_result.get("eval_f1", 0.0)
final_metric_str = f"{acc:.4f}/{f1:.4f}"
final_metric_num = (acc + f1) / 2
elif task == "stsb":
val = trainer.evaluate().get("eval_combined_score", 0.0)
final_metric_str = f"{val:.4f}"
final_metric_num = val
elif task == "cola":
val = trainer.evaluate().get("eval_matthews_correlation", 0.0)
final_metric_str = f"{val:.4f}"
final_metric_num = val
else:
val = trainer.evaluate().get("eval_accuracy", 0.0)
final_metric_str = f"{val:.4f}"
final_metric_num = val
print(f"\n=== FINAL RESULTS for {task} | {model_name} | {method} ===")
print(f"Metric: {final_metric_str}")
print(f"Training Time: {training_time:.2f} seconds\n")
return {
"task": task,
"model_name": model_name,
"method": method,
"metric_str": final_metric_str,
"metric_num": final_metric_num,
"training_time": training_time,
"trainable_params": trainable_params,
}
###############################################
# Main: Run Experiments over GLUE Tasks for Multiple Methods
###############################################
if __name__ == "__main__":
# Desired order and corresponding indicators:
# [mnli (m/mm), sst2 (Acc), cola (Mcc), qqp (Acc/F1), qnli (Acc), rte (Acc), mrpc (Acc), stsb (Corr)]
tasks = ["mnli", "sst2", "cola", "qqp", "qnli", "rte", "mrpc", "stsb"]
methods = ["lora", "diff_lora", "adalora", "vb_lora", "olora", "full_finetuning"]
model_names = ["bert-base-uncased"]
all_results = []
for model_name in model_names:
for method in methods:
for task in tasks:
result = run_glue_experiment(
method=method,
model_name=model_name,
task=task,
num_train_epochs=3,
batch_size=32,
lr=2e-5,
seed=42,
diff_r_ratio=1.0
)
all_results.append(result)
# Organize results: create a summary table for each model-method combination.
from collections import defaultdict
summary = defaultdict(dict)
for res in all_results:
key = f"{res['model_name']} | {res['method']}"
summary[key][res["task"]] = res["metric_str"]
# Print summary table with column indicators.
indicator_names = {
"mnli": "m/mm",
"sst2": "Acc",
"cola": "Mcc",
"qqp": "Acc/F1",
"qnli": "Acc",
"rte": "Acc",
"mrpc": "Acc",
"stsb": "Corr"
}
print("\n===== Summary of GLUE Results =====")
header = "Model | Method || " + " | ".join([f"{task} ({indicator_names[task]})" for task in tasks]) + " || Average"
print(header)
print("-" * len(header))
for key, metrics in summary.items():
avg_list = []
display_values = []
for task in tasks:
val = metrics.get(task, "N/A")
display_values.append(val)
if "/" in val:
parts = val.split("/")
try:
num_val = (float(parts[0]) + float(parts[1])) / 2
avg_list.append(num_val)
except:
pass
else:
try:
avg_list.append(float(val))
except:
pass
overall_avg = sum(avg_list) / len(avg_list) if avg_list else 0.0
print(f"{key} || " + " | ".join(display_values) + f" || {overall_avg:.4f}")