Backends#

The future development of pyscfad may support multiple numpy-like backends with a universal implementation. The work is still on-going. As of version 0.1, only the JAX backend is fully tested. However, it is recommended to call numpy functions through the pyscfad.numpy module.

Switching backends#

The numpy-like backends can be specified with the environment variable PYSCFAD_BACKEND. By default, the JAX backend is used. In addition, numpy and torch backends may be specified, e.g.,

export PYSCFAD_BACKEND='torch'

With the numpy backend, pyscfad would behave like pyscf, and may be useful when certain methods are not available in pyscf. The torch backend has limited functionality. An example of performing Hartree-Fock calculation with input Fock matrix can be found here.

Numpy#

The numpy functions are registered in the pyscfad.numpy module, which is a wrapper to the numpy-like backends. It is recommended to call numpy functions as follows.

from pyscfad import numpy as np

a = np.ones((4,4))
print(type(a))

w, v = np.linalg.eigh(a)
print(w)
<class 'jaxlib.xla_extension.ArrayImpl'>
[-9.89816667e-16 -3.42450962e-16 -1.23259516e-32  4.00000000e+00]

Scipy#

pyscfad does not provide a scipy wrapper at the moment. However, the pyscfad.scipy module contains some custom scipy functions that may be useful. For example, pyscfad.scipy.linalg.eigh extends jax.scipy.linalg.eigh to allow for differentiable generalized eigen decompositions. Similarly, pyscfad.scipy.linalg.svd extends jax.scipy.linalg.svd to allow for differentiation when returning the full matrix.

Other operations#

The pyscfad.ops module provides useful operations, most of which are wrappers to JAX functions that are compatible with other backends. For instance, it contains jit, vmap, stop_gradient, etc.