naxalpha commited on
Commit
ad01999
1 Parent(s): 9f1ebfc
Files changed (1) hide show
  1. app.py +46 -38
app.py CHANGED
@@ -13,17 +13,19 @@ from gated_state_spaces_pytorch import GatedStateSpacesLM
13
  from gated_state_spaces_pytorch.autoregressive_wrapper import AutoregressiveWrapper
14
 
15
  from c4x import C4X
 
16
 
17
 
18
- if __name__ == '__main__':
19
- wandb.init(
20
- project="gated-state-space",
21
- entity="naxalpha",
22
  )
23
 
24
- # gpt_2 = GPT2LMHeadModel.from_pretrained('gpt2-xl')
25
- # gpt_2.requires_grad_(False)
26
- # gpt_2 = gpt_2.cuda()
 
 
27
 
28
  f_emb = 1600
29
  model = AutoregressiveWrapper(
@@ -32,56 +34,57 @@ if __name__ == '__main__':
32
  dim=f_emb,
33
  depth=24,
34
  ),
35
- )
36
- wandb.watch(model)
37
-
38
- # emb = gpt_2.state_dict()['transformer.wte.weight']
39
-
40
  model.net.token_emb.weight.requires_grad_(False)
41
- # model.net.token_emb.weight.copy_(emb)
42
-
43
  model.net.to_logits.weight.requires_grad_(False)
44
- # model.net.to_logits.weight.copy_(emb)
45
-
46
  model.net.to_logits = nn.Sequential(
47
  nn.LayerNorm(f_emb),
48
  model.net.to_logits,
49
  )
 
 
 
 
50
 
51
  model.load_state_dict(torch.load('model.pt'))
52
- model = model.cuda()
53
  optim = AdamW(model.parameters(), 2e-5)
54
 
55
- bs = 8
56
  kk = 128
57
  dsx = C4X(kk+1)
58
  dlx = DataLoader(
59
  dsx,
60
  batch_size=bs,
61
- num_workers=16,
62
  )
63
 
64
  k = 4
65
- prog = tqdm(dlx)
 
 
 
66
  optim.zero_grad()
67
-
68
  for i, batch in enumerate(prog):
69
- batch = batch.cuda()
70
- los = model(batch)
71
-
72
- (los / k).backward()
73
- if (i+1) % k == 0:
74
- clip_grad_norm_(
75
- model.parameters(),
76
- max_norm=1.,
77
- )
 
78
  optim.step()
79
  optim.zero_grad()
80
 
81
- if i % 1000 == 0:
 
 
 
82
  b, n = 4, 512
83
- init = torch.tensor([[50256]]*b).cuda()
84
- prd = model.generate(init, n)
85
  prd = [dsx.decode(p) for p in prd]
86
  try:
87
  wandb.log(dict(
@@ -92,9 +95,14 @@ if __name__ == '__main__':
92
  )), step=i)
93
  except Exception as ex:
94
  print('Failed to log to W&B...', ex)
95
- torch.save(model.state_dict(), 'model.pt')
 
 
 
 
 
 
 
96
 
97
- wandb.log(dict(
98
- loss=los.item(),
99
- ), step=i)
100
- prog.set_postfix(loss=los.item())
 
13
  from gated_state_spaces_pytorch.autoregressive_wrapper import AutoregressiveWrapper
14
 
15
  from c4x import C4X
16
+ from accelerate import Accelerator
17
 
18
 
19
+ def main():
20
+ accelerator = Accelerator(
21
+ gradient_accumulation_steps=4,
 
22
  )
23
 
24
+ if accelerator.is_main_process:
25
+ wandb.init(
26
+ project="gated-state-space",
27
+ entity="naxalpha",
28
+ )
29
 
30
  f_emb = 1600
31
  model = AutoregressiveWrapper(
 
34
  dim=f_emb,
35
  depth=24,
36
  ),
37
+ )
 
 
 
 
38
  model.net.token_emb.weight.requires_grad_(False)
 
 
39
  model.net.to_logits.weight.requires_grad_(False)
 
 
40
  model.net.to_logits = nn.Sequential(
41
  nn.LayerNorm(f_emb),
42
  model.net.to_logits,
43
  )
44
+ model = model.to(accelerator.device)
45
+
46
+ if accelerator.is_main_process:
47
+ wandb.watch(model)
48
 
49
  model.load_state_dict(torch.load('model.pt'))
 
50
  optim = AdamW(model.parameters(), 2e-5)
51
 
52
+ bs = 16
53
  kk = 128
54
  dsx = C4X(kk+1)
55
  dlx = DataLoader(
56
  dsx,
57
  batch_size=bs,
58
+ num_workers=8,
59
  )
60
 
61
  k = 4
62
+ prog = tqdm(dlx, disable=not accelerator.is_main_process)
63
+
64
+ model, optim, dlx = accelerator.prepare(model, optim, dlx)
65
+
66
  optim.zero_grad()
 
67
  for i, batch in enumerate(prog):
68
+ batch = batch.to(accelerator.device)
69
+ with accelerator.accumulate(model):
70
+ with accelerator.autocast():
71
+ los = model(batch)
72
+ accelerator.backward(los)
73
+ if accelerator.sync_gradients:
74
+ accelerator.clip_grad_norm_(
75
+ model.parameters(),
76
+ 1.0,
77
+ )
78
  optim.step()
79
  optim.zero_grad()
80
 
81
+ if i % 1000 == 0 and accelerator.is_main_process:
82
+ print('generating...')
83
+ accelerator.wait_for_everyone()
84
+ unwrapped_model = accelerator.unwrap_model(model)
85
  b, n = 4, 512
86
+ init = torch.tensor([[50256]]*b).to(accelerator.device)
87
+ prd = unwrapped_model.generate(init, n)
88
  prd = [dsx.decode(p) for p in prd]
89
  try:
90
  wandb.log(dict(
 
95
  )), step=i)
96
  except Exception as ex:
97
  print('Failed to log to W&B...', ex)
98
+ accelerator.save(unwrapped_model.state_dict(), 'model.pt')
99
+
100
+ if i % 10 == 0 and accelerator.is_main_process:
101
+ print('logging...')
102
+ wandb.log(dict(
103
+ loss=los.item(),
104
+ ), step=i)
105
+ prog.set_postfix(loss=los.item())
106
 
107
+ if __name__ == '__main__':
108
+ main()