File size: 419 Bytes
8fb5471
401a1f1
77a9fb3
401a1f1
4740821
77a9fb3
8fb5471
 
 
a306ac2
 
 
8fb5471
9b0a562
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
from typing import Union

from datasets import DatasetDict

from .artifact import fetch_artifact
from .operator import StreamSource


def load_dataset(source: Union[StreamSource, str]) -> DatasetDict:
    assert isinstance(
        source, (StreamSource, str)
    ), "source must be a StreamSource or a string"
    if isinstance(source, str):
        source, _ = fetch_artifact(source)
    return source().to_dataset()