EMA#
- class flax.nnx.EMA(self, params, decay, *, only=Ellipsis, graph=None)[source]#
Exponential Moving Average (EMA) of parameters.
Maintains a shadow copy of model Variables that is updated as an exponentially weighted moving average on each call to
update(). This is commonly used to stabilize training and improve evaluation performance by applying the averaged parameters at inference time.Example usage:
>>> from flax import nnx >>> import jax, jax.numpy as jnp >>> import optax ... >>> model = nnx.Linear(2, 2, rngs=nnx.Rngs(0)) >>> optimizer = nnx.Optimizer(model, optax.sgd(0.1), wrt=nnx.Param) >>> ema = nnx.EMA(model, decay=0.9) >>> ema_model = ema.apply_to(model) ... >>> def loss_fn(model, x, y): ... return jnp.mean((model(x) - y) ** 2) ... >>> @nnx.jit ... def train_step(model, optimizer, ema, x, y): ... grads = nnx.grad(loss_fn)(model, x, y) ... optimizer.update(model, grads) ... ema.update(model) ... >>> @nnx.jit ... def eval_step(model, x, y): ... return loss_fn(model, x, y) ... >>> x, y = jnp.ones((1, 2)), jnp.ones((1, 2)) >>> train_step(model, optimizer, ema, x, y) >>> loss = eval_step(ema_model, x, y)
In this example,
ema.updatecomputes the moving average and updates the internal state ofema.ema.apply_tocreates a new model instance (ema_model) that shares its Variables withema. Therefore,ema_modelwill automatically reflect the updates performed byema.updateand can be used directly ineval_step.- decay#
The decay rate for the exponential moving average.
- filter#
The filter used to select which variables to track.
- params#
A pytree of variables holding the current moving average values.
- __init__(params, decay, *, only=Ellipsis, graph=None)[source]#
Initializes the EMA module.
- Parameters:
params – Any object, typically an NNX module/node, whose parameters will be tracked.
decay – The decay rate for the moving average.
only – A filter indicating which variables should be included in the EMA tracking. Defaults to matching everything. Note that EMA only tracks
nnx.Variableinstances.graph – If
True, uses graph-mode which supports the full NNX feature set including shared references. IfFalse, uses tree-mode which treats Modules as regular JAX pytrees, avoiding the overhead of the graph protocol. IfNone(default), the value is determined by the currentnnx.set_graph_modecontext.
- update(updates)[source]#
Updates the EMA parameters towards the given new parameters.
The update rule for each parameter is:
ema = decay * ema + (1 - decay) * update
- Parameters:
updates – The new parameters or module to blend into the current EMA. This should have the same structure as the
paramsobject passed during initialization.