|
|
from typing import Tuple |
|
|
from torch.utils.data.datapipes._decorator import functional_datapipe |
|
|
from torch.utils.data.datapipes.datapipe import IterDataPipe |
|
|
|
|
|
__all__ = ["StreamReaderIterDataPipe", ] |
|
|
|
|
|
|
|
|
@functional_datapipe('read_from_stream') |
|
|
class StreamReaderIterDataPipe(IterDataPipe[Tuple[str, bytes]]): |
|
|
r""" |
|
|
Given IO streams and their label names, yields bytes with label |
|
|
name in a tuple (functional name: ``read_from_stream``). |
|
|
|
|
|
Args: |
|
|
datapipe: Iterable DataPipe provides label/URL and byte stream |
|
|
chunk: Number of bytes to be read from stream per iteration. |
|
|
If ``None``, all bytes will be read util the EOF. |
|
|
|
|
|
Example: |
|
|
>>> # xdoctest: +SKIP |
|
|
>>> from torchdata.datapipes.iter import IterableWrapper, StreamReader |
|
|
>>> from io import StringIO |
|
|
>>> dp = IterableWrapper([("alphabet", StringIO("abcde"))]) |
|
|
>>> list(StreamReader(dp, chunk=1)) |
|
|
[('alphabet', 'a'), ('alphabet', 'b'), ('alphabet', 'c'), ('alphabet', 'd'), ('alphabet', 'e')] |
|
|
""" |
|
|
def __init__(self, datapipe, chunk=None): |
|
|
self.datapipe = datapipe |
|
|
self.chunk = chunk |
|
|
|
|
|
def __iter__(self): |
|
|
for furl, stream in self.datapipe: |
|
|
while True: |
|
|
d = stream.read(self.chunk) |
|
|
if not d: |
|
|
stream.close() |
|
|
break |
|
|
yield (furl, d) |
|
|
|