| """Utility functions for Chiluka.""" | |
| import torch | |
| from munch import Munch | |
| def length_to_mask(lengths): | |
| """Convert lengths to attention mask.""" | |
| mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths) | |
| mask = torch.gt(mask + 1, lengths.unsqueeze(1)) | |
| return mask | |
| def recursive_munch(d): | |
| """Recursively convert dict to Munch for dot notation access.""" | |
| if isinstance(d, dict): | |
| return Munch((k, recursive_munch(v)) for k, v in d.items()) | |
| elif isinstance(d, list): | |
| return [recursive_munch(v) for v in d] | |
| else: | |
| return d | |