正則化ができるRidge、Lasso、Elastic Netは、正則化の重みαで結果が違います。
αは、実際に試してみないと、どの値が良いかわかりません。
RidgeCV, LassoCV, ElasticNetCVを使うと、簡単に探索できます。
これらのモデルでは、最も良いαはalpha_データ属性で確認できます。
まずは、RidgeCVを使ってみましょう。
# 出力時に小数点以下3桁に
%precision 3
import pandas as pd
from sklearn.linear_model import RidgeCV
from sklearn.metrics import r2_score
df = pd.read_csv("input/boston.csv") # ボストン市の住宅価格データ一式
X, y = df.iloc[:, :-1], df.iloc[:, -1] # サンプルデータ
%precision 3
とすると、出力時に小数点以下3桁で表示します。
ただし、NumPyの多次元配列の要素の場合は、この指定が有効になりません。その場合は、float(...)
でfloatに変換しています。
# αの探索
ridge = RidgeCV(scoring='r2')
ridge.fit(X, y)
ridge.alpha_
0.1
scoring=’r2’とし、評価関数として決定係数を指定します。
alphasは、デフォルトで、(0.1, 1.0, 10.0)なので、この3つから探索します。結果は、alpha_で確認できます。0.1となったので、さらに詳しく調べます。
結果的に、αは0に近い方が良いことがわかります。
# さらにαの探索
ridge = RidgeCV([0.0001, 0.001, 0.01, 0.1], scoring='r2')
ridge.fit(X, y)
ridge.alpha_
0.01
# 係数の確認
ridge.coef_
array([-1.07954381e-01, 4.64363781e-02, 2.00760695e-02, 2.68500996e+00, -1.76521422e+01, 3.81076688e+00, 5.90349860e-04, -1.47388018e+00, 3.05780593e-01, -1.23436529e-02, -9.51477563e-01, 9.31764990e-03, -5.24884869e-01])
coef_で係数を確認します。
# 決定係数の確認
float(r2_score(y, ridge.predict(X)))
0.741
r2_score(y, ridge.predict(X))で決定係数を確認します。
RidgeCV
RidgeCV
を使うと、指定されたαの中から最もスコアがよいものを探索してくれます(CV
は、クロスバリデーションです)。
RidgeCV
の主なパラメーターは以下のようになります。
オプション | デフォルト | 説明 |
---|---|---|
alphas | (0.1, 1.0, 10.0) | 探索対象 |
fit_intercept | True | 切片が0でないか |
scoring | None | 評価関数の指定 |
cv | None | クロスバリデーションのパラメーター |
cvオプションで整数を指定すると、クロスバリデーションの分割数を指定できます。
cv=Noneで一般クロスバリデーション(Generalized Cross-Validation)になります。一般クロスバリデーションは効率的なので、cv=Noneの利用はお勧めです。
クロスバリデーションするので、トレーニングデータとテストデータに分割する必要はありません。
RidgeCV
のscoring
オプションで指定できる文字列は、sklearn.metrics.SCORERS
で確認できます。
from sklearn.metrics import SCORERS
SCORERS.keys()
dict_keys(['explained_variance', 'r2', 'max_error', 'neg_median_absolute_error',
'neg_mean_absolute_error', 'neg_mean_absolute_percentage_error',
'neg_mean_squared_error', 'neg_mean_squared_log_error', 'neg_root_mean_squared_error',
'neg_mean_poisson_deviance', 'neg_mean_gamma_deviance', 'accuracy', 'top_k_accuracy',
'roc_auc', 'roc_auc_ovr', 'roc_auc_ovo', 'roc_auc_ovr_weighted', 'roc_auc_ovo_weighted',
'balanced_accuracy', 'average_precision', 'neg_log_loss', 'neg_brier_score',
'adjusted_rand_score', 'rand_score', 'homogeneity_score', 'completeness_score',
'v_measure_score', 'mutual_info_score', 'adjusted_mutual_info_score',
'normalized_mutual_info_score', 'fowlkes_mallows_score', 'precision', 'precision_macro',
'precision_micro', 'precision_samples', 'precision_weighted', 'recall', 'recall_macro',
'recall_micro', 'recall_samples', 'recall_weighted', 'f1', 'f1_macro', 'f1_micro', 'f1_samples',
'f1_weighted', 'jaccard', 'jaccard_macro', 'jaccard_micro', 'jaccard_samples', 'jaccard_weighted'])
コメント