Simple mean-field calculations#
Energy derivatives w.r.t. molecular parameters#
The most straightforward application of pyscfad is to compute energy derivatives w.r.t. the parameters of the Mole
object.
Currently, three parameters are supported, including nuclear coordinates Mole.coords
, and exponentes Mole.exp
and contraction coefficients Mole.ctr_coeff
of the basis functions. A typical energy derivative calculation involves the following steps.
1. Define the Mole object#
The Mole
object constructor follows the same syntax as that of pyscf. In addtion, one can control whether to trace (compute the derivatives w.r.t.) the above mentioned parameters. The default is to trace all of them.
from pyscfad import gto
mol = gto.Mole()
mol.atom = "H 0 0 0; H 0 0 0.74"
mol.basis = "6-31G*"
mol.verbose = 0
mol.build(trace_coords=True, trace_exp=True, trace_ctr_coeff=True)
2. Define the energy function#
The energy function takes the Mole
object as the input, and returns the energy, which is a scalar. In this example, we compute the Hartree-Fock energy.
from pyscfad import scf
def hf_energy(mol):
mf = scf.RHF(mol)
ehf = mf.kernel()
return ehf
3. Compute the gradient#
We use jax as the backend to trace the computational graph and perform the gradient calculation. See e.g., jax.value_and_grad
.
import jax
ehf, grad = jax.value_and_grad(hf_energy)(mol)
print(f'RHF energy (in Eh): {ehf}')
RHF energy (in Eh): -1.1267553171969316
The gradients w.r.t. each parameter are stored as attributes of grad
, which is also a Mole
object.
print(grad)
<pyscfad.gto.mole.Mole object at 0x7f364c1b7e90>
print(f'Nuclear gradient:\n{grad.coords}')
Nuclear gradient:
[[ 0. 0. -0.00756136]
[ 0. 0. 0.00756136]]
print(f'Energy gradient w.r.t. basis function exponents:\n{grad.exp}')
Energy gradient w.r.t. basis function exponents:
[-8.02030941e-05 1.27267947e-03 1.29202851e-02 -3.61927384e-02]
print(f'Energy gradient w.r.t. basis function contraction coefficients:\n{grad.ctr_coeff}')
Energy gradient w.r.t. basis function contraction coefficients:
[ 2.36262161e-03 4.68735066e-03 -5.34074485e-03 8.42659276e-13]
4. Higher order derivatives#
Higher order derivatives can also be computed, although with much higer memory footprint. Two functions,
jax.jacfwd
and
jax.jacrev
,
compute the Jacobian with forward- and reverse-mode differentiation, respectively.
hessian = jax.jacfwd(jax.grad(hf_energy))(mol)
print(f'Energy Hessians\n'
f'∂^2E/∂R^2: {hessian.coords.coords.shape}\n'
f'∂^2E/∂R∂ε: {hessian.coords.exp.shape}\n'
f'∂^2E/∂R∂c: {hessian.coords.ctr_coeff.shape}\n')
Energy Hessians
∂^2E/∂R^2: (2, 3, 2, 3)
∂^2E/∂R∂ε: (2, 3, 4)
∂^2E/∂R∂c: (2, 3, 4)
Note
Only first-order derivatives w.r.t. Mole.exp
and Mole.ctr_coeff
are available at the moment.
Third-order derivatives w.r.t. nuclear coordinates can be computed similarly.
third_order_deriv = jax.jacfwd(jax.jacfwd(jax.grad(hf_energy)))(mol)
print(f'∂^3E/∂R^3: {third_order_deriv.coords.coords.coords.shape}')
∂^3E/∂R^3: (2, 3, 2, 3, 2, 3)