Image Classification
Transformers
Safetensors
PyTorch
rsp_swin
remote-sensing
swin-transformer
custom_code
Instructions to use BiliSakura/RSP-Swin-T with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use BiliSakura/RSP-Swin-T with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("image-classification", model="BiliSakura/RSP-Swin-T", trust_remote_code=True) pipe("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/hub/parrots.png")# Load model directly from transformers import AutoModelForImageClassification model = AutoModelForImageClassification.from_pretrained("BiliSakura/RSP-Swin-T", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
| """Model classes for RSP models compatible with transformers""" | |
| import sys | |
| import os | |
| from pathlib import Path | |
| import torch | |
| import torch.nn as nn | |
| from transformers import PreTrainedModel | |
| from safetensors.torch import load_file | |
| # Import local modular model | |
| from modular_swin import SwinTransformer | |
| # Import other models from sibling directories if needed | |
| _parent_dir = Path(__file__).parent.parent | |
| import importlib.util | |
| # Import ResNet from RSP-ResNet-50 | |
| _resnet_path = _parent_dir / "RSP-ResNet-50" / "modular_resnet.py" | |
| if _resnet_path.exists(): | |
| spec = importlib.util.spec_from_file_location("modular_resnet_resnet", _resnet_path) | |
| resnet_module = importlib.util.module_from_spec(spec) | |
| spec.loader.exec_module(resnet_module) | |
| ResNet = resnet_module.ResNet | |
| Bottleneck = resnet_module.Bottleneck | |
| else: | |
| ResNet = None | |
| Bottleneck = None | |
| # Import ViTAE from RSP-ViTAEv2-S | |
| _vitae_path = _parent_dir / "RSP-ViTAEv2-S" / "modular_vitae_window_noshift.py" | |
| if _vitae_path.exists(): | |
| spec = importlib.util.spec_from_file_location("modular_vitae_window_noshift_vitae", _vitae_path) | |
| vitae_module = importlib.util.module_from_spec(spec) | |
| spec.loader.exec_module(vitae_module) | |
| ViTAE_Window_NoShift_12_basic_stages4_14 = vitae_module.ViTAE_Window_NoShift_12_basic_stages4_14 | |
| else: | |
| ViTAE_Window_NoShift_12_basic_stages4_14 = None | |
| # Import configuration - handle both relative and absolute imports | |
| try: | |
| from configuration_rsp import RSPResNetConfig, RSPSwinConfig, RSPViTAEConfig | |
| except ImportError: | |
| # Fallback: import from same directory | |
| import importlib.util | |
| config_path = Path(__file__).parent / "configuration_rsp.py" | |
| spec = importlib.util.spec_from_file_location("configuration_rsp", config_path) | |
| config_module = importlib.util.module_from_spec(spec) | |
| spec.loader.exec_module(config_module) | |
| RSPResNetConfig = config_module.RSPResNetConfig | |
| RSPSwinConfig = config_module.RSPSwinConfig | |
| RSPViTAEConfig = config_module.RSPViTAEConfig | |
| class RSPResNetForImageClassification(PreTrainedModel): | |
| """RSP ResNet model for image classification""" | |
| config_class = RSPResNetConfig | |
| def __init__(self, config): | |
| super().__init__(config) | |
| # Build ResNet model from config | |
| block = Bottleneck if config.block == "Bottleneck" else None | |
| if block is None: | |
| raise ValueError(f"Unsupported block type: {config.block}") | |
| self.model = ResNet( | |
| block=block, | |
| layers=config.layers, | |
| num_classes=config.num_labels | |
| ) | |
| def forward(self, pixel_values=None, labels=None, **kwargs): | |
| """ | |
| Args: | |
| pixel_values: Input images (B, C, H, W) | |
| labels: Optional labels for loss computation | |
| """ | |
| if pixel_values is None: | |
| raise ValueError("pixel_values must be provided") | |
| logits = self.model(pixel_values) | |
| loss = None | |
| if labels is not None: | |
| loss_fct = nn.CrossEntropyLoss() | |
| loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1)) | |
| return { | |
| "logits": logits, | |
| "loss": loss | |
| } if loss is not None else {"logits": logits} | |
| def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): | |
| """Load model from pretrained checkpoint""" | |
| config = kwargs.pop("config", None) | |
| if config is None: | |
| config = RSPResNetConfig.from_pretrained(pretrained_model_name_or_path) | |
| model = cls(config) | |
| # Load weights from safetensors | |
| model_path = Path(pretrained_model_name_or_path) | |
| safetensors_path = model_path / "model.safetensors" | |
| if safetensors_path.exists(): | |
| state_dict = load_file(str(safetensors_path)) | |
| # Remove 'model.' prefix if present | |
| state_dict_clean = {} | |
| for k, v in state_dict.items(): | |
| if k.startswith("model."): | |
| state_dict_clean[k[6:]] = v | |
| else: | |
| state_dict_clean[k] = v | |
| model.model.load_state_dict(state_dict_clean, strict=False) | |
| else: | |
| raise FileNotFoundError(f"Model weights not found at {safetensors_path}") | |
| return model | |
| class RSPSwinForImageClassification(PreTrainedModel): | |
| """RSP Swin Transformer model for image classification""" | |
| config_class = RSPSwinConfig | |
| def __init__(self, config): | |
| super().__init__(config) | |
| # Build SwinTransformer model from config | |
| self.model = SwinTransformer( | |
| img_size=config.image_size, | |
| patch_size=config.patch_size, | |
| in_chans=config.num_channels, | |
| num_classes=config.num_labels, | |
| embed_dim=config.embed_dim, | |
| depths=config.depths, | |
| num_heads=config.num_heads, | |
| window_size=config.window_size, | |
| mlp_ratio=config.mlp_ratio, | |
| qkv_bias=config.qkv_bias, | |
| ape=config.ape, | |
| patch_norm=config.patch_norm, | |
| ) | |
| def forward(self, pixel_values=None, labels=None, **kwargs): | |
| """ | |
| Args: | |
| pixel_values: Input images (B, C, H, W) | |
| labels: Optional labels for loss computation | |
| """ | |
| if pixel_values is None: | |
| raise ValueError("pixel_values must be provided") | |
| logits = self.model(pixel_values) | |
| loss = None | |
| if labels is not None: | |
| loss_fct = nn.CrossEntropyLoss() | |
| loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1)) | |
| return { | |
| "logits": logits, | |
| "loss": loss | |
| } if loss is not None else {"logits": logits} | |
| def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): | |
| """Load model from pretrained checkpoint""" | |
| config = kwargs.pop("config", None) | |
| if config is None: | |
| config = RSPSwinConfig.from_pretrained(pretrained_model_name_or_path) | |
| model = cls(config) | |
| # Load weights from safetensors | |
| model_path = Path(pretrained_model_name_or_path) | |
| safetensors_path = model_path / "model.safetensors" | |
| if safetensors_path.exists(): | |
| state_dict = load_file(str(safetensors_path)) | |
| # Remove 'model.' prefix if present | |
| state_dict_clean = {} | |
| for k, v in state_dict.items(): | |
| if k.startswith("model."): | |
| state_dict_clean[k[6:]] = v | |
| else: | |
| state_dict_clean[k] = v | |
| model.model.load_state_dict(state_dict_clean, strict=False) | |
| else: | |
| raise FileNotFoundError(f"Model weights not found at {safetensors_path}") | |
| return model | |
| class RSPViTAEForImageClassification(PreTrainedModel): | |
| """RSP ViTAE model for image classification""" | |
| config_class = RSPViTAEConfig | |
| def __init__(self, config): | |
| super().__init__(config) | |
| # Build ViTAE model from config | |
| # Note: ViTAE_Window_NoShift_12_basic_stages4_14 already sets most parameters as defaults: | |
| # - stages=4, embed_dims=[64, 64, 128, 256], token_dims=[64, 128, 256, 512] | |
| # - downsample_ratios=[4, 2, 2, 2], NC_depth=[2, 2, 8, 2], etc. | |
| # We only pass parameters that need to be overridden (img_size, num_classes) | |
| # The function accepts **kwargs, so we can pass window_size if needed | |
| self.model = ViTAE_Window_NoShift_12_basic_stages4_14( | |
| pretrained=False, | |
| img_size=config.image_size, | |
| num_classes=config.num_labels, | |
| window_size=7, | |
| ) | |
| def forward(self, pixel_values=None, labels=None, **kwargs): | |
| """ | |
| Args: | |
| pixel_values: Input images (B, C, H, W) | |
| labels: Optional labels for loss computation | |
| """ | |
| if pixel_values is None: | |
| raise ValueError("pixel_values must be provided") | |
| logits = self.model(pixel_values) | |
| loss = None | |
| if labels is not None: | |
| loss_fct = nn.CrossEntropyLoss() | |
| loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1)) | |
| return { | |
| "logits": logits, | |
| "loss": loss | |
| } if loss is not None else {"logits": logits} | |
| def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): | |
| """Load model from pretrained checkpoint""" | |
| config = kwargs.pop("config", None) | |
| if config is None: | |
| config = RSPViTAEConfig.from_pretrained(pretrained_model_name_or_path) | |
| model = cls(config) | |
| # Load weights from safetensors | |
| model_path = Path(pretrained_model_name_or_path) | |
| safetensors_path = model_path / "model.safetensors" | |
| if safetensors_path.exists(): | |
| state_dict = load_file(str(safetensors_path)) | |
| # Remove 'model.' prefix if present | |
| state_dict_clean = {} | |
| for k, v in state_dict.items(): | |
| if k.startswith("model."): | |
| state_dict_clean[k[6:]] = v | |
| else: | |
| state_dict_clean[k] = v | |
| model.model.load_state_dict(state_dict_clean, strict=False) | |
| else: | |
| raise FileNotFoundError(f"Model weights not found at {safetensors_path}") | |
| return model | |