Spaces:
Configuration error
Configuration error
| import os | |
| from typing import override | |
| import torch | |
| import torchvision.transforms as T | |
| from PIL import Image | |
| from src.config import Config, CustomPreprocessing | |
| from src.heads.head import HeadOutput | |
| from src.model.base import BaseDeepakeDetectionModel | |
| from src.model.fsfm import models_vit, models_vit_fs_adapter | |
| from src.utils import logger | |
| def download_model_if_needed(checkpoint_path: str, link: str): | |
| if not os.path.exists(checkpoint_path): | |
| logger.print_warning_once(f"Checkpoint '{checkpoint_path}' not found, downloading...") | |
| os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True) | |
| os.system(f"wget {link} -O {checkpoint_path}") | |
| class FSFM(BaseDeepakeDetectionModel): | |
| def __init__(self, config: Config): | |
| super().__init__(config, verbose=True) | |
| self.initialize_model(config.checkpoint) | |
| self.model.eval() | |
| def initialize_model(self, checkpoint_path: str): | |
| if checkpoint_path == "weights/FS-VFM/FS-VFM-ViT-L-Adapter.pth": | |
| link = "https://hf.co/Wolowolo/fsfm-3c/resolve/main/finetuned_models/FS-VFM_extensions/finetune_fs-adapter/cross_dataset_DfD_and_DiFF/ViT-L_VF2_600e/FT_on_FF%2B%2B_c23_32frames/checkpoint-min_val_loss.pth?download=true" | |
| download_model_if_needed(checkpoint_path, link) | |
| self.model = models_vit_fs_adapter.vit_large_patch16(num_classes=2, drop_path_rate=0.1, global_pool=True) | |
| elif checkpoint_path == "weights/FS-VFM/FS-VFM-ViT-L.pth": | |
| link = "https://hf.co/Wolowolo/fsfm-3c/resolve/main/finetuned_models/FS-VFM_extensions/cross_dataset_DFD_and_DiFF/ViT-L_VF2_600e/FT_on_FF%2B%2B_c23_32frames/checkpoint-min_val_loss.pth?download=true" | |
| download_model_if_needed(checkpoint_path, link) | |
| self.model = models_vit.vit_large_patch16( | |
| num_classes=2, | |
| drop_path_rate=0.1, | |
| global_pool=True, | |
| ) | |
| else: | |
| raise ValueError(f"Unknown FS-VFM checkpoint path: {checkpoint_path}") | |
| def forward(self, inputs: torch.Tensor) -> HeadOutput: | |
| outputs = self.model(inputs) | |
| outputs = outputs[..., [1, 0]] # Swap 0 and 1 rows to have [real, fake] | |
| return HeadOutput(logits_labels=outputs) | |
| def test_step(self, batch, batch_idx): | |
| batch = self.get_batch(batch) | |
| outputs = self.forward(batch.images) | |
| probs = outputs.logits_labels.softmax(dim=1) | |
| # Save outputs for metrics calculation | |
| self.test_step_outputs.labels.update(batch.labels) | |
| self.test_step_outputs.probs.update(probs.detach()) | |
| self.test_step_outputs.idx.update(batch.idx) | |
| def load_checkpoint(self, checkpoint_path: str): | |
| """Load the model checkpoint.""" | |
| logger.print_info(f"Loading checkpoint from {checkpoint_path}") | |
| checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=False) | |
| incompatible_keys = self.model.load_state_dict(checkpoint["model"], strict=False) | |
| self.print_checkpoint_keys(incompatible_keys) | |
| def get_preprocessing(self): | |
| if self.config.custom_preprocessing is None: | |
| logger.print_warning_once("This model might expect a zoom in to the facial image. Make sure to tune it.") | |
| def preprocess(image: Image) -> torch.Tensor: | |
| image = self.custom_preprocessing(image) | |
| return transform(image) | |
| return preprocess | |
| transform = T.Compose( | |
| [ | |
| T.Resize(224, interpolation=T.InterpolationMode.BILINEAR), | |
| T.ToTensor(), | |
| T.Normalize( | |
| [0.5482207536697388, 0.42340534925460815, 0.3654651641845703], | |
| [0.2789176106452942, 0.2438540756702423, 0.23493893444538116], | |
| ), | |
| ] | |
| ) | |
| if __name__ == "__main__": | |
| config = Config( | |
| checkpoint="weights/FS-VFM/FS-VFM-ViT-L.pth", | |
| custom_preprocessing=CustomPreprocessing(zoom_factor=1.3), | |
| ) | |
| model = FSFM(config) | |
| model.load_checkpoint(config.checkpoint) | |
| image = Image.open("datasets/FF/DF/001_870/000.png") | |
| # image = Image.open("datasets/FF/real/001/000.png") | |
| preprocessed_image = model.get_preprocessing()(image) # Convert to tensor | |
| batch = preprocessed_image.unsqueeze(0) # Add batch dimension | |
| outputs = model.forward(batch) | |
| print(outputs.logits_labels.softmax(dim=-1)) | |