Вопрос или проблема
Я работаю над задачей сегментации, где я планировал использовать U-Net
для входного изображения формы (224,224,3)
, выходом должна быть маска изображения формы (224,224,1)
Маска изображения содержит два уникальных значения – черный [0
] и белый [1
]
выходной слой из UNet имеет тензор формы (None, 224, 224, 1)
Я использовал функцию активации softmax
для выходного слоя
Формы и типы для обучающих данных
print(trainX.shape) # (200, 224, 224, 3)
print(testX.shape) # (50, 224, 224, 3)
print(trainY.shape) # (200, 224, 224, 1)
print(testY.shape) # (50, 224, 224, 1)
print(trainX.dtype) # float64
print(testX.dtype) # float64
print(trainY.dtype) # int16
print(testY.dtype) # int16
В маске изображения черных пикселей гораздо больше, чем белых пикселей. Для балансировки черных и белых пикселей я планировал использовать веса классов [для черного - 0.53083749, для белого - 8.60701406
] в обучении. Поэтому я написал эту функцию
def lossFunc(true, pred):
weightsList = K.constant([0.53083749, 8.60701406])
true = K.reshape(true, [-1])
pred = K.squeeze(pred, axis=3)
sample_weightsList = K.gather(weightsList, true)
loss = keras.losses.sparse_categorical_crossentropy(true,pred)
loss*sample_weightsList
return loss
Но когда я начал обучение, я получил эту ошибку
InvalidArgumentError Traceback (most recent call last)
<timed exec> in <module>
/usr/local/lib/python3.10/dist-packages/keras/utils/traceback_utils.py in error_handler(*args, **kwargs)
68 # Чтобы получить полный стек вызовов, вызовите:
69 # `tf.debugging.disable_traceback_filtering()`
---> 70 raise e.with_traceback(filtered_tb) from None
71 finally:
72 del filtered_tb
/usr/local/lib/python3.10/dist-packages/tensorflow/python/eager/execute.py in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
50 try:
51 ctx.ensure_initialized()
---> 52 tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
53 inputs, attrs, num_outputs)
54 except core._NotOkStatusException as e:
InvalidArgumentError: Не зарегистрирован ни один OpKernel для поддержки Op 'GatherV2', используемого {{node lossFunc/GatherV2}} с этими атрибутами: [Tparams=DT_FLOAT, Tindices=DT_INT16, batch_dims=0, Taxis=DT_INT32]
Зарегистрированные устройства: [CPU, GPU]
Зарегистрированные ядра:
device="XLA_CPU_JIT"; Taxis in [DT_INT32, DT_INT64]; Tindices in [DT_INT32, DT_INT16, DT_INT64]; Tparams in [DT_FLOAT, DT_DOUBLE, DT_INT32, DT_UINT8, DT_INT16, 930109355527764061, DT_HALF, DT_UINT32, DT_UINT64, DT_FLOAT8_E5M2, DT_FLOAT8_E4M3FN]
device="XLA_GPU_JIT"; Taxis in [DT_INT32, DT_INT64]; Tindices in [DT_INT32, DT_INT16, DT_INT64]; Tparams in [DT_FLOAT, DT_DOUBLE, DT_INT32, DT_UINT8, DT_INT16, 930109355527764061, DT_HALF, DT_UINT32, DT_UINT64, DT_FLOAT8_E5M2, DT_FLOAT8_E4M3FN]
device="CPU"; Tparams in [DT_QINT16]; Tindices in [DT_INT64]
device="CPU"; Tparams in [DT_QINT16]; Tindices in [DT_INT32]
device="CPU"; Tparams in [DT_QUINT16]; Tindices in [DT_INT64]
device="CPU"; Tparams in [DT_QUINT16]; Tindices in [DT_INT32]
device="CPU"; Tparams in [DT_QINT32]; Tindices in [DT_INT64]
device="CPU"; Tparams in [DT_QINT32]; Tindices in [DT_INT32]
device="CPU"; Tparams in [DT_QUINT8]; Tindices in [DT_INT64]
device="CPU"; Tparams in [DT_QUINT8]; Tindices in [DT_INT32]
device="CPU"; Tparams in [DT_QINT8]; Tindices in [DT_INT64]
device="CPU"; Tparams in [DT_QINT8]; Tindices in [DT_INT32]
device="CPU"; Tparams in [DT_VARIANT]; Tindices in [DT_INT64]
device="CPU"; Tparams in [DT_VARIANT]; Tindices in [DT_INT32]
device="CPU"; Tparams in [DT_RESOURCE]; Tindices in [DT_INT64]
device="CPU"; Tparams in [DT_RESOURCE]; Tindices in [DT_INT32]
device="CPU"; Tparams in [DT_STRING]; Tindices in [DT_INT64]
device="CPU"; Tparams in [DT_STRING]; Tindices in [DT_INT32]
device="CPU"; Tparams in [DT_BOOL]; Tindices in [DT_INT64]
device="CPU"; Tparams in [DT_BOOL]; Tindices in [DT_INT32]
device="CPU"; Tparams in [DT_COMPLEX128]; Tindices in [DT_INT64]
device="CPU"; Tparams in [DT_COMPLEX128]; Tindices in [DT_INT32]
device="CPU"; Tparams in [DT_COMPLEX64]; Tindices in [DT_INT64]
device="CPU"; Tparams in [DT_COMPLEX64]; Tindices in [DT_INT32]
device="CPU"; Tparams in [DT_DOUBLE]; Tindices in [DT_INT64]
device="CPU"; Tparams in [DT_DOUBLE]; Tindices in [DT_INT32]
device="CPU"; Tparams in [DT_FLOAT]; Tindices in [DT_INT64]
device="CPU"; Tparams in [DT_FLOAT]; Tindices in [DT_INT32]
device="CPU"; Tparams in [DT_BFLOAT16]; Tindices in [DT_INT64]
device="CPU"; Tparams in [DT_BFLOAT16]; Tindices in [DT_INT32]
device="CPU"; Tparams in [DT_HALF]; Tindices in [DT_INT64]
device="CPU"; Tparams in [DT_HALF]; Tindices in [DT_INT32]
device="CPU"; Tparams in [DT_INT32]; Tindices in [DT_INT64]
device="CPU"; Tparams in [DT_INT32]; Tindices in [DT_INT32]
device="CPU"; Tparams in [DT_INT8]; Tindices in [DT_INT64]
device="CPU"; Tparams in [DT_INT8]; Tindices in [DT_INT32]
device="CPU"; Tparams in [DT_UINT8]; Tindices in [DT_INT64]
device="CPU"; Tparams in [DT_UINT8]; Tindices in [DT_INT32]
device="CPU"; Tparams in [DT_INT16]; Tindices in [DT_INT64]
device="CPU"; Tparams in [DT_INT16]; Tindices in [DT_INT32]
device="CPU"; Tparams in [DT_UINT16]; Tindices in [DT_INT64]
device="CPU"; Tparams in [DT_UINT16]; Tindices in [DT_INT32]
device="CPU"; Tparams in [DT_UINT32]; Tindices in [DT_INT64]
device="CPU"; Tparams in [DT_UINT32]; Tindices in [DT_INT32]
device="CPU"; Tparams in [DT_INT64]; Tindices in [DT_INT64]
device="CPU"; Tparams in [DT_INT64]; Tindices in [DT_INT32]
device="CPU"; Tparams in [DT_UINT64]; Tindices in [DT_INT64]
device="CPU"; Tparams in [DT_UINT64]; Tindices in [DT_INT32]
device="GPU"; Tparams in [DT_BOOL]; Tindices in [DT_INT64]
device="GPU"; Tparams in [DT_BOOL]; Tindices in [DT_INT32]
device="GPU"; Tparams in [DT_COMPLEX128]; Tindices in [DT_INT64]
device="GPU"; Tparams in [DT_COMPLEX128]; Tindices in [DT_INT32]
device="GPU"; Tparams in [DT_COMPLEX64]; Tindices in [DT_INT64]
device="GPU"; Tparams in [DT_COMPLEX64]; Tindices in [DT_INT32]
device="GPU"; Tparams in [DT_DOUBLE]; Tindices in [DT_INT64]
device="GPU"; Tparams in [DT_DOUBLE]; Tindices in [DT_INT32]
device="GPU"; Tparams in [DT_FLOAT]; Tindices in [DT_INT64]
device="GPU"; Tparams in [DT_FLOAT]; Tindices in [DT_INT32]
device="GPU"; Tparams in [DT_BFLOAT16]; Tindices in [DT_INT64]
device="GPU"; Tparams in [DT_BFLOAT16]; Tindices in [DT_INT32]
device="GPU"; Tparams in [DT_HALF]; Tindices in [DT_INT64]
device="GPU"; Tparams in [DT_HALF]; Tindices in [DT_INT32]
device="GPU"; Tparams in [DT_INT64]; Tindices in [DT_INT64]
device="GPU"; Tparams in [DT_INT64]; Tindices in [DT_INT32]
device="GPU"; Tparams in [DT_INT32]; Tindices in [DT_INT64]
device="GPU"; Tparams in [DT_INT32]; Tindices in [DT_INT32]
[[lossFunc/GatherV2]] [Op:__inference_train_function_6870]
В чем проблема и как это решить
Спасибо
Согласно документации, функция gather
требует индексы int32
или int64
, в то время как вы, похоже, предоставляете int16
. В основном попробуйте привести их к int32
:
def lossFunc(true, pred):
weightsList = K.constant([0.53083749, 8.60701406])
true = K.cast(K.reshape(true, [-1]), "int32") # или tf.cast(..., tf.int32)
pred = K.squeeze(pred, axis=3)
sample_weightsList = K.gather(weightsList, true)
loss = keras.losses.sparse_categorical_crossentropy(true,pred)
loss*sample_weightsList
return loss
Ответ или решение
Создание пользовательской функции потерь в TensorFlow для сетей U-Net — важный шаг при работе с задачами сегментации, особенно когда ваши данные имеют несбалансированные классы. В вашем случае, задача состоит в объемной сегментации изображений, где выходным слоем является маска с двумя уникальными значениями: черным (0) и белым (1). Рассмотрим, как исправить ошибку, возникающую при использовании пользовательской функции потерь, а также оптимизируем код, чтобы улучшить его читаемость и эффективность.
Ошибка, возникающая в вашем коде
Как вы заметили, ошибка возникла из-за несовпадения типов данных. Функция tf.gather
требует индексы типа int32
или int64
, в то время как ваши true
метки имеют тип данных int16
. Это приводит к сбою в процессе обучения модели.
Как исправить ошибку
Вам необходимо использовать явное преобразование типа данных для меток true
. Это можно сделать с помощью функции tf.cast
. Давайте рассмотрим исправленный код вашей функции потерь:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import backend as K
def lossFunc(true, pred):
# Задать веса классов как константу
weightsList = K.constant([0.53083749, 8.60701406])
# Преобразовать метки true в одномерный вид и изменить тип на int32
true = K.cast(K.reshape(true, [-1]), "int32")
# Удалить размерность в предсказании
pred = K.squeeze(pred, axis=3)
# Получить веса для выборки
sample_weightsList = K.gather(weightsList, true)
# Вычислить потери с использованием функции потерь 'sparse_categorical_crossentropy'
loss = keras.losses.sparse_categorical_crossentropy(true, pred)
# Применить веса к потерям
loss *= sample_weightsList
return loss
Объяснение кода
-
Определение весов классов: Мы задаем веса классов как константу. Это поможет сместить акцент на меньший класс бензином для эффективного обучения модели.
-
Преобразование меток в int32: Явное преобразование меток в
int32
позволяет избежать ошибок во время вызоваtf.gather
. -
Предобработка предсказаний: Удаление лишней размерности из предсказаний помогает упростить вычисления.
-
Применение выбранных весов: Мы используем функцию потерь
sparse_categorical_crossentropy
, подходящую для случаев, когда метки классов представлены в виде целых чисел. На финальном этапе мы умножаем потери на веса, чтобы учесть несбалансированность классов.
Проверка модели
После внесенных изменений вам следует убедиться, что ваша модель искажается правильно во время обучения. Для этого запустите процедуру обучения с использованием новой пользовательской функции потерь и проверьте метрики. Важно следить за изменениями в значении функции потерь, чтобы удостовериться, что модель корректно обрабатывает обучающие данные.
Заметки по улучшению
- Кросс-валидация: Разбейте ваши данные на кросс-валидационные наборы, чтобы объективно оценить производительность вашей модели.
- Аугментация данных: Учитывая несбалансированность классов, рассмотрите возможность применения аугментации данных для увеличения числа примеров белого класса.
- Мониторинг обучения: Используйте инструменты, такие как TensorBoard, чтобы отслеживать метрики во время тренировки.
Таким образом, ваша пользовательская функция потерь теперь корректно обрабатывает несбалансированные данные, и вы сможете эффективно обучить модель U-Net для решения задачи сегментации. Успехов в вашем проекте!