Бинарные классификаторы. Практика. Часть 2.

==**Улучшение классификатора**==

При рассмотрении графика изменения значений ROC-кривой в предыдущем пункте, можно заметить, что количество деревьев равное 150 недостаточно для того, чтобы полностью использовать все возможности алгоритма и потому есть смысл нарастить этот параметр и тем самым попытаться улучшить классификатор. Для этого необходимо сделать две вещи

# создать сетку желаемых параметров
# передать эту сетку в обучающий алгоритм.

Делается это следующим образом

{{{ lang=rsplus
modelGrid <- expand.grid(iter = c(100, 200, 300, 400, 500), maxdepth = c(1, 2, 3, 4), nu = 0.1) model_tuned <- train( target ~ ., data = train, method = "ada", trControl = ctrl, tuneGrid = modelGrid, metric = 'ROC' ) }}} Передача сетки производится при помощи параметра 'tuneGrid' в команде 'train'. Сама же сетка формируется при помощи команды 'expand.grid', которая составляет все возможные комбинации входных параметров {{{ lang=rsplus > head(modelGrid, n = 10)
iter maxdepth nu
1 100 1 0.1
2 200 1 0.1
3 300 1 0.1
4 400 1 0.1
5 500 1 0.1
6 100 2 0.1
7 200 2 0.1
8 300 2 0.1
9 400 2 0.1
10 500 2 0.1
}}}

Результат вычислений выглядит следующим образом

[[image:ada_tuned.jpeg|link=source]]

Как видно из графика, после того, как количество используемых деревьев переходит 300 единиц, алгоритм выходит на плато и по всей видимости уже не может улучшать свою работу. Самым выигрышным набором гиперпараметров становится комбинация с iter = 400, maxdepth = 3, nu = 0.1 и значением ROC = 0.944. Соответствующий график ROC-кривой на проверочной выборке выглядит следующим образом

[[image:ada_tuned_roc.jpeg|link=source]]

В данной модели обнаруживается следующий интересный факт. Если попытаться посчитать предсказание модели на проверочной выборке, то результат получится хуже, чем в случае, когда модель строилась без каких-либо дополнительных настроек

{{{ lang=rsplus
> testClasses <- predict(model_tuned, newdata = test) > confusionMatrix(data = testClasses, test$target)

Confusion Matrix and Statistics
Reference
Prediction M R
M 25 4
R 2 20

Accuracy : 0.8824
95% CI : (0.7613, 0.9556)
No Information Rate : 0.5294
P-Value [Acc > NIR] : 8.488e-08

Kappa : 0.7628
Mcnemar’s Test P-Value : 0.6831

Sensitivity : 0.9259
Specificity : 0.8333
Pos Pred Value : 0.8621
Neg Pred Value : 0.9091
Prevalence : 0.5294
Detection Rate : 0.4902
Detection Prevalence : 0.5686
Balanced Accuracy : 0.8796

‘Positive’ Class : M
}}}

Тем не менее, это не противоречит тому факту, что ROC для такой модели выше. Из результатов видно, что алгоритм чаще ошибается занося ‘R’ в класс ‘M’. Таким образом, если перекалибровать модель, есть возможность понизить количество ошибок такого типа.

{{{ lang=rsplus
> tc_calibrated <- factor(ifelse(testProbs_tuned$M >= 0.7, ‘M’, ‘R’))
> confusionMatrix(data = tc_calibrated, test$target)

Confusion Matrix and Statistics
Reference
Prediction M R
M 25 2
R 2 22

Accuracy : 0.9216
95% CI : (0.8112, 0.9782)
No Information Rate : 0.5294
P-Value [Acc > NIR] : 1.407e-09

Kappa : 0.8426
Mcnemar’s Test P-Value : 1

Sensitivity : 0.9259
Specificity : 0.9167
Pos Pred Value : 0.9259
Neg Pred Value : 0.9167
Prevalence : 0.5294
Detection Rate : 0.4902
Detection Prevalence : 0.5294
Balanced Accuracy : 0.9213

‘Positive’ Class : M
}}}

Сдвиг порога обнаружения металла до 70% позволяет повысить //accuracy// до 92%!

==**Сравнение с другим классификатором (Random Forest)**==

Для демонстрации удобства работы с пакетом ‘caret’, построим решение той же задачи при помощи классификатора [[https://ru.wikipedia.org/wiki/Random_forest|RandomForest]].

{{{ lang=rsplus
modelGrid <- expand.grid(mtry = seq(2, 60, 1)) ctrl <- trainControl( method = "repeatedcv", repeats = 10, number = 10, classProbs = T, verboseIter = T, summaryFunction = twoClassSummary ) model_rf <- train( target ~ ., data = train, method = "rf", tuneGrid = modelGrid, trControl = ctrl, metric = 'ROC' ) }}} Настройка данного алгоритма выглядит следующим образом [[image:rf.jpeg|link=source]] И из графика видно, что увеличение числа предикторов ухудшает точность работы алгоритма, и наилучший результат получается при числе предикторов равном 4 с ROC=0.935, что немного хуже, чем в случае AdaBoost. Более низкое качество показывается и на проверочной выборке {{{ lang=rsplus > testClasses <- predict(model_rf, newdata = test) > confusionMatrix(data = testClasses, test$target)

Confusion Matrix and Statistics
Reference
Prediction M R
M 26 9
R 1 15

Accuracy : 0.8039
95% CI : (0.6688, 0.9018)
No Information Rate : 0.5294
P-Value [Acc > NIR] : 4.341e-05

Kappa : 0.5991
Mcnemar’s Test P-Value : 0.02686

Sensitivity : 0.9630
Specificity : 0.6250
Pos Pred Value : 0.7429
Neg Pred Value : 0.9375
Prevalence : 0.5294
Detection Rate : 0.5098
Detection Prevalence : 0.6863
Balanced Accuracy : 0.7940

‘Positive’ Class : M
}}}

Помимо простого визуального сравнения работы разных классификаторов, есть и точные численные оценки, которые можно производить следующим образом

{{{ lang=rsplus
> summary(resamps <- resamples(list(ada = model_tuned, rf = model_rf))) Call: summary.resamples(object = resamps <- resamples(list(ada = model_tuned, rf = model_rf))) Models: ada, rf Number of resamples: 100 ROC Min. 1st Qu. Median Mean 3rd Qu. Max. NA's ada 0.7500 0.9092 0.9531 0.9435 0.9883 1 0 rf 0.7143 0.9009 0.9524 0.9353 0.9842 1 0 Sens ... > summary(diffs <- diff(resamps)) Call: summary.diff.resamples(object = diffs <- diff(resamps)) p-value adjustment: bonferroni Upper diagonal: estimates of the difference Lower diagonal: p-value for H0: difference = 0 ROC ada rf ada 0.008202 rf 0.3594 ... }}} Первая команда сравнивает набранные статистики по ROC метрике между разными методами. Как можно видеть, RandomForest в целом смещён влево по сравнению с AdaBoost. Вторая команда производит сравнение распределений по теории гипотез. Из этого сравнения видно, что AdaBoost немного лучше RandomForest (на 0.008), но значимость нулевой гипотезы (что модели равносильны) равна 36%, то есть разница есть, но она может быть статистической ошибкой.

Tagged , , , , , , . Bookmark the permalink.

Leave a Reply

Your email address will not be published. Required fields are marked *