I ran into a subtle leaked Tracer bug while trying to vmap my training function across different values of learning rate. The training function contains the end-to-end training process for classic PPO. We use optax for our optimization. In the original code, we created a linearly decaying schedule for the learning rate and then passed that to optax’s TrainState, which we initialized outside of the training function. We then passed the TrainState object to the training function.
In order to vmap across learning rates, I had to move the TrainState initialization into the training function so that we could create a different schedule for each learning rate value. When I ran the code, I got this error:
Exception: Leaked trace BatchTrace
Here is a code snippet that reproduces the error. Instead of using optax’s TrainState, we create a custom pytree node to mimic its behavior.
import jaximport jax.numpy as jnpfrom jax import tree_utilclass TrainState:def__init__(self, fn):self.fn = fntree_util.register_pytree_node( TrainState,lambda container: ((), container),lambda aux_data, children: aux_data,)def train_fn(lr: jax.Array):def learning_schedule():return lr *10.0return TrainState(fn=learning_schedule)def train_fn_b(lr: jax.Array):return lr *10.0# An array of values to map overinput_values = jnp.array([1.0, 5.0, 10.0])# vmap the creator functionvmapped_creator = jax.vmap(train_fn)with jax.checking_leaks():# 4. Now the `TypeError` is bypassed, and JAX's leak detection# finds the tracer inside the container, raising the intended error. vmapped_creator(input_values)
Exception: Leaked trace BatchTrace. Leaked tracer(s):
Traced<ShapedArray(float32[])>with<BatchTrace> with
val = Array([ 1., 5., 10.], dtype=float32)
batch_dim = 0
This BatchTracer with object id 4816883200 was created on line:
/var/folders/q4/2lsmb6qd1ks8137720rg8fz80000gn/T/ipykernel_27631/527263614.py:34:4 (<module>)
<BatchTracer 4816883200> is referred to by <function 4822615680> (learning_schedule) closed-over variable lr
<function 4822615680> is referred to by <TrainState 4822757392>.fn
When we vmap train_fn, Jax passes a BatchedTracer through train_fn during the tracing stage. Within train_fn, lr (currently a BatchedTracer) is captured in the closure of learning_schedule. When we initialize TrainState by passing in learning_schedule, we are in a way “storing” a BatchedTracer in the TrainState object. That’s fine if TrainState remains in the function, but the problem here is that we are returning it from the function. This causes the BatchedTracer to be leaked.
“But tracing a function will return Tracers anyway!” you might say. Well, yes, we return Tracers, but that’s for when we do operations like * 5.0 or + 1.0 to the input Tracer. Storing the Tracer inside a Python object is not an operation that is “supported” by Tracers.
One gotcha is that we don’t gt this error if we don’t use jax.tracing_leaks(). Seems like Jax doesn’t do these checks by default?