수식이 나오지 않는다면 새로고침(F5)을 해주세요
모바일은 수식이 나오지 않습니다.
📌 PRUNING
Tree 모델의 과적합을 막고 계산량을 줄이기 위해서 PRUNING이라는 것을 할 수 있습니다.
Pruning은 가지치기라고 합니다. 하부 tree를 제거하여 깊이가 줄어드는 방법을 이용합니다.
아래와 같은 가정을 해봅시다.
$T$ : 나무의 개수
$|T|$ : 나무의 터미널 노드 개수(맨 마지막 노드)
$T_1, T_2, ..., T_k$ : 나무의 터미널 노드들
$r(T_i)$ : $T_i$노드에 할당된 class
$L(i,j)$ : 손실 행렬
$P(T_i) = \sum_{j=1}^C \pi_jP(T_i,|j)$ : 여기서 $\pi_j$는 클래스 j에 대한 사전 확률
위 가정으로 $T_i$의 risk는 아래와 같습니다.
$$
R(T_i) = \sum_{j=1}^C p(j|T_i)L(j,r(T_i))
$$
수학이 싫으신 분들께 쉽게 설명하자면! $j$ 클래스에 속할 확률이 높은데 $i (\neq j)$ 클래스로 잘못 예측해서 손실이 늘어났다? 이들의 곱은 그만큼 크겠죠. 결국 risk값인 $R(T_i)$가 커질 것입니다.
식을 보시면 $p(j|T_i)$는 $T_i$노드에서 클래스 j에 속할 조건부 확률로. 주어진 데이터에 대한 클래스 j의 확률을 나타냅니다.
또한 $L(j, r(T_i))$는 손실 행렬 $L$에서 클래스 j와 $T_i$노드에서 예측된 클래스 $r(T_i)$ 간의 손실을 나타냅니다. 이는 예측이 실제 클래스와 얼마나 일치하는지에 따라 발생하는 손실을 계산하는 것 입니다.
따라서 $R(T_i)$는 $T_i$ 노드에서 예측된 클래스와 실제 클래스 간의 손실을 계산하고, 각 클래스에 대한 조건부 확률과 손실을 고려하여 해당 노드의 전반적인 risk를 계산합니다. risk인 만큼 값이 낮을 수록 예측이 정확하다고 볼 수 있습니다.
쉽게 설명하면 $T_i$노드에서 각 클래스의 손실값을 모두 더한 값입니다. 해당 노드에서의 전반적인 리스크를 계산하는 것입니다.
이번엔 $T$의 risk입니다.
$$
R(T) = \sum_{i=1}^kP(T_i)R(T_i)
$$
전체 트리 $T$에 대한 리스크식입니다. 각 터미널 노드 $T_i$의 리스크를 해당 노드가 선택될 확률 $P(T_i)$와 곱하여 합산한 값입니다. 이를 통해 전체 트리의 예측 성능을 나타내는 리스크 측정값이 됩니다.
📌 complexity parameter
이 때, 프루닝을 위해 복잡성 매개변수인 $\alpha$를 사용하여 트리의 복잡성을 조절합니다.
$[0,\infty)$범위내의 숫자인 $\alpha$를 complexity parameter라고 합니다. 트리 $T$의 비용을 다음과 같이 정의 합니다. $|T|$가 높아지면 $R(T)$는 낮아집니다.
$$
R_\alpha(T) = R(T) + \alpha |T|
$$
위처럼 리스크를 설정했을 경우 $\alpha$값을 높인다면 트리의 터미널 노드의 개수를 적게 만들어 나무의 복잡도를 낮춥니다. 이를 통해 과적합을 피할 수 있습니다. 반대로 낮춘다면 트리를 복잡하게 유지하므로 과적합의 위험이 있을 수 있습니다.
◼️ R로 확인하기
- 데이터 불러오기, 결측치 제거, train : test = 5 : 5로 나누기
데이터는 아래 첨부파일을 사용하였습니다.
library(tree)
hdata = read.csv('Heart.csv', header=T, stringsAsFactors=T)
str(hdata)
sum(is.na(hdata))
hdata = na.omit(hdata)
n = dim(hdata)[1]
hdata$X = NULL
set.seed(42)
train = sample(n, n/2)
htrain = hdata[train,]
htest = hdata[-train,]
- 10-fold cv 파이프라인 만들기
library(caret)
control = trainControl(method='cv', number =10, savePredictions= 'final', classProbs=T)
- model train 하기
library(rpart)
model = train(AHD~., data= htrain, method = "rpart", trControl=control,
tuneGrid= data.frame(cp = seq(0, 0.5, by = 0.05)))
여기서 중요한 점은 cp입니다. 위에서 언급했던 complexity parameter로 가지치기를 조정합니다.
- 결과 확인
model
범위를 너무 크게 잡긴 했습니다만 우선 0.05일 때가 가장 높은 정확도를 보입니다.
- test set 예측 정확도 확인
pred = predict(model, htest)
confusionMatrix(pred, htest$AHD)
test set 대상 예측 정확도는 0.73정도가 나옵니다.
'⚙️ Machine Learning > Machine learning' 카테고리의 다른 글
3. Bagging(배깅) : Random Forest은 뭐가 다를까? (0) | 2023.10.12 |
---|---|
2. Bagging(배깅) : Out of bag error estimation (0) | 2023.10.12 |
1. Bagging(배깅) : 왜 여러 모델을 쓰는가? (0) | 2023.10.10 |
CART 2. 분류나무(Classification tree)[지니 계수, 엔트로피] (0) | 2023.10.09 |
CART 1. 회귀나무(Regression tree)[재귀적 이진 분할 알고리즘] (1) | 2023.10.06 |