Вопрос или проблема
Я пытаюсь понять, как получить таблицу вероятностей действий вместо уникального предсказания в stable baselines 3 SAC, чтобы переопределить метод ‘predict’ и отфильтровать недопустимые действия. Я полагаю, что правильный способ сделать это — удалить недопустимые действия из таблицы и изменить размер вероятностей.
Предсказания моделей могут быть детерминированными или нет, но я не могу найти функцию, которая возвращает таблицу вероятностей действий (такая функция была доступна в SB2, если я правильно помню…)
Я читал эти темы, но они не о SAC:
https://github.com/DLR-RM/stable-baselines3/issues/1347
https://github.com/DLR-RM/stable-baselines3/issues/568
https://github.com/DLR-RM/stable-baselines3/issues/385
С наилучшими пожеланиями.
Преобразуйте ваше наблюдение в тензор нужной формы, если это еще не сделано
obs_tensor = torch.tensor(obs.reshape((1,-1)), dtype=torch.float32)
with torch.no_grad():
distribution = model.policy.get_distribution(obs_tensor)
Объект distribution
имеет атрибуты logits
и probs
:
print(distribution.distribution.probs)
Ответ или решение
Чтобы получить вероятности действий вместо единственного предсказания в Stable Baselines 3 (SB3) для алгоритма Soft Actor-Critic (SAC), вам необходимо обращаться к политике модели и использовать соответствующие функции для получения распределения вероятностей. В отличие от предыдущих версий, как SB2, в SB3 нет встроенной функции, которая напрямую возвращает таблицу вероятностей действий. Однако, можно использовать get_distribution
, чтобы получить необходимые данные.
Вот шаги, которые вам нужно выполнить:
-
Преобразуйте наблюдение в тензор: Убедитесь, что ваше наблюдение (объект среды) имеет правильную форму. Если у вас есть одномерный массив, вам нужно изменить его форму на двумерный тензор.
obs_tensor = torch.tensor(obs.reshape((1, -1)), dtype=torch.float32)
-
Получите распределение вероятностей: Используйте метод
get_distribution
вашей политики для получения объекта распределения. Это будет содержать логи вероятностей и сами вероятности.with torch.no_grad(): distribution = model.policy.get_distribution(obs_tensor)
-
Получите вероятности действий: После получения распределения вы можете извлечь вероятности действий из объекта распределения.
action_probs = distribution.distribution.probs print(action_probs)
-
Фильтрация недопустимых действий: Если у вас есть недопустимые действия, вы можете отфильтровать их. Напоминаю, что для корректного изменения вероятностей вы должны масштабировать их так, чтобы их сумма оставалась равной 1. Ниже приведен пример кода для этого.
valid_actions_mask = ... # Boolean mask для недопустимых действий valid_action_probs = action_probs[valid_actions_mask] # Масштабируем вероятности, чтобы сумма была равной 1 valid_action_probs /= valid_action_probs.sum()
-
Переопределение метода predict: Если вы хотите переопределить метод
predict
вашей модели, чтобы использовать фильтрованные вероятности, вам нужно реализовать логику в соответствии с вышеуказанными шагами.
Вот пример переопределенного метода predict
:
class CustomSAC(SAC):
def predict(self, obs, deterministic=False):
obs_tensor = torch.tensor(obs.reshape((1, -1)), dtype=torch.float32)
with torch.no_grad():
distribution = self.policy.get_distribution(obs_tensor)
action_probs = distribution.distribution.probs
# Ваш код для фильтрации недопустимых действий
valid_actions_mask = ... # Определите, какие действия допустимы
valid_action_probs = action_probs[valid_actions_mask]
# Масштабируем вероятности
valid_action_probs /= valid_action_probs.sum()
# Здесь метод выбора действия на основе вероятностей
if deterministic:
# Выбор действия с самой высокой вероятностью
action = valid_action_probs.argmax().item()
else:
# Образуем вероятностный выбор
action = torch.multinomial(valid_action_probs, num_samples=1).item()
return action
Таким образом, это позволит вам получать вероятности действий и фильтровать их для SAC в Stable Baselines 3.