Spaces:
Running
on
Zero
Running
on
Zero
| from typing import IO, Generator, Tuple, Union, overload | |
| from pathlib import Path, PosixPath, PurePosixPath | |
| import io | |
| import os | |
| import re | |
| import requests | |
| import fnmatch | |
| from azure.identity import DefaultAzureCredential | |
| from azure.storage.blob import ContainerClient, BlobClient | |
| import requests.adapters | |
| import requests.packages | |
| from urllib3.util.retry import Retry | |
| __all__ = [ | |
| 'download_blob', 'upload_blob', | |
| 'download_blob_with_cache', | |
| 'open_blob', 'open_blob_with_cache', | |
| 'blob_file_exists', | |
| 'AzureBlobPath','SmartPath' | |
| ] | |
| DEFAULT_CREDENTIAL = DefaultAzureCredential() | |
| BLOB_CACHE_DIR = './.blobcache' | |
| def download_blob(blob: Union[str, BlobClient]) -> bytes: | |
| if isinstance(blob, str): | |
| blob_client = BlobClient.from_blob_url(blob_client) | |
| else: | |
| blob_client = blob | |
| return blob_client.download_blob().read() | |
| def upload_blob(blob: Union[str, BlobClient], data: Union[str, bytes]): | |
| if isinstance(blob, str): | |
| blob_client = BlobClient.from_blob_url(blob) | |
| else: | |
| blob_client = blob | |
| blob_client.upload_blob(data, overwrite=True) | |
| def download_blob_with_cache(container: Union[str, ContainerClient], blob_name: str, cache_dir: str = 'blobcache') -> bytes: | |
| """ | |
| Download a blob file from a container and return its content as bytes. | |
| If the file is already present in the cache, it is read from there. | |
| """ | |
| cache_path = Path(cache_dir) / blob_name | |
| if cache_path.exists(): | |
| return cache_path.read_bytes() | |
| data = download_blob(container, blob_name) | |
| cache_path.parent.mkdir(parents=True, exist_ok=True) | |
| cache_path.write_bytes(data) | |
| return data | |
| def open_blob(container: Union[str, ContainerClient], blob_name: str) -> io.BytesIO: | |
| """ | |
| Open a blob file for reading from a container and return its content as a BytesIO object. | |
| """ | |
| return io.BytesIO(download_blob(container, blob_name)) | |
| def open_blob_with_cache(container: Union[str, ContainerClient], blob_name: str, cache_dir: str = 'blobcache') -> io.BytesIO: | |
| """ | |
| Open a blob file for reading from a container and return its content as a BytesIO object. | |
| If the file is already present in the cache, it is read from there. | |
| """ | |
| return io.BytesIO(download_blob_with_cache(container, blob_name, cache_dir=cache_dir)) | |
| def blob_file_exists(container: Union[str, ContainerClient], blob_name: str) -> bool: | |
| """ | |
| Check if a blob file exists in a container. | |
| """ | |
| if isinstance(container, str): | |
| container = ContainerClient.from_container_url(container) | |
| blob_client = container.get_blob_client(blob_name) | |
| return blob_client.exists() | |
| def is_blob_url(url: str) -> bool: | |
| return re.match(r'https://[^/]+blob.core.windows.net/+', url) is not None | |
| def split_blob_url(url: str) -> Tuple[str, str, str]: | |
| match = re.match(r'(https://[^/]+blob.core.windows.net/[^/?]+)(/([^\?]*))?(\?.+)?', url) | |
| if match: | |
| container, _, path, sas = match.groups() | |
| return container, path or '', sas or '' | |
| raise ValueError(f'Not a valid blob URL: {url}') | |
| def join_blob_path(url: str, *others: str) -> str: | |
| container, path, sas = split_blob_url(url) | |
| return container + '/' + os.path.join(path, *others) + sas | |
| class AzureBlobStringWriter(io.StringIO): | |
| def __init__(self, blob_client: BlobClient, encoding: str = 'utf-8', **kwargs): | |
| self._encoding = encoding | |
| self.blob_client = blob_client | |
| self.kwargs = kwargs | |
| super().__init__() | |
| def close(self): | |
| self.blob_client.upload_blob(self.getvalue().encode(self._encoding), blob_type='BlockBlob', overwrite=True, **self.kwargs) | |
| class AzureBlobBytesWriter(io.BytesIO): | |
| def __init__(self, blob_client: BlobClient, **kwargs): | |
| super().__init__() | |
| self.blob_client = blob_client | |
| self.kwargs = kwargs | |
| def close(self): | |
| self.blob_client.upload_blob(self.getvalue(), blob_type='BlockBlob', overwrite=True, **self.kwargs) | |
| def open_azure_blob(blob: Union[str, BlobClient], mode: str = 'r', encoding: str = 'utf-8', newline: str = None, cache_blob: bool = False, **kwargs) -> IO: | |
| if isinstance(blob, str): | |
| blob_client = BlobClient.from_blob_url(blob) | |
| elif isinstance(blob, BlobClient): | |
| blob_client = blob | |
| else: | |
| raise ValueError(f'Must be a blob URL or a BlobClient object: {blob}') | |
| if cache_blob: | |
| cache_path = Path(BLOB_CACHE_DIR, blob_client.account_name, blob_client.container_name, blob_client.blob_name) | |
| if mode == 'r' or mode == 'rb': | |
| if cache_blob: | |
| if cache_path.exists(): | |
| data = cache_path.read_bytes() | |
| else: | |
| data = blob_client.download_blob(**kwargs).read() | |
| cache_path.parent.mkdir(parents=True, exist_ok=True) | |
| cache_path.write_bytes(data) | |
| else: | |
| data = blob_client.download_blob(**kwargs).read() | |
| if mode == 'r': | |
| return io.StringIO(data.decode(encoding), newline=newline) | |
| else: | |
| return io.BytesIO(data) | |
| elif mode == 'w': | |
| return AzureBlobStringWriter(blob_client, **kwargs) | |
| elif mode == 'wb': | |
| return AzureBlobBytesWriter(blob_client, **kwargs) | |
| else: | |
| raise ValueError(f'Unsupported mode: {mode}') | |
| def smart_open(path_or_url: Union[Path, str], mode: str = 'r', encoding: str = 'utf-8') -> IO: | |
| if is_blob_url(str(path_or_url)): | |
| return open_azure_blob(str(path_or_url), mode, encoding) | |
| return open(path_or_url, mode, encoding) | |
| class AzureBlobPath(PurePosixPath): | |
| """ | |
| Implementation of pathlib.Path like interface for Azure Blob Storage. | |
| """ | |
| container_client: ContainerClient | |
| _parse_path = PurePosixPath._parse_args if hasattr(PurePosixPath, '_parse_args') else PurePosixPath._parse_path | |
| def __new__(cls, *args, **kwargs): | |
| """Override the old __new__ method. Parts are parsed in __init__""" | |
| return object.__new__(cls) | |
| def __init__(self, root: Union[str, 'AzureBlobPath', ContainerClient], *others: Union[str, PurePosixPath], pool_maxsize: int = 256, retries: int = 3): | |
| if isinstance(root, AzureBlobPath): | |
| self.container_client = root.container_client | |
| parts = root.parts + others | |
| elif isinstance(root, str): | |
| url = root | |
| container, path, sas = split_blob_url(url) | |
| session = self._get_session(pool_maxsize=pool_maxsize, retries=retries) | |
| if sas: | |
| self.container_client = ContainerClient.from_container_url(container + sas, session=session) | |
| else: | |
| self.container_client = ContainerClient.from_container_url(container, credential=DEFAULT_CREDENTIAL, session=session) | |
| parts = (path, *others) | |
| elif isinstance(root, ContainerClient): | |
| self.container_client = root | |
| parts = others | |
| else: | |
| raise ValueError(f'Invalid root: {root}') | |
| if hasattr(PurePosixPath, '_parse_args'): | |
| # For compatibility with Python 3.10 | |
| drv, root, parts = PurePosixPath._parse_args(parts) | |
| self._drv = drv | |
| self._root = root | |
| self._parts = parts | |
| else: | |
| super().__init__(*parts) | |
| def _get_session(self, pool_maxsize: int = 1024, retries: int = 3) -> requests.Session: | |
| session = requests.Session() | |
| retry_strategy = Retry( | |
| total=retries, | |
| status_forcelist=[429, 500, 502, 503, 504], | |
| allowed_methods=["HEAD", "GET", "PUT", "DELETE"], | |
| backoff_factor=1, | |
| raise_on_status=False, | |
| read=retries, | |
| connect=retries, | |
| redirect=retries, | |
| ) | |
| adapter = requests.adapters.HTTPAdapter(pool_connections=pool_maxsize, pool_maxsize=pool_maxsize, max_retries=retry_strategy) | |
| session.mount('http://', adapter) | |
| session.mount('https://', adapter) | |
| return session | |
| def _from_parsed_parts(self, drv, root, parts): | |
| "For compatibility with Python 3.10" | |
| return AzureBlobPath(self.container_client, drv, root, *parts) | |
| def with_segments(self, *pathsegments): | |
| return AzureBlobPath(self.container_client, *pathsegments) | |
| def path(self) -> str: | |
| return '/'.join(self.parts) | |
| def blob_client(self) -> BlobClient: | |
| return self.container_client.get_blob_client(self.path) | |
| def url(self) -> str: | |
| if len(self.parts) == 0: | |
| return self.container_client.url | |
| return self.container_client.get_blob_client(self.path).url | |
| def container_name(self) -> str: | |
| return self.container_client.container_name | |
| def account_name(self) -> str: | |
| return self.container_client.account_name | |
| def __str__(self): | |
| return self.url | |
| def __repr__(self): | |
| return self.url | |
| def open(self, mode: str = 'r', encoding: str = 'utf-8', cache_blob: bool = False, **kwargs) -> IO: | |
| return open_azure_blob(self.blob_client, mode, encoding, cache_blob=cache_blob, **kwargs) | |
| def __truediv__(self, other: Union[str, Path]) -> 'AzureBlobPath': | |
| return self.joinpath(other) | |
| def mkdir(self, parents: bool = False, exist_ok: bool = False): | |
| pass | |
| def iterdir(self) -> Generator['AzureBlobPath', None, None]: | |
| path = self.path | |
| if not path.endswith('/'): | |
| path += '/' | |
| for item in self.container_client.walk_blobs(self.path): | |
| yield AzureBlobPath(self.container_client, item.name) | |
| def glob(self, pattern: str) -> Generator['AzureBlobPath', None, None]: | |
| special_chars = ".^$+{}[]()|/" | |
| for char in special_chars: | |
| pattern = pattern.replace(char, "\\" + char) | |
| pattern = pattern.replace('**', './/.') | |
| pattern = pattern.replace('*', '[^/]*') | |
| pattern = pattern.replace('.//.', '.*') | |
| pattern = "^" + pattern + "$" | |
| reg = re.compile(pattern) | |
| for item in self.container_client.list_blobs(self.path): | |
| if reg.match(os.path.relpath(item.name, self.path)): | |
| yield AzureBlobPath(self.container_client, item.name) | |
| def exists(self) -> bool: | |
| return self.blob_client.exists() | |
| def read_bytes(self, cache_blob: bool = False) -> bytes: | |
| with self.open('rb', cache_blob=cache_blob) as f: | |
| return f.read() | |
| def read_text(self, encoding: str = 'utf-8', cache_blob: bool = False) -> str: | |
| with self.open('r', encoding=encoding, cache_blob=cache_blob) as f: | |
| return f.read() | |
| def write_bytes(self, data: bytes): | |
| self.blob_client.upload_blob(data, overwrite=True) | |
| def write_text(self, data: str, encoding: str = 'utf-8'): | |
| self.blob_client.upload_blob(data.encode(encoding), overwrite=True) | |
| def unlink(self): | |
| self.blob_client.delete_blob() | |
| def new_client(self) -> 'AzureBlobPath': | |
| return AzureBlobPath(self.container_client.url, self.path) | |
| class SmartPath(Path, AzureBlobPath): | |
| """ | |
| Supports both local file paths and Azure Blob Storage URLs. | |
| """ | |
| def __new__(cls, first: Union[Path, str], *others: Union[str, PurePosixPath]) -> Union[Path, AzureBlobPath]: | |
| if is_blob_url(str(first)): | |
| return AzureBlobPath(str(first), *others) | |
| return Path(first, *others) | |