Вопрос или проблема
У меня есть модель XGBoost с следующими параметрами
xgbc_final = XGBClassifier(objective="multi:softprob",
num_class = 2,max_depth = 60,
n_estimators = 512,
reg_lambda = 0.1214,
alpha = 0.9131,
gamma = 0,
colsample_bytree = 0.7,
colsample_bylevel = 0.8,
colsample_bynode = 0.7,
subsample = 0.6,
learning_rate = .01,
min_child_weight = 14,
random_state = 2020,
eval_metric="auc",
verbosity = 1)
Здесь у меня только n_estimator = 512
, но я заметил, что когда я пытаюсь вывести дерево решений с индексом больше 511, я все равно получаю график
plot_tree(xgbc_final, num_trees=900)
Я ожидал ошибку для n_estimator больше 511 (если деревья индексируются с 0)
Может кто-то объяснить, почему он выдает деревья для чисел больше 512?
Это происходит потому, что вы делаете многоклассовую классификацию, а XGBoost реализует это, строя параллельные модели для каждого класса. Таким образом, общее количество деревьев на самом деле равно $512\cdot (\text{количество классов})$.
Я не знаю, в каком порядке эти деревья появляются. plot_tree
вызывает to_graphviz
, который вызывает (после, возможно, model.get_booster
) model.get_dump
, что вызывает некоторые c-level детали, которые я не буду отслеживать.
Ответ или решение
Ваша ситуация с XGBoost и количеством деревьев действительно может вызывать путаницу, особенно в контексте многоклассовой классификации. Позвольте объяснить подробнее, почему вы можете видеть деревья, выходящие за пределы указанного вами параметра n_estimators
.
Ваша модель определена с использованием параметра n_estimators=512
. Однако, когда вы используете XGBoost для многоклассовой классификации (как в вашем случае, с objective="multi:softprob"
и num_class=2
), фактическое количество деревьев, создаваемых моделью, на самом деле может превышать это значение. Это связано с тем, что XGBoost создает отдельные наборы деревьев для каждой категории, что приводит к увеличению общего числа деревьев.
Итак, в вашем случае, с двумя классами, XGBoost фактически построит 512 деревьев для каждого класса, что в итоге даст вам 1024 дерева. Это объясняет, почему вы можете запрашивать деревья с индексами больше 511 и все равно получать графики.
Когда вы вызываете команду plot_tree(xgbc_final, num_trees=900)
, XGBoost находит и отображает дерево с индексом 900, которое относится ко второму классу. Сам порядок, в котором эти деревья располагаются, может неясен, так как он зависит от внутренней реализации, как вы уже правильно заметили.
Резюмируя:
- Число деревьев: В многоклассовой классификации количество деревьев определяется как
n_estimators * num_class
, поэтому в вашем случае фактически создано 1024 дерева (512 для каждого из 2 классов). - Отображение деревьев: Вы можете запрашивать деревья с индексами, превышающими
n_estimators
, при условии, что они относятся к различным классам.
Таким образом, ваше наблюдение справедливо, и это нормальное поведение для XGBoost в контексте многоклассовой классификации.