|
import os |
|
import re |
|
import math |
|
import time |
|
import sys |
|
import io |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from huggingface_hub import login |
|
login("Your_API_Key") |
|
|
|
|
|
class Logger(io.TextIOBase): |
|
def __init__(self, filename="experiment_log_GLUE.txt", stream=sys.stdout): |
|
self.terminal = stream |
|
self.log = open(filename, "w", encoding="utf8") |
|
def write(self, message): |
|
|
|
self.terminal.write(message) |
|
self.log.write(message) |
|
self.log.flush() |
|
def flush(self): |
|
self.terminal.flush() |
|
self.log.flush() |
|
@property |
|
def encoding(self): |
|
return self.log.encoding |
|
|
|
|
|
sys.stdout = Logger("experiment_log_GLUE.txt") |
|
|
|
from transformers import ( |
|
AutoTokenizer, |
|
AutoModelForSequenceClassification, |
|
TrainingArguments, |
|
Trainer, |
|
DataCollatorWithPadding, |
|
) |
|
from datasets import load_dataset, DownloadConfig |
|
import evaluate |
|
from sklearn.metrics import f1_score |
|
|
|
|
|
from diff_lora.model import replace_linear_with_diff_lora |
|
|
|
|
|
|
|
|
|
|
|
|
|
text_column_mapping = { |
|
"mnli": ("premise", "hypothesis"), |
|
"sst2": "sentence", |
|
"cola": "sentence", |
|
"qqp": ("question1", "question2"), |
|
"qnli": ("question", "sentence"), |
|
"rte": ("sentence1", "sentence2"), |
|
"mrpc": ("sentence1", "sentence2"), |
|
"stsb": ("sentence1", "sentence2") |
|
} |
|
|
|
|
|
num_labels_mapping = { |
|
"mnli": 3, |
|
"sst2": 2, |
|
"cola": 2, |
|
"qqp": 2, |
|
"qnli": 2, |
|
"rte": 2, |
|
"mrpc": 2, |
|
"stsb": 1, |
|
} |
|
|
|
|
|
|
|
|
|
|
|
def run_glue_experiment(method: str, model_name: str, task: str, |
|
num_train_epochs: int = 3, batch_size: int = 32, |
|
lr: float = 2e-5, seed: int = 42, diff_r_ratio: float = 1.0): |
|
print("\n==============================") |
|
print(f"Task: {task} | Model: {model_name} | Method: {method}") |
|
print("==============================\n") |
|
torch.manual_seed(seed) |
|
|
|
|
|
download_config = DownloadConfig(max_retries=10) |
|
dataset = load_dataset("glue", task, download_config=download_config) |
|
if task == "mnli": |
|
eval_split = "validation_matched" |
|
else: |
|
eval_split = "validation" |
|
|
|
|
|
metric = evaluate.load("glue", task) |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
text_cols = text_column_mapping[task] |
|
|
|
|
|
def preprocess_function(examples): |
|
if isinstance(text_cols, tuple): |
|
texts = [ex1 + " " + ex2 for ex1, ex2 in zip(examples[text_cols[0]], examples[text_cols[1]])] |
|
else: |
|
texts = examples[text_cols] |
|
return tokenizer(texts, truncation=True) |
|
|
|
encoded_dataset = dataset.map(preprocess_function, batched=True) |
|
data_collator = DataCollatorWithPadding(tokenizer=tokenizer) |
|
|
|
|
|
is_regression = (task == "stsb") |
|
num_labels = num_labels_mapping[task] |
|
|
|
|
|
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=num_labels) |
|
|
|
|
|
if method == "full_finetuning": |
|
print("Performing full fine-tuning: All parameters are trainable.") |
|
else: |
|
|
|
for param in model.parameters(): |
|
param.requires_grad = False |
|
|
|
baseline_r = 8 |
|
adapter_r = max(1, int(baseline_r * diff_r_ratio)) |
|
|
|
|
|
if method == "lora": |
|
from peft import LoraConfig, get_peft_model |
|
lora_config = LoraConfig( |
|
r=baseline_r, |
|
lora_alpha=16, |
|
target_modules=["query", "value", "dense"], |
|
lora_dropout=0.1, |
|
bias="none", |
|
task_type="SEQ_CLS", |
|
) |
|
model = get_peft_model(model, lora_config) |
|
print("Injected standard LoRA adapters via PEFT.") |
|
elif method == "diff_lora": |
|
target_pattern = r"(query|value|dense)" |
|
replace_linear_with_diff_lora(model, target_pattern, adapter_r) |
|
print(f"Injected fused DiffLoRA adapters with rank {adapter_r} (ratio={diff_r_ratio}).") |
|
elif method == "adalora": |
|
from peft import AdaLoraConfig, get_peft_model |
|
adalora_config = AdaLoraConfig( |
|
peft_type="ADALORA", |
|
r=baseline_r, |
|
lora_alpha=16, |
|
target_modules=["query", "value", "dense"], |
|
lora_dropout=0.1, |
|
bias="none", |
|
task_type="SEQ_CLS", |
|
) |
|
model = get_peft_model(model, adalora_config) |
|
print("Injected AdaLoRA adapters via PEFT.") |
|
elif method == "vb_lora": |
|
from peft import VBLoRAConfig, get_peft_model |
|
vb_lora_config = VBLoRAConfig( |
|
r=baseline_r, |
|
task_type="SEQ_CLS", |
|
target_modules=["query", "value", "dense"], |
|
num_vectors=256, |
|
vector_length=256, |
|
topk=2, |
|
vblora_dropout=0.1, |
|
bias="none", |
|
) |
|
model = get_peft_model(model, vb_lora_config) |
|
print("Injected VB-LoRA adapters via PEFT.") |
|
elif method == "olora": |
|
from peft import LoraConfig, get_peft_model |
|
olora_config = LoraConfig( |
|
r=baseline_r, |
|
lora_alpha=16, |
|
target_modules=["query", "value", "dense"], |
|
lora_dropout=0.1, |
|
bias="none", |
|
task_type="SEQ_CLS", |
|
init_lora_weights="olora", |
|
) |
|
model = get_peft_model(model, olora_config) |
|
print("Injected OLoRA adapters via PEFT.") |
|
elif method == "full_finetuning": |
|
print("Proceeding with full fine-tuning (no adapter injection).") |
|
else: |
|
raise ValueError("Unknown method. Choose from 'lora', 'diff_lora', 'adalora', 'vb_lora', 'olora', or 'full_finetuning'.") |
|
|
|
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) |
|
total_params = sum(p.numel() for p in model.parameters()) |
|
print(f"Trainable params: {trainable_params} / {total_params} ({100 * trainable_params / total_params:.2f}%)") |
|
|
|
|
|
training_args = TrainingArguments( |
|
output_dir=f"./outputs/results_{model_name}_{task}_{method}", |
|
evaluation_strategy="epoch", |
|
save_strategy="epoch", |
|
learning_rate=lr, |
|
per_device_train_batch_size=batch_size, |
|
per_device_eval_batch_size=batch_size, |
|
num_train_epochs=num_train_epochs, |
|
weight_decay=0.01, |
|
logging_steps=10000, |
|
load_best_model_at_end=True, |
|
report_to="none", |
|
disable_tqdm=True |
|
) |
|
|
|
|
|
def compute_metrics(eval_pred): |
|
logits, labels = eval_pred |
|
if task == "stsb": |
|
predictions = logits.squeeze() |
|
result = metric.compute(predictions=predictions, references=labels) |
|
result["combined_score"] = (result["pearson"] + result["spearmanr"]) / 2 |
|
return result |
|
elif task == "cola": |
|
predictions = logits.argmax(axis=-1) |
|
return metric.compute(predictions=predictions, references=labels) |
|
elif task == "qqp": |
|
predictions = logits.argmax(axis=-1) |
|
acc = (predictions == labels).mean() |
|
f1 = f1_score(labels, predictions) |
|
return {"eval_accuracy": acc, "eval_f1": f1} |
|
else: |
|
predictions = logits.argmax(axis=-1) |
|
return metric.compute(predictions=predictions, references=labels) |
|
|
|
trainer = Trainer( |
|
model=model, |
|
args=training_args, |
|
train_dataset=encoded_dataset["train"], |
|
eval_dataset=encoded_dataset[eval_split], |
|
tokenizer=tokenizer, |
|
data_collator=data_collator, |
|
compute_metrics=compute_metrics, |
|
) |
|
|
|
print("Starting training...") |
|
start_time = time.time() |
|
trainer.train() |
|
training_time = time.time() - start_time |
|
print(f"Training completed in {training_time:.2f} seconds.") |
|
|
|
|
|
if task == "mnli": |
|
eval_result_matched = trainer.evaluate(eval_dataset=encoded_dataset["validation_matched"]) |
|
eval_result_mismatched = trainer.evaluate(eval_dataset=encoded_dataset["validation_mismatched"]) |
|
acc_matched = eval_result_matched.get("eval_accuracy", 0.0) |
|
acc_mismatched = eval_result_mismatched.get("eval_accuracy", 0.0) |
|
final_metric_str = f"{acc_matched:.4f}/{acc_mismatched:.4f}" |
|
final_metric_num = (acc_matched + acc_mismatched) / 2 |
|
elif task == "qqp": |
|
eval_result = trainer.evaluate() |
|
acc = eval_result.get("eval_accuracy", 0.0) |
|
f1 = eval_result.get("eval_f1", 0.0) |
|
final_metric_str = f"{acc:.4f}/{f1:.4f}" |
|
final_metric_num = (acc + f1) / 2 |
|
elif task == "stsb": |
|
val = trainer.evaluate().get("eval_combined_score", 0.0) |
|
final_metric_str = f"{val:.4f}" |
|
final_metric_num = val |
|
elif task == "cola": |
|
val = trainer.evaluate().get("eval_matthews_correlation", 0.0) |
|
final_metric_str = f"{val:.4f}" |
|
final_metric_num = val |
|
else: |
|
val = trainer.evaluate().get("eval_accuracy", 0.0) |
|
final_metric_str = f"{val:.4f}" |
|
final_metric_num = val |
|
|
|
print(f"\n=== FINAL RESULTS for {task} | {model_name} | {method} ===") |
|
print(f"Metric: {final_metric_str}") |
|
print(f"Training Time: {training_time:.2f} seconds\n") |
|
|
|
return { |
|
"task": task, |
|
"model_name": model_name, |
|
"method": method, |
|
"metric_str": final_metric_str, |
|
"metric_num": final_metric_num, |
|
"training_time": training_time, |
|
"trainable_params": trainable_params, |
|
} |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
tasks = ["mnli", "sst2", "cola", "qqp", "qnli", "rte", "mrpc", "stsb"] |
|
methods = ["lora", "diff_lora", "adalora", "vb_lora", "olora", "full_finetuning"] |
|
model_names = ["bert-base-uncased"] |
|
|
|
all_results = [] |
|
for model_name in model_names: |
|
for method in methods: |
|
for task in tasks: |
|
result = run_glue_experiment( |
|
method=method, |
|
model_name=model_name, |
|
task=task, |
|
num_train_epochs=3, |
|
batch_size=32, |
|
lr=2e-5, |
|
seed=42, |
|
diff_r_ratio=1.0 |
|
) |
|
all_results.append(result) |
|
|
|
|
|
from collections import defaultdict |
|
summary = defaultdict(dict) |
|
for res in all_results: |
|
key = f"{res['model_name']} | {res['method']}" |
|
summary[key][res["task"]] = res["metric_str"] |
|
|
|
|
|
indicator_names = { |
|
"mnli": "m/mm", |
|
"sst2": "Acc", |
|
"cola": "Mcc", |
|
"qqp": "Acc/F1", |
|
"qnli": "Acc", |
|
"rte": "Acc", |
|
"mrpc": "Acc", |
|
"stsb": "Corr" |
|
} |
|
|
|
print("\n===== Summary of GLUE Results =====") |
|
header = "Model | Method || " + " | ".join([f"{task} ({indicator_names[task]})" for task in tasks]) + " || Average" |
|
print(header) |
|
print("-" * len(header)) |
|
for key, metrics in summary.items(): |
|
avg_list = [] |
|
display_values = [] |
|
for task in tasks: |
|
val = metrics.get(task, "N/A") |
|
display_values.append(val) |
|
if "/" in val: |
|
parts = val.split("/") |
|
try: |
|
num_val = (float(parts[0]) + float(parts[1])) / 2 |
|
avg_list.append(num_val) |
|
except: |
|
pass |
|
else: |
|
try: |
|
avg_list.append(float(val)) |
|
except: |
|
pass |
|
overall_avg = sum(avg_list) / len(avg_list) if avg_list else 0.0 |
|
print(f"{key} || " + " | ".join(display_values) + f" || {overall_avg:.4f}") |
|
|