File size: 3,907 Bytes
cb669f3
 
9564cbf
e3ab2c6
9564cbf
e3ab2c6
cb669f3
e3ab2c6
26a73a2
9564cbf
26a73a2
cb669f3
 
 
 
 
 
 
 
e3ab2c6
 
 
 
 
 
 
 
 
 
2e01f35
e3ab2c6
 
 
 
2e01f35
e3ab2c6
 
 
cb669f3
 
9564cbf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cb669f3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import os
from tempfile import TemporaryDirectory
from typing import Dict, Mapping, Optional, Sequence, Union

import pandas as pd
from datasets import load_dataset as hf_load_dataset
from tqdm import tqdm

from .operator import SourceOperator
from .stream import MultiStream, Stream

try:
    import ibm_boto3
    from ibm_botocore.client import ClientError

    ibm_boto3_available = True
except ImportError:
    ibm_boto3_available = False


class Loader(SourceOperator):
    pass


class LoadHF(Loader):
    path: str
    name: Optional[str] = None
    data_dir: Optional[str] = None
    data_files: Optional[Union[str, Sequence[str], Mapping[str, Union[str, Sequence[str]]]]] = None
    streaming: bool = True
    cached = False

    def process(self):
        dataset = hf_load_dataset(
            self.path, name=self.name, data_dir=self.data_dir, data_files=self.data_files, streaming=self.streaming
        )

        return MultiStream.from_iterables(dataset)


class LoadCSV(Loader):
    files: Dict[str, str]
    chunksize: int = 1000

    def load_csv(self, file):
        for chunk in pd.read_csv(file, chunksize=self.chunksize):
            for index, row in chunk.iterrows():
                yield row.to_dict()

    def process(self):
        return MultiStream(
            {name: Stream(generator=self.load_csv, gen_kwargs={"file": file}) for name, file in self.files.items()}
        )


class LoadFromIBMCloud(Loader):
    endpoint_url_env: str
    aws_access_key_id_env: str
    aws_secret_access_key_env: str
    bucket_name: str
    data_dir: str
    data_files: Sequence[str]

    def _download_from_cos(self, cos, bucket_name, item_name, local_file):
        print(f"Downloading {item_name} from {bucket_name} COS to {local_file}")
        try:
            response = cos.Object(bucket_name, item_name).get()
            size = response["ContentLength"]
        except Exception as e:
            raise Exception(f"Unabled to access {item_name} in {bucket_name} in COS", e)

        progress_bar = tqdm(total=size, unit="iB", unit_scale=True)

        def upload_progress(chunk):
            progress_bar.update(chunk)

        try:
            cos.Bucket(bucket_name).download_file(item_name, local_file, Callback=upload_progress)
            print("\nDownload Successful")
        except Exception as e:
            raise Exception(f"Unabled to download {item_name} in {bucket_name}", e)

    def prepare(self):
        super().prepare()
        self.endpoint_url = os.getenv(self.endpoint_url_env)
        self.aws_access_key_id = os.getenv(self.aws_access_key_id_env)
        self.aws_secret_access_key = os.getenv(self.aws_secret_access_key_env)

    def verify(self):
        super().verify()
        assert (
            ibm_boto3_available
        ), f"Please install ibm_boto3 in order to use the LoadFromIBMCloud loader (using `pip install ibm-cos-sdk`) "
        assert self.endpoint_url is not None, f"Please set the {self.endpoint_url_env} environmental variable"
        assert self.aws_access_key_id is not None, f"Please set {self.aws_access_key_id_env} environmental variable"
        assert (
            self.aws_secret_access_key is not None
        ), f"Please set {self.aws_secret_access_key_env} environmental variable"

    def process(self):
        cos = ibm_boto3.resource(
            "s3",
            aws_access_key_id=self.aws_access_key_id,
            aws_secret_access_key=self.aws_secret_access_key,
            endpoint_url=self.endpoint_url,
        )

        with TemporaryDirectory() as temp_directory:
            for data_file in self.data_files:
                self._download_from_cos(
                    cos, self.bucket_name, self.data_dir + "/" + data_file, temp_directory + "/" + data_file
                )
            dataset = hf_load_dataset(temp_directory, streaming=False)

        return MultiStream.from_iterables(dataset)