機械学習実践 前処理 python

import numpy as np
import pandas as pd
import pandas_profiling as pdp
import matplotlib.pyplot as plt
%matplotlib inline
from mpl_toolkits.mplot3d import Axes3D
import seaborn as sns
import pickle
from sklearn.model_selection import train_test_split,GridSearchCV
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix, f1_score, roc_curve, roc_auc_score
from sklearn.ensemble import RandomForestClassifier
import xgboost as xgb
import warnings
warnings.simplefilter('ignore')
csvDataPath = 'UCI_Credit_Card_light.csv'
#csvDataPath = 'UCI_Credit_Card.csv'
data = pd.read_csv(csvDataPath, engine='python', encoding='utf-8')
y_target = data['default.payment.next.month']
col = data.columns.tolist()
col.remove('default.payment.next.month')
x_explanatory = data[col]
x_train,x_test,y_train, y_test = train_test_split(
    x_explanatory, y_target,stratify = y_target,random_state=1)

pickle.load()を用いてモデルの呼び出し

 with open('rf_model.pkl', mode='rb') as f:
    rf_model = pickle.load(f)
 with open('rf_model.pkl', mode='rb') as f:
    rf_model = pickle.load(f)

ランダムフォレストで.feature_importances_の可視化

アウトカムとの 関係の特徴量が重要との結果を求める(この重要度は,相対的な数値になります.)

importances = rf_model.feature_importances_
importances
array([0.07731317, 0.04914963, 0.01228686, 0.01803668, 0.01151149,
       0.06112387, 0.06713772, 0.03412625, 0.02216912, 0.02397789,
       0.02902746, 0.01426677, 0.05257647, 0.05166613, 0.05090452,
       0.04768364, 0.05031091, 0.04936703, 0.05790126, 0.04958181,
       0.04416273, 0.04355233, 0.0448326 , 0.03733364])

【Point】np.argsort()は引数に与えた配列をソートして、結果に対応する元の配列のindexを返します。 ※ソート結果は昇順

indices = np.argsort(importances)
indices
array([ 4,  2, 11,  3,  8,  9, 10,  7, 23, 21, 20, 22, 15,  1, 17, 19, 16,
       14, 13, 12, 18,  5,  6,  0], dtype=int64)

np.argsort()plt.barhを用いて.feature_importances_の可視化

features = x_train.columns featuresはXのコラム名
importances = rf_model.feature_importances_
indices = np.argsort(importances)

plt.figure(figsize=(12,10))
plt.barh(range(len(indices)), importances[indices],  align='center')


plt.barh:棒グラフを作る
range(len(indices))でindiceの長さを、range()の中に入れることが出来る。
→これってグラフの数の指定?
importances[indices]でindicesで作った順番通りのインデックスをimportances[]の中に入れることで、値が表示できる。
→描画するインデックスの順番

そのあとの内容は意味が分からないが、覚えてしまおう。
plt.yticks(range(len(indices)), features[indices])
plt.show()

XGBで.feature_importances_の可視化

features = x_train.columns
importances = xgb_model.feature_importances_
indices = np.argsort(importances)

plt.figure(figsize=(12,10))
plt.barh(range(len(indices)), importances[indices],  align='center')
plt.yticks(range(len(indices)), features[indices])
plt.show()

XGBで.feature_importances_を別のアプローチで可視化

XGBのパッケージには.plot_importance()という便利なメソッドが用意されています。
importance_typeという引数を変更することで、重要度の算出アルゴリズムが変更できます。
['weight','gain','cover']の中から選択可能で、前項の可視化結果と統一させる場合はgainを選択します。
weight : 生成された全ての木の中に何度その特徴量が存在するか
gain : 評価基準をどれだけ改善させたることができたかの値
cover : 葉に分類された訓練データの二次勾配の合計値

fig, ax = plt.subplots(figsize=(12, 10))
xgb.plot_importance(xgb_model,
                    ax=ax,
                    importance_type='gain'
                    #show_values=False
                    )
method = ['weight', 'gain','cover']
for val in method:
    fig, ax = plt.subplots(figsize=(12, 10))
    xgb.plot_importance(xgb_model,
                        ax=ax,
                        importance_type=val
                        #show_values=False
                        )

前節の考察

PAY_0: 2005年9月の返済状況
つまり直近の返済状況が支払いに強く影響していることがわかりました。

IDに関しては全てユニークな値をもっているので、特徴量としては不適切です。
本来は学習には使用しない列なので、参考値としましょう。

PAY_6BILL_AMT6PAY_AMT6など予測時期よりも遠い過去となる情報には
値特徴量としての効果が薄くなっていることも読み取れます。これは直感的にも納得できる結果ですね。

AGESEXMARRIAGEなど、直感的には影響が出てきそうな特徴量は、
gainのアルゴリズムでは影響は少ないとの結果がでています。
XGBで他のアルゴリズムを使用した場合では大幅に特徴量として有効化しているので、
調査のしがいがありそうです。

データの前処理

データの再確認

pdp.ProfileReport(data)

データのクリーニング

ID列は全てユニークな値を持つので、説明変数として使用できません。
.dropを使ってID列を削除します。

data = data.drop('ID',axis=1)

EDUCATION:(1 =大学院、2 =大学、3 =高校、4 =その他、5 =不明、6 =不明)
となっていますが、

  • 本来存在しない0が含まれている
  • 4~6の持つ意味合いがほぼ同じ

なので解釈性向上のためにカテゴリを1つにまとめて解釈します。

data.EDUCATION.value_counts()
2    448
1    395
3    150
5      3
4      2
6      2
Name: EDUCATION, dtype: int64

EDUCATION:(0=???、5 =不明、6 =不明)のindexを抜き出し、そのindexEDUCATION列を値4で上書きします!

index = (data.EDUCATION == 5) | (data.EDUCATION == 6) | (data.EDUCATION == 0)
data.loc[index, 'EDUCATION'] = 4
data.EDUCATION.value_counts()
2    448
1    395
3    150
4      7
Name: EDUCATION, dtype: int64

MARRIAGE:結婚暦(1 =結婚、2 =独身、3 =その他)
となっていますが、0の値が含まれているのでこの値は 3 =その他 と解釈します。

data.MARRIAGE.value_counts()
2    570
1    408
3     19
0      3
Name: MARRIAGE, dtype: int64
index = (data.MARRIAGE == 0)
data.loc[index, 'MARRIAGE'] = 3
data.MARRIAGE.value_counts()
2    570
1    408
3     22
Name: MARRIAGE, dtype: int64

PAY_0BILL_AMT1PAY_AMT1PAY_0のみ数値のラベルが0から始まり、
1が抜けているのでデータを扱いやすくするために、列名をPAY_1へと修正します。

  • 列名の変更には.renameを使用
data  = data.rename(columns={'PAY_0':'PAY_1'})

さらに各列の値を確認します
参考 : PAY_1:2005年9月の返済状況
(-1 = 定期的な支払い、1 = 1か月の支払い遅延、2 = 2か月の支払い遅延、…8 = 8か月の支払い遅延、9 = 9か月以上の支払い遅延)

colNameList = ['PAY_'+str(i+1) for i in range(6)]
colNameList
['PAY_1', 'PAY_2', 'PAY_3', 'PAY_4', 'PAY_5', 'PAY_6']
colNameList = ['PAY_'+str(i+1) for i in range(6)]
for col in colNameList:
    print(data[col].value_counts())
 0    473
-1    214
 1    136
 2     84
-2     79
 3      6
 4      4
 8      4
Name: PAY_1, dtype: int64
 0    525
-1    205
-2    130
 2    125
 3      8
 7      4
 5      1
 4      1
 1      1
Name: PAY_2, dtype: int64
 0    513
-1    208
-2    137
 2    129
 4      4
 6      4
 7      2
 3      1
 1      1
 5      1
Name: PAY_3, dtype: int64
 0    539
-1    204
-2    156
 2     87
 3      5
 5      5
 4      2
 7      2
Name: PAY_4, dtype: int64
 0    544
-1    192
-2    161
 2     90
 3      5
 4      5
 7      2
 5      1
Name: PAY_5, dtype: int64
 0    500
-1    216
-2    172
 2     99
 3      8
 6      3
 4      1
 7      1
Name: PAY_6, dtype: int64

定期的な支払いの値に本来含まれないはずの0や-2の値が含まれています。
こちらについても[-2, -1, 0]の値に関しては解釈性向上のために0として扱うことにします。

colNameList = ['PAY_'+str(i+1) for i in range(6)]
for col in colNameList:
    index = (data[col] == -2) | (data[col] == -1) 直入れで大丈夫
    data.loc[index, col] = 0
    print(data[col].value_counts())
0    766
1    136
2     84
3      6
4      4
8      4
Name: PAY_1, dtype: int64
0    860
2    125
3      8
7      4
5      1
4      1
1      1
Name: PAY_2, dtype: int64
0    858
2    129
4      4
6      4
7      2
3      1
1      1
5      1
Name: PAY_3, dtype: int64
0    899
2     87
3      5
5      5
4      2
7      2
Name: PAY_4, dtype: int64
0    897
2     90
3      5
4      5
7      2
5      1
Name: PAY_5, dtype: int64
0    888
2     99
3      8
6      3
4      1
7      1
Name: PAY_6, dtype: int64
colNameList = ['PAY_'+str(i+1) for i in range(6)]
for col in colNameList:
    print(data[col].value_counts())
0    766
1    136
2     84
3      6
4      4
8      4
Name: PAY_1, dtype: int64
0    860
2    125
3      8
7      4
5      1
4      1
1      1
Name: PAY_2, dtype: int64
0    858
2    129
4      4
6      4
7      2
3      1
1      1
5      1
Name: PAY_3, dtype: int64
0    899
2     87
3      5
5      5
4      2
7      2
Name: PAY_4, dtype: int64
0    897
2     90
3      5
4      5
7      2
5      1
Name: PAY_5, dtype: int64
0    888
2     99
3      8
6      3
4      1
7      1
Name: PAY_6, dtype: int64

データの確認

data.head()
LIMIT_BALSEXEDUCATIONMARRIAGEAGEPAY_1PAY_2PAY_3PAY_4PAY_5BILL_AMT4BILL_AMT5BILL_AMT6PAY_AMT1PAY_AMT2PAY_AMT3PAY_AMT4PAY_AMT5PAY_AMT6default.payment.next.month
020000.02212422000000068900.00.001
1120000.022226020003272345532610100010001000.00.020001
290000.022234000001433114948155491518150010001000.01000.050000
350000.022137000002831428959295472000201912001100.01069.010000
450000.01215700000209401914619131200036681100009000.0689.06790
data.columns
Index(['LIMIT_BAL', 'SEX', 'EDUCATION', 'MARRIAGE', 'AGE', 'PAY_1', 'PAY_2',
       'PAY_3', 'PAY_4', 'PAY_5', 'PAY_6', 'BILL_AMT1', 'BILL_AMT2',
       'BILL_AMT3', 'BILL_AMT4', 'BILL_AMT5', 'BILL_AMT6', 'PAY_AMT1',
       'PAY_AMT2', 'PAY_AMT3', 'PAY_AMT4', 'PAY_AMT5', 'PAY_AMT6',
       'default.payment.next.month'],
      dtype='object')

コメント

タイトルとURLをコピーしました