ThinkSound / data_utils /v2a_utils /vggsound_text.py
liuhuadai
support cot
052cf68
import logging
import os
from pathlib import Path
from typing import Optional, Union
import pandas as pd
import torch
import torchaudio
from torch.utils.data.dataset import Dataset
from torchvision.transforms import v2
from torio.io import StreamingMediaDecoder
from torchvision.utils import save_image
log = logging.getLogger()
_CLIP_SIZE = 384
_CLIP_FPS = 8.0
_SYNC_SIZE = 224
_SYNC_FPS = 25.0
class VGGSound(Dataset):
def __init__(
self,
root: Union[str, Path],
*,
tsv_path: Union[str, Path] = 'dataset/vggsound/split_txt/train_caption.csv',
start_row: Optional[int] = None,
end_row: Optional[int] = None,
save_dir: str = 'data/vggsound/video_latents_text/train'
):
self.root = Path(root)
# videos = sorted(os.listdir(self.root))
# videos = set([Path(v).stem for v in videos]) # remove extensions
videos = []
self.labels = []
self.cots = []
self.videos = []
missing_videos = []
# read the tsv for subset information
df_list = pd.read_csv(tsv_path, sep=',', dtype={'id': str}).to_dict('records')
# ζŽ§εˆΆε€„η†ηš„θ‘ŒθŒƒε›΄
if start_row is not None and end_row is not None:
df_list = df_list[start_row:end_row]
for record in df_list:
id = record['id']
# if os.path.exists(f'{save_dir}/{id}.pth'):
# continue
# try:
# torch.load(f'{save_dir}/{id}.pth')
# continue
# except:
# print(f'error load file: {save_dir}/{id}.pth')
# os.system(f'rm -f {save_dir}/{id}.pth')
label = record['caption']
# if id in videos:
self.labels.append(label)
self.cots.append(record['caption_cot'])
# self.labels[id] = label
self.videos.append(id)
# else:
# missing_videos.append(id)
log.info(f'{len(videos)} videos found in {root}')
log.info(f'{len(self.videos)} videos found in {tsv_path}')
log.info(f'{len(missing_videos)} videos missing in {root}')
def sample(self, idx: int) -> dict[str, torch.Tensor]:
video_id = self.videos[idx]
label = self.labels[idx]
cot = self.cots[idx]
data = {
'id': video_id,
'caption': label,
'caption_cot': cot
}
return data
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
try:
return self.sample(idx)
except Exception as e:
log.error(f'Error loading video {self.videos[idx]}: {e}')
return None
def __len__(self):
return len(self.labels)
# dataset = VGGSound(
# root="data/vggsound/video/test",
# tsv_path="data/vggsound/split_txt/temp.csv",
# sample_rate=44100,
# duration_sec=9.0,
# audio_samples=397312,
# start_row=0,
# end_row=None,
# save_dir="data/vggsound/video_latents_text/test"
# )
# dataset[0]