Пользовательская функция потерь в Tensorflow для UNet

Вопрос или проблема

Я работаю над задачей сегментации, где я планировал использовать U-Net

для входного изображения формы (224,224,3), выходом должна быть маска изображения формы (224,224,1)

Маска изображения содержит два уникальных значения – черный [0] и белый [1]

выходной слой из UNet имеет тензор формы (None, 224, 224, 1)

Я использовал функцию активации softmax для выходного слоя

Формы и типы для обучающих данных

print(trainX.shape) # (200, 224, 224, 3)
print(testX.shape) # (50, 224, 224, 3)
print(trainY.shape) # (200, 224, 224, 1)
print(testY.shape) # (50, 224, 224, 1)

print(trainX.dtype) # float64
print(testX.dtype) # float64
print(trainY.dtype) # int16
print(testY.dtype) # int16

В маске изображения черных пикселей гораздо больше, чем белых пикселей. Для балансировки черных и белых пикселей я планировал использовать веса классов [для черного - 0.53083749, для белого - 8.60701406] в обучении. Поэтому я написал эту функцию

def lossFunc(true, pred):
  weightsList = K.constant([0.53083749, 8.60701406])
  true = K.reshape(true, [-1])
  pred = K.squeeze(pred, axis=3)
  sample_weightsList = K.gather(weightsList, true)
  loss = keras.losses.sparse_categorical_crossentropy(true,pred)
  loss*sample_weightsList

  return loss

Но когда я начал обучение, я получил эту ошибку

InvalidArgumentError                      Traceback (most recent call last)
<timed exec> in <module>

/usr/local/lib/python3.10/dist-packages/keras/utils/traceback_utils.py in error_handler(*args, **kwargs)
     68             # Чтобы получить полный стек вызовов, вызовите:
     69             # `tf.debugging.disable_traceback_filtering()`
---> 70             raise e.with_traceback(filtered_tb) from None
     71         finally:
     72             del filtered_tb

/usr/local/lib/python3.10/dist-packages/tensorflow/python/eager/execute.py in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
     50   try:
     51     ctx.ensure_initialized()
---> 52     tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
     53                                         inputs, attrs, num_outputs)
     54   except core._NotOkStatusException as e:

InvalidArgumentError: Не зарегистрирован ни один OpKernel для поддержки Op 'GatherV2', используемого {{node lossFunc/GatherV2}} с этими атрибутами: [Tparams=DT_FLOAT, Tindices=DT_INT16, batch_dims=0, Taxis=DT_INT32]
Зарегистрированные устройства: [CPU, GPU]
Зарегистрированные ядра:
  device="XLA_CPU_JIT"; Taxis in [DT_INT32, DT_INT64]; Tindices in [DT_INT32, DT_INT16, DT_INT64]; Tparams in [DT_FLOAT, DT_DOUBLE, DT_INT32, DT_UINT8, DT_INT16, 930109355527764061, DT_HALF, DT_UINT32, DT_UINT64, DT_FLOAT8_E5M2, DT_FLOAT8_E4M3FN]
  device="XLA_GPU_JIT"; Taxis in [DT_INT32, DT_INT64]; Tindices in [DT_INT32, DT_INT16, DT_INT64]; Tparams in [DT_FLOAT, DT_DOUBLE, DT_INT32, DT_UINT8, DT_INT16, 930109355527764061, DT_HALF, DT_UINT32, DT_UINT64, DT_FLOAT8_E5M2, DT_FLOAT8_E4M3FN]
  device="CPU"; Tparams in [DT_QINT16]; Tindices in [DT_INT64]
  device="CPU"; Tparams in [DT_QINT16]; Tindices in [DT_INT32]
  device="CPU"; Tparams in [DT_QUINT16]; Tindices in [DT_INT64]
  device="CPU"; Tparams in [DT_QUINT16]; Tindices in [DT_INT32]
  device="CPU"; Tparams in [DT_QINT32]; Tindices in [DT_INT64]
  device="CPU"; Tparams in [DT_QINT32]; Tindices in [DT_INT32]
  device="CPU"; Tparams in [DT_QUINT8]; Tindices in [DT_INT64]
  device="CPU"; Tparams in [DT_QUINT8]; Tindices in [DT_INT32]
  device="CPU"; Tparams in [DT_QINT8]; Tindices in [DT_INT64]
  device="CPU"; Tparams in [DT_QINT8]; Tindices in [DT_INT32]
  device="CPU"; Tparams in [DT_VARIANT]; Tindices in [DT_INT64]
  device="CPU"; Tparams in [DT_VARIANT]; Tindices in [DT_INT32]
  device="CPU"; Tparams in [DT_RESOURCE]; Tindices in [DT_INT64]
  device="CPU"; Tparams in [DT_RESOURCE]; Tindices in [DT_INT32]
  device="CPU"; Tparams in [DT_STRING]; Tindices in [DT_INT64]
  device="CPU"; Tparams in [DT_STRING]; Tindices in [DT_INT32]
  device="CPU"; Tparams in [DT_BOOL]; Tindices in [DT_INT64]
  device="CPU"; Tparams in [DT_BOOL]; Tindices in [DT_INT32]
  device="CPU"; Tparams in [DT_COMPLEX128]; Tindices in [DT_INT64]
  device="CPU"; Tparams in [DT_COMPLEX128]; Tindices in [DT_INT32]
  device="CPU"; Tparams in [DT_COMPLEX64]; Tindices in [DT_INT64]
  device="CPU"; Tparams in [DT_COMPLEX64]; Tindices in [DT_INT32]
  device="CPU"; Tparams in [DT_DOUBLE]; Tindices in [DT_INT64]
  device="CPU"; Tparams in [DT_DOUBLE]; Tindices in [DT_INT32]
  device="CPU"; Tparams in [DT_FLOAT]; Tindices in [DT_INT64]
  device="CPU"; Tparams in [DT_FLOAT]; Tindices in [DT_INT32]
  device="CPU"; Tparams in [DT_BFLOAT16]; Tindices in [DT_INT64]
  device="CPU"; Tparams in [DT_BFLOAT16]; Tindices in [DT_INT32]
  device="CPU"; Tparams in [DT_HALF]; Tindices in [DT_INT64]
  device="CPU"; Tparams in [DT_HALF]; Tindices in [DT_INT32]
  device="CPU"; Tparams in [DT_INT32]; Tindices in [DT_INT64]
  device="CPU"; Tparams in [DT_INT32]; Tindices in [DT_INT32]
  device="CPU"; Tparams in [DT_INT8]; Tindices in [DT_INT64]
  device="CPU"; Tparams in [DT_INT8]; Tindices in [DT_INT32]
  device="CPU"; Tparams in [DT_UINT8]; Tindices in [DT_INT64]
  device="CPU"; Tparams in [DT_UINT8]; Tindices in [DT_INT32]
  device="CPU"; Tparams in [DT_INT16]; Tindices in [DT_INT64]
  device="CPU"; Tparams in [DT_INT16]; Tindices in [DT_INT32]
  device="CPU"; Tparams in [DT_UINT16]; Tindices in [DT_INT64]
  device="CPU"; Tparams in [DT_UINT16]; Tindices in [DT_INT32]
  device="CPU"; Tparams in [DT_UINT32]; Tindices in [DT_INT64]
  device="CPU"; Tparams in [DT_UINT32]; Tindices in [DT_INT32]
  device="CPU"; Tparams in [DT_INT64]; Tindices in [DT_INT64]
  device="CPU"; Tparams in [DT_INT64]; Tindices in [DT_INT32]
  device="CPU"; Tparams in [DT_UINT64]; Tindices in [DT_INT64]
  device="CPU"; Tparams in [DT_UINT64]; Tindices in [DT_INT32]
  device="GPU"; Tparams in [DT_BOOL]; Tindices in [DT_INT64]
  device="GPU"; Tparams in [DT_BOOL]; Tindices in [DT_INT32]
  device="GPU"; Tparams in [DT_COMPLEX128]; Tindices in [DT_INT64]
  device="GPU"; Tparams in [DT_COMPLEX128]; Tindices in [DT_INT32]
  device="GPU"; Tparams in [DT_COMPLEX64]; Tindices in [DT_INT64]
  device="GPU"; Tparams in [DT_COMPLEX64]; Tindices in [DT_INT32]
  device="GPU"; Tparams in [DT_DOUBLE]; Tindices in [DT_INT64]
  device="GPU"; Tparams in [DT_DOUBLE]; Tindices in [DT_INT32]
  device="GPU"; Tparams in [DT_FLOAT]; Tindices in [DT_INT64]
  device="GPU"; Tparams in [DT_FLOAT]; Tindices in [DT_INT32]
  device="GPU"; Tparams in [DT_BFLOAT16]; Tindices in [DT_INT64]
  device="GPU"; Tparams in [DT_BFLOAT16]; Tindices in [DT_INT32]
  device="GPU"; Tparams in [DT_HALF]; Tindices in [DT_INT64]
  device="GPU"; Tparams in [DT_HALF]; Tindices in [DT_INT32]
  device="GPU"; Tparams in [DT_INT64]; Tindices in [DT_INT64]
  device="GPU"; Tparams in [DT_INT64]; Tindices in [DT_INT32]
  device="GPU"; Tparams in [DT_INT32]; Tindices in [DT_INT64]
  device="GPU"; Tparams in [DT_INT32]; Tindices in [DT_INT32]

     [[lossFunc/GatherV2]] [Op:__inference_train_function_6870]

В чем проблема и как это решить

Спасибо

Согласно документации, функция gather требует индексы int32 или int64, в то время как вы, похоже, предоставляете int16. В основном попробуйте привести их к int32:

def lossFunc(true, pred):
  weightsList = K.constant([0.53083749, 8.60701406])
  true = K.cast(K.reshape(true, [-1]), "int32")  # или tf.cast(..., tf.int32)
  pred = K.squeeze(pred, axis=3)
  sample_weightsList = K.gather(weightsList, true)
  loss = keras.losses.sparse_categorical_crossentropy(true,pred)
  loss*sample_weightsList

  return loss

Ответ или решение

Создание пользовательской функции потерь в TensorFlow для сетей U-Net — важный шаг при работе с задачами сегментации, особенно когда ваши данные имеют несбалансированные классы. В вашем случае, задача состоит в объемной сегментации изображений, где выходным слоем является маска с двумя уникальными значениями: черным (0) и белым (1). Рассмотрим, как исправить ошибку, возникающую при использовании пользовательской функции потерь, а также оптимизируем код, чтобы улучшить его читаемость и эффективность.

Ошибка, возникающая в вашем коде

Как вы заметили, ошибка возникла из-за несовпадения типов данных. Функция tf.gather требует индексы типа int32 или int64, в то время как ваши true метки имеют тип данных int16. Это приводит к сбою в процессе обучения модели.

Как исправить ошибку

Вам необходимо использовать явное преобразование типа данных для меток true. Это можно сделать с помощью функции tf.cast. Давайте рассмотрим исправленный код вашей функции потерь:

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import backend as K

def lossFunc(true, pred):
    # Задать веса классов как константу
    weightsList = K.constant([0.53083749, 8.60701406])

    # Преобразовать метки true в одномерный вид и изменить тип на int32
    true = K.cast(K.reshape(true, [-1]), "int32")

    # Удалить размерность в предсказании
    pred = K.squeeze(pred, axis=3)

    # Получить веса для выборки
    sample_weightsList = K.gather(weightsList, true)

    # Вычислить потери с использованием функции потерь 'sparse_categorical_crossentropy'
    loss = keras.losses.sparse_categorical_crossentropy(true, pred)

    # Применить веса к потерям
    loss *= sample_weightsList

    return loss

Объяснение кода

  1. Определение весов классов: Мы задаем веса классов как константу. Это поможет сместить акцент на меньший класс бензином для эффективного обучения модели.

  2. Преобразование меток в int32: Явное преобразование меток в int32 позволяет избежать ошибок во время вызова tf.gather.

  3. Предобработка предсказаний: Удаление лишней размерности из предсказаний помогает упростить вычисления.

  4. Применение выбранных весов: Мы используем функцию потерь sparse_categorical_crossentropy, подходящую для случаев, когда метки классов представлены в виде целых чисел. На финальном этапе мы умножаем потери на веса, чтобы учесть несбалансированность классов.

Проверка модели

После внесенных изменений вам следует убедиться, что ваша модель искажается правильно во время обучения. Для этого запустите процедуру обучения с использованием новой пользовательской функции потерь и проверьте метрики. Важно следить за изменениями в значении функции потерь, чтобы удостовериться, что модель корректно обрабатывает обучающие данные.

Заметки по улучшению

  1. Кросс-валидация: Разбейте ваши данные на кросс-валидационные наборы, чтобы объективно оценить производительность вашей модели.
  2. Аугментация данных: Учитывая несбалансированность классов, рассмотрите возможность применения аугментации данных для увеличения числа примеров белого класса.
  3. Мониторинг обучения: Используйте инструменты, такие как TensorBoard, чтобы отслеживать метрики во время тренировки.

Таким образом, ваша пользовательская функция потерь теперь корректно обрабатывает несбалансированные данные, и вы сможете эффективно обучить модель U-Net для решения задачи сегментации. Успехов в вашем проекте!

Оцените материал
Добавить комментарий

Капча загружается...