SemanticSegmentationModel
/
semantic-segmentation
/SemanticModel
/.ipynb_checkpoints
/encoder_management-checkpoint.py
import os | |
import ssl | |
import shutil | |
import tempfile | |
import hashlib | |
from tqdm import tqdm | |
from torch.hub import get_dir | |
from urllib.request import urlopen, Request | |
from segmentation_models_pytorch.encoders import ( | |
resnet_encoders, dpn_encoders, vgg_encoders, senet_encoders, | |
densenet_encoders, inceptionresnetv2_encoders, inceptionv4_encoders, | |
efficient_net_encoders, mobilenet_encoders, xception_encoders, | |
timm_efficientnet_encoders, timm_resnest_encoders, timm_res2net_encoders, | |
timm_regnet_encoders, timm_sknet_encoders, timm_mobilenetv3_encoders, | |
timm_gernet_encoders | |
) | |
from segmentation_models_pytorch.encoders.timm_universal import TimmUniversalEncoder | |
def initialize_encoders(): | |
"""Initialize dictionary of available encoders.""" | |
available_encoders = {} | |
encoder_modules = [ | |
resnet_encoders, dpn_encoders, vgg_encoders, senet_encoders, | |
densenet_encoders, inceptionresnetv2_encoders, inceptionv4_encoders, | |
efficient_net_encoders, mobilenet_encoders, xception_encoders, | |
timm_efficientnet_encoders, timm_resnest_encoders, timm_res2net_encoders, | |
timm_regnet_encoders, timm_sknet_encoders, timm_mobilenetv3_encoders, | |
timm_gernet_encoders | |
] | |
for module in encoder_modules: | |
available_encoders.update(module) | |
try: | |
import segmentation_models_pytorch | |
from packaging import version | |
if version.parse(segmentation_models_pytorch.__version__) >= version.parse("0.3.3"): | |
from segmentation_models_pytorch.encoders.mix_transformer import mix_transformer_encoders | |
from segmentation_models_pytorch.encoders.mobileone import mobileone_encoders | |
available_encoders.update(mix_transformer_encoders) | |
available_encoders.update(mobileone_encoders) | |
except ImportError: | |
pass | |
return available_encoders | |
def download_weights(url, destination, hash_prefix=None, show_progress=True): | |
"""Downloads model weights with progress tracking and verification.""" | |
ssl._create_default_https_context = ssl._create_unverified_context | |
req = Request(url, headers={"User-Agent": "torch.hub"}) | |
response = urlopen(req) | |
content_length = response.headers.get("Content-Length") | |
file_size = int(content_length[0]) if content_length else None | |
destination = os.path.expanduser(destination) | |
temp_file = tempfile.NamedTemporaryFile(delete=False, dir=os.path.dirname(destination)) | |
try: | |
hasher = hashlib.sha256() if hash_prefix else None | |
with tqdm(total=file_size, disable=not show_progress, | |
unit='B', unit_scale=True, unit_divisor=1024) as pbar: | |
while True: | |
buffer = response.read(8192) | |
if not buffer: | |
break | |
temp_file.write(buffer) | |
if hasher: | |
hasher.update(buffer) | |
pbar.update(len(buffer)) | |
temp_file.close() | |
if hasher and hash_prefix: | |
digest = hasher.hexdigest() | |
if digest[:len(hash_prefix)] != hash_prefix: | |
raise RuntimeError(f'Invalid hash value (expected "{hash_prefix}", got "{digest}")') | |
shutil.move(temp_file.name, destination) | |
finally: | |
temp_file.close() | |
if os.path.exists(temp_file.name): | |
os.remove(temp_file.name) | |
def initialize_encoder(name, in_channels=3, depth=5, weights=None, output_stride=32, **kwargs): | |
"""Initializes and returns configured encoder.""" | |
encoders = initialize_encoders() | |
if name.startswith("tu-"): | |
name = name[3:] | |
return TimmUniversalEncoder( | |
name=name, | |
in_channels=in_channels, | |
depth=depth, | |
output_stride=output_stride, | |
pretrained=weights is not None, | |
**kwargs | |
) | |
try: | |
encoder_config = encoders[name] | |
except KeyError: | |
raise KeyError(f"Invalid encoder name '{name}'. Available encoders: {list(encoders.keys())}") | |
encoder_class = encoder_config["encoder"] | |
encoder_params = encoder_config["params"] | |
encoder_params.update(depth=depth) | |
if weights: | |
try: | |
weights_config = encoder_config["pretrained_settings"][weights] | |
except KeyError: | |
raise KeyError( | |
f"Invalid weights '{weights}' for encoder '{name}'. " | |
f"Available options: {list(encoder_config['pretrained_settings'].keys())}" | |
) | |
cache_dir = os.path.join(get_dir(), 'checkpoints') | |
os.makedirs(cache_dir, exist_ok=True) | |
weights_file = os.path.basename(weights_config["url"]) | |
weights_path = os.path.join(cache_dir, weights_file) | |
if not os.path.exists(weights_path): | |
print(f'Downloading {weights_file}...') | |
download_weights( | |
weights_config["url"].replace("https", "http"), | |
weights_path | |
) | |
return encoder_class(**encoder_params) |