# Copyright 2024 The Flax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
from collections import namedtuple
import dataclasses
import functools
import inspect
import typing as tp
from flax import struct
from flax import typing
from flax.nnx import filterlib, graphlib, reprlib, variablelib
from flax.nnx.statelib import State
from flax.nnx.pytreelib import Pytree
from flax.typing import Missing, PathParts
import jax
A = tp.TypeVar('A')
B = tp.TypeVar('B')
F = tp.TypeVar('F', bound=tp.Callable)
Index = int
KeyPath = tuple[tp.Hashable, ...]
Prefix = tp.Any
Leaf = tp.Any
class OrderedDict(reprlib.MappingReprMixin, dict[A, B]):
pass
def _ordered_dict_flatten_with_keys(d: OrderedDict):
children = [(jax.tree_util.DictKey(k), v) for k, v in d.items()]
return children, tuple(d.keys())
jax.tree_util.register_pytree_with_keys(
OrderedDict,
_ordered_dict_flatten_with_keys,
lambda keys, values: OrderedDict(zip(keys, values)),
)
_labeled_tuples_cache: dict[tuple[str, ...], type[tp.Any]] = {}
def labeled(**kwargs):
keys = tuple(kwargs.keys())
if keys not in _labeled_tuples_cache:
class LabeledTuple(namedtuple('LabeledTuple', keys)):
def keys(self):
return self._fields
def __getitem__(self, key):
if isinstance(key, str):
return getattr(self, key)
return super().__getitem__(key)
_labeled_tuples_cache[keys] = LabeledTuple
return _labeled_tuples_cache[keys](**kwargs)
class PrefixMapping(abc.ABC):
@abc.abstractmethod
def map_prefix(
self,
path: typing.PathParts,
variable: variablelib.Variable,
/,
) -> tp.Any: ...
def check_consistent_aliasing(
node: tp.Any,
prefix: tp.Any,
/,
*,
node_prefixes: dict[int, list[tuple[PathParts, tp.Any]]] | None = None,
):
"""Check for consistent aliasing of nodes when extracting graph."""
if node_prefixes is None:
node_prefixes = {}
# Store variable references for error messages
node_id_to_variable: dict[int, tp.Any] = {}
# collect all paths and prefixes for each node
for path, value in graphlib.iter_graph(node, graph=True):
if graphlib.is_graph_node(value) or isinstance(value, graphlib.Variable):
if isinstance(value, Pytree):
value._check_valid_context(
lambda: f'Trying to extract graph node from different trace level, got {value!r}'
)
if isinstance(value, graphlib.Variable):
if not value._can_update:
raise ValueError(
f'Cannot extract graph node from different trace level, got {value!r}'
)
if isinstance(prefix, PrefixMapping):
variable_prefix = prefix.map_prefix(path, value)
else:
variable_prefix = prefix
value_id = id(value)
node_id_to_variable[value_id] = value
if value_id in node_prefixes:
paths_prefixes = node_prefixes[value_id]
paths_prefixes.append((path, variable_prefix))
else:
node_prefixes[value_id] = [(path, variable_prefix)]
# check for inconsistent aliasing
node_msgs = []
for node_id, paths_prefixes in node_prefixes.items():
unique_prefixes = {prefix for _, prefix in paths_prefixes}
if len(unique_prefixes) > 1:
path_prefix_repr = '\n'.join(
f' {"/".join(map(str,path)) if path else "<root>"}: {prefix}'
for path, prefix in paths_prefixes
)
# Get the variable type name if available
if node_id in node_id_to_variable:
variable = node_id_to_variable[node_id]
node_type_name = type(variable).__name__
else:
node_type_name = f'Node ID: {node_id}'
nodes_msg = f'Node: {node_type_name}\n{path_prefix_repr}'
node_msgs.append(nodes_msg)
if node_msgs:
raise ValueError(
'Inconsistent aliasing detected. The following nodes have different prefixes:\n'
+ '\n'.join(node_msgs)
)
def check_consistent_aliasing2(
node: tp.Any,
prefix: tp.Any,
/,
*,
base_path: tuple[tp.Any, ...] = (),
node_prefixes: dict[int, list[tuple[PathParts, tp.Any]]],
):
node_id_to_variable: dict[int, tp.Any] = {}
for local_path, value in graphlib.iter_graph(node, graph=True):
path = base_path + local_path
if isinstance(value, variablelib.Variable):
value_id = id(value)
node_id_to_variable[value_id] = value
# If prefix is a TreeState (e.g. from nnx.prefix(graph=True)),
# extract the actual prefix value for this Variable using local_path.
if isinstance(prefix, TreeState):
prefix_fn = prefix.prefix_fn.value
if not callable(prefix_fn):
raise ValueError(
'When passing a TreeState object as a prefix (e.g. for'
' `in_axes`), it must have been produced by `nnx.prefix()` or'
' contain a callable in `TreeState.metadata` with signature'
' `(path: tuple[Any, ...], value: Variable) -> Any`. Got'
f' metadata of type {type(prefix_fn).__name__}.'
)
leaf_prefix = prefix_fn(local_path, value)
else:
leaf_prefix = prefix
if value_id in node_prefixes:
node_prefixes[value_id].append((path, leaf_prefix))
else:
node_prefixes[value_id] = [(path, leaf_prefix)]
node_msgs = []
for node_id, paths_prefixes in node_prefixes.items():
unique_prefixes = {p for _, p in paths_prefixes}
if len(unique_prefixes) > 1:
path_prefix_repr = '\n'.join(
f' {"/".join(map(str,path)) if path else "<root>"}: {p}'
for path, p in paths_prefixes
)
if node_id in node_id_to_variable:
variable = node_id_to_variable[node_id]
node_type_name = type(variable).__name__
else:
node_type_name = f'Node ID: {node_id}'
node_msgs.append(f'Node: {node_type_name}\n{path_prefix_repr}')
if node_msgs:
raise ValueError(
'Inconsistent aliasing detected. The following nodes have different prefixes:\n'
+ '\n'.join(node_msgs)
)
# -----------------------------
# to_tree/from_tree
# -----------------------------
def broadcast_prefix(
prefix_tree: tp.Any,
full_tree: tp.Any,
prefix_is_leaf: tp.Callable[[tp.Any], bool] | None = None,
tree_is_leaf: tp.Callable[[tp.Any], bool] | None = None,
) -> list[tp.Any]:
# If prefix_tree is not a tree prefix of full_tree, this code can raise a
# ValueError; use prefix_errors to find disagreements and raise more precise
# error messages.
result = []
num_leaves = lambda t: jax.tree_util.tree_structure(
t, is_leaf=tree_is_leaf
).num_leaves
add_leaves = lambda x, subtree: result.extend([x] * num_leaves(subtree))
jax.tree.map(
add_leaves,
prefix_tree,
full_tree,
is_leaf=lambda x: isinstance(x, variablelib.Variable)
or graphlib.is_graph_node(x)
or (prefix_is_leaf is not None and prefix_is_leaf(x)),
)
return result
def broadcast_prefix2(
prefix_tree: tp.Any,
full_tree: tp.Any,
is_leaf: tp.Callable[[tp.Any], bool] | None = None,
prefix_leaf: tp.Callable[[tp.Any], bool] | None = None,
) -> tuple[list[KeyPath], list[tp.Any]]:
_prefix_leaf: tp.Callable[[tp.Any], bool] | None
if prefix_leaf is not None and is_leaf is not None:
_prefix_leaf = lambda x: prefix_leaf(x) or is_leaf(x)
elif prefix_leaf is not None:
_prefix_leaf = prefix_leaf
else:
_prefix_leaf = is_leaf
paths: list[KeyPath] = []
leaves: list[tp.Any] = []
num_leaves = lambda t: jax.tree.structure(t, is_leaf=is_leaf).num_leaves
def add_leaves(path, x, subtree):
n = num_leaves(subtree)
paths.extend([path] * n)
leaves.extend([x] * n)
jax.tree.map_with_path(add_leaves, prefix_tree, full_tree, is_leaf=_prefix_leaf)
return paths, leaves
def broadcast_prefix_map(
f: tp.Callable[..., tp.Any],
prefix_tree: tp.Any,
full_tree: tp.Any,
*rest: tp.Any,
is_leaf: tp.Callable[[tp.Any], bool] | None = None,
prefix_leaf: tp.Callable[[tp.Any], bool] | None = None,
) -> tp.Any:
_, prefix_leaves = broadcast_prefix2(prefix_tree, full_tree, is_leaf=is_leaf, prefix_leaf=prefix_leaf)
full_leaves_with_path, treedef = jax.tree.flatten_with_path(full_tree, is_leaf=is_leaf)
rest_flat = [treedef.flatten_up_to(r) for r in rest]
out_leaves = []
for (path, full_leaf), p_leaf, *r_leaves in zip(full_leaves_with_path, prefix_leaves, *rest_flat):
out_leaf = f(path, p_leaf, full_leaf, *r_leaves)
out_leaves.append(out_leaf)
return jax.tree.unflatten(treedef, out_leaves)
class GraphDefState(struct.PyTreeNode):
graphdef: graphlib.GraphDef[tp.Any] = struct.field(pytree_node=False)
state: State = struct.field(pytree_node=True)
S = tp.TypeVar(
'S', bound=State | graphlib.GraphFlatState | list[tp.Any]
)
class NodeStates(struct.PyTreeNode):
_graphdef: graphlib.GraphDef[tp.Any] | None
states: tuple[tp.Any, ...]
metadata: tp.Any = struct.field(pytree_node=False)
@property
def graphdef(self) -> graphlib.GraphDef[tp.Any]:
if self._graphdef is None:
raise ValueError('No graphdef available')
return self._graphdef
@property
def state(self) -> tp.Any:
if len(self.states) != 1:
raise ValueError(
f'Expected exactly one GraphDefState, got {len(self.states)}'
)
return self.states[0]
@classmethod
def from_split(
cls,
graphdef: graphlib.GraphDef[tp.Any] | None,
state: tp.Any,
/,
*states: tp.Any,
metadata: tp.Any = None,
):
return cls(_graphdef=graphdef, states=(state, *states), metadata=metadata)
@classmethod
def from_states(
cls,
state: tp.Any,
*states: tp.Any,
):
return cls(_graphdef=None, states=(state, *states), metadata=None)
@classmethod
def from_prefixes(
cls,
prefixes: tp.Iterable[tp.Any],
/,
*,
metadata: tp.Any = None,
):
return cls(_graphdef=None, states=tuple(prefixes), metadata=metadata)
def default_split_fn(
ctx: graphlib.SplitContext, path: KeyPath, prefix: Prefix, leaf: Leaf
) -> tp.Any:
return NodeStates.from_split(*ctx.split(leaf))
def to_tree(
tree,
/,
*,
prefix: tp.Any = Missing,
split_fn: tp.Callable[
[graphlib.SplitContext, KeyPath, Prefix, Leaf], tp.Any
] = default_split_fn,
map_non_graph_nodes: bool = False,
ctxtag: tp.Hashable | None = None,
check_aliasing: bool = True,
) -> tp.Any:
if prefix is Missing or prefix is None:
# fast path, no need for prefix broadcasting or consistent aliasing checks
with graphlib.split_context(ctxtag) as split_ctx:
return jax.tree.map(
lambda x: split_fn(split_ctx, (), prefix, x)
if map_non_graph_nodes
or graphlib.is_graph_node(x)
or isinstance(x, variablelib.Variable)
else x,
tree,
is_leaf=lambda x: isinstance(x, variablelib.Variable)
or graphlib.is_graph_node(x),
)
leaf_prefixes = broadcast_prefix(
prefix,
tree,
prefix_is_leaf=lambda x: x is None
or isinstance(x, variablelib.Variable)
or graphlib.is_graph_node(x),
tree_is_leaf=lambda x: isinstance(x, variablelib.Variable)
or graphlib.is_graph_node(x),
)
leaf_keys, treedef = jax.tree_util.tree_flatten_with_path(
tree,
is_leaf=lambda x: isinstance(x, variablelib.Variable)
or graphlib.is_graph_node(x),
)
assert len(leaf_keys) == len(leaf_prefixes)
leaves_out = []
node_prefixes: dict[int, list[tuple[PathParts, tp.Any]]] = {}
with graphlib.split_context(ctxtag) as split_ctx:
for (keypath, leaf), leaf_prefix in zip(leaf_keys, leaf_prefixes):
if graphlib.is_graph_node(leaf) or isinstance(leaf, variablelib.Variable):
if check_aliasing:
check_consistent_aliasing(
leaf, leaf_prefix, node_prefixes=node_prefixes
)
tree_node = split_fn(split_ctx, keypath, leaf_prefix, leaf)
leaves_out.append(tree_node)
else:
if map_non_graph_nodes:
leaf = split_fn(split_ctx, keypath, leaf_prefix, leaf)
leaves_out.append(leaf)
pytree_out = jax.tree.unflatten(treedef, leaves_out)
return pytree_out
@dataclasses.dataclass(frozen=True, slots=True)
class Opaque(tp.Generic[A]):
value: A
def __eq__(self, other):
return type(self) == type(other)
def __hash__(self):
return 0
@functools.partial(
jax.tree_util.register_dataclass,
data_fields=['__state__'],
meta_fields=['graphdef', 'prefix_fn'],
)
@dataclasses.dataclass(frozen=True, slots=True)
class TreeState:
graphdef: graphlib.GraphDef[tp.Any] | None
__state__: tp.Any
prefix_fn: Opaque[tp.Callable[[PathParts, tp.Any], tp.Any] | None] = Opaque(
None
)
@property
def state(self) -> tp.Any:
return self.__state__
def replace(self, **kwargs):
return dataclasses.replace(self, **kwargs)
def to_tree2(
tree,
/,
*,
prefix: tp.Any = Missing,
check_aliasing: bool = True,
prefix_fn: tp.Callable[[PathParts, tp.Any], tp.Any] | None = None,
) -> tp.Any:
"""to_tree2 has two main tasks:
1. Convert all graph nodes to TreeState (a tree representation).
2. Check all Variables are aliased consistently given the prefix tree,
e.g. vmap's in/out_axes arguments.
Each NodeState contains the `GraphDef` and State for each object, these
are generated using `graphlib.flatten`. `extract.broadcast_prefix` is used
to calculate the prefix for each node, `check_consistent_aliasing2` traverses
the nodes subgraph and checks for Variable aliasing.
"""
ref_index: graphlib.RefMap = graphlib.RefMap()
def _to_node_states(leaf):
if not (graphlib.is_graph_node(leaf) or isinstance(leaf, variablelib.Variable)):
return leaf
graphdef, flat_state = graphlib.flatten(
leaf, ref_index=ref_index, graph=True
)
(state,) = graphlib._to_nested_state(graphdef, (flat_state,))
return TreeState(graphdef, state, prefix_fn=Opaque(prefix_fn))
is_leaf = lambda x: (
isinstance(x, variablelib.Variable) or graphlib.is_graph_node(x)
)
if prefix is Missing or prefix is None:
return jax.tree.map(_to_node_states, tree, is_leaf=is_leaf)
leaf_prefixes = broadcast_prefix(
prefix,
tree,
prefix_is_leaf=lambda x: x is None
or isinstance(x, TreeState)
or is_leaf(x),
tree_is_leaf=is_leaf,
)
leaf_paths, treedef = jax.tree_util.tree_flatten_with_path(tree, is_leaf=is_leaf)
assert len(leaf_paths) == len(leaf_prefixes)
leaves_out = []
node_prefixes: dict[int, list[tuple[PathParts, tp.Any]]] = {}
for (keypath, leaf), leaf_prefix in zip(leaf_paths, leaf_prefixes):
if is_leaf(leaf):
if check_aliasing:
base_path = graphlib.jax_to_nnx_path(keypath)
check_consistent_aliasing2(
leaf, leaf_prefix, base_path=base_path, node_prefixes=node_prefixes
)
leaves_out.append(_to_node_states(leaf))
else:
leaves_out.append(leaf)
return jax.tree.unflatten(treedef, leaves_out)
def from_tree2(tree: tp.Any, /, recreate_variables: bool = True) -> tp.Any:
index_ref = graphlib.IndexMap()
def _from_node_states(x):
if not isinstance(x, TreeState):
return x
state = graphlib._merge_to_flat_state((x.state,))
return graphlib.unflatten(
x.graphdef, state, index_ref=index_ref,
recreate_variables=recreate_variables,
)
return jax.tree.map(
_from_node_states,
tree,
is_leaf=lambda x: (
isinstance(x, TreeState)
or graphlib.is_graph_node(x)
or isinstance(x, variablelib.Variable)
),
)
def merge_tree_node(
ctx: graphlib.MergeContext, path: KeyPath, prefix: Prefix, leaf: Leaf
) -> tp.Any:
if not isinstance(leaf, NodeStates):
raise ValueError(f'Expected TreeNode, got {type(leaf)} at path {path}')
return ctx.merge(leaf.graphdef, *leaf.states)
def is_tree_node(x):
return isinstance(x, NodeStates)
def from_tree(
tree: tp.Any,
/,
*,
prefix: tp.Any = Missing,
merge_fn: tp.Callable[
[graphlib.MergeContext, KeyPath, Prefix, Leaf], tp.Any
] = merge_tree_node,
is_node_leaf: tp.Callable[[Leaf], bool] = is_tree_node,
is_leaf: tp.Callable[[Leaf], bool] = is_tree_node,
map_non_graph_nodes: bool = False,
is_inner: bool | None = None,
ctxtag: tp.Hashable | None = None,
) -> tp.Any:
if prefix is Missing or prefix is None:
# fast path, no need for prefix broadcasting or consistent aliasing checks
with graphlib.merge_context(ctxtag, is_inner) as merge_ctx:
def maybe_split(x):
if (
map_non_graph_nodes
or is_node_leaf(x)
or isinstance(x, variablelib.Variable)
):
return merge_fn(merge_ctx, (), prefix, x)
return x
return jax.tree.map(maybe_split, tree, is_leaf=is_leaf)
leaf_prefixes = broadcast_prefix(
prefix,
tree,
prefix_is_leaf=lambda x: x is None or is_leaf(x),
tree_is_leaf=is_leaf,
)
leaf_keys, treedef = jax.tree_util.tree_flatten_with_path(
tree, is_leaf=is_leaf
)
assert len(leaf_keys) == len(leaf_prefixes)
leaves_out = []
with graphlib.merge_context(ctxtag, is_inner) as merge_ctx:
for (keypath, leaf), leaf_prefix in zip(leaf_keys, leaf_prefixes):
if (
map_non_graph_nodes
or is_node_leaf(leaf)
or isinstance(leaf, variablelib.Variable)
):
leaf = merge_fn(merge_ctx, keypath, leaf_prefix, leaf)
leaves_out.append(leaf)
pytree_out = jax.tree.unflatten(treedef, leaves_out)
return pytree_out
def clear_non_graph_nodes(tree):
return jax.tree.map(
lambda x: x
if graphlib.is_graph_node(x) or isinstance(x, variablelib.Variable)
else None,
tree,
is_leaf=lambda x: isinstance(x, variablelib.Variable)
or graphlib.is_graph_node(x),
)
class Mask(tp.NamedTuple):
pass
def mask_at(t: tuple, index: int | None) -> tuple[tp.Any, tuple]:
if index is None or not isinstance(t, tuple):
return None, t
x = t[index]
new_t = tuple(
Mask() if i == index else x
for i, x in enumerate(t)
)
return x, new_t
def slice_at(t: tuple, index: int | None) -> tuple[tp.Any, tuple]:
if index is None:
return None, t
return t[index], t[:index] + t[index + 1 :]
def replace_at(t: tuple, index: int | None, value: tp.Any) -> tuple:
if index is None:
return t
xs = list(t)
xs[index] = value
return tuple(xs)
def find(t: tuple, value: tp.Any) -> int | None:
if not isinstance(t, tuple):
return None
return next((i for i, x in enumerate(t) if x == value), None)
@jax.tree_util.register_static
@dataclasses.dataclass(frozen=True, slots=True)
class ExtractIndex:
index: int
def extract(
f: tp.Callable[[jax.tree_util.KeyPath, tp.Any, tp.Any], bool],
prefix: tp.Any,
tree: tp.Any,
*,
is_leaf: tp.Callable[[tp.Any], bool] | None = None,
prefix_leaf: tp.Callable[[tp.Any], bool] | None = None,
) -> tuple[tp.Any, list[tp.Any]]:
extracted: list[tp.Any] = []
def _leaf_fn(path, prefix_leaf, leaf):
if f(path, prefix_leaf, leaf):
idx = len(extracted)
extracted.append(leaf)
return ExtractIndex(idx)
return leaf
new_tree = broadcast_prefix_map(
_leaf_fn, prefix, tree, is_leaf=is_leaf, prefix_leaf=prefix_leaf
)
return new_tree, extracted
def insert(tree: tp.Any, extracted: list[tp.Any]) -> tp.Any:
def _leaf_fn(leaf):
if isinstance(leaf, ExtractIndex):
return extracted[leaf.index]
return leaf
return jax.tree.map(
_leaf_fn, tree, is_leaf=lambda x: isinstance(x, ExtractIndex)
)
def snapshot(args: A) -> tuple[A, A]:
is_leaf = lambda x: isinstance(x, variablelib.Variable)
current = jax.tree.map(lambda x: x, args, is_leaf=is_leaf)
snapshot = jax.tree.map(lambda x: x, args)
return current, snapshot
def copy_var_structure(tree: A) -> A:
return jax.tree.map(
lambda x: x, tree, is_leaf=lambda x: isinstance(x, variablelib.Variable)
)
def check_no_aliases(
fn_name: str, /, *, check: tp.Iterable[str] = (), **kwargs
) -> dict[jax.tree_util.KeyPath, variablelib.Variable]:
container = labeled(**kwargs)
is_leaf = lambda x: isinstance(x, variablelib.Variable)
seen: dict[int, jax.tree_util.KeyPath] = {}
all_variables: dict[jax.tree_util.KeyPath, variablelib.Variable] = {}
for path, leaf in jax.tree.leaves_with_path(container, is_leaf=is_leaf):
if not isinstance(leaf, variablelib.Variable):
continue
assert isinstance(path[0], jax.tree_util.GetAttrKey)
kwarg_name = path[0].name
if kwarg_name in check:
if not leaf._can_update:
path_str = jax.tree_util.keystr(path)
raise ValueError(
f'Cannot return captured Variable of type {type(leaf).__name__} '
f'from nnx.{fn_name}.\n'
f'Found at path: {path_str}'
)
var_id = id(leaf)
if var_id in seen:
path_str = jax.tree_util.keystr(path)
seen_path_str = jax.tree_util.keystr(seen[var_id])
raise ValueError(
f'Duplicate {leaf}\nfound at paths:\n\n'
f' - {seen_path_str}\n'
f' - {path_str}\n\n'
f'nnx.{fn_name} with graph_updates=False does not support '
'Variable aliasing (duplicate inputs, duplicate outputs, or '
'input Variables returned as outputs). '
f'Consider the following options:\n\n'
f'1. Remove the duplicate Variables.\n'
f'2. Create new Variables via nnx.clone() and use those instead.\n'
f'3. Enable graph mode and graph updates by passing graph=True and '
f'graph_updates=True to {fn_name}\n\n'
f' nnx.{fn_name}(..., graph=True, graph_updates=True)\n\n'
f'4. Use nnx.compat.{fn_name} (sets graph and graph_updates to True '
f'automatically)\n\n'
f' nnx.compat.{fn_name}(...)'
)
seen[var_id] = path
all_variables[path] = leaf
return all_variables
def check_prefix(
prefix: tp.Any,
prefix_name: str,
fn_name: str,
graph: bool,
graph_updates: bool,
none_leaf: bool = True,
):
unique_prefixes: OrderedDict[tp.Any, tp.Any] = OrderedDict()
def _check_prefix(path, leaf):
if isinstance(leaf, variablelib.Variable):
raise ValueError(
f'Found Variable of type {type(leaf).__name__} '
f'at path {jax.tree_util.keystr(path)} in `{prefix_name}` '
f'for nnx.{fn_name}. Variables prefixes are not supported.'
f'Pass a prefix for the entire Variable instead of passing a '
f'Variable with a prefix for its value.'
)
if isinstance(leaf, PrefixMapping) and not (graph and graph_updates):
raise ValueError(
f'`{prefix_name}` cannot contain `{type(leaf).__name__}` objects '
f'when `graph=False` or `graph_updates=False`. '
f'Consider the following options:\n\n'
f'1. Remove `{type(leaf).__name__}` objects from `{prefix_name}`.\n'
f'2. Enable graph mode and graph updates by passing graph=True and '
f'graph_updates=True to {fn_name} e.g.\n\n'
f' nnx.{fn_name}(..., graph=True, graph_updates=True)\n\n'
f'3. Use nnx.compat.{fn_name} instead e.g.\n\n'
f' nnx.compat.{fn_name}(...)'
)
if graphlib.is_graph_node(leaf) and graph:
raise ValueError(
f'Found graph node of type {type(leaf).__name__} '
f'at path {jax.tree_util.keystr(path)} in `{prefix_name}` '
f'for nnx.{fn_name}. Graph nodes are not allowed as prefixes when '
f'graph=True.'
f'Consider the following options:\n\n'
f'1. Remove graph nodes from `{prefix_name}`.\n'
f'2. Enable tree mode by passing graph=False to {fn_name} e.g.\n\n'
f' nnx.{fn_name}(..., graph=False)\n\n'
f'3. If you using nnx.prefix to create the prefix, pass graph=True:\n\n'
f' prefix = nnx.prefix(..., graph=True)'
)
if isinstance(leaf, TreeState) and (not graph or graph_updates):
msg = (
f'Found `TreeState` object at path {jax.tree_util.keystr(path)} in '
f'`{prefix_name}` for nnx.{fn_name}. `TreeState` objects are only '
f'allowed as prefixes when `graph=True` and `graph_updates=False`.'
f'Consider the following options:\n\n'
f'1. Enable graph mode and graph updates by passing graph=True and '
f'graph_updates=True to {fn_name} e.g.\n\n'
f' nnx.{fn_name}(..., graph=True, graph_updates=True)\n\n'
f'2. Use nnx.compat.{fn_name} instead e.g.\n\n'
f' nnx.compat.{fn_name}(...)'
)
if graph_updates:
msg += (
f'\n\n3. If you using nnx.prefix to create the prefix, pass graph=False:\n\n'
f' prefix = nnx.prefix(..., graph=False)'
)
raise ValueError(msg)
jax.tree.map_with_path(
_check_prefix,
prefix,
is_leaf=lambda x: x is None
or isinstance(x, variablelib.Variable)
or graphlib.is_graph_node(x)
or isinstance(x, PrefixMapping)
or isinstance(x, TreeState),
)
def _collect_prefix(_, leaf):
unique_prefixes[leaf] = leaf
jax.tree.map_with_path(
_collect_prefix, prefix, is_leaf=lambda x: x is None and none_leaf
)
return unique_prefixes
def variable_changed(post: variablelib.Variable, pre: variablelib.Variable) -> bool:
post_leaves, post_td = jax.tree.flatten(post)
pre_leaves, pre_td = jax.tree.flatten(pre)
return post_td != pre_td or any( # type: ignore[operator]
a is not b for a, b in zip(post_leaves, pre_leaves)
)
KeepFn = tp.Callable[
[PathParts, tp.Any, variablelib.Variable, variablelib.Variable], bool
]
class Updates(
tp.Sequence[tuple[jax.tree_util.KeyPath, variablelib.Variable]],
reprlib.Representable,
):
__slots__ = ('_keys', '_values')
_keys: list[jax.tree_util.KeyPath]
_values: list[variablelib.Variable]
def __init__(
self,
items: tp.Iterable[
tuple[jax.tree_util.KeyPath, variablelib.Variable]
] = (),
):
self._keys, self._values = [], []
for key, value in items:
self._keys.append(key)
self._values.append(value)
def append(self, key: jax.tree_util.KeyPath, value: variablelib.Variable):
self._keys.append(key)
self._values.append(value)
@property
def paths(self) -> list[jax.tree_util.KeyPath]:
return self._keys
@property
def leaves(self) -> list[variablelib.Variable]:
return self._values
@tp.overload
def __getitem__(
self, key: int
) -> tuple[jax.tree_util.KeyPath, variablelib.Variable]:
...
@tp.overload
def __getitem__(
self, key: slice
) -> tp.Sequence[tuple[jax.tree_util.KeyPath, variablelib.Variable]]:
...
@tp.overload # type: ignore[override]
def __getitem__(self, key: tuple[tp.Hashable, ...]) -> variablelib.Variable:
...
def __getitem__(
self, key: int | slice | jax.tree_util.KeyPath
):
if isinstance(key, int):
return self._keys[key], self._values[key]
elif isinstance(key, slice):
raise NotImplementedError('Slicing is not supported for Updates.')
idx = self._keys.index(key)
return self._values[idx]
def __len__(self):
return len(self._keys)
def __iter__(self):
return iter(zip(self._keys, self._values))
def __nnx_repr__(self):
yield reprlib.Object(type=type(self), kv_sep=': ', start='({', end='})')
for path, value in self:
yield reprlib.Attr(
jax.tree_util.keystr(path),
value,
use_raw_key=True,
)
def _updates_flatten_with_keys(x: Updates):
key_children = [
(jax.tree_util.FlattenedIndexKey(i), v)
for i, v in enumerate(x._values)
]
return key_children, x._keys
def _updates_flatten(x: Updates):
return x._values, x._keys
def _updates_unflatten(keys, values) -> Updates:
updates = object.__new__(Updates)
updates._keys = keys
updates._values = list(values)
return updates
jax.tree_util.register_pytree_with_keys(
Updates,
_updates_flatten_with_keys,
_updates_unflatten,
flatten_func=_updates_flatten,
)
def get_updates(
current_tree: A,
snapshot_tree: A,
*,
prefix: tp.Any = None,
keep_fn: KeepFn | None = None,
known_prefixes: tp.Iterable[tp.Any] = (None,),
):
if keep_fn is None:
keep_fn = lambda _, _pfx, cur, snap: variable_changed(cur, snap)
updates = OrderedDict((pfx, Updates()) for pfx in known_prefixes)
def _mask_updates(path, prefix_leaf, current, snapshot):
if isinstance(current, variablelib.Variable):
if current.hijax or current.ref:
return
if keep_fn(path, prefix_leaf, current, snapshot):
updates[prefix_leaf].append(path, current)
prefix_leaf = lambda x: x is None
is_leaf = lambda x: isinstance(x, variablelib.Variable)
broadcast_prefix_map(
_mask_updates, prefix, current_tree, snapshot_tree, is_leaf=is_leaf,
prefix_leaf=prefix_leaf,
)
return updates
def apply_updates(
variables: dict[jax.tree_util.KeyPath, variablelib.Variable],
updates: OrderedDict[tp.Any, Updates],
):
for _, flat_state in updates.items():
for path, update in flat_state:
if path in variables:
variable = variables[path]
assert isinstance(variable, variablelib.Variable)
variable.update_from_state(update)
else:
path_str = jax.tree_util.keystr(path)
raise RuntimeError(
f'Variable not found at path {path_str}. This is a bug in NNX, '
f'please report it. Variable: {update}'
)
def treemap_copy_args(f: F) -> F:
@functools.wraps(f)
def wrapper(*args, **kwargs):
args, kwargs = jax.tree.map(lambda x: x, (args, kwargs))
return f(*args, **kwargs)
return wrapper # type: ignore[return-value]
def check_same_variables(inputs, outputs, transform_name: str = ''):
def _check(in_leaf, out_leaf):
if isinstance(in_leaf, variablelib.Variable) and in_leaf is not out_leaf:
raise ValueError(
f'{transform_name} Variable identity must be preserved '
'across iterations.'
)
is_leaf = lambda x: isinstance(x, (Mask, variablelib.Variable))
jax.tree.map(
_check, inputs, outputs,
is_leaf=is_leaf,
)
def update_carry_variables(init_val, val_out):
def _update(in_leaf, out_leaf):
if isinstance(in_leaf, variablelib.Variable):
in_leaf.update_from_state(out_leaf)
return in_leaf
return out_leaf
return jax.tree.map(
_update, init_val, val_out,
is_leaf=lambda x: isinstance(x, variablelib.Variable),
)
[docs]def prefix(
node,
filter_map: tp.Mapping[filterlib.Filter, tp.Any] | tp.Callable[..., tp.Any],
/,
*,
graph: bool | None = None,
):
"""Replaces leaves in a graph node with prefix values.
``prefix`` replaces each leaf in ``node`` with a prefix value computed by
``filter_map(path, leaf)``. In graph mode (``graph=True``), the node is
first converted to a tree and the prefix is applied to
the resulting structure so it can be used directly as axes arguments for
transforms like ``nnx.vmap``.
Example usage::
from flax import nnx
import jax.numpy as jnp
d = {'a': nnx.Param(jnp.array(2)), 'b': nnx.BatchStat(jnp.arange(5))}
prefix = nnx.prefix(d, lambda path, x: 0 if 'b' in path else None)
@nnx.vmap(in_axes=(prefix,))
def f(d):
return d['a'] * d['b']
f(d) # Array([0, 2, 4, 6, 8])
``filter_map`` can also be a mapping from :class:`Filter` to prefix values.
Filters are checked in order and the first match determines the prefix::
d = {'a': nnx.Param(jnp.array(2)), 'b': nnx.BatchStat(jnp.arange(5))}
prefix = nnx.prefix(d, {nnx.Param: None, nnx.BatchStat: 0})
Calculating prefixes for graph mode transforms is a bit more involved as
the graph nodes are first converted to a trees in an order-dependent manner.
This means prefixes should be calculated jointly between all graph nodes in
the transform in the same order they appear in the arguments. For example::
import jax
import jax.numpy as jnp
@nnx.vmap
def create_model(rngs):
return nnx.Linear(2, 3, rngs=rngs)
model = create_model(nnx.Rngs(0).split(4))
px1, px2 = nnx.prefix((model, model), {nnx.Param: 0}, graph=True)
@nnx.vmap(in_axes=(px1, px2, None), graph=True)
def forward(m1, m2, x):
assert m1 is m2
return m1(x) + m2(x)
y = forward(model, model, jnp.ones(2))
assert y.shape == (4, 3)
The prefixes might be invalid if all graph node involved in the transform
aren't passed to `nnx.prefix`.
Args:
node: A graph node object.
filter_map: A callable ``(path, leaf) -> prefix`` that computes the prefix
for each leaf, or a mapping from :class:`Filter` to prefix values (filters
are checked in order; the first match determines the prefix).
graph: If ``True``, uses graph-mode which supports the full NNX feature set
including shared references. If ``False``, uses tree-mode which treats
Modules as regular JAX pytrees, avoiding the overhead of the graph
protocol.
Returns:
A new tree with prefix values replacing the leaves.
"""
if graph is None:
graph = graphlib.set_graph_mode.current_value()
if isinstance(filter_map, tp.Mapping):
predicates = tuple(
(filterlib.to_predicate(f), value) for f, value in filter_map.items()
)
filters = list(filter_map.keys())
def prefix_fn(path, leaf):
for predicate, _prefix in predicates:
if predicate(path, leaf):
return _prefix
raise ValueError(
f'No filter matched leaf at path {path!r} with value {leaf!r}. '
f'Filters: {filters}'
)
else:
prefix_fn = filter_map
is_leaf = lambda x: isinstance(x, variablelib.Variable)
if graph:
node = to_tree2(node, prefix_fn=prefix_fn)
def _apply_prefix(jax_path, leaf):
path = graphlib.jax_to_nnx_path(jax_path)
if graph:
# remove __state__ resulting from TreeState from path
# to match the path you get on graph=False
path = tuple(k for k in path if k != '__state__')
return prefix_fn(path, leaf)
return jax.tree.map_with_path(_apply_prefix, node, is_leaf=is_leaf)
def to_masked(tree, all_updates: OrderedDict[tp.Any, Updates]):
combined: OrderedDict[tp.Any, tp.Any] = OrderedDict()
for updates in all_updates.values():
combined.update(updates)
return jax.tree.map_with_path(
lambda path, _: combined.get(path, None), tree,
is_leaf=lambda x: x is None
)
def filter_kwargs(f, **kwargs):
sig = inspect.signature(f)
has_var_keyword = any(
p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values()
)
if has_var_keyword:
return kwargs
named_params = {
name
for name, p in sig.parameters.items()
if p.kind
in (
inspect.Parameter.POSITIONAL_OR_KEYWORD,
inspect.Parameter.KEYWORD_ONLY,
)
}
filtered_kwargs = {k: v for k, v in kwargs.items() if k in named_params}
return filtered_kwargs