Выход промежуточного слоя Keras (модель внимания)

Вопрос или проблема

У меня есть модель с таким резюме:


___________________________
Слой (тип)                     Форма выхода          Кол-во параметров     Подключено к                     
====================================================================================================
input_1 (InputLayer)             (None, 30, 37)        0                                            
____________________________________________________________________________________________________
s0 (InputLayer)                  (None, 128)           0                                            
____________________________________________________________________________________________________
bidirectional_1 (Bidirectional)  (None, 30, 128)       52224       input_1[0][0]                    
____________________________________________________________________________________________________
repeat_vector_1 (RepeatVector)   (None, 30, 128)       0           s0[0][0]                         
                                                                   lstm_1[0][0]                     
                                                                   lstm_1[1][0]                     
                                                                   lstm_1[2][0]                     
                                                                   lstm_1[3][0]                     
                                                                   lstm_1[4][0]                     
                                                                   lstm_1[5][0]                     
                                                                   lstm_1[6][0]                     
                                                                   lstm_1[7][0]                     
                                                                   lstm_1[8][0]                     
____________________________________________________________________________________________________
concatenate_1 (Concatenate)      (None, 30, 256)       0           bidirectional_1[0][0]            
                                                                   repeat_vector_1[0][0]            
                                                                   bidirectional_1[0][0]            
                                                                   repeat_vector_1[1][0]            
                                                                   bidirectional_1[0][0]            
                                                                   repeat_vector_1[2][0]            
                                                                   bidirectional_1[0][0]            
                                                                   repeat_vector_1[3][0]            
                                                                   bidirectional_1[0][0]            
                                                                   repeat_vector_1[4][0]            
                                                                   bidirectional_1[0][0]            
                                                                   repeat_vector_1[5][0]            
                                                                   bidirectional_1[0][0]            
                                                                   repeat_vector_1[6][0]            
                                                                   bidirectional_1[0][0]            
                                                                   repeat_vector_1[7][0]            
                                                                   bidirectional_1[0][0]            
                                                                   repeat_vector_1[8][0]            
                                                                   bidirectional_1[0][0]            
                                                                   repeat_vector_1[9][0]            
____________________________________________________________________________________________________
dense_1 (Dense)                  (None, 30, 1)         257         concatenate_1[0][0]              
                                                                   concatenate_1[1][0]              
                                                                   concatenate_1[2][0]              
                                                                   concatenate_1[3][0]              
                                                                   concatenate_1[4][0]              
                                                                   concatenate_1[5][0]              
                                                                   concatenate_1[6][0]              
                                                                   concatenate_1[7][0]              
                                                                   concatenate_1[8][0]              
                                                                   concatenate_1[9][0]              
____________________________________________________________________________________________________
attention_weights (Activation)   (None, 30, 1)         0           dense_1[0][0]                    
                                                                   dense_1[1][0]                    
                                                                   dense_1[2][0]                    
                                                                   dense_1[3][0]                    
                                                                   dense_1[4][0]                    
                                                                   dense_1[5][0]                    
                                                                   dense_1[6][0]                    
                                                                   dense_1[7][0]                    
                                                                   dense_1[8][0]                    
                                                                   dense_1[9][0]                    
____________________________________________________________________________________________________
dot_1 (Dot)                      (None, 1, 128)        0           attention_weights[0][0]          
                                                                   bidirectional_1[0][0]            
                                                                   attention_weights[1][0]          
                                                                   bidirectional_1[0][0]            
                                                                   attention_weights[2][0]          
                                                                   bidirectional_1[0][0]            
                                                                   attention_weights[3][0]          
                                                                   bidirectional_1[0][0]            
                                                                   attention_weights[4][0]          
                                                                   bidirectional_1[0][0]            
                                                                   attention_weights[5][0]          
                                                                   bidirectional_1[0][0]            
                                                                   attention_weights[6][0]          
                                                                   bidirectional_1[0][0]            
                                                                   attention_weights[7][0]          
                                                                   bidirectional_1[0][0]            
                                                                   attention_weights[8][0]          
                                                                   bidirectional_1[0][0]            
                                                                   attention_weights[9][0]          
                                                                   bidirectional_1[0][0]            
____________________________________________________________________________________________________
c0 (InputLayer)                  (None, 128)           0                                            
____________________________________________________________________________________________________
lstm_1 (LSTM)                    [(None, 128), (None,  131584      dot_1[0][0]                      
                                                                   s0[0][0]                         
                                                                   c0[0][0]                         
                                                                   dot_1[1][0]                      
                                                                   lstm_1[0][0]                     
                                                                   lstm_1[0][2]                     
                                                                   dot_1[2][0]                      
                                                                   lstm_1[1][0]                     
                                                                   lstm_1[1][2]                     
                                                                   dot_1[3][0]                      
                                                                   lstm_1[2][0]                     
                                                                   lstm_1[2][2]                     
                                                                   dot_1[4][0]                      
                                                                   lstm_1[3][0]                     
                                                                   lstm_1[3][2]                     
                                                                   dot_1[5][0]                      
                                                                   lstm_1[4][0]                     
                                                                   lstm_1[4][2]                     
                                                                   dot_1[6][0]                      
                                                                   lstm_1[5][0]                     
                                                                   lstm_1[5][2]                     
                                                                   dot_1[7][0]                      
                                                                   lstm_1[6][0]                     
                                                                   lstm_1[6][2]                     
                                                                   dot_1[8][0]                      
                                                                   lstm_1[7][0]                     
                                                                   lstm_1[7][2]                     
                                                                   dot_1[9][0]                      
                                                                   lstm_1[8][0]                     
                                                                   lstm_1[8][2]                     
____________________________________________________________________________________________________
dense_2 (Dense)                  (None, 11)            1419        lstm_1[0][0]                     
                                                                   lstm_1[1][0]                     
                                                                   lstm_1[2][0]                     
                                                                   lstm_1[3][0]                     
                                                                   lstm_1[4][0]                     
                                                                   lstm_1[5][0]                     
                                                                   lstm_1[6][0]                     
                                                                   lstm_1[7][0]                     
                                                                   lstm_1[8][0]                     
                                                                   lstm_1[9][0]                     
====================================================================================================
Общее количество параметров: 185,484
Обучаемые параметры: 185,484
Необучаемые параметры: 0
____________________________________________________________________________________________________

Модель далее упрощена следующим образом:

введите описание изображения здесь

И блок “attention” упрощен следующим образом:

введите описание изображения здесь

Входом является нечеткая дата, например, “17 ноября 1979 года” (ограниченная 30 символами), а выходом является представление “ГГГГ-мм-дд” в 10 символах.

Я хотел бы построить график значений слоя attention_weights.

Мне хотелось бы увидеть, на какую часть “Saturday, 17th November, 1979” сеть “смотрит”, когда она предсказывает каждый из ГГГГ, мм и дд. Я ожидаю, что она полностью игнорирует день недели (“Saturday”).

Я пытался следовать документации Keras для получения выхода промежуточного слоя.

Тем не менее, узел attention имеет 10 входов, поэтому мне нужно взять каждый из них:

f = K.function(model.inputs, [model.get_layer('attention_weights').get_output_at(t) for t in range(10)])
r = f([source, np.zeros((1,128)), np.zeros((1,128))])

С source, например, “17 ноября 1979 года”, закодировано как

[[[ 0.  0.  0.  0.  1.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
    0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
    0.]
  [ 0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  1.  0.  0.  0.  0.  0.  0.  0.
    0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
    0.]
  [ 1.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
    0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
    0.]
  [ 0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
    0.  0.  0.  0.  0.  0.  0.  1.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
    0.]
  [ 0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
    0.  0.  0.  0.  0.  0.  0.  0.  1.  0.  0.  0.  0.  0.  0.  0.  0.  0.
    0.]
  [ 0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
    0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  1.  0.  0.  0.
    0.]
  [ 0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  1.
    0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
    0.]
  [ 0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
    0.  0.  0.  0.  0.  0.  1.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
    0.]
  [ 0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  1.  0.  0.  0.
    0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
    0.]
  [ 0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  1.
    0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
    0.]
  [ 0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
    0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  1.  0.  0.  0.  0.  0.  0.  0.
    0.]
  [ 1.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
    0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
    0.]
  [ 0.  0.  0.  0.  1.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
    0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
    0.]
  [ 0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  1.  0.  0.  0.  0.  0.
    0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
    0.]
  [ 0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  1.  0.  0.  0.  0.  0.  0.  0.
    0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
    0.]
  [ 0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  1.  0.  0.  0.  0.  0.
    0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
    0.]
  [ 0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  1.  0.  0.  0.  0.  0.
    0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
    0.]
  [ 0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
    0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
    1.]
  [ 0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
    0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
    1.]
  [ 0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
    0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
    1.]
  [ 0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
    0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
    1.]
  [ 0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
    0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
    1.]
  [ 0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
    0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
    1.]
  [ 0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
    0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
    1.]
  [ 0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
    0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
    1.]
  [ 0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
    0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
    1.]
  [ 0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
    0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
    1.]
  [ 0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
    0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
    1.]
  [ 0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
    0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
    1.]]]

r это матрица формы (10,1,30,1) и карту внимания я строю таким образом:

attention_map = np.zeros((10, 30))
for t in range(10):
    for t_prime in range(30):
        attention_map[t][t_prime] = r[t][0,t_prime,0]

…но все значения одинаковые! Я ожидал некое различие.

Я также пробовал добавлять K.learning_phase(), но безуспешно. Что я делаю не так?

Проблема заключалась в том, что я пытался построить карту внимания модели, загруженной из сохраненной модели.

Вывод при сохранении модели был следующим:

/home/opyate/anaconda3/lib/python3.6/site-packages/keras/engine/topology.py:2361:
UserWarning: Слой lstm_1 был передан с несовместимыми ключевыми
аргументами: {‘initial_state’: [, ]}.
Они не будут включены в сериализованную модель (и, следовательно, будут
отсутствовать при десериализации). str(node.arguments) + ‘. Они не
будут включены ‘
/home/opyate/anaconda3/lib/python3.6/site-packages/keras/engine/topology.py:2361:
UserWarning: Слой lstm_1 был передан с несовместимыми ключевыми
аргументами: {‘initial_state’: [, ]}. Они не будут включены в
сериализованную модель (и, следовательно, будут отсутствовать при
десериализации).
str(node.arguments) + ‘. Они не будут включены ‘
/home/opyate/anaconda3/lib/python3.6/site-packages/keras/engine/topology.py:2361:
UserWarning: Слой lstm_1 был передан с несовместимыми ключевыми
аргументами: {‘initial_state’: [,
]}.
Они не будут включены в сериализованную модель (и, следовательно, будут
отсутствовать при десериализации). str(node.arguments) + ‘. Они не
будут включены ‘
/home/opyate/anaconda3/lib/python3.6/site-packages/keras/engine/topology.py:2361:
UserWarning: Слой lstm_1 был передан с несовместимыми ключевыми
аргументами: {‘initial_state’: [,
]}.
Они не будут включены в сериализованную модель (и, следовательно, будут
отсутствовать при десериализации). str(node.arguments) + ‘. Они не
будут включены ‘
/home/opyate/anaconda3/lib/python3.6/site-packages/keras/engine/topology.py:2361:
UserWarning: Слой lstm_1 был передан с несовместимыми ключевыми
аргументами: {‘initial_state’: [,
]}.
Они не будут включены в сериализованную модель (и, следовательно, будут
отсутствовать при десериализации). str(node.arguments) + ‘. Они не
будут включены ‘
/home/opyate/anaconda3/lib/python3.6/site-packages/keras/engine/topology.py:2361:
UserWarning: Слой lstm_1 был передан с несовместимыми ключевыми
аргументами: {‘initial_state’: [,
]}.
Они не будут включены в сериализованную модель (и, следовательно, будут
отсутствовать при десериализации). str(node.arguments) + ‘. Они не
будут включены ‘
/home/opyate/anaconda3/lib/python3.6/site-packages/keras/engine/topology.py:2361:
UserWarning: Слой lstm_1 был передан с несовместимыми ключевыми
аргументами: {‘initial_state’: [,
]}.
Они не будут включены в сериализованную модель (и, следовательно, будут
отсутствовать при десериализации). str(node.arguments) + ‘. Они не
будут включены ‘
/home/opyate/anaconda3/lib/python3.6/site-packages/keras/engine/topology.py:2361:
UserWarning: Слой lstm_1 был передан с несовместимыми ключевыми
аргументами: {‘initial_state’: [,
]}.
Они не будут включены в сериализованную модель (и, следовательно, будут
отсутствовать при десериализации). str(node.arguments) + ‘. Они не
будут включены ‘

Тем не менее, если я создаю модель из кода и просто загружаю сохраненные веса, это работает.

Предположительно, UserWarningы при сохранении модели связаны с моей проблемой.

.

Ответ или решение

Для того чтобы извлечь значения из промежуточного слоя "attention_weights" в модели Keras, вы уже проделали большую работу и столкнулись с проблемой, связанной с инициализацией состояния LSTM слоя, которая изначально не сохранялась должным образом. Давайте разберем этот процесс более детально и предоставим шаги для успешного получения значений из слоя внимания.

Понимание проблемы и ее решение

  1. Инициализация и сохранение модели:
    Когда вы сохраняли модель, возникли предупреждения, что аргументы initial_state LSTM слоя не являются сериализуемыми и поэтому отсутствуют при десериализации. В результате, при загрузке модели из файла она не содержала необходимой начальной информации для корректной работы слоя LSTM.

  2. Загрузка весов вместо модели:
    Ваше решение с загрузкой только весов модели вместо полной загруженной модели оказалось правильным. Это позволило обойти проблему несохраненных аргументов initial_state. При создании модели из кода и последующей загрузке весов все состояния инициируются корректно.

Пошаговый процесс извлечения значений из слоя внимания:

Шаг 1: Подготовка модели

  • Убедитесь, что модель создана из кода и инициализирована правильно, без использования сохраненной структуры модели, которая теряет ключевые части при десериализации.

Шаг 2: Создание функции для извлечения значений

  • Используйте Keras-функционал для получения значений из промежуточного слоя. Это можно сделать через K.function, как вы уже пробовали, но убедитесь, что модель загружена в правильном контексте.

    from keras import backend as K
    
    f = K.function(model.inputs, [model.get_layer('attention_weights').output])
    attention_output = f([source, np.zeros((1,128)), np.zeros((1,128))])

Шаг 3: Построение карты внимания

  • После получения attention_output, вы можете преобразовать его в карту внимания для визуализации.

    import numpy as np
    
    attention_map = np.zeros((10, 30))
    for t in range(10):
        for t_prime in range(30):
            attention_map[t][t_prime] = attention_output[t][0, t_prime, 0]

Шаг 4: Визуализация карты

  • Используйте библиотеки для визуализации, такие как Matplotlib, для отображения карты внимания, чтобы анализировать, какие части ввода модель "внимательно осматривает" при предсказании каждого из компонентов даты:

    import matplotlib.pyplot as plt
    
    plt.imshow(attention_map, cmap='hot', interpolation='nearest')
    plt.title("Attention Map")
    plt.xlabel("Input Sequence")
    plt.ylabel("Output Sequence")
    plt.show()

Обобщение

Для успешного извлечения и анализа значений из промежуточного слоя внимания в Keras, важно обеспечить правильную инициализацию модели в коде и загрузку только весов, если возникают проблемы с сериализацией. Это позволяет сохранить целостность модели и корректность работы всех ее частей, включая слои, зависящие от начальных состояний, такие как LSTM.

Оцените материал
Добавить комментарий

Капча загружается...