Как загрузить предварительно обученную модель VGG16 с помощью TensorFlow и сгенерировать матрицу ошибок.

Вопрос или проблема

Я натренировал модель VGG16 (из tensorflow.keras.applications) на 15 категориях изображений в течение 100 эпох. После обучения я сохранил модель в файл с именем ‘best_model.h5’, но, к сожалению, забыл включить код в свой скрипт для генерации матрицы ошибок. Теперь, чтобы получить матрицу ошибок, возможно ли как-то использовать сохраненный файл модели, который у меня есть (best_model.h5), не проходя через всю процедуру повторного обучения модели с нуля?

Вот код, который я использовал для обучения своей модели

import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications import VGG16
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Flatten, Dropout
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ModelCheckpoint

print("Настройка директорий для обучающих и валидационных данных...")
train_dir="/scratch/user/nabarunkar/Disaster Dataset/train"
validation_dir="/scratch/user/nabarunkar/Disaster Dataset/validation"

print("Предварительная обработка изображений для увеличения данных...")
train_datagen = ImageDataGenerator(
    rescale=1./255,
    rotation_range=20,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
    fill_mode="nearest"
)

validation_datagen = ImageDataGenerator(rescale=1./255)

print("Загрузка изображений из директорий...")

train_generator = train_datagen.flow_from_directory(
    train_dir,
    target_size=(224, 224),  # VGG16 ожидает вход 224x224
    batch_size=32,
    class_mode="categorical"  # Убедитесь, что это установлено как 'categorical'
)

validation_generator = validation_datagen.flow_from_directory(
    validation_dir,
    target_size=(224, 224),
    batch_size=32,
    class_mode="categorical"  # Убедитесь, что это установлено как 'categorical'
)

print("Создание модели VGG16...")
base_model = VGG16(weights="imagenet", include_top=False, input_shape=(224, 224, 3))

print("Заморозка базовой модели...")
base_model.trainable = False

print("Добавление пользовательских слоев классификации сверху...")
model = Sequential([
    base_model,
    Flatten(),
    Dense(256, activation='relu'),
    Dropout(0.5),
    Dense(15, activation='softmax') # 15 категорий
])

print("Компиляция модели...")
model.compile(optimizer=Adam(learning_rate=0.0001),
              loss="categorical_crossentropy",
              metrics=['accuracy'])

print("Обратные вызовы...")
checkpoint = ModelCheckpoint(
    'best_model2.h5',  # Путь для сохранения лучшей модели
    monitor="val_loss",  # Метрика для мониторинга
    save_best_only=True,  # Сохраняйте модель только в том случае, если отслеживаемая метрика улучшается
    mode="min",  # Сохранять, когда метрика минимальна (для потерь)
    verbose=1  # Печать сообщений при сохранении модели
)

print("Обучение модели...")
history = model.fit(
    train_generator,
    epochs=100,
    validation_data=validation_generator,
    callbacks=[checkpoint]
)

print("Оценка модели на валидационном наборе...")
val_loss, val_accuracy = model.evaluate(validation_generator)
print(f'Валидационная точность: {val_accuracy:.2f}')

Я попытался запустить этот фрагмент кода, чтобы загрузить сохраненную модель и сгенерировать матрицу ошибок:

import numpy as np
from tensorflow.keras.models import load_model
from sklearn.metrics import confusion_matrix, classification_report
import seaborn as sns
import matplotlib.pyplot as plt
import re

def extract_numbers(class_name):
    """Извлечение числовой части из названий классов, таких как 'cat01building'"""
    numbers = re.findall(r'\d+', class_name)
    return numbers[0] if numbers else class_name

def generate_confusion_matrix(model_path, validation_generator, class_names):
    # Загрузка сохраненной модели
    model = load_model('./best_model2.h5')

    # Получение предсказаний
    # Сброс генератора, чтобы начать с начала
    validation_generator.reset()

    # Получение предсказаний для валидационного набора
    predictions = model.predict(validation_generator)
    predicted_classes = np.argmax(predictions, axis=1)

    # Получение истинных меток
    true_classes = validation_generator.classes

    # Генерация матрицы ошибок
    cm = confusion_matrix(true_classes, predicted_classes)

    # Упрощение названий классов до только номеров
    simplified_labels = [extract_numbers(name) for name in class_names]

    # Построение матрицы ошибок
    plt.figure(figsize=(12, 10))
    sns.heatmap(cm, annot=True, fmt="d", cmap='Blues',
                xticklabels=simplified_labels,
                yticklabels=simplified_labels)
    plt.title('Матрица ошибок')
    plt.ylabel('Истинная метка')
    plt.xlabel('Предсказанная метка')
    plt.xticks(rotation=45)
    plt.yticks(rotation=45)
    plt.tight_layout()
    plt.show()

    # Печать отчета о классификации
    print("\nОтчет о классификации:")
    print(classification_report(true_classes, predicted_classes, 
                              target_names=class_names))  

train_dir="/content/dataset/train"
validation_dir="/content/dataset/validation"

train_datagen = ImageDataGenerator(
    rescale=1./255,
    rotation_range=20,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
    fill_mode="nearest"
)

validation_datagen = ImageDataGenerator(rescale=1./255)

# Загрузка изображений из директорий

train_generator = train_datagen.flow_from_directory(
    train_dir,
    target_size=(224, 224),  # VGG16 ожидает вход 224x224
    batch_size=32,
    class_mode="categorical"  # Убедитесь, что это установлено как 'categorical'
)

validation_generator = validation_datagen.flow_from_directory(
    validation_dir,
    target_size=(224, 224),
    batch_size=32,
    class_mode="categorical"  # Убедитесь, что это установлено как 'categorical'
)

class names = ['cat07', 'cat08', 'cat09', 'cat10', 'cat11', 'cat12', 'cat13', 'cat14', 'cat15', 'cat16', 'cat17', 'cat18', 'cat19', 'cat20', 'cat21', 'cat22', 'cat23']

# Генерация матрицы ошибок
generate_confusion_matrix('best_model2.h5', validation_generator, class_names)```

но, что бы я ни пробовал, я продолжаю получать ту же ошибку, как показано ниже:


> AttributeError: Исключение, возникшее при вызове Flatten.call().

'list' объект не имеет атрибута 'shape'

Аргументы, полученные от Flatten.call():
  • args=(['<KerasTensor shape=(None, 7, 7, 512), dtype=float32, sparse=False, name=keras_tensor_450>'],)
  • kwargs=<class 'inspect._empty'>

Ответ или решение

Чтобы загрузить предобученную модель VGG16 с использованием TensorFlow и сгенерировать матрицу путаницы, можно осуществить следующие шаги. Мы начнём с описания процесса загрузки модели, подготовки данных и вычисления матрицы путаницы на основе ваших результатов.

Шаг 1: Загрузка сохраненной модели

Вы можете загрузить сохраненную модель с помощью следующего кода:

from tensorflow.keras.models import load_model

# Загрузка модели
model = load_model('best_model.h5')

Шаг 2: Подготовка данных для прогнозирования

Для генерации матрицы путаницы вам понадобятся истинные метки и предсказания модели. Для этого используйте генератор валидации, который вы уже создали. Убедитесь, что вы используете тот же генератор validation_generator, который использовался при обучении:

# Получение истинных меток
true_classes = validation_generator.classes

# Сброс генератора
validation_generator.reset()

# Получение предсказаний
predictions = model.predict(validation_generator)
predicted_classes = np.argmax(predictions, axis=1)

Шаг 3: Генерация матрицы путаницы

Теперь, когда у вас есть предсказанные и истинные классы, можно создать матрицу путаницы. Вы также можете использовать библиотеку seaborn для визуализации этой матрицы:

from sklearn.metrics import confusion_matrix, classification_report
import seaborn as sns
import matplotlib.pyplot as plt

# Генерация матрицы путаницы
cm = confusion_matrix(true_classes, predicted_classes)

# Визуализация матрицы путаницы
plt.figure(figsize=(12, 10))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
plt.title('Матрица путаницы')
plt.ylabel('Истинные метки')
plt.xlabel('Предсказанные метки')
plt.show()

# Отчет о классификации
class_names = validation_generator.class_indices.keys()  # Получаем имена классов
print(classification_report(true_classes, predicted_classes, target_names=class_names))

Шаг 4: Устранение ошибок

Ошибка, которую вы наблюдаете, может быть связана с неправильной обработкой данных в процессе предсказания. Проверьте, что вы правильно загружаете данные и используете правильные форматы. Возможно, вам нужно изменить способ создания validation_generator, чтобы он возвращал данные в ожидаемом формате.

Также убедитесь, что class_names содержит правильное количество классов, соответствующее вашим данным. Правильное значение class_names должно быть:

class_names = validation_generator.class_indices.keys()

Шаг 5: Завершение

Следуя указанным выше шагам, вы сможете загрузить вашу предобученную модель VGG16 и создать матрицу путаницы, используя заранее подготовленные классы. Это сэкономит ваше время и упростит процесс анализа результатов модели без необходимости повторного обучения.

Если у вас возникнут вопросы или потребуется дополнительная помощь, не стесняйтесь задать их!

Оцените материал
Добавить комментарий

Капча загружается...