pllava-7b-demo / tasks /shared_utils.py
cathyxl
added
f239efc
import copy
import logging
import os
import os.path as osp
from os.path import join
import torch
from torch.utils.data import ConcatDataset, DataLoader
from utils.optimizer import create_optimizer
from utils.scheduler import create_scheduler
logger = logging.getLogger(__name__)
def get_media_types(datasources):
"""get the media types for for all the dataloaders.
Args:
datasources (List): List of dataloaders or datasets.
Returns: List. The media_types.
"""
if isinstance(datasources[0], DataLoader):
datasets = [dataloader.dataset for dataloader in datasources]
else:
datasets = datasources
media_types = [
dataset.datasets[0].media_type
if isinstance(dataset, ConcatDataset)
else dataset.media_type
for dataset in datasets
]
return media_types