Spaces:
Configuration error
Configuration error
| from typing import override | |
| import cv2 | |
| import numpy as np | |
| import torch | |
| import yaml | |
| from PIL import Image | |
| from torchvision import transforms as T | |
| from src.config import Config | |
| from src.heads.head import HeadOutput | |
| from src.model.base import BaseDeepakeDetectionModel, OutputsForMetrics | |
| from src.model.forada.ds import DS | |
| from src.utils import logger | |
| class ForAda(BaseDeepakeDetectionModel): | |
| def __init__(self, config: Config): | |
| super().__init__(config, verbose=True) | |
| # load yaml file relative to the current file | |
| config_path = __file__.replace("forensics_adapter.py", "forensics_adapter_model/config.yaml") | |
| with open(config_path, "r") as f: | |
| config = yaml.safe_load(f) | |
| self.model = DS( | |
| clip_name=config["clip_model_name"], | |
| adapter_vit_name=config["vit_name"], | |
| num_quires=config["num_quires"], | |
| fusion_map=config["fusion_map"], | |
| mlp_dim=config["mlp_dim"], | |
| mlp_out_dim=config["mlp_out_dim"], | |
| head_num=config["head_num"], | |
| ) | |
| self.eval() | |
| def forward(self, inputs: torch.Tensor) -> HeadOutput: | |
| outputs = self.model({"image": inputs}, inference=True) | |
| return HeadOutput(logits_labels=outputs["logits"]) | |
| def on_test_epoch_start(self): | |
| self.test_step_outputs = OutputsForMetrics() | |
| # move model to the device | |
| self.model.to(self.trainer.strategy.root_device) | |
| def test_step(self, batch, batch_idx): | |
| batch = self.get_batch(batch) | |
| outputs = self.forward(batch.images) | |
| probs = outputs.logits_labels.softmax(dim=1) | |
| # Save outputs for metrics calculation | |
| self.test_step_outputs.labels.update(batch.labels) | |
| self.test_step_outputs.probs.update(probs.detach()) | |
| self.test_step_outputs.idx.update(batch.idx) | |
| def load_checkpoint(self, checkpoint_path: str): | |
| """Load the model checkpoint.""" | |
| logger.print_info(f"Loading checkpoint from {checkpoint_path}") | |
| state_dict = torch.load(checkpoint_path, map_location="cpu") | |
| incompatible_keys = self.model.load_state_dict(state_dict, strict=False) | |
| self.print_checkpoint_keys(incompatible_keys) | |
| def get_preprocessing(self): | |
| def preprocess(image: Image) -> torch.Tensor: | |
| return preprocessing(image) | |
| return preprocess | |
| _preprocess = T.Compose( | |
| [ | |
| T.ToTensor(), | |
| T.Normalize( | |
| mean=[0.48145466, 0.4578275, 0.40821073], | |
| std=[0.26862954, 0.26130258, 0.27577711], | |
| ), | |
| ] | |
| ) | |
| def preprocessing(image: Image) -> torch.Tensor: | |
| image = np.array(image) | |
| image = cv2.resize(image, (256, 256), interpolation=cv2.INTER_CUBIC) | |
| image = np.array(image, dtype=np.uint8) | |
| image = _preprocess(image) | |
| # image = F.interpolate( | |
| # image.unsqueeze(0), | |
| # size=(224, 224), | |
| # mode="bilinear", | |
| # align_corners=False, | |
| # )[0] | |
| return image | |
| if __name__ == "__main__": | |
| #! Run as module: | |
| #! python -m src.model.forensics_adapter | |
| from PIL import Image | |
| from src.config import Config | |
| from src.model.ForAda import ForAda | |
| config = Config() | |
| model = ForAda(config) | |
| model.load_checkpoint("weights/forensics_adapter/ForensicsAdapter.pth") | |
| path = "datasets/FF/real/000/000.png" | |
| image = Image.open(path) # Load image | |
| preprocessed_image = model.get_preprocessing()(image) # Convert to tensor | |
| batch = preprocessed_image.unsqueeze(0) # Add batch dimension | |
| outputs = model(batch) | |
| print(outputs.logits_labels) # Print logits labels | |
| print(outputs.logits_labels.softmax(dim=1)) # Print probabilities | |