File size: 5,118 Bytes
8e5d8c7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
import os
import ssl
import shutil
import tempfile
import hashlib
from tqdm import tqdm
from torch.hub import get_dir
from urllib.request import urlopen, Request
from segmentation_models_pytorch.encoders import (
resnet_encoders, dpn_encoders, vgg_encoders, senet_encoders,
densenet_encoders, inceptionresnetv2_encoders, inceptionv4_encoders,
efficient_net_encoders, mobilenet_encoders, xception_encoders,
timm_efficientnet_encoders, timm_resnest_encoders, timm_res2net_encoders,
timm_regnet_encoders, timm_sknet_encoders, timm_mobilenetv3_encoders,
timm_gernet_encoders
)
from segmentation_models_pytorch.encoders.timm_universal import TimmUniversalEncoder
def initialize_encoders():
"""Initialize dictionary of available encoders."""
available_encoders = {}
encoder_modules = [
resnet_encoders, dpn_encoders, vgg_encoders, senet_encoders,
densenet_encoders, inceptionresnetv2_encoders, inceptionv4_encoders,
efficient_net_encoders, mobilenet_encoders, xception_encoders,
timm_efficientnet_encoders, timm_resnest_encoders, timm_res2net_encoders,
timm_regnet_encoders, timm_sknet_encoders, timm_mobilenetv3_encoders,
timm_gernet_encoders
]
for module in encoder_modules:
available_encoders.update(module)
try:
import segmentation_models_pytorch
from packaging import version
if version.parse(segmentation_models_pytorch.__version__) >= version.parse("0.3.3"):
from segmentation_models_pytorch.encoders.mix_transformer import mix_transformer_encoders
from segmentation_models_pytorch.encoders.mobileone import mobileone_encoders
available_encoders.update(mix_transformer_encoders)
available_encoders.update(mobileone_encoders)
except ImportError:
pass
return available_encoders
def download_weights(url, destination, hash_prefix=None, show_progress=True):
"""Downloads model weights with progress tracking and verification."""
ssl._create_default_https_context = ssl._create_unverified_context
req = Request(url, headers={"User-Agent": "torch.hub"})
response = urlopen(req)
content_length = response.headers.get("Content-Length")
file_size = int(content_length[0]) if content_length else None
destination = os.path.expanduser(destination)
temp_file = tempfile.NamedTemporaryFile(delete=False, dir=os.path.dirname(destination))
try:
hasher = hashlib.sha256() if hash_prefix else None
with tqdm(total=file_size, disable=not show_progress,
unit='B', unit_scale=True, unit_divisor=1024) as pbar:
while True:
buffer = response.read(8192)
if not buffer:
break
temp_file.write(buffer)
if hasher:
hasher.update(buffer)
pbar.update(len(buffer))
temp_file.close()
if hasher and hash_prefix:
digest = hasher.hexdigest()
if digest[:len(hash_prefix)] != hash_prefix:
raise RuntimeError(f'Invalid hash value (expected "{hash_prefix}", got "{digest}")')
shutil.move(temp_file.name, destination)
finally:
temp_file.close()
if os.path.exists(temp_file.name):
os.remove(temp_file.name)
def initialize_encoder(name, in_channels=3, depth=5, weights=None, output_stride=32, **kwargs):
"""Initializes and returns configured encoder."""
encoders = initialize_encoders()
if name.startswith("tu-"):
name = name[3:]
return TimmUniversalEncoder(
name=name,
in_channels=in_channels,
depth=depth,
output_stride=output_stride,
pretrained=weights is not None,
**kwargs
)
try:
encoder_config = encoders[name]
except KeyError:
raise KeyError(f"Invalid encoder name '{name}'. Available encoders: {list(encoders.keys())}")
encoder_class = encoder_config["encoder"]
encoder_params = encoder_config["params"]
encoder_params.update(depth=depth)
if weights:
try:
weights_config = encoder_config["pretrained_settings"][weights]
except KeyError:
raise KeyError(
f"Invalid weights '{weights}' for encoder '{name}'. "
f"Available options: {list(encoder_config['pretrained_settings'].keys())}"
)
cache_dir = os.path.join(get_dir(), 'checkpoints')
os.makedirs(cache_dir, exist_ok=True)
weights_file = os.path.basename(weights_config["url"])
weights_path = os.path.join(cache_dir, weights_file)
if not os.path.exists(weights_path):
print(f'Downloading {weights_file}...')
download_weights(
weights_config["url"].replace("https", "http"),
weights_path
)
return encoder_class(**encoder_params) |