Error(s) in loading state_dict for ZoeDepth

#2
by Andyrasika - opened
!git clone https://github.com/isl-org/ZoeDepth.git
%cd ZoeDepth
import torch
import matplotlib
import matplotlib.cm
import numpy as np

from zoedepth.utils.misc import get_image_from_url, colorize
from PIL import Image
import matplotlib.pyplot as plt



torch.hub.help("intel-isl/MiDaS", "DPT_BEiT_L_384", force_reload=True)  # Triggers fresh download of MiDaS repo
zoe = torch.hub.load(".", "ZoeD_N", source="local", pretrained=True)
zoe = zoe.to("cuda")


def colorize(value, vmin=None, vmax=None, cmap='gray_r', invalid_val=-99, invalid_mask=None, background_color=(128, 128, 128, 255), gamma_corrected=False, value_transform=None):
    if isinstance(value, torch.Tensor):
        value = value.detach().cpu().numpy()

    value = value.squeeze()
    if invalid_mask is None:
        invalid_mask = value == invalid_val
    mask = np.logical_not(invalid_mask)

    # normalize
    vmin = np.percentile(value[mask],2) if vmin is None else vmin
    vmax = np.percentile(value[mask],85) if vmax is None else vmax
    if vmin != vmax:
        value = (value - vmin) / (vmax - vmin)  # vmin..vmax
    else:
        # Avoid 0-division
        value = value * 0.

    # squeeze last dim if it exists
    # grey out the invalid values

    value[invalid_mask] = np.nan
    cmapper = matplotlib.cm.get_cmap(cmap)
    if value_transform:
        value = value_transform(value)
        # value = value / value.max()
    value = cmapper(value, bytes=True)  # (nxmx4)

    # img = value[:, :, :]
    img = value[...]
    img[invalid_mask] = background_color

    # gamma correction
    img = img / 255
    img = np.power(img, 2.2)
    img = img * 255
    img = img.astype(np.uint8)
    img = Image.fromarray(img)
    return img


def get_zoe_depth_map(image):
    with torch.autocast("cuda", enabled=True):
        depth = model_zoe_n.infer_pil(image)
    depth = colorize(depth, cmap="gray_r")
    return depth

ERROR:
```

RuntimeError Traceback (most recent call last)
Cell In[11], line 11
7 from PIL import Image
8 import matplotlib.pyplot as plt
---> 11 zoe = torch.hub.load(".", "ZoeD_N", source="local", pretrained=True)
13 torch.hub.help("intel-isl/MiDaS", "DPT_BEiT_L_384", force_reload=True) # Triggers fresh download of MiDaS repo
14 zoe = torch.hub.load(".", "ZoeD_N", source="local", pretrained=True)

File /usr/local/lib/python3.10/dist-packages/torch/hub.py:542, in load(repo_or_dir, model, source, trust_repo, force_reload, verbose, skip_validation, *args, **kwargs)
538 if source == 'github':
539 repo_or_dir = _get_cache_or_reload(repo_or_dir, force_reload, trust_repo, "load",
540 verbose=verbose, skip_validation=skip_validation)
--> 542 model = _load_local(repo_or_dir, model, *args, **kwargs)
543 return model

File /usr/local/lib/python3.10/dist-packages/torch/hub.py:572, in _load_local(hubconf_dir, model, *args, **kwargs)
569 hub_module = _import_module(MODULE_HUBCONF, hubconf_path)
571 entry = _load_entry_from_hubconf(hub_module, model)
--> 572 model = entry(*args, **kwargs)
574 sys.path.remove(hubconf_dir)
576 return model

File /workspace/ZoeDepth/./hubconf.py:69, in ZoeD_N(pretrained, midas_model_type, config_mode, **kwargs)
66 pretrained_resource = "url::https://github.com/isl-org/ZoeDepth/releases/download/v1.0/ZoeD_M12_N.pt"
68 config = get_config("zoedepth", config_mode, pretrained_resource=pretrained_resource, **kwargs)
---> 69 model = build_model(config)
70 return model

File ~/.cache/torch/hub/isl-org_ZoeDepth_main/zoedepth/models/builder.py:51, in build_model(config)
48 except AttributeError as e:
49 raise ValueError(
50 f"Model {config.model} has no get_version function.") from e
---> 51 return get_version(config.version_name).build_from_config(config)

File ~/.cache/torch/hub/isl-org_ZoeDepth_main/zoedepth/models/zoedepth/zoedepth_v1.py:250, in ZoeDepth.build_from_config(config)
248 @staticmethod
249 def build_from_config(config):
--> 250 return ZoeDepth.build(**config)

File ~/.cache/torch/hub/isl-org_ZoeDepth_main/zoedepth/models/zoedepth/zoedepth_v1.py:245, in ZoeDepth.build(midas_model_type, pretrained_resource, use_pretrained_midas, train_midas, freeze_midas_bn, **kwargs)
243 if pretrained_resource:
244 assert isinstance(pretrained_resource, str), "pretrained_resource must be a string"
--> 245 model = load_state_from_resource(model, pretrained_resource)
246 return model

File ~/.cache/torch/hub/isl-org_ZoeDepth_main/zoedepth/models/model_io.py:84, in load_state_from_resource(model, resource)
82 if resource.startswith('url::'):
83 url = resource.split('url::')[1]
---> 84 return load_state_dict_from_url(model, url, progress=True)
86 elif resource.startswith('local::'):
87 path = resource.split('local::')[1]

File ~/.cache/torch/hub/isl-org_ZoeDepth_main/zoedepth/models/model_io.py:61, in load_state_dict_from_url(model, url, **kwargs)
59 def load_state_dict_from_url(model, url, **kwargs):
60 state_dict = torch.hub.load_state_dict_from_url(url, map_location='cpu', **kwargs)
---> 61 return load_state_dict(model, state_dict)

File ~/.cache/torch/hub/isl-org_ZoeDepth_main/zoedepth/models/model_io.py:49, in load_state_dict(model, state_dict)
45 k = 'module.' + k
47 state[k] = v
---> 49 model.load_state_dict(state)
50 print("Loaded successfully")
51 return model

File /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1671, in Module.load_state_dict(self, state_dict, strict)
1666 error_msgs.insert(
1667 0, 'Missing key(s) in state_dict: {}. '.format(
1668 ', '.join('"{}"'.format(k) for k in missing_keys)))
1670 if len(error_msgs) > 0:
-> 1671 raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
1672 self.class.name, "\n\t".join(error_msgs)))
1673 return _IncompatibleKeys(missing_keys, unexpected_keys)

RuntimeError: Error(s) in loading state_dict for ZoeDepth:
Unexpected key(s) in state_dict: "core.core.pretrained.model.blocks.0.attn.relative_position_index", "core.core.pretrained.model.blocks.1.attn.relative_position_index", "core.core.pretrained.model.blocks.2.attn.relative_position_index", "core.core.pretrained.model.blocks.3.attn.relative_position_index", "core.core.pretrained.model.blocks.4.attn.relative_position_index", "core.core.pretrained.model.blocks.5.attn.relative_position_index", "core.core.pretrained.model.blocks.6.attn.relative_position_index", "core.core.pretrained.model.blocks.7.attn.relative_position_index", "core.core.pretrained.model.blocks.8.attn.relative_position_index", "core.core.pretrained.model.blocks.9.attn.relative_position_index", "core.core.pretrained.model.blocks.10.attn.relative_position_index", "core.core.pretrained.model.blocks.11.attn.relative_position_index", "core.core.pretrained.model.blocks.12.attn.relative_position_index", "core.core.pretrained.model.blocks.13.attn.relative_position_index", "core.core.pretrained.model.blocks.14.attn.relative_position_index", "core.core.pretrained.model.blocks.15.attn.relative_position_index", "core.core.pretrained.model.blocks.16.attn.relative_position_index", "core.core.pretrained.model.blocks.17.attn.relative_position_index", "core.core.pretrained.model.blocks.18.attn.relative_position_index", "core.core.pretrained.model.blocks.19.attn.relative_position_index", "core.core.pretrained.model.blocks.20.attn.relative_position_index", "core.core.pretrained.model.blocks.21.attn.relative_position_index", "core.core.pretrained.model.blocks.22.attn.relative_position_index", "core.core.pretrained.model.blocks.23.attn.relative_position_index".


Installing the right version of timm did the trick for me python3 -m pip install timm==0.6.12

Sign up or log in to comment