xiaosena commited on
Commit
474a5df
·
verified ·
1 Parent(s): c8e2afa

训练模型源码

Browse files
Files changed (1) hide show
  1. fanyi.py +356 -0
fanyi.py ADDED
@@ -0,0 +1,356 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ warnings.filterwarnings("ignore")
3
+ import torch
4
+ import torch.nn as nn
5
+ import math
6
+ from transformers import MarianTokenizer
7
+ from datasets import load_dataset
8
+ from typing import List
9
+ from torch import Tensor
10
+ from torch.nn import Transformer
11
+ from torch.nn.utils.rnn import pad_sequence
12
+ from torch.utils.data import DataLoader, Dataset
13
+ from timeit import default_timer as timer
14
+ import urllib.request
15
+ import os
16
+ from torch.cuda.amp import GradScaler, autocast
17
+ import logging
18
+
19
+ logging.getLogger("datasets").setLevel(logging.ERROR)
20
+
21
+ print("CUDA是否可用:", torch.cuda.is_available())
22
+ print("PyTorch版本:", torch.__version__)
23
+ if torch.cuda.is_available():
24
+ print("CUDA版本:", torch.version.cuda)
25
+
26
+ # 设置设备
27
+ DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
28
+ print("当前使用设备:", DEVICE)
29
+ if torch.cuda.is_available():
30
+ print(f"GPU信息: {torch.cuda.get_device_name(0)}")
31
+ print(f"当前GPU显存使用: {torch.cuda.memory_allocated(0)/1024**2:.2f} MB")
32
+
33
+ # 初始化tokenizer,MarianMT模型主要是通过其tokenizer(分词器)在发挥作用,而不是使��其预训练的翻译能力
34
+ tokenizer = MarianTokenizer.from_pretrained('Helsinki-NLP/opus-mt-de-en')
35
+
36
+ # 定义特殊token的索引
37
+ PAD_IDX = tokenizer.pad_token_id
38
+ BOS_IDX = tokenizer.bos_token_id
39
+ EOS_IDX = tokenizer.eos_token_id
40
+ UNK_IDX = tokenizer.unk_token_id
41
+
42
+ # 获取词汇表大小
43
+ SRC_VOCAB_SIZE = tokenizer.vocab_size
44
+ TGT_VOCAB_SIZE = tokenizer.vocab_size
45
+
46
+ class PositionalEncoding(nn.Module):
47
+ def __init__(self, emb_size: int, dropout: float, maxlen: int = 5000):
48
+ super(PositionalEncoding, self).__init__()
49
+ den = torch.exp(-torch.arange(0, emb_size, 2) * math.log(10000) / emb_size)
50
+ pos = torch.arange(0, maxlen).reshape(maxlen, 1)
51
+ pos_embedding = torch.zeros((maxlen, emb_size))
52
+ pos_embedding[:, 0::2] = torch.sin(pos * den)
53
+ pos_embedding[:, 1::2] = torch.cos(pos * den)
54
+ pos_embedding = pos_embedding.unsqueeze(-2)
55
+ self.dropout = nn.Dropout(dropout)
56
+ self.register_buffer('pos_embedding', pos_embedding)
57
+
58
+ def forward(self, token_embedding: Tensor):
59
+ return self.dropout(token_embedding + self.pos_embedding[:token_embedding.size(0), :])
60
+
61
+ class TokenEmbedding(nn.Module):
62
+ def __init__(self, vocab_size: int, emb_size):
63
+ super(TokenEmbedding, self).__init__()
64
+ self.embedding = nn.Embedding(vocab_size, emb_size)
65
+ self.emb_size = emb_size
66
+
67
+ def forward(self, tokens: Tensor):
68
+ return self.embedding(tokens.long()) * math.sqrt(self.emb_size)
69
+
70
+ class Seq2SeqTransformer(nn.Module):
71
+ def __init__(self, num_encoder_layers: int, num_decoder_layers: int,
72
+ emb_size: int, nhead: int, src_vocab_size: int,
73
+ tgt_vocab_size: int, dim_feedforward: int = 512, dropout: float = 0.1):
74
+ super(Seq2SeqTransformer, self).__init__()
75
+ self.transformer = Transformer(d_model=emb_size,
76
+ nhead=nhead,
77
+ num_encoder_layers=num_encoder_layers,
78
+ num_decoder_layers=num_decoder_layers,
79
+ dim_feedforward=dim_feedforward,
80
+ dropout=dropout)
81
+ self.generator = nn.Linear(emb_size, tgt_vocab_size)
82
+ self.src_tok_emb = TokenEmbedding(src_vocab_size, emb_size)
83
+ self.tgt_tok_emb = TokenEmbedding(tgt_vocab_size, emb_size)
84
+ self.positional_encoding = PositionalEncoding(emb_size, dropout=dropout)
85
+
86
+ def forward(self, src: Tensor, trg: Tensor, src_mask: Tensor,
87
+ tgt_mask: Tensor, src_padding_mask: Tensor,
88
+ tgt_padding_mask: Tensor, memory_key_padding_mask: Tensor):
89
+ src_emb = self.positional_encoding(self.src_tok_emb(src))
90
+ tgt_emb = self.positional_encoding(self.tgt_tok_emb(trg))
91
+ outs = self.transformer(src_emb, tgt_emb, src_mask, tgt_mask, None,
92
+ src_padding_mask, tgt_padding_mask, memory_key_padding_mask)
93
+ return self.generator(outs)
94
+
95
+ def encode(self, src: Tensor, src_mask: Tensor):
96
+ return self.transformer.encoder(self.positional_encoding(self.src_tok_emb(src)), src_mask)
97
+
98
+ def decode(self, tgt: Tensor, memory: Tensor, tgt_mask: Tensor):
99
+ return self.transformer.decoder(self.positional_encoding(self.tgt_tok_emb(tgt)), memory, tgt_mask)
100
+
101
+ def generate_square_subsequent_mask(sz):
102
+ mask = (torch.triu(torch.ones((sz, sz), device=DEVICE)) == 1).transpose(0, 1)
103
+ mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
104
+ return mask
105
+
106
+ def create_mask(src, tgt):
107
+ src_seq_len = src.shape[0]
108
+ tgt_seq_len = tgt.shape[0]
109
+
110
+ tgt_mask = generate_square_subsequent_mask(tgt_seq_len)
111
+ src_mask = torch.zeros((src_seq_len, src_seq_len), device=DEVICE).type(torch.bool)
112
+
113
+ src_padding_mask = (src == PAD_IDX).transpose(0, 1)
114
+ tgt_padding_mask = (tgt == PAD_IDX).transpose(0, 1)
115
+ return src_mask, tgt_mask, src_padding_mask, tgt_padding_mask
116
+
117
+ def download_multi30k():
118
+ base_url = "https://raw.githubusercontent.com/multi30k/dataset/master/data/task1/raw/"
119
+
120
+ # 创建数据目录
121
+ os.makedirs("multi30k", exist_ok=True)
122
+
123
+ # 下载训练、验证和测试数据
124
+ splits = ['train', 'val', 'test']
125
+ languages = ['de', 'en']
126
+
127
+ for split in splits:
128
+ for lang in languages:
129
+ filename = f"{split}.{lang}"
130
+ url = f"{base_url}{filename}"
131
+ path = f"multi30k/{filename}"
132
+
133
+ if not os.path.exists(path):
134
+ print(f"Downloading {filename}...")
135
+ urllib.request.urlretrieve(url, path)
136
+
137
+ def load_data():
138
+ # 加载WMT14数据集的德英对
139
+ dataset = load_dataset("wmt14", "de-en", cache_dir=".cache")
140
+
141
+ # 为了便于训练,我们只使用一部分数据
142
+ train_size = 29000 # 与Multi30k训练集大小相近
143
+ val_size = 1000
144
+ test_size = 1000
145
+
146
+ # 处理数据集
147
+ data = {
148
+ 'train': {
149
+ 'de': [item['de'] for item in dataset['train']['translation'][:train_size]],
150
+ 'en': [item['en'] for item in dataset['train']['translation'][:train_size]]
151
+ },
152
+ 'val': {
153
+ 'de': [item['de'] for item in dataset['validation']['translation'][:val_size]],
154
+ 'en': [item['en'] for item in dataset['validation']['translation'][:val_size]]
155
+ },
156
+ 'test': {
157
+ 'de': [item['de'] for item in dataset['test']['translation'][:test_size]],
158
+ 'en': [item['en'] for item in dataset['test']['translation'][:test_size]]
159
+ }
160
+ }
161
+
162
+ return data
163
+
164
+ # 添加一个自定义Dataset类
165
+ class TranslationDataset(Dataset):
166
+ def __init__(self, de_texts, en_texts):
167
+ self.de_texts = de_texts
168
+ self.en_texts = en_texts
169
+
170
+ def __len__(self):
171
+ return len(self.de_texts)
172
+
173
+ def __getitem__(self, idx):
174
+ return {
175
+ 'de': self.de_texts[idx],
176
+ 'en': self.en_texts[idx]
177
+ }
178
+
179
+ print("正在加载数据集...")
180
+ _cached_data = load_data() # 全局缓存数据
181
+
182
+ def get_dataloader(split='train', batch_size=32):
183
+ # 使用缓存的数据而不是重新加载
184
+ data = _cached_data[split]
185
+
186
+ # 创建Dataset对象
187
+ dataset = TranslationDataset(data['de'], data['en'])
188
+
189
+ return DataLoader(
190
+ dataset,
191
+ batch_size=batch_size,
192
+ shuffle=(split == 'train')
193
+ )
194
+
195
+ # 修改模型参数,减少显存使用
196
+ BATCH_SIZE = 32 # 减小批次大小,原来是64
197
+ EMB_SIZE = 512 # 保持不变
198
+ NHEAD = 8 # 保持不变
199
+ FFN_HID_DIM = 512 # 改回512,原来改成了1024
200
+ NUM_ENCODER_LAYERS = 3 # 改回3,原来改成了4
201
+ NUM_DECODER_LAYERS = 3 # 改回3,原来改成了4
202
+ NUM_EPOCHS = 18 # 保持不变
203
+
204
+ # 实例化模型
205
+ transformer = Seq2SeqTransformer(NUM_ENCODER_LAYERS, NUM_DECODER_LAYERS, EMB_SIZE,
206
+ NHEAD, SRC_VOCAB_SIZE, TGT_VOCAB_SIZE, FFN_HID_DIM)
207
+ transformer = transformer.to(DEVICE)
208
+
209
+ # 初始化参数
210
+ for p in transformer.parameters():
211
+ if p.dim() > 1:
212
+ nn.init.xavier_uniform_(p)
213
+
214
+ # 定义损失函数和优化器
215
+ loss_fn = nn.CrossEntropyLoss(ignore_index=PAD_IDX)
216
+ optimizer = torch.optim.Adam(transformer.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)
217
+
218
+ # 创建梯度缩放器
219
+ scaler = GradScaler()
220
+
221
+ def train_epoch(model, optimizer):
222
+ try:
223
+ model.train()
224
+ losses = 0
225
+ train_dataloader = get_dataloader('train', BATCH_SIZE)
226
+
227
+ for batch in train_dataloader:
228
+ src_texts = batch['de']
229
+ tgt_texts = batch['en']
230
+
231
+ # 使用自动混合精度
232
+ with autocast():
233
+ src_tokens = tokenizer(src_texts, padding=True, return_tensors='pt')
234
+ tgt_tokens = tokenizer(tgt_texts, padding=True, return_tensors='pt')
235
+
236
+ src = src_tokens['input_ids'].transpose(0, 1).to(DEVICE)
237
+ tgt = tgt_tokens['input_ids'].transpose(0, 1).to(DEVICE)
238
+
239
+ tgt_input = tgt[:-1, :]
240
+ src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input)
241
+
242
+ logits = model(src, tgt_input, src_mask, tgt_mask,
243
+ src_padding_mask, tgt_padding_mask, src_padding_mask)
244
+
245
+ tgt_out = tgt[1:, :]
246
+ loss = loss_fn(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))
247
+
248
+ optimizer.zero_grad()
249
+ scaler.scale(loss).backward()
250
+ scaler.step(optimizer)
251
+ scaler.update()
252
+ losses += loss.item()
253
+
254
+ return losses / len(train_dataloader)
255
+ except KeyboardInterrupt:
256
+ print("\n训练被手动中断!正在保存当前模型状态...")
257
+ # 保存检查点
258
+ checkpoint = {
259
+ 'model_state_dict': model.state_dict(),
260
+ 'optimizer_state_dict': optimizer.state_dict(),
261
+ 'epoch': epoch, # 保存当前的epoch
262
+ 'train_loss': train_loss,
263
+ 'val_loss': val_loss
264
+ }
265
+ torch.save(checkpoint, 'transformer_translation.pth')
266
+ print("模型检查点已保存到 transformer_translation.pth")
267
+ raise KeyboardInterrupt
268
+
269
+ def evaluate(model):
270
+ model.eval()
271
+ losses = 0
272
+ val_dataloader = get_dataloader('val', BATCH_SIZE)
273
+
274
+ for batch in val_dataloader:
275
+ src_texts = batch['de']
276
+ tgt_texts = batch['en']
277
+
278
+ src_tokens = tokenizer(src_texts, padding=True, return_tensors='pt')
279
+ tgt_tokens = tokenizer(tgt_texts, padding=True, return_tensors='pt')
280
+
281
+ src = src_tokens['input_ids'].transpose(0, 1).to(DEVICE)
282
+ tgt = tgt_tokens['input_ids'].transpose(0, 1).to(DEVICE)
283
+
284
+ tgt_input = tgt[:-1, :]
285
+ src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input)
286
+
287
+ logits = model(src, tgt_input, src_mask, tgt_mask,
288
+ src_padding_mask, tgt_padding_mask, src_padding_mask)
289
+
290
+ tgt_out = tgt[1:, :]
291
+ loss = loss_fn(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))
292
+ losses += loss.item()
293
+
294
+ return losses / len(val_dataloader)
295
+
296
+ def greedy_decode(model, src, src_mask, max_len, start_symbol):
297
+ src = src.to(DEVICE)
298
+ src_mask = src_mask.to(DEVICE)
299
+
300
+ memory = model.encode(src, src_mask)
301
+ ys = torch.ones(1, 1).fill_(start_symbol).type(torch.long).to(DEVICE)
302
+
303
+ for i in range(max_len-1):
304
+ memory = memory.to(DEVICE)
305
+ tgt_mask = (generate_square_subsequent_mask(ys.size(0))
306
+ .type(torch.bool)).to(DEVICE)
307
+ out = model.decode(ys, memory, tgt_mask)
308
+ out = out.transpose(0, 1)
309
+ prob = model.generator(out[:, -1])
310
+ _, next_word = torch.max(prob, dim=1)
311
+ next_word = next_word.item()
312
+
313
+ ys = torch.cat([ys, torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=0)
314
+ if next_word == EOS_IDX:
315
+ break
316
+ return ys
317
+
318
+ def translate(model: torch.nn.Module, src_sentence: str):
319
+ model.eval()
320
+ tokens = tokenizer(src_sentence, return_tensors='pt', padding=True)
321
+ src = tokens['input_ids'].transpose(0, 1).to(DEVICE)
322
+ src_mask = (torch.zeros(src.shape[0], src.shape[0])).type(torch.bool).to(DEVICE)
323
+
324
+ tgt_tokens = greedy_decode(model, src, src_mask, max_len=src.shape[0] + 5, start_symbol=BOS_IDX).flatten()
325
+ return tokenizer.decode(tgt_tokens.tolist(), skip_special_tokens=True)
326
+
327
+ # 在训练前添加显存清理
328
+ if torch.cuda.is_available():
329
+ torch.cuda.empty_cache()
330
+
331
+
332
+
333
+ # 训练模型
334
+ for epoch in range(1, NUM_EPOCHS + 1):
335
+ start_time = timer()
336
+ train_loss = train_epoch(transformer, optimizer)
337
+ end_time = timer()
338
+ val_loss = evaluate(transformer)
339
+ print(f"Epoch: {epoch}, Train loss: {train_loss:.3f}, Val loss: {val_loss:.3f}, "
340
+ f"Epoch time = {(end_time - start_time):.3f}s")
341
+
342
+ # 保存模型
343
+ path = 'transformer_translation.pth'
344
+ torch.save(transformer.state_dict(), path)
345
+ print("模型保存成功!")
346
+
347
+ # 加载模型
348
+ transformer = Seq2SeqTransformer(NUM_ENCODER_LAYERS, NUM_DECODER_LAYERS, EMB_SIZE,
349
+ NHEAD, SRC_VOCAB_SIZE, TGT_VOCAB_SIZE, FFN_HID_DIM)
350
+ transformer.load_state_dict(torch.load(path))
351
+ transformer = transformer.to(DEVICE)
352
+ print("模型加载成功!")
353
+
354
+ # 测试翻译
355
+ print(translate(transformer, "Eine Gruppe von Freunden spielt Billiade."))
356
+