– forwards and reverse
This article demonstrates how to perform source transformations on a program to generate forward mode and reverse mode derivative programs (automatic differentiation, or “AD”). My aim is to write the shortest possible article that communicates all the essential features of a source-to-source AD system with a particular focus on making the reverse mode transformation clear.
The goal of brevity means that a lot of possible commentary has been omitted. If you find this makes some part of the article hard to understand then please contact me and I’ll do my best to clarify. In particular this article contains hardly any mathematical content at all. I hope that the reader who is familiar with multivariate calculus will be able to obtain an intuitive understanding of how AD relates to mathematical techniques he or she is already familiar with. A more in-depth description of the relationship will have to wait for another article.
Let’s consider the following pseudocode program that performs some elementary arithmetic through a sequence of assignment statements.
p = 7 * x
r = 1 / y
q = p * x * 5
v = 2 * p * q + 3 * r
x
and y
are not defined in the program so I’m going to informally
consider them to be “inputs”; v
is not used anywhere so I’m going to
consider it to be the “output”. (I won’t burden the article by
formalising these notions here.)
We’ll do a small amount of preparation to our original program which will preserve its behaviour and get it into a form in which it is straightforward to apply the automatic differentiation (AD) algorithms. It is possible to apply AD algorithms without doing these transformations first but then the AD algorithms would have to do equivalent operations implicitly. Doing these transformations first is a kind of separation of concerns.
Let’s use prefix functions instead of infix
operators. Infix
operators are more familiar for arithmetic but the AD algorithms will
be clearer to present if we use prefix functions. Additionally I want
every function to have exactly one argument (although that argument
may be a tuple). Single-argument style will make the reverse mode
transformation much clearer (although it does not make any difference
for forward mode). For example, x1 + x2
would become add (x1, x2)
. Our program becomes
p = mul (7, x)
r = div (1, y)
q = mul (mul (p, x), 5)
v = add (mul (mul (2, p), q), mul (3, r))
Next let’s convert to a form where every function is applied to (tuples of) variables and constants only, i.e. where there are no nested sub-expressions (besides potentially nested tuples). We assign each nested sub-expression to an intermediate variable. For example
a = add (add (b, c), d)
would become
i = add (b, c)
a = add (i, d)
The choice of i
is arbitrary; it just has to be a variable that’s
not used elsewhere in our program. This form without nested
subexpressions is a lot like
ANF from the field of
functional compiler construction. It’s also a lot like the SSA
form of
assembly language. After removing nested subexpressions, our program
becomes
p = mul (7, x)
r = div (1, y)
i1 = mul (p, x)
q = mul (i1, 5)
i2 = mul (2, p)
i3 = mul (i2, q)
i4 = mul (3, r)
v = add (i3, i4)
We have performed all the transformations needed to prepare our program and we are ready to proceed to differentiation. We will differentiate the program line-by-line, that is, both the forward mode and reverse mode differentiation algorithms will generate one line of derivative code for each line of input code. But what is the derivative of an assignment statement? For forward mode, the derivatives correspond quite closely to what you might be familiar with from a first multivariate calculus course..
If a line of our pseudocode program is
y = add (x1, x2)
then the derivative line is
dy = add (dx1, dx2)
If a line of our pseudocode program is
y = mul (x1, x2)
then the derivative line is
dy = add (mul (x2, dx1), mul (x1, dx2))
If a line of our pseudocode program is
y = div (x1, x2)
then the derivative line is
dy = add (div (dx1, x2), negate (mul (div (x1, mul (x2, x2)), dx2)))
The forward mode transformation applies the appropriate differentiation rule to each line in the input program (listed on the left) to obtain a derivative line (listed on the right). To each line we apply exactly one rule and the form of the rule does not depend on any of the other lines.
p = mul (7, x) | dp = mul (7, dx)
r = div (1, y) | dr = negate (div (dy, mul (y, y)))
i1 = mul (p, x) | di1 = add (mul (dp, x), mul (p, dx))
q = mul (i1, 5) | dq = mul (di1, 5)
i2 = mul (2, p) | di2 = mul (2, dp)
i3 = mul (i2, q) | di3 = add (mul (di2, q), mul (i2, dq))
i4 = mul (3, r) | di4 = mul (3, dr)
v = add (i3, i4) | dv = add (di3, di4)
If we form a new program consisting of the sequence of assignments on
the left followed by the sequence of assignments on the right then we
have a program that calculates the forward derivative! The “inputs”
of this program are x
, y
, dx
and dy
. The “outputs” are v
and dv
.
(The derivatives of constants are zero and I’ve left terms that are zero out for simplicity.)
In fact we can be a little more clever. We can interleave the assignments, so an assignment from the left is immediately followed by its corresponding assignment from the right, that is
p = mul (7, x)
dp = mul (7, dx)
r = div (1, y)
dr = negate (div (dy, mul (y, y)))
i1 = mul (p, x)
di1 = add (mul (dp, x), mul (p, dx))
q = mul (i1, 5)
dq = mul (di1, 5)
i2 = mul (2, p)
di2 = mul (2, dp)
i3 = mul (i2, q)
di3 = add (mul (di2, q), mul (i2, dq))
i4 = mul (3, r)
di4 = mul (3, dr)
v = add (i3, i4)
dv = add (di3, di4)
This interleaving demonstrates an important property of the automatic
derivative: that it uses space proportional to the space usage of the
original program. Specifically, as soon as we no longer need a
variable that was assigned in the original program we no longer need
the corresponding d
version either.
We can also see another important property of the forward derivative: it runs in time proportional to the run time of the original program (assuming that the derivative of every primitive runs in time proportional to the run time of the primitive itself).
Now that we’ve shown how to generate the forward mode derivative we can move on to the reverse mode derivative. Reverse mode requires two additional ideas:
We need to convert our original program to “explicit duplication” form: if a variable is used more than once then we make that explicit in the structure of the program. This is unusual but straightforward.
We need to use a form of the derivative that will be unfamiliar to most readers. It will appear quite bizarre when seeing it for the first time but it is crucial to implementing the reverse mode derivative.
Before applying the reverse mode AD transformation we will convert to
“explicit duplication” form. Again, the transformation is not
strictly required but if we omit it then the differentiation pass will
have to do it implicitly. We take the ANF form of the program and
insert explicit duplications (dup
) for any variable that is used
more that once. Recall that after removing nested subexpressions our
program was
p = mul (7, x)
r = div (1, y)
i1 = mul (p, x)
q = mul (i1, 5)
i2 = mul (2, p)
i3 = mul (i2, q)
i4 = mul (3, r)
v = add (i3, i4)
We can see that x
and p
appear on the right hand side (i.e. are
consumed) twice each. Therefore, they will need explicit duplication,
so that each variable in the resulting program is used only once.
With explicit duplication the program looks like
(x1, x2) = dup x
p = mul (7, x1)
(p1, p2) = dup p
r = div (1, y)
i1 = mul (p1, x2)
q = mul (i1, 5)
i2 = mul (2, p2)
i3 = mul (i2, q)
i4 = mul (3, r)
v = add (i3, i4)
(If a variable were used \(n\) times then we would have to insert \(n-1\)
dup
s for it. In our example no variable is used more than twice.)
Notice that now not only is every variable defined exactly once, but
every variable is also used exactly once (except the inputs and
outputs, x
, y
and v
– I won’t say more here about how exactly
these seemingly special cases fit into the story). This property is
important for a reason which will be explained when we come to
generate the reverse mode program.
The line-by-line differentiation rules for generating the reverse mode
need another article to explain thoroughly, but in this article I will
hope to provide some basic intuition via examples and the informal
notion that the reverse mode program calculates how sensitive the
output is to different variables. For example, if the variable y
appears in the original program then the variable d_dy
will appear
in the reverse mode program and measures “how sensitive the output is
to small changes in y
”. (I’ll abbreviate this to “d_dy
is the
sensitivity to y
”.)
If a line of our pseudocode program is
y = add (x1, x2)
then the derivatives are
d_dx1 = d_dy
d_dx2 = d_dy
because the sensitivity to x1
is the same as the sensitivity to y
(and likewise for x2
). This is written on a single line as
(d_dx1, d_dx2) = dup (d_dy)
If a line of our program was
y = mul (x1, x2)
then the derivative line is
(d_dx1, d_dx2) = (mul (x2, d_dy), mul (x1, d_dy))
because the sensitivity to x1
is x2
times the sensitivity to y
(and similarly for x2
).
If a line of our program was
(x1, x2) = dup (x)
then the derivative line is
d_dx = add (d_dx1, d_dx2)
because the sensitivity to x
is the sensitivity to x1
plus the
sensitivity to x2
.
Like forward mode before it, the reverse mode transformation applies the appropriate differentiation rule to each line in the input program (listed on the left) to obtain a derivative line (listed on the right). To each line we apply exactly one rule and the form of the rule does not depend on any of the other lines.
(x1, x2) = dup x | d_dx = add (d_x1, d_dx2)
p = mul (7, x1) | (_, d_dx) = mul (d_dp, (x1, 7))
(p1, p2) = dup p | d_dp = add (d_dp1, d_dp2)
r = div (1, y) | d_dy = negate (div (d_dr, mul (y, y)))
i1 = mul (p1, x2) | (d_dp1, d_dx1) = mul (d_di1, (x, p1))
q = mul (i1, 5) | (d_di1, _) = mul (d_dq, (5, di1)
i2 = mul (2, p2) | (_, d_dp2) = mul (d_di2, (p2, 2))
i3 = mul (i2, q) | (d_di2, d_dq) = mul (d_di3, (q, i2))
i4 = mul (3, r) | (_, d_dr) = mul (d_di4, (r, 3))
v = add (i3, i4) | (d_di3, d_di4) = dup(d_dv)
If we form a new program consisting of the sequence of assignments on
the left followed by the sequence of assignments on the right in
reverse order then we have a program that calculates the reverse
derivative! The “inputs” of this program are x
, y
and d_dv
.
The “outputs” are v
, d_dx
and d_dy
.
(Note that, similar to how in forward mode we omitted derivatives of constants because they are zero, in reverse mode we omit the calculation of derivatives with respect to constants because they have no effect on the rest of the program.)
The explicit duplication property is important because in the code
generated by the reverse mode transformation, usages in the original
program become definitions in the reverse mode program;
correspondingly, definitions in the original program become usages in
the reverse mode program. Therefore it is important that there is
exactly one of each: a variable cannot have two definitions! For
example, consider the original source assignment (x1, x2) = dup x
.
x
is “used” in this line, and x1
and x2
are “defined”. It gives
rise to the assignment d_dx = add (d_x1, d_dx2)
in the reverse mode
program. d_dx
is “defined” in this line and d_dx1
and d_dx2
are
“used”.
Again we see an important property of the automatic derivative: it runs in time proportional to the run time of the original program (assuming that the derivative of every primitive runs in time proportional to the run time of the primitive itself).
If we look carefully we can also see another property of the
reverse derivative: it might use space proportional to the run time of
the original program! Notice that the value of x1
needs to be kept
around throughout the lifetime of the program so that we can calculate
d_dx
. Once we’ve calculated a value we can’t just use it and throw
it away, like we could in forward mode. (There’s a technique called
“checkpointing” to address this which prefers to rerun computations
rather than store their results, decreasing space usage but increasing
run time. Siskind and Pearlmutter
have a useful introduction to checkpointing.)
This is a description of how to differentiate what are called “straight-line” programs, that is it does not cover recursion or loops, or conditionals. Arrays are not explicitly treated here either, although they fit naturally into this framework. Dealing properly with those concepts requires extending the presentation given here into something which would require a significantly longer article.
An in-depth analysis of the sense in which the resulting programs are the “derivative” of the input program will have to wait for another article.
Source-to-source forward and reverse mode automatic differentiation can be expressed as follows
Apply simple transformations to get your program into a form where it can be differentiated line-by-line.
Apply the differentiation rule for each line separately.
The forward mode differentiation rules are quite close to what you might already be familiar with. The reverse mode rules are probably not, which might go some way to explaining why the reverse mode derivative has a reputation for being very mysterious.
If you have any questions then please contact me.
I learned the “explicit duplication form” from Tom Minka’s talk From automatic differentiation to message passing. Thanks to Mark Saroufim and Pashmina Cameron for helpful feedback.