obichimav's picture
Upload 42 files
8e5d8c7 verified
import os
import ssl
import shutil
import tempfile
import hashlib
from tqdm import tqdm
from torch.hub import get_dir
from urllib.request import urlopen, Request
from segmentation_models_pytorch.encoders import (
resnet_encoders, dpn_encoders, vgg_encoders, senet_encoders,
densenet_encoders, inceptionresnetv2_encoders, inceptionv4_encoders,
efficient_net_encoders, mobilenet_encoders, xception_encoders,
timm_efficientnet_encoders, timm_resnest_encoders, timm_res2net_encoders,
timm_regnet_encoders, timm_sknet_encoders, timm_mobilenetv3_encoders,
timm_gernet_encoders
)
from segmentation_models_pytorch.encoders.timm_universal import TimmUniversalEncoder
def initialize_encoders():
"""Initialize dictionary of available encoders."""
available_encoders = {}
encoder_modules = [
resnet_encoders, dpn_encoders, vgg_encoders, senet_encoders,
densenet_encoders, inceptionresnetv2_encoders, inceptionv4_encoders,
efficient_net_encoders, mobilenet_encoders, xception_encoders,
timm_efficientnet_encoders, timm_resnest_encoders, timm_res2net_encoders,
timm_regnet_encoders, timm_sknet_encoders, timm_mobilenetv3_encoders,
timm_gernet_encoders
]
for module in encoder_modules:
available_encoders.update(module)
try:
import segmentation_models_pytorch
from packaging import version
if version.parse(segmentation_models_pytorch.__version__) >= version.parse("0.3.3"):
from segmentation_models_pytorch.encoders.mix_transformer import mix_transformer_encoders
from segmentation_models_pytorch.encoders.mobileone import mobileone_encoders
available_encoders.update(mix_transformer_encoders)
available_encoders.update(mobileone_encoders)
except ImportError:
pass
return available_encoders
def download_weights(url, destination, hash_prefix=None, show_progress=True):
"""Downloads model weights with progress tracking and verification."""
ssl._create_default_https_context = ssl._create_unverified_context
req = Request(url, headers={"User-Agent": "torch.hub"})
response = urlopen(req)
content_length = response.headers.get("Content-Length")
file_size = int(content_length[0]) if content_length else None
destination = os.path.expanduser(destination)
temp_file = tempfile.NamedTemporaryFile(delete=False, dir=os.path.dirname(destination))
try:
hasher = hashlib.sha256() if hash_prefix else None
with tqdm(total=file_size, disable=not show_progress,
unit='B', unit_scale=True, unit_divisor=1024) as pbar:
while True:
buffer = response.read(8192)
if not buffer:
break
temp_file.write(buffer)
if hasher:
hasher.update(buffer)
pbar.update(len(buffer))
temp_file.close()
if hasher and hash_prefix:
digest = hasher.hexdigest()
if digest[:len(hash_prefix)] != hash_prefix:
raise RuntimeError(f'Invalid hash value (expected "{hash_prefix}", got "{digest}")')
shutil.move(temp_file.name, destination)
finally:
temp_file.close()
if os.path.exists(temp_file.name):
os.remove(temp_file.name)
def initialize_encoder(name, in_channels=3, depth=5, weights=None, output_stride=32, **kwargs):
"""Initializes and returns configured encoder."""
encoders = initialize_encoders()
if name.startswith("tu-"):
name = name[3:]
return TimmUniversalEncoder(
name=name,
in_channels=in_channels,
depth=depth,
output_stride=output_stride,
pretrained=weights is not None,
**kwargs
)
try:
encoder_config = encoders[name]
except KeyError:
raise KeyError(f"Invalid encoder name '{name}'. Available encoders: {list(encoders.keys())}")
encoder_class = encoder_config["encoder"]
encoder_params = encoder_config["params"]
encoder_params.update(depth=depth)
if weights:
try:
weights_config = encoder_config["pretrained_settings"][weights]
except KeyError:
raise KeyError(
f"Invalid weights '{weights}' for encoder '{name}'. "
f"Available options: {list(encoder_config['pretrained_settings'].keys())}"
)
cache_dir = os.path.join(get_dir(), 'checkpoints')
os.makedirs(cache_dir, exist_ok=True)
weights_file = os.path.basename(weights_config["url"])
weights_path = os.path.join(cache_dir, weights_file)
if not os.path.exists(weights_path):
print(f'Downloading {weights_file}...')
download_weights(
weights_config["url"].replace("https", "http"),
weights_path
)
return encoder_class(**encoder_params)