Backends#
The future development of pyscfad may support multiple numpy-like backends
with a universal implementation. The work is still on-going.
As of version 0.1, only the JAX backend is fully tested.
However, it is recommended to call numpy functions
through the pyscfad.numpy
module.
Switching backends#
The numpy-like backends can be specified with the environment variable
PYSCFAD_BACKEND
. By default, the JAX backend is used.
In addition, numpy
and torch
backends may be specified, e.g.,
export PYSCFAD_BACKEND='torch'
With the numpy backend, pyscfad would behave like pyscf, and may be useful when certain methods are not available in pyscf. The torch backend has limited functionality. An example of performing Hartree-Fock calculation with input Fock matrix can be found here.
Numpy#
The numpy functions are registered in the pyscfad.numpy
module,
which is a wrapper to the numpy-like backends.
It is recommended to call numpy functions as follows.
from pyscfad import numpy as np
a = np.ones((4,4))
print(type(a))
w, v = np.linalg.eigh(a)
print(w)
<class 'jaxlib.xla_extension.ArrayImpl'>
[-9.89816667e-16 -3.42450962e-16 -1.23259516e-32 4.00000000e+00]
Scipy#
pyscfad does not provide a scipy wrapper at the moment.
However, the pyscfad.scipy
module contains some custom
scipy functions that may be useful.
For example, pyscfad.scipy.linalg.eigh
extends
jax.scipy.linalg.eigh
to allow for differentiable generalized eigen decompositions.
Similarly, pyscfad.scipy.linalg.svd
extends
jax.scipy.linalg.svd
to allow for differentiation when returning the full matrix.
Other operations#
The pyscfad.ops
module provides useful operations,
most of which are wrappers to JAX functions that are compatible with other backends.
For instance, it contains jit
, vmap
, stop_gradient
, etc.