Norod78 commited on
Commit
b587ec2
1 Parent(s): 4760c8f

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +14 -2
README.md CHANGED
@@ -86,6 +86,9 @@ if input_ids != None:
86
 
87
  print("Updated max_len = " + str(max_len))
88
 
 
 
 
89
  sample_outputs = model.generate(
90
  input_ids,
91
  do_sample=True,
@@ -95,9 +98,18 @@ sample_outputs = model.generate(
95
  num_return_sequences=sample_output_num
96
  )
97
 
98
- print(100 * '-' + "\nOutput:\n" + 100 * '-')
99
  for i, sample_output in enumerate(sample_outputs):
100
- print("\n{}: {}".format(i, tokenizer.decode(sample_output, skip_special_tokens=True)))
 
 
 
 
 
 
 
 
 
101
  print("\n" + 100 * '-')
102
 
103
  ```
 
86
 
87
  print("Updated max_len = " + str(max_len))
88
 
89
+ stop_token = "<|endoftext|>"
90
+ new_lines = "\n\n\n"
91
+
92
  sample_outputs = model.generate(
93
  input_ids,
94
  do_sample=True,
 
98
  num_return_sequences=sample_output_num
99
  )
100
 
101
+ print(100 * '-' + "\n\t\tOutput\n" + 100 * '-')
102
  for i, sample_output in enumerate(sample_outputs):
103
+
104
+ text = tokenizer.decode(sample_output, skip_special_tokens=True)
105
+
106
+ # Remove all text after the stop token
107
+ text = text[: text.find(stop_token) if stop_token else None]
108
+
109
+ # Remove all text after 3 newlines
110
+ text = text[: text.find(new_lines) if new_lines else None]
111
+
112
+ print("\n{}: {}".format(i, text))
113
  print("\n" + 100 * '-')
114
 
115
  ```