pyscfad.util.pytree_node#

pyscfad.util.pytree_node(leaf_names, num_args=0, exclude_aux_name=())[source]#

Class decorator that registers the underlying class as a pytree.

See jax document for the definition of pytrees.

Parameters:
leaf_nameslist or tuple

Attributes of the class that are traced as pytree leaves.

num_argsint, optional

Number of positional arguments in leaf_names. This is useful when the __init__ method of the class has positional arguments that are named differently than the actual attribute names. Default value is 0.

exclude_aux_nametuple, default=()

A set of static attribute names that are not used for comparing the pytrees. Note that jax.jit recompiles the function for input pytrees with different static attribute values.

Notes

The __init__ method of the class can’t have positional arguments that are not included in leaf_names. If num_args is greater than 0, the sequence of positional arguments in leaf_names must follow that in the __init__ method.