Vulnerability Details

1.1. Orbax CheckpointManager Path Traversal (CVE-XXXX-XXXX)

Description: The orbax.checkpoint.CheckpointManager.restore method, when used with a crafted step argument, is susceptible to a path traversal vulnerability. This allows an attacker to manipulate the path resolution logic, causing Orbax to attempt to load checkpoint data from arbitrary locations on the filesystem outside the intended base directory.

Impact: This vulnerability can lead to Arbitrary File Read (Information Disclosure). By providing a step argument containing path traversal sequences (e.g., ../../), an attacker can force the CheckpointManager to read sensitive files (e.g., /etc/passwd, private keys) from the system. If the content of these files can be interpreted as a valid checkpoint item (e.g., JSON for JsonCheckpointer or a PyTree for PyTreeCheckpointer), the data can be exfiltrated or used to further compromise the system. In scenarios where a malicious checkpoint is loaded, this could also lead to Arbitrary Code Execution (ACE) if the loaded data contains executable code that is subsequently processed by the application.

Reproduction Steps (PoC: orbax_disclosure_v2.py):

  1. Prepare the Environment: Ensure flax, jax, orbax-checkpoint, msgpack, and numpy are installed.
  2. Create a Sensitive File: Create a dummy sensitive file at a known location, for example, /tmp/secret_item/data with content like SECRET_DATA_CONTENT.
    import os
    secret_dir = "/tmp/secret_item"
    if not os.path.exists(secret_dir):
        os.makedirs(secret_dir)
    with open(os.path.join(secret_dir, "data"), "w") as f:
        f.write("PWNED_DATA")
    
  3. Initialize Orbax CheckpointManager: Instantiate orbax.checkpoint.CheckpointManager with a PyTreeCheckpointer and a base directory (e.g., /tmp/orbax_base).
    import orbax.checkpoint
    base_dir = "/tmp/orbax_base"
    if not os.path.exists(base_dir):
        os.makedirs(base_dir)
    mngr = orbax.checkpoint.CheckpointManager(
        base_dir, 
        item_names=(\'data\',), 
        item_handlers={\'data\': orbax.checkpoint.PyTreeCheckpointer()}
    )
    
  4. Trigger Path Traversal: Call the restore method with a crafted step argument that includes path traversal sequences to point to the sensitive file.
    traversal_path = "../../tmp/secret_item"
    res = mngr.restore(traversal_path)
    print(f"[+] Successfully read data from outside base_dir: {res}")
    
  5. Observe Information Disclosure: The res variable will contain the content of /tmp/secret_item/data, demonstrating successful information disclosure.

Mitigation: Implement robust path sanitization and validation for all inputs used in constructing filesystem paths within CheckpointManager.restore. The step argument should be strictly validated to be a simple integer or a sanitized string that does not contain directory traversal characters.

1.2. Flax Insecure Pytree Reconstruction / Type Confusion (CVE-XXXX-XXXX)

Description: The flax.serialization.from_state_dict function, responsible for reconstructing Python objects from a serialized state dictionary, exhibits a type confusion vulnerability. When a leaf node in the target object (the template for reconstruction) is of a type not explicitly registered in Flax's _STATE_DICT_REGISTRY, the function directly returns the corresponding value from the untrusted state_dict without type validation or conversion. This bypasses expected type safety mechanisms.

Impact: This vulnerability allows for Arbitrary Type Injection. An attacker can craft a malicious state_dict to inject arbitrary Python objects (e.g., strings, integers, or even custom malicious classes if they can be instantiated) into a PyTree structure where the application expects a specific, registered type (like a JAX array or a Flax module). While not directly leading to ACE, this type confusion can be chained with other vulnerabilities or application logic flaws that make assumptions about the type of data being processed. For instance, if a downstream function expects a numerical array but receives a string, it could lead to unexpected behavior, crashes, or further exploitation.

Reproduction Steps (PoC: exploit_partial_v2.py):

  1. Prepare the Environment: Ensure flax, jax, orbax-checkpoint, msgpack, and numpy are installed.
  2. Define an Unregistered Type: Create a simple Python class that is not registered with Flax's serialization mechanism.
    class Secret: pass
    
  3. Create a Target Template: Define a PyTree (e.g., a dictionary) where one of the leaves is an instance of the unregistered type.
    target = {'key': Secret()}
    
  4. Craft a Malicious Payload: Create a state_dict where the value corresponding to the unregistered type's key is an arbitrary, attacker-controlled value (e.g., a string).
    payload = {'key': 'INJECTED_MALICIOUS_STRING'}
    
  5. Trigger Insecure Reconstruction: Call flax.serialization.from_state_dict with the target template and the malicious payload.
    restored = flax.serialization.from_state_dict(target, payload)
    print(f"[+] Restored with injection: {restored}")
    
  6. Observe Type Confusion: The restored['key'] will now contain 'INJECTED_MALICIOUS_STRING' instead of an instance of Secret, demonstrating successful type injection.

Mitigation: Implement strict type checking within flax.serialization.from_state_dict for all leaf nodes, regardless of whether their type is explicitly registered. If an unregistered type is encountered, it should either raise an error or be handled with a default safe deserialization mechanism, preventing arbitrary type injection.

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support