Есть ли что-то неправильно в использовании session.run()?

Вопрос или проблема

Я использую Tensorflow.net, Numsharp и TensorFlow.Keras в C# WinForms на .NET Framework 4.8 и пытаюсь создать нейронную сеть для обучения с подкреплением (Q-learning). Школа заставляет нас использовать WinForms, и это действительно больно.
В чем неправилен мой вызов session.run() в этом случае?

public void Train(NDArray currentState, int action, float reward, NDArray nextState, bool done)
{
    var graph = session.graph.as_default();

    simMenu.LogMessage($"Обучение начато с currentState: {currentState}, action: {action}, reward: {reward}, nextState: {nextState}, done: {done}");

    try
    {

        currentState = np.reshape(currentState, new int[] { 1, stateSize });
        currentState = currentState.astype(np.float32);


        // Получить Q-значения для текущего состояния
        var qValuesCurrent = session.run(output, new FeedItem(state, currentState));
        
        float currentQValue = qValuesCurrent.ToArray<float>()[action];

        // Получить Q-значения для следующего состояния
        nextState = np.reshape(nextState, new int[] { 1, stateSize });
        nextState = nextState.astype(np.float32);
        
        var qValuesNext = session.run(output, new FeedItem(state, nextState)); // здесь возникает ошибка
     
        float maxFutureQValue = np.max(qValuesNext.ToArray<float>());

        // Расчет целевого Q-значения
        float targetQValue = reward;
        if (!done)
        {
            targetQValue += discountFactor * maxFutureQValue;
        }

        // Обновление Q-значения с использованием формулы Q-обучения
        var qValuesArray = qValuesCurrent.ToArray<float>();
        qValuesArray[action] = currentQValue + learningRate * (targetQValue - currentQValue);
        var updatedQValues = np.array(qValuesArray).reshape(1, actionSize);

        // Обучение модели
        session.run(optimizer, new FeedItem(state, currentState), new FeedItem(qTarget, updatedQValues));
      
    }
    catch (Exception ex)
    {
        simMenu.LogMessage($"Произошла ошибка во время обучения: {ex.Message}");
        throw;
    }

При запуске ошибка выглядит просто так: System.NotImplementedException: ”, и лог-сообщения таковы:

Сессия и граф успешно инициализированы.
Обучение начато с currentState: [140.0723, 30.92329, 30.92329, 51.61395, 40.36087], action: 2, reward: 1, nextState: [140.1722, 30.92329, 30.92329, 51.61395, 40.50206], done: False Произошла ошибка во время обучения:

смотрите переменные в момент ошибки

Вся функция представлена здесь –

using NumSharp;
using Tensorflow;
using static Tensorflow.Binding;
using System;
using Coursework;

public class QLearningAgent
{
    private readonly int stateSize;        // Размер состояния
    private readonly int actionSize;       // Количество действий
    private readonly float learningRate;   // Скорость обучения
    private readonly float discountFactor = 0.99f; // Фактор скидки для будущих вознаграждений
    private readonly float explorationProbability; // Эпсилон для исследования
    private readonly Session session;      // Сессия TensorFlow
    private readonly Tensor state;         // Плейсхолдер для текущего состояния
    private readonly Tensor qTarget; // Плейсхолдер для целевых Q-значений
    private readonly SimMenu simMenu;

    private Tensor output;                 // Выходной слой для Q-значений
    private Operation optimizer;
    
    public QLearningAgent(int stateSize, int actionSize, float learningRate, SimMenu simMenu, float explorationProbability = 0.1f)
    {

        this.stateSize = stateSize;
        this.actionSize = actionSize;
        this.learningRate = learningRate;
        this.explorationProbability = explorationProbability;
        this.simMenu = simMenu;

        // Конструкция графа TensorFlow
        var graph = tf.Graph().as_default();

        // Инициализация сессии
        session = tf.Session(graph);
        session.run(tf.global_variables_initializer()); // session.run() работает нормально здесь

        this.state = tf.placeholder(tf.float32, shape: new int[] { -1, stateSize });
        this.qTarget = tf.placeholder(tf.float32, shape: new int[] { -1, actionSize });

        // Определение структуры нейронной сети
        var hidden = tf.keras.layers.Dense(24, activation: "relu").Apply(state); // Скрытый слой
        this.output = tf.keras.layers.Dense(actionSize, activation: "linear").Apply(hidden); // Выходной слой для Q-значений

        // Потеря и оптимизатор
        var loss = tf.reduce_mean(tf.square(output - qTarget));
        this.optimizer = tf.train.AdamOptimizer(learningRate).minimize(loss);

        simMenu.LogMessage("Сессия и граф успешно инициализированы.");
    }

    public int GetAction(NDArray currentState)
    {
        
        // Выбор действий с использованием метода ε-жадности
        if (new Random().NextDouble() < explorationProbability)
        {
            // Исследование: выбор случайного действия
            return new Random().Next(actionSize);
        }
        else
        {
            // Использование: выбор действия с наивысшим Q-значением

            var graph = session.graph.as_default();

            // Получить Q-значения для текущего состояния
            var feedDict = new FeedItem(state, currentState);
            var qValues = session.run(output, feedDict); // Session.run() тоже не работает здесь, однако она переходит к Train() перед GetAction()
            var qValuesArray = qValues.ToArray<float>(); // Преобразовать Tensorflow.NumPy.NDArray в массив float
            return np.argmax(np.array(qValuesArray)); // Вернуть индекс действия с наивысшим Q-значением

        }
    }

    public void Train(NDArray currentState, int action, float reward, NDArray nextState, bool done)
    {
        var graph = session.graph.as_default();

        simMenu.LogMessage($"Обучение начато с currentState: {currentState}, action: {action}, reward: {reward}, nextState: {nextState}, done: {done}");

        try
        {

            currentState = np.reshape(currentState, new int[] { 1, stateSize });
            currentState = currentState.astype(np.float32);


            // Получить Q-значения для текущего состояния
            var qValuesCurrent = session.run(output, new FeedItem(state, currentState));
            
            float currentQValue = qValuesCurrent.ToArray<float>()[action];

            // Получить Q-значения для следующего состояния
            nextState = np.reshape(nextState, new int[] { 1, stateSize });
            nextState = nextState.astype(np.float32);
            
            var qValuesNext = session.run(output, new FeedItem(state, nextState));
         
            float maxFutureQValue = np.max(qValuesNext.ToArray<float>());

            // Расчет целевого Q-значения
            float targetQValue = reward;
            if (!done)
            {
                targetQValue += discountFactor * maxFutureQValue;
            }

            // Обновление Q-значения с использованием формулы Q-обучения
            var qValuesArray = qValuesCurrent.ToArray<float>();
            qValuesArray[action] = currentQValue + learningRate * (targetQValue - currentQValue);
            var updatedQValues = np.array(qValuesArray).reshape(1, actionSize);

            // Обучение модели
            session.run(optimizer, new FeedItem(state, currentState), new FeedItem(qTarget, updatedQValues));
          
        }
        catch (Exception ex)
        {
            simMenu.LogMessage($"Произошла ошибка во время обучения: {ex.Message}");
            throw;
        }
    }
}


Я пытался проверить все типы и совместимость state и nextState, я также пытался изменить их размер, чтобы убедиться, что они являются NDArray одинаковой длины и типа. Это должно принимать 5 входов.

Мое описание проблемы может быть не лучшим или самым полным, так как это мой первый пост, спасибо, что нашли время прочитать. Если вам нужна дополнительная информация, просто спросите!

Ответ или решение

Проблема с использованием метода session.run() в TensorFlow.Net

Использование библиотеки TensorFlow в C# с библиотеками TensorFlow.Net и NumSharp может быть сложной задачей, особенно в контексте реализации нейронной сети для Q-обучения. Ваша проблема заключается в ошибке System.NotImplementedException, возникающей при вызове метода session.run() в функции Train. Давайте проанализируем ваш код и попробуем выявить возможные источники данных ошибок.

1. Область графа

Вы правильно используете as_default() для установки текущего графа в методе Train. Однако отсутствие корректной инициализации графа может привести к возникновению ошибок. Убедитесь, что все операции, связанные с вашим графом, выполняются внутри метода Train, и все переменные и тензоры корректно инициализированы перед выполнением сессии.

2. Подходящие размеры тензоров

Одной из распространенных проблем является несовпадение размеров передаваемых тензоров. В вашем случае, вы пишете:

nextState = np.reshape(nextState, new int[] { 1, stateSize });
nextState = nextState.astype(np.float32);

Убедитесь, что переменная stateSize истинно соответствует количеству элементов в nextState. Проверьте, что nextState имеет правильную форму (1, stateSize) и является массивом float32.

3. Проверка тензора qValuesNext

Ошибка может возникнуть из-за того, что выходной тензор output может не возвращать ожидаемое значение из графа. Посмотрите на определение вашего тензора. Убедитесь, что слои, создаваемые с помощью tf.keras.layers, правильно подключены. Вместо вызова session.run(output, ...), можно сначала проверить сам output:

simMenu.LogMessage($"Output tensor shape: {output.shape}");

Теперь, перед вызовом session.run(), вы сможете убедиться, что output имеет правильную форму.

4. Использование метода numpy.max

Использование метода np.max() для нахождения максимального значения также может представлять собой потенциальную проблему, если qValuesNext может быть пустым. Убедитесь, что для qValuesNext не возникает исключения, связанного с отсутствием значений:

var qValuesNext = session.run(output, new FeedItem(state, nextState));
if (qValuesNext.size == 0)
{
    throw new InvalidOperationException("qValuesNext is empty.");
}

5. Логирование ошибок

Ваше текущее сообщение об ошибке не предоставляет достаточно информации. Я рекомендую расширить блок catch, чтобы вывести стек вызовов и больше деталей о происходящем:

catch (Exception ex)
{
    simMenu.LogMessage($"An error occurred during training: {ex.Message}, Stack Trace: {ex.StackTrace}");
    throw;
}

Вывод

Ключом к устранению проблемы с session.run() в вашем Q-обучении является тщательная проверка всех входных параметров и корректности их форм. Примените предложенные решения и ведите тщательное логирование, чтобы получить больше информации в случае возникновения ошибки. Если в дальнейшем возникнут дополнительные затруднения, обратитесь к документации TensorFlow.Net или профессиональным сообществам за поддержкой. Ваш успех в этом проекте будет зависеть от тщательной отладки и понимания работы нейросетей в среде C#.

Оцените материал
Добавить комментарий

Капча загружается...