from typing import List import torch def torch_device(): device = ( f"cuda:{torch.cuda.current_device()}" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu") ) return device def split(iterable: List, chunk_size: int): for i in range(0, len(iterable), chunk_size): yield iterable[i: i + chunk_size]