hysts HF staff commited on
Commit
88e3286
1 Parent(s): 5d7586b

Update logger

Browse files
Files changed (1) hide show
  1. model.py +9 -6
model.py CHANGED
@@ -80,10 +80,10 @@ formatter = logging.Formatter(
80
  '[%(asctime)s] %(name)s %(levelname)s: %(message)s',
81
  datefmt='%Y-%m-%d %H:%M:%S')
82
  stream_handler = logging.StreamHandler(stream=sys.stdout)
83
- stream_handler.setLevel(logging.DEBUG)
84
  stream_handler.setFormatter(formatter)
85
  logger = logging.getLogger(__name__)
86
- logger.setLevel(logging.DEBUG)
87
  logger.propagate = False
88
  logger.addHandler(stream_handler)
89
 
@@ -254,7 +254,7 @@ class Model:
254
  self.style = style
255
  self.args = argparse.Namespace(**(vars(self.args) | get_recipe(style)))
256
  self.query_template = self.args.query_template
257
- logger.info(f'{self.query_template=}')
258
 
259
  self.strategy.temperature = self.args.temp_all_gen
260
 
@@ -296,7 +296,7 @@ class Model:
296
  start = time.perf_counter()
297
 
298
  text = self.query_template.format(text)
299
- logger.info(f'{text=}')
300
  seq = tokenizer.encode(text)
301
  logger.info(f'{len(seq)=}')
302
  if len(seq) > 110:
@@ -342,7 +342,7 @@ class Model:
342
  output_list.append(coarse_samples)
343
  remaining -= self.max_batch_size
344
  output_tokens = torch.cat(output_list, dim=0)
345
- logger.info(f'{output_tokens.shape=}')
346
 
347
  elapsed = time.perf_counter() - start
348
  logger.info(f'Elapsed: {elapsed}')
@@ -360,7 +360,7 @@ class Model:
360
  logger.info('--- generate_images ---')
361
  start = time.perf_counter()
362
 
363
- logger.info(f'{self.only_first_stage=}')
364
  res = []
365
  if self.only_first_stage:
366
  for i in range(len(tokens)):
@@ -414,6 +414,9 @@ class AppModel(Model):
414
  self, text: str, translate: bool, style: str, seed: int,
415
  only_first_stage: bool, num: int
416
  ) -> tuple[str | None, np.ndarray | None, list[np.ndarray] | None]:
 
 
 
417
  if translate:
418
  text = translated_text = self.translator(text)
419
  else:
 
80
  '[%(asctime)s] %(name)s %(levelname)s: %(message)s',
81
  datefmt='%Y-%m-%d %H:%M:%S')
82
  stream_handler = logging.StreamHandler(stream=sys.stdout)
83
+ stream_handler.setLevel(logging.INFO)
84
  stream_handler.setFormatter(formatter)
85
  logger = logging.getLogger(__name__)
86
+ logger.setLevel(logging.INFO)
87
  logger.propagate = False
88
  logger.addHandler(stream_handler)
89
 
 
254
  self.style = style
255
  self.args = argparse.Namespace(**(vars(self.args) | get_recipe(style)))
256
  self.query_template = self.args.query_template
257
+ logger.debug(f'{self.query_template=}')
258
 
259
  self.strategy.temperature = self.args.temp_all_gen
260
 
 
296
  start = time.perf_counter()
297
 
298
  text = self.query_template.format(text)
299
+ logger.debug(f'{text=}')
300
  seq = tokenizer.encode(text)
301
  logger.info(f'{len(seq)=}')
302
  if len(seq) > 110:
 
342
  output_list.append(coarse_samples)
343
  remaining -= self.max_batch_size
344
  output_tokens = torch.cat(output_list, dim=0)
345
+ logger.debug(f'{output_tokens.shape=}')
346
 
347
  elapsed = time.perf_counter() - start
348
  logger.info(f'Elapsed: {elapsed}')
 
360
  logger.info('--- generate_images ---')
361
  start = time.perf_counter()
362
 
363
+ logger.debug(f'{self.only_first_stage=}')
364
  res = []
365
  if self.only_first_stage:
366
  for i in range(len(tokens)):
 
414
  self, text: str, translate: bool, style: str, seed: int,
415
  only_first_stage: bool, num: int
416
  ) -> tuple[str | None, np.ndarray | None, list[np.ndarray] | None]:
417
+ logger.info(
418
+ f'{text=}, {translate=}, {style=}, {seed=}, {only_first_stage=}, {num=}'
419
+ )
420
  if translate:
421
  text = translated_text = self.translator(text)
422
  else: