project / app /utils.py
kabylake's picture
commit
7bd11ed
raw
history blame contribute delete
No virus
389 Bytes
from typing import List
import torch
def torch_device():
device = (
f"cuda:{torch.cuda.current_device()}"
if torch.cuda.is_available()
else ("mps" if torch.backends.mps.is_available() else "cpu")
)
return device
def split(iterable: List, chunk_size: int):
for i in range(0, len(iterable), chunk_size):
yield iterable[i: i + chunk_size]