Запуск Tensorflow MobileNet из Java

Вопрос или проблема

Я пытаюсь запустить Tensorflow для распознавания изображений (классификация) на Java (JSE, не Android).

Я использую код из здесь, и здесь.

Он работает для моделей Inceptionv3 и для моделей, переобученных на Inceptionv3.

Но для моделей MobileNet он не работает (например, следуя этой статье).

Код работает, но дает неправильные результаты (неправильные метки классификации). Какой код/настройки необходимы для использования MobileNet из Java?

Код, который работает для Inceptionv3:

try (Tensor image = Tensor.create(imageBytes)) {
    float[] labelProbabilities = executeInceptionGraph(graphDef, image);
    int bestLabelIdx = maxIndex(labelProbabilities);
    result.setText("");
    result.setText(String.format(
        "ЛУЧШЕЕ СОВПАДЕНИЕ: %s (%.2f%% вероятно)",
        labels.get(bestLabelIdx), labelProbabilities[bestLabelIdx] * 100f));
        System.out.println(
            String.format(
                "ЛУЧШЕЕ СОВПАДЕНИЕ: %s (%.2f%% вероятно)",
                labels.get(bestLabelIdx), labelProbabilities[bestLabelIdx] * 100f));
}

Это работает для моделей Inceptionv3, но не для MobileNet.
Возникает ошибка: “Ожидается, что args[0] будет float, но предоставлена строка”.

Для MobileNet мы попробовали код:

try (Graph g = new Graph()) {
    GraphBuilder b = new GraphBuilder(g);
    // Некоторые константы, специфичные для предобученной модели по адресу:
    // https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip
    //
    // - Модель была обучена на изображениях, масштабирующихся до 224x224 пикселей.
    // - Цвета, представленные как R, G, B по 1 байту каждый, были преобразованы в
    //   float, используя (значение - Среднее)/Масштаб.
    final int H = 224;
    final int W = 224;
    final float mean = 128f;
    final float scale = 1f;
    // Поскольку граф создается один раз за выполнение, мы можем использовать константу для
    // входного изображения. Если бы граф использовался повторно для нескольких входных изображений, 
    // было бы более уместно использовать плейсхолдер.
    final Output<String> input = b.constant("input", imageBytes);
    final Output<Float> output = b.div(
        b.sub(
            b.resizeBilinear(
                b.expandDims(
                    b.cast(b.decodeJpeg(input, 3), Float.class),
                    b.constant("make_batch", 0)),
                    b.constant("size", new int[] {H, W})),
                    b.constant("mean", mean)),
                    b.constant("scale", scale));
                    try (Session s = new Session(g)) {
                        return s.runner().fetch(output.op().name()).run().get(0).expect(Float.class);
                    }
}

Это работает, но дает неправильные метки.

Не давайте использовать в java 🙂

У меня была такая же проблема, попробуйте изменить значение масштаба, после этого у меня были такие же метки как в java, так и в python.

        // - Модель была обучена на изображениях, масштабирующихся до 224x224 пикселей.
        // - Цвета, представленные как R, G, B по 1 байту каждый, были преобразованы в
        // float, используя (значение - Среднее)/Масштаб.
        final int H = 224;
        final int W = 224;
        final float mean = 117f;
        final float scale = 255f;

Ответ или решение

Для запуска модели TensorFlow MobileNet на Java (JSE, а не Android) могут возникнуть определённые сложности, особенно если данный код ранее успешно работал с другими моделями, такими как InceptionV3. Важно понимать, как правильно подготовить модель и данные, чтобы добиться корректной классификации изображений.

1. Проверка подготовки изображения

При использовании MobileNet, как и у других моделей, вы должны удостовериться, что входные данные подготавливаются правильно. Модель MobileNet требует, чтобы входное изображение было изменено по размеру до 224×224 пикселей. Однако также важно учесть параметры нормализации изображений, которые были использованы при обучении модели.

2. Нормализация данных

Одной из ключевых отличий, которые могут повлиять на результаты классификации, является способ нормализации входных данных. В коде, который вы используете, обратите внимание на следующие параметры:

final float mean = 128f; // значение, вычитанное из пикселей
final float scale = 1f;  // коэффициент масштабирования

Начиная с этой точки, рекомендуется изменить значения mean и scale, так как они могут существенно повлиять на точность.

Попробуйте использовать следующие параметры:

final float mean = 117f;  // значение, вычитанное из пикселей для MobileNet
final float scale = 255f;  // масштаб итоговых значений

Эти значения соответствуют предобученной модели MobileNet и должны помочь улучшить результаты.

3. Конструкция графа TensorFlow

Убедитесь также, что граф TensorFlow строится корректно. Основные шаги, которые вы уже предприняли, выглядят правильно. Ваш код расширяет размерность изображения, декодирует JPEG и выполняет нормализацию входных данных:

final Output<String> input = b.constant("input", imageBytes);
final Output<Float> output = b.div(
    b.sub(
        b.resizeBilinear(
            b.expandDims(
                b.cast(b.decodeJpeg(input, 3), Float.class),
                b.constant("make_batch", 0)
            ),
            b.constant("size", new int[]{H, W})
        ),
        b.constant("mean", mean)
    ),
    b.constant("scale", scale)
);

4. Выполнение сессии

При получении результатов не забудьте извлечь их правильно:

try (Session s = new Session(g)) {
    return s.runner().fetch(output.op().name()).run().get(0).expect(Float.class);
}

Это должно работать, если модель MobileNet загружена правильно и соответствует ожидаемой архитектуре.

5. Тестирование и отладка

После внесения всех изменений обязательно проведите тестирование с различными изображениями, которые были ситуативно классифицированы в других вариантах. Если результаты всё ещё неверные:

  • Проверьте, правильно ли загружается модель MobileNet.
  • Убедитесь, что лейблы для ваших результатов соответствуют тем, которые использовались для обучения сети.
  • Попробуйте использовать другие изображения с известными результатами, чтобы определить, сохраняется ли ошибка.

Заключение

Правильная работа с MobileNet в Java требует чёткой настройки подготовки изображений и моделей. Попробовав указанные изменения в параметрах нормализации, вы должны получить более точные результаты классификации. Если же проблемы сохраняются, возможно, стоит рассмотреть возможность использования более низкоуровневых библиотек для отладки модели, таких как TensorFlow Python API, для изоляции проблемы и её дальнейшего решения.

Оцените материал
Добавить комментарий

Капча загружается...