bu1 commited on
Commit
4111940
·
verified ·
1 Parent(s): 0024a9c

Upload model

Browse files
Files changed (2) hide show
  1. model.safetensors +1 -1
  2. modeling_transformer.py +36 -2
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:7a813c542acf675006e2699e2a1070a166e99f9c1cb225ff613a76c0fda09df2
3
  size 250204
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3c6565fcbb6c375fc4b9c112ed3d73602e2741f45d6f7bea766a81c603e2f0be
3
  size 250204
modeling_transformer.py CHANGED
@@ -71,7 +71,7 @@ def masked_softmax(X, valid_lens):
71
  valid_lens = valid_lens.reshape(-1)
72
  # 最后一轴上被掩蔽的元素使用一个非常大的负值替换,从而其softmax输出为0
73
  X = sequence_mask(X.reshape(-1, shape[-1]), valid_lens,
74
- value=-1e6)
75
  return nn.functional.softmax(X.reshape(shape), dim=-1)
76
 
77
  class DotProductAttention(nn.Module):
@@ -411,4 +411,38 @@ class transformerModel(PreTrainedModel):
411
  def forward(self, enc_X, dec_X, *args):
412
  enc_outputs = self.encoder(enc_X, *args)
413
  dec_state = self.decoder.init_state(enc_outputs, *args)
414
- return self.decoder(dec_X, dec_state)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  valid_lens = valid_lens.reshape(-1)
72
  # 最后一轴上被掩蔽的元素使用一个非常大的负值替换,从而其softmax输出为0
73
  X = sequence_mask(X.reshape(-1, shape[-1]), valid_lens,
74
+ value=-1e4)
75
  return nn.functional.softmax(X.reshape(shape), dim=-1)
76
 
77
  class DotProductAttention(nn.Module):
 
411
  def forward(self, enc_X, dec_X, *args):
412
  enc_outputs = self.encoder(enc_X, *args)
413
  dec_state = self.decoder.init_state(enc_outputs, *args)
414
+ return self.decoder(dec_X, dec_state)
415
+
416
+ def predict_seq2seq(net, src_sentence, src_vocab, tgt_vocab, num_steps,
417
+ device, save_attention_weights=False):
418
+ """序列到序列模型的预测
419
+
420
+ Defined in :numref:`sec_seq2seq_training`"""
421
+ # 在预测时将net设置为评估模式
422
+ net.eval()
423
+ src_tokens = src_vocab[src_sentence.lower().split(' ')] + [
424
+ src_vocab['<eos>']]
425
+ enc_valid_len = torch.tensor([len(src_tokens)], device=device)
426
+ src_tokens = d2l.truncate_pad(src_tokens, num_steps, src_vocab['<pad>'])
427
+ # 添加批量轴
428
+ enc_X = torch.unsqueeze(
429
+ torch.tensor(src_tokens, dtype=torch.long, device=device), dim=0)
430
+ enc_outputs = net.encoder(enc_X, enc_valid_len)
431
+ dec_state = net.decoder.init_state(enc_outputs, enc_valid_len)
432
+ # 添加批量轴
433
+ dec_X = torch.unsqueeze(torch.tensor(
434
+ [tgt_vocab['<bos>']], dtype=torch.long, device=device), dim=0)
435
+ output_seq, attention_weight_seq = [], []
436
+ for _ in range(num_steps):
437
+ Y, dec_state = net.decoder(dec_X, dec_state)
438
+ # 我们使用具有预测最高可能性的词元,作为解码器在下一时间步的输入
439
+ dec_X = Y.argmax(dim=2)
440
+ pred = dec_X.squeeze(dim=0).type(torch.int32).item()
441
+ # 保存注意力权重(稍后讨论)
442
+ if save_attention_weights:
443
+ attention_weight_seq.append(net.decoder.attention_weights)
444
+ # 一旦序列结束词元被预测,输出序列的生成就完成了
445
+ if pred == tgt_vocab['<eos>']:
446
+ break
447
+ output_seq.append(pred)
448
+ return ' '.join(tgt_vocab.to_tokens(output_seq)), attention_weight_seq