pszemraj commited on
Commit
a738f02
1 Parent(s): a7e67dd

🔊 add logs

Browse files

Signed-off-by: peter szemraj <peterszemraj@gmail.com>

Files changed (3) hide show
  1. constrained_generation.py +7 -1
  2. converse.py +1 -0
  3. utils.py +13 -0
constrained_generation.py CHANGED
@@ -4,6 +4,7 @@
4
 
5
  import copy
6
  import logging
 
7
  logging.basicConfig(level=logging.INFO)
8
  import time
9
  from pathlib import Path
@@ -11,6 +12,7 @@ from pathlib import Path
11
  import yake
12
  from transformers import AutoTokenizer, PhrasalConstraint
13
 
 
14
  def get_tokenizer(model_name="gpt2", verbose=False):
15
  """
16
  get_tokenizer - returns a tokenizer object
@@ -164,6 +166,8 @@ def constrained_generation(
164
  -------
165
  response : str, generated text
166
  """
 
 
167
  st = time.perf_counter()
168
  tokenizer = tokenizer or copy.deepcopy(pipeline.tokenizer)
169
  tokenizer.add_prefix_space = True
@@ -228,7 +232,9 @@ def constrained_generation(
228
  force_words_ids=force_words_ids if force_flexible is not None else None,
229
  max_length=None,
230
  max_new_tokens=max_generated_tokens,
231
- min_length=min_generated_tokens + prompt_length if full_text else min_generated_tokens,
 
 
232
  num_beams=num_beams,
233
  no_repeat_ngram_size=no_repeat_ngram_size,
234
  num_return_sequences=num_return_sequences,
 
4
 
5
  import copy
6
  import logging
7
+
8
  logging.basicConfig(level=logging.INFO)
9
  import time
10
  from pathlib import Path
 
12
  import yake
13
  from transformers import AutoTokenizer, PhrasalConstraint
14
 
15
+
16
  def get_tokenizer(model_name="gpt2", verbose=False):
17
  """
18
  get_tokenizer - returns a tokenizer object
 
166
  -------
167
  response : str, generated text
168
  """
169
+ logging.debug(f" constraining generation with {locals()}")
170
+
171
  st = time.perf_counter()
172
  tokenizer = tokenizer or copy.deepcopy(pipeline.tokenizer)
173
  tokenizer.add_prefix_space = True
 
232
  force_words_ids=force_words_ids if force_flexible is not None else None,
233
  max_length=None,
234
  max_new_tokens=max_generated_tokens,
235
+ min_length=min_generated_tokens + prompt_length
236
+ if full_text
237
+ else min_generated_tokens,
238
  num_beams=num_beams,
239
  no_repeat_ngram_size=no_repeat_ngram_size,
240
  num_return_sequences=num_return_sequences,
converse.py CHANGED
@@ -186,6 +186,7 @@ def gen_response(
186
  str, the generated text
187
 
188
  """
 
189
  input_len = len(pipeline.tokenizer(query).input_ids)
190
  if max_length + input_len > 1024:
191
  max_length = max(1024 - input_len, 8)
 
186
  str, the generated text
187
 
188
  """
189
+ logging.debug(f"input args - gen_response() : {locals()}")
190
  input_len = len(pipeline.tokenizer(query).input_ids)
191
  if max_length + input_len > 1024:
192
  max_length = max(1024 - input_len, 8)
utils.py CHANGED
@@ -7,6 +7,7 @@ from pathlib import Path
7
  import pprint as pp
8
  import re
9
  import shutil # zipfile formats
 
10
  from datetime import datetime
11
  from os.path import basename
12
  from os.path import getsize, join
@@ -383,3 +384,15 @@ def cleantxt_wrap(ugly_text, all_lower=False):
383
  return clean(ugly_text, lower=all_lower)
384
  else:
385
  return ugly_text
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  import pprint as pp
8
  import re
9
  import shutil # zipfile formats
10
+ import logging
11
  from datetime import datetime
12
  from os.path import basename
13
  from os.path import getsize, join
 
384
  return clean(ugly_text, lower=all_lower)
385
  else:
386
  return ugly_text
387
+
388
+
389
+ def setup_logging(loglevel):
390
+ """Setup basic logging
391
+
392
+ Args:
393
+ loglevel (int): minimum loglevel for emitting messages
394
+ """
395
+ logformat = "[%(asctime)s] %(levelname)s:%(name)s:%(message)s"
396
+ logging.basicConfig(
397
+ level=loglevel, stream=sys.stdout, format=logformat, datefmt="%Y-%m-%d %H:%M:%S"
398
+ )