Spaces:
Sleeping
Sleeping
| import difflib | |
| import torch | |
| def get_layer(l_name, library=torch.nn): | |
| """Return layer object handler from library e.g. from torch.nn | |
| E.g. if l_name=="elu", returns torch.nn.ELU. | |
| Args: | |
| l_name (string): Case insensitive name for layer in library (e.g. .'elu'). | |
| library (module): Name of library/module where to search for object handler | |
| with l_name e.g. "torch.nn". | |
| Returns: | |
| layer_handler (object): handler for the requested layer e.g. (torch.nn.ELU) | |
| """ | |
| all_torch_layers = [x for x in dir(torch.nn)] | |
| match = [x for x in all_torch_layers if l_name.lower() == x.lower()] | |
| if len(match) == 0: | |
| close_matches = difflib.get_close_matches( | |
| l_name, [x.lower() for x in all_torch_layers] | |
| ) | |
| raise NotImplementedError( | |
| "Layer with name {} not found in {}.\n Closest matches: {}".format( | |
| l_name, str(library), close_matches | |
| ) | |
| ) | |
| elif len(match) > 1: | |
| close_matches = difflib.get_close_matches( | |
| l_name, [x.lower() for x in all_torch_layers] | |
| ) | |
| raise NotImplementedError( | |
| "Multiple matchs for layer with name {} not found in {}.\n " | |
| "All matches: {}".format(l_name, str(library), close_matches) | |
| ) | |
| else: | |
| # valid | |
| layer_handler = getattr(library, match[0]) | |
| return layer_handler |