Spaces:
Running
Running
| from typing import List | |
| import numpy as np | |
| def collate_outputs(outputs: List[dict]): | |
| """ | |
| used to collate default train_step and validation_step outputs. If you want something different then you gotta | |
| extend this | |
| we expect outputs to be a list of dictionaries where each of the dict has the same set of keys | |
| """ | |
| collated = {} | |
| for k in outputs[0].keys(): | |
| if np.isscalar(outputs[0][k]): | |
| collated[k] = [o[k] for o in outputs] | |
| elif isinstance(outputs[0][k], np.ndarray): | |
| collated[k] = np.vstack([o[k][None] for o in outputs]) | |
| elif isinstance(outputs[0][k], list): | |
| collated[k] = [item for o in outputs for item in o[k]] | |
| else: | |
| raise ValueError(f'Cannot collate input of type {type(outputs[0][k])}. ' | |
| f'Modify collate_outputs to add this functionality') | |
| return collated |