pllava-7b-demo / dataset /base_dataset.py
cathyxl
added
f239efc
import logging
import os
import json
import random
from torch.utils.data import Dataset
import time
from dataset.utils import load_image_from_path
try:
from petrel_client.client import Client
has_client = True
except ImportError:
has_client = False
logger = logging.getLogger(__name__)
class ImageVideoBaseDataset(Dataset):
"""Base class that implements the image and video loading methods"""
media_type = "video"
def __init__(self):
assert self.media_type in ["image", "video", "only_video"]
self.data_root = None
self.anno_list = (
None # list(dict), each dict contains {"image": str, # image or video path}
)
self.transform = None
self.video_reader = None
self.num_tries = None
self.client = None
if has_client:
self.client = Client('~/petreloss.conf')
def __getitem__(self, index):
raise NotImplementedError
def __len__(self):
raise NotImplementedError
def get_anno(self, index):
"""obtain the annotation for one media (video or image)
Args:
index (int): The media index.
Returns: dict.
- "image": the filename, video also use "image".
- "caption": The caption for this file.
"""
anno = self.anno_list[index]
if self.data_root is not None:
anno["image"] = os.path.join(self.data_root, anno["image"])
return anno
def load_and_transform_media_data(self, index, data_path):
if self.media_type == "image":
return self.load_and_transform_media_data_image(index, data_path, clip_transform=self.clip_transform)
else:
return self.load_and_transform_media_data_video(index, data_path, clip_transform=self.clip_transform)
def load_and_transform_media_data_image(self, index, data_path, clip_transform=False):
image = load_image_from_path(data_path, client=self.client)
if not clip_transform:
image = self.transform(image)
return image, index
def load_and_transform_media_data_video(self, index, data_path, return_fps=False, clip=None, clip_transform=False):
for _ in range(self.num_tries):
try:
max_num_frames = self.max_num_frames if hasattr(self, "max_num_frames") else -1
if "webvid" in data_path:
hdfs_dir="hdfs://harunava/home/byte_ailab_us_cvg/user/weimin.wang/videogen_data/webvid_data/10M_full_train"
video_name = os.path.basename(data_path)
video_id, extension = os.path.splitext(video_name)
ind_file = os.path.join(hdfs_dir, self.keys_indexfile[video_id])
frames, frame_indices, fps = self.video_reader(ind_file, video_id, self.num_frames, self.sample_type,
max_num_frames=max_num_frames, client=self.client, clip=clip)
else:
frames, frame_indices, fps = self.video_reader(
data_path, self.num_frames, self.sample_type,
max_num_frames=max_num_frames, client=self.client, clip=clip
)
except Exception as e:
logger.warning(
f"Caught exception {e} when loading video {data_path}, "
f"randomly sample a new video as replacement"
)
index = random.randint(0, len(self) - 1)
ann = self.get_anno(index)
data_path = ann["image"]
continue
# shared aug for video frames
if not clip_transform:
frames = self.transform(frames)
if return_fps:
sec = [str(round(f / fps, 1)) for f in frame_indices]
return frames, index, sec
else:
return frames, index
else:
raise RuntimeError(
f"Failed to fetch video after {self.num_tries} tries. "
f"This might indicate that you have many corrupted videos."
)