2019/04/23
2020/04/14
機械学習における損失関数の役割や種類
この記事では損失関数について解説します。
損失関数を用いることでモデルの精度を評価し、モデルのパラメータを求めることが可能になります。
今回は損失関数の役割、また用途に応じた様々な損失関数について解説します。
目次
損失関数とは
よりよい予測モデルを作成するためにはどうすればいいのでしょうか。まずモデルの予測精度を評価する必要があります。その次に評価に基づいて適切なモデルのパラメータを求めなくてはなりません。
ではどのようにモデルの予測精度を評価するのでしょうか。
そんなときに用いるのが損失関数です。損失関数は予測と実際の値のズレの大きさを表す関数でモデルの予測精度を評価します。損失関数の値が小さければより正確なモデルと言えます。
この損失関数の評価をもとにモデルのパラメータを算出します。ニューラルネットワークをはじめとした機械学習モデルは損失関数の値が最小となるようなパラメータを様々な方法で求めます。
損失関数がどのように用いられるかについて、線形回帰モデルを例に考えていきましょう。
線形回帰の損失関数
例えば、立地、築年数などの説明変数から家の価格を予測する線形回帰モデルを作りたいとします。どのような予測モデルが理想的でしょうか。実際の値と予測値のズレが小さいモデルが理想的ですね。
実値と予測値のズレを表す関数として以下の平均二乗誤差(meen squared error)が考えられます。
\( MSE( y_i, \hat{y_i}) = \displaystyle \frac{ 1 }{ n } \sum_{i = 1}^{ n } (y_i – \hat{y_i})^2 \)
\( y_iは実値、\hat{y_i} \)は予測値を指す。
線形回帰モデルでは上記の平均二乗誤差を用いてモデルの精度を評価していました。平均二乗誤差、つまり予測と実値のズレを最小にするパラメータ(回帰係数)を最小二乗法(偏微分)で用いていましたね。
この平均二乗誤差も損失関数の一種です。
様々な損失関数
予測する問題や用いるモデル、パラメータの求め方などによって用いる損失関数は変わります。特に回帰問題と分類問題とでは用いる損失関数が大きく異なります。
今回は回帰問題、分類問題それぞれに用いる代表的な損失関数について順に説明します。
まずは回帰問題に用いる損失関数について紹介していきます。
平均二乗誤差 / Mean Squared Error
以下の数式で表されるのが平均二乗誤差です。
\( MSE( y_i, \hat{y_i}) = \displaystyle \frac{ 1 }{ n } \sum_{i = 1}^{ n } (y_i – \hat{y_i})^2 \)
\( y_iは実値、\hat{y_i} \)は予測値を指す。
回帰問題において平均二乗誤差は最もメジャーな損失関数といえるでしょう。線形回帰モデルやニューラルネットワーク、決定木といった様々なモデルにおいて用いられます。
平均二乗誤差の性質として外れ値に対して敏感であることが挙げられます。ですので、外れ値を含むデータに平均二乗誤差を用いてモデルを構築すると、予測結果が不安定になります。
平均絶対誤差 / Mean Absolute Error
以下の数式で表される関数が平均絶対誤差です。
\( MAE( y_i, \hat{y_i}) = \displaystyle \frac{ 1 }{ n } \sum_{i = 1}^{ n } | y_i – \hat{y_i} | \)
\( y_iは実値、\hat{y_i} \)は予測値を指す。
平均絶対誤差の性質として外れ値に強いことが挙げられます。
平均二乗対数誤差 / Mean Squared Logarithmic Error
以下の数式で表される関数が平均二乗対数誤差です。
\( MSLE( y_i, \hat{y_i}) = \displaystyle \frac{ 1 }{ n } \sum_{i = 1}^{ n } \{ log(1 + y_i) – log( \hat{1 + y_i} ) \}^2 \)
\( y_iは実値、\hat{y_i} \)は予測値を指す。
平均二乗対数誤差を用いたモデルは予測が実値を上回りやすくなるという傾向があります。
例えば、来客人数を予測するようなモデルにおいては平均対数誤差が用いらます。来客数を控え目に予想して、多くの客が来てしまうと手が足りなくて困るなんてことが起きますからね。
交差エントロピー誤差 / cross entropy error
この損失関数は分類問題に用います。以下の数式で表されるのが交差エントロピー誤差です。
\( E = -\displaystyle \sum_{ k }^{ } t_k \log y_k \)
\( t_k \)は実際のカテゴリーを0,1を用いて表す(正解であれば1、不正解に対しては0)。\( y_k \)は予測確率。
交差エントロピー誤差もモデルの予測と実値のズレを評価します。ただ今まで説明した損失関数よりも計算が複雑ですね。例を用いて交差エントロピー誤差について考えてみましょう。
今回は写真が犬、猫、馬のどれであるか分類するモデルを例にします。
モデルで写真を判定してみると犬である確率は20%、猫である確率30%、馬である確率が50%であると予測しました。この写真が馬であったとしましょう。この時、クロスエントロピー誤差はいくつになるでしょうか。以下のように求まります。
\( E = – ( 0 \times {\log 0.2} + 0 \times {\log 0.3} + 1 \times {\log 0.5} ) = – \log 0.5 \)
交差エントロピー誤差は実際のカテゴリーに対する予測確率のみを評価します。ですので、犬、猫の予測確率に対しては0を掛けます。
もう少し詳しく見てましょう。どのようなときに交差エントロピー誤差が小さくなるのでしょうか。
上記の例で馬である確率を50%、100%と予測した場合、それぞれについて考えてみましょう。
交差エントロピー誤差をそれぞれ計算すると\( – \log 0.5 = 0.301 \)、\( – \log 1 = 0 \)となります。正解のカテゴリーをうまく予測すると交差エントロピー誤差も小さくなると納得できたのではないでしょうか。
交差エントロピー誤差はニューラルネットワークなどでも用いられる損失関数です。
まとめ
損失関数の役割について説明しました。適切な損失関数を設定することで理想的なモデルが構築できるようになります。
今回紹介した損失関数以外にも様々な損失関数が存在します。必要に応じて、それらの損失関数についても理解を深めるとよいでしょう。
Recommended