gss / app.py
naxautify's picture
init
e6333f5
raw
history blame
4.04 kB
# pip install accelerate datasets transformers huggingface_hub wandb gated_state_spaces_pytorch
import os
import torch
import torch.nn as nn
from torch.optim import AdamW
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
import wandb
from tqdm import tqdm
from transformers import BloomForCausalLM, BloomTokenizerFast
from gated_state_spaces_pytorch import GatedStateSpacesLM
from gated_state_spaces_pytorch.autoregressive_wrapper import AutoregressiveWrapper
# from c4x import C4X
from pile_hf import ThePile, ThePileTokenized
from accelerate import Accelerator
def main():
accelerator = Accelerator(
log_with="wandb",
gradient_accumulation_steps=8192,
)
accelerator.init_trackers("gated-state-space")
emb_fn = "emb.pt"
model_name = "bigscience/bloomz-1b7"
if not os.path.isfile(emb_fn):
bloom = BloomForCausalLM.from_pretrained(model_name)
wte = bloom.transformer.word_embeddings.state_dict()
torch.save(wte, emb_fn)
else:
wte = torch.load(emb_fn)
f_emb = 2048
n_vocab = 250880
model = AutoregressiveWrapper(
GatedStateSpacesLM(
num_tokens=n_vocab,
dim=f_emb,
depth=24,
),
)
model.net.token_emb.requires_grad_(False)
model.net.token_emb.load_state_dict(wte)
to_logits = nn.Linear(f_emb, n_vocab, bias=False)
to_logits.requires_grad_(False)
to_logits.load_state_dict(wte)
model.net.to_logits = nn.Sequential(
nn.LayerNorm(f_emb),
to_logits,
)
model.load_state_dict(torch.load("model3.pt"))
model = model.to(accelerator.device)
if accelerator.is_main_process:
wandb.watch(model)
optim = AdamW(model.parameters(), 1e-4)
sch = CosineAnnealingWarmRestarts(
optim,
T_0=1000,
T_mult=2,
eta_min=1e-7,
)
bs = 1
kk = 2048
tok: BloomTokenizerFast = BloomTokenizerFast.from_pretrained(model_name)
dsx = ThePileTokenized(
ThePile("train"),
tokenizer=tok,
max_length=kk,
repeat_factor=4 / 3,
)
dlx = DataLoader(
dsx,
batch_size=bs,
num_workers=12,
)
prog = tqdm(dlx, disable=not accelerator.is_main_process)
model = accelerator.prepare(model)
optim, dlx, sch = accelerator.prepare(optim, dlx, sch)
optim.zero_grad()
for i, batch in enumerate(prog):
batch = batch.to(accelerator.device)
with accelerator.accumulate(model):
with accelerator.autocast():
los = model(batch)
accelerator.backward(los)
if accelerator.sync_gradients:
accelerator.clip_grad_norm_(model.parameters(), 1.0)
optim.step()
optim.zero_grad()
if not accelerator.optimizer_step_was_skipped:
sch.step()
if i % 1000 == 0:
unwrapped_model = accelerator.unwrap_model(model)
b, n = 1, 512
init = torch.tensor([[2]] * b).to(accelerator.device)
prd = unwrapped_model.generate(init, n)
prd = [tok.decode(p) for p in prd]
try:
accelerator.log(
dict(
text=wandb.Html(
"<hr>".join(p.replace("\n", "<br>") for p in prd)
)
),
step=i,
)
except Exception as ex:
accelerator.print("Failed to log to W&B...", ex)
sd = unwrapped_model.state_dict()
# sd.pop('net.to_logits.weight')
accelerator.save(sd, "model4.pt")
if i % 10 == 0:
accelerator.log(
dict(
loss=los.item(),
lr=optim.param_groups[0]["lr"],
),
step=i,
)
prog.set_postfix(loss=los.item())
if __name__ == "__main__":
main()