BathSalt-1 commited on
Commit
11bec05
·
verified ·
1 Parent(s): fe04db1

Create train.py

Browse files
Files changed (1) hide show
  1. train.py +31 -0
train.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from daedalus_mobile import DaedalusMobile
3
+ from tokenizer import DaedalusTokenizer
4
+ from config import config
5
+
6
+ def train(model, device, train_loader, optimizer):
7
+ model.train()
8
+ total_loss = 0
9
+ for batch in train_loader:
10
+ input_ids, attention_mask, labels = batch
11
+ input_ids, attention_mask, labels = input_ids.to(device), attention_mask.to(device), labels.to(device)
12
+ optimizer.zero_grad()
13
+ loss = model.train_step((input_ids, attention_mask, labels))
14
+ loss.backward()
15
+ optimizer.step()
16
+ total_loss += loss.item()
17
+ return total_loss / len(train_loader)
18
+
19
+ def main():
20
+ device = torch.device(config.device)
21
+ model = DaedalusMobile(config)
22
+ model.to(device)
23
+ tokenizer = DaedalusTokenizer(config)
24
+ train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=config.batch_size, shuffle=True)
25
+ optimizer = model.configure_optimizers()
26
+ for epoch in range(config.epochs):
27
+ loss = train(model, device, train_loader, optimizer)
28
+ print(f'Epoch {epoch+1}, Loss: {loss:.4f}')
29
+
30
+ if __name__ == '__main__':
31
+ main()