caixiaoshun commited on
Commit
5153277
·
verified ·
1 Parent(s): 042e3bf

Upload 6 files

Browse files
Files changed (6) hide show
  1. src/config.py +34 -0
  2. src/dataset.py +60 -0
  3. src/model.py +216 -0
  4. src/sample.py +67 -0
  5. src/train.py +298 -0
  6. src/train_tokenizer.py +41 -0
src/config.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class Config:
2
+ def __init__(self):
3
+ self.encoder_layer = 6
4
+ self.decoder_layer = 6
5
+ self.embed_dim = 512
6
+ self.num_heads = 8
7
+ self.drop_out = 0.1
8
+ self.max_len = 256
9
+ self.vocab_size = 30_000
10
+ self.wmt_zh_en_path = "data/wmt_zh_en_training_corpus.csv"
11
+ self.tokenizer_file = "checkpoints/tokenizer.json"
12
+ self.batch_size = 64
13
+ self.compile = False
14
+ self.seed = 42
15
+ self.val_ratio = 0.1
16
+ self.num_workers = 4
17
+ self.pin_memory = True
18
+ self.tensorboard_dir = "log/tensorboard"
19
+ self.checkpoint_dir = "log/checkpoint"
20
+
21
+ self.base_lr = 3e-4
22
+ self.betas = (0.9, 0.98)
23
+ self.eps = 1e-9
24
+ self.weight_decay = 0.1
25
+ self.warmup_ratio = 0.005
26
+ self.start_factor = 1e-3
27
+ self.end_factor = 1.0
28
+ self.eta_min = 3e-6
29
+ self.max_epochs = 1
30
+
31
+ self.data_cache_dir = "data/cache.pickle"
32
+ self.use_cache = True
33
+
34
+ self.every_n_train_steps = 10000
src/dataset.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils.data import Dataset
2
+ import csv
3
+ from tokenizers import Tokenizer
4
+ import torch
5
+ import os
6
+ import pickle
7
+ from src.config import Config
8
+
9
+
10
+ class TranslateDataset(Dataset):
11
+ def __init__(self, config: Config):
12
+ super().__init__()
13
+ self.config = config
14
+ self.tokenizer: Tokenizer = Tokenizer.from_file(config.tokenizer_file)
15
+ self.pad_id = self.tokenizer.token_to_id("[PAD]")
16
+ self.pairs = []
17
+ if os.path.exists(config.data_cache_dir) and config.use_cache:
18
+ with open(config.data_cache_dir, "rb") as f:
19
+ self.pairs = pickle.load(f)
20
+ else:
21
+ with open(self.config.wmt_zh_en_path, mode="r", encoding="utf-8") as f:
22
+ reader = csv.DictReader(f)
23
+ for line in reader:
24
+ self.pairs.append((line["0"], line["1"]))
25
+ if config.use_cache:
26
+ with open(config.data_cache_dir, "wb") as cache_f:
27
+ pickle.dump(self.pairs, cache_f)
28
+
29
+ def __len__(self):
30
+ return len(self.pairs)
31
+
32
+ def encode(self, text):
33
+ ids = self.tokenizer.encode(text).ids
34
+
35
+ if len(ids) > self.config.max_len:
36
+ ids = ids[: self.config.max_len]
37
+
38
+ pad_len = self.config.max_len - len(ids)
39
+
40
+ if pad_len > 0:
41
+ ids = ids + [self.pad_id] * pad_len
42
+ pad_mask = [False if i == self.pad_id else True for i in ids]
43
+ return torch.tensor(ids, dtype=torch.long), torch.tensor(
44
+ pad_mask, dtype=torch.bool
45
+ )
46
+
47
+ def __getitem__(self, idx):
48
+ zh, en = self.pairs[idx]
49
+
50
+ zh_id, zh_pad = self.encode(zh)
51
+
52
+ en_id, en_pad = self.encode(en)
53
+
54
+ return dict(
55
+ src=zh_id,
56
+ src_pad_mask=zh_pad,
57
+ tgt=en_id[:-1],
58
+ tgt_pad_mask=en_pad[:-1],
59
+ label=en_id[1:],
60
+ )
src/model.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn as nn
2
+ import torch
3
+ from src.config import Config
4
+
5
+
6
+
7
+ class MultiHeadAttention(nn.Module):
8
+ def __init__(self, embed_dim, num_heads, drop_out=0.1):
9
+ super().__init__()
10
+
11
+ assert embed_dim % num_heads == 0
12
+
13
+ self.q_proj = nn.Linear(embed_dim, embed_dim)
14
+ self.k_proj = nn.Linear(embed_dim, embed_dim)
15
+ self.v_proj = nn.Linear(embed_dim, embed_dim)
16
+ self.out_proj = nn.Linear(embed_dim, embed_dim)
17
+ self.dropout = nn.Dropout(drop_out)
18
+ self.scale = (embed_dim // num_heads) ** 0.5
19
+ self.embed_dim = embed_dim
20
+ self.num_heads = num_heads
21
+
22
+ def forward(
23
+ self,
24
+ q: torch.Tensor,
25
+ k: torch.Tensor,
26
+ v: torch.Tensor,
27
+ mask: torch.Tensor = None,
28
+ pad_mask: torch.Tensor = None,
29
+ ):
30
+ bs = q.shape[0]
31
+ q_len = q.shape[1]
32
+ k_len = k.shape[1]
33
+
34
+ Q: torch.Tensor = self.q_proj(q)
35
+ K: torch.Tensor = self.k_proj(k)
36
+ V: torch.Tensor = self.v_proj(v)
37
+
38
+ q_state = Q.view(bs, q_len, self.num_heads, -1).transpose(1, 2)
39
+ k_state = K.view(bs, k_len, self.num_heads, -1).transpose(1, 2)
40
+ v_state = V.view(bs, k_len, self.num_heads, -1).transpose(1, 2)
41
+
42
+ attn = q_state @ k_state.transpose(
43
+ -1, -2
44
+ ) # [bs, head, q_len, dim] @ [bs, head, dim, k_len] = [bs, head, q_len, k_len]
45
+ attn: torch.Tensor = attn / self.scale
46
+
47
+ if mask is not None:
48
+ attn = attn.masked_fill(~mask, -1e8)
49
+
50
+ if pad_mask is not None:
51
+ attn = attn.masked_fill(~pad_mask.unsqueeze(1).unsqueeze(2), -1e8)
52
+
53
+ attn = torch.softmax(attn, dim=-1)
54
+
55
+ attn = self.dropout(attn)
56
+
57
+ out = (
58
+ attn @ v_state
59
+ ) # [bs, head, q_len, k_len] @ [bs, head, k_len, dim] = [bs, head, q_len, dim]
60
+
61
+ out = out.transpose(1, 2).contiguous().view(bs, q_len, -1)
62
+
63
+ out = self.out_proj(out)
64
+
65
+ return out
66
+
67
+
68
+ class FFN(nn.Module):
69
+ def __init__(self, embed_dim, drop_out=0.1):
70
+ super().__init__()
71
+ self.mlp = nn.Sequential(
72
+ nn.Linear(embed_dim, embed_dim * 4),
73
+ nn.ReLU(),
74
+ nn.Linear(embed_dim * 4, embed_dim),
75
+ nn.Dropout(drop_out),
76
+ )
77
+
78
+ def forward(self, x):
79
+ return self.mlp(x)
80
+
81
+
82
+ class EncoderLayer(nn.Module):
83
+ def __init__(self, embed_dim, num_heads, drop_out=0.1):
84
+ super().__init__()
85
+ self.mha = MultiHeadAttention(
86
+ embed_dim=embed_dim, num_heads=num_heads, drop_out=drop_out
87
+ )
88
+ self.ffn = FFN(embed_dim=embed_dim, drop_out=drop_out)
89
+ self.norm1 = nn.LayerNorm(embed_dim)
90
+ self.norm2 = nn.LayerNorm(embed_dim)
91
+
92
+ def forward(self, x: torch.Tensor, pad_mask=None):
93
+ x = x + self.mha(x, x, x, pad_mask=pad_mask)
94
+ x = self.norm1(x)
95
+
96
+ x = x + self.ffn(x)
97
+
98
+ x = self.norm2(x)
99
+
100
+ return x
101
+
102
+
103
+ class Encoder(nn.Module):
104
+ def __init__(self, config: Config):
105
+ super().__init__()
106
+ self.layers = nn.ModuleList([])
107
+ for _ in range(config.encoder_layer):
108
+ self.layers.append(
109
+ EncoderLayer(config.embed_dim, config.num_heads, config.drop_out)
110
+ )
111
+
112
+ def forward(self, x, pad_mask=None):
113
+ for layer in self.layers:
114
+ x = layer(x, pad_mask)
115
+ return x
116
+
117
+
118
+ class DecoderLayer(nn.Module):
119
+ def __init__(self, embed_dim, num_heads, drop_out=0.1):
120
+ super().__init__()
121
+ self.self_attn = MultiHeadAttention(
122
+ embed_dim=embed_dim, num_heads=num_heads, drop_out=drop_out
123
+ )
124
+ self.cross_attn = MultiHeadAttention(
125
+ embed_dim=embed_dim, num_heads=num_heads, drop_out=drop_out
126
+ )
127
+ self.ffn = FFN(embed_dim=embed_dim, drop_out=drop_out)
128
+ self.norm0 = nn.LayerNorm(embed_dim)
129
+ self.norm1 = nn.LayerNorm(embed_dim)
130
+ self.norm2 = nn.LayerNorm(embed_dim)
131
+
132
+ def forward(self, x: torch.Tensor, memory, src_pad_mask=None, tgt_pad_mask=None):
133
+
134
+ x_len = x.shape[1]
135
+ mask = torch.ones(size=(1, 1, x_len, x_len), device=x.device, dtype=torch.bool).tril()
136
+
137
+ x = x + self.self_attn(x, x, x, mask=mask, pad_mask=tgt_pad_mask)
138
+
139
+ x = self.norm0(x)
140
+
141
+ x = x + self.cross_attn(x, memory, memory, pad_mask=src_pad_mask)
142
+ x = self.norm1(x)
143
+
144
+ x = x + self.ffn(x)
145
+
146
+ x = self.norm2(x)
147
+
148
+ return x
149
+
150
+
151
+ class Decoder(nn.Module):
152
+ def __init__(self, config: Config):
153
+ super().__init__()
154
+ self.layers = nn.ModuleList([])
155
+ for _ in range(config.decoder_layer):
156
+ self.layers.append(
157
+ DecoderLayer(config.embed_dim, config.num_heads, config.drop_out)
158
+ )
159
+
160
+ def forward(self, x: torch.Tensor, memory, src_pad_mask=None, tgt_pad_mask=None):
161
+ for layer in self.layers:
162
+ x = layer(x, memory, src_pad_mask=src_pad_mask, tgt_pad_mask=tgt_pad_mask)
163
+ return x
164
+
165
+
166
+ class PositionEmbedding(nn.Module):
167
+ def __init__(self, config: Config):
168
+ super().__init__()
169
+ pe = torch.zeros(config.max_len, config.embed_dim)
170
+ pos = torch.arange(0, config.max_len, 1).float().unsqueeze(1)
171
+ _2i = torch.arange(0, config.embed_dim, 2)
172
+ pe[:, 0::2] = torch.sin(pos / (10000 ** (_2i / config.embed_dim)))
173
+ pe[:, 1::2] = torch.cos(pos / (10000 ** (_2i / config.embed_dim)))
174
+
175
+ pe = pe.unsqueeze(0)
176
+ self.register_buffer("pe", pe)
177
+
178
+ def forward(self, x):
179
+ x_len = x.shape[1]
180
+ return x + self.pe[:, :x_len].to(dtype=x.dtype)
181
+
182
+
183
+ class TranslateModel(nn.Module):
184
+ def __init__(self, config: Config):
185
+ super().__init__()
186
+ self.position_embedding = PositionEmbedding(config=config)
187
+ self.encoder = Encoder(config=config)
188
+ self.decoder = Decoder(config=config)
189
+ self.embedding = nn.Embedding(config.vocab_size, config.embed_dim)
190
+ self.head = nn.Linear(config.embed_dim, config.vocab_size)
191
+ self.drop = nn.Dropout(config.drop_out)
192
+
193
+ def forward(
194
+ self,
195
+ src: torch.Tensor,
196
+ tgt: torch.Tensor,
197
+ src_pad_mask=None,
198
+ tgt_pad_mask=None,
199
+ ):
200
+
201
+ ## encoder
202
+
203
+ src_embedding = self.embedding(src)
204
+ src_embedding = self.position_embedding(src_embedding)
205
+ memory = self.encoder.forward(src_embedding, src_pad_mask)
206
+
207
+ tgt_embedding = self.embedding(tgt)
208
+ tgt_embedding = self.position_embedding(tgt_embedding)
209
+
210
+ output = self.decoder.forward(tgt_embedding, memory, src_pad_mask, tgt_pad_mask)
211
+
212
+ output = self.drop(output)
213
+
214
+ output = self.head(output)
215
+
216
+ return output
src/sample.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from tokenizers import Tokenizer
3
+ import torch
4
+ from src.config import Config
5
+ from src.model import TranslateModel
6
+
7
+
8
+
9
+
10
+ def get_args():
11
+ parser = argparse.ArgumentParser()
12
+ parser.add_argument("--ckpt_path", default="checkpoints/translate-step=290000.ckpt")
13
+ parser.add_argument("--zh", default="早上好")
14
+ return parser.parse_args()
15
+
16
+ class Inference:
17
+ def __init__(self,config:Config, ckpt_path):
18
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
19
+ self.tokenizer:Tokenizer = Tokenizer.from_file(config.tokenizer_file)
20
+ self.model:TranslateModel = TranslateModel(config)
21
+ ckpt = torch.load(ckpt_path, map_location="cpu")["state_dict"]
22
+ state_dict = {}
23
+ for k, v in ckpt.items():
24
+ new_k = k[len("net._orig_mod."):]
25
+ state_dict[new_k] = v
26
+ self.model.load_state_dict(state_dict, strict=True)
27
+ self.model.eval()
28
+ self.model = self.model.to(self.device)
29
+ self.config = config
30
+
31
+
32
+ @torch.no_grad()
33
+ def sampler(self, src:str)->str:
34
+ src = self.tokenizer.encode(src).ids
35
+ tgt = [self.tokenizer.token_to_id("[SOS]")]
36
+ max_len = self.config.max_len
37
+ EOS = self.tokenizer.token_to_id("[EOS]")
38
+
39
+ src = torch.tensor(src, dtype=torch.long).to(self.device).unsqueeze(0)
40
+ tgt = torch.tensor(tgt, dtype=torch.long).to(self.device).unsqueeze(0)
41
+
42
+ for _ in range(1, max_len):
43
+ logits = self.model.forward(src, tgt) # [1, len, vocab]
44
+ logits = logits[:,-1,:]
45
+ logits = torch.softmax(logits, dim=-1)
46
+ index = torch.argmax(logits, dim=-1)
47
+ tgt = torch.cat((tgt, index.unsqueeze(0)), dim=-1)
48
+ if index.detach().cpu().item() == EOS:
49
+ break
50
+
51
+ tgt = tgt.detach().cpu().squeeze(0).tolist()
52
+ tgt_str = self.tokenizer.decode(tgt)
53
+ return tgt_str
54
+
55
+
56
+ def main():
57
+ args = get_args()
58
+ config = Config()
59
+ inference = Inference(config, args.ckpt_path)
60
+ zh = args.zh
61
+ result = inference.sampler(zh)
62
+ print(f"中文:{zh}")
63
+ print(f"English:{result}")
64
+
65
+
66
+ if __name__ == "__main__":
67
+ main()
src/train.py ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import argparse
4
+ from torch import nn
5
+ from torch.utils.data import DataLoader, random_split
6
+ from torchmetrics import MeanMetric, MaxMetric
7
+ import lightning as L
8
+ from torchmetrics.classification.accuracy import Accuracy
9
+ from torch.optim.lr_scheduler import LinearLR, CosineAnnealingLR, SequentialLR
10
+ from tokenizers import Tokenizer
11
+ from src.config import Config
12
+ from src.model import TranslateModel
13
+ from src.dataset import TranslateDataset
14
+ from lightning.pytorch.loggers import TensorBoardLogger
15
+ from lightning.pytorch.callbacks import RichProgressBar
16
+ from lightning.pytorch.callbacks import ModelCheckpoint
17
+
18
+
19
+ import argparse
20
+
21
+
22
+ def parser_args():
23
+ parser = argparse.ArgumentParser(description="Training configuration")
24
+
25
+ parser.add_argument("--encoder_layer", type=int, default=6, help="Number of encoder layers")
26
+ parser.add_argument("--decoder_layer", type=int, default=6, help="Number of decoder layers")
27
+ parser.add_argument("--embed_dim", type=int, default=512, help="Embedding dimension size")
28
+ parser.add_argument("--num_heads", type=int, default=8, help="Number of attention heads")
29
+ parser.add_argument("--drop_out", type=float, default=0.1, help="Dropout rate")
30
+ parser.add_argument("--max_len", type=int, default=256, help="Maximum sequence length")
31
+ parser.add_argument("--vocab_size", type=int, default=30000, help="Vocabulary size")
32
+
33
+ parser.add_argument("--wmt_zh_en_path", type=str,
34
+ default="data/wmt_zh_en_training_corpus.csv",
35
+ help="Path to WMT zh-en training corpus")
36
+ parser.add_argument("--tokenizer_file", type=str,
37
+ default="checkpoints/tokenizer.json",
38
+ help="Path to tokenizer file")
39
+ parser.add_argument("--batch_size", type=int, default=64, help="Batch size for training")
40
+ parser.add_argument("--compile", action="store_true", help="Enable torch.compile if available")
41
+ parser.add_argument("--seed", type=int, default=42, help="Random seed")
42
+ parser.add_argument("--val_ratio", type=float, default=0.1, help="Validation data ratio")
43
+ parser.add_argument("--num_workers", type=int, default=4, help="Number of data loading workers")
44
+ parser.add_argument("--pin_memory", action="store_true", help="Use pinned memory in dataloader")
45
+
46
+ parser.add_argument("--tensorboard_dir", type=str, default="log/tensorboard",
47
+ help="Directory for tensorboard logs")
48
+ parser.add_argument("--checkpoint_dir", type=str, default="log/checkpoint",
49
+ help="Directory for saving checkpoints")
50
+
51
+ # -------- optimizer / scheduler 参数 --------
52
+ parser.add_argument("--base_lr", type=float, default=3e-4, help="Base learning rate")
53
+ parser.add_argument("--betas", type=float, nargs=2, default=(0.9, 0.98), help="AdamW betas")
54
+ parser.add_argument("--eps", type=float, default=1e-9, help="AdamW epsilon")
55
+ parser.add_argument("--weight_decay", type=float, default=0.1, help="Weight decay")
56
+
57
+ parser.add_argument("--warmup_ratio", type=float, default=0.0005,
58
+ help="Warmup ratio (fraction of total steps)")
59
+ parser.add_argument("--start_factor", type=float, default=1e-3,
60
+ help="Linear warmup start factor (relative LR scale)")
61
+ parser.add_argument("--end_factor", type=float, default=1.0,
62
+ help="Linear warmup end factor (relative LR scale)")
63
+ parser.add_argument("--eta_min", type=float, default=3e-6,
64
+ help="Minimum LR in cosine annealing")
65
+
66
+ parser.add_argument("--max_epochs", type=int, default=10, help="Number of epochs to train")
67
+
68
+ # -------- dataset cache 参数 --------
69
+ parser.add_argument("--data_cache_dir", type=str, default="data/cache.pickle",
70
+ help="Path to cache file for dataset")
71
+ parser.add_argument("--use_cache", action="store_true",
72
+ help="Enable dataset caching with pickle")
73
+
74
+ parser.add_argument("--every_n_train_steps", type=int, default=10000,
75
+ help="Save checkpoint every N training steps")
76
+
77
+ return parser.parse_args()
78
+
79
+
80
+
81
+ def merge_args_config(config: Config, args):
82
+ for k, v in vars(args).items():
83
+ setattr(config, k, v)
84
+ return config
85
+
86
+
87
+ class TranslateLitModule(L.LightningModule):
88
+ def __init__(self, config: Config):
89
+ super().__init__()
90
+
91
+ tokenizer: Tokenizer = Tokenizer.from_file(config.tokenizer_file)
92
+ self.pad_id = tokenizer.token_to_id("[PAD]")
93
+ self.net = TranslateModel(config=config)
94
+ self.train_loss = MeanMetric()
95
+ self.train_acc = Accuracy(
96
+ task="multiclass", num_classes=config.vocab_size, ignore_index=self.pad_id
97
+ )
98
+ self.criterion = nn.CrossEntropyLoss(ignore_index=self.pad_id)
99
+
100
+ self.val_loss = MeanMetric()
101
+ self.val_acc = Accuracy(
102
+ task="multiclass", num_classes=config.vocab_size, ignore_index=self.pad_id
103
+ )
104
+
105
+ self.test_loss = MeanMetric()
106
+ self.test_acc = Accuracy(
107
+ task="multiclass", num_classes=config.vocab_size, ignore_index=self.pad_id
108
+ )
109
+
110
+ self.val_acc_best = MaxMetric()
111
+
112
+ self.config = config
113
+
114
+ def forward(self, batch) -> torch.Tensor:
115
+ pred = self.net.forward(
116
+ src=batch["src"],
117
+ tgt=batch["tgt"],
118
+ src_pad_mask=batch["src_pad_mask"],
119
+ tgt_pad_mask=batch["tgt_pad_mask"],
120
+ )
121
+ return pred
122
+
123
+ def on_train_start(self) -> None:
124
+ self.train_loss.reset()
125
+ self.train_acc.reset()
126
+
127
+ self.val_loss.reset()
128
+ self.val_acc.reset()
129
+ self.val_acc_best.reset()
130
+
131
+ self.test_loss.reset()
132
+ self.test_acc.reset()
133
+
134
+ def model_step(self, batch):
135
+
136
+ logits = self.forward(batch)
137
+
138
+ B, L, C = logits.shape
139
+
140
+ loss = self.criterion(logits.reshape(-1, C), batch["label"].reshape(-1))
141
+ preds = torch.argmax(logits, dim=-1)
142
+ return loss, preds.reshape(-1), batch["label"].reshape(-1)
143
+
144
+ def training_step(self, batch, batch_idx):
145
+ loss, preds, targets = self.model_step(batch)
146
+
147
+ # update and log metrics
148
+ self.train_loss(loss)
149
+ self.train_acc(preds, targets)
150
+ self.log(
151
+ "train/loss", self.train_loss, on_step=True, on_epoch=True, prog_bar=True
152
+ )
153
+ self.log(
154
+ "train/acc", self.train_acc, on_step=True, on_epoch=True, prog_bar=True
155
+ )
156
+
157
+ # return loss or backpropagation will fail
158
+ return loss
159
+
160
+ def on_train_epoch_end(self) -> None:
161
+ pass
162
+
163
+ def validation_step(self, batch, batch_idx: int) -> None:
164
+ loss, preds, targets = self.model_step(batch)
165
+
166
+ # update and log metrics
167
+ self.val_loss(loss)
168
+ self.val_acc(preds, targets)
169
+ self.log("val/loss", self.val_loss, on_step=True, on_epoch=True, prog_bar=True)
170
+ self.log("val/acc", self.val_acc, on_step=True, on_epoch=True, prog_bar=True)
171
+
172
+ def on_validation_epoch_end(self) -> None:
173
+ acc = self.val_acc.compute() # get current val acc
174
+ self.val_acc_best(acc) # update best so far val acc
175
+ self.log(
176
+ "val/acc_best", self.val_acc_best.compute(), sync_dist=True, prog_bar=True
177
+ )
178
+
179
+ def test_step(self, batch, batch_idx: int) -> None:
180
+ loss, preds, targets = self.model_step(batch)
181
+
182
+ # update and log metrics
183
+ self.test_loss(loss)
184
+ self.test_acc(preds, targets)
185
+ self.log(
186
+ "test/loss", self.test_loss, on_step=False, on_epoch=True, prog_bar=True
187
+ )
188
+ self.log("test/acc", self.test_acc, on_step=False, on_epoch=True, prog_bar=True)
189
+
190
+ def on_test_epoch_end(self) -> None:
191
+ """Lightning hook that is called when a test epoch ends."""
192
+ pass
193
+
194
+ def setup(self, stage: str) -> None:
195
+ if self.config.compile and stage == "fit":
196
+ self.net = torch.compile(self.net)
197
+
198
+ def configure_optimizers(self):
199
+
200
+ optimizer = torch.optim.AdamW(
201
+ self.parameters(),
202
+ lr=self.config.base_lr,
203
+ betas=self.config.betas,
204
+ eps=self.config.eps,
205
+ weight_decay=self.config.weight_decay,
206
+ )
207
+
208
+
209
+ total_steps = self.trainer.estimated_stepping_batches
210
+
211
+ warmup_steps = max(1, int(self.config.warmup_ratio * total_steps))
212
+ cosine_steps = max(1, total_steps - warmup_steps)
213
+
214
+ warmup = LinearLR(
215
+ optimizer,
216
+ start_factor=self.config.start_factor,
217
+ end_factor=self.config.end_factor,
218
+ total_iters=warmup_steps,
219
+ )
220
+
221
+
222
+ cosine = CosineAnnealingLR(
223
+ optimizer,
224
+ T_max=cosine_steps,
225
+ eta_min=self.config.eta_min,
226
+ )
227
+
228
+ scheduler = SequentialLR(optimizer, schedulers=[warmup, cosine], milestones=[warmup_steps])
229
+
230
+ return {
231
+ "optimizer": optimizer,
232
+ "lr_scheduler": {
233
+ "scheduler": scheduler,
234
+ "interval": "step", # 每个 step 调整
235
+ "frequency": 1
236
+ }
237
+ }
238
+
239
+
240
+ def prepare_dataloader(dataset, config: Config, shuffle=True):
241
+ dataloader = DataLoader(
242
+ dataset,
243
+ batch_size=config.batch_size,
244
+ shuffle=shuffle,
245
+ num_workers=config.num_workers,
246
+ pin_memory=config.pin_memory,
247
+ )
248
+ return dataloader
249
+
250
+
251
+ def prepare_dataset(config: Config):
252
+ full_ds = TranslateDataset(config=config)
253
+ val_ratio = config.val_ratio
254
+ val_len = int(len(full_ds) * val_ratio)
255
+ train_len = len(full_ds) - val_len
256
+ train_ds, val_ds = random_split(
257
+ full_ds,
258
+ [train_len, val_len],
259
+ generator=torch.Generator().manual_seed(config.seed),
260
+ )
261
+ return prepare_dataloader(train_ds, config), prepare_dataloader(
262
+ val_ds, config, False
263
+ )
264
+
265
+
266
+ def prepare_callback(config: Config):
267
+ logger = TensorBoardLogger(save_dir=config.tensorboard_dir, name="runs")
268
+ rich_progress_bar = RichProgressBar()
269
+ checkpoint = ModelCheckpoint(
270
+ dirpath=config.checkpoint_dir,
271
+ filename="translate-{step:05d}",
272
+ save_weights_only=True,
273
+ every_n_train_steps=config.every_n_train_steps,
274
+ save_top_k=-1,
275
+ )
276
+ return logger, [rich_progress_bar, checkpoint]
277
+
278
+
279
+ def main():
280
+
281
+ args = parser_args()
282
+ config = merge_args_config(Config(), args)
283
+
284
+ L.seed_everything(config.seed)
285
+
286
+ train_loader, val_loader = prepare_dataset(config)
287
+
288
+ model = TranslateLitModule(config=config)
289
+
290
+ logger, callbacks = prepare_callback(config)
291
+
292
+ trainer = L.Trainer(callbacks=callbacks, logger=logger, max_epochs=config.max_epochs)
293
+
294
+ trainer.fit(model=model, train_dataloaders=train_loader, val_dataloaders=val_loader)
295
+
296
+
297
+ if __name__ == "__main__":
298
+ main()
src/train_tokenizer.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import csv
2
+ from tokenizers import models, Tokenizer, normalizers, pre_tokenizers, decoders, trainers, processors
3
+
4
+
5
+ def text_iterator(file_path):
6
+ with open(file_path, "r", encoding="utf-8") as f:
7
+ reader = csv.DictReader(f)
8
+ for row in reader:
9
+ text = row['0'] + " " + row['1']
10
+ yield text
11
+
12
+ tokenizer = Tokenizer(models.BPE(unk_token="[UNK]"))
13
+
14
+ tokenizer.normalizer = normalizers.Sequence([
15
+ normalizers.NFKC()
16
+ ])
17
+
18
+ tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel()
19
+
20
+ tokenizer.decoder = decoders.ByteLevel()
21
+
22
+ trainer = trainers.BpeTrainer(
23
+ vocab_size=30_000,
24
+ min_frequency=2,
25
+ special_tokens=["[UNK]", "[PAD]", "[SOS]", "[EOS]"],
26
+ )
27
+
28
+ tokenizer.train_from_iterator(text_iterator("data/wmt_zh_en_training_corpus.csv"), trainer=trainer)
29
+
30
+
31
+
32
+ tokenizer.post_processor = processors.TemplateProcessing(
33
+ single="[SOS] $A [EOS]",
34
+ pair="[SOS] $A [EOS] $B [EOS]",
35
+ special_tokens=[
36
+ ("[SOS]", tokenizer.token_to_id("[SOS]")),
37
+ ("[EOS]", tokenizer.token_to_id("[EOS]")),
38
+ ],
39
+ )
40
+
41
+ tokenizer.save("checkpoints/tokenizer.json")