Предсказание реального времени с помощью ResNet

Вопрос или проблема

Я обучил модель resnet50 на знаках рук от 0 до 5, и я пытаюсь развернуть ее для предсказания классов в реальном времени через веб-камеру ноутбука.

Хотя модель имеет 98% точности, и я почти уверен, что ошибка не возникает из-за плохой тренировки модели, значения в реальном времени застревают на 1 или 2 классах из 5, они всегда предсказывают номер 0 и номер 2.

Вот код:

import torch
import torch.nn as nn
import cv2
import numpy as np
from torchvision import models, transforms
from PIL import Image  # Импортируйте PIL для преобразования изображения

# Определите архитектуру модели и загрузите веса
class ResNet50Modified(nn.Module):
    def __init__(self, num_classes=6):
        super(ResNet50Modified, self).__init__()
        self.model = models.resnet50(pretrained=True)  # Используйте pretrained=True для повышения производительности
        self.model.fc = nn.Linear(self.model.fc.in_features, num_classes)

    def forward(self, X):
        return self.model(X)

# Загрузите обученную модель
model = ResNet50Modified(num_classes=6)
# Загрузите состояние модели для CPU
model.load_state_dict(torch.load("resnet50_modified1.pth", map_location=torch.device('cpu')))
model.eval()

# Определите преобразования, чтобы соответствовать предварительной обработке тренировки
preprocess = transforms.Compose([
    transforms.Resize((64, 64)),  # Измените размер до входного размера модели
    transforms.ToTensor(),  # Преобразуйте в тензор
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Нормализуйте в соответствии со стандартами ResNet
])

# Метки для знаков
class_names = ['Class_0', 'Class_1', 'Class_2', 'Class_3', 'Class_4', 'Class_5']  # Замените на реальные имена знаков

# Откройте веб-камеру для предсказания в реальном времени
cap = cv2.VideoCapture(0)

if not cap.isOpened():
    print("Ошибка: Не удалось открыть веб-камеру.")
    exit()

while True:
    ret, frame = cap.read()
    if not ret:
        print("Ошибка: Не удалось прочитать кадр.")
        break

    # Преобразуйте кадр из BGR (OpenCV) в RGB (PIL)
    frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)

    # Преобразуйте массив NumPy в изображение PIL
    pil_image = Image.fromarray(frame_rgb)

    # Предобработайте кадр
    input_image = preprocess(pil_image)  # Используйте изображение PIL для предварительной обработки
    input_image = input_image.unsqueeze(0)  # Добавьте размерность пакета

    # Предсказание с использованием модели
    with torch.no_grad():
        outputs = model(input_image)
        
        # Примените softmax для получения вероятностей
        probabilities = torch.softmax(outputs, dim=1)
        
        # Получите предсказанный класс и уверенность
        _, predicted = torch.max(probabilities, 1)
        confidence = probabilities[0][predicted].item() * 100  # Преобразуйте в процент
        label = class_names[predicted.item()]

    # Отобразите результат с уровнем уверенности
    cv2.putText(frame, f"Предсказано: {label}, Уверенность: {confidence:.2f}%", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
    cv2.imshow("Обнаружение знаков", frame)

    # Выход при нажатии 'q'
    if cv2.waitKey(1) & 0xFF == ord('q'):
        break

cap.release()
cv2.destroyAllWindows()

Уровни уверенности всегда высоки, но с неправильными метками, даже если я показываю на веб-камеру 5 пальцев, она застревает на нуле.

Мне кажется, что проблема связана с обработкой кадров, есть ли у кого-нибудь идеи по этому поводу?

Спасибо за вашу помощь

Ответ или решение

Проблема с реализацией предсказания в реальном времени с использованием модели ResNet50 может быть связана с несколькими аспектами работы вашей программы, хотя сама модель показывает хорошую точность на тестовом наборе. Давайте рассмотрим возможные причины и предложим решения для получения более корректных предсказаний.

1. Обработка кадров

Первое, что стоит проверить, это как именно вы обрабатываете кадры из веб-камеры. Если качество входного изображения или его предварительная обработка не соответствуют тем, что использовались во время тренировки модели, то это может привести к неверным предсказаниям.

Рекомендации:

  • Резкость изображения: Проверьте, достаточно ли качественные кадры от веб-камеры. Для лучшего результата можно использовать фильтры для повышения резкости, если картинка размытая.
  • Размер входа: Убедитесь, что вы изменяете размер изображения корректно. Ваша модель ожидала входных данных размером 224×224, тогда как вы используете 64×64. Измените размер на 224×224 в вашем коде (в строке transforms.Resize((64, 64))).

2. Аугментация данных

Если во время обучения модель не видела достаточного разнообразия в изображениях (например, разные условия освещения, различные углы и т.д.), это может привести к недостаточной обобщающей способности. Основной причиной, по которой модель может «залипать» на двух классах, является недостаток информации в обучающих данных.

Рекомендации:

  • Проверьте и увеличьте набор данных с использованием методов аугментации данных. Это может включать поворот, изменение яркости и контрастности, обрезку.

3. Проверка выходных данных

Несмотря на то, что некоторые классы получают высокие вероятности, вы должны убедиться, что у вас есть достаточная статистика по каждому классу.

Рекомендации:

  • Проведите тестирование модели на валидационных данных, используя ту же предобработку, чтобы проанализировать, какие именно классы наиболее часто путаются.

4. Графический интерфейс и отображение результатов

Следует убедиться, что текст, который выводится на экран, основан на актуальных данных и температура его не "зависает" на значениях, которые не меняются. Проверьте, правильно ли отображаются предсказания.

Пример исправленного кода

preprocess = transforms.Compose([
    transforms.Resize((224, 224)),  # Изменяем на 224x224
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

Заключение

Реализация детекции жестов с использованием ResNet50 в реальном времени является увлекательной задачей, однако требует тщательной проверки каждого этапа процесса. Обратите внимание на обработку изображений, разнообразие данных для обучения и правильность отображения выводов на экран. Устранение этих проблем должно значительно улучшить точность ваших предсказаний. Если после внесения исправлений проблема останется, возможно, стоит посмотреть на другие архитектуры моделей или методы классификации, которые лучше будут работать с вашими данными.

Оцените материал
Добавить комментарий

Капча загружается...