|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | """ | 
					
						
						|  | Utility to export a training checkpoint to a lightweight release checkpoint. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | from pathlib import Path | 
					
						
						|  | import typing as tp | 
					
						
						|  |  | 
					
						
						|  | from omegaconf import OmegaConf, DictConfig | 
					
						
						|  | import torch | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def _clean_lm_cfg(cfg: DictConfig): | 
					
						
						|  | OmegaConf.set_struct(cfg, False) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | cfg['transformer_lm']['card'] = 2048 | 
					
						
						|  | cfg['transformer_lm']['n_q'] = 4 | 
					
						
						|  |  | 
					
						
						|  | bad_params = ['spectral_norm_attn_iters', 'spectral_norm_ff_iters', | 
					
						
						|  | 'residual_balancer_attn', 'residual_balancer_ff', 'layer_drop'] | 
					
						
						|  | for name in bad_params: | 
					
						
						|  | del cfg['transformer_lm'][name] | 
					
						
						|  | OmegaConf.set_struct(cfg, True) | 
					
						
						|  | return cfg | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def export_encodec(checkpoint_path: tp.Union[Path, str], out_folder: tp.Union[Path, str]): | 
					
						
						|  | sig = Path(checkpoint_path).parent.name | 
					
						
						|  | assert len(sig) == 8, "Not a valid Dora signature" | 
					
						
						|  | pkg = torch.load(checkpoint_path, 'cpu') | 
					
						
						|  | new_pkg = { | 
					
						
						|  | 'best_state': pkg['ema']['state']['model'], | 
					
						
						|  | 'xp.cfg': OmegaConf.to_yaml(pkg['xp.cfg']), | 
					
						
						|  | } | 
					
						
						|  | out_file = Path(out_folder) / f'{sig}.th' | 
					
						
						|  | torch.save(new_pkg, out_file) | 
					
						
						|  | return out_file | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def export_lm(checkpoint_path: tp.Union[Path, str], out_folder: tp.Union[Path, str]): | 
					
						
						|  | sig = Path(checkpoint_path).parent.name | 
					
						
						|  | assert len(sig) == 8, "Not a valid Dora signature" | 
					
						
						|  | pkg = torch.load(checkpoint_path, 'cpu') | 
					
						
						|  | new_pkg = { | 
					
						
						|  | 'best_state': pkg['fsdp_best_state']['model'], | 
					
						
						|  | 'xp.cfg': OmegaConf.to_yaml(_clean_lm_cfg(pkg['xp.cfg'])) | 
					
						
						|  | } | 
					
						
						|  | out_file = Path(out_folder) / f'{sig}.th' | 
					
						
						|  | torch.save(new_pkg, out_file) | 
					
						
						|  | return out_file | 
					
						
						|  |  |