Ridge回帰のαの探索

正則化ができるRidge、Lasso、Elastic Netは、正則化の重みαで結果が違います。
αは、実際に試してみないと、どの値が良いかわかりません。
RidgeCV, LassoCV, ElasticNetCVを使うと、簡単に探索できます。
これらのモデルでは、最も良いαはalpha_データ属性で確認できます。

まずは、RidgeCVを使ってみましょう。

# 出力時に小数点以下3桁に
%precision 3
import pandas as pd
from sklearn.linear_model import RidgeCV
from sklearn.metrics import r2_score

df = pd.read_csv("input/boston.csv")  # ボストン市の住宅価格データ一式
X, y = df.iloc[:, :-1], df.iloc[:, -1]  # サンプルデータ

%precision 3とすると、出力時に小数点以下3桁で表示します。
ただし、NumPyの多次元配列の要素の場合は、この指定が有効になりません。その場合は、float(...)でfloatに変換しています。

# αの探索
ridge = RidgeCV(scoring='r2')
ridge.fit(X, y)
ridge.alpha_
0.1

scoring=’r2’とし、評価関数として決定係数を指定します。
alphasは、デフォルトで、(0.1, 1.0, 10.0)なので、この3つから探索します。結果は、alpha_で確認できます。0.1となったので、さらに詳しく調べます
結果的に、αは0に近い方が良いことがわかります。

# さらにαの探索
ridge = RidgeCV([0.0001, 0.001, 0.01, 0.1], scoring='r2')
ridge.fit(X, y)
ridge.alpha_
0.01
# 係数の確認
ridge.coef_
array([-1.07954381e-01,  4.64363781e-02,  2.00760695e-02,  2.68500996e+00,
       -1.76521422e+01,  3.81076688e+00,  5.90349860e-04, -1.47388018e+00,
        3.05780593e-01, -1.23436529e-02, -9.51477563e-01,  9.31764990e-03,
       -5.24884869e-01])

coef_で係数を確認します。

# 決定係数の確認
float(r2_score(y, ridge.predict(X)))
0.741

r2_score(y, ridge.predict(X))で決定係数を確認します。

RidgeCV

RidgeCVを使うと、指定されたαの中から最もスコアがよいものを探索してくれます(CVは、クロスバリデーションです)。

RidgeCVの主なパラメーターは以下のようになります。

オプションデフォルト説明
alphas(0.1, 1.0, 10.0)探索対象
fit_interceptTrue切片が0でないか
scoringNone評価関数の指定
cvNoneクロスバリデーションのパラメーター

cvオプションで整数を指定すると、クロスバリデーションの分割数を指定できます。
cv=Noneで一般クロスバリデーション(Generalized Cross-Validation)になります。一般クロスバリデーションは効率的なので、cv=Noneの利用はお勧めです。

クロスバリデーションするので、トレーニングデータとテストデータに分割する必要はありません。

RidgeCVscoringオプションで指定できる文字列は、sklearn.metrics.SCORERSで確認できます。

from sklearn.metrics import SCORERS
SCORERS.keys()
dict_keys(['explained_variance', 'r2', 'max_error', 'neg_median_absolute_error',
'neg_mean_absolute_error', 'neg_mean_absolute_percentage_error',
'neg_mean_squared_error', 'neg_mean_squared_log_error', 'neg_root_mean_squared_error',
'neg_mean_poisson_deviance', 'neg_mean_gamma_deviance', 'accuracy', 'top_k_accuracy',
'roc_auc', 'roc_auc_ovr', 'roc_auc_ovo', 'roc_auc_ovr_weighted', 'roc_auc_ovo_weighted',
'balanced_accuracy', 'average_precision', 'neg_log_loss', 'neg_brier_score',
'adjusted_rand_score', 'rand_score', 'homogeneity_score', 'completeness_score',
'v_measure_score', 'mutual_info_score', 'adjusted_mutual_info_score',
'normalized_mutual_info_score', 'fowlkes_mallows_score', 'precision', 'precision_macro',
'precision_micro', 'precision_samples', 'precision_weighted', 'recall', 'recall_macro',
'recall_micro', 'recall_samples', 'recall_weighted', 'f1', 'f1_macro', 'f1_micro', 'f1_samples',
'f1_weighted', 'jaccard', 'jaccard_macro', 'jaccard_micro', 'jaccard_samples', 'jaccard_weighted'])

コメント

タイトルとURLをコピーしました