Tim77777767
Anpassungen an der modeling, sodass der Head nun direkt importiert, und nicht selbst implementiert ist
66c5431
import torch | |
import torch.nn as nn | |
from transformers import PreTrainedModel | |
from segformer_plusplus.model.backbone.mit import MixVisionTransformer # Backbone | |
from mix_vision_transformer_config import MySegformerConfig # Config | |
from segformer_plusplus.model.head.segformer_head import SegformerHead # <-- dein Head | |
class MySegformerForSemanticSegmentation(PreTrainedModel): | |
config_class = MySegformerConfig | |
base_model_prefix = "my_segformer" | |
def __init__(self, config): | |
super().__init__(config) | |
# Backbone (MixVisionTransformer) | |
self.backbone = MixVisionTransformer( | |
embed_dims=config.embed_dims, # z.B. [64, 128, 320, 512] | |
num_stages=config.num_stages, | |
num_layers=config.num_layers, | |
num_heads=config.num_heads, | |
patch_sizes=config.patch_sizes, | |
strides=config.strides, | |
sr_ratios=config.sr_ratios, | |
mlp_ratio=config.mlp_ratio, | |
qkv_bias=config.qkv_bias, | |
drop_rate=config.drop_rate, | |
attn_drop_rate=config.attn_drop_rate, | |
drop_path_rate=config.drop_path_rate, | |
out_indices=config.out_indices | |
) | |
# Head direkt importieren | |
in_channels = config.embed_dims | |
if isinstance(in_channels, int): | |
in_channels = [in_channels] | |
self.segmentation_head = SegformerHead( | |
in_channels=in_channels, # Liste der Embeddings aus Backbone | |
in_index=list(config.out_indices), # welche Feature Maps genutzt werden | |
out_channels=getattr(config, "num_classes", 19), # Anzahl Klassen | |
dropout_ratio=0.1, | |
align_corners=False | |
) | |
self.post_init() | |
def forward(self, x): | |
# Backbone → Features (Liste von Tensors) | |
features = self.backbone(x) | |
# Head → logits | |
logits = self.segmentation_head(features) | |
return {"logits": logits} | |