– or, Forward mode Automatic Diffentiation demystified
All introductions to Automatic Differentiation that I have seen seem to present the technique mysteriously. It’s actually very simple. I’ll describe how.
Since the very first time I learned calculus I have subconsciously understood “differentiating” to mean symbolically manipulating a expression to get another expression which represents its derivative. I understood “calculating the derivative” to mean “differentiating” followed by substuting in a value for a variable and reducing the expression to get the value of the derivative. “Differentiating” (2x + 1)^2 would give 4(2x + 1), and “calculating the derivative” of it at 8 involves substituting 8 for x to get 68.
Then I learned about Automatic Differentiation (AD). It claims to be able to calculate the derivative without differentiating! Given my subconscious understanding of what those terms meant this seemed bizarre, impossible and magical. It is claimed it doesn’t symbolically manipulate the expression at all. Apparently it gets to 68 without going via 4(2x + 1). How on earth can this be possible?
It turns out that AD is extremely cool but it is not mysterious. When I tried to understand it I was led astray by strange descriptions like
Forward-mode AD is implemented by a nonstandard interpretation of the program in which real numbers are replaced by dual numbers, constants are lifted to dual numbers with a zero epsilon coefficient, and the numeric primitives are lifted to operate on dual numbers. This nonstandard interpretation is generally implemented using one of two strategies: source code transformation or operator overloading.
If you try to read introductions to AD you’ll probably come across a lot of passages like this which describe the rigmarole of shoehorning AD into a bog-standard proceduralish language, but they put an additional veil on the intrinsic meaning of AD, rather than lifting one. You’ll probably also get the impression that symbolic expressions are bad.
AD is actually a very simple concept, and yes it can work on symbolic expressions too.
If we have a type E
of expressions of one variable X
then we can
easily write an evaluator for it.
{-# LANGUAGE LambdaCase #-}
data E = X | One | Zero | Negate E
| Sum E E | Product E E | Exp E
deriving Show
eval :: Double -> E -> Double
eval x = ev where
ev = \case
X -> x
One -> 1
Zero -> 0
Negate e -> -ev e
Sum e e' -> ev e + ev e'
Product e e' -> ev e * ev e'
Exp e -> exp (ev e)
It’s also easy to write a differentiator (implementing the usual rules of calculus)
diff :: E -> E
diff = \case
X -> One
One -> Zero
Zero -> Zero
Negate e -> Negate (diff e)
Sum e e' -> Sum (diff e) (diff e')
Product e e' -> Product e (diff e') `Sum` Product (diff e) e'
Exp e -> Exp e `Product` diff e
and we can write simple expressions.
f :: E -> E
f x = Exp (x `minus` One)
where a `minus` b = a `Sum` Negate b
smallExpression :: E
smallExpression = iterate f X !! 3
bigExpression :: E
bigExpression = iterate f X !! 1000
Working with small expressions is easy
> smallExpression
Exp (Sum (Exp (Sum (Exp (Sum X (Negate One))) (Negate One))) (Negate One))
> diff smallExpression
Product (Exp (Sum (Exp (Sum (Exp (Sum X (Negate One))) (Negate
One))) (Negate One))) (Sum (Product (Exp (Sum (Exp (Sum X (Negate
One))) (Negate One))) (Sum (Product (Exp (Sum X (Negate One)))
(Sum One (Negate Zero))) (Negate Zero))) (Negate Zero))
We can even differentiate them with diff
and then evaluate their
derivatives with eval
> mapM_ (print . flip eval (diff smallExpression))
[0.0009, 1, 1.0001]
0.12254834896191881
1.0
1.0003000600100016
but calculating the derivatives of big expressions this way is slow, in fact quadratic in the size of the expression.
> mapM_ (print . flip eval (diff bigExpression))
[0.00009, 1, 1.00001]
3.2478565715995278e-6
1.0
1.0100754777229357
-- Trust me, that was slow
We can calculate derivatives efficiently (in time linear in the size of the expression) by using the key idea of AD. The key idea is to evaluate both the expression and its derivative at the same time. I’m not going to go into why this is faster (it’s to do with avoiding redundant calculation) but suffice it to know that this is indeed the unique, key idea of forward mode AD.
It’s not even hard to write. For each term, we evaluate the subterms
and the derivative of the subterms, and then combine them using the
usual rules of calculus. They’re exactly the same rules implemented
by eval
and diff
above, but encoded slightly differently.
diffEval :: Double -> E -> (Double, Double)
diffEval x = ev where
ev = \case
X -> (x, 1)
One -> (1, 0)
Zero -> (0, 0)
Negate e -> let (ex, ed) = ev e
in (-ex, -ed)
Sum e e' -> let (ex, ed) = ev e
(ex', ed') = ev e'
in (ex + ex', ed + ed')
Product e e' -> let (ex, ed) = ev e
(ex', ed') = ev e'
in (ex * ex', ex * ed' + ed * ex')
Exp e -> let (ex, ed) = ev e
in (exp ex, exp ex * ed)
Take, for example, the branch for Exp
Exp e -> let (ex, ed) = ev e
in (exp ex, exp ex * ed)
It says that the value of the exponential of e
at x
is exp ex
where ex
is the value of e
at x
(this is utterly trivial) and to
calculate the derivative of the exponential of e
at x
we take exp ex * ed
, where ed
is the value of the derivative of e
at x
(this is just the definition of the derivative of the exponential
function). That last sentence is a very long winded way of giving one
tautology and one definition of the derivative of exp
! The former
is what is done in the Exp
branch of eval
and the latter is what
is done in the Exp
branch of diff
, only here the latter is numeric
rather than symbolic. Basically, there’s nothing going on here. Once
we have our key idea, everything else falls out for free; this whole
paragraph is a long way of saying nothing at all.
diffEval
gives the same results as diff
followed by eval
> mapM_ (print . snd . flip diffEval smallExpression)
[0.0009, 1, 1.0001]
0.12254834896191881
1.0
1.0003000600100016
but is much quicker on large expressions
> mapM_ (print . snd . flip diffEval bigExpression)
[0.00009, 1, 1.00001]
3.2478565715995278e-6
1.0
1.0100754777229357
-- Trust me, it was fast
Even if, when you are reading an introduction to AD, you manage to distinguish a nugget of theory amongst the grime of the implementation details, you will probably still believe you need to write your numerical calculations to work on “dual numbers”, something like
data D = Dual { value :: Double, derivative :: Double }
But as we’ve seen above, that too is an implementation detail.
Working with symbolic expressions is fine. In fact what diffEval
does is give an interpretation of symbolic expressions E
into dual
numbers D
.
Mysterious comment: The distinction between E
and D
is exactly the
same as the distinction between a free monad and a hand-rolled monad
that contains the effects interpreted in a particular way.
Forward mode AD is a very cool idea and at its heart is very simple. There are surely many important details that come later when you want to optimize your AD implementation or extend it to higher dimensions, but for the basics all you need is one key idea, and that is to calculate the value of the derivative at the same time as the value of the expression.
Jared Tobin wrote a nice little
extension using
catamorphisms. In fact if you use this technique then you can
implement eval
and diff
separately but still get good performance
when you compose them! I may write about this later …