CephVIT / secure_torch_load.py
farrell236's picture
Upload 4 files
325d063 verified
import gzip
import io
import os
import torch
from typing import Optional
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
def _parse_key(key_str: str) -> bytes:
key_str = key_str.strip()
try:
key = bytes.fromhex(key_str)
if len(key) == 32:
return key
except ValueError:
pass
key = key_str.encode("utf-8")
if len(key) == 32:
return key
raise ValueError("Key must be either a 64-character hex string or a 32-character raw string.")
def _get_key(key: Optional[str] = None, env_var: str = "MODEL_KEY") -> bytes:
if key is not None:
return _parse_key(key)
env_value = os.environ.get(env_var)
if not env_value:
raise RuntimeError("Missing key. Provide key=... or set environment variable {}.".format(env_var))
return _parse_key(env_value)
def decrypt_and_decompress_to_bytes(path: str, key: Optional[str] = None, env_var: str = "MODEL_KEY") -> bytes:
key_bytes = _get_key(key=key, env_var=env_var)
aesgcm = AESGCM(key_bytes)
with open(path, "rb") as f:
data = f.read()
if len(data) < 13:
raise ValueError("Encrypted file is too short or invalid.")
nonce = data[:12]
ciphertext = data[12:]
compressed = aesgcm.decrypt(nonce, ciphertext, None)
plaintext = gzip.decompress(compressed)
return plaintext
def secure_torch_load(path: str, *args, key: Optional[str] = None, env_var: str = "MODEL_KEY", **kwargs):
plaintext = decrypt_and_decompress_to_bytes(path, key=key, env_var=env_var)
buffer = io.BytesIO(plaintext)
return torch.load(buffer, *args, **kwargs)