머신러닝에서 우리의 목표는 training dataset $\mathcal{D}_{train}$을 이용하여 모델을 학습하고, 학습된 모델을 이용하여 관측되지 않았던 새로운 데이터에 대해 예측을 수행하는 것이다. 머신러닝 연구 및 개발에서는 $\mathcal{D}_{train}$에 포함되지 않는 데이터를 모아서 test dataset $\mathcal{D}_{test}$를 구성하고, $\mathcal{D}_{test}$에 대해 모델의 성능을 측정함으로써 새로운 데이터에 대한 모델의 성능을 평가한다. $\mathcal{D}_{test}$는 학습 과정에서 참조할 수 없기 때문에 머신러닝 모델은 $\mathcal{D}_{train}$만을 가지고 $\mathcal{D}_{test}$에 대한 정확한 예측이 가능하도록 학습되어야한다.
기본적으로 머신 러닝 모델은 $\mathcal{D}_{train}$에 대한 손실 함수 (loss function)가 작아지도록 학습을 진행하기 때문에 분류 문제의 경우 그림 1과 같이 학습이 진행될수록 모델의 decision boundary는 $\mathcal{D}_{train}$에 적합 (fitting)된다. 그러나 모델이 $\mathcal{D}_{train}$에 너무 과하게 적합되면 이것은 모델이 데이터에 내재된 어떠한 구조나 패턴을 일반화한 것이 아니라, $\mathcal{D}_{train}$에만 한정되는 정보를 그대로 외운 것과 같을 수 있기 때문에 이것이 올바른 학습 결과인지는 생각해볼 필요가 있다.
만약 $\mathcal{D}_{train}$과 $\mathcal{D}_{test}$의 데이터 분포가 정확히 일치한다면, 모델이 $\mathcal{D}_{train}$에 적합될 수록 모델의 예측 정확도는 증가할 것이다. 그러나 대부분의 문제에서 우리는 $\mathcal{D}_{train}$가 데이터 전체를 대표할 수 있을만큼의 충분한 학습 데이터를 확보할 수 없다. 이에 따라 $\mathcal{D}_{train}$은 $\mathcal{D}_{test}$의 데이터 분포를 완벽히 나타낼 수 없다. 실제 응용에서는 $\mathcal{D}_{train}$과 $\mathcal{D}_{test}$의 데이터 분포가 완전히 다른 경우도 매우 많다. 이러한 경우에는 그림 2와 같이 $\mathcal{D}_{test}$의 데이터에 대해 $\mathcal{D}_{train}$에 조금은 덜 적합된 모델이 $\mathcal{D}_{test}$에 대해 더 높은 정확도를 보일 수가 있다.
그림 2b와 같이 $\mathcal{D}_{train}$에 과하게 적합되어 예측 모델이 $\mathcal{D}_{train}$에서는 더 낮은 예측 오차를 보이지만 실제 $\mathcal{D}_{test}$에서는 더 높은 예측 오차를 보이는 것을 과적합 문제 (overfitting problem)이라고 한다. 과적합 문제는 대부분의 응용에서 발생하며 학습 데이터가 적거나 문제가 어려울수록 과적합의 정도가 심해진다.
그림 3은 예측 오차의 관점에서 과적합 문제를 설명한다. 예측 오차의 변화에 따라 그림 3의 그래프는 아래와 같은 두 개의 구간으로 나눌 수 있다.
- 구간 A: 학습 오차와 테스트 오차가 같이 감소하는 구간 (과소적합, underfitting).
- 구간 B: 학습 오차는 감소하지만, 테스트 오차는 증가하는 구간 (과대적합, overfitting).
머신러닝에서는 구간 A에 있는 모델을 과소적합되었다고 하고, 구간 B에 있는 모델은 과대적합되었다고 한다. 우리의 목적은 학습을 통해 예측모델의 과소적합된 부분을 제거해나가면서 과대적합이 발생하기 직전에 학습을 멈추는 것이다. 머신러닝에서는 과대적합을 방지하기 위한 여러 방법이 연구되었으며, 일반적으로 validation set $\mathcal{D}_{val}$을 이용하여 과대적합이 일어났는지를 판별한다.
Validation dataset $\mathcal{D}_{val}$은 모델의 학습 과정에 참조되어 과대적합이 발생했는지를 판별하기 위해 사용되는 별도의 데이터셋이다. 일반적으로 실제 머신러닝 개발에서는 $\mathcal{D}_{val}$을 별도로 구축하기보다는 $\mathcal{D}_{train}$에서 일부 데이터를 추출하여 구축한다.
그림 4는 $\mathcal{D}_{train}$과 $\mathcal{D}_{val}$을 이용한 모델의 학습 과정이며, 학습은 아래의 5단계를 통해 진행된다.
- 단계 (1): 주어진 데이터셋을 $\mathcal{D}_{train}$, $\mathcal{D}_{val}$, $\mathcal{D}_{test}$로 나눈다. 일반적으로 각 데이터셋의 비율은 60:20:20으로 설정한다.
- 단계 (2): $\mathcal{D}_{train}$에 대한 예측 오차를 최소화하도록 모델을 학습시킨다.
- 단계 (3): $\mathcal{D}_{val}$에 대한 모델의 예측 오차를 계산한다.
- 단계 (4): 만약 $\mathcal{D}_{val}$에 대한 예측 오차가 증가했다면 학습을 종료한다. 그렇지 않을 경우에는 (2)로 돌아가서 학습을 계속 진행한다.
- 단계 (5): $\mathcal{D}_{test}$에 대한 예측 오차를 측정함으로써 학습된 모델의 최종 성능을 평가한다.
기존의 $\mathcal{D}_{train}$만 가지고 수행하였던 모델 학습과 그림 4의 가장 큰 차이점은 $\mathcal{D}_{val}$에 대한 예측 오차를 평가하는 단계 (3)과 (4)가 추가된 것이다. 이러한 두 과정을 통해 현재의 모델이 학습 과정에서 참조하지 않았던 데이터를 얼마나 정확하게 예측하는지를 평가하고, 평가 결과를 학습의 종료 조건으로 이용함으로써 과대적합을 간접적으로 방지한다.
그림 5는 $\mathcal{D}_{val}$을 이용하여 과대적합을 방지하는 방법을 예측 오차의 관점에서 설명한다. 위의 그래프에서 묘사된 것과 같이 $\mathcal{D}_{val}$에 대한 예측 오차 (validation loss)가 증가하는 시점부터 과대적합이 발생했다고 판단하고, 이 시점에서 모델 학습을 중단한다. 그러나 $\mathcal{D}_{val}$ 또한 $\mathcal{D}_{test}$의 데이터 분포를 완벽히 표현하지는 못 하기 때문에 $\mathcal{D}_{val}$에 대한 예측 오차가 최소가 되는 시점이 $\mathcal{D}_{test}$에 대한 예측 오차가 최소가 되는 시점과 정확히 일치하지는 않을 수도 있다.
처음 머신 러닝을 접하면 validation dataset의 개념을 정확히 이해하지 못 하거나, validation dataset과 test dataset의 개념을 혼동하기 쉽다. 아래는 혼동하기 쉬운 training dataset, validation dataset, test dataset의 용도를 요약한 것이다.
Validation dataset과 test dataset의 개념 중 혼동하기 쉬운 몇 가지를 요약하면 아래와 같다.
- Validation dataset은 학습 과정에서 참조할 수 있지만, test dataset은 그렇지 않다.
- Training dataset은 모델의 인자값 (인공신경망에서는 가중치)을 결정하는데 이용되지만, validation dataset은 이용되지 않는다. Validation dataset은 오직 언제 학습을 멈출지를 판단하기 위해 이용된다.
- Validation dataset과 test dataset은 모두 모델의 성능 평가에 이용된다. 그러나 validation dataset에 대한 성능은 학습을 중단하는 시점을 결정하기 위해 이용되고, test dataset에 대한 성능은 모델의 최종 정확도를 평가하기 위해 이용된다.
'지능형시스템 > 머신러닝' 카테고리의 다른 글
[머신 러닝/딥 러닝] Metric Learning의 개념과 Deep Metric Learning (1) | 2020.09.06 |
---|---|
[머신 러닝/딥 러닝] CGCNN: 인공지능과 소재 개발 (Artificial Intelligence and Material Discovery) (4) | 2020.03.20 |
[머신 러닝] 앙상블 (Ensemble) 방법의 이해 (0) | 2019.12.24 |
[머신 러닝/딥 러닝] 인공신경망 (Artificial Neural Network, ANN)의 종류와 구조 및 개념 (3) | 2019.12.23 |
[머신 러닝/딥 러닝] 그래프 합성곱 신경망 (Graph Convolutional Network, GCN) (19) | 2019.11.28 |