SegformerPlusPlus / modeling_my_segformer.py
Tim77777767
Anpassungen für HF
c620883
raw
history blame
4.59 kB
from transformers import PreTrainedModel
import torch
import torch.nn as nn
from segformer_plusplus.utils import resize
from segformer_plusplus.model.backbone.mit import MixVisionTransformer # Backbone-Import
from mix_vision_transformer_config import MySegformerConfig # Config-Import
# Head-Implementierung (vereinfacht)
class SegformerHead(nn.Module):
def __init__(self,
in_channels=[64, 128, 256, 512], # anpassen je nach Backbone-Ausgabe!
in_index=[0, 1, 2, 3],
channels=256,
dropout_ratio=0.1,
out_channels=19, # Anzahl Klassen anpassen!
norm_cfg=None,
align_corners=False,
interpolate_mode='bilinear'):
super().__init__()
self.in_channels = in_channels
self.in_index = in_index
self.channels = channels
self.dropout_ratio = dropout_ratio
self.out_channels = out_channels
self.norm_cfg = norm_cfg
self.align_corners = align_corners
self.interpolate_mode = interpolate_mode
self.act_cfg = dict(type='ReLU')
self.conv_seg = nn.Conv2d(channels, out_channels, kernel_size=1)
self.dropout = nn.Dropout2d(dropout_ratio) if dropout_ratio > 0 else None
num_inputs = len(in_channels)
assert num_inputs == len(in_index)
from segformer_plusplus.utils.activation import ConvModule
self.convs = nn.ModuleList()
for i in range(num_inputs):
self.convs.append(
ConvModule(
in_channels=in_channels[i],
out_channels=channels,
kernel_size=1,
stride=1,
bias=False,
norm_cfg=norm_cfg,
act_cfg=self.act_cfg))
self.fusion_conv = ConvModule(
in_channels=channels * num_inputs,
out_channels=channels,
kernel_size=1,
bias=False,
norm_cfg=norm_cfg)
def cls_seg(self, feat):
if self.dropout is not None:
feat = self.dropout(feat)
return self.conv_seg(feat)
def forward(self, inputs):
outs = []
for idx in range(len(inputs)):
x = inputs[idx]
conv = self.convs[idx]
outs.append(
resize(
input=conv(x),
size=inputs[0].shape[2:],
mode=self.interpolate_mode,
align_corners=self.align_corners))
out = self.fusion_conv(torch.cat(outs, dim=1))
out = self.cls_seg(out)
return out
class MySegformerForSemanticSegmentation(PreTrainedModel):
config_class = MySegformerConfig
base_model_prefix = "my_segformer"
def __init__(self, config):
super().__init__(config)
# Wichtig: die gesamte Liste übergeben, nicht nur das erste Element
self.backbone = MixVisionTransformer(
embed_dims=config.embed_dims, # GANZE Liste, z.B. [64, 128, 320, 512]
num_stages=config.num_stages,
num_layers=config.num_layers,
num_heads=config.num_heads,
patch_sizes=config.patch_sizes,
strides=config.strides,
sr_ratios=config.sr_ratios,
mlp_ratio=config.mlp_ratio,
qkv_bias=config.qkv_bias,
drop_rate=config.drop_rate,
attn_drop_rate=config.attn_drop_rate,
drop_path_rate=config.drop_path_rate,
out_indices=config.out_indices
)
# Sicherstellen, dass in_channels eine Liste ist
in_channels = config.embed_dims
if isinstance(in_channels, int):
in_channels = [in_channels]
print(f"config.embed_dims: {config.embed_dims}, type: {type(config.embed_dims)}")
self.segmentation_head = SegformerHead(
in_channels=config.embed_dims, # z.B. [64, 128, 320, 512]
in_index=list(config.out_indices), # z.B. [0, 1, 2, 3]
out_channels=config.num_classes if hasattr(config, 'num_classes') else 19,
dropout_ratio=0.1,
align_corners=False
)
self.post_init()
def forward(self, x):
# Backbone liefert eine Liste von Features (Multi-Scale Features)
features = self.backbone(x) # z.B. List[Tensor]
# Übergabe an den Segmentation Head
output = self.segmentation_head(features) # Tensor: logits oder Segmentationsmasken
return output