cha_kabuのNotebooks

個人的な機械学習関連勉強のアウトプット置き場です。素人の勉強録なので、こちらに辿り着いた稀有な方、情報はあまり信じない方が身のためです。

ゼロから作るDeep Learning3 フレームワーク編を読む その⑦ステップ41

はじめに

シリーズ記事の続編で前記事はこちら↓です。 cha-kabu.hatenablog.com

ですが、今回の内容的にはこちら↓の記事との情報の関連性が高いです。 cha-kabu.hatenablog.com

ステップ41 行列の積

行列積を計算するMatMulクラス、そしてそのインスタンス化とforward処理を行ってくれるmatmul関数を実装します。実装コード自体は今までの延長で難しいことは無いのですが、行列積の逆伝播を理解するのが数学苦手にはしんどいです…魔のステップ37に引き続き、適宜書籍よりも低いレベルから情報をまとめていきます。

参考にしたサイト様

以下の先人たちのまとめを参考にさせて頂きました。本記事で紹介する逆伝播の勾配算出方法はどのページとも異なるのですが、とても参考になりました。なお、本記事に誤った情報があった場合は当然ながらすべて私の理解力の無さが原因であり、引用先の皆様の責任は一切ありません。

頭の中に思い浮かべた時には

Pythonと機械学習

Qiita 【機械学習】誤差逆伝播法のコンパクトな説明

Qiita 行列の和と積の誤差逆伝播法の証明

諸注意

まず、以降の説明で混乱しないために以下の用語を抑えておいてください。

出力1つ 出力複数
入力1つ (スカラ値)関数 ベクトル値関数
入力複数 多変数(スカラ値)関数 多変数ベクトル値関数

ステップ37のまとめでも、一応多変数ベクトル値関数の微分までをまとめました。今回は行列積なので行列の行列微分ではありません。というのも未だに良く分かってないのですが、どうやら「行列の行列微分」ってもの自体が存在しないみたいなんですよね(参考:おしえてgoo)。

「行列積の微分なんだから行列を行列で微分できないとダメじゃん!」と思っていたのですが、先にネタバレをしてしまうと、「行列積の微分そのもの」に注目するのではなく、「最終出力\displaystyle{L}(スカラ)を微分するとはどういうことか?」から考えていって逆伝播時に下流に流す勾配を求めます。図にすると以下の感じです。(matmulは行列積を行う関数、太字の大文字は行列を表します。)

f:id:cha_kabu:20201117000600p:plain

今までの逆伝播実装では、その時点での勾配を直接(数式的に)求め、上流から流れてくる勾配と掛け合わせて下流に流す実装を行ってきました。\displaystyle{\frac{\partial L}{\partial x}=\frac{\partial L}{\partial y}\frac{\partial y}{\partial x}}\displaystyle{\frac{\partial y}{\partial x}}を直接求める方法です。一方、今回はこれまでのものとは違い、逆算的に(?)\displaystyle{\frac{\partial L}{\partial \bf{X}}}はどうやって表すことができるか?に向かって数式を組み立てていく方法になります。

その際重要になってくるのが連鎖律です。スカラ値関数の連鎖律は今までも多用してきたので問題ないと思いますが、合成関数に多変数関数やベクトル値関数が含まれる場合の連鎖律についてまとめた後、逆伝播計算について考えていきます。

なお、以降ベクトルや行列を使った表現が出てきますがそれぞれ幾何的なイメージ(ベクトルは矢印、とか)は持たない方が納得しやすいと思います。単純に、「ベクトルや行列の形でまとめて書いた方がスッキリ書けるからそうしているだけ」で、表記のために使っていると割り切らないと頭がごちゃごちゃになっていきます。自分は「行列とベクトルの積は線形写像だからこの計算の意味は…」とか考え始めてドツボに嵌りました。

目次的なもの

諸注意で記載した通り、連鎖律→逆伝播を実際にやってみる流れでまとめていきます。

  1. 多変数関数の連鎖律

  2. 逆伝播を実際にやってみる

多変数関数の連鎖律

まずは連鎖律についてまとめます。序盤のステップでも連鎖律自体は使われており、見た目同じなので混乱は少ないと思います。ただし、序盤で出ていた連鎖律は単変数関数のものでした。多変数関数になると形は似ているのですが、「総和をとる」点が異なってきます。

簡単な具体例

まずは証明抜きに簡単な二変数関数の連鎖律を例を見ていきます。\displaystyle{z}\displaystyle{u}\displaystyle{v}の関数、\displaystyle{u}\displaystyle{x}の関数、\displaystyle{v}\displaystyle{y}の関数とします。数式で表すと以下の通りです。

\displaystyle{



z = f(u,v)\\
u = g(x)\\
v = h(y)

}

この時、\displaystyle{z}\displaystyle{x}偏微分した値は以下の数式で表すことができます。

\displaystyle{



\dfrac{\partial z}{\partial x}=\dfrac{\partial z}{\partial u}\dfrac{\partial u}{\partial x}+\dfrac{\partial z}{\partial v}\dfrac{\partial v}{\partial x}

}

更に具体的な例で本当に正しそうか見てみます。例えばu=x^{2}+x+1、v=2y+1として、\displaystyle{z=u+v=(x^ {2}+x+1)+(2y+1)}となる場合を考えます。この\displaystyle{z}を数式的に\displaystyle{x}偏微分するのは簡単ですね。\displaystyle{2x+1}になります。この結果を先ほどの連鎖律を使って出した結果と見合わせます。

f:id:cha_kabu:20201117000623p:plain

確かにあっていそうです。

一般化

先ほどの例は二変数に限ったものでした。これを三変数、四変数、…の時にも適用できるように一般化すると、以下の様に表せられます。

\displaystyle{



fがu_{1},u_{2},...,u_{m}の関数で、u_{1},u_{2},...,u_{m}がxの関数の時、\\
\begin{align}
\dfrac{\partial f}{\partial x}=\sum_{j=1}^{m}\dfrac{\partial f}{\partial u_{j}}\dfrac{\partial u_{j}}{\partial x} \tag{1}
\end{align}

}

変数がいくつになっても良い様に、多くの変数を\displaystyle{u_{1},u_{2},...,u_{m}}で表しているのが文字に慣れていないと混乱しますが、先ほどの二変数の例で言うと\displaystyle{u}\displaystyle{u_{1}}に、\displaystyle{v}\displaystyle{u_{2}}に対応しています。先ほどの例から\displaystyle{v}が消えて\displaystyle{u}が増えたわけではないのでご注意ください。別々のアルファベットで表そうとしても最大26文字で使い切ってしまうので、仕方なく添え字で区別し、添え字を使うことで総和記号を使って短く書けるようになっているだけです。

こちらの証明については難しいので諦めた長くなるのでこちら九州大学講義ノートなどでご確認ください。

なお、式\displaystyle{(1)}を見て「二変数関数の例の\displaystyle{y}にあたる様な、\displaystyle{x}と対になる変数はどこにいったの?」と混乱してしまう人もいるかも知れません。これは式\displaystyle{(1)}では\displaystyle{x}についての偏微分のみを問題にしており、変数\displaystyle{y}は単なる定数(\displaystyle{x}偏微分すると0)と見なせるからです。\displaystyle{y}について同じことをやりたければ、式\displaystyle{(1)}\displaystyle{x}\displaystyle{y}と置き換えるだけです。

特殊形

さて、式\displaystyle{(1)}はある特殊な条件下では実はもっと簡単な形で書くことができます。特殊な条件とは、\displaystyle{u_{1}}\displaystyle{x}の関数、\displaystyle{u_{2}}\displaystyle{y}の関数、…といった様にそれぞれの\displaystyle{u}が一つの変数しか持たないときです。具体例で見た方が分かりやすいと思うので、\displaystyle{u_{1},u_{2},u_{3}}がそれぞれ\displaystyle{x,y,z}の一つだけを変数に持つ場合を考えます。

f:id:cha_kabu:20201117000647p:plain

結局残るのは黒字部分になります。そう考えると結局0になる部分も含めて総和をとる必要はないので式\displaystyle{(1)}\displaystyle{\Sigma}は無くすことができ、以下の様に書き換えられます。

\displaystyle{



fがu_{1},u_{2},...,u_{m}の関数で、u_{x}がxだけを変数に持つ関数の時、\\
\begin{align}\dfrac{\partial f}{\partial x}=\dfrac{\partial f}{\partial u_{x}}\dfrac{\partial u_{x}}{\partial x} \tag{1'}\end{align}

}

\displaystyle{u_{x}}は一般的な表記ではないかと思いますが、ここでは「\displaystyle{x}だけを変数にもつ関数」の意味で使用しています。

こちらの特殊形はあくまで特殊形で、公式というよりかは式\displaystyle{(1)}の条件を指定しただけのものなのですが、後で出てきますので覚えておいてください。

式の表現を変えてみる

ここまでのところで多変数関数の連鎖律について学びましたのでもう具体的に逆伝播を考えても良いのですが、少し脱線して「これまでの式って書き換えることができるよね」という話です。逆伝播を考える際にも出てくると言えば出てきますが、文字いっぱいで辛ければ飛ばしてください。

\displaystyle{(1)}を見直してみると、ベクトルや行列、その内積を使って表記ができることに気づきます。式を見るよりも展開して書いてみた方が分かりやすいかと思いますので、書き下してみます。なお、こちらでは書籍に合わせて出力を行ベクトルとして表し、かつ転置の記号は付けていません。しかしDeepLearning以外の文脈で多変数関数の連鎖律を調べると、出力を列ベクトルで表記していることがほとんどです。結果が縦に並んでいるか横に並んでいるかだけの違いで本質的には同じですが、数式はそれらのものとは異なりますのでご注意ください。

また、変数を\displaystyle{x,y,z,...}で表していると数に限りがあるので、ここからは\displaystyle{x_{1},x_{2},x_{3}...}の形で表すことにします。

f:id:cha_kabu:20201117000710p:plain

画像内最後の数式に注目頂きたいのですが、この式は参照として記載している多変数ベクトル値関数を全微分した式ととても似た形をしています。ただし、(出力を行ベクトルで表したので)右辺の成分が全微分のときのヤコビ行列×変化量の形から、変化量×ヤコビ行列の形に逆転しています。出力を列ベクトルに合わせるとこの逆転は元に戻るのですが、同時にヤコビ行列(緑字部分)は転置の形になります。

出力も多変数だったら(多変数ベクトル値関数の連鎖律)

これまで見てきた数式展開では、出力はスカラ値であることを前提としていました。出力も多変数、すなわち\displaystyle{f}が多変数ベクトル値関数の場合にどうなるかを見てみます。こちらの数式は書き下すと大変なので、コンパクトにまとめます。

f:id:cha_kabu:20201117000740p:plain

黄色のところは参考までに記載しています。赤の行列を\displaystyle{\bf{W}}、青を\displaystyle{\bf{V}}、緑を\displaystyle{\bf{U}}と呼ぶならば、\displaystyle{\bf{W}_{11}}成分の計算方法はの1行目\displaystyle{\bf{V}}成分と緑の行列の1列目成\displaystyle{\bf{U}}分の内積で表せられます。

逆伝播を実際にやってみる

少し脱線しましたがいよいよ行列積の逆伝播について考えていきます。

本記事で今まで使用してきた記号体系と変わってしまいますが、書籍と表記を合わせて以下の逆伝播を考えます。ただし行列とベクトルの表記については一部書籍に従わず、行列は大文字の太字、ベクトルは小文字の太字で表すことにします。

※書籍では\displaystyle{\bf{X}}がベクトルの時のことを先に考えていますが、行列について考えれば網羅できるので本記事では省略します。

f:id:cha_kabu:20201117000757p:plain

目標は図の中の\displaystyle{\frac{\partial L}{\partial \bf{X}}}\displaystyle{\frac{\partial L}{\partial \bf{W}}}がどう計算できるかを考えることです。\displaystyle{\frac{\partial L}{\partial \bf{Y}}}については、今回上流の計算が無いので実際の値は分かりませんが、実際に逆伝播するときには上流の計算は終わっているはずので既知と仮定します(\displaystyle{\frac{\partial L}{\partial \bf{X}}}\displaystyle{\frac{\partial L}{\partial \bf{W}}}を表す計算式の中に残っても良い)。また\displaystyle{L}は最終出力のスカラです。大文字ですが行列ではないのでご注意ください。

前準備①\displaystyle{\bf{Y}}について

逆伝播の前に、順伝播の時に計算される\displaystyle{\bf{Y}}とは何なのか、具体的に見ておきます。順伝播の図から\displaystyle{\bf{X}}\displaystyle{\bf{W}}の積であることは明らかですが、その各要素の計算は以下の様になります(意味ありげに書いていますが、ただの行列積です)。

\displaystyle{



\begin{pmatrix}
y_{11} & \cdots & y_{1h} & \cdots & y_{1H} \\
\vdots & \ddots & \vdots & \ddots & \vdots \\
y_{n1} & \cdots & y_{nh} & \cdots & y_{nH} \\
\vdots & \ddots & \vdots & \ddots & \vdots \\
y_{N1} & \cdots & y_{Nh} & \cdots & y_{NH} \\
\end{pmatrix}
=
\begin{pmatrix}
x_{11} & \cdots & x_{1d} & \cdots & x_{1D} \\
\vdots & \ddots & \vdots & \ddots & \vdots \\
x_{n1} & \cdots & x_{nd} & \cdots & x_{nD} \\
\vdots & \ddots & \vdots & \ddots & \vdots \\
x_{N1} & \cdots & x_{Nd} & \cdots & x_{ND} \\
\end{pmatrix}
\begin{pmatrix}
w_{11} & \cdots & w_{1h} & \cdots & w_{1H} \\
\vdots & \ddots & \vdots & \ddots & \vdots \\
w_{d1} & \cdots & w_{dh} & \cdots & w_{dH} \\
\vdots & \ddots & \vdots & \ddots & \vdots \\
w_{D1} & \cdots & w_{Dh} & \cdots & w_{DH} \\
\end{pmatrix}

}

行列の積の定義から明らかですが、例えば\displaystyle{y_{11}}の計算は\displaystyle{\bf{X}}の1行目成分と\displaystyle{\bf{W}}の一列目成分の内積で計算され、以下の様に表されます。

\displaystyle{



y_{11}=x_{11}w_{11}+x_{12}w_{21}+…x_{1d}w_{d1}+…x_{1D}w_{D1}

}

これを一般化すると、\displaystyle{\bf{Y}}\displaystyle{nh}成分\displaystyle{y_{nh}}は以下の様に計算されます。

\displaystyle{



y_{nh}=\sum_{d=1}^{D}x_{nd}w_{dh}

}

\displaystyle{n,h}\displaystyle{\Sigma}の中にある変数なので段々数を増やしてループしたくなりますが、今回は\displaystyle{\Sigma}の影響は受けない添え字ですのでご注意ください。\displaystyle{y_{nh}}をどれかに定めるとそれと一緒に決定的に決まる変数です。\displaystyle{\Sigma}の影響を受けるのは\displaystyle{x}の列番号と\displaystyle{w}の行番号で、それぞれ常に同じ値\displaystyle{d}となります。また、仮定された行列のサイズより\displaystyle{d}の最大値は\displaystyle{D}です。

前準備②この記事だけで使う記法の確認

後々の説明を考え、他ではあまり見ないオリジナルの記法を用いたいと思います。行列\displaystyle{\bf{A}}の1行目成分(行ベクトル)を、\displaystyle{\bf{a_1}}の形で「太小文字に行番号を表す添え字」で表します。今回出てくる記号について具体的に書くと以下図の通りです。

f:id:cha_kabu:20201117000820p:plain

前準備はここまでです。それでは\displaystyle{\frac{\partial L}{\partial \bf{X}}}\displaystyle{\frac{\partial L}{\partial \bf{W}}}をどうやって求めるのか、個別に見ていきましょう。

\displaystyle{\bf{X}}方向に流れる勾配

\displaystyle{\frac{\partial L}{\partial \bf{X}}}を求めていきます。スカラの行列微分ですので、サイズは\displaystyle{\bf{X}}と同じ\displaystyle{(N \times D)}になります。この後どうやってこれを求めていくかですが、\displaystyle{\bf{X}}の一要素\displaystyle{x_{nd}}での偏微分\displaystyle{\frac{\partial L}{\partial x_{nd}}}について考えていき、後でそれを\displaystyle{\frac{\partial L}{\partial \bf{X}}}に拡張します。

\displaystyle{x_{nd}}と添え字が変数の状態だと分かりづらくなってしまうので、下図の通り\displaystyle{x_{11},x_{12},x_{21}}について具体的に見ていって、後から一般化したいと思います。

f:id:cha_kabu:20201117000829p:plain

まず、3つを代表して\displaystyle{x_{11}}\displaystyle{L}偏微分する―\displaystyle{x_{11}}が少し動くと\displaystyle{L}はどう変わるのかを求める―際の連鎖律を考えます。連鎖律を考える際、途中で\displaystyle{\partial \bf{Y}}が何かしらの形で中継地点として出てくるのは想像できるかと思いますが、\displaystyle{x_{11}}\displaystyle{\bf{Y}}にどのような影響を与えているのでしょうか?

ここで前準備のところで導出した、\displaystyle{y_{nh}=\sum_{d=1}^ {D}x_{nd}w_{dh}}を思い出してください。\displaystyle{x_{nd}}\displaystyle{d}(列方向成分)については総和をとるので多くの\displaystyle{y}に関わりそうですが、\displaystyle{x}の行方向成分が\displaystyle{n}であれば、\displaystyle{y}の行方向成分も\displaystyle{n}で決まります。つまり、\displaystyle{x_{11}}が影響を与えるのは、\displaystyle{\bf{Y}}の1行目成分\displaystyle{\bf{y}_{1}}のみということが分かります。そう考えると連鎖律の特殊形が使え(総和記号が不要)、\displaystyle{\frac{\partial L}{\partial x_{11}}}他は以下の様に表すことができます。

\displaystyle{



\dfrac{\partial L}{\partial x_{11}}=\dfrac{\partial L}{\partial {\bf{y}}_{1}}\dfrac{\partial {\bf{y}}_{1}}{\partial x_{11}}\\

\dfrac{\partial L}{\partial x_{12}}=\dfrac{\partial L}{\partial {\bf{y}}_{1}}\dfrac{\partial {\bf{y}}_{1}}{\partial x_{12}}\\

\dfrac{\partial L}{\partial x_{21}}=\dfrac{\partial L}{\partial {\bf{y}}_{2}}\dfrac{\partial {\bf{y}}_{2}}{\partial x_{21}}

}

続いて引き続き\displaystyle{x_{11}}を代表に、連鎖律の際右辺\displaystyle{\frac{\partial \bf{y}_{1}}{\partial x_{11}}}が何なのかについて考えます。\displaystyle{\bf{Y}}の1行目成分\displaystyle{\bf{y}_{1}}\displaystyle{x_{11}}偏微分していますが、そもそも\displaystyle{\bf{y}_{1}}はどのように求められるものだったでしょうか。各要素\displaystyle{y_{nh}}\displaystyle{n=1}の時に限って考えることになるので、\displaystyle{y_{1h}=\sum_{d=1}^ {D}x_{1d}w_{dh}}で求められます。そしてそれぞれ\displaystyle{x_{11}}について偏微分するとペアになる\displaystyle{w_{1h}}だけが残ることになり、\displaystyle{\frac{\partial \bf{y}_{1}}{\partial x_{11}}=\bf{w}_{1}}になります。少しややこしいのでこのことを書き下してみます。

f:id:cha_kabu:20201117000839p:plain

ということで、他も同様に考えて各連鎖律の式は以下の様にアップデートできます。

\displaystyle{



\dfrac{\partial L}{\partial x_{11}}=\dfrac{\partial L}{\partial \bf{y}_{1}}{\bf{w}}_{1}\\

\dfrac{\partial L}{\partial x_{12}}=\dfrac{\partial L}{\partial \bf{y}_{1}}{\bf{w}}_{2}\\

\dfrac{\partial L}{\partial x_{21}}=\dfrac{\partial L}{\partial \bf{y}_{2}}{\bf{w}}_{1}

}

良い感じに\displaystyle{\frac{\partial L}{\partial \bf{X}}}の各要素が分かってきたので、具体的に行列に落とし込んでみます。繰り返しになりますが\displaystyle{\frac{\partial L}{\partial \bf{X}}}はスカラ\displaystyle{L}を行列\displaystyle{\bf{X}}偏微分しているので、スケールは\displaystyle{(N \times D)}であることを思い出しておいてください。

f:id:cha_kabu:20201117000850p:plain

なんだか凶悪な行列に変化しましたね…しかし、実はこれでもうほぼ完成です!変形後の行列の各要素を見てみます。\displaystyle{\bf{w}_d}の次元数は、\displaystyle{\bf{W}}の列数に該当するので\displaystyle{H}です。そして\displaystyle{\frac{\partial L}{\partial \bf{y}_{n}}}偏微分になっているので分かりにくいですがスカラのベクトル微分なので要はベクトルで、次元数は\displaystyle{\bf{Y}}の列数に該当するのでこちらも\displaystyle{H}です。となると、凶悪に見える各要素はただの次元数\displaystyle{H}のベクトルどうしの内積でしかなく、行列積で表せそうです!パズルの様にどういう行列積で表せられるか考え(慣れてたら一瞬で分かるのでしょうが…)、書いてみます。

f:id:cha_kabu:20201117000858p:plain

だいぶスッキリしました!\displaystyle{\frac{\partial L}{\partial \bf{Y}}}のサイズは\displaystyle{(N \times H)}\displaystyle{\bf{W}^ {T}}のサイズは\displaystyle{(H \times D)}なので行列積が成り立つ条件も満たしています。さらに\displaystyle{\frac{\partial L}{\partial \bf{Y}}}は上流から流れてくる勾配なので既知、\displaystyle{\bf{W}^ {T}}も順伝播の時に使用するものを転置しただけなので既知でしたので、無事\displaystyle{\bf{X}}方向の下流に計算可能な勾配を流すことができることが分かりました。

\displaystyle{\bf{W}}方向に流れる微分

続いて\displaystyle{\frac{\partial L}{\partial \bf{W}}}を求めますが、求め方としては\displaystyle{\frac{\partial L}{\partial \bf{X}}}と同様の考えで求めることができます。しかしまた同じような話を書いてもしょうが無いので、少しズルをします。先ほど求めた通り、\displaystyle{\frac{\partial L}{\partial \bf{X}}=\frac{\partial L}{\partial \bf{Y}}{\bf{W}}^ {T}}です。また、\displaystyle{\bf{Y=XW}}であり、転置の性質より\displaystyle{\bf{Y^ {T}=W^ {T}X^ {T}}}です。これを求めたものに対応させて置き換えて求めます。図で書くと以下の様な感じです。

f:id:cha_kabu:20201117000907p:plain

こちらも既知の成分で表すことができました!

最後に

以上でステップ41のまとめ終了です。ステップ37と合わせて1週間くらい時間かかりました…個人的な一番の学びは、「テンソル微分は意味は考えずに"そう表せるからそうしてるだけ"と考えた方が分かりやすい」です。なんちゃら写像やらなんりゃら座標やらの知識を中途半端に掻い摘んで沼に嵌っていきました…次のステップからはスピード上げていきたいところです。