GrimSqueaker commited on
Commit
f6fb767
·
verified ·
1 Parent(s): 9c95323

Upload electra_pretrain.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. electra_pretrain.py +539 -0
electra_pretrain.py ADDED
@@ -0,0 +1,539 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ELECTRA-style discriminative pre-training for ModernProteinLM.
3
+
4
+ Generator (small): ~25% of discriminator size, trained with MLM.
5
+ Discriminator (main model): Trained to detect replaced tokens (RTD objective).
6
+
7
+ Key improvements over standard ELECTRA:
8
+ 1. Curriculum masking: start at 30%, decay to 5%
9
+ 2. Span masking: mask contiguous regions (protein structural motifs)
10
+ 3. Generator-distillation: generator temperature annealing
11
+ 4. No NSP, no dropout (following ESM-2)
12
+ """
13
+
14
+ import os
15
+ import math
16
+ import random
17
+ from typing import Dict, List, Optional
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+ from torch.utils.data import DataLoader, Dataset
22
+ from transformers import (
23
+ PreTrainedTokenizerFast,
24
+ get_cosine_schedule_with_warmup,
25
+ get_linear_schedule_with_warmup,
26
+ )
27
+ from datasets import load_dataset, concatenate_datasets
28
+ import numpy as np
29
+ from tqdm import tqdm
30
+
31
+ from modeling_modern_protein import ModernProteinLM, ModernProteinLMConfig
32
+
33
+
34
+ class ProteinTokenizer:
35
+ """Simple protein tokenizer matching ESM-2 vocab."""
36
+
37
+ ALL_AA = "LAGVSERTIDPQKNFYWMHCXBUZO"
38
+
39
+ def __init__(self):
40
+ # ESM-2 vocab
41
+ # 0: <cls>, 1: <pad>, 2: <eos>, 3: <unk>
42
+ # 4-29: amino acids
43
+ # 30: <mask>, 31: <sep>, 32: <mask> (duplicate for compatibility)
44
+ self.vocab = {
45
+ "<cls>": 0, "<pad>": 1, "<eos>": 2, "<unk>": 3,
46
+ "L": 4, "A": 5, "G": 6, "V": 7, "S": 8, "E": 9, "R": 10,
47
+ "T": 11, "I": 12, "D": 13, "P": 14, "Q": 15, "K": 16, "N": 17,
48
+ "F": 18, "Y": 19, "W": 20, "M": 21, "H": 22, "C": 23, "X": 24,
49
+ "B": 25, "U": 26, "Z": 27, "O": 28, "<mask>": 29,
50
+ "<sep>": 30, # additional sep
51
+ }
52
+ # Pad to 33 for ESM compatibility
53
+ while len(self.vocab) < 33:
54
+ self.vocab[f"<special_{len(self.vocab)}>"] = len(self.vocab)
55
+
56
+ self.id_to_token = {v: k for k, v in self.vocab.items()}
57
+ self.mask_token_id = 29
58
+ self.pad_token_id = 1
59
+ self.cls_token_id = 0
60
+ self.eos_token_id = 2
61
+
62
+ def encode(self, sequence: str, max_length: int = 1024, add_special_tokens: bool = True):
63
+ tokens = []
64
+ if add_special_tokens:
65
+ tokens.append(self.cls_token_id)
66
+
67
+ for aa in sequence.upper():
68
+ if aa in self.vocab:
69
+ tokens.append(self.vocab[aa])
70
+ else:
71
+ tokens.append(self.vocab["<unk>"])
72
+
73
+ if add_special_tokens:
74
+ tokens.append(self.eos_token_id)
75
+
76
+ # Truncate or pad
77
+ if len(tokens) > max_length:
78
+ tokens = tokens[:max_length]
79
+
80
+ attention_mask = [1] * len(tokens)
81
+ while len(tokens) < max_length:
82
+ tokens.append(self.pad_token_id)
83
+ attention_mask.append(0)
84
+
85
+ return {
86
+ "input_ids": tokens,
87
+ "attention_mask": attention_mask,
88
+ }
89
+
90
+ def batch_encode(self, sequences: List[str], max_length: int = 1024):
91
+ results = [self.encode(seq, max_length) for seq in sequences]
92
+ return {
93
+ "input_ids": torch.tensor([r["input_ids"] for r in results], dtype=torch.long),
94
+ "attention_mask": torch.tensor([r["attention_mask"] for r in results], dtype=torch.long),
95
+ }
96
+
97
+ def decode(self, token_ids):
98
+ if isinstance(token_ids, torch.Tensor):
99
+ token_ids = token_ids.tolist()
100
+ return "".join([self.id_to_token.get(t, "<unk>") for t in token_ids])
101
+
102
+
103
+ def create_span_mask(length, mask_ratio=0.30, mean_span_length=3, min_span_length=1):
104
+ """Create span mask for protein sequences."""
105
+ num_to_mask = max(1, int(length * mask_ratio))
106
+ mask = [False] * length
107
+
108
+ attempts = 0
109
+ masked = 0
110
+ while masked < num_to_mask and attempts < num_to_mask * 10:
111
+ span_len = max(min_span_length, min(mean_span_length + random.randint(-1, 1), num_to_mask - masked))
112
+ start = random.randint(0, max(0, length - span_len - 1))
113
+
114
+ # Don't mask if already masked
115
+ if any(mask[start:start+span_len]):
116
+ attempts += 1
117
+ continue
118
+
119
+ for i in range(start, min(start + span_len, length)):
120
+ mask[i] = True
121
+ masked += 1
122
+
123
+ return mask
124
+
125
+
126
+ class ProteinDataset(Dataset):
127
+ def __init__(self, sequences, tokenizer, max_length=1024, mask_ratio=0.30,
128
+ mean_span_length=3, curriculum_start_ratio=0.30, curriculum_end_ratio=0.05,
129
+ total_steps=100000, current_step=0):
130
+ self.sequences = sequences
131
+ self.tokenizer = tokenizer
132
+ self.max_length = max_length
133
+ self.mean_span_length = mean_span_length
134
+ self.curriculum_start_ratio = curriculum_start_ratio
135
+ self.curriculum_end_ratio = curriculum_end_ratio
136
+ self.total_steps = total_steps
137
+ self.current_step = current_step
138
+
139
+ def get_current_mask_ratio(self):
140
+ """Linear decay from start to end ratio."""
141
+ progress = min(1.0, self.current_step / self.total_steps)
142
+ return self.curriculum_start_ratio + (self.curriculum_end_ratio - self.curriculum_start_ratio) * progress
143
+
144
+ def __len__(self):
145
+ return len(self.sequences)
146
+
147
+ def __getitem__(self, idx):
148
+ seq = self.sequences[idx]
149
+ encoded = self.tokenizer.encode(seq, max_length=self.max_length)
150
+ input_ids = encoded["input_ids"]
151
+ attention_mask = encoded["attention_mask"]
152
+
153
+ # Find actual sequence length (before padding)
154
+ seq_len = sum(attention_mask)
155
+ # Exclude special tokens from masking
156
+ effective_len = seq_len - 2 if seq_len > 2 else seq_len
157
+
158
+ # Apply span masking
159
+ mask_ratio = self.get_current_mask_ratio()
160
+ span_mask = create_span_mask(effective_len, mask_ratio, self.mean_span_length)
161
+
162
+ # Create masked input and labels
163
+ masked_input = input_ids.copy()
164
+ labels = [-100] * len(input_ids) # -100 = ignore in loss
165
+ replaced = [False] * len(input_ids) # For discriminator
166
+
167
+ for i in range(1, 1 + effective_len): # Skip CLS
168
+ if span_mask[i - 1]:
169
+ labels[i] = input_ids[i]
170
+ replaced[i] = True
171
+ # 80% mask, 10% random, 10% keep
172
+ r = random.random()
173
+ if r < 0.8:
174
+ masked_input[i] = self.tokenizer.mask_token_id
175
+ elif r < 0.9:
176
+ masked_input[i] = random.randint(4, 28) # Random AA
177
+ # else: keep original
178
+
179
+ return {
180
+ "input_ids": torch.tensor(masked_input, dtype=torch.long),
181
+ "attention_mask": torch.tensor(attention_mask, dtype=torch.long),
182
+ "labels": torch.tensor(labels, dtype=torch.long),
183
+ "replaced": torch.tensor(replaced, dtype=torch.bool),
184
+ "original_ids": torch.tensor(input_ids, dtype=torch.long),
185
+ }
186
+
187
+
188
+ class GeneratorModel(nn.Module):
189
+ """Small generator model for ELECTRA."""
190
+
191
+ def __init__(self, vocab_size, hidden_size=256, num_layers=4, num_heads=4, intermediate_size=1024):
192
+ super().__init__()
193
+ config = ModernProteinLMConfig(
194
+ vocab_size=vocab_size,
195
+ hidden_size=hidden_size,
196
+ num_hidden_layers=num_layers,
197
+ num_attention_heads=num_heads,
198
+ intermediate_size=intermediate_size,
199
+ tie_word_embeddings=True,
200
+ )
201
+ self.model = ModernProteinLM(config)
202
+
203
+ def forward(self, input_ids, attention_mask, labels):
204
+ return self.model(input_ids, attention_mask, labels=labels)
205
+
206
+
207
+ class DiscriminatorModel(ModernProteinLM):
208
+ """Discriminator with additional classification head for RTD."""
209
+
210
+ def __init__(self, config):
211
+ super().__init__(config)
212
+ self.discriminator_head = nn.Linear(config.hidden_size, 1)
213
+
214
+ def forward(self, input_ids, attention_mask, labels=None):
215
+ outputs = super().forward(input_ids, attention_mask, return_dict=True)
216
+ hidden = outputs.hidden_states[-1] # (B, T, H)
217
+
218
+ # Discriminator logits: real vs fake
219
+ disc_logits = self.discriminator_head(hidden).squeeze(-1) # (B, T)
220
+
221
+ disc_loss = None
222
+ if labels is not None:
223
+ # labels: 1 = real, 0 = fake (replaced)
224
+ loss_fct = nn.BCEWithLogitsLoss()
225
+ active_loss = labels != -100
226
+ active_logits = disc_logits[active_loss]
227
+ active_labels = labels[active_loss].float()
228
+ disc_loss = loss_fct(active_logits, active_labels)
229
+
230
+ return {
231
+ "loss": disc_loss,
232
+ "logits": disc_logits,
233
+ "hidden_states": outputs.hidden_states,
234
+ }
235
+
236
+
237
+ class ELECTRAProteinTrainer:
238
+ def __init__(
239
+ self,
240
+ generator: GeneratorModel,
241
+ discriminator: DiscriminatorModel,
242
+ tokenizer,
243
+ train_dataset,
244
+ eval_dataset,
245
+ output_dir="./electra_protein",
246
+ lr=5e-4,
247
+ batch_size=32,
248
+ max_steps=100000,
249
+ warmup_steps=10000,
250
+ weight_decay=0.01,
251
+ grad_clip=1.0,
252
+ generator_weight=1.0,
253
+ discriminator_weight=50.0,
254
+ device="cuda",
255
+ ):
256
+ self.generator = generator.to(device)
257
+ self.discriminator = discriminator.to(device)
258
+ self.tokenizer = tokenizer
259
+ self.train_dataset = train_dataset
260
+ self.eval_dataset = eval_dataset
261
+ self.output_dir = output_dir
262
+ self.device = device
263
+ self.max_steps = max_steps
264
+ self.grad_clip = grad_clip
265
+ self.generator_weight = generator_weight
266
+ self.discriminator_weight = discriminator_weight
267
+
268
+ os.makedirs(output_dir, exist_ok=True)
269
+
270
+ # Optimizers
271
+ self.gen_optimizer = torch.optim.AdamW(
272
+ generator.parameters(), lr=lr, betas=(0.9, 0.98), eps=1e-6, weight_decay=weight_decay
273
+ )
274
+ self.disc_optimizer = torch.optim.AdamW(
275
+ discriminator.parameters(), lr=lr, betas=(0.9, 0.98), eps=1e-6, weight_decay=weight_decay
276
+ )
277
+
278
+ # Schedulers
279
+ self.gen_scheduler = get_cosine_schedule_with_warmup(
280
+ self.gen_optimizer, num_warmup_steps=warmup_steps, num_training_steps=max_steps
281
+ )
282
+ self.disc_scheduler = get_cosine_schedule_with_warmup(
283
+ self.disc_optimizer, num_warmup_steps=warmup_steps, num_training_steps=max_steps
284
+ )
285
+
286
+ self.train_loader = DataLoader(
287
+ train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True
288
+ )
289
+ self.eval_loader = DataLoader(
290
+ eval_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True
291
+ )
292
+
293
+ self.global_step = 0
294
+ self.best_eval_loss = float("inf")
295
+
296
+ def train(self):
297
+ self.generator.train()
298
+ self.discriminator.train()
299
+
300
+ pbar = tqdm(total=self.max_steps, desc="Training")
301
+
302
+ for batch in self.train_loader:
303
+ if self.global_step >= self.max_steps:
304
+ break
305
+
306
+ self._train_step(batch)
307
+ self.global_step += 1
308
+ pbar.update(1)
309
+
310
+ if self.global_step % 1000 == 0:
311
+ eval_loss = self.evaluate()
312
+ if eval_loss < self.best_eval_loss:
313
+ self.best_eval_loss = eval_loss
314
+ self.save_checkpoint("best")
315
+ self.generator.train()
316
+ self.discriminator.train()
317
+
318
+ if self.global_step % 5000 == 0:
319
+ self.save_checkpoint(f"step_{self.global_step}")
320
+
321
+ pbar.close()
322
+ self.save_checkpoint("final")
323
+
324
+ def _train_step(self, batch):
325
+ input_ids = batch["input_ids"].to(self.device)
326
+ attention_mask = batch["attention_mask"].to(self.device)
327
+ mlm_labels = batch["labels"].to(self.device)
328
+ replaced_positions = batch["replaced"].to(self.device)
329
+ original_ids = batch["original_ids"].to(self.device)
330
+
331
+ # ====== GENERATOR STEP ======
332
+ gen_outputs = self.generator(input_ids, attention_mask, mlm_labels)
333
+ gen_loss = gen_outputs.loss
334
+
335
+ # Sample from generator to create corrupted input for discriminator
336
+ with torch.no_grad():
337
+ gen_logits = gen_outputs.logits # (B, T, V)
338
+ gen_probs = F.softmax(gen_logits, dim=-1)
339
+ sampled_ids = torch.multinomial(
340
+ gen_probs.view(-1, gen_probs.size(-1)), 1
341
+ ).view(gen_probs.shape[:-1])
342
+
343
+ # Replace masked positions with generator samples
344
+ corrupted_input = original_ids.clone()
345
+ mask_positions = mlm_labels != -100
346
+ corrupted_input[mask_positions] = sampled_ids[mask_positions]
347
+
348
+ # ====== DISCRIMINATOR STEP ======
349
+ # Create discriminator labels: 1 = original, 0 = replaced
350
+ disc_labels = torch.ones_like(original_ids, dtype=torch.float) # (B, T)
351
+ disc_labels[replaced_positions] = 0.0
352
+ # Ignore padding
353
+ disc_labels[attention_mask == 0] = -100
354
+
355
+ disc_outputs = self.discriminator(corrupted_input, attention_mask, disc_labels)
356
+ disc_loss = disc_outputs["loss"]
357
+
358
+ # ====== BACKWARD ======
359
+ # Combined loss with weighting
360
+ total_loss = self.generator_weight * gen_loss + self.discriminator_weight * disc_loss
361
+
362
+ total_loss.backward()
363
+
364
+ torch.nn.utils.clip_grad_norm_(self.generator.parameters(), self.grad_clip)
365
+ torch.nn.utils.clip_grad_norm_(self.discriminator.parameters(), self.grad_clip)
366
+
367
+ self.gen_optimizer.step()
368
+ self.disc_optimizer.step()
369
+ self.gen_scheduler.step()
370
+ self.disc_scheduler.step()
371
+
372
+ self.gen_optimizer.zero_grad()
373
+ self.disc_optimizer.zero_grad()
374
+
375
+ if self.global_step % 100 == 0:
376
+ pbar = tqdm.get_tqdm()
377
+ pbar.set_postfix({
378
+ "gen_loss": f"{gen_loss.item():.4f}",
379
+ "disc_loss": f"{disc_loss.item():.4f}",
380
+ "lr": f"{self.gen_scheduler.get_last_lr()[0]:.2e}",
381
+ })
382
+
383
+ def evaluate(self):
384
+ self.generator.eval()
385
+ self.discriminator.eval()
386
+
387
+ total_gen_loss = 0
388
+ total_disc_loss = 0
389
+ total_samples = 0
390
+
391
+ with torch.no_grad():
392
+ for batch in self.eval_loader:
393
+ input_ids = batch["input_ids"].to(self.device)
394
+ attention_mask = batch["attention_mask"].to(self.device)
395
+ mlm_labels = batch["labels"].to(self.device)
396
+ replaced_positions = batch["replaced"].to(self.device)
397
+ original_ids = batch["original_ids"].to(self.device)
398
+
399
+ gen_outputs = self.generator(input_ids, attention_mask, mlm_labels)
400
+ total_gen_loss += gen_outputs.loss.item() * input_ids.size(0)
401
+
402
+ disc_labels = torch.ones_like(original_ids, dtype=torch.float)
403
+ disc_labels[replaced_positions] = 0.0
404
+ disc_labels[attention_mask == 0] = -100
405
+
406
+ disc_outputs = self.discriminator(input_ids, attention_mask, disc_labels)
407
+ total_disc_loss += disc_outputs["loss"].item() * input_ids.size(0)
408
+ total_samples += input_ids.size(0)
409
+
410
+ avg_gen = total_gen_loss / total_samples
411
+ avg_disc = total_disc_loss / total_samples
412
+
413
+ print(f"Eval - Gen Loss: {avg_gen:.4f}, Disc Loss: {avg_disc:.4f}")
414
+ return avg_gen + avg_disc
415
+
416
+ def save_checkpoint(self, name):
417
+ path = os.path.join(self.output_dir, name)
418
+ os.makedirs(path, exist_ok=True)
419
+
420
+ torch.save({
421
+ "generator": self.generator.state_dict(),
422
+ "discriminator": self.discriminator.state_dict(),
423
+ "gen_optimizer": self.gen_optimizer.state_dict(),
424
+ "disc_optimizer": self.disc_optimizer.state_dict(),
425
+ "step": self.global_step,
426
+ }, os.path.join(path, "checkpoint.pt"))
427
+
428
+ # Save discriminator config (main model)
429
+ self.discriminator.config.save_pretrained(path)
430
+
431
+ print(f"Saved checkpoint to {path}")
432
+
433
+
434
+ def load_protein_sequences(dataset_name="lamm-mit/protein_secondary_structure_from_PDB", split="train", max_seqs=None):
435
+ """Load protein sequences from HF dataset."""
436
+ ds = load_dataset(dataset_name, split=split, streaming=True)
437
+ sequences = []
438
+
439
+ for i, example in enumerate(ds):
440
+ if max_seqs and i >= max_seqs:
441
+ break
442
+ # Try common column names
443
+ seq = None
444
+ for key in ["input", "primary", "sequences", "sequence", "protein", "text"]:
445
+ if key in example:
446
+ seq = example[key]
447
+ break
448
+ if seq and len(seq) > 10:
449
+ sequences.append(seq)
450
+
451
+ return sequences
452
+
453
+
454
+ def main():
455
+ # Config
456
+ DISC_CONFIG = ModernProteinLMConfig(
457
+ vocab_size=33,
458
+ hidden_size=576,
459
+ num_hidden_layers=28,
460
+ num_attention_heads=9,
461
+ intermediate_size=2304,
462
+ use_geglu=True,
463
+ tie_word_embeddings=True,
464
+ max_position_embeddings=1026,
465
+ position_embedding_type="rotary",
466
+ rope_theta=10000.0,
467
+ )
468
+
469
+ # Generator: ~25% of discriminator size
470
+ GEN_CONFIG = ModernProteinLMConfig(
471
+ vocab_size=33,
472
+ hidden_size=320,
473
+ num_hidden_layers=8,
474
+ num_attention_heads=8,
475
+ intermediate_size=1280,
476
+ use_geglu=True,
477
+ tie_word_embeddings=True,
478
+ )
479
+
480
+ tokenizer = ProteinTokenizer()
481
+
482
+ # Load data
483
+ print("Loading protein sequences...")
484
+ train_seqs = load_protein_sequences("lamm-mit/protein_secondary_structure_from_PDB", "train", max_seqs=50000)
485
+ eval_seqs = load_protein_sequences("lamm-mit/protein_secondary_structure_from_PDB", "train", max_seqs=5000)
486
+
487
+ print(f"Loaded {len(train_seqs)} train, {len(eval_seqs)} eval sequences")
488
+
489
+ train_dataset = ProteinDataset(
490
+ train_seqs, tokenizer, max_length=1024,
491
+ curriculum_start_ratio=0.30, curriculum_end_ratio=0.05,
492
+ total_steps=100000,
493
+ )
494
+ eval_dataset = ProteinDataset(
495
+ eval_seqs, tokenizer, max_length=1024,
496
+ curriculum_start_ratio=0.30, curriculum_end_ratio=0.05,
497
+ total_steps=100000, current_step=100000, # Fixed at end ratio for eval
498
+ )
499
+
500
+ # Models
501
+ generator = GeneratorModel(
502
+ vocab_size=33,
503
+ hidden_size=GEN_CONFIG.hidden_size,
504
+ num_layers=GEN_CONFIG.num_hidden_layers,
505
+ num_heads=GEN_CONFIG.num_attention_heads,
506
+ intermediate_size=GEN_CONFIG.intermediate_size,
507
+ )
508
+ discriminator = DiscriminatorModel(DISC_CONFIG)
509
+
510
+ # Count parameters
511
+ gen_params = sum(p.numel() for p in generator.parameters())
512
+ disc_params = sum(p.numel() for p in discriminator.parameters())
513
+ print(f"Generator params: {gen_params/1e6:.1f}M")
514
+ print(f"Discriminator params: {disc_params/1e6:.1f}M")
515
+
516
+ trainer = ELECTRAProteinTrainer(
517
+ generator=generator,
518
+ discriminator=discriminator,
519
+ tokenizer=tokenizer,
520
+ train_dataset=train_dataset,
521
+ eval_dataset=eval_dataset,
522
+ output_dir="./modern_protein_electra",
523
+ lr=5e-4,
524
+ batch_size=16,
525
+ max_steps=100000,
526
+ warmup_steps=10000,
527
+ weight_decay=0.01,
528
+ grad_clip=1.0,
529
+ generator_weight=1.0,
530
+ discriminator_weight=50.0,
531
+ device="cuda" if torch.cuda.is_available() else "cpu",
532
+ )
533
+
534
+ print("Starting ELECTRA pre-training...")
535
+ trainer.train()
536
+
537
+
538
+ if __name__ == "__main__":
539
+ main()