XGBoost выводит деревья за пределами n_estimator

Вопрос или проблема

У меня есть модель XGBoost с следующими параметрами

xgbc_final = XGBClassifier(objective="multi:softprob",
                           num_class = 2,max_depth = 60,
                           n_estimators = 512,
                           reg_lambda = 0.1214,
                           alpha = 0.9131,
                           gamma = 0,
                           colsample_bytree = 0.7,
                           colsample_bylevel = 0.8,
                           colsample_bynode = 0.7,
                           subsample = 0.6,
                           learning_rate = .01,
                           min_child_weight = 14,
                           random_state = 2020,
                           eval_metric="auc",
                           verbosity = 1)

Здесь у меня только n_estimator = 512, но я заметил, что когда я пытаюсь вывести дерево решений с индексом больше 511, я все равно получаю график

plot_tree(xgbc_final, num_trees=900)

Я ожидал ошибку для n_estimator больше 511 (если деревья индексируются с 0)

Может кто-то объяснить, почему он выдает деревья для чисел больше 512?

Это происходит потому, что вы делаете многоклассовую классификацию, а XGBoost реализует это, строя параллельные модели для каждого класса. Таким образом, общее количество деревьев на самом деле равно $512\cdot (\text{количество классов})$.

Я не знаю, в каком порядке эти деревья появляются. plot_tree вызывает to_graphviz, который вызывает (после, возможно, model.get_booster) model.get_dump, что вызывает некоторые c-level детали, которые я не буду отслеживать.

Ответ или решение

Ваша ситуация с XGBoost и количеством деревьев действительно может вызывать путаницу, особенно в контексте многоклассовой классификации. Позвольте объяснить подробнее, почему вы можете видеть деревья, выходящие за пределы указанного вами параметра n_estimators.

Ваша модель определена с использованием параметра n_estimators=512. Однако, когда вы используете XGBoost для многоклассовой классификации (как в вашем случае, с objective="multi:softprob" и num_class=2), фактическое количество деревьев, создаваемых моделью, на самом деле может превышать это значение. Это связано с тем, что XGBoost создает отдельные наборы деревьев для каждой категории, что приводит к увеличению общего числа деревьев.

Итак, в вашем случае, с двумя классами, XGBoost фактически построит 512 деревьев для каждого класса, что в итоге даст вам 1024 дерева. Это объясняет, почему вы можете запрашивать деревья с индексами больше 511 и все равно получать графики.

Когда вы вызываете команду plot_tree(xgbc_final, num_trees=900), XGBoost находит и отображает дерево с индексом 900, которое относится ко второму классу. Сам порядок, в котором эти деревья располагаются, может неясен, так как он зависит от внутренней реализации, как вы уже правильно заметили.

Резюмируя:

  1. Число деревьев: В многоклассовой классификации количество деревьев определяется как n_estimators * num_class, поэтому в вашем случае фактически создано 1024 дерева (512 для каждого из 2 классов).
  2. Отображение деревьев: Вы можете запрашивать деревья с индексами, превышающими n_estimators, при условии, что они относятся к различным классам.

Таким образом, ваше наблюдение справедливо, и это нормальное поведение для XGBoost в контексте многоклассовой классификации.

Оцените материал
Добавить комментарий

Капча загружается...