File size: 748 Bytes
08a9c32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
from datasets import load_dataset
import torch
import config
from model import load_image

from diffusers.utils.loading_utils import load_image as dl_to_pil

def my_collate(batch):
    try:
      img = [item['image'] for item in batch]
      img = torch.cat([load_image(pil_image=i, image_file=None) for i in img])
      text = ['''<|user|><img><IMG_CONTEXT></img><|end|><|assistant|>'''+item['prompt'] for item in batch]
    except Exception as e:
      print(e)
      return None
    return {'image':img, 'text':text}

ds = load_dataset("stylebreeder/stylebreeder", split='2M_sample', streaming=True).shuffle(seed=7, buffer_size=1)
dataloader = torch.utils.data.DataLoader(ds, num_workers=32, collate_fn=my_collate, batch_size=config.batch_size)