|
from pathlib import PurePath |
|
from typing import Sequence |
|
|
|
import torch |
|
from torch import nn |
|
|
|
import yaml |
|
|
|
|
|
class InvalidModelError(RuntimeError): |
|
"""Exception raised for any model-related error (creation, loading)""" |
|
|
|
|
|
_WEIGHTS_URL = { |
|
'parseq-tiny': 'https://github.com/baudm/parseq/releases/download/v1.0.0/parseq_tiny-e7a21b54.pt', |
|
'parseq': 'https://github.com/baudm/parseq/releases/download/v1.0.0/parseq-bb5792a6.pt', |
|
'abinet': 'https://github.com/baudm/parseq/releases/download/v1.0.0/abinet-1d1e373e.pt', |
|
'trba': 'https://github.com/baudm/parseq/releases/download/v1.0.0/trba-cfaed284.pt', |
|
'vitstr': 'https://github.com/baudm/parseq/releases/download/v1.0.0/vitstr-26d0fcf4.pt', |
|
'crnn': 'https://github.com/baudm/parseq/releases/download/v1.0.0/crnn-679d0e31.pt', |
|
} |
|
|
|
|
|
def _get_config(experiment: str, **kwargs): |
|
"""Emulates hydra config resolution""" |
|
root = PurePath(__file__).parents[2] |
|
with open(root / 'configs/main.yaml', 'r') as f: |
|
config = yaml.load(f, yaml.Loader)['model'] |
|
with open(root / f'configs/charset/94_full.yaml', 'r') as f: |
|
config.update(yaml.load(f, yaml.Loader)['model']) |
|
with open(root / f'configs/experiment/{experiment}.yaml', 'r') as f: |
|
exp = yaml.load(f, yaml.Loader) |
|
|
|
model = exp['defaults'][0]['override /model'] |
|
with open(root / f'configs/model/{model}.yaml', 'r') as f: |
|
config.update(yaml.load(f, yaml.Loader)) |
|
|
|
if 'model' in exp: |
|
config.update(exp['model']) |
|
config.update(kwargs) |
|
|
|
config['lr'] = float(config['lr']) |
|
return config |
|
|
|
|
|
def _get_model_class(key): |
|
if 'abinet' in key: |
|
from .abinet.system import ABINet as ModelClass |
|
elif 'crnn' in key: |
|
from .crnn.system import CRNN as ModelClass |
|
elif 'parseq' in key: |
|
from .parseq.system import PARSeq as ModelClass |
|
elif 'trba' in key: |
|
from .trba.system import TRBA as ModelClass |
|
elif 'trbc' in key: |
|
from .trba.system import TRBC as ModelClass |
|
elif 'vitstr' in key: |
|
from .vitstr.system import ViTSTR as ModelClass |
|
else: |
|
raise InvalidModelError("Unable to find model class for '{}'".format(key)) |
|
return ModelClass |
|
|
|
|
|
def get_pretrained_weights(experiment): |
|
try: |
|
url = _WEIGHTS_URL[experiment] |
|
except KeyError: |
|
raise InvalidModelError("No pretrained weights found for '{}'".format(experiment)) from None |
|
return torch.hub.load_state_dict_from_url(url=url, map_location='cpu', check_hash=True) |
|
|
|
|
|
def create_model(experiment: str, pretrained: bool = False, **kwargs): |
|
try: |
|
config = _get_config(experiment, **kwargs) |
|
except FileNotFoundError: |
|
raise InvalidModelError("No configuration found for '{}'".format(experiment)) from None |
|
ModelClass = _get_model_class(experiment) |
|
model = ModelClass(**config) |
|
if pretrained: |
|
model.load_state_dict(get_pretrained_weights(experiment)) |
|
return model |
|
|
|
|
|
def load_from_checkpoint(checkpoint_path: str, **kwargs): |
|
if checkpoint_path.startswith('pretrained='): |
|
model_id = checkpoint_path.split('=', maxsplit=1)[1] |
|
model = create_model(model_id, True, **kwargs) |
|
else: |
|
ModelClass = _get_model_class(checkpoint_path) |
|
model = ModelClass.load_from_checkpoint(checkpoint_path, **kwargs) |
|
return model |
|
|
|
|
|
def parse_model_args(args): |
|
kwargs = {} |
|
arg_types = {t.__name__: t for t in [int, float, str]} |
|
arg_types['bool'] = lambda v: v.lower() == 'true' |
|
for arg in args: |
|
name, value = arg.split('=', maxsplit=1) |
|
name, arg_type = name.split(':', maxsplit=1) |
|
kwargs[name] = arg_types[arg_type](value) |
|
return kwargs |
|
|
|
|
|
def init_weights(module: nn.Module, name: str = '', exclude: Sequence[str] = ()): |
|
"""Initialize the weights using the typical initialization schemes used in SOTA models.""" |
|
if any(map(name.startswith, exclude)): |
|
return |
|
if isinstance(module, nn.Linear): |
|
nn.init.trunc_normal_(module.weight, std=.02) |
|
if module.bias is not None: |
|
nn.init.zeros_(module.bias) |
|
elif isinstance(module, nn.Embedding): |
|
nn.init.trunc_normal_(module.weight, std=.02) |
|
if module.padding_idx is not None: |
|
module.weight.data[module.padding_idx].zero_() |
|
elif isinstance(module, nn.Conv2d): |
|
nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu') |
|
if module.bias is not None: |
|
nn.init.zeros_(module.bias) |
|
elif isinstance(module, (nn.LayerNorm, nn.BatchNorm2d, nn.GroupNorm)): |
|
nn.init.ones_(module.weight) |
|
nn.init.zeros_(module.bias) |
|
|