Automatic Differentiation for nonlinear operators

Authors
Affiliations

Automatic Differentiation for nonlinear operators

@Author: Francesco Picetti - picettifrancesco@gmail.com

In this notebook, we build a new kind of operator that leverages a new entry in the PyTorch ecosystem, functorch. Basically, it allows for the computation of the Jacobian-vector product given a function defined with PyTorch primitives.

import torch
import occamypy as o

o.backend.set_seed_everywhere()

try:
    from functorch import jvp
except ModuleNotFoundError:
    raise ModuleNotFoundError("This submodule requires functorch to be installed. Do it with:\n\tpip install functorch")
WARNING! DATAPATH not found. The folder /tmp will be used to write binary files
/nas/home/fpicetti/miniconda3/envs/occd/lib/python3.10/site-packages/dask_jobqueue/core.py:20: FutureWarning: tmpfile is deprecated and will be removed in a future release. Please use dask.utils.tmpfile instead.
  from distributed.utils import tmpfile
class OperatorAD(o.Operator):

    def __init__(self, domain: o.VectorTorch, range: o.VectorTorch, fwd_fn, background: o.VectorTorch = None, name: str = None):
        """
        Generic operator whose forward is defined, and adjoint is computed with automatic differentiation
        
        Args:
            domain: operator domain vector
            range: operator range vector
            fwd_fn: torch-compatible forward function
            background: vector in which the Jacobian will be computed
            name: function name for print purpose
        """
        self.fwd_fn = fwd_fn
        
        # store the vector to be multiplied with the Jacobian
        self.domain_tensor = torch.ones_like(domain[:])
        # store the domain vector in which the function will be linearized
        self.background = background if background is not None else domain

        super(OperatorAD, self).__init__(domain=domain, range=range)
        self.name = "OperatorAD" if name is None else name
        
    def forward(self, add, model, data):
        self.checkDomainRange(model, data)
        
        if not add:
            data.zero()
        data[:] += self.fwd_fn(model[:])
        return
    
    def adjoint(self, add, model, data):
        self.checkDomainRange(model, data)

        grad = jvp(self.fwd_fn, (self.background[:],), (self.domain_tensor,))[1]
        
        if not add:
            model.zero()
        model[:] += grad * data[:]
        return
    
    def set_background(self, in_content: o.VectorTorch):
        """
        Set the background vector in which the Jacobian will be computed

        Args:
            in_content: background vector
        """
        self.background.copy(in_content)

Case 1: linear function f(x)=xf(x)=x

x = o.VectorTorch(torch.tensor([1., 2., 3.]))
print("x =", x)
x = tensor([1., 2., 3.])

Forward function

f = lambda x: x

Analytical gradient

g = lambda x: 1.

Instantiate the operator

A = OperatorAD(domain=x, range=x, fwd_fn=f, name="x")
print("A is %s" % A.name)
A is x

Forward computation: y=f(x)=Ax\mathbf{y}=f(\mathbf{x}) = \mathbf{A} \mathbf{x}

y = A * x
assert torch.allclose(y[:], f(x[:]))
print("y = f(x) =", y)
y = f(x) = tensor([1., 2., 3.])

Adjoint computation: z=fxx=x0y\mathbf{z} = \frac{\partial f}{\partial \mathbf{x}}\Bigr|_{\mathbf{x}=\mathbf{x}_0} \mathbf{y}

  1. set the linearization domain where the gradient is computed (default is the domain vector passed to the __init__)
A.set_background(x)
print("The gradient ∂f/∂x will be computed at x0 =", A.background)
The gradient ∂f/∂x will be computed at x0 = tensor([1., 2., 3.])
  1. compute the gradient and multiply it with data
z = A.H * y
assert torch.allclose(z[:], g(A.background[:]) * y[:])
print("z = ∂f/∂x0 * y =", z)
z = ∂f/∂x0 * y = tensor([1., 2., 3.])

Let's use another linearization point x0\mathbf{x}_0

x0 = x.clone().rand()
A.set_background(x0)
print("The gradient ∂f/∂x0 will be computed on x0 =", A.background)
The gradient ∂f/∂x0 will be computed on x0 = tensor([-0.0075,  0.5364, -0.8230])
z = A.H * y
assert torch.allclose(z[:], g(A.background[:]) * y[:])
print("z = ∂f/∂x0 * y =", z)
z = ∂f/∂x0 * y = tensor([1., 2., 3.])

As f(x)=xf(x)=x is linear, we can compute the dot product test for operator A\mathbf{A}

A.dotTest(True)
Dot-product tests of forward and adjoint operators
--------------------------------------------------
Applying forward operator add=False
 Runs in: 0.00019025802612304688 seconds
Applying adjoint operator add=False
 Runs in: 0.0007958412170410156 seconds
Dot products add=False: domain=3.146009e-01 range=3.146009e-01 
Absolute error: 0.000000e+00
Relative error: 0.000000e+00 

Applying forward operator add=True
 Runs in: 3.981590270996094e-05 seconds
Applying adjoint operator add=True
 Runs in: 0.0004761219024658203 seconds
Dot products add=True: domain=6.292018e-01 range=6.292018e-01 
Absolute error: 0.000000e+00
Relative error: 0.000000e+00 

-------------------------------------------------

Case 2: nonlinear function f(x)=xsin(x)f(x) = x \sin(x)

x = o.VectorTorch(torch.tensor([1., 2., 3.]))
print("x =", x)
x = tensor([1., 2., 3.])

Forward function

f = lambda x: x * torch.sin(x)

Analytical gradient

g = lambda x: x * torch.cos(x) + torch.sin(x)

Instantiate the operator

A = OperatorAD(domain=x, range=x, fwd_fn=f, name="x sin(x)")
print("A is %s" % A.name)
A is x sin(x)

Forward computation: y=f(x)=Ax\mathbf{y}=f(\mathbf{x}) = \mathbf{A} \cdot \mathbf{x}

 y = A * x
assert torch.allclose(y[:], f(x[:]))
print("y = f(x) =", y)
y = f(x) = tensor([0.8415, 1.8186, 0.4234])

Adjoint computation: z=fxx=x0y\mathbf{z} = \frac{\partial f}{\partial \mathbf{x}}\Bigr|_{\mathbf{x}=\mathbf{x}_0} \cdot \mathbf{y}

print("The gradient ∂f/∂x0 will be computed on x0 =", A.background)
The gradient ∂f/∂x0 will be computed on x0 = tensor([1., 2., 3.])
z = A.H * y
assert torch.allclose(z[:], g(A.background[:]) * y[:])
print("z = ∂f1/∂x0 * y =", z)
z = ∂f1/∂x0 * y = tensor([ 1.1627,  0.1400, -1.1976])

Let's use another linearization point x0\mathbf{x}_0

x0 = x.clone().randn()
A.set_background(x0)
print("The gradient ∂f/∂x0 will be computed on x0 =", A.background)
The gradient ∂f/∂x0 will be computed on x0 = tensor([ 0.6472,  0.2490, -0.3354])
z = A.H * y
assert torch.allclose(z[:], g(A.background[:]) * y[:])
print("z = ∂f/∂x0 * y =", z)
z = ∂f/∂x0 * y = tensor([ 0.9419,  0.8870, -0.2734])

Finally, we can wrap A\mathbf{A} into a NonlinearOperator and compute the linearization test

B = o.NonlinearOperator(A)
_, _ = B.linTest(x, plot=True)
<Figure size 432x216 with 1 Axes>