Spaces:
Runtime error
Runtime error
import copy | |
import logging | |
import os | |
import os.path as osp | |
from os.path import join | |
import torch | |
from torch.utils.data import ConcatDataset, DataLoader | |
from utils.optimizer import create_optimizer | |
from utils.scheduler import create_scheduler | |
logger = logging.getLogger(__name__) | |
def get_media_types(datasources): | |
"""get the media types for for all the dataloaders. | |
Args: | |
datasources (List): List of dataloaders or datasets. | |
Returns: List. The media_types. | |
""" | |
if isinstance(datasources[0], DataLoader): | |
datasets = [dataloader.dataset for dataloader in datasources] | |
else: | |
datasets = datasources | |
media_types = [ | |
dataset.datasets[0].media_type | |
if isinstance(dataset, ConcatDataset) | |
else dataset.media_type | |
for dataset in datasets | |
] | |
return media_types | |