ニューラルネットワークのミニバッチ、オンライン学習

 2019/04/23    ニューラルネットワーク    

ニューラルネットワークではどれだけのデータを用いて重みパラメータの更新を行えばいいのでしょう。

学習するためのデータの扱い方によって、バッチ学習、ミニバッチ学習、オンライン学習といった学習手法が存在します。

オンライン学習は確率的勾配降下法、またはSGD(Stochastic Gradient Descent)と呼ばれることもあります。

この記事ではこれらの学習手法について説明します。

ニューラルネットワークの学習

ニューラルネットワークの重みパラメータを学習させるためにはデータが必要です。この学習に用いるデータを訓練データと呼びます。

ニューラルネットワークでは勾配法を用いた以下の作業を繰り返していました。

 

1 ある地点\( x_0 \)における傾きを求める。

2 新たに探索地点\( x_1 \)を傾きと学習率\( \eta \)を用いて更新する。

3 傾きが0となる\( x \)を見つけるまで1~2の作業を繰り返す。

 

すべての訓練データを用いて学習するのが最も効率的な方法なのでしょうか。用いるデータ量について工夫したらどうなるでしょうか。

用いるデータ数を減らせば損失関数の値や傾きを求めるための計算量が減ります。

バッチ学習

パラメータの更新にすべての訓練データを用いるのがバッチ学習です。以下のようにパラメータ更新を行います。

 

1 全ての訓練データを用いて地点\( x_0 \)における傾きを求める。

2 新たな探索地点\( x_1 \)を傾きと学習率\( \eta \)を用いて更新する。

3 傾きが0となる\( x \)が見つかるまで、1-2の更新を繰り返す。

 

バッチ学習の性質として学習結果が安定しやすいことがあげられます。これは全データを用いた損失関数の変化を考えるためです。

一方で新たな学習データを追加するたびに、全データを用いて再度計算を行わなければならないという欠点があります。

このバッチ学習は全データ数が少ないときに有効な学習手法です。

ミニバッチ学習

\( N \)個の訓練データのなかから一部、\( n \)個を取り出し、パラメータの更新をするのがミニバッチ学習です。取り出した訓練データをミニバッチと呼びます。また取り出すデータ数\( n \)をミニバッチサイズと呼びます。以下のように更新を行います。

 

1 N個のデータからランダムに\( n \)個を取り出す。

2 \( n \)個のデータを用いて地点\( x_0 \)における傾きを求める。

3 新たな探索地点\( x_1 \)を傾きと学習率\( \eta \)を用いて求める。

4 新たに\( n \)個のデータを取り出して2-3の更新を行う。

5 1-4の更新を繰り返す。

 

ミニバッチ学習の性質として学習の停滞に陥りにくいことがあげれらます。学習の停滞は局所地点における傾きが0となることによって、パラメータの更新が行われなくなる現象を指します。

学習が停滞しづらい理由として以下のように考えることができます。

ミニバッチ学習は用いるデータ数が少ないため、パラメータの変化に対して損失関数が敏感に反応します。そのため、傾きが0となる局所地点が少なく結果として学習が停滞しづらくなります。

ミニバッチ学習はニューラルネットワークでよく用いられる学習手法です。

オンライン学習

訓練データの一つを取り出してパラメータの更新をするのがオンライン学習です。以下のように更新を行います。

 

1 N個のデータからランダムに1つのデータを取り出す。

2 \( 1 \)個のデータを用いて地点\( x_0 \)における傾きを求める。

3 新たな探索地点\( x_1 \)を傾きと学習率\( \eta \)を用いて更新する。

4 新たに一つのデータを取り出して2-3の更新を行う。

5 1-4の更新を繰り返す。

 

オンライン学習はミニバッチ学習と同様、局所解に陥りにくいという性質があります。

一方で結果が不安定になりやすいという欠点があります。不安定になる原因として一つ一つのデータに対して更新を行うため、はずれ値にも反応しやすいことが挙げられます。

まとめ

今回はバッチ、ミニバッチ、オンライン学習といった学習手法について説明してました。

各学習手法によって損失関数を求める、更新に用いるためのデータサイズが異なると理解するとよいでしょう。

 

  • 人気の投稿とページ

  • コメントを残す

    メールアドレスが公開されることはありません。 * が付いている欄は必須項目です