Spaces:
Configuration error
Configuration error
| from typing import Callable | |
| import torch | |
| from lightning import seed_everything | |
| from PIL import Image | |
| from torch import optim | |
| from src import config as C | |
| from src.config import Config, Head | |
| from src.heads import head | |
| from src.loss import Loss, LossInputs, LossOutputs | |
| from src.losses import unifalign | |
| from src.model.base import BaseDeepakeDetectionModel, Batch | |
| from src.utils import logger | |
| from src.utils.decorators import TryExcept | |
| class GenD(BaseDeepakeDetectionModel): | |
| def __init__(self, config: Config, verbose: bool = False): | |
| super().__init__(config, verbose) | |
| self.config = config | |
| self.save_hyperparameters(config.model_dump()) | |
| self.is_debug_mode = "tmp" in config.run_name | |
| if verbose: | |
| logger.print(config) | |
| seed_everything(self.config.seed, workers=True, verbose=verbose) | |
| self._init_specific_attributes(verbose) | |
| def _init_specific_attributes(self, verbose: bool = False): | |
| self._init_feature_extractor() | |
| self._init_head() | |
| self._freeze_parameters() | |
| self._init_peft() | |
| self._init_loss() | |
| if verbose: | |
| self.print_trainable_parameters() | |
| def print_trainable_parameters(self): | |
| logger.print("\n🔥 [red bold]Trainable parameters:") | |
| for name, param in self.named_parameters(): | |
| if param.requires_grad: | |
| logger.print(f"[red]- {name} shape = {tuple(param.shape)}") | |
| all_params = sum(p.numel() for p in self.parameters()) | |
| trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad) | |
| logger.print( | |
| f"Total parameters: {all_params}, trainable: {trainable_params}, %: {trainable_params / all_params * 100:.4f}" | |
| ) | |
| def _init_feature_extractor(self): | |
| logger.print("\n[blue]Initializing image encoder...") | |
| backbone = self.config.backbone | |
| backbone_lowercase = backbone.lower() | |
| if "clip" in backbone_lowercase: | |
| from src.encoders.clip_encoder import CLIPEncoder | |
| self.feature_extractor = CLIPEncoder(backbone) | |
| elif "vit_pe" in backbone_lowercase: | |
| from src.encoders.perception_encoder import PerceptionEncoder | |
| self.feature_extractor = PerceptionEncoder(backbone, self.config.backbone_args.img_size) | |
| elif "dino" in backbone_lowercase: | |
| from src.encoders.dino_encoder import DINOEncoder | |
| if self.config.backbone_args is not None: | |
| merge_cls_token_with_patches = self.config.backbone_args.merge_cls_token_with_patches | |
| else: | |
| merge_cls_token_with_patches = None | |
| self.feature_extractor = DINOEncoder(backbone, merge_cls_token_with_patches) | |
| else: | |
| raise ValueError(f"Unknown backbone: {backbone}") | |
| logger.print(self.feature_extractor) | |
| # self.feature_extractor.eval() | |
| # self.feature_extractor.to(self.device) | |
| def _init_peft(self): | |
| if self.config.peft_v2 is not None: | |
| from peft import get_peft_model | |
| if self.config.peft_v2.lora is not None: | |
| from peft import LoraConfig | |
| peft_config = LoraConfig( | |
| target_modules=self.config.peft_v2.lora.target_modules, | |
| r=self.config.peft_v2.lora.rank, | |
| lora_alpha=self.config.peft_v2.lora.alpha, | |
| lora_dropout=self.config.peft_v2.lora.dropout, | |
| bias=self.config.peft_v2.lora.bias, | |
| use_rslora=self.config.peft_v2.lora.use_rslora, | |
| use_dora=self.config.peft_v2.lora.use_dora, | |
| ) | |
| else: | |
| raise ValueError("Unknown PEFT configuration") | |
| backbone = self.feature_extractor | |
| training_parameters = {name for name, param in backbone.named_parameters() if param.requires_grad} | |
| self.feature_extractor = get_peft_model(self.feature_extractor, peft_config) | |
| for name, param in backbone.named_parameters(): | |
| if name in training_parameters: | |
| param.requires_grad = True | |
| def _init_head(self): | |
| logger.print("\n[blue]Initializing head...") | |
| features_dim = self.feature_extractor.get_features_dim() | |
| match self.config.head: | |
| case Head.Linear: | |
| self.model = head.LinearProbe(features_dim, self.config.num_classes) | |
| case Head.NLinear: | |
| self.model = head.LinearProbe(features_dim, self.config.num_classes, True) | |
| case _: | |
| raise ValueError(f"Unknown head: {self.config.head}") | |
| # self.model.eval() | |
| # self.model.to(self.device) | |
| logger.print(self.model) | |
| def _freeze_parameters(self): | |
| # Freeze feature extractor | |
| self.feature_extractor.requires_grad_(not self.config.freeze_feature_extractor) | |
| if len(self.config.unfreeze_layers) > 0: | |
| for name, param in self.named_parameters(): | |
| if any(layer in name for layer in self.config.unfreeze_layers): | |
| param.requires_grad = True | |
| def _init_loss(self): | |
| self.criterion = Loss(self.config.loss) | |
| def get_preprocessing(self) -> Callable[[Image.Image], torch.Tensor]: | |
| def preprocessing(image: Image.Image) -> torch.Tensor: | |
| image = self.custom_preprocessing(image) | |
| image = self.feature_extractor.preprocess(image) | |
| return image | |
| return preprocessing | |
| def forward(self, inputs: torch.Tensor) -> head.HeadOutput: | |
| features = self.feature_extractor(inputs) | |
| outputs = self.model.forward(features) | |
| return outputs | |
| def log_loss(self, loss: LossOutputs, stage: str, batch_size: int): | |
| common = {"prog_bar": self.is_debug_mode, "on_epoch": True, "on_step": False, "batch_size": batch_size} | |
| if loss.total is not None: | |
| self.log(f"{stage}/loss", loss.total, **common) | |
| if loss.ce_labels is not None: | |
| self.log(f"{stage}/loss_ce", loss.ce_labels, **common) | |
| def log_aliunif(self, outputs: head.HeadOutput, labels: torch.Tensor, stage: str, batch_size: int): | |
| alignment = unifalign.alignment(outputs.l2_embeddings, labels) | |
| uniformity = unifalign.uniformity(outputs.l2_embeddings) | |
| common = {"prog_bar": self.is_debug_mode, "on_epoch": True, "on_step": False, "batch_size": batch_size} | |
| self.log(f"{stage}/alignment", alignment, **common) | |
| self.log(f"{stage}/uniformity", uniformity, **common) | |
| def get_probs(self, outputs: head.HeadOutput): | |
| if self.config.inference_strategy == C.InferenceStrategy.SOFTMAX: | |
| return outputs.logits_labels.softmax(1) | |
| raise NotImplementedError("Unknown inference strategy") | |
| def get_batch(self, batch: dict) -> Batch: | |
| return Batch.from_dict(batch) | |
| def on_train_start(self): | |
| logger.print(f"[blue]Logs: {self.logger.log_dir}") | |
| self.log("num_train_files", len(self.trainer.datamodule.train_dataset)) | |
| self.log("num_val_files", len(self.trainer.datamodule.val_dataset)) | |
| def on_train_epoch_start(self): | |
| # Log learning rate for the current epoch | |
| self.log("lr", self.trainer.optimizers[0].param_groups[0]["lr"]) | |
| def training_step(self, batch, batch_idx): | |
| batch = self.get_batch(batch) | |
| # outputs = self.forward(batch.images) | |
| features = self.feature_extractor(batch.images) | |
| outputs = self.model.forward(features) | |
| loss_inputs = LossInputs( | |
| logits_labels=outputs.logits_labels, | |
| labels=batch.labels, | |
| l2_embeddings=outputs.l2_embeddings, | |
| ) | |
| loss = self.criterion(loss_inputs) | |
| probs = self.get_probs(outputs) # Get probabilities based on the inference strategy | |
| # Log metrics | |
| self.log_loss(loss, "train", batch_size=len(batch.images)) | |
| self.log_aliunif(outputs, batch.labels, "train", batch_size=len(batch.images)) | |
| # Save outputs for metrics calculation | |
| self.train_step_outputs.labels.update(batch.labels) | |
| self.train_step_outputs.probs.update(probs.detach()) | |
| self.train_step_outputs.idx.update(batch.idx) | |
| return loss.total | |
| def on_train_epoch_end(self): | |
| if self.logger.log_dir is None: | |
| # TODO: figure out why logger.log_dir can be None | |
| return | |
| # Log weights norms | |
| with TryExcept(verbose=False): | |
| self.log("model/linear-W-norm", self.model.linear.weight.norm().item()) | |
| self.log("model/linear-b-norm", self.model.linear.bias.norm().item()) | |
| dataset = self.trainer.datamodule.train_dataset | |
| self.log_all_metrics(self.train_step_outputs, "train", dataset) | |
| def validation_step(self, batch, batch_idx): | |
| batch = self.get_batch(batch) | |
| outputs = self.forward(batch.images) | |
| loss_inputs = LossInputs( | |
| logits_labels=outputs.logits_labels, | |
| labels=batch.labels, | |
| l2_embeddings=outputs.l2_embeddings, | |
| ) | |
| loss = self.criterion(loss_inputs) | |
| probs = self.get_probs(outputs) | |
| self.log_loss(loss, "val", len(batch.images)) | |
| self.log_aliunif(outputs, batch.labels, "val", len(batch.images)) | |
| # Save outputs for metrics calculation | |
| self.val_step_outputs.labels.update(batch.labels) | |
| self.val_step_outputs.probs.update(probs.detach()) | |
| self.val_step_outputs.idx.update(batch.idx) | |
| def test_step(self, batch, batch_idx): | |
| batch = self.get_batch(batch) | |
| outputs = self.forward(batch.images) | |
| loss_inputs = LossInputs( | |
| logits_labels=outputs.logits_labels, | |
| labels=batch.labels, | |
| l2_embeddings=outputs.l2_embeddings, | |
| ) | |
| loss = self.criterion(loss_inputs) | |
| probs = self.get_probs(outputs) | |
| self.log_loss(loss, "test", len(batch.images)) | |
| self.log_aliunif(outputs, batch.labels, "test", len(batch.images)) | |
| # Save outputs for metrics calculation | |
| self.test_step_outputs.labels.update(batch.labels) | |
| self.test_step_outputs.probs.update(probs.detach()) | |
| self.test_step_outputs.idx.update(batch.idx) | |
| def on_validation_epoch_end(self): | |
| if self.logger.log_dir is None: | |
| # TODO: figure out why logger.log_dir can be None | |
| return | |
| dataset = self.trainer.datamodule.val_dataset | |
| self.log_all_metrics(self.val_step_outputs, "val", dataset) | |
| def configure_optimizers(self): | |
| self.trainer.fit_loop.setup_data() # because we need an access to the dataloader | |
| config = self.config | |
| # Separate parameters for weight decay and no weight decay | |
| decay_params = [] | |
| no_decay_params = [] | |
| for name, param in self.named_parameters(): | |
| if not param.requires_grad: | |
| continue | |
| if "bias" in name or "norm" in name: | |
| no_decay_params.append(param) | |
| else: | |
| decay_params.append(param) | |
| optimizer_grouped_parameters = [ | |
| {"params": decay_params, "weight_decay": config.weight_decay}, | |
| {"params": no_decay_params, "weight_decay": 0.0}, | |
| ] | |
| # Configure optimizer | |
| if config.optimizer == C.Optimizer.AdamW: | |
| optimizer = optim.AdamW( | |
| optimizer_grouped_parameters, | |
| lr=config.lr, | |
| weight_decay=config.weight_decay, | |
| betas=config.betas, | |
| ) | |
| elif config.optimizer == C.Optimizer.SGD: | |
| optimizer = optim.SGD( | |
| optimizer_grouped_parameters, | |
| lr=config.lr, | |
| momentum=config.betas[0], | |
| weight_decay=config.weight_decay, | |
| ) | |
| else: | |
| raise ValueError(f"Unknown optimizer: {config.optimizer}") | |
| optimizers = {"optimizer": optimizer} | |
| scheduler = None | |
| # Configure LR scheduler | |
| if config.lr_scheduler == "cosine": | |
| #! be careful when running experiments with limit_train_batches | |
| if config.limit_train_batches is not None: | |
| logger.print_warning_once("lr scheduling and limit_train_batches are not compatible") | |
| T_max = config.max_epochs * len(self.trainer.train_dataloader) | |
| scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=T_max, eta_min=config.min_lr) | |
| elif config.lr_scheduler == "cyclic": | |
| cycle_length_in_epochs = int(config.num_epochs_in_cycle * len(self.trainer.train_dataloader)) | |
| scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts( | |
| optimizer, T_0=cycle_length_in_epochs, T_mult=1, eta_min=config.min_lr | |
| ) | |
| # Configure warmup | |
| if config.warmup_epochs > 0: | |
| total_warmup_steps = int(config.warmup_epochs * len(self.trainer.train_dataloader)) | |
| warmup = optim.lr_scheduler.LinearLR( | |
| optimizer, start_factor=config.min_lr / config.lr, total_iters=total_warmup_steps | |
| ) | |
| if scheduler is not None: | |
| scheduler = optim.lr_scheduler.SequentialLR( | |
| optimizer, [warmup, scheduler], milestones=[total_warmup_steps] | |
| ) | |
| else: | |
| scheduler = warmup | |
| if scheduler is not None: | |
| optimizers["lr_scheduler"] = { | |
| "scheduler": scheduler, | |
| "interval": "step", | |
| "frequency": 1, | |
| } | |
| return optimizers | |