Randolphzeng commited on
Commit
486312f
1 Parent(s): d5b4c41

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +4 -4
README.md CHANGED
@@ -51,12 +51,13 @@ vae_model = Della.from_pretrained("IDEA-CCNL/Randeng-DELLA-CVAE-226M-NER-Chinese
51
  special_tokens_dict = {'bos_token': '<BOS>', 'eos_token': '<EOS>', 'additional_special_tokens': ['<ENT>', '<ENS>']}
52
  tokenizer.add_special_tokens(special_tokens_dict)
53
 
54
- model = vae_model.model
 
55
  ent_token_type_id = tokenizer.additional_special_tokens_ids[0]
56
  ent_token_sep_id = tokenizer.additional_special_tokens_ids[1]
57
  bos_token_id, eos_token_id = tokenizer.bos_token_id, tokenizer.eos_token_id
58
  decoder_target, decoder_entities = [], []
59
- entity_list = [('体验中心', '地点/地理位置'), ('昨天', '时间')]
60
 
61
  for ent in entity_list:
62
  entity_name = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(ent[0]))
@@ -72,7 +73,7 @@ prior_z_list, prior_output_list = model.get_cond_prior_vecs(encoder_outputs.hidd
72
  outputs = model.decoder.generate(input_ids=inputs.to(device), layer_latent_vecs=prior_z_list, labels=None,
73
  label_ignore=model.pad_token_id, num_return_sequences=32, max_new_tokens=256,
74
  eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.pad_token_id,
75
- no_repeat_ngram_size=-1, do_sample=True, top_p=0.5)
76
 
77
  print(tokenizer.decode(inputs[0]))
78
  gen_sents = []
@@ -84,7 +85,6 @@ for idx in range(len(outputs)):
84
  gen_sents.append(gen_sent)
85
  for s in gen_sents:
86
  print(s)
87
-
88
  ```
89
 
90
  ## 引用 Citation
 
51
  special_tokens_dict = {'bos_token': '<BOS>', 'eos_token': '<EOS>', 'additional_special_tokens': ['<ENT>', '<ENS>']}
52
  tokenizer.add_special_tokens(special_tokens_dict)
53
 
54
+ device = 0
55
+ model = vae_model.model.to(device)
56
  ent_token_type_id = tokenizer.additional_special_tokens_ids[0]
57
  ent_token_sep_id = tokenizer.additional_special_tokens_ids[1]
58
  bos_token_id, eos_token_id = tokenizer.bos_token_id, tokenizer.eos_token_id
59
  decoder_target, decoder_entities = [], []
60
+ entity_list = [('深圳', '地点/地理位置'), ('昨天', '时间')]
61
 
62
  for ent in entity_list:
63
  entity_name = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(ent[0]))
 
73
  outputs = model.decoder.generate(input_ids=inputs.to(device), layer_latent_vecs=prior_z_list, labels=None,
74
  label_ignore=model.pad_token_id, num_return_sequences=32, max_new_tokens=256,
75
  eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.pad_token_id,
76
+ no_repeat_ngram_size=-1, do_sample=True, top_p=0.8)
77
 
78
  print(tokenizer.decode(inputs[0]))
79
  gen_sents = []
 
85
  gen_sents.append(gen_sent)
86
  for s in gen_sents:
87
  print(s)
 
88
  ```
89
 
90
  ## 引用 Citation