ataeff commited on
Commit
0932565
·
verified ·
1 Parent(s): cfcec39

Delete train_resonate.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train_resonate.py +0 -215
train_resonate.py DELETED
@@ -1,215 +0,0 @@
1
- """
2
- Train Gemma-3 270M-IT with LoRA for /resonate/ format.
3
- Freeze embed_tokens (63% of model = all 140 languages preserved).
4
- LoRA rank 16 on Q+V projections only — minimal intervention.
5
- """
6
-
7
- import json, os, sys, time, random, math
8
- import torch
9
- from transformers import AutoTokenizer, AutoModelForCausalLM
10
- from peft import LoraConfig, get_peft_model, TaskType
11
-
12
- # --- config ---
13
- MODEL = 'unsloth/gemma-3-270m-it'
14
- RANK = 16
15
- ALPHA = 32
16
- LR = 2e-4
17
- EPOCHS = 3
18
- BATCH = 4
19
- GRAD_ACCUM = 4 # effective batch 16
20
- MAX_LEN = 1024
21
- EVAL_EVERY = 100
22
- SAVE_DIR = 'gemma3-resonate'
23
-
24
- # --- load data ---
25
- print('[data] Loading...')
26
- data = []
27
- for path in ['resonance_yent_full.jsonl', 'resonance_gold_10.jsonl']:
28
- if os.path.exists(path):
29
- with open(path) as f:
30
- for line in f:
31
- d = json.loads(line)
32
- data.append(d)
33
- print(f'[data] {len(data)} examples')
34
- random.seed(42)
35
- random.shuffle(data)
36
-
37
- split = int(len(data) * 0.95)
38
- train_data = data[:split]
39
- val_data = data[split:]
40
- print(f'[data] train={len(train_data)}, val={len(val_data)}')
41
-
42
- # --- load model ---
43
- print('[model] Loading Gemma-3 270M-IT...')
44
- tokenizer = AutoTokenizer.from_pretrained(MODEL)
45
- model = AutoModelForCausalLM.from_pretrained(MODEL, dtype=torch.bfloat16).cuda()
46
-
47
- n_total = sum(p.numel() for p in model.parameters())
48
- n_embed = sum(p.numel() for n, p in model.named_parameters() if 'embed_tokens' in n)
49
- print(f'[model] {n_total/1e6:.1f}M total, {n_embed/1e6:.1f}M in embed_tokens ({n_embed*100/n_total:.0f}%)')
50
-
51
- # --- LoRA config ---
52
- lora_config = LoraConfig(
53
- task_type=TaskType.CAUSAL_LM,
54
- r=RANK,
55
- lora_alpha=ALPHA,
56
- lora_dropout=0.05,
57
- target_modules=['q_proj', 'v_proj'],
58
- bias='none',
59
- )
60
-
61
- model = get_peft_model(model, lora_config)
62
- trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
63
- frozen = sum(p.numel() for p in model.parameters() if not p.requires_grad)
64
- print(f'[lora] trainable={trainable/1e6:.2f}M ({trainable*100/n_total:.1f}%), frozen={frozen/1e6:.1f}M')
65
-
66
- # --- prepare data ---
67
- def format_example(d):
68
- msgs = d['messages']
69
- text = ''
70
- for m in msgs:
71
- if m['role'] == 'user':
72
- text += f"<start_of_turn>user\n{m['content']}<end_of_turn>\n"
73
- elif m['role'] == 'assistant':
74
- text += f"<start_of_turn>model\n{m['content']}<end_of_turn>\n"
75
- return text
76
-
77
- def tokenize_with_labels(text):
78
- toks = tokenizer(text, truncation=True, max_length=MAX_LEN, return_tensors='pt')
79
- input_ids = toks['input_ids'][0]
80
- labels = input_ids.clone()
81
- # mask user turn — only train on model output
82
- model_marker = '<start_of_turn>model\n'
83
- idx = text.find(model_marker)
84
- if idx > 0:
85
- prefix = text[:idx + len(model_marker)]
86
- prefix_toks = tokenizer(prefix, add_special_tokens=False)['input_ids']
87
- mask_len = min(len(prefix_toks), len(labels))
88
- labels[:mask_len] = -100
89
- return input_ids, labels
90
-
91
- print('[data] Tokenizing...')
92
- train_tokens = []
93
- for d in train_data:
94
- text = format_example(d)
95
- ids, labels = tokenize_with_labels(text)
96
- if len(ids) > 10:
97
- train_tokens.append((ids, labels))
98
-
99
- val_tokens = []
100
- for d in val_data:
101
- text = format_example(d)
102
- ids, labels = tokenize_with_labels(text)
103
- if len(ids) > 10:
104
- val_tokens.append((ids, labels))
105
-
106
- print(f'[data] {len(train_tokens)} train, {len(val_tokens)} val tokenized')
107
- if train_tokens:
108
- avg_len = sum(len(t[0]) for t in train_tokens) / len(train_tokens)
109
- print(f'[data] avg length: {avg_len:.0f} tokens')
110
-
111
- # --- training ---
112
- optimizer = torch.optim.AdamW(
113
- [p for p in model.parameters() if p.requires_grad],
114
- lr=LR, weight_decay=0.01
115
- )
116
-
117
- total_steps = len(train_tokens) * EPOCHS // (BATCH * GRAD_ACCUM)
118
- warmup_steps = int(total_steps * 0.1)
119
- print(f'[train] {total_steps} steps, {warmup_steps} warmup, {EPOCHS} epochs')
120
-
121
- def get_lr(step):
122
- if step < warmup_steps:
123
- return LR * step / max(warmup_steps, 1)
124
- progress = (step - warmup_steps) / max(total_steps - warmup_steps, 1)
125
- return LR * 0.5 * (1 + math.cos(math.pi * progress))
126
-
127
- model.train()
128
- step = 0
129
- best_val_loss = float('inf')
130
- os.makedirs(SAVE_DIR, exist_ok=True)
131
- t0 = time.time()
132
-
133
- for epoch in range(EPOCHS):
134
- random.shuffle(train_tokens)
135
- epoch_loss = 0
136
- epoch_count = 0
137
- optimizer.zero_grad()
138
-
139
- for i, (ids, labels) in enumerate(train_tokens):
140
- ids = ids.unsqueeze(0).cuda()
141
- labels = labels.unsqueeze(0).cuda()
142
-
143
- outputs = model(input_ids=ids, labels=labels)
144
- loss = outputs.loss / GRAD_ACCUM
145
- loss.backward()
146
-
147
- epoch_loss += outputs.loss.item()
148
- epoch_count += 1
149
-
150
- if (i + 1) % GRAD_ACCUM == 0:
151
- step += 1
152
- lr = get_lr(step)
153
- for g in optimizer.param_groups:
154
- g['lr'] = lr
155
-
156
- torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
157
- optimizer.step()
158
- optimizer.zero_grad()
159
-
160
- if step % 50 == 0:
161
- avg = epoch_loss / epoch_count
162
- elapsed = time.time() - t0
163
- print(f' ep{epoch+1} step {step}/{total_steps} | train loss {avg:.4f} | lr {lr:.6f} | {elapsed:.0f}s', flush=True)
164
-
165
- if step % EVAL_EVERY == 0 and val_tokens:
166
- model.eval()
167
- val_loss = 0
168
- with torch.no_grad():
169
- for vid, vlbl in val_tokens[:50]:
170
- vid = vid.unsqueeze(0).cuda()
171
- vlbl = vlbl.unsqueeze(0).cuda()
172
- out = model(input_ids=vid, labels=vlbl)
173
- val_loss += out.loss.item()
174
- val_loss /= min(50, len(val_tokens))
175
- print(f' >>> VAL loss {val_loss:.4f} (best {best_val_loss:.4f})', flush=True)
176
-
177
- if val_loss < best_val_loss:
178
- best_val_loss = val_loss
179
- model.save_pretrained(f'{SAVE_DIR}/best')
180
- tokenizer.save_pretrained(f'{SAVE_DIR}/best')
181
- print(f' >>> SAVED best', flush=True)
182
-
183
- model.train()
184
-
185
- avg = epoch_loss / max(epoch_count, 1)
186
- print(f'[epoch {epoch+1}] avg loss {avg:.4f}', flush=True)
187
-
188
- model.save_pretrained(f'{SAVE_DIR}/final')
189
- tokenizer.save_pretrained(f'{SAVE_DIR}/final')
190
- print(f'[done] best val loss: {best_val_loss:.4f}')
191
-
192
- # --- test generation ---
193
- print('\n[gen] Testing on 5 languages...')
194
- model.eval()
195
-
196
- prompts = [
197
- 'What is the meaning of life?',
198
- 'Explain recursion simply.',
199
- 'Dis-moi quelque chose en francais',
200
- 'Was denkst du ueber die Zukunft?',
201
- 'Why do programmers mass delete repos at 3am?',
202
- ]
203
-
204
- for p in prompts:
205
- text = f'<start_of_turn>user\n{p}<end_of_turn>\n<start_of_turn>model\n'
206
- ids = tokenizer(text, return_tensors='pt').input_ids.cuda()
207
- with torch.no_grad():
208
- out = model.generate(ids, max_new_tokens=200, do_sample=True, temperature=0.7, top_k=40)
209
- gen = tokenizer.decode(out[0], skip_special_tokens=True)
210
- answer = gen.split('model\n')[-1] if 'model\n' in gen else gen[-300:]
211
- print(f'\n>>> {p}')
212
- print(answer[:300])
213
- print('---')
214
-
215
- print(f'\n[done] Total time: {time.time()-t0:.0f}s')