Почему подкласс keras.Model ведет себя иначе, чем последовательный API?

Вопрос или проблема

Я реализовал WGAN на основе этого блога. Я попытался реализовать define_critic() как подкласс keras.Model с целью получить тот же результат.

class CNNBlock(layers.Layer):

    def __init__(self, channels, kernel_size=(4, 4)):
        super(CNNBlock, self).__init__()

        w_init = initializers.RandomNormal(stddev=0.02)
        const = ClipConstraint(0.01)

        self.conv = layers.Conv2D(channels, kernel_size, strides=(2, 2), padding="same",
                                  kernel_initializer=w_init, kernel_constraint=const)
        self.bn = layers.BatchNormalization()
        self.relu = layers.LeakyReLU(alpha=0.2)

    def call(self, input_tensor, training=False):
        x = self.conv(input_tensor)
        x = self.bn(x)
        return self.relu(x)

class CriticModel(Model):

    def __init__(self, channel):
        super(CriticModel, self).__init__()
        self.cnn1 = CNNBlock(channel[0])
        self.cnn2 = CNNBlock(channel[1])

    def call(self, input_tensor, training=False, mask=None):
        x = self.cnn1(input_tensor, training=training)
        x = self.cnn2(x, training=training)
        x = layers.Flatten()(x)
        x = layers.Dense(1)(x)
        return x

    def get_config(self):
        super(CriticModel, self).get_config()

Тем не менее, когда я запускаю код, я получаю ошибку при вызове c_model.train_on_batch(X_real, y_real).

Ошибка

ValueError: функция, декорированная tf.function, попыталась создать переменные при не первом вызове.

Я не понимаю, почему мое изменение приводит к этой ошибке.

Весь код можно найти здесь

Я создаю два новых объекта в call(): layers.Flatten и layers.Dense. Когда я создаю объекты в конструкторе, модель работает как ожидается.

class CriticModel(Model):

    def __init__(self, channel):
        super(CriticModel, self).__init__()
        self.cnn1 = CNNBlock(channel[0])
        self.cnn2 = CNNBlock(channel[1])
        self.flat = layers.Flatten()
        self.dense = layers.Dense(1)

    def call(self, input_tensor, training=False, mask=None):
        x = self.cnn1(input_tensor, training=training)
        x = self.cnn2(x, training=training)
        # x = layers.Flatten()(x) <-- Это создает новый объект при каждом прямом проходе
        # x = layers.Dense(1)(x)  <-- Это создает новый объект при каждом прямом проходе
        x = self.flat(x)
        x = self.dense(x)
        return x

    def get_config(self):
        super(CriticModel, self).get_config()
```

Ответ или решение

Почему поведение подкласса keras.Model отличается от последовательного API?

Вопрос, обсуждаемый в данной ситуации, касается различий между подклассом keras.Model и последовательным API keras.Sequential в Keras. Это различие особенно важно при реализации сложных архитектур, таких как Wasserstein Generative Adversarial Network (WGAN), где порядок и способ создания моделей играют критическую роль.

1. Создание переменных

Основная причина ошибки, с которой вы столкнулись, состоит в том, как Keras управляет переменными и их созданием в контексте функции call(). При использовании layers.Flatten() и layers.Dense(1) внутри метода call(), каждый вызов этой функции создает новые экземпляры этих слоев. Это приводит к ошибке:

ValueError: tf.function-decorated function tried to create variables on non-first call.

Keras ожидает, что на протяжении жизненного цикла модели переменные будут созданы только один раз. Когда вы вызываете слой в первый раз, он инициализирует свои внутренние переменные (например, веса). Если вы пытаетесь вызвать этот слой повторно, Keras ожидает, что эти переменные уже существуют. Однако, если вы создали новый экземпляр слоя в каждом вызове, то Keras не может выполнить это ожидание, поэтому появляется ошибка.

2. Правильное создание слоев

Вместо создания новых экземпляров в методе call(), вы должны инстанцировать ваши слои (Flatten, Dense) в конструкторе класса. Таким образом, вы сохраняете ссылки на уже созданные объекты, и Keras будет использовать их повторно при каждом вызове call(). Это гарантирует, что переменные создаются только один раз, и вы избежите состояния, когда Keras "смотрит" на новый слой, который не был инициирован в первом вызове.

Ваш исправленный код выглядит следующим образом:

class CriticModel(Model):

    def __init__(self, channel):
        super(CriticModel, self).__init__()
        self.cnn1 = CNNBlock(channel[0])
        self.cnn2 = CNNBlock(channel[1])
        self.flat = layers.Flatten()  # Создаем один экземпляр
        self.dense = layers.Dense(1)   # Создаем один экземпляр

    def call(self, input_tensor, training=False, mask=None):
        x = self.cnn1(input_tensor, training=training)
        x = self.cnn2(x, training=training)
        x = self.flat(x)  # Используем существующий экземпляр
        x = self.dense(x)  # Используем существующий экземпляр
        return x

    def get_config(self):
        super(CriticModel, self).get_config()

3. Структура и гибкость модели

Использование подклассов keras.Model дает вам больше гибкости в архитектуре, чем keras.Sequential, который является линейным стеком. Подклассы позволяют создавать более сложные и взаимосвязанные структуры, включая условия и циклы при формировании выходных данных.

4. Заключение

Использование подкласса keras.Model предоставляет большой потенциал для создания сложных моделей, но требует осторожного подхода к созданию слоев и управлению переменными. Создание экземпляров слоев в конструкторе, а не в методе call(), является ключевым моментом при избежании ошибок и недостатков, связанных с областями видимости переменных.

Изучение этих аспектов может значительно улучшить вашу работу с Keras и оптимизацию ваших моделей, что особенно важно в исследовательских проектах, таких как WGAN.

Оцените материал
Добавить комментарий

Капча загружается...