AI/머신러닝

[머신러닝] Cross Validation(교차 검증)

caramel-bottle 2023. 12. 28.

K-Fold Cross Validation ( K-겹 교차 검증 )

https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.train_test_split.html

 

sklearn.model_selection.train_test_split

Examples using sklearn.model_selection.train_test_split: Release Highlights for scikit-learn 0.24 Release Highlights for scikit-learn 0.23 Release Highlights for scikit-learn 0.22 Comparison of Cal...

scikit-learn.org

https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.KFold.html

 

sklearn.model_selection.KFold

Examples using sklearn.model_selection.KFold: Feature agglomeration vs. univariate selection Comparing Random Forests and Histogram Gradient Boosting models Gradient Boosting Out-of-Bag estimates N...

scikit-learn.org

 

K-Fold란?

train_test_split에서 발생하는 데이터의 섞임에 따라 성능이 좌우되는 문제를 해결하는 교차 검증 기술

 

교차 검증이란?

고정된 test set으로 검증을 한다면 그 test set만 잘 맞추는 모델이 될 수 있다.

 

이 모델에 실제 데이터를 넣어 예측을 수행하면 결과가 좋지 않을 수 있다.

 

이것을 과적합(Overfitting)되었다고 한다.

 

이처럼 test set에 따라 성능이 크게 달라질 수 있기 때문에 평균적으로 믿을만한 모델을 만들 필요가 있다.

 

이를 해결하고자 하는 것이 교차검증(Cross Validation)이다.


사용방법

데이터셋 준비

실험용 데이터셋을 준비한다.

2023.12.27 - [분류 전체보기] - [머신러닝] 로지스틱 회귀 (hr 데이터셋)

 

[머신러닝] 로지스틱 회귀 (hr 데이터셋)

로지스틱 회귀 로지스틱 회귀는 둘 중 하나를 결정하는 문제(이진 분류)를 풀기 위한 대표적인 알고리즘이다. hr 데이터셋 hr 데이터셋은 직원정보와 승진여부에 대한 데이터이다. 직원 데이터를

caramelbottle.tistory.com

# Human Resources
# 데이터 전처리가 끝난 df임
hr_df.info()

output>>

<class 'pandas.core.frame.DataFrame'>
Int64Index: 48660 entries, 0 to 54807
Data columns (total 59 columns):
 #   Column                        Non-Null Count  Dtype  
---  ------                        --------------  -----  
 0   employee_id                   48660 non-null  int64  
 1   no_of_trainings               48660 non-null  int64  
 2   age                           48660 non-null  int64  
 3   previous_year_rating          48660 non-null  float64
 4   length_of_service             48660 non-null  int64  
 5   awards_won?                   48660 non-null  int64  
 6   avg_training_score            48660 non-null  int64  
 7   is_promoted                   48660 non-null  int64  
 8   department_Analytics          48660 non-null  uint8  
 9   department_Finance            48660 non-null  uint8  
 10  department_HR                 48660 non-null  uint8  
 11  department_Legal              48660 non-null  uint8  
 12  department_Operations         48660 non-null  uint8  
 13  department_Procurement        48660 non-null  uint8  
 14  department_R&D                48660 non-null  uint8  
 15  department_Sales & Marketing  48660 non-null  uint8  
 16  department_Technology         48660 non-null  uint8  
 17  region_region_1               48660 non-null  uint8  
 18  region_region_10              48660 non-null  uint8  
 19  region_region_11              48660 non-null  uint8  
 20  region_region_12              48660 non-null  uint8  
 21  region_region_13              48660 non-null  uint8  
 22  region_region_14              48660 non-null  uint8  
 23  region_region_15              48660 non-null  uint8  
 24  region_region_16              48660 non-null  uint8  
 25  region_region_17              48660 non-null  uint8  
 26  region_region_18              48660 non-null  uint8  
 27  region_region_19              48660 non-null  uint8  
 28  region_region_2               48660 non-null  uint8  
 29  region_region_20              48660 non-null  uint8  
 30  region_region_21              48660 non-null  uint8  
 31  region_region_22              48660 non-null  uint8  
 32  region_region_23              48660 non-null  uint8  
 33  region_region_24              48660 non-null  uint8  
 34  region_region_25              48660 non-null  uint8  
 35  region_region_26              48660 non-null  uint8  
 36  region_region_27              48660 non-null  uint8  
 37  region_region_28              48660 non-null  uint8  
 38  region_region_29              48660 non-null  uint8  
 39  region_region_3               48660 non-null  uint8  
 40  region_region_30              48660 non-null  uint8  
 41  region_region_31              48660 non-null  uint8  
 42  region_region_32              48660 non-null  uint8  
 43  region_region_33              48660 non-null  uint8  
 44  region_region_34              48660 non-null  uint8  
 45  region_region_4               48660 non-null  uint8  
 46  region_region_5               48660 non-null  uint8  
 47  region_region_6               48660 non-null  uint8  
 48  region_region_7               48660 non-null  uint8  
 49  region_region_8               48660 non-null  uint8  
 50  region_region_9               48660 non-null  uint8  
 51  education_Bachelor's          48660 non-null  uint8  
 52  education_Below Secondary     48660 non-null  uint8  
 53  education_Master's & above    48660 non-null  uint8  
 54  gender_f                      48660 non-null  uint8  
 55  gender_m                      48660 non-null  uint8  
 56  recruitment_channel_other     48660 non-null  uint8  
 57  recruitment_channel_referred  48660 non-null  uint8  
 58  recruitment_channel_sourcing  48660 non-null  uint8  
dtypes: float64(1), int64(7), uint8(51)
memory usage: 5.7 MB

KFold()

KFold()를 사용하여 n_splits를 설정한다. n_splits는 나눌 validate set 개수를 설정하는 것이다.

from sklearn.model_selection import KFold

kf = KFold(n_splits=5)
kf

output>>

KFold(n_splits=5, random_state=None, shuffle=False)

 

 

K-Fold가 어떻게 나뉘었는지 확인하기 위한 반복문

for train_index, test_index in kf.split(range(len(hr_df))):
    print(train_index, test_index)
    print(len(train_index), len(test_index))

output>>

[ 9732  9733  9734 ... 48657 48658 48659] [   0    1    2 ... 9729 9730 9731]
38928 9732

[    0     1     2 ... 48657 48658 48659] [ 9732  9733  9734 ... 19461 19462 19463]
38928 9732

[    0     1     2 ... 48657 48658 48659] [19464 19465 19466 ... 29193 29194 29195]
38928 9732

[    0     1     2 ... 48657 48658 48659] [29196 29197 29198 ... 38925 38926 38927]
38928 9732

[    0     1     2 ... 38925 38926 38927] [38928 38929 38930 ... 48657 48658 48659]
38928 9732

 

같은 크기로 5개의 경우를 확인할 수 있다.


KFold() feature

# random_state=2023, shuffle=True
kf = KFold(n_splits=5, random_state=2023, shuffle=True)
kf

output>>

KFold(n_splits=5, random_state=2023, shuffle=True)

 

 

K-Fold가 어떻게 나뉘었는지 확인하기 위한 반복문

for train_index, test_index in kf.split(range(len(hr_df))):
    print(train_index, test_index)
    print(len(train_index), len(test_index))

output>>

[    1     2     3 ... 48656 48658 48659] [    0     7    13 ... 48634 48645 48657]
38928 9732

[    0     1     3 ... 48656 48657 48659] [    2    12    16 ... 48642 48644 48658]
38928 9732

[    0     1     2 ... 48657 48658 48659] [    5     6    17 ... 48652 48653 48655]
38928 9732

[    0     1     2 ... 48657 48658 48659] [    4     8    10 ... 48650 48654 48656]
38928 9732

[    0     2     4 ... 48656 48657 48658] [    1     3     9 ... 48648 48649 48659]
38928 9732

 

매 분할마다 데이터를 섞게 된다.

 

경우에 따라 적절한 하이퍼 파라미터를 설정하여 분할을 할 수 있다.


fit & predict

acc_list = []

for train_index, test_index in kf.split(range(len(hr_df))):
    X = hr_df.drop('is_promoted', axis=1)
    y = hr_df['is_promoted']

    X_train = X.iloc[train_index]
    X_test = X.iloc[test_index]
    y_train = y.iloc[train_index]
    y_test = y.iloc[test_index]

    lr = LogisticRegression()
    lr.fit(X_train, y_train)
    pred = lr.predict(X_test)
    acc_list.append(accuracy_score(y_test, pred))
    
acc_list

output>>

[0.9169749280723387,
 0.9110152075626798,
 0.9126592683929305,
 0.913481298808056,
 0.9110152075626798]

KFold의 평균

# 모든 결과의 평균
np.array(acc_list).mean()

output>>

0.9130291820797372

결론

Cross Validation을 사용하는 이유는 결과를 좋게 하기 위함이 아니라 믿을만한 검증을 하기 위해서임

시간이 그만큼 오래걸림

하지만 LOOCV보다는 훨씬 빠름

데이터가 많으면 KFold가 비교적 좋음


참고

https://wooono.tistory.com/105

 

[ML] 교차검증 (CV, Cross Validation) 이란?

교차 검증이란? 보통은 train set 으로 모델을 훈련, test set으로 모델을 검증한다. 여기에는 한 가지 약점이 존재한다. 고정된 test set을 통해 모델의 성능을 검증하고 수정하는 과정을 반복하면, 결

wooono.tistory.com

 

댓글