pyscfad.util.pytree_node#
- pyscfad.util.pytree_node(leaf_names, num_args=0, exclude_aux_name=())[source]#
Class decorator that registers the underlying class as a pytree.
See jax document for the definition of pytrees.
- Parameters:
- leaf_nameslist or tuple
Attributes of the class that are traced as pytree leaves.
- num_argsint, optional
Number of positional arguments in
leaf_names
. This is useful when the__init__
method of the class has positional arguments that are named differently than the actual attribute names. Default value is 0.- exclude_aux_nametuple, default=()
A set of static attribute names that are not used for comparing the pytrees. Note that
jax.jit
recompiles the function for input pytrees with different static attribute values.
Notes
The
__init__
method of the class can’t have positional arguments that are not included inleaf_names
. Ifnum_args
is greater than 0, the sequence of positional arguments inleaf_names
must follow that in the__init__
method.