| | import torch |
| | from torch.utils.data import Dataset, DataLoader |
| | from PIL import Image |
| | import torchaudio |
| |
|
| | class OmniDataset(Dataset): |
| | def __init__(self, data_list, vision_tokenizer, audio_tokenizer): |
| | """ |
| | data_list: List of dicts with {'text': str, 'img_path': str, 'audio_path': str} |
| | """ |
| | self.data = data_list |
| | self.v_tok = vision_tokenizer |
| | self.a_tok = audio_tokenizer |
| |
|
| | def __len__(self): |
| | return len(self.data) |
| |
|
| | def __getitem__(self, idx): |
| | item = self.data[idx] |
| | |
| | |
| | image = Image.open(item['img_path']).convert("RGB").resize((224, 224)) |
| | |
| | |
| | |
| | waveform, sr = torchaudio.load(item['audio_path']) |
| | |
| | |
| | |
| | text = item['text'] |
| |
|
| | return { |
| | "text": text, |
| | "image": image, |
| | "audio": waveform |
| | } |
| |
|
| | |
| | |
| | |