– or, Forward mode Automatic Diffentiation demystified

All introductions to Automatic Differentiation that I have seen seem to present the technique mysteriously. It’s actually very simple. I’ll describe how.

Since the very first time I learned calculus I have subconsciously understood “differentiating” to mean symbolically manipulating a expression to get another expression which represents its derivative. I understood “calculating the derivative” to mean “differentiating” followed by substuting in a value for a variable and reducing the expression to get the value of the derivative. “Differentiating” (2x + 1)^2 would give 4(2x + 1), and “calculating the derivative” of it at 8 involves substituting 8 for x to get 68.

Then I learned about Automatic Differentiation (AD). It claims to be able to calculate the derivative without differentiating! Given my subconscious understanding of what those terms meant this seemed bizarre, impossible and magical. It is claimed it doesn’t symbolically manipulate the expression at all. Apparently it gets to 68 without going via 4(2x + 1). How on earth can this be possible?

It turns out that AD *is* extremely cool but it *is not* mysterious. When I tried to understand it I was led astray by strange descriptions like

Forward-mode AD is implemented by a nonstandard interpretation of the program in which real numbers are replaced by dual numbers, constants are lifted to dual numbers with a zero epsilon coefficient, and the numeric primitives are lifted to operate on dual numbers. This nonstandard interpretation is generally implemented using one of two strategies: source code transformation or operator overloading.

If you try to read introductions to AD you’ll probably come across a lot of passages like this which describe the rigmarole of shoehorning AD into a bog-standard proceduralish language, but they put an additional veil on the *intrinsic* meaning of AD, rather than lifting one. You’ll probably also get the impression that symbolic expressions are bad.

AD is actually a very simple concept, and yes it can work on symbolic expressions too.

If we have a type `E`

of expressions of one variable `X`

then we can easily write an evaluator for it.

```
{-# LANGUAGE LambdaCase #-}
data E = X | One | Zero | Negate E
| Sum E E | Product E E | Exp E
deriving Show
eval :: Double -> E -> Double
eval x = ev where
ev = \case
X -> x
One -> 1
Zero -> 0
Negate e -> -ev e
Sum e e' -> ev e + ev e'
Product e e' -> ev e * ev e'
Exp e -> exp (ev e)
```

It’s also easy to write a differentiator (implementing the usual rules of calculus)

```
diff :: E -> E
diff = \case
X -> One
One -> Zero
Zero -> Zero
Negate e -> Negate (diff e)
Sum e e' -> Sum (diff e) (diff e')
Product e e' -> Product e (diff e') `Sum` Product (diff e) e'
Exp e -> Exp e `Product` diff e
```

and we can write simple expressions.

```
f :: E -> E
f x = Exp (x `minus` One)
where a `minus` b = a `Sum` Negate b
smallExpression :: E
smallExpression = iterate f X !! 3
bigExpression :: E
bigExpression = iterate f X !! 1000
```

Working with small expressions is easy

```
> smallExpression
Exp (Sum (Exp (Sum (Exp (Sum X (Negate One))) (Negate One))) (Negate One))
> diff smallExpression
Product (Exp (Sum (Exp (Sum (Exp (Sum X (Negate One))) (Negate
One))) (Negate One))) (Sum (Product (Exp (Sum (Exp (Sum X (Negate
One))) (Negate One))) (Sum (Product (Exp (Sum X (Negate One)))
(Sum One (Negate Zero))) (Negate Zero))) (Negate Zero))
```

We can even differentiate them with `diff`

and then evaluate their derivatives with `eval`

```
> mapM_ (print . flip eval (diff smallExpression))
[0.0009, 1, 1.0001]
0.12254834896191881
1.0
1.0003000600100016
```

but calculating the derivatives of big expressions this way is slow, in fact quadratic in the size of the expression.

```
> mapM_ (print . flip eval (diff bigExpression))
[0.00009, 1, 1.00001]
3.2478565715995278e-6
1.0
1.0100754777229357
-- Trust me, that was slow
```

We can calculate derivatives efficiently (in time linear in the size of the expression) by using the *key idea* of AD. The *key idea* is to evaluate both the expression *and* its derivative *at the same time*. I’m not going to go into why this is faster (it’s to do with avoiding redundant calculation) but suffice it to know that this is indeed the unique, key idea of AD.

It’s not even hard to write. For each term, we evaluate the subterms and the derivative of the subterms, and then combine them using the usual rules of calculus. They’re exactly the same rules implemented by `eval`

and `diff`

above, but encoded slightly differently.

```
diffEval :: Double -> E -> (Double, Double)
diffEval x = ev where
ev = \case
X -> (x, 1)
One -> (1, 0)
Zero -> (0, 0)
Negate e -> let (ex, ed) = ev e
in (-ex, -ed)
Sum e e' -> let (ex, ed) = ev e
(ex', ed') = ev e'
in (ex + ex', ed + ed')
Product e e' -> let (ex, ed) = ev e
(ex', ed') = ev e'
in (ex * ex', ex * ed' + ed * ex')
Exp e -> let (ex, ed) = ev e
in (exp ex, exp ex * ed)
```

Take, for example, the branch for `Exp`

```
Exp e -> let (ex, ed) = ev e
in (exp ex, exp ex * ed)
```

It says that the value of the exponential of `e`

at `x`

is `exp ex`

where `ex`

is the value of `e`

at `x`

(this is utterly trivial) and to calculate the derivative of the exponential of `e`

at `x`

we take `exp ex * ed`

, where `ed`

is the value of the derivative of `e`

at `x`

(this is just the definition of the derivative of the exponential function). That last sentence is a very long winded way of giving one tautology and one definition of the derivative of `exp`

! The former is what is done in the `Exp`

branch of `eval`

and the latter is what is done in the `Exp`

branch of `diff`

, only here the latter is numeric rather than symbolic. Basically, there’s nothing going on here. Once we have our key idea, everything else falls out for free; this whole paragraph is a long way of saying nothing at all.

`diffEval`

gives the same results as `diff`

followed by `eval`

```
> mapM_ (print . snd . flip diffEval smallExpression)
[0.0009, 1, 1.0001]
0.12254834896191881
1.0
1.0003000600100016
```

but is much quicker on large expressions

```
> mapM_ (print . snd . flip diffEval bigExpression)
[0.00009, 1, 1.00001]
3.2478565715995278e-6
1.0
1.0100754777229357
-- Trust me, it was fast
```

Even if, when you are reading an introduction to AD, you manage to distinguish a nugget of theory amongst the grime of the implementation details, you will probably still believe you need to write your numerical calculations to work on “dual numbers”, something like

`data D = Dual { value :: Double, derivative :: Double }`

But as we’ve seen above, that too is an implementation detail. Working with symbolic expressions is fine. In fact what `diffEval`

does is give an interpretation of symbolic expressions `E`

into dual numbers `D`

.

Mysterious comment: The distinction between `E`

and `D`

is exactly the same as the distinction between a free monad and a hand-rolled monad that contains the effects interpreted in a particular way.

AD is a very cool idea and at its heart is very simple. There are surely many important details that come later when you want to optimize your AD implementation or extend it to higher dimensions, but for the basics all you need is one key idea, and that is to calculate the value of the derivative at the same time as the value of the expression.

Jared Tobin wrote a nice little extension using catamorphisms. In fact if you use this technique then you can implement `eval`

and `diff`

separately but still get good performance when you compose them! I may write about this later …