Вопрос или проблема
Я реализовал WGAN на основе этого блога. Я попытался реализовать define_critic()
как подкласс keras.Model с целью получить тот же результат.
class CNNBlock(layers.Layer):
def __init__(self, channels, kernel_size=(4, 4)):
super(CNNBlock, self).__init__()
w_init = initializers.RandomNormal(stddev=0.02)
const = ClipConstraint(0.01)
self.conv = layers.Conv2D(channels, kernel_size, strides=(2, 2), padding="same",
kernel_initializer=w_init, kernel_constraint=const)
self.bn = layers.BatchNormalization()
self.relu = layers.LeakyReLU(alpha=0.2)
def call(self, input_tensor, training=False):
x = self.conv(input_tensor)
x = self.bn(x)
return self.relu(x)
class CriticModel(Model):
def __init__(self, channel):
super(CriticModel, self).__init__()
self.cnn1 = CNNBlock(channel[0])
self.cnn2 = CNNBlock(channel[1])
def call(self, input_tensor, training=False, mask=None):
x = self.cnn1(input_tensor, training=training)
x = self.cnn2(x, training=training)
x = layers.Flatten()(x)
x = layers.Dense(1)(x)
return x
def get_config(self):
super(CriticModel, self).get_config()
Тем не менее, когда я запускаю код, я получаю ошибку при вызове c_model.train_on_batch(X_real, y_real)
.
Ошибка
ValueError: функция, декорированная tf.function, попыталась создать переменные при не первом вызове.
Я не понимаю, почему мое изменение приводит к этой ошибке.
Весь код можно найти здесь
Я создаю два новых объекта в call()
: layers.Flatten
и layers.Dense
. Когда я создаю объекты в конструкторе, модель работает как ожидается.
class CriticModel(Model):
def __init__(self, channel):
super(CriticModel, self).__init__()
self.cnn1 = CNNBlock(channel[0])
self.cnn2 = CNNBlock(channel[1])
self.flat = layers.Flatten()
self.dense = layers.Dense(1)
def call(self, input_tensor, training=False, mask=None):
x = self.cnn1(input_tensor, training=training)
x = self.cnn2(x, training=training)
# x = layers.Flatten()(x) <-- Это создает новый объект при каждом прямом проходе
# x = layers.Dense(1)(x) <-- Это создает новый объект при каждом прямом проходе
x = self.flat(x)
x = self.dense(x)
return x
def get_config(self):
super(CriticModel, self).get_config()
```
Ответ или решение
Почему поведение подкласса keras.Model отличается от последовательного API?
Вопрос, обсуждаемый в данной ситуации, касается различий между подклассом keras.Model
и последовательным API keras.Sequential
в Keras. Это различие особенно важно при реализации сложных архитектур, таких как Wasserstein Generative Adversarial Network (WGAN), где порядок и способ создания моделей играют критическую роль.
1. Создание переменных
Основная причина ошибки, с которой вы столкнулись, состоит в том, как Keras управляет переменными и их созданием в контексте функции call()
. При использовании layers.Flatten()
и layers.Dense(1)
внутри метода call()
, каждый вызов этой функции создает новые экземпляры этих слоев. Это приводит к ошибке:
ValueError: tf.function-decorated function tried to create variables on non-first call.
Keras ожидает, что на протяжении жизненного цикла модели переменные будут созданы только один раз. Когда вы вызываете слой в первый раз, он инициализирует свои внутренние переменные (например, веса). Если вы пытаетесь вызвать этот слой повторно, Keras ожидает, что эти переменные уже существуют. Однако, если вы создали новый экземпляр слоя в каждом вызове, то Keras не может выполнить это ожидание, поэтому появляется ошибка.
2. Правильное создание слоев
Вместо создания новых экземпляров в методе call()
, вы должны инстанцировать ваши слои (Flatten
, Dense
) в конструкторе класса. Таким образом, вы сохраняете ссылки на уже созданные объекты, и Keras будет использовать их повторно при каждом вызове call()
. Это гарантирует, что переменные создаются только один раз, и вы избежите состояния, когда Keras "смотрит" на новый слой, который не был инициирован в первом вызове.
Ваш исправленный код выглядит следующим образом:
class CriticModel(Model):
def __init__(self, channel):
super(CriticModel, self).__init__()
self.cnn1 = CNNBlock(channel[0])
self.cnn2 = CNNBlock(channel[1])
self.flat = layers.Flatten() # Создаем один экземпляр
self.dense = layers.Dense(1) # Создаем один экземпляр
def call(self, input_tensor, training=False, mask=None):
x = self.cnn1(input_tensor, training=training)
x = self.cnn2(x, training=training)
x = self.flat(x) # Используем существующий экземпляр
x = self.dense(x) # Используем существующий экземпляр
return x
def get_config(self):
super(CriticModel, self).get_config()
3. Структура и гибкость модели
Использование подклассов keras.Model
дает вам больше гибкости в архитектуре, чем keras.Sequential
, который является линейным стеком. Подклассы позволяют создавать более сложные и взаимосвязанные структуры, включая условия и циклы при формировании выходных данных.
4. Заключение
Использование подкласса keras.Model
предоставляет большой потенциал для создания сложных моделей, но требует осторожного подхода к созданию слоев и управлению переменными. Создание экземпляров слоев в конструкторе, а не в методе call()
, является ключевым моментом при избежании ошибок и недостатков, связанных с областями видимости переменных.
Изучение этих аспектов может значительно улучшить вашу работу с Keras и оптимизацию ваших моделей, что особенно важно в исследовательских проектах, таких как WGAN.