Upload model
Browse files- model.safetensors +1 -1
- modeling_transformer.py +36 -2
model.safetensors
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 250204
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3c6565fcbb6c375fc4b9c112ed3d73602e2741f45d6f7bea766a81c603e2f0be
|
3 |
size 250204
|
modeling_transformer.py
CHANGED
@@ -71,7 +71,7 @@ def masked_softmax(X, valid_lens):
|
|
71 |
valid_lens = valid_lens.reshape(-1)
|
72 |
# 最后一轴上被掩蔽的元素使用一个非常大的负值替换,从而其softmax输出为0
|
73 |
X = sequence_mask(X.reshape(-1, shape[-1]), valid_lens,
|
74 |
-
value=-
|
75 |
return nn.functional.softmax(X.reshape(shape), dim=-1)
|
76 |
|
77 |
class DotProductAttention(nn.Module):
|
@@ -411,4 +411,38 @@ class transformerModel(PreTrainedModel):
|
|
411 |
def forward(self, enc_X, dec_X, *args):
|
412 |
enc_outputs = self.encoder(enc_X, *args)
|
413 |
dec_state = self.decoder.init_state(enc_outputs, *args)
|
414 |
-
return self.decoder(dec_X, dec_state)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
71 |
valid_lens = valid_lens.reshape(-1)
|
72 |
# 最后一轴上被掩蔽的元素使用一个非常大的负值替换,从而其softmax输出为0
|
73 |
X = sequence_mask(X.reshape(-1, shape[-1]), valid_lens,
|
74 |
+
value=-1e4)
|
75 |
return nn.functional.softmax(X.reshape(shape), dim=-1)
|
76 |
|
77 |
class DotProductAttention(nn.Module):
|
|
|
411 |
def forward(self, enc_X, dec_X, *args):
|
412 |
enc_outputs = self.encoder(enc_X, *args)
|
413 |
dec_state = self.decoder.init_state(enc_outputs, *args)
|
414 |
+
return self.decoder(dec_X, dec_state)
|
415 |
+
|
416 |
+
def predict_seq2seq(net, src_sentence, src_vocab, tgt_vocab, num_steps,
|
417 |
+
device, save_attention_weights=False):
|
418 |
+
"""序列到序列模型的预测
|
419 |
+
|
420 |
+
Defined in :numref:`sec_seq2seq_training`"""
|
421 |
+
# 在预测时将net设置为评估模式
|
422 |
+
net.eval()
|
423 |
+
src_tokens = src_vocab[src_sentence.lower().split(' ')] + [
|
424 |
+
src_vocab['<eos>']]
|
425 |
+
enc_valid_len = torch.tensor([len(src_tokens)], device=device)
|
426 |
+
src_tokens = d2l.truncate_pad(src_tokens, num_steps, src_vocab['<pad>'])
|
427 |
+
# 添加批量轴
|
428 |
+
enc_X = torch.unsqueeze(
|
429 |
+
torch.tensor(src_tokens, dtype=torch.long, device=device), dim=0)
|
430 |
+
enc_outputs = net.encoder(enc_X, enc_valid_len)
|
431 |
+
dec_state = net.decoder.init_state(enc_outputs, enc_valid_len)
|
432 |
+
# 添加批量轴
|
433 |
+
dec_X = torch.unsqueeze(torch.tensor(
|
434 |
+
[tgt_vocab['<bos>']], dtype=torch.long, device=device), dim=0)
|
435 |
+
output_seq, attention_weight_seq = [], []
|
436 |
+
for _ in range(num_steps):
|
437 |
+
Y, dec_state = net.decoder(dec_X, dec_state)
|
438 |
+
# 我们使用具有预测最高可能性的词元,作为解码器在下一时间步的输入
|
439 |
+
dec_X = Y.argmax(dim=2)
|
440 |
+
pred = dec_X.squeeze(dim=0).type(torch.int32).item()
|
441 |
+
# 保存注意力权重(稍后讨论)
|
442 |
+
if save_attention_weights:
|
443 |
+
attention_weight_seq.append(net.decoder.attention_weights)
|
444 |
+
# 一旦序列结束词元被预测,输出序列的生成就完成了
|
445 |
+
if pred == tgt_vocab['<eos>']:
|
446 |
+
break
|
447 |
+
output_seq.append(pred)
|
448 |
+
return ' '.join(tgt_vocab.to_tokens(output_seq)), attention_weight_seq
|