Simo76 commited on
Commit
d72fbc5
Β·
1 Parent(s): d8f43b7

Update stable_task_test.py

Browse files
Files changed (1) hide show
  1. experiments/stable_task_test.py +168 -114
experiments/stable_task_test.py CHANGED
@@ -1,26 +1,36 @@
1
  """
2
- Unified-LoRA β€” Stable Task Parity Test
3
- ========================================
4
 
5
  MRPC only, 120 steps, 3 seeds.
6
  Validates that the controller causes zero degradation on stable training.
7
 
 
8
  Usage:
9
- pip install transformers datasets evaluate
10
- python stable_task_test.py
11
  """
12
 
 
13
  import time, random, math, numpy as np, torch, torch.nn as nn
14
  import torch.nn.functional as F, evaluate
15
  from datasets import load_dataset
16
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
17
  from torch.utils.data import DataLoader
18
 
 
19
  import sys, os
20
- sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
21
- from controller import NestedLoRALinear, OrbitalController, inject_nested_lora, set_rank
 
 
 
 
 
 
 
 
22
 
23
- # ── CONFIG ──────────────────────────────────────────
24
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
25
  MODEL = "distilbert-base-uncased"
26
  BATCH = 8
@@ -28,18 +38,24 @@ STEPS = 120
28
  LR = 5e-5
29
  SEEDS = [0, 1, 2]
30
 
 
31
  MAX_RANK = 16
32
  WARMUP = 15
33
  STABLE_WINDOW = 8
34
 
35
- # ── DATA ────────────────────────────────────────────
 
 
 
36
  print("Loading data...")
37
  tok = AutoTokenizer.from_pretrained(MODEL)
38
  ds = load_dataset("glue", "mrpc")
39
 
 
40
  def tok_fn(x):
41
- return tok(x["sentence1"], x["sentence2"],
42
- truncation=True, padding="max_length", max_length=128)
 
43
 
44
  ds = ds.map(tok_fn, batched=True)
45
  ds.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])
@@ -47,126 +63,164 @@ train_loader = DataLoader(ds["train"], batch_size=BATCH, shuffle=True)
47
  val_loader = DataLoader(ds["validation"], batch_size=BATCH)
48
  metric = evaluate.load("glue", "mrpc")
49
 
50
- # ── HELPERS ─────────────────────────────────────────
 
 
 
51
  def build_model():
52
- base = AutoModelForSequenceClassification.from_pretrained(
53
- MODEL, num_labels=2, ignore_mismatched_sizes=True
54
- )
55
- return inject_nested_lora(base, MAX_RANK).to(DEVICE)
 
56
 
57
  def eval_model(model):
58
- model.eval()
59
- preds, labels = [], []
60
- with torch.no_grad():
61
- for batch in val_loader:
62
- x = batch["input_ids"].to(DEVICE)
63
- m = batch["attention_mask"].to(DEVICE)
64
- y = batch["label"].to(DEVICE)
65
- logits = model(input_ids=x, attention_mask=m).logits
66
- preds.extend(logits.argmax(dim=-1).cpu().numpy())
67
- labels.extend(y.cpu().numpy())
68
- return metric.compute(predictions=preds, references=labels)["f1"]
 
69
 
70
  def eff_rank(usage):
71
- tot = sum(usage.values())
72
- return sum(k * v for k, v in usage.items()) / tot if tot > 0 else 0
 
 
 
 
73
 
74
- # ── TRAIN BASELINE ──────────────────────────────────
75
  def train_baseline(model):
76
- opt = torch.optim.AdamW(model.parameters(), lr=LR)
77
- set_rank(model, 16)
78
- it = iter(train_loader)
79
-
80
- for step in range(STEPS):
81
- try:
82
- batch = next(it)
83
- except StopIteration:
84
- it = iter(train_loader); batch = next(it)
85
-
86
- x = batch["input_ids"].to(DEVICE)
87
- m = batch["attention_mask"].to(DEVICE)
88
- y = batch["label"].to(DEVICE)
89
-
90
- loss = model(input_ids=x, attention_mask=m, labels=y).loss
91
- loss.backward()
92
- torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
93
- opt.step()
94
- opt.zero_grad()
95
-
96
- return model
97
-
98
- # ── TRAIN UNIFIED ───────────────────────────────────
99
- def train_unified(model):
100
- ctrl = OrbitalController(warmup=WARMUP, stable_window=STABLE_WINDOW)
101
- opt = torch.optim.AdamW(model.parameters(), lr=LR)
102
- usage = {4: 0, 8: 0, 16: 0}
103
- rank_trace = []
104
- it = iter(train_loader)
105
-
106
- for step in range(STEPS):
107
- try:
108
- batch = next(it)
109
- except StopIteration:
110
- it = iter(train_loader); batch = next(it)
111
-
112
- x = batch["input_ids"].to(DEVICE)
113
- m = batch["attention_mask"].to(DEVICE)
114
- y = batch["label"].to(DEVICE)
115
-
116
- loss = model(input_ids=x, attention_mask=m, labels=y).loss
117
- new_rank = ctrl.step(loss.item())
118
- set_rank(model, new_rank)
119
-
120
- usage[new_rank] += 1
121
- rank_trace.append(new_rank)
122
-
123
- loss.backward()
124
- torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
125
- opt.step()
126
- opt.zero_grad()
127
-
128
- return model, usage, rank_trace, ctrl
129
-
130
- # ── RUN ─────────────────────────────────────────────
 
 
 
 
 
 
 
 
 
 
 
 
131
  print(f"\nDevice: {DEVICE}")
132
  print(f"Task: MRPC, {STEPS} steps")
133
  print("=" * 55)
134
 
 
135
  results = []
136
 
 
137
  for seed in SEEDS:
138
- print(f"\n{'─' * 50}\n SEED {seed}\n{'─' * 50}")
139
-
140
- torch.manual_seed(seed); np.random.seed(seed); random.seed(seed)
141
- base_model = build_model()
142
- base_model = train_baseline(base_model)
143
- f1_base = eval_model(base_model)
144
- del base_model; torch.cuda.empty_cache()
145
-
146
- torch.manual_seed(seed); np.random.seed(seed); random.seed(seed)
147
- uni_model = build_model()
148
- uni_model, usage, trace, ctrl = train_unified(uni_model)
149
- f1_uni = eval_model(uni_model)
150
-
151
- er = eff_rank(usage)
152
- saving = 1 - er / 16
153
- transitions = sum(1 for i in range(1, len(trace)) if trace[i] != trace[i-1])
154
-
155
- print(f"\n BASELINE F1 = {f1_base:.3f} (rank=16 fixed)")
156
- print(f" UNIFIED F1 = {f1_uni:.3f} (eff_rank={er:.1f}, saving={saving*100:.0f}%)")
157
- print(f" delta F1 = {f1_uni - f1_base:+.3f}")
158
- print(f" Usage: r4={usage[4]} r8={usage[8]} r16={usage[16]} transitions={transitions}")
159
-
160
- results.append({
161
- 'seed': seed, 'f1_base': f1_base, 'f1_uni': f1_uni,
162
- 'delta': f1_uni - f1_base, 'eff_rank': er,
163
- })
164
- del uni_model; torch.cuda.empty_cache()
165
-
166
- # ── SUMMARY ─────────────────────────────────────────
 
 
 
 
 
 
 
 
 
 
 
 
 
167
  print(f"\n{'=' * 55}\n SUMMARY\n{'=' * 55}")
168
  f1b = [r['f1_base'] for r in results]
169
  f1u = [r['f1_uni'] for r in results]
 
 
170
  print(f"\n Baseline F1: {np.mean(f1b):.3f} +/- {np.std(f1b):.3f}")
171
- print(f" Unified F1: {np.mean(f1u):.3f} +/- {np.std(f1u):.3f}")
172
  print(f" delta F1: {np.mean([r['delta'] for r in results]):+.3f}")
 
 
1
  """
2
+ Orbital LoRA β€” Stable Task Parity Test
3
+
4
 
5
  MRPC only, 120 steps, 3 seeds.
6
  Validates that the controller causes zero degradation on stable training.
7
 
8
+
9
  Usage:
10
+ pip install transformers datasets evaluate
11
+ python stable_task_test.py
12
  """
13
 
14
+
15
  import time, random, math, numpy as np, torch, torch.nn as nn
16
  import torch.nn.functional as F, evaluate
17
  from datasets import load_dataset
18
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
19
  from torch.utils.data import DataLoader
20
 
21
+
22
  import sys, os
23
+ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(file))))
24
+
25
+
26
+ from nested_lora import NestedLoRALinear, inject_nested_lora
27
+ from orbital_controller import OrbitalController
28
+ from controller import set_rank
29
+
30
+
31
+ ── CONFIG ──────────────────────────────────────────
32
+
33
 
 
34
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
35
  MODEL = "distilbert-base-uncased"
36
  BATCH = 8
 
38
  LR = 5e-5
39
  SEEDS = [0, 1, 2]
40
 
41
+
42
  MAX_RANK = 16
43
  WARMUP = 15
44
  STABLE_WINDOW = 8
45
 
46
+
47
+ ── DATA ────────────────────────────────────────────
48
+
49
+
50
  print("Loading data...")
51
  tok = AutoTokenizer.from_pretrained(MODEL)
52
  ds = load_dataset("glue", "mrpc")
53
 
54
+
55
  def tok_fn(x):
56
+ return tok(x["sentence1"], x["sentence2"],
57
+ truncation=True, padding="max_length", max_length=128)
58
+
59
 
60
  ds = ds.map(tok_fn, batched=True)
61
  ds.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])
 
63
  val_loader = DataLoader(ds["validation"], batch_size=BATCH)
64
  metric = evaluate.load("glue", "mrpc")
65
 
66
+
67
+ ── HELPERS ─────────────────────────────────────────
68
+
69
+
70
  def build_model():
71
+ base = AutoModelForSequenceClassification.from_pretrained(
72
+ MODEL, num_labels=2, ignore_mismatched_sizes=True
73
+ )
74
+ return inject_nested_lora(base, MAX_RANK).to(DEVICE)
75
+
76
 
77
  def eval_model(model):
78
+ model.eval()
79
+ preds, labels = [], []
80
+ with torch.no_grad():
81
+ for batch in val_loader:
82
+ x = batch["input_ids"].to(DEVICE)
83
+ m = batch["attention_mask"].to(DEVICE)
84
+ y = batch["label"].to(DEVICE)
85
+ logits = model(input_ids=x, attention_mask=m).logits
86
+ preds.extend(logits.argmax(dim=-1).cpu().numpy())
87
+ labels.extend(y.cpu().numpy())
88
+ return metric.compute(predictions=preds, references=labels)["f1"]
89
+
90
 
91
  def eff_rank(usage):
92
+ tot = sum(usage.values())
93
+ return sum(k * v for k, v in usage.items()) / tot if tot > 0 else 0
94
+
95
+
96
+ ── TRAIN BASELINE ──────────────────────────────────
97
+
98
 
 
99
  def train_baseline(model):
100
+ opt = torch.optim.AdamW(model.parameters(), lr=LR)
101
+ set_rank(model, 16)
102
+ it = iter(train_loader)
103
+
104
+
105
+ for step in range(STEPS):
106
+ try:
107
+ batch = next(it)
108
+ except StopIteration:
109
+ it = iter(train_loader); batch = next(it)
110
+
111
+ x = batch["input_ids"].to(DEVICE)
112
+ m = batch["attention_mask"].to(DEVICE)
113
+ y = batch["label"].to(DEVICE)
114
+
115
+ loss = model(input_ids=x, attention_mask=m, labels=y).loss
116
+ loss.backward()
117
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
118
+ opt.step()
119
+ opt.zero_grad()
120
+
121
+ return model
122
+
123
+
124
+
125
+ ── TRAIN ORBITAL ───────────────────────────────────
126
+
127
+
128
+ def train_orbital(model):
129
+ ctrl = OrbitalController(warmup=WARMUP, stable_window=STABLE_WINDOW)
130
+ opt = torch.optim.AdamW(model.parameters(), lr=LR)
131
+ usage = {4: 0, 8: 0, 16: 0}
132
+ rank_trace = []
133
+ it = iter(train_loader)
134
+
135
+
136
+ for step in range(STEPS):
137
+ try:
138
+ batch = next(it)
139
+ except StopIteration:
140
+ it = iter(train_loader); batch = next(it)
141
+
142
+ x = batch["input_ids"].to(DEVICE)
143
+ m = batch["attention_mask"].to(DEVICE)
144
+ y = batch["label"].to(DEVICE)
145
+
146
+ loss = model(input_ids=x, attention_mask=m, labels=y).loss
147
+ loss.backward()
148
+
149
+ new_rank = ctrl.step(loss.item())
150
+ new_rank = max(4, min(16, new_rank))
151
+ set_rank(model, new_rank)
152
+
153
+ usage[new_rank] += 1
154
+ rank_trace.append(new_rank)
155
+
156
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
157
+ opt.step()
158
+ opt.zero_grad()
159
+
160
+ return model, usage, rank_trace, ctrl
161
+
162
+
163
+
164
+ ── RUN ─────────────────────────────────────────────
165
+
166
+
167
  print(f"\nDevice: {DEVICE}")
168
  print(f"Task: MRPC, {STEPS} steps")
169
  print("=" * 55)
170
 
171
+
172
  results = []
173
 
174
+
175
  for seed in SEEDS:
176
+ print(f"\n{'─' * 50}\n SEED {seed}\n{'─' * 50}")
177
+
178
+
179
+ torch.manual_seed(seed)
180
+ torch.cuda.manual_seed_all(seed)
181
+ np.random.seed(seed)
182
+ random.seed(seed)
183
+
184
+ base_model = build_model()
185
+ base_model = train_baseline(base_model)
186
+ f1_base = eval_model(base_model)
187
+ del base_model; torch.cuda.empty_cache()
188
+
189
+ torch.manual_seed(seed)
190
+ torch.cuda.manual_seed_all(seed)
191
+ np.random.seed(seed)
192
+ random.seed(seed)
193
+
194
+ uni_model = build_model()
195
+ uni_model, usage, trace, ctrl = train_orbital(uni_model)
196
+ f1_uni = eval_model(uni_model)
197
+
198
+ er = eff_rank(usage)
199
+ saving = 1 - er / 16
200
+ transitions = sum(1 for i in range(1, len(trace)) if trace[i] != trace[i-1])
201
+
202
+ print(f"\n BASELINE F1 = {f1_base:.3f} (rank=16 fixed)")
203
+ print(f" ORBITAL F1 = {f1_uni:.3f} (eff_rank={er:.1f}, saving={saving*100:.0f}%)")
204
+ print(f" delta F1 = {f1_uni - f1_base:+.3f}")
205
+ print(f" Usage: r4={usage[4]} r8={usage[8]} r16={usage[16]} transitions={transitions}")
206
+
207
+ results.append({
208
+ 'seed': seed, 'f1_base': f1_base, 'f1_uni': f1_uni,
209
+ 'delta': f1_uni - f1_base, 'eff_rank': er,
210
+ })
211
+ del uni_model; torch.cuda.empty_cache()
212
+
213
+
214
+
215
+ ── SUMMARY ─────────────────────────────────────────
216
+
217
+
218
  print(f"\n{'=' * 55}\n SUMMARY\n{'=' * 55}")
219
  f1b = [r['f1_base'] for r in results]
220
  f1u = [r['f1_uni'] for r in results]
221
+
222
+
223
  print(f"\n Baseline F1: {np.mean(f1b):.3f} +/- {np.std(f1b):.3f}")
224
+ print(f" Orbital F1: {np.mean(f1u):.3f} +/- {np.std(f1u):.3f}")
225
  print(f" delta F1: {np.mean([r['delta'] for r in results]):+.3f}")
226
+