File size: 5,118 Bytes
8e5d8c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import os
import ssl
import shutil
import tempfile
import hashlib
from tqdm import tqdm
from torch.hub import get_dir
from urllib.request import urlopen, Request

from segmentation_models_pytorch.encoders import (
    resnet_encoders, dpn_encoders, vgg_encoders, senet_encoders,
    densenet_encoders, inceptionresnetv2_encoders, inceptionv4_encoders,
    efficient_net_encoders, mobilenet_encoders, xception_encoders,
    timm_efficientnet_encoders, timm_resnest_encoders, timm_res2net_encoders,
    timm_regnet_encoders, timm_sknet_encoders, timm_mobilenetv3_encoders,
    timm_gernet_encoders
)

from segmentation_models_pytorch.encoders.timm_universal import TimmUniversalEncoder

def initialize_encoders():
    """Initialize dictionary of available encoders."""
    available_encoders = {}
    encoder_modules = [
        resnet_encoders, dpn_encoders, vgg_encoders, senet_encoders,
        densenet_encoders, inceptionresnetv2_encoders, inceptionv4_encoders,
        efficient_net_encoders, mobilenet_encoders, xception_encoders,
        timm_efficientnet_encoders, timm_resnest_encoders, timm_res2net_encoders,
        timm_regnet_encoders, timm_sknet_encoders, timm_mobilenetv3_encoders,
        timm_gernet_encoders
    ]
    
    for module in encoder_modules:
        available_encoders.update(module)
    
    try:
        import segmentation_models_pytorch
        from packaging import version
        if version.parse(segmentation_models_pytorch.__version__) >= version.parse("0.3.3"):
            from segmentation_models_pytorch.encoders.mix_transformer import mix_transformer_encoders
            from segmentation_models_pytorch.encoders.mobileone import mobileone_encoders
            available_encoders.update(mix_transformer_encoders)
            available_encoders.update(mobileone_encoders)
    except ImportError:
        pass
    
    return available_encoders

def download_weights(url, destination, hash_prefix=None, show_progress=True):
    """Downloads model weights with progress tracking and verification."""
    ssl._create_default_https_context = ssl._create_unverified_context
    
    req = Request(url, headers={"User-Agent": "torch.hub"})
    response = urlopen(req)
    content_length = response.headers.get("Content-Length")
    file_size = int(content_length[0]) if content_length else None

    destination = os.path.expanduser(destination)
    temp_file = tempfile.NamedTemporaryFile(delete=False, dir=os.path.dirname(destination))

    try:
        hasher = hashlib.sha256() if hash_prefix else None
        
        with tqdm(total=file_size, disable=not show_progress,
                 unit='B', unit_scale=True, unit_divisor=1024) as pbar:
            while True:
                buffer = response.read(8192)
                if not buffer:
                    break
                    
                temp_file.write(buffer)
                if hasher:
                    hasher.update(buffer)
                pbar.update(len(buffer))

        temp_file.close()
        
        if hasher and hash_prefix:
            digest = hasher.hexdigest()
            if digest[:len(hash_prefix)] != hash_prefix:
                raise RuntimeError(f'Invalid hash value (expected "{hash_prefix}", got "{digest}")')
                
        shutil.move(temp_file.name, destination)
        
    finally:
        temp_file.close()
        if os.path.exists(temp_file.name):
            os.remove(temp_file.name)

def initialize_encoder(name, in_channels=3, depth=5, weights=None, output_stride=32, **kwargs):
    """Initializes and returns configured encoder."""
    encoders = initialize_encoders()
    
    if name.startswith("tu-"):
        name = name[3:]
        return TimmUniversalEncoder(
            name=name,
            in_channels=in_channels,
            depth=depth,
            output_stride=output_stride,
            pretrained=weights is not None,
            **kwargs
        )

    try:
        encoder_config = encoders[name]
    except KeyError:
        raise KeyError(f"Invalid encoder name '{name}'. Available encoders: {list(encoders.keys())}")

    encoder_class = encoder_config["encoder"]
    encoder_params = encoder_config["params"]
    encoder_params.update(depth=depth)
    
    if weights:
        try:
            weights_config = encoder_config["pretrained_settings"][weights]
        except KeyError:
            raise KeyError(
                f"Invalid weights '{weights}' for encoder '{name}'. "
                f"Available options: {list(encoder_config['pretrained_settings'].keys())}"
            )
            
        cache_dir = os.path.join(get_dir(), 'checkpoints')
        os.makedirs(cache_dir, exist_ok=True)
        
        weights_file = os.path.basename(weights_config["url"])
        weights_path = os.path.join(cache_dir, weights_file)
        
        if not os.path.exists(weights_path):
            print(f'Downloading {weights_file}...')
            download_weights(
                weights_config["url"].replace("https", "http"),
                weights_path
            )
    
    return encoder_class(**encoder_params)