自動微分を簡単な例で理解する

自動微分は、デリバティブプライシングではXVAの感応度など、多くのインプットパラメーターについての微分を一斉に求めるのに使われている。機械学習でも、誤差逆伝播法で明示的に用いられている。これら応用例はどちらもリバースモードの自動微分を用いている。

自動微分には2種類ある。

(1)フォワードモード

(2)リバースモード

応用上はリバースモードが重要だが、フォワードモードとリバースモードを対比することでリバースモードの理解が深まるため、以下では2つを両方取り上げて、簡単な例で説明してみる。

例として、以下の関数を考える。

f(u, v) = uv

u(x, y) = x + 3y

v(x, y) = 2x + y

つまり、uとvは、(x, y)の関数になっており、関数fは、関数uと関数vの関数になっている。応用上は、もっと関数の入れ子構造(ネスト)が多く、一つ一つの関数ももっと複雑なものだが、ここでは簡単な関数を例にとって説明する。

(1)フォワードモード

記号の準備:

$u, $vなどと$を付けると、それはxについての微分か、あるいは、yについての微分、のどちらかを表すものとする。すなわち、$uは、uをxで微分したものか、あるいは、uをyで微分したものの、どちらかを表す。

このとき、$xは、xについての微分であれば1、yについての微分であれば0である。なぜなら、xをxで微分すれば1、xをyで微分すれば0であるからだ。

一方で、$yは、xについての微分であれば0、yについての微分であれば1である。すると、

($x = 1 かつ $y = 0) もしくは ($x =0 かつ $y = 1)

となる。左の( )はxについての微分を表すケース、右の( )はyについての微分を表すケースである。このことから、記号$は、xについての微分とyについての微分を切り替える役割を担っていることがわかる。この記号を導入することにより、xについての微分とyについての微分を1つの式にまとめて書くことができる。

次に、$uと$vがどうなるかを考えてみよう。u, vはどちらもx, yについての関数であるから、チェーンルールにより、

$u = (∂u/∂x) $x + (∂u/∂y) $y = 1 $x + 3 $y = $x + 3$y

$v = (∂v/∂x) $x + (∂v/∂y) $y = 2 $x + 1 $y = 2$x + $y

と書ける。同様に、$fはどうなるかを見てみる。fはu, vについての関数であるから、チェーンルールにより、

$f = (∂f/∂u) $u + (∂f/∂v) $v = v $u + u $v

と書ける。ここで、$u, $vはすでに上で求めており、$x, $yを使って表せている。$x, $yは、($x = 1 かつ $y = 0) もしくは ($x =0 かつ $y = 1)となることがわかっている。

以上から、関数fのxについての微分、関数fのyについての微分を求めるには、$fを計算すればよいが、それは以下の順序で計算されることがわかる。

($x, $y) -> ($u, $v) -> $f

つまり、f(u(x, y), v(x, y))と入れ子になっている関数の、内側から外側に向かって計算していく。これがフォワードモードの自動微分である。

しかし注意点として、今回の例の場合、出発点の$xと$yがとりうる値が、2通りに分かれている、という点である。つまりfのxについての微分と、fのyについての微分を両方得るには、($x, $y) -> ($u, $v) -> $f とたどっていく計算を、2回行わないといけない。すなわち、($x = 1 かつ $y = 0)の場合の計算と、($x =0 かつ $y = 1)の計算をそれぞれ行わないといけない。

したがって一般に、インプットのパラメーターが多く、アウトプットの数が少ないケースでは、フォワードモードを使うことはない。むしろ、インプットのパラメーターが少なく、アウトプットの数が多いケースに向いている。

応用上出くわすケースはたいてい、インプットのパラメーターが多く、アウトプットの数が1つとかで少ないケースである。このように多くのパラメーターについての微分を一斉に計算するのに自動微分が用いられることが多いが、その場合は以下に示すリバースモードを用いることになる。

(2)リバースモード

記号の準備:

&u, &vなどと&を付けると、それは最も外側の関数fをその変数について微分したものを表すとする。つまり、&uは、fをuで微分したもの、&vは、fをvで微分したものになる。フォワードモードでは記号がインプット (x, y) 目線であったのに対し、リバースモードでは記号がアウトプット (f) 目線であることに注意。

この記号を用いて、今度はfから順番に見ていくと、

&f = df/df = 1

である。次に&u, &vをチェーンルールで求めると、

&u = (∂f/∂u) (df/df) = v &f

&v = (∂f/∂v) (df/df) = u &f

となる。最後に、&x, &yを見てみると、チェーンルールにより、

&x = (∂f/∂x) = (∂f/∂u)(∂u/∂x) + (∂f/∂v)(∂v/∂x) = &u 1 + &v 2

= &u + 2&v

&y = (∂f/∂y) = (∂f/∂u)(∂u/∂y) + (∂f/∂v)(∂v/∂y) = &u 3 + &v 1

= 3&u + &v

以上から、関数fのxについての微分、関数fのyについての微分を求めるには、&x, &yを計算すればよいが、それは以下の順序で計算されることがわかる。

&f -> (&u, &v) -> (&x, &y)

つまり、f(u(x, y), v(x, y))と入れ子になっている関数の、外側から内側に向かって計算していく。これがリバースモードの自動微分である。リバースモードは通常とは逆方向に計算が流れていくことに注意。&u, &vが求まれば、それらから&x, &yの両方が一斉に出るのが特徴である。

ここで重要なのは、インプットパラメーターが増えても、アウトプットが f ひとつだけであれば、&x, &yはいずれもfを微分したものであるから、フォワードモードのように場合分けして同様の計算を複数行う必要がない、ということだ。このように、リバースモードはインプットパラメーターが多く、アウトプットが少ないケースに用いられる。