haakohu's picture
initial
5d756f1
import torch
import tops
import numpy as np
import io
import webdataset as wds
import os
from ..utils import png_decoder, get_num_workers, collate_fn
def kp_decoder(x):
# Keypoints are between [0, 1] for webdataset
keypoints = torch.from_numpy(np.load(io.BytesIO(x))).float().view(7, 2).clamp(0, 1)
keypoints = torch.cat((keypoints, torch.ones((7, 1))), dim=-1)
return keypoints
def bbox_decoder(x):
return torch.from_numpy(np.load(io.BytesIO(x))).float().view(4)
class BBoxToMask:
def __call__(self, sample):
imsize = sample["image.png"].shape[-1]
bbox = sample["bounding_box.npy"] * imsize
x0, y0, x1, y1 = np.round(bbox).astype(np.int64)
mask = torch.ones((1, imsize, imsize), dtype=torch.bool)
mask[:, y0:y1, x0:x1] = 0
sample["mask"] = mask
return sample
def get_dataloader_fdf_wds(
path,
batch_size: int,
num_workers: int,
transform: torch.nn.Module,
gpu_transform: torch.nn.Module,
infinite: bool,
shuffle: bool,
partial_batches: bool,
sample_shuffle=10_000,
tar_shuffle=100,
channels_last=False,
):
# Need to set this for split_by_node to work.
os.environ["RANK"] = str(tops.rank())
os.environ["WORLD_SIZE"] = str(tops.world_size())
if infinite:
pipeline = [wds.ResampledShards(str(path))]
else:
pipeline = [wds.SimpleShardList(str(path))]
if shuffle:
pipeline.append(wds.shuffle(tar_shuffle))
pipeline.extend([
wds.split_by_node,
wds.split_by_worker,
])
if shuffle:
pipeline.append(wds.shuffle(sample_shuffle))
decoder = [
wds.handle_extension("image.png", png_decoder),
wds.handle_extension("keypoints.npy", kp_decoder),
]
rename_keys = [
["img", "image.png"],
["keypoints", "keypoints.npy"],
["__key__", "__key__"],
["mask", "mask"]
]
pipeline.extend([
wds.tarfile_to_samples(),
wds.decode(*decoder),
])
pipeline.append(wds.map(BBoxToMask()))
pipeline.extend([
wds.batched(batch_size, collation_fn=collate_fn, partial=partial_batches),
wds.rename_keys(*rename_keys),
])
if transform is not None:
pipeline.append(wds.map(transform))
pipeline = wds.DataPipeline(*pipeline)
if infinite:
pipeline = pipeline.repeat(nepochs=1000000)
loader = wds.WebLoader(
pipeline, batch_size=None, shuffle=False,
num_workers=get_num_workers(num_workers),
persistent_workers=True,
)
loader = tops.DataPrefetcher(loader, gpu_transform, channels_last=channels_last, to_float=False)
return loader