Elron commited on
Commit
74ba290
1 Parent(s): cef1000

Upload operators.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. operators.py +58 -3
operators.py CHANGED
@@ -36,6 +36,7 @@ import importlib
36
  import operator
37
  import os
38
  import uuid
 
39
  from abc import abstractmethod
40
  from collections import Counter
41
  from copy import deepcopy
@@ -54,6 +55,8 @@ from typing import (
54
  Union,
55
  )
56
 
 
 
57
  from .artifact import Artifact, fetch_artifact
58
  from .dataclass import NonPositionalField
59
  from .dict_utils import dict_delete, dict_get, dict_set, is_subpath
@@ -62,12 +65,13 @@ from .operator import (
62
  MultiStreamOperator,
63
  PagedStreamOperator,
64
  SequentialOperator,
 
65
  SingleStreamOperator,
66
  SingleStreamReducer,
 
67
  StreamingOperator,
68
  StreamInitializerOperator,
69
  StreamInstanceOperator,
70
- StreamSource,
71
  )
72
  from .random_utils import new_random_generator
73
  from .stream import Stream
@@ -89,7 +93,7 @@ class FromIterables(StreamInitializerOperator):
89
  return MultiStream.from_iterables(iterables)
90
 
91
 
92
- class IterableSource(StreamSource):
93
  """Creates a MultiStream from a dict of named iterables.
94
 
95
  It is a callable.
@@ -105,7 +109,7 @@ class IterableSource(StreamSource):
105
 
106
  iterables: Dict[str, Iterable]
107
 
108
- def __call__(self) -> MultiStream:
109
  return MultiStream.from_iterables(self.iterables)
110
 
111
 
@@ -1784,3 +1788,54 @@ class LengthBalancer(DeterministicBalancer):
1784
  if total_len < val:
1785
  return i
1786
  return i + 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  import operator
37
  import os
38
  import uuid
39
+ import zipfile
40
  from abc import abstractmethod
41
  from collections import Counter
42
  from copy import deepcopy
 
55
  Union,
56
  )
57
 
58
+ import requests
59
+
60
  from .artifact import Artifact, fetch_artifact
61
  from .dataclass import NonPositionalField
62
  from .dict_utils import dict_delete, dict_get, dict_set, is_subpath
 
65
  MultiStreamOperator,
66
  PagedStreamOperator,
67
  SequentialOperator,
68
+ SideEffectOperator,
69
  SingleStreamOperator,
70
  SingleStreamReducer,
71
+ SourceOperator,
72
  StreamingOperator,
73
  StreamInitializerOperator,
74
  StreamInstanceOperator,
 
75
  )
76
  from .random_utils import new_random_generator
77
  from .stream import Stream
 
93
  return MultiStream.from_iterables(iterables)
94
 
95
 
96
+ class IterableSource(SourceOperator):
97
  """Creates a MultiStream from a dict of named iterables.
98
 
99
  It is a callable.
 
109
 
110
  iterables: Dict[str, Iterable]
111
 
112
+ def process(self) -> MultiStream:
113
  return MultiStream.from_iterables(self.iterables)
114
 
115
 
 
1788
  if total_len < val:
1789
  return i
1790
  return i + 1
1791
+
1792
+
1793
+ class DownloadError(Exception):
1794
+ def __init__(
1795
+ self,
1796
+ message,
1797
+ ):
1798
+ self.__super__(message)
1799
+
1800
+
1801
+ class UnexpectedHttpCodeError(Exception):
1802
+ def __init__(self, http_code):
1803
+ self.__super__(f"unexpected http code {http_code}")
1804
+
1805
+
1806
+ class DownloadOperator(SideEffectOperator):
1807
+ """Operator for downloading a file from a given URL to a specified local path.
1808
+
1809
+ Attributes:
1810
+ source (str): URL of the file to be downloaded.
1811
+ target (str): Local path where the downloaded file should be saved.
1812
+ """
1813
+
1814
+ source: str
1815
+ target: str
1816
+
1817
+ def process(self):
1818
+ try:
1819
+ response = requests.get(self.source, allow_redirects=True)
1820
+ except Exception as e:
1821
+ raise DownloadError(f"Unabled to download {self.source}") from e
1822
+ if response.status_code != 200:
1823
+ raise UnexpectedHttpCodeError(response.status_code)
1824
+ with open(self.target, "wb") as f:
1825
+ f.write(response.content)
1826
+
1827
+
1828
+ class ExtractZipFile(SideEffectOperator):
1829
+ """Operator for extracting files from a zip archive.
1830
+
1831
+ Attributes:
1832
+ zip_file (str): Path of the zip file to be extracted.
1833
+ target_dir (str): Directory where the contents of the zip file will be extracted.
1834
+ """
1835
+
1836
+ zip_file: str
1837
+ target_dir: str
1838
+
1839
+ def process(self):
1840
+ with zipfile.ZipFile(self.zip_file) as zf:
1841
+ zf.extractall(self.target_dir)