Spaces:
Sleeping
Sleeping
| import os | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| from huggingface_hub import hf_hub_download | |
| from torch.utils.data import DataLoader | |
| from torchvision import models | |
| from config import HF_BACKBONE_REPO, HF_TOKEN | |
| _BACKBONE = None | |
| _FEATURES_CACHE = None | |
| # Partagé entre tous les workers Gradio (même process group) | |
| _DISK_CACHE_PATH = "/tmp/charcoal_features.npz" | |
| def load_backbone(device: torch.device) -> nn.Module: | |
| global _BACKBONE | |
| if _BACKBONE is not None: | |
| return _BACKBONE.to(device) | |
| pt_path = hf_hub_download( | |
| repo_id=HF_BACKBONE_REPO, | |
| filename="resnet18_charcoal_backbone.pt", | |
| token=HF_TOKEN, | |
| repo_type="model", | |
| ) | |
| backbone = models.resnet18() | |
| backbone.fc = nn.Identity() | |
| backbone.load_state_dict(torch.load(pt_path, map_location="cpu")) | |
| for p in backbone.parameters(): | |
| p.requires_grad = False | |
| _BACKBONE = backbone | |
| return _BACKBONE.to(device) | |
| def extract_all_features(batch_size: int = 64): | |
| global _FEATURES_CACHE | |
| from data_utils import prepare_splits, get_class_names, HFDatasetWrapper, get_eval_transform | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| backbone = load_backbone(device) | |
| backbone.eval() | |
| splits = prepare_splits() | |
| class_names = get_class_names() | |
| cache = {} | |
| counts = {} | |
| for split_name, split_data in splits.items(): | |
| dataset = HFDatasetWrapper(split_data, get_eval_transform()) | |
| loader = DataLoader(dataset, batch_size=batch_size, shuffle=False) | |
| X_parts, y_parts = [], [] | |
| with torch.no_grad(): | |
| for images, labels in loader: | |
| features = backbone(images.to(device)) | |
| X_parts.append(features.cpu().numpy()) | |
| y_parts.append(labels.numpy()) | |
| cache[split_name] = { | |
| "X": np.concatenate(X_parts, axis=0), | |
| "y": np.concatenate(y_parts, axis=0), | |
| } | |
| counts[split_name] = len(cache[split_name]["y"]) | |
| # Sauvegarde sur disque pour que tous les workers Gradio y aient accès | |
| np.savez( | |
| _DISK_CACHE_PATH, | |
| train_X=cache["train"]["X"], train_y=cache["train"]["y"], | |
| validation_X=cache["validation"]["X"], validation_y=cache["validation"]["y"], | |
| test_X=cache["test"]["X"], test_y=cache["test"]["y"], | |
| ) | |
| _FEATURES_CACHE = cache | |
| return cache, class_names, counts | |
| def get_cached_features(): | |
| global _FEATURES_CACHE | |
| if _FEATURES_CACHE is not None: | |
| return _FEATURES_CACHE | |
| # Essaye de charger depuis le disque (autre worker a peut-être déjà extrait) | |
| if os.path.exists(_DISK_CACHE_PATH): | |
| data = np.load(_DISK_CACHE_PATH) | |
| _FEATURES_CACHE = { | |
| "train": {"X": data["train_X"], "y": data["train_y"]}, | |
| "validation": {"X": data["validation_X"], "y": data["validation_y"]}, | |
| "test": {"X": data["test_X"], "y": data["test_y"]}, | |
| } | |
| return _FEATURES_CACHE | |
| return None | |