Spaces:
Sleeping
Sleeping
| import gzip | |
| import io | |
| import os | |
| import torch | |
| from typing import Optional | |
| from cryptography.hazmat.primitives.ciphers.aead import AESGCM | |
| def _parse_key(key_str: str) -> bytes: | |
| key_str = key_str.strip() | |
| try: | |
| key = bytes.fromhex(key_str) | |
| if len(key) == 32: | |
| return key | |
| except ValueError: | |
| pass | |
| key = key_str.encode("utf-8") | |
| if len(key) == 32: | |
| return key | |
| raise ValueError("Key must be either a 64-character hex string or a 32-character raw string.") | |
| def _get_key(key: Optional[str] = None, env_var: str = "MODEL_KEY") -> bytes: | |
| if key is not None: | |
| return _parse_key(key) | |
| env_value = os.environ.get(env_var) | |
| if not env_value: | |
| raise RuntimeError("Missing key. Provide key=... or set environment variable {}.".format(env_var)) | |
| return _parse_key(env_value) | |
| def decrypt_and_decompress_to_bytes(path: str, key: Optional[str] = None, env_var: str = "MODEL_KEY") -> bytes: | |
| key_bytes = _get_key(key=key, env_var=env_var) | |
| aesgcm = AESGCM(key_bytes) | |
| with open(path, "rb") as f: | |
| data = f.read() | |
| if len(data) < 13: | |
| raise ValueError("Encrypted file is too short or invalid.") | |
| nonce = data[:12] | |
| ciphertext = data[12:] | |
| compressed = aesgcm.decrypt(nonce, ciphertext, None) | |
| plaintext = gzip.decompress(compressed) | |
| return plaintext | |
| def secure_torch_load(path: str, *args, key: Optional[str] = None, env_var: str = "MODEL_KEY", **kwargs): | |
| plaintext = decrypt_and_decompress_to_bytes(path, key=key, env_var=env_var) | |
| buffer = io.BytesIO(plaintext) | |
| return torch.load(buffer, *args, **kwargs) | |