File size: 1,009 Bytes
62a2f1c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 |
from typing import Set
import spconv
if float(spconv.__version__[2:]) >= 2.2:
spconv.constants.SPCONV_USE_DIRECT_TABLE = False
try:
import spconv.pytorch as spconv
except:
import spconv as spconv
import torch.nn as nn
def find_all_spconv_keys(model: nn.Module, prefix="") -> Set[str]:
"""
Finds all spconv keys that need to have weight's transposed
"""
found_keys: Set[str] = set()
for name, child in model.named_children():
new_prefix = f"{prefix}.{name}" if prefix != "" else name
if isinstance(child, spconv.conv.SparseConvolution):
new_prefix = f"{new_prefix}.weight"
found_keys.add(new_prefix)
found_keys.update(find_all_spconv_keys(child, prefix=new_prefix))
return found_keys
def replace_feature(out, new_features):
if "replace_feature" in out.__dir__():
# spconv 2.x behaviour
return out.replace_feature(new_features)
else:
out.features = new_features
return out
|