Tutorials # Automatic Differentiation for nonlinear operators

Authors
Francesco Picetti
Ettore Biondi

# #Automatic Differentiation for nonlinear operators

@Author: Francesco Picetti - picettifrancesco@gmail.com

In this notebook, we build a new kind of operator that leverages a new entry in the PyTorch ecosystem, functorch. Basically, it allows for the computation of the Jacobian-vector product given a function defined with PyTorch primitives.

import torch
import occamypy as o

o.backend.set_seed_everywhere()

try:
from functorch import jvp
except ModuleNotFoundError:
raise ModuleNotFoundError("This submodule requires functorch to be installed. Do it with:\n\tpip install functorch")
class OperatorAD(o.Operator):

def __init__(self, domain: o.VectorTorch, range: o.VectorTorch, fwd_fn, background: o.VectorTorch = None, name: str = None):
"""
Generic operator whose forward is defined, and adjoint is computed with automatic differentiation

Args:
domain: operator domain vector
range: operator range vector
fwd_fn: torch-compatible forward function
background: vector in which the Jacobian will be computed
name: function name for print purpose
"""
self.fwd_fn = fwd_fn

# store the vector to be multiplied with the Jacobian
self.domain_tensor = torch.ones_like(domain[:])
# store the domain vector in which the function will be linearized
self.background = background if background is not None else domain

self.name = "OperatorAD" if name is None else name

def forward(self, add, model, data):
self.checkDomainRange(model, data)

data.zero()
data[:] += self.fwd_fn(model[:])
return

self.checkDomainRange(model, data)

grad = jvp(self.fwd_fn, (self.background[:],), (self.domain_tensor,))

model.zero()
model[:] += grad * data[:]
return

def set_background(self, in_content: o.VectorTorch):
"""
Set the background vector in which the Jacobian will be computed

Args:
in_content: background vector
"""
self.background.copy(in_content)


## #Case 1: linear function $f(x)=x$

x = o.VectorTorch(torch.tensor([1., 2., 3.]))
print("x =", x)

Forward function

f = lambda x: x

g = lambda x: 1.

Instantiate the operator

A = OperatorAD(domain=x, range=x, fwd_fn=f, name="x")
print("A is %s" % A.name)

Forward computation: $\mathbf{y}=f(\mathbf{x}) = \mathbf{A} \mathbf{x}$

y = A * x
assert torch.allclose(y[:], f(x[:]))
print("y = f(x) =", y)

Adjoint computation: $\mathbf{z} = \frac{\partial f}{\partial \mathbf{x}}\Bigr|_{\mathbf{x}=\mathbf{x}_0} \mathbf{y}$

1. set the linearization domain where the gradient is computed (default is the domain vector passed to the __init__)
A.set_background(x)
print("The gradient ∂f/∂x will be computed at x0 =", A.background)
1. compute the gradient and multiply it with data
z = A.H * y
assert torch.allclose(z[:], g(A.background[:]) * y[:])
print("z = ∂f/∂x0 * y =", z)

Let's use another linearization point $\mathbf{x}_0$

x0 = x.clone().rand()
A.set_background(x0)
print("The gradient ∂f/∂x0 will be computed on x0 =", A.background)
z = A.H * y
assert torch.allclose(z[:], g(A.background[:]) * y[:])
print("z = ∂f/∂x0 * y =", z)

As $f(x)=x$ is linear, we can compute the dot product test for operator $\mathbf{A}$

A.dotTest(True)

## #Case 2: nonlinear function $f(x) = x \sin(x)$

x = o.VectorTorch(torch.tensor([1., 2., 3.]))
print("x =", x)

Forward function

f = lambda x: x * torch.sin(x)

g = lambda x: x * torch.cos(x) + torch.sin(x)

Instantiate the operator

A = OperatorAD(domain=x, range=x, fwd_fn=f, name="x sin(x)")
print("A is %s" % A.name)

Forward computation: $\mathbf{y}=f(\mathbf{x}) = \mathbf{A} \cdot \mathbf{x}$

 y = A * x
assert torch.allclose(y[:], f(x[:]))
print("y = f(x) =", y)

Adjoint computation: $\mathbf{z} = \frac{\partial f}{\partial \mathbf{x}}\Bigr|_{\mathbf{x}=\mathbf{x}_0} \cdot \mathbf{y}$

print("The gradient ∂f/∂x0 will be computed on x0 =", A.background)
z = A.H * y
assert torch.allclose(z[:], g(A.background[:]) * y[:])
print("z = ∂f1/∂x0 * y =", z)

Let's use another linearization point $\mathbf{x}_0$

x0 = x.clone().randn()
A.set_background(x0)
print("The gradient ∂f/∂x0 will be computed on x0 =", A.background)
z = A.H * y
assert torch.allclose(z[:], g(A.background[:]) * y[:])
print("z = ∂f/∂x0 * y =", z)

Finally, we can wrap $\mathbf{A}$ into a NonlinearOperator and compute the linearization test

B = o.NonlinearOperator(A)
_, _ = B.linTest(x, plot=True)