import os import ssl import shutil import tempfile import hashlib from tqdm import tqdm from torch.hub import get_dir from urllib.request import urlopen, Request from segmentation_models_pytorch.encoders import ( resnet_encoders, dpn_encoders, vgg_encoders, senet_encoders, densenet_encoders, inceptionresnetv2_encoders, inceptionv4_encoders, efficient_net_encoders, mobilenet_encoders, xception_encoders, timm_efficientnet_encoders, timm_resnest_encoders, timm_res2net_encoders, timm_regnet_encoders, timm_sknet_encoders, timm_mobilenetv3_encoders, timm_gernet_encoders ) from segmentation_models_pytorch.encoders.timm_universal import TimmUniversalEncoder def initialize_encoders(): """Initialize dictionary of available encoders.""" available_encoders = {} encoder_modules = [ resnet_encoders, dpn_encoders, vgg_encoders, senet_encoders, densenet_encoders, inceptionresnetv2_encoders, inceptionv4_encoders, efficient_net_encoders, mobilenet_encoders, xception_encoders, timm_efficientnet_encoders, timm_resnest_encoders, timm_res2net_encoders, timm_regnet_encoders, timm_sknet_encoders, timm_mobilenetv3_encoders, timm_gernet_encoders ] for module in encoder_modules: available_encoders.update(module) try: import segmentation_models_pytorch from packaging import version if version.parse(segmentation_models_pytorch.__version__) >= version.parse("0.3.3"): from segmentation_models_pytorch.encoders.mix_transformer import mix_transformer_encoders from segmentation_models_pytorch.encoders.mobileone import mobileone_encoders available_encoders.update(mix_transformer_encoders) available_encoders.update(mobileone_encoders) except ImportError: pass return available_encoders def download_weights(url, destination, hash_prefix=None, show_progress=True): """Downloads model weights with progress tracking and verification.""" ssl._create_default_https_context = ssl._create_unverified_context req = Request(url, headers={"User-Agent": "torch.hub"}) response = urlopen(req) content_length = response.headers.get("Content-Length") file_size = int(content_length[0]) if content_length else None destination = os.path.expanduser(destination) temp_file = tempfile.NamedTemporaryFile(delete=False, dir=os.path.dirname(destination)) try: hasher = hashlib.sha256() if hash_prefix else None with tqdm(total=file_size, disable=not show_progress, unit='B', unit_scale=True, unit_divisor=1024) as pbar: while True: buffer = response.read(8192) if not buffer: break temp_file.write(buffer) if hasher: hasher.update(buffer) pbar.update(len(buffer)) temp_file.close() if hasher and hash_prefix: digest = hasher.hexdigest() if digest[:len(hash_prefix)] != hash_prefix: raise RuntimeError(f'Invalid hash value (expected "{hash_prefix}", got "{digest}")') shutil.move(temp_file.name, destination) finally: temp_file.close() if os.path.exists(temp_file.name): os.remove(temp_file.name) def initialize_encoder(name, in_channels=3, depth=5, weights=None, output_stride=32, **kwargs): """Initializes and returns configured encoder.""" encoders = initialize_encoders() if name.startswith("tu-"): name = name[3:] return TimmUniversalEncoder( name=name, in_channels=in_channels, depth=depth, output_stride=output_stride, pretrained=weights is not None, **kwargs ) try: encoder_config = encoders[name] except KeyError: raise KeyError(f"Invalid encoder name '{name}'. Available encoders: {list(encoders.keys())}") encoder_class = encoder_config["encoder"] encoder_params = encoder_config["params"] encoder_params.update(depth=depth) if weights: try: weights_config = encoder_config["pretrained_settings"][weights] except KeyError: raise KeyError( f"Invalid weights '{weights}' for encoder '{name}'. " f"Available options: {list(encoder_config['pretrained_settings'].keys())}" ) cache_dir = os.path.join(get_dir(), 'checkpoints') os.makedirs(cache_dir, exist_ok=True) weights_file = os.path.basename(weights_config["url"]) weights_path = os.path.join(cache_dir, weights_file) if not os.path.exists(weights_path): print(f'Downloading {weights_file}...') download_weights( weights_config["url"].replace("https", "http"), weights_path ) return encoder_class(**encoder_params)