Неверная функция потерь превосходит правильную функцию потерь?

Вопрос или проблема

Набор данных

У меня есть набор данных в следующем формате:

834 входных данных: 2 числа с плавающей запятой в диапазоне 0-1 и 832 целых числа, созданные с помощью one-hot-кодирования 64 значений (13 классов на значение).

4096 выходов: Каждый выход – это one-hot-кодированный класс, так что первое число было бы первой категорией, второе число – второй категорией, и так далее. Есть только один правильный класс выхода, поэтому только одно значение будет равно 1.

Поскольку из 4096 выходных категорий не все возможны для каждого входа, я создал пользовательскую функцию потерь, в которую передаю две части выхода.

  • 4096 выходных значений. [0,0,0,0,...,1,0,0,0,0,...]
  • Список из 4096 целых чисел, описывающий, какие выходы возможны. Это перемножается с предсказаниями ИИ, так как мне не важно, что ИИ предлагает для невозможных выходов. [0,0,1,0,0,...,0,1,1,0,0]

Модель

Модель выглядит следующим образом:

def MakeModel():
    model = Sequential()
    model.add(Dense(256, input_dim=834, activation='relu'))
    model.add(BatchNormalization())
    for _ in range(2):
        model.add(Dense(256, activation='relu'))
        model.add(BatchNormalization())
    model.add(Dense(4096, activation='softmax'))
    model.compile(loss=customLoss, optimizer=Adam(amsgrad=True), metrics=['accuracy'])
    return model

Функция потерь

Пользовательская функция потерь:

def customLoss(dataOut,aiOut):
    actualOut     = dataOut[:, 0:4096]
    possibleMoves = dataOut[:, 4096:8192]
    
    aiOutPossible = possibleMoves*aiOut     #Это выходной результат ИИ, включающий только возможные ходы
    
    loss = tf.keras.backend.binary_crossentropy(actualOut, aiOutPossible)
    #loss = tf.keras.backend.categorical_crossentropy(actualOut, aiOutPossible)
    
    return loss

Если я правильно понимаю, правильной функцией потерь была бы categorical_crossentropy, так как я использую 4096 one-hot-кодированных выходов.

Результаты

К сожалению, результаты очень плохие, с точностью валидации около ~0.17%:

225/225 - 14s - loss: 6.0649 - accuracy: 0.0182 - val_loss: 4.2419 - val_accuracy: 0.0215
222/222 - 13s - loss: 3.7911 - accuracy: 0.0160 - val_loss: 3.2553 - val_accuracy: 0.0164
201/201 - 12s - loss: 3.0921 - accuracy: 0.0194 - val_loss: 2.9757 - val_accuracy: 0.0189
221/221 - 13s - loss: 2.9000 - accuracy: 0.0170 - val_loss: 2.8644 - val_accuracy: 0.0190
222/222 - 13s - loss: 2.8031 - accuracy: 0.0174 - val_loss: 2.8231 - val_accuracy: 0.0180
221/221 - 13s - loss: 2.7412 - accuracy: 0.0186 - val_loss: 2.7795 - val_accuracy: 0.0191
214/214 - 12s - loss: 2.7455 - accuracy: 0.0158 - val_loss: 2.7329 - val_accuracy: 0.0158
193/193 - 12s - loss: 2.6967 - accuracy: 0.0181 - val_loss: 2.7116 - val_accuracy: 0.0188
217/217 - 12s - loss: 2.6593 - accuracy: 0.0199 - val_loss: 2.6849 - val_accuracy: 0.0170
227/227 - 13s - loss: 2.6688 - accuracy: 0.0191 - val_loss: 2.6879 - val_accuracy: 0.0202
202/202 - 12s - loss: 2.6455 - accuracy: 0.0211 - val_loss: 2.6561 - val_accuracy: 0.0220
217/217 - 13s - loss: 2.6129 - accuracy: 0.0211 - val_loss: 2.6610 - val_accuracy: 0.0220
224/224 - 13s - loss: 2.6278 - accuracy: 0.0207 - val_loss: 2.6387 - val_accuracy: 0.0173
221/221 - 13s - loss: 2.5982 - accuracy: 0.0188 - val_loss: 2.6365 - val_accuracy: 0.0161
211/211 - 13s - loss: 2.5818 - accuracy: 0.0198 - val_loss: 2.6107 - val_accuracy: 0.0195
221/221 - 13s - loss: 2.6160 - accuracy: 0.0190 - val_loss: 2.6113 - val_accuracy: 0.0204
218/218 - 13s - loss: 2.5962 - accuracy: 0.0191 - val_loss: 2.5933 - val_accuracy: 0.0177
224/224 - 15s - loss: 2.5523 - accuracy: 0.0197 - val_loss: 2.5937 - val_accuracy: 0.0174
224/224 - 15s - loss: 2.5810 - accuracy: 0.0159 - val_loss: 2.5831 - val_accuracy: 0.0174
220/220 - 15s - loss: 2.5465 - accuracy: 0.0184 - val_loss: 2.5631 - val_accuracy: 0.0210
232/232 - 17s - loss: 2.5596 - accuracy: 0.0178 - val_loss: 2.5787 - val_accuracy: 0.0182
217/217 - 15s - loss: 2.5375 - accuracy: 0.0191 - val_loss: 2.5591 - val_accuracy: 0.0182
226/226 - 14s - loss: 2.5161 - accuracy: 0.0193 - val_loss: 2.5489 - val_accuracy: 0.0183
218/218 - 13s - loss: 2.5139 - accuracy: 0.0183 - val_loss: 2.5449 - val_accuracy: 0.0175
221/221 - 13s - loss: 2.4935 - accuracy: 0.0199 - val_loss: 2.5453 - val_accuracy: 0.0194
230/230 - 14s - loss: 2.5067 - accuracy: 0.0176 - val_loss: 2.5356 - val_accuracy: 0.0198
209/209 - 12s - loss: 2.4901 - accuracy: 0.0187 - val_loss: 2.5324 - val_accuracy: 0.0169
221/221 - 13s - loss: 2.4924 - accuracy: 0.0171 - val_loss: 2.5150 - val_accuracy: 0.0181
233/233 - 14s - loss: 2.4844 - accuracy: 0.0174 - val_loss: 2.5139 - val_accuracy: 0.0177
219/219 - 13s - loss: 2.4908 - accuracy: 0.0167 - val_loss: 2.5540 - val_accuracy: 0.0167
212/212 - 13s - loss: 2.4907 - accuracy: 0.0191 - val_loss: 2.5199 - val_accuracy: 0.0190
227/227 - 13s - loss: 2.4772 - accuracy: 0.0160 - val_loss: 2.5036 - val_accuracy: 0.0175
226/226 - 14s - loss: 2.4818 - accuracy: 0.0169 - val_loss: 2.5041 - val_accuracy: 0.0177
219/219 - 13s - loss: 2.4773 - accuracy: 0.0168 - val_loss: 2.4996 - val_accuracy: 0.0157
217/217 - 13s - loss: 2.4706 - accuracy: 0.0171 - val_loss: 2.5008 - val_accuracy: 0.0175
228/228 - 14s - loss: 2.4687 - accuracy: 0.0167 - val_loss: 2.4900 - val_accuracy: 0.0170
222/222 - 13s - loss: 2.4495 - accuracy: 0.0174 - val_loss: 2.4872 - val_accuracy: 0.0156
219/219 - 13s - loss: 2.4442 - accuracy: 0.0167 - val_loss: 2.4824 - val_accuracy: 0.0152
217/217 - 13s - loss: 2.4519 - accuracy: 0.0162 - val_loss: 2.4799 - val_accuracy: 0.0170
218/218 - 14s - loss: 2.4606 - accuracy: 0.0147 - val_loss: 2.4775 - val_accuracy: 0.0170
220/220 - 14s - loss: 2.4382 - accuracy: 0.0173 - val_loss: 2.4724 - val_accuracy: 0.0145
209/209 - 13s - loss: 2.4238 - accuracy: 0.0170 - val_loss: 2.4657 - val_accuracy: 0.0154
212/212 - 13s - loss: 2.4480 - accuracy: 0.0148 - val_loss: 2.4657 - val_accuracy: 0.0140
226/226 - 14s - loss: 2.4373 - accuracy: 0.0157 - val_loss: 2.4677 - val_accuracy: 0.0177
231/231 - 14s - loss: 2.4427 - accuracy: 0.0152 - val_loss: 2.4690 - val_accuracy: 0.0155
216/216 - 13s - loss: 2.4252 - accuracy: 0.0157 - val_loss: 2.4670 - val_accuracy: 0.0165
225/225 - 15s - loss: 2.4437 - accuracy: 0.0147 - val_loss: 2.4518 - val_accuracy: 0.0148
226/226 - 16s - loss: 2.4178 - accuracy: 0.0144 - val_loss: 2.4503 - val_accuracy: 0.0158
235/235 - 15s - loss: 2.4281 - accuracy: 0.0146 - val_loss: 2.4495 - val_accuracy: 0.0142
218/218 - 15s - loss: 2.4193 - accuracy: 0.0144 - val_loss: 2.4502 - val_accuracy: 0.0137
216/216 - 15s - loss: 2.4175 - accuracy: 0.0144 - val_loss: 2.4530 - val_accuracy: 0.0149
232/232 - 14s - loss: 2.4210 - accuracy: 0.0142 - val_loss: 2.4441 - val_accuracy: 0.0145
226/226 - 14s - loss: 2.4304 - accuracy: 0.0140 - val_loss: 2.4549 - val_accuracy: 0.0160
219/219 - 13s - loss: 2.4300 - accuracy: 0.0165 - val_loss: 2.4584 - val_accuracy: 0.0163
223/223 - 14s - loss: 2.4165 - accuracy: 0.0146 - val_loss: 2.4426 - val_accuracy: 0.0152
217/217 - 14s - loss: 2.4247 - accuracy: 0.0150 - val_loss: 2.4391 - val_accuracy: 0.0144
216/216 - 14s - loss: 2.4271 - accuracy: 0.0146 - val_loss: 2.4360 - val_accuracy: 0.0156

Однако, при использовании той же задачи с binary_crossentropy:

225/225 - 25s - loss: 0.0018 - accuracy: 0.0348 - val_loss: 0.0017 - val_accuracy: 0.0383
222/222 - 24s - loss: 0.0016 - accuracy: 0.0624 - val_loss: 0.0015 - val_accuracy: 0.0633
201/201 - 22s - loss: 0.0014 - accuracy: 0.0778 - val_loss: 0.0014 - val_accuracy: 0.0805
221/221 - 23s - loss: 0.0013 - accuracy: 0.0830 - val_loss: 0.0013 - val_accuracy: 0.0929
222/222 - 23s - loss: 0.0013 - accuracy: 0.0968 - val_loss: 0.0013 - val_accuracy: 0.0933
221/221 - 23s - loss: 0.0013 - accuracy: 0.1051 - val_loss: 0.0013 - val_accuracy: 0.1024
214/214 - 29s - loss: 0.0012 - accuracy: 0.1001 - val_loss: 0.0012 - val_accuracy: 0.1007
193/193 - 21s - loss: 0.0012 - accuracy: 0.1060 - val_loss: 0.0012 - val_accuracy: 0.1079
217/217 - 23s - loss: 0.0012 - accuracy: 0.1165 - val_loss: 0.0012 - val_accuracy: 0.1156
227/227 - 25s - loss: 0.0012 - accuracy: 0.1172 - val_loss: 0.0012 - val_accuracy: 0.1183
202/202 - 21s - loss: 0.0012 - accuracy: 0.1177 - val_loss: 0.0012 - val_accuracy: 0.1176
217/217 - 25s - loss: 0.0011 - accuracy: 0.1308 - val_loss: 0.0012 - val_accuracy: 0.1180
224/224 - 24s - loss: 0.0012 - accuracy: 0.1238 - val_loss: 0.0011 - val_accuracy: 0.1273
221/221 - 20s - loss: 0.0011 - accuracy: 0.1209 - val_loss: 0.0011 - val_accuracy: 0.1253
211/211 - 24s - loss: 0.0011 - accuracy: 0.1285 - val_loss: 0.0011 - val_accuracy: 0.1278
221/221 - 23s - loss: 0.0011 - accuracy: 0.1230 - val_loss: 0.0011 - val_accuracy: 0.1159
218/218 - 23s - loss: 0.0011 - accuracy: 0.1308 - val_loss: 0.0011 - val_accuracy: 0.1325
224/224 - 24s - loss: 0.0011 - accuracy: 0.1333 - val_loss: 0.0011 - val_accuracy: 0.1343
224/224 - 21s - loss: 0.0011 - accuracy: 0.1249 - val_loss: 0.0011 - val_accuracy: 0.1305
220/220 - 21s - loss: 0.0011 - accuracy: 0.1359 - val_loss: 0.0011 - val_accuracy: 0.1371
232/232 - 22s - loss: 0.0011 - accuracy: 0.1318 - val_loss: 0.0011 - val_accuracy: 0.1369
217/217 - 21s - loss: 0.0011 - accuracy: 0.1384 - val_loss: 0.0011 - val_accuracy: 0.1361
226/226 - 21s - loss: 0.0011 - accuracy: 0.1357 - val_loss: 0.0011 - val_accuracy: 0.1353
218/218 - 21s - loss: 0.0011 - accuracy: 0.1386 - val_loss: 0.0011 - val_accuracy: 0.1398
221/221 - 23s - loss: 0.0011 - accuracy: 0.1439 - val_loss: 0.0011 - val_accuracy: 0.1412
230/230 - 25s - loss: 0.0011 - accuracy: 0.1391 - val_loss: 0.0011 - val_accuracy: 0.1383
209/209 - 21s - loss: 0.0011 - accuracy: 0.1454 - val_loss: 0.0011 - val_accuracy: 0.1395
221/221 - 20s - loss: 0.0011 - accuracy: 0.1462 - val_loss: 0.0011 - val_accuracy: 0.1452
233/233 - 24s - loss: 0.0011 - accuracy: 0.1443 - val_loss: 0.0011 - val_accuracy: 0.1422
219/219 - 23s - loss: 0.0011 - accuracy: 0.1458 - val_loss: 0.0011 - val_accuracy: 0.1420
212/212 - 23s - loss: 0.0011 - accuracy: 0.1471 - val_loss: 0.0011 - val_accuracy: 0.1477
227/227 - 27s - loss: 0.0011 - accuracy: 0.1438 - val_loss: 0.0011 - val_accuracy: 0.1479
226/226 - 22s - loss: 0.0011 - accuracy: 0.1457 - val_loss: 0.0011 - val_accuracy: 0.1443
219/219 - 20s - loss: 0.0011 - accuracy: 0.1499 - val_loss: 0.0011 - val_accuracy: 0.1481
217/217 - 23s - loss: 0.0011 - accuracy: 0.1502 - val_loss: 0.0011 - val_accuracy: 0.1460
228/228 - 22s - loss: 0.0011 - accuracy: 0.1495 - val_loss: 0.0011 - val_accuracy: 0.1458
222/222 - 23s - loss: 0.0011 - accuracy: 0.1538 - val_loss: 0.0011 - val_accuracy: 0.1516
219/219 - 25s - loss: 0.0011 - accuracy: 0.1553 - val_loss: 0.0011 - val_accuracy: 0.1551
217/217 - 21s - loss: 0.0011 - accuracy: 0.1535 - val_loss: 0.0011 - val_accuracy: 0.1525

Вот хороший график, показывающий точность при использовании binary_crossentropy на протяжении большего числа эпох:
Точность бинарной классификации

Итак, основные вопросы:

  1. Какую правильную классификацию использовать?
  2. Почему моя точность валидации так низка и что я могу сделать, чтобы улучшить её?

Если классы взаимоисключающие, то в идеале следует использовать категориальную перекрёстную энтропию для функции потерь. Однако, бинарная потеря также может работать в некоторой степени, поскольку она всё равно будет поощрять низкую вероятность для неправильных классов и высокую для правильных.

Вы, судя по всему, не нормализуете значение aiOutPossible перед использованием categorical_crossentropy. Определённо следует это сделать, потому что без этого шага вероятности классов не будут суммироваться до 1, что значительно исказит расчёты.

Я думаю, что что-то вроде:

aiOutPossible = aiOutPossible / aiOutPossible.sum()

перед вызовом tf.keras.backend.categorical_crossentropy должно решить эту проблему в вашей пользовательской функции потерь.

Вы также можете обнаружить, что в процессе обучения вам не нужна эта пользовательская функция потерь – дополнительная выгода от фокусировки только на релевантных выходах может быть компенсирована отсутствием штрафов за высокие/уверенные оценки для невозможных выходов. Поэтому вы можете также рассмотреть вариант обучения без фильтра и пользовательской потери, используя фильтр только при оценке тестовых данных и в производстве. Это позволит использовать некоторые встроенные вычисления градиентов в TensorFlow, которые более численно стабильны (используйте from_logits=true). Я не могу с уверенностью сказать, поможет ли это на самом деле, но, на мой взгляд, стоит попробовать, так как это упростит ваш учебный процесс.

Ответ или решение

Неверная функция потерь лучше выполняет функцию потерь точной классификации: анализ и рекомендации

Вопрос о том, почему неправильная функция потерь — binary_crossentropy — выдаёт более высокую точность по сравнению с методически правильной — categorical_crossentropy — при решении задачи классификации с взаимоисключающими классами, требует детального анализа. Рассмотрим основные аспекты проблемы.

Контекст задачи и текущие результаты

Вы работаете с моделью, которая обрабатывает 834 входных признака и предсказывает один из 4096 возможных классов, который является one-hot-кодированным. Из этого множества классов только один верен для каждого входа. Однако число возможных классов для каждого входа ограничено, что и спровоцировало создание специальной функции потерь.

Через использование binary_crossentropy результаты значительно превосходят использование categorical_crossentropy, которое, на первый взгляд, кажется более подходящим из-за one-hot-кодированного формата целевого признака.

Технические причины для выбора функции потерь

  1. Переработка данных выхода и вероятностное распределение: При использовании categorical_crossentropy важно, чтобы выходные данные представляли собой вероятностное распределение, сумма которого равняется единице. Если Вы не нормализуете aiOutPossible, то можете получить некорректные результаты.

  2. Маскирование невозможных классов: Ваш подход с добавлением маски возможных движений уменьшает воздействие невозможных выводов, но может нарушать математическую корректность вероятностного распределения при расчёте categorical_crossentropy.

Рекомендации по улучшению модели

  1. Нормализация выводов: В первую очередь, перед использованием categorical_crossentropy выполняйте нормализацию выводов для получения валидного вероятностного распределения:

    aiOutPossible = aiOutPossible / tf.reduce_sum(aiOutPossible, axis=-1, keepdims=True)
  2. Обратитесь к TensorFlow функциональности: Использование опции from_logits=True в TensorFlow позволяет использовать встроенные функции, которые могут быть более стабильными и производительными.

  3. Сканирование структуры модели: Убедитесь, что структура модели оптимальна для понимания ваших данных; возможно, стоит попробовать изменить количество слоёв или нейронов.

  4. Проводите эксперименты без фильтрации невозможных классов: Попробуйте обучить модель без кастомной функции потерь и фильтрации — оставьте это только для тестирования и внедрения. Это может улучшить способность модели к общему обучению.

Заключение

Проблема с выбором функции потерь достаточно сложная и зависит от множества факторов. Ключевой момент — это корректное построение вероятностного распределения, что может потребовать максимальной внимательности в разработке кастомных функций потерь. Также исследования и эксперименты без фильтрации невозможных классов могут сделать ваш подход более гибким и надёжным.

Оцените материал
Добавить комментарий

Капча загружается...