naxautify commited on
Commit
e6333f5
1 Parent(s): 46c7fa3
Files changed (8) hide show
  1. .gitignore +3 -0
  2. app.py +139 -0
  3. c4x.py +61 -0
  4. model2.pt +3 -0
  5. model3.pt +3 -0
  6. model4.pt +3 -0
  7. pile.py +107 -0
  8. pile_hf.py +50 -0
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ wandb
2
+ __pycache__
3
+ .ipynb_checkpoints
app.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pip install accelerate datasets transformers huggingface_hub wandb gated_state_spaces_pytorch
2
+ import os
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from torch.optim import AdamW
7
+ from torch.utils.data import DataLoader
8
+ from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
9
+
10
+ import wandb
11
+ from tqdm import tqdm
12
+ from transformers import BloomForCausalLM, BloomTokenizerFast
13
+ from gated_state_spaces_pytorch import GatedStateSpacesLM
14
+ from gated_state_spaces_pytorch.autoregressive_wrapper import AutoregressiveWrapper
15
+
16
+ # from c4x import C4X
17
+ from pile_hf import ThePile, ThePileTokenized
18
+ from accelerate import Accelerator
19
+
20
+
21
+ def main():
22
+ accelerator = Accelerator(
23
+ log_with="wandb",
24
+ gradient_accumulation_steps=8192,
25
+ )
26
+ accelerator.init_trackers("gated-state-space")
27
+
28
+ emb_fn = "emb.pt"
29
+ model_name = "bigscience/bloomz-1b7"
30
+ if not os.path.isfile(emb_fn):
31
+ bloom = BloomForCausalLM.from_pretrained(model_name)
32
+ wte = bloom.transformer.word_embeddings.state_dict()
33
+ torch.save(wte, emb_fn)
34
+ else:
35
+ wte = torch.load(emb_fn)
36
+
37
+ f_emb = 2048
38
+ n_vocab = 250880
39
+ model = AutoregressiveWrapper(
40
+ GatedStateSpacesLM(
41
+ num_tokens=n_vocab,
42
+ dim=f_emb,
43
+ depth=24,
44
+ ),
45
+ )
46
+
47
+ model.net.token_emb.requires_grad_(False)
48
+ model.net.token_emb.load_state_dict(wte)
49
+
50
+ to_logits = nn.Linear(f_emb, n_vocab, bias=False)
51
+ to_logits.requires_grad_(False)
52
+ to_logits.load_state_dict(wte)
53
+
54
+ model.net.to_logits = nn.Sequential(
55
+ nn.LayerNorm(f_emb),
56
+ to_logits,
57
+ )
58
+ model.load_state_dict(torch.load("model3.pt"))
59
+ model = model.to(accelerator.device)
60
+
61
+ if accelerator.is_main_process:
62
+ wandb.watch(model)
63
+
64
+ optim = AdamW(model.parameters(), 1e-4)
65
+ sch = CosineAnnealingWarmRestarts(
66
+ optim,
67
+ T_0=1000,
68
+ T_mult=2,
69
+ eta_min=1e-7,
70
+ )
71
+
72
+ bs = 1
73
+ kk = 2048
74
+ tok: BloomTokenizerFast = BloomTokenizerFast.from_pretrained(model_name)
75
+ dsx = ThePileTokenized(
76
+ ThePile("train"),
77
+ tokenizer=tok,
78
+ max_length=kk,
79
+ repeat_factor=4 / 3,
80
+ )
81
+ dlx = DataLoader(
82
+ dsx,
83
+ batch_size=bs,
84
+ num_workers=12,
85
+ )
86
+
87
+ prog = tqdm(dlx, disable=not accelerator.is_main_process)
88
+
89
+ model = accelerator.prepare(model)
90
+ optim, dlx, sch = accelerator.prepare(optim, dlx, sch)
91
+
92
+ optim.zero_grad()
93
+ for i, batch in enumerate(prog):
94
+ batch = batch.to(accelerator.device)
95
+ with accelerator.accumulate(model):
96
+ with accelerator.autocast():
97
+ los = model(batch)
98
+ accelerator.backward(los)
99
+ if accelerator.sync_gradients:
100
+ accelerator.clip_grad_norm_(model.parameters(), 1.0)
101
+ optim.step()
102
+ optim.zero_grad()
103
+ if not accelerator.optimizer_step_was_skipped:
104
+ sch.step()
105
+
106
+ if i % 1000 == 0:
107
+ unwrapped_model = accelerator.unwrap_model(model)
108
+ b, n = 1, 512
109
+ init = torch.tensor([[2]] * b).to(accelerator.device)
110
+ prd = unwrapped_model.generate(init, n)
111
+ prd = [tok.decode(p) for p in prd]
112
+ try:
113
+ accelerator.log(
114
+ dict(
115
+ text=wandb.Html(
116
+ "<hr>".join(p.replace("\n", "<br>") for p in prd)
117
+ )
118
+ ),
119
+ step=i,
120
+ )
121
+ except Exception as ex:
122
+ accelerator.print("Failed to log to W&B...", ex)
123
+ sd = unwrapped_model.state_dict()
124
+ # sd.pop('net.to_logits.weight')
125
+ accelerator.save(sd, "model4.pt")
126
+
127
+ if i % 10 == 0:
128
+ accelerator.log(
129
+ dict(
130
+ loss=los.item(),
131
+ lr=optim.param_groups[0]["lr"],
132
+ ),
133
+ step=i,
134
+ )
135
+ prog.set_postfix(loss=los.item())
136
+
137
+
138
+ if __name__ == "__main__":
139
+ main()
c4x.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # stream C4 dataset from Huggingface with GPT-2 Tokenizer for PyTorch Language Model Training
2
+ import json
3
+ import torch
4
+ import random
5
+ from datasets import load_dataset
6
+ from transformers import BloomTokenizerFast
7
+ from torch.utils.data import Dataset, get_worker_info
8
+
9
+
10
+ def cycled(itr):
11
+ while True:
12
+ for itm in itr:
13
+ yield itm
14
+
15
+ class C4X(Dataset):
16
+
17
+ def __init__(self, seq_len=512, split='train'):
18
+ self.seq = seq_len
19
+ self.ds = load_dataset(
20
+ 'c4',
21
+ name='en',
22
+ split=split,
23
+ streaming=True,
24
+ )
25
+ self.tok = BloomTokenizerFast.from_pretrained('bigscience/bloomz-1b7')
26
+ self.init = False
27
+
28
+ def __len__(self):
29
+ return 1_000_000_000
30
+
31
+ def _init(self):
32
+ if self.init:
33
+ return
34
+ wi = get_worker_info()
35
+ self.ds = cycled(
36
+ self.ds.shuffle(
37
+ seed=wi.seed,
38
+ buffer_size=10_000,
39
+ )
40
+ )
41
+ self.init = True
42
+
43
+ def _get_next(self):
44
+ self._init()
45
+ obj = next(self.ds)['text']
46
+ tkn = self.tok.encode(obj)
47
+ return tkn
48
+
49
+ def _get_full(self):
50
+ obj = []
51
+ while len(obj) < self.seq:
52
+ obj += self._get_next()
53
+ obj.append(self.tok.eos_token_id)
54
+ s = random.randint(0, len(obj)-self.seq)
55
+ return obj[s:s+self.seq]
56
+
57
+ def __getitem__(self, _):
58
+ return torch.tensor(self._get_full())
59
+
60
+ def decode(self, tkns):
61
+ return self.tok.decode(tkns)
model2.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:176c772feff0cf8504a46f872f6a32ae4269632b3e805e9437438f29268b795b
3
+ size 7609367025
model3.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c89f900da2bae9f79193ba785df8be4118d99135ffe66848e60f1ee6627b4bac
3
+ size 7609367025
model4.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c89f900da2bae9f79193ba785df8be4118d99135ffe66848e60f1ee6627b4bac
3
+ size 7609367025
pile.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import time
3
+ import random
4
+ from typing import Literal
5
+
6
+ import requests
7
+ import zstandard as zstd
8
+ from torch.utils.data import IterableDataset, get_worker_info
9
+
10
+
11
+ Subset = Literal["train", "val", "test"]
12
+ URLs = {
13
+ "val": [
14
+ "https://the-eye.eu/public/AI/pile/val.jsonl.zst",
15
+ ],
16
+ "test": [
17
+ "https://the-eye.eu/public/AI/pile/test.jsonl.zst",
18
+ ],
19
+ "train": [
20
+ "https://the-eye.eu/public/AI/pile/train/00.jsonl.zst",
21
+ "https://the-eye.eu/public/AI/pile/train/01.jsonl.zst",
22
+ "https://the-eye.eu/public/AI/pile/train/02.jsonl.zst",
23
+ "https://the-eye.eu/public/AI/pile/train/03.jsonl.zst",
24
+ "https://the-eye.eu/public/AI/pile/train/04.jsonl.zst",
25
+ "https://the-eye.eu/public/AI/pile/train/05.jsonl.zst",
26
+ "https://the-eye.eu/public/AI/pile/train/06.jsonl.zst",
27
+ "https://the-eye.eu/public/AI/pile/train/07.jsonl.zst",
28
+ "https://the-eye.eu/public/AI/pile/train/08.jsonl.zst",
29
+ "https://the-eye.eu/public/AI/pile/train/09.jsonl.zst",
30
+ "https://the-eye.eu/public/AI/pile/train/10.jsonl.zst",
31
+ "https://the-eye.eu/public/AI/pile/train/11.jsonl.zst",
32
+ "https://the-eye.eu/public/AI/pile/train/12.jsonl.zst",
33
+ "https://the-eye.eu/public/AI/pile/train/13.jsonl.zst",
34
+ "https://the-eye.eu/public/AI/pile/train/14.jsonl.zst",
35
+ "https://the-eye.eu/public/AI/pile/train/15.jsonl.zst",
36
+ "https://the-eye.eu/public/AI/pile/train/16.jsonl.zst",
37
+ "https://the-eye.eu/public/AI/pile/train/17.jsonl.zst",
38
+ "https://the-eye.eu/public/AI/pile/train/18.jsonl.zst",
39
+ "https://the-eye.eu/public/AI/pile/train/19.jsonl.zst",
40
+ "https://the-eye.eu/public/AI/pile/train/20.jsonl.zst",
41
+ "https://the-eye.eu/public/AI/pile/train/21.jsonl.zst",
42
+ "https://the-eye.eu/public/AI/pile/train/22.jsonl.zst",
43
+ "https://the-eye.eu/public/AI/pile/train/23.jsonl.zst",
44
+ "https://the-eye.eu/public/AI/pile/train/24.jsonl.zst",
45
+ "https://the-eye.eu/public/AI/pile/train/25.jsonl.zst",
46
+ "https://the-eye.eu/public/AI/pile/train/26.jsonl.zst",
47
+ "https://the-eye.eu/public/AI/pile/train/27.jsonl.zst",
48
+ "https://the-eye.eu/public/AI/pile/train/28.jsonl.zst",
49
+ "https://the-eye.eu/public/AI/pile/train/29.jsonl.zst",
50
+ ],
51
+ }
52
+
53
+
54
+ def _read_line_from_stream(reader, initial_line="", buffer_size=4096):
55
+ line = initial_line
56
+ while True:
57
+ c = reader.read(buffer_size)
58
+ if not c:
59
+ raise StopIteration
60
+ line += c.decode("utf-8")
61
+ if "\n" in line:
62
+ break
63
+ return line.split("\n", 1)
64
+
65
+
66
+ def _line_streamer(reader, buffer_size=4096):
67
+ rest = ""
68
+ while True:
69
+ try:
70
+ line, rest = _read_line_from_stream(
71
+ reader,
72
+ rest,
73
+ buffer_size,
74
+ )
75
+ yield line
76
+ except StopIteration:
77
+ break
78
+
79
+
80
+ class ThePile(IterableDataset):
81
+ TEXT_BUFFER_SIZE = 4096
82
+
83
+ def __init__(self, subset: Subset):
84
+ self.subset = subset
85
+
86
+ def __iter__(self):
87
+ urls = URLs[self.subset].copy()
88
+ while True:
89
+ wi = get_worker_info()
90
+ seed = wi.id if wi is not None else None
91
+ rnd = random.Random(seed)
92
+ rnd.shuffle(urls)
93
+ for url in urls:
94
+ r = requests.get(url, stream=True)
95
+ with zstd.ZstdDecompressor().stream_reader(r.raw) as reader:
96
+ for line in _line_streamer(reader, self.TEXT_BUFFER_SIZE):
97
+ data = json.loads(line)
98
+ yield data
99
+
100
+
101
+ if __name__ == "__main__":
102
+ from tqdm import tqdm
103
+
104
+ dataset = ThePile("train")
105
+ for data in tqdm(dataset, smoothing=0.01):
106
+ pass
107
+ # Average: ~2000 samples/sec/worker
pile_hf.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import IterableDataset
3
+
4
+ from transformers import PreTrainedTokenizerBase
5
+
6
+ from pile import ThePile
7
+
8
+
9
+ class ThePileTokenized(IterableDataset):
10
+ def __init__(
11
+ self,
12
+ base_dataset: ThePile,
13
+ tokenizer: PreTrainedTokenizerBase,
14
+ max_length: int = 1024,
15
+ repeat_factor: float = 1.0,
16
+ ):
17
+ self.pile = base_dataset
18
+ self.tokenizer = tokenizer
19
+ self.max_length = max_length
20
+ self.repeat_factor = repeat_factor
21
+
22
+ def __iter__(self):
23
+ ds = iter(self.pile)
24
+ buffer = []
25
+ while True:
26
+ tokens = self.tokenizer.encode(next(ds)["text"])
27
+ buffer += [self.tokenizer.eos_token_id] + tokens
28
+ while len(buffer) > self.max_length:
29
+ yield torch.tensor(buffer[: self.max_length])
30
+ buffer = buffer[int(self.max_length / self.repeat_factor) :]
31
+
32
+
33
+ if __name__ == "__main__":
34
+ from tqdm import tqdm
35
+ from torch.utils.data import DataLoader
36
+ from transformers import GPT2Tokenizer
37
+
38
+ dataset = ThePileTokenized(
39
+ ThePile("train"),
40
+ GPT2Tokenizer.from_pretrained("gpt2"),
41
+ max_length=2048,
42
+ repeat_factor=4 / 3,
43
+ )
44
+ dataloader = DataLoader(
45
+ dataset,
46
+ batch_size=1,
47
+ )
48
+ for batch in tqdm(dataloader, smoothing=0.01):
49
+ x = 0
50
+ # ~6 iters/s for 1 worker