|
from datasets import load_dataset |
|
import torch |
|
import config |
|
from model import load_image |
|
|
|
from diffusers.utils.loading_utils import load_image as dl_to_pil |
|
|
|
def my_collate(batch): |
|
try: |
|
img = [item['image'] for item in batch] |
|
img = torch.cat([load_image(pil_image=i, image_file=None) for i in img]) |
|
text = ['''<|user|><img><IMG_CONTEXT></img><|end|><|assistant|>'''+item['prompt'] for item in batch] |
|
except Exception as e: |
|
print(e) |
|
return None |
|
return {'image':img, 'text':text} |
|
|
|
ds = load_dataset("stylebreeder/stylebreeder", split='2M_sample', streaming=True).shuffle(seed=7, buffer_size=1) |
|
dataloader = torch.utils.data.DataLoader(ds, num_workers=32, collate_fn=my_collate, batch_size=config.batch_size) |
|
|