Spaces:
Running
Running
| import os | |
| import shutil | |
| import subprocess | |
| import tempfile | |
| import hydra | |
| import torch | |
| from omegaconf import DictConfig, OmegaConf | |
| from pytorch_lightning.callbacks import ModelCheckpoint, BasePredictionWriter | |
| from pytorch_lightning.utilities.rank_zero import rank_zero_only | |
| from pytorch_lightning.utilities.deepspeed import ( | |
| convert_zero_checkpoint_to_fp32_state_dict, | |
| ) | |
| from safetensors.torch import save_file | |
| def convert_deepspeed_checkpoint( | |
| cfg: DictConfig, checkpoint_callback: ModelCheckpoint, output_dir: str | |
| ): | |
| """ | |
| Convert deepspeed checkpoint to fp32 safetensors format. | |
| All frozen parameters will be removed. | |
| """ | |
| os.makedirs(output_dir, exist_ok=True) | |
| convert_zero_checkpoint_to_fp32_state_dict( | |
| checkpoint_callback.best_model_path, | |
| os.path.join(output_dir, "fp32_state_dict.pth"), | |
| ) | |
| with torch.serialization.safe_globals([set]): | |
| ckpt = torch.load( | |
| os.path.join(output_dir, "fp32_state_dict.pth"), | |
| map_location="cpu", | |
| weights_only=True, | |
| ) | |
| for param in list(ckpt["state_dict"].keys()): | |
| if getattr(cfg.model, "freeze_backbone", False) and param.startswith( | |
| "loupe.backbone" | |
| ): | |
| ckpt["state_dict"].pop(param) | |
| if getattr(cfg.model, "freeze_cls", False) and param.startswith( | |
| "loupe.classifier" | |
| ): | |
| ckpt["state_dict"].pop(param) | |
| if getattr(cfg.model, "freeze_seg", False) and param.startswith( | |
| "loupe.segmentor" | |
| ): | |
| ckpt["state_dict"].pop(param) | |
| save_file(ckpt["state_dict"], os.path.join(output_dir, "model.safetensors")) | |
| OmegaConf.save(config=cfg, f=os.path.join(output_dir, "config.yaml")) | |
| OmegaConf.save( | |
| config=hydra.core.hydra_config.HydraConfig.get().overrides.task, | |
| f=os.path.join(output_dir, "overrides.yaml"), | |
| ) | |
| print(f"Model converted to FP32 and saved to {output_dir}.") | |
| os.remove(os.path.join(output_dir, "fp32_state_dict.pth")) | |
| shutil.rmtree(checkpoint_callback.best_model_path) | |
| def prepare_output_dir(pred_path, mask_dir): | |
| if os.path.isfile(pred_path): | |
| os.remove(pred_path) | |
| if os.path.isdir(mask_dir): | |
| print(f"Removing existing directory: {mask_dir}...") | |
| try: | |
| with tempfile.TemporaryDirectory() as empty_dir: | |
| result = subprocess.run( | |
| ["rsync", "-a", "--delete", empty_dir + "/", mask_dir + "/"], | |
| stdout=subprocess.PIPE, | |
| stderr=subprocess.PIPE, | |
| text=True, | |
| ) | |
| if result.returncode != 0: | |
| raise RuntimeError(f"rsync failed: {result.stderr}") | |
| except (FileNotFoundError, RuntimeError) as e: | |
| print( | |
| f"rsync not available or failed ({e}), overwriting previous results..." | |
| ) | |
| os.makedirs(mask_dir, exist_ok=True) | |
| class CustomWriter(BasePredictionWriter): | |
| def __init__(self, cfg: DictConfig, write_interval): | |
| super().__init__(write_interval) | |
| output_dir = cfg.stage.pred_output_dir | |
| self.mask_dir = os.path.join(output_dir, "masks") | |
| self.pred_path = os.path.join(output_dir, "predictions.txt") | |
| prepare_output_dir(self.pred_path, self.mask_dir) | |
| def write_on_batch_end( | |
| self, | |
| trainer, | |
| pl_module, | |
| prediction, | |
| batch_indices, | |
| batch, | |
| batch_idx, | |
| dataloader_idx, | |
| ): | |
| cls_probs, pred_masks = prediction["cls_probs"], prediction["pred_masks"] | |
| with open(self.pred_path, "a") as f: | |
| for name, cls_prob in zip(batch["name"], cls_probs): | |
| f.write(f"{name},{cls_prob:.4f}\n") | |
| for name, pred_mask in zip(batch["name"], pred_masks): | |
| pred_mask.save(os.path.join(self.mask_dir, name)) | |