Spaces:
Runtime error
Runtime error
| import logging | |
| import torch | |
| from saicinpainting.training.trainers.default import DefaultInpaintingTrainingModule | |
| def get_training_model_class(kind): | |
| if kind == 'default': | |
| return DefaultInpaintingTrainingModule | |
| raise ValueError(f'Unknown trainer module {kind}') | |
| def make_training_model(config): | |
| kind = config.training_model.kind | |
| kwargs = dict(config.training_model) | |
| kwargs.pop('kind') | |
| kwargs['use_ddp'] = config.trainer.kwargs.get('accelerator', None) == 'ddp' | |
| logging.info(f'Make training model {kind}') | |
| cls = get_training_model_class(kind) | |
| return cls(config, **kwargs) | |
| def load_checkpoint(train_config, path, map_location='cuda', strict=True): | |
| model: torch.nn.Module = make_training_model(train_config) | |
| state = torch.load(path, map_location=map_location) | |
| model.load_state_dict(state['state_dict'], strict=strict) | |
| model.on_load_checkpoint(state) | |
| return model | |