File size: 5,609 Bytes
60465e5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
# Using the SegFormer++ Model via PyTorch Hub

This document explains how to use a pre-trained SegFormer++ model and its associated data transformations by loading them directly from a GitHub repository using PyTorch Hub. The process streamlines model access, making it easy to integrate the model into your projects with a simple one-liner.

## Prerequisites

Before running the script, ensure you have PyTorch installed. You also need to install the following dependencies, which are required by the model and its entry points:

```bash

pip install tomesd omegaconf numpy rich yapf addict tqdm packaging torchvision

```
## How It Works

The provided Python script demonstrates a full workflow, from loading the model and transformations to running inference on a dummy image.

## Step 1: Loading the Model

You can easily load the model from torchhub.
The parameters are:
- `pretrained`: If set to True, it loads the model with pre-trained ImageNet weights.
- `backbone`: Specifies the backbone architecture (e.g., 'b5' for MiT-B5). Other options include 'b0', 'b1', 'b2', 'b3', and 'b4'.
- `tome_strategy`: Defines the token merging strategy. Options include 'bsm_hq' (high quality), 'bsm_fast' (faster), and 'n2d_2x2' (non-overlapping 2x2).

- `checkpoint_url`: A URL to a specific checkpoint file. This way you can load our trained model weights that you can find in the README. Make sure, that your weight fit to the model size and number of classes.
- `out_channels`: The number of output classes for segmentation (e.g., 19 for Cityscapes).

```python

import torch

model = torch.hub.load(

    'KieDani/SegformerPlusPlus',

    'segformer_plusplus',

    pretrained=True,

    backbone='b5',

    tome_strategy='bsm_hq',

    checkpoint_url='https://mediastore.rz.uni-augsburg.de/get/yzE65lzm6N/',  # URL to checkpoints, optional

    out_channels=19,

)

model.eval()  # Set the model to evaluation mode

```

## Step 2: Loading Data Transformations

The data_transforms entry point returns a torchvision.transforms.Compose object, which encapsulates the standard preprocessing steps required by the model (resizing and normalization).



```python

# Load the data transformations

transform = torch.hub.load(

    'KieDani/SegformerPlusPlus',

    'data_transforms',
)
```



## Step 3: Preparing the Image and Running Inference



After loading the model and transformations, you can apply them to an input image. The script creates a dummy image for this example, but in a real-world scenario, you would load an image from your file system.



```python

from PIL import Image



# In a real-world scenario, you would load your image here:

# image = Image.open('path_to_your_image.jpg').convert('RGB')

dummy_image = Image.new('RGB', (1300, 1300), color='red')



# Apply the transformations

input_tensor = transform(dummy_image).unsqueeze(0)  # Add a batch dimension



# Run inference

with torch.no_grad():

    output = model(input_tensor)



# Process the output tensor to get the final segmentation map

segmentation_map = torch.argmax(output.squeeze(0), dim=0)

```

The final segmentation_map is a tensor where each pixel value represents the predicted class (from 0 to 18).



## Full Script



Below is the complete, runnable script for your reference.



```python

import torch.hub

from PIL import Image



# --- IMPORTANT: TorchHub Dependencies ---

# Install the dependencies via:

# pip install tomesd omegaconf numpy rich yapf addict tqdm packaging torchvision



# Load the SegFormer++ model with predefined parameters.

print("Loading SegFormer++ Model...")

# Replace 'your_username/your_repo' with the actual path to your repository

model = torch.hub.load(

    'KieDani/SegformerPlusPlus',  # This is a placeholder, replace it with your actual GitHub path

    'segformer_plusplus',
    pretrained=True,

    backbone='b5',

    tome_strategy='bsm_hq',

    checkpoint_url='https://mediastore.rz.uni-augsburg.de/get/yzE65lzm6N/',

    out_channels=19,

)

model.eval()

print("Model loaded successfully.")


# Load the data transformations via the 'data_transforms' entry point.

print("Loading data transformations...")

transform = torch.hub.load(

    'KieDani/SegformerPlusPlus',  # Placeholder, replace it with your actual GitHub path

    'data_transforms',
)
print("Transformations loaded successfully.")

# --- Example for Image Preparation and Inference ---
# Create a dummy image, as we don't need a real image file.
# In a real scenario, you would load an image from the hard drive, e.g.
# from PIL import Image
# image = Image.open('path_to_your_image.jpg').convert('RGB')

print("Creating a dummy image for demonstration...")

dummy_image = Image.new('RGB', (1300, 1300), color='red')
print("Original image size:", dummy_image.size)



# Apply the transformations loaded from the Hub to the image.

print("Applying transformations to the image...")

input_tensor = transform(dummy_image).unsqueeze(0)  # Adds a batch dimension

print("Transformed image tensor size:", input_tensor.shape)

# Run inference.
print("Running inference...")
with torch.no_grad():

    output = model(input_tensor)

# The output tensor has the shape [1, num_classes, height, width]

# We remove the batch dimension (1)

output_tensor = output.squeeze(0)

print(f"\nInference completed. Output tensor size: {output_tensor.shape}")



# To get the final segmentation map, you would use argmax.

segmentation_map = torch.argmax(output_tensor, dim=0)

print(f"Size of the generated segmentation map: {segmentation_map.shape}")
```