In another article about demystifying Automatic Differentiation (AD) I explained how to use the key idea of AD to calculate derivatives in time linear in the size of a symbolic expression, after mentioning that the naive approach is quadratic. But why is the naive approach quadratic? In this article I’ll explain.
Recall the diffEval
function which implements the key idea of AD.
diffEval :: Double -> E -> (Double, Double)
= ev where
diffEval x = \case
ev X -> (x, 1)
One -> (1, 0)
Zero -> (0, 0)
Negate e -> let (ex, ed) = ev e
in (-ex, -ed)
Sum e e' -> let (ex, ed) = ev e
= ev e'
(ex', ed') in (ex + ex', ed + ed')
Product e e' -> let (ex, ed) = ev e
= ev e'
(ex', ed') in (ex * ex', ex * ed' + ed * ex')
Exp e -> let (ex, ed) = ev e
in (exp ex, exp ex * ed)
If you look at how each expression node is handled you’ll notice that there is one recursive call per child node, leading to linear run time. An alternative implementation could have been this:
diffEvalSlow :: Double -> E -> (Double, Double)
= ev where
diffEvalSlow x = \case
ev X -> (x, 1)
One -> (1, 0)
Zero -> (0, 0)
Negate e -> let ex = fst (ev e)
= snd (ev e)
ed in (-ex, -ed)
Sum e e' -> let ex = fst (ev e)
= snd (ev e)
ed = fst (ev e')
ex' = snd (ev e')
ed' in (ex + ex', ed + ed')
Product e e' -> let ex = fst (ev e)
= snd (ev e)
ed = fst (ev e')
ex' = snd (ev e')
ed' in (ex * ex', ex * ed' + ed * ex')
Exp e -> let ex = fst (ev e)
= snd (ev e)
ed in (exp ex, exp ex * ed)
diffEvalSlow
has two recursive calls per child node, leading to
quadratic run time.
> mapM_ (print . snd . flip diffEval bigExpression)
[0.00009, 1, 1.00001]
3.2478565715995278e-6
1.0
1.0100754777229357
-- ^^ Trust me, it was fast
> mapM_ (print . snd . flip diffEvalSlow bigExpression)
[0.00009, 1, 1.00001]
3.2478565715995278e-6
1.0
1.0100754777229357
-- ^^ Trust me, it was slow
There is a similar reason for the slowness of the naive approach to
calculating the derivative of symbolic expressions. The naive
approach is to first diff
the expression, and then eval
it.
diff
and eval
are as follows:
eval :: Double -> E -> Double
= ev where
eval x = \case
ev X -> x
One -> 1
Zero -> 0
Negate e -> -ev e
Sum e e' -> ev e + ev e'
Product e e' -> ev e * ev e'
Exp e -> exp (ev e)
diff :: E -> E
= \case
diff X -> One
One -> Zero
Zero -> Zero
Negate e -> Negate (diff e)
Sum e e' -> Sum (diff e) (diff e')
Product e e' -> Product e (diff e') `Sum` Product (diff e) e'
Exp e -> Exp e `Product` diff e
You can see that, like diffEval
, both of these functions have only
one recursive call per child node, leading to linear run time in the
size of their input. So what’s going on? Why is composing them
quadratic?
Have a look at the Product
case
Product e e' -> Product e (diff e') `Sum` Product (diff e) e'
There are only two recursive calls to diff
, but it doubles the
number of child nodes from two to four. The key observation is
The size of the output of
diff
can be quadratic in the size of the input!
eval
then runs over this quadratic sized output, so the run time of
eval
composed with diff
can be quadratic in the size of the
original input.
The Exp
case doubles the number of nodes too, as would potential
cases for Sin
, Cos
or any other function where the derivative is
given by the chain rule.
The key idea of AD is to evaluate e
and its derivative in one
recursive call, leading to linear run time, rather than doubling the
amount of work which needs to be done at each level, which leads to
quadratic run time.