Randolphzeng commited on
Commit
cf79171
1 Parent(s): 6188ee5

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +36 -56
README.md CHANGED
@@ -26,70 +26,50 @@ A deep VAE model pretrained on Wudao dataset. Both encoder and decoder are based
26
 
27
  ## 模型信息 Model Information
28
 
29
- 参考论文:[Fuse It More Deeply! A Variational Transformer with Layer-Wise Latent Variable Inference for Text Generation](https://arxiv.org/abs/2207.06130)
30
-
 
31
 
32
 
33
  ## 使用 Usage
34
 
35
  ```python
36
  # Checkout the latest Fengshenbang-LM directory and run following script under Fengshenbang-LM root directory
37
- import sys
38
  import torch
39
- import argparse
40
  from torch.nn.utils.rnn import pad_sequence
41
- from fengshen.models.deepVAE.vae_pl_module import DeepVAEModule
42
-
43
-
44
-
45
- if __name__ == "__main__":
46
- # TODO: Update this path to the downloaded directory
47
- checkpoint_path = '..../Randeng-DELLA-226M-Chinese'
48
- gpt2_model_path = '..../Randeng-DELLA-226M-Chinese'
49
-
50
- args_parser = argparse.ArgumentParser()
51
- args_parser.add_argument("--checkpoint_path", type=str, default=checkpoint_path)
52
- args_parser.add_argument("--gpt2_model_path", type=str, default=gpt2_model_path)
53
- args_parser.add_argument("--latent_dim", type=int, default=256)
54
- args_parser.add_argument("--beta_kl_constraints_start", type=float, default=1e-5)
55
- args_parser.add_argument("--beta_kl_constraints_stop", type=float, default=1.)
56
- args_parser.add_argument("--beta_n_cycles", type=int, default=10)
57
- args_parser.add_argument("--latent_lmf_rank", type=int, default=4)
58
- args_parser.add_argument("--CVAE", action='store_true')
59
- args_parser.add_argument("--share_param", action='store_false',
60
- help="specify this argument if we want to share dec's and enc's params")
61
-
62
- args, unknown_args = args_parser.parse_known_args()
63
-
64
- # load model
65
- model, tokenizer = DeepVAEModule.load_model(args, labels_dict=None)
66
- # VAE generation
67
- sentence = "本模型是在通用数据集下预训练的VAE模型,如要获得最佳效果请在特定领域微调后使用。"
68
- tokenized_text = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(sentence))
69
- decoder_target = [tokenizer.bos_token_id] + tokenized_text + [tokenizer.eos_token_id]
70
- inputs = []
71
- inputs.append(torch.tensor(decoder_target, dtype=torch.long))
72
- inputs = pad_sequence(inputs, batch_first=True, padding_value=0)
73
-
74
- max_length = 256
75
- top_p = 0.5
76
- top_k = 0
77
- temperature = .7
78
- repetition_penalty = 1.0
79
- sample = False
80
- device = 0
81
- model = model.eval()
82
- model = model.to(device)
83
-
84
- outputs = model.inference(inputs.to(device), top_p=top_p, top_k=top_k, max_length=max_length, sample=sample,
85
- temperature=temperature, repetition_penalty=repetition_penalty)
86
-
87
- for gen_sent, orig_sent in zip(outputs, inputs):
88
- print('orig_sent:', tokenizer.decode(orig_sent).replace(' ', ''))
89
- print('gen_sent:', tokenizer.decode(gen_sent).replace(' ', ''))
90
- print("-"*20)
91
-
92
-
93
 
94
 
95
  ```
 
26
 
27
  ## 模型信息 Model Information
28
 
29
+ 参考论文 Reference Paper:[Fuse It More Deeply! A Variational Transformer with Layer-Wise Latent Variable Inference for Text Generation](https://arxiv.org/abs/2207.06130)
30
+ 本模型使用了Della论文里的循环潜在向量架构,但对于解码器生成并未采用原论文的low-rank-tensor-product来进行信息融合,而是使用了简单的线性变换后逐位逐词添加的方式。该方式对于开放域数据集的预训练稳定性有较大正向作用。
31
+ Note that although we adopted the layer-wise recurrent latent variables structure as the paper, we did not use the low-rank-tensor-product to fuse the latent vectors to the decoder hidden states. Instead we applied a simple linear transformation on the latent vectors and then add them to the hidden states independently.
32
 
33
 
34
  ## 使用 Usage
35
 
36
  ```python
37
  # Checkout the latest Fengshenbang-LM directory and run following script under Fengshenbang-LM root directory
38
+
39
  import torch
 
40
  from torch.nn.utils.rnn import pad_sequence
41
+ from fengshen.models.deepVAE.deep_vae import Della
42
+ from transformers.models.bert.tokenization_bert import BertTokenizer
43
+
44
+ tokenizer = BertTokenizer.from_pretrained("IDEA-CCNL/Randeng-DELLA-226M-Chinese")
45
+ vae_model = Della.from_pretrained("IDEA-CCNL/Randeng-DELLA-226M-Chinese")
46
+
47
+ special_tokens_dict = {'bos_token': '<BOS>', 'eos_token': '<EOS>'}
48
+ tokenizer.add_special_tokens(special_tokens_dict)
49
+ sentence = "本模型是在通用数据集下预训练的VAE模型,如要获得最佳效果请在特定领域微调后使用。"
50
+ tokenized_text = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(sentence))
51
+ decoder_target = [tokenizer.bos_token_id] + tokenized_text + [tokenizer.eos_token_id]
52
+ inputs = []
53
+ inputs.append(torch.tensor(decoder_target, dtype=torch.long))
54
+ inputs = pad_sequence(inputs, batch_first=True, padding_value=0)
55
+
56
+ max_length = 256
57
+ top_p = 0.5
58
+ top_k = 0
59
+ temperature = .7
60
+ repetition_penalty = 1.0
61
+ sample = False
62
+ device = 0
63
+ model = vae_model.eval()
64
+ model = model.to(device)
65
+
66
+ outputs = model.model.inference(inputs.to(device), top_p=top_p, top_k=top_k, max_length=max_length, sample=sample,
67
+ temperature=temperature, repetition_penalty=repetition_penalty)
68
+
69
+ for gen_sent, orig_sent in zip(outputs, inputs):
70
+ print('orig_sent:', tokenizer.decode(orig_sent).replace(' ', ''))
71
+ print('gen_sent:', tokenizer.decode(gen_sent).replace(' ', ''))
72
+ print("-"*20)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
 
74
 
75
  ```