Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	Update logger
Browse files
    	
        model.py
    CHANGED
    
    | 
         @@ -80,10 +80,10 @@ formatter = logging.Formatter( 
     | 
|
| 80 | 
         
             
                '[%(asctime)s] %(name)s %(levelname)s: %(message)s',
         
     | 
| 81 | 
         
             
                datefmt='%Y-%m-%d %H:%M:%S')
         
     | 
| 82 | 
         
             
            stream_handler = logging.StreamHandler(stream=sys.stdout)
         
     | 
| 83 | 
         
            -
            stream_handler.setLevel(logging. 
     | 
| 84 | 
         
             
            stream_handler.setFormatter(formatter)
         
     | 
| 85 | 
         
             
            logger = logging.getLogger(__name__)
         
     | 
| 86 | 
         
            -
            logger.setLevel(logging. 
     | 
| 87 | 
         
             
            logger.propagate = False
         
     | 
| 88 | 
         
             
            logger.addHandler(stream_handler)
         
     | 
| 89 | 
         | 
| 
         @@ -254,7 +254,7 @@ class Model: 
     | 
|
| 254 | 
         
             
                    self.style = style
         
     | 
| 255 | 
         
             
                    self.args = argparse.Namespace(**(vars(self.args) | get_recipe(style)))
         
     | 
| 256 | 
         
             
                    self.query_template = self.args.query_template
         
     | 
| 257 | 
         
            -
                    logger. 
     | 
| 258 | 
         | 
| 259 | 
         
             
                    self.strategy.temperature = self.args.temp_all_gen
         
     | 
| 260 | 
         | 
| 
         @@ -296,7 +296,7 @@ class Model: 
     | 
|
| 296 | 
         
             
                    start = time.perf_counter()
         
     | 
| 297 | 
         | 
| 298 | 
         
             
                    text = self.query_template.format(text)
         
     | 
| 299 | 
         
            -
                    logger. 
     | 
| 300 | 
         
             
                    seq = tokenizer.encode(text)
         
     | 
| 301 | 
         
             
                    logger.info(f'{len(seq)=}')
         
     | 
| 302 | 
         
             
                    if len(seq) > 110:
         
     | 
| 
         @@ -342,7 +342,7 @@ class Model: 
     | 
|
| 342 | 
         
             
                        output_list.append(coarse_samples)
         
     | 
| 343 | 
         
             
                        remaining -= self.max_batch_size
         
     | 
| 344 | 
         
             
                    output_tokens = torch.cat(output_list, dim=0)
         
     | 
| 345 | 
         
            -
                    logger. 
     | 
| 346 | 
         | 
| 347 | 
         
             
                    elapsed = time.perf_counter() - start
         
     | 
| 348 | 
         
             
                    logger.info(f'Elapsed: {elapsed}')
         
     | 
| 
         @@ -360,7 +360,7 @@ class Model: 
     | 
|
| 360 | 
         
             
                    logger.info('--- generate_images ---')
         
     | 
| 361 | 
         
             
                    start = time.perf_counter()
         
     | 
| 362 | 
         | 
| 363 | 
         
            -
                    logger. 
     | 
| 364 | 
         
             
                    res = []
         
     | 
| 365 | 
         
             
                    if self.only_first_stage:
         
     | 
| 366 | 
         
             
                        for i in range(len(tokens)):
         
     | 
| 
         @@ -414,6 +414,9 @@ class AppModel(Model): 
     | 
|
| 414 | 
         
             
                    self, text: str, translate: bool, style: str, seed: int,
         
     | 
| 415 | 
         
             
                    only_first_stage: bool, num: int
         
     | 
| 416 | 
         
             
                ) -> tuple[str | None, np.ndarray | None, list[np.ndarray] | None]:
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 417 | 
         
             
                    if translate:
         
     | 
| 418 | 
         
             
                        text = translated_text = self.translator(text)
         
     | 
| 419 | 
         
             
                    else:
         
     | 
| 
         | 
|
| 80 | 
         
             
                '[%(asctime)s] %(name)s %(levelname)s: %(message)s',
         
     | 
| 81 | 
         
             
                datefmt='%Y-%m-%d %H:%M:%S')
         
     | 
| 82 | 
         
             
            stream_handler = logging.StreamHandler(stream=sys.stdout)
         
     | 
| 83 | 
         
            +
            stream_handler.setLevel(logging.INFO)
         
     | 
| 84 | 
         
             
            stream_handler.setFormatter(formatter)
         
     | 
| 85 | 
         
             
            logger = logging.getLogger(__name__)
         
     | 
| 86 | 
         
            +
            logger.setLevel(logging.INFO)
         
     | 
| 87 | 
         
             
            logger.propagate = False
         
     | 
| 88 | 
         
             
            logger.addHandler(stream_handler)
         
     | 
| 89 | 
         | 
| 
         | 
|
| 254 | 
         
             
                    self.style = style
         
     | 
| 255 | 
         
             
                    self.args = argparse.Namespace(**(vars(self.args) | get_recipe(style)))
         
     | 
| 256 | 
         
             
                    self.query_template = self.args.query_template
         
     | 
| 257 | 
         
            +
                    logger.debug(f'{self.query_template=}')
         
     | 
| 258 | 
         | 
| 259 | 
         
             
                    self.strategy.temperature = self.args.temp_all_gen
         
     | 
| 260 | 
         | 
| 
         | 
|
| 296 | 
         
             
                    start = time.perf_counter()
         
     | 
| 297 | 
         | 
| 298 | 
         
             
                    text = self.query_template.format(text)
         
     | 
| 299 | 
         
            +
                    logger.debug(f'{text=}')
         
     | 
| 300 | 
         
             
                    seq = tokenizer.encode(text)
         
     | 
| 301 | 
         
             
                    logger.info(f'{len(seq)=}')
         
     | 
| 302 | 
         
             
                    if len(seq) > 110:
         
     | 
| 
         | 
|
| 342 | 
         
             
                        output_list.append(coarse_samples)
         
     | 
| 343 | 
         
             
                        remaining -= self.max_batch_size
         
     | 
| 344 | 
         
             
                    output_tokens = torch.cat(output_list, dim=0)
         
     | 
| 345 | 
         
            +
                    logger.debug(f'{output_tokens.shape=}')
         
     | 
| 346 | 
         | 
| 347 | 
         
             
                    elapsed = time.perf_counter() - start
         
     | 
| 348 | 
         
             
                    logger.info(f'Elapsed: {elapsed}')
         
     | 
| 
         | 
|
| 360 | 
         
             
                    logger.info('--- generate_images ---')
         
     | 
| 361 | 
         
             
                    start = time.perf_counter()
         
     | 
| 362 | 
         | 
| 363 | 
         
            +
                    logger.debug(f'{self.only_first_stage=}')
         
     | 
| 364 | 
         
             
                    res = []
         
     | 
| 365 | 
         
             
                    if self.only_first_stage:
         
     | 
| 366 | 
         
             
                        for i in range(len(tokens)):
         
     | 
| 
         | 
|
| 414 | 
         
             
                    self, text: str, translate: bool, style: str, seed: int,
         
     | 
| 415 | 
         
             
                    only_first_stage: bool, num: int
         
     | 
| 416 | 
         
             
                ) -> tuple[str | None, np.ndarray | None, list[np.ndarray] | None]:
         
     | 
| 417 | 
         
            +
                    logger.info(
         
     | 
| 418 | 
         
            +
                        f'{text=}, {translate=}, {style=}, {seed=}, {only_first_stage=}, {num=}'
         
     | 
| 419 | 
         
            +
                    )
         
     | 
| 420 | 
         
             
                    if translate:
         
     | 
| 421 | 
         
             
                        text = translated_text = self.translator(text)
         
     | 
| 422 | 
         
             
                    else:
         
     |