Custom methods#
As the purpose of pyscfad is to provide a framework for developing new methods that are automatically differentiable, it also offers several useful functionalities to simplify such development. This is based on the powerful AD tools like JAX. Typically, developing a new method involves the following steps.
Class definition#
We assume the new method is defined in a custom class.
JAX function transformations are applied to functions that operate over
pytrees.
Although not required, it may be convenient to convert the class into a pytree,
so that the class instance can be passed to the functions being transformed.
This conversion can be achieved by subclassing the PytreeNode
class.
from pyscfad import numpy as np
from pyscfad.pytree import PytreeNode
class PowerSum(PytreeNode):
_dynamic_attr = {'array'}
def __init__(self, array, order=2):
self.array = array
self.order = order
def kernel(self):
return np.sum(self.array**self.order)
In the example above, we define a class whose kernel
function performs the calculation of
element-wise power then summation for the input array.
Note the class attribute _dynamic_attr
in the definition,
which labels the names of dynamic attributes of the class.
These attributes are considered as the leaves of the pytree,
which are traced variables in the computational graph.
Whereas the other attributes of the object are static, which means
that they are kept as constants during the computation.
Function transformation#
With the class registered as a pytree, it is possible to apply function transformations to the functions that take the class instance as the input.
import jax
a = PowerSum(np.eye(2), order=4)
grad = jax.jit(jax.grad(PowerSum.kernel))(a)
print(grad.array)
[[4. 0.]
[0. 4.]]
Here, both jax.jit
and jax.grad
can be applied to the kernel
function,
taking a
(an instance of PowerSum
) as the input.
Note that the static attribute order
must be kept unmodified within the function
being transformed, i.e., kernel
in this example. Otherwise, unpredicted behavior may occur.
Subclassing#
A subclass of PytreeNode
can be further subclassed.
In addition, only the newly added dynamic attributes need to be registered.
For example, to subclass PowerSum
, and to add a new dynamic variable array1
,
we can simply do the following.
class MultiplyPowerSum(PowerSum):
_dynamic_attr = {'array1'}
def __init__(self, array, array1, order=2):
super().__init__(array, order=order)
self.array1 = array1
def kernel(self):
return np.sum((self.array*self.array1)**self.order)
Now, both array
and array1
will be correctly traced.
a = MultiplyPowerSum(np.eye(2), np.eye(2)*2, order=4)
grad = jax.jit(jax.grad(MultiplyPowerSum.kernel))(a)
print(grad.array)
print(grad.array1)
[[64. 0.]
[ 0. 64.]]
[[32. 0.]
[ 0. 32.]]