|
import numpy as np |
|
import random |
|
import string |
|
from torch.utils.data import Dataset, Subset |
|
|
|
class DummyData(Dataset): |
|
def __init__(self, length, size): |
|
self.length = length |
|
self.size = size |
|
|
|
def __len__(self): |
|
return self.length |
|
|
|
def __getitem__(self, i): |
|
x = np.random.randn(*self.size) |
|
letters = string.ascii_lowercase |
|
y = ''.join(random.choice(string.ascii_lowercase) for i in range(10)) |
|
return {"jpg": x, "txt": y} |
|
|
|
|
|
class DummyDataWithEmbeddings(Dataset): |
|
def __init__(self, length, size, emb_size): |
|
self.length = length |
|
self.size = size |
|
self.emb_size = emb_size |
|
|
|
def __len__(self): |
|
return self.length |
|
|
|
def __getitem__(self, i): |
|
x = np.random.randn(*self.size) |
|
y = np.random.randn(*self.emb_size).astype(np.float32) |
|
return {"jpg": x, "txt": y} |
|
|
|
|