|
from typing import IO, Generator, Tuple, Union, overload |
|
from pathlib import Path, PosixPath, PurePosixPath |
|
import io |
|
import os |
|
import re |
|
import requests |
|
import fnmatch |
|
|
|
from azure.identity import DefaultAzureCredential |
|
from azure.storage.blob import ContainerClient, BlobClient |
|
import requests.adapters |
|
import requests.packages |
|
from urllib3.util.retry import Retry |
|
|
|
|
|
__all__ = [ |
|
'download_blob', 'upload_blob', |
|
'download_blob_with_cache', |
|
'open_blob', 'open_blob_with_cache', |
|
'blob_file_exists', |
|
'AzureBlobPath','SmartPath' |
|
] |
|
|
|
DEFAULT_CREDENTIAL = DefaultAzureCredential() |
|
|
|
BLOB_CACHE_DIR = './.blobcache' |
|
|
|
def download_blob(blob: Union[str, BlobClient]) -> bytes: |
|
if isinstance(blob, str): |
|
blob_client = BlobClient.from_blob_url(blob_client) |
|
else: |
|
blob_client = blob |
|
return blob_client.download_blob().read() |
|
|
|
|
|
def upload_blob(blob: Union[str, BlobClient], data: Union[str, bytes]): |
|
if isinstance(blob, str): |
|
blob_client = BlobClient.from_blob_url(blob) |
|
else: |
|
blob_client = blob |
|
blob_client.upload_blob(data, overwrite=True) |
|
|
|
|
|
def download_blob_with_cache(container: Union[str, ContainerClient], blob_name: str, cache_dir: str = 'blobcache') -> bytes: |
|
""" |
|
Download a blob file from a container and return its content as bytes. |
|
If the file is already present in the cache, it is read from there. |
|
""" |
|
cache_path = Path(cache_dir) / blob_name |
|
if cache_path.exists(): |
|
return cache_path.read_bytes() |
|
data = download_blob(container, blob_name) |
|
cache_path.parent.mkdir(parents=True, exist_ok=True) |
|
cache_path.write_bytes(data) |
|
return data |
|
|
|
|
|
def open_blob(container: Union[str, ContainerClient], blob_name: str) -> io.BytesIO: |
|
""" |
|
Open a blob file for reading from a container and return its content as a BytesIO object. |
|
""" |
|
return io.BytesIO(download_blob(container, blob_name)) |
|
|
|
|
|
def open_blob_with_cache(container: Union[str, ContainerClient], blob_name: str, cache_dir: str = 'blobcache') -> io.BytesIO: |
|
""" |
|
Open a blob file for reading from a container and return its content as a BytesIO object. |
|
If the file is already present in the cache, it is read from there. |
|
""" |
|
return io.BytesIO(download_blob_with_cache(container, blob_name, cache_dir=cache_dir)) |
|
|
|
|
|
def blob_file_exists(container: Union[str, ContainerClient], blob_name: str) -> bool: |
|
""" |
|
Check if a blob file exists in a container. |
|
""" |
|
if isinstance(container, str): |
|
container = ContainerClient.from_container_url(container) |
|
blob_client = container.get_blob_client(blob_name) |
|
return blob_client.exists() |
|
|
|
def is_blob_url(url: str) -> bool: |
|
return re.match(r'https://[^/]+blob.core.windows.net/+', url) is not None |
|
|
|
|
|
def split_blob_url(url: str) -> Tuple[str, str, str]: |
|
match = re.match(r'(https://[^/]+blob.core.windows.net/[^/?]+)(/([^\?]*))?(\?.+)?', url) |
|
if match: |
|
container, _, path, sas = match.groups() |
|
return container, path or '', sas or '' |
|
raise ValueError(f'Not a valid blob URL: {url}') |
|
|
|
|
|
def join_blob_path(url: str, *others: str) -> str: |
|
container, path, sas = split_blob_url(url) |
|
return container + '/' + os.path.join(path, *others) + sas |
|
|
|
|
|
class AzureBlobStringWriter(io.StringIO): |
|
def __init__(self, blob_client: BlobClient, encoding: str = 'utf-8', **kwargs): |
|
self._encoding = encoding |
|
self.blob_client = blob_client |
|
self.kwargs = kwargs |
|
super().__init__() |
|
|
|
def close(self): |
|
self.blob_client.upload_blob(self.getvalue().encode(self._encoding), blob_type='BlockBlob', overwrite=True, **self.kwargs) |
|
|
|
|
|
class AzureBlobBytesWriter(io.BytesIO): |
|
def __init__(self, blob_client: BlobClient, **kwargs): |
|
super().__init__() |
|
self.blob_client = blob_client |
|
self.kwargs = kwargs |
|
|
|
def close(self): |
|
self.blob_client.upload_blob(self.getvalue(), blob_type='BlockBlob', overwrite=True, **self.kwargs) |
|
|
|
|
|
def open_azure_blob(blob: Union[str, BlobClient], mode: str = 'r', encoding: str = 'utf-8', newline: str = None, cache_blob: bool = False, **kwargs) -> IO: |
|
if isinstance(blob, str): |
|
blob_client = BlobClient.from_blob_url(blob) |
|
elif isinstance(blob, BlobClient): |
|
blob_client = blob |
|
else: |
|
raise ValueError(f'Must be a blob URL or a BlobClient object: {blob}') |
|
|
|
if cache_blob: |
|
cache_path = Path(BLOB_CACHE_DIR, blob_client.account_name, blob_client.container_name, blob_client.blob_name) |
|
|
|
if mode == 'r' or mode == 'rb': |
|
if cache_blob: |
|
if cache_path.exists(): |
|
data = cache_path.read_bytes() |
|
else: |
|
data = blob_client.download_blob(**kwargs).read() |
|
cache_path.parent.mkdir(parents=True, exist_ok=True) |
|
cache_path.write_bytes(data) |
|
else: |
|
data = blob_client.download_blob(**kwargs).read() |
|
if mode == 'r': |
|
return io.StringIO(data.decode(encoding), newline=newline) |
|
else: |
|
return io.BytesIO(data) |
|
elif mode == 'w': |
|
return AzureBlobStringWriter(blob_client, **kwargs) |
|
elif mode == 'wb': |
|
return AzureBlobBytesWriter(blob_client, **kwargs) |
|
else: |
|
raise ValueError(f'Unsupported mode: {mode}') |
|
|
|
|
|
def smart_open(path_or_url: Union[Path, str], mode: str = 'r', encoding: str = 'utf-8') -> IO: |
|
if is_blob_url(str(path_or_url)): |
|
return open_azure_blob(str(path_or_url), mode, encoding) |
|
return open(path_or_url, mode, encoding) |
|
|
|
|
|
class AzureBlobPath(PurePosixPath): |
|
""" |
|
Implementation of pathlib.Path like interface for Azure Blob Storage. |
|
""" |
|
container_client: ContainerClient |
|
_parse_path = PurePosixPath._parse_args if hasattr(PurePosixPath, '_parse_args') else PurePosixPath._parse_path |
|
|
|
def __new__(cls, *args, **kwargs): |
|
"""Override the old __new__ method. Parts are parsed in __init__""" |
|
return object.__new__(cls) |
|
|
|
def __init__(self, root: Union[str, 'AzureBlobPath', ContainerClient], *others: Union[str, PurePosixPath], pool_maxsize: int = 256, retries: int = 3): |
|
if isinstance(root, AzureBlobPath): |
|
self.container_client = root.container_client |
|
parts = root.parts + others |
|
elif isinstance(root, str): |
|
url = root |
|
container, path, sas = split_blob_url(url) |
|
session = self._get_session(pool_maxsize=pool_maxsize, retries=retries) |
|
if sas: |
|
self.container_client = ContainerClient.from_container_url(container + sas, session=session) |
|
else: |
|
self.container_client = ContainerClient.from_container_url(container, credential=DEFAULT_CREDENTIAL, session=session) |
|
parts = (path, *others) |
|
elif isinstance(root, ContainerClient): |
|
self.container_client = root |
|
parts = others |
|
else: |
|
raise ValueError(f'Invalid root: {root}') |
|
|
|
if hasattr(PurePosixPath, '_parse_args'): |
|
|
|
drv, root, parts = PurePosixPath._parse_args(parts) |
|
self._drv = drv |
|
self._root = root |
|
self._parts = parts |
|
else: |
|
super().__init__(*parts) |
|
|
|
def _get_session(self, pool_maxsize: int = 1024, retries: int = 3) -> requests.Session: |
|
session = requests.Session() |
|
retry_strategy = Retry( |
|
total=retries, |
|
status_forcelist=[429, 500, 502, 503, 504], |
|
allowed_methods=["HEAD", "GET", "PUT", "DELETE"], |
|
backoff_factor=1, |
|
raise_on_status=False, |
|
read=retries, |
|
connect=retries, |
|
redirect=retries, |
|
) |
|
adapter = requests.adapters.HTTPAdapter(pool_connections=pool_maxsize, pool_maxsize=pool_maxsize, max_retries=retry_strategy) |
|
session.mount('http://', adapter) |
|
session.mount('https://', adapter) |
|
return session |
|
|
|
def _from_parsed_parts(self, drv, root, parts): |
|
"For compatibility with Python 3.10" |
|
return AzureBlobPath(self.container_client, drv, root, *parts) |
|
|
|
def with_segments(self, *pathsegments): |
|
return AzureBlobPath(self.container_client, *pathsegments) |
|
|
|
@property |
|
def path(self) -> str: |
|
return '/'.join(self.parts) |
|
|
|
@property |
|
def blob_client(self) -> BlobClient: |
|
return self.container_client.get_blob_client(self.path) |
|
|
|
@property |
|
def url(self) -> str: |
|
if len(self.parts) == 0: |
|
return self.container_client.url |
|
return self.container_client.get_blob_client(self.path).url |
|
|
|
@property |
|
def container_name(self) -> str: |
|
return self.container_client.container_name |
|
|
|
@property |
|
def account_name(self) -> str: |
|
return self.container_client.account_name |
|
|
|
def __str__(self): |
|
return self.url |
|
|
|
def __repr__(self): |
|
return self.url |
|
|
|
def open(self, mode: str = 'r', encoding: str = 'utf-8', cache_blob: bool = False, **kwargs) -> IO: |
|
return open_azure_blob(self.blob_client, mode, encoding, cache_blob=cache_blob, **kwargs) |
|
|
|
def __truediv__(self, other: Union[str, Path]) -> 'AzureBlobPath': |
|
return self.joinpath(other) |
|
|
|
def mkdir(self, parents: bool = False, exist_ok: bool = False): |
|
pass |
|
|
|
def iterdir(self) -> Generator['AzureBlobPath', None, None]: |
|
path = self.path |
|
if not path.endswith('/'): |
|
path += '/' |
|
for item in self.container_client.walk_blobs(self.path): |
|
yield AzureBlobPath(self.container_client, item.name) |
|
|
|
def glob(self, pattern: str) -> Generator['AzureBlobPath', None, None]: |
|
special_chars = ".^$+{}[]()|/" |
|
for char in special_chars: |
|
pattern = pattern.replace(char, "\\" + char) |
|
pattern = pattern.replace('**', './/.') |
|
pattern = pattern.replace('*', '[^/]*') |
|
pattern = pattern.replace('.//.', '.*') |
|
pattern = "^" + pattern + "$" |
|
reg = re.compile(pattern) |
|
|
|
for item in self.container_client.list_blobs(self.path): |
|
if reg.match(os.path.relpath(item.name, self.path)): |
|
yield AzureBlobPath(self.container_client, item.name) |
|
|
|
def exists(self) -> bool: |
|
return self.blob_client.exists() |
|
|
|
def read_bytes(self, cache_blob: bool = False) -> bytes: |
|
with self.open('rb', cache_blob=cache_blob) as f: |
|
return f.read() |
|
|
|
def read_text(self, encoding: str = 'utf-8', cache_blob: bool = False) -> str: |
|
with self.open('r', encoding=encoding, cache_blob=cache_blob) as f: |
|
return f.read() |
|
|
|
def write_bytes(self, data: bytes): |
|
self.blob_client.upload_blob(data, overwrite=True) |
|
|
|
def write_text(self, data: str, encoding: str = 'utf-8'): |
|
self.blob_client.upload_blob(data.encode(encoding), overwrite=True) |
|
|
|
def unlink(self): |
|
self.blob_client.delete_blob() |
|
|
|
def new_client(self) -> 'AzureBlobPath': |
|
return AzureBlobPath(self.container_client.url, self.path) |
|
|
|
|
|
class SmartPath(Path, AzureBlobPath): |
|
""" |
|
Supports both local file paths and Azure Blob Storage URLs. |
|
""" |
|
def __new__(cls, first: Union[Path, str], *others: Union[str, PurePosixPath]) -> Union[Path, AzureBlobPath]: |
|
if is_blob_url(str(first)): |
|
return AzureBlobPath(str(first), *others) |
|
return Path(first, *others) |
|
|
|
|
|
|