File size: 867 Bytes
0b7b08a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
build_model = None
ZeroRedundancyOptimizer = None
GradScaler = None
laion_loader = None
pile_loader = None
autocast = None
zero_embedding_gradient = None
torch = None
lr_scheduler = None
get_cosine_schedule_with_warmup = None


ddp_model = build_model(...)
optimizer = ZeroRedundancyOptimizer(...)
lr_scheduler = get_cosine_schedule_with_warmup(...)
scaler = GradScaler()

for batch_laion, batch_pile in zip(laion_loader, pile_loader):
    with autocast():
        loss_laion = ddp_model(batch_laion)
    scaler.scale(loss_laion).backward()
    with autocast():
        loss_pile = ddp_model(batch_pile)
    scaler.scale(loss_pile).backward()

    zero_embedding_gradient()
    scaler.unscale_(optimizer)
    torch.nn.utils.clip_grad_norm_(ddp_model.parameters(), 1.0)

    scaler.step(optimizer)
    scaler.update()
    lr_scheduler.step()
    optimizer.zero_grad()