Вопрос или проблема
Полный набор данных для моей модели CNN слишком велик для 12 ГБ vRAM, которые есть у GPU, и model.fit() вызывает ошибку OOM, когда все данные загружаются. Чтобы решить эту проблему OOM, генератор данных реализован следующим образом:
def my_generator(image_patches, user_ids, binary_flags, batch_size): #<<==функция-генератор. images_patches - это все патчи, созданные из всех тренировочных или тестовых изображений.
for patch, user_id, binary_flag in zip(image_patches, user_ids, binary_flags):
# Нормализуем изображения патчи (например, делим на 255.0)
patch = np.array(patch) / 255.0
# One-hot кодируем user_ids и binary_flags
user_id_one_hot = tf.keras.utils.to_categorical(user_id, num_classes=len(USERID_LIST)) # возвращаем в формате np.array
binary_flag_one_hot = tf.keras.utils.to_categorical(binary_flag, num_classes=2)
yield patch, user_id_one_hot, binary_flag_one_hot. #<<== возвращается только один образец данных
Вот функция return_dataset(), которая определена для использования tf.data.Dataset.from_generator():
def return_dataset(data, window_size=(224, 224), step_size=112, shuffle_buffer_size=1000, prefetch_buffer_size=1):
df = pd.DataFrame(data)
image_paths="./data/"+df['image_id'].values. #<<==пути к изображениям для всех изображений в наборе данных
label_user_ids = df['label_user_id'].values
label_binary_flags = df['label_binary_flag'].values
temp_dataset = preprocess_image_patches(image_paths, label_user_ids, label_binary_flags, window_size, step_size)
#temp_dataset = tuple(temp_dataset.map(augment_data_image))
image_patches, user_ids, binary_flags = temp_dataset. #<<== image_patches - это все патчи изображений, сгенерированные из image_path.
# Создаем набор данных TensorFlow
patch_dataset = tf.data.Dataset.from_generator(
lambda: my_generator(image_patches, user_ids, binary_flags), #(s for s in temp_dataset),
output_types=(tf.float32, tf.int32, tf.int32),
output_shapes=((224, 224, 3), (len(USERID_LIST),), (2,))
)
# Перемешиваем, увеличиваем, формируем пакеты и предварительно загружаем набор данных
patch_dataset = patch_dataset.shuffle(buffer_size=shuffle_buffer_size) # Перемешиваем данные
patch_dataset = patch_dataset.batch(batch_size=BATCH_SIZE) # Формируем пакеты
patch_dataset = patch_dataset.prefetch(buffer_size=tf.data.AUTOTUNE) # Оптимизация для производительности
dataset_size = patch_dataset.reduce(0, lambda x, _: x + 1). #проверяем размер patch_dataset
print(f"patch_dataset size in return_dataset() :", dataset_size.numpy()). ##<<==размер patch_dataset возвращается как размер всего набора данных, что противоречит цели использования from_generator()
yield patch_dataset. #это все еще полный набор данных, что вызвало OOM
Несмотря на то, что my_generator()
возвращает только один образец данных/патч за раз, почему dataset_size.numpy()
выше все еще является полным размером набора данных, а не набором данных BATCH_SIZE
, как указано в .batch()?
Ответ или решение
Проблема, с которой вы столкнулись, связана с тем, как работает tf.data.Dataset.from_generator()
, особенно когда вы используете его в сочетании с reduce
. Давайте разберемся, почему ваш код возвращает полный размер набора данных, даже когда my_generator()
генерирует по одному экземпляру за раз.
Понимание работы tf.data.Dataset
-
Генератор: Когда вы создаете набор данных с помощью
tf.data.Dataset.from_generator()
, он будет вызывать ваш генератор,my_generator()
, для получения данных. Ваш генератор возвращает по одному элементу за раз, что отлично подходит для работы с большими наборами данных, которые не помещаются в память. -
Палитра Данных: Ваша функция
return_dataset()
создает временный набор данныхtemp_dataset
, который затем разбивается наimage_patches
,user_ids
, иbinary_flags
. На основе этих массивов создается tf.data.Dataset с использованием генератора. -
Подсчет элементов: Ваша строка с подсчетом элементов:
dataset_size = patch_dataset.reduce(0, lambda x, _: x + 1)
фактически проходит по всем элементам в
patch_dataset
для получения общего количества примеров.reduce
работает, превращая каждый элемент в одно значение, при этом он должен итерироваться через все данные, чтобы подсчитать количество элементов.Кроме того, поскольку
patch_dataset
создаётся черезfrom_generator
, он требует, чтобы генератор производил все элементы, чтобы правильно оценить размер набора данных, а это приводит к попытке загрузки всех данных, что вызывает ошибку "Out of Memory".
Решение проблемы
Чтобы избежать проблемы с OOM (исчерпанием памяти), вам следует избегать подсчета размера набора данных с использованием reduce
, так как это приведет к итерации по всем данным в наборе.
Вот шаги, которые вы можете предпринять:
-
Не считайте размер набора данных: Уберите строчку с
dataset_size
из функцииreturn_dataset
. Вместо этого, просто возвращайтеpatch_dataset
без проверки его размера.return patch_dataset
-
Используйте в fit: Когда вы будете использовать
model.fit()
, передавайтеpatch_dataset
напрямую, он будет загружать данные в режиме "по требованию", что предотвратит попытку загрузить весь набор данных в объёмной памяти сразу. -
Расширьте загрузку данных: Если вам нужен контроль над процессом загрузки данных (например, для мониторинга или дебага), рассмотрите возможность использования методов
.take()
или.skip()
для итераций без загрузки всего набора данных сразу.
Пример кода
Ваш модифицированный код может выглядеть так:
def return_dataset(data, window_size=(224, 224), step_size=112, shuffle_buffer_size=1000, prefetch_buffer_size=1):
# Ваши предыдущие шаги...
# Создание TensorFlow набора данных
patch_dataset = tf.data.Dataset.from_generator(
lambda: my_generator(image_patches, user_ids, binary_flags),
output_types=(tf.float32, tf.int32, tf.int32),
output_shapes=((224, 224, 3), (len(USERID_LIST),), (2,))
)
# Примените перемешивание, батчи и предвыборку
patch_dataset = patch_dataset.shuffle(buffer_size=shuffle_buffer_size)
patch_dataset = patch_dataset.batch(batch_size=BATCH_SIZE)
patch_dataset = patch_dataset.prefetch(buffer_size=tf.data.AUTOTUNE)
return patch_dataset # Возвратите набор данных без его подсчета
Следуя этим шагам, вы сможете избежать ошибок памяти и использовать данные эффективно благодаря ленивой загрузке.