skytnt commited on
Commit
e47df31
1 Parent(s): 1eb9075

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +3 -4
README.md CHANGED
@@ -31,10 +31,10 @@ model = GPT2LMHeadModel.from_pretrained("skytnt/gpt2-japanese-lyric-small")
31
 
32
 
33
  def gen_lyric(prompt_text: str):
34
- prompt_text = prompt_text.replace("\n", "\\n ")
35
  prompt_tokens = tokenizer.tokenize(prompt_text)
36
  prompt_token_ids = tokenizer.convert_tokens_to_ids(prompt_tokens)
37
- prompt_tensor = torch.LongTensor(prompt_token_ids)
38
  prompt_tensor = prompt_tensor.view(1, -1)
39
  # model forward
40
  output_sequences = model.generate(
@@ -55,8 +55,7 @@ def gen_lyric(prompt_text: str):
55
  generated_sequence = output_sequences.tolist()[0]
56
  generated_tokens = tokenizer.convert_ids_to_tokens(generated_sequence)
57
  generated_text = tokenizer.convert_tokens_to_string(generated_tokens)
58
- generated_text = "\n".join([s.strip() for s in generated_text.split('\\n')]).replace(' ', '\u3000').replace(
59
- '</s>', '\n\n---end---')
60
  return generated_text
61
 
62
 
 
31
 
32
 
33
  def gen_lyric(prompt_text: str):
34
+ prompt_text = "<s>" + prompt_text
35
  prompt_tokens = tokenizer.tokenize(prompt_text)
36
  prompt_token_ids = tokenizer.convert_tokens_to_ids(prompt_tokens)
37
+ prompt_tensor = torch.LongTensor(prompt_token_ids).to(device)
38
  prompt_tensor = prompt_tensor.view(1, -1)
39
  # model forward
40
  output_sequences = model.generate(
 
55
  generated_sequence = output_sequences.tolist()[0]
56
  generated_tokens = tokenizer.convert_ids_to_tokens(generated_sequence)
57
  generated_text = tokenizer.convert_tokens_to_string(generated_tokens)
58
+ generated_text = "\n".join([s.strip() for s in generated_text.split('\\n')]).replace(' ', '\u3000').replace('<s>', '').replace('</s>', '\n\n---end---')
 
59
  return generated_text
60
 
61