Pytorch: Извлечение среза 3D тензора вдоль размерности 2 с использованием 1D тензора в качестве индекса

Вопрос или проблема

Я хочу сделать срез 3D тензора в PyTorch. Размер 3D тензора src_tensor равен (batch, max_len, hidden_dim), и у меня есть 1D вектор индексов indices с размером (batch,). Я хочу сделать срез вдоль второго измерения src_tensor. Я могу добиться этой функциональности с помощью следующего кода:

import torch
nums = 30
l = [i for i in range(nums)]
src_tensor = torch.Tensor(l).reshape((3,5,2))
indices = [1,2,3]
slice_tensor = torch.zeros((3,2,2)) 
for i in range(3):
    p1,p2 = indices[i],indices[i]+1
    slice_tensor[i,:,:]=src_tensor[i,[p1,p2],:]
print(src_tensor)
print(indices)
print(slice_tensor)
"""
tensor([[[ 0.,  1.],
         [ 2.,  3.],
         [ 4.,  5.],
         [ 6.,  7.],
         [ 8.,  9.]],

        [[10., 11.],
         [12., 13.],
         [14., 15.],
         [16., 17.],
         [18., 19.]],

        [[20., 21.],
         [22., 23.],
         [24., 25.],
         [26., 27.],
         [28., 29.]]])
[1, 2, 3]
tensor([[[ 2.,  3.],
         [ 4.,  5.]],

        [[14., 15.],
         [16., 17.]],

        [[26., 27.],
         [28., 29.]]])
"""

Мой вопрос заключается в том, можно ли упростить приведенный выше код, например, исключив цикл for.

Ответ или решение

Чтобы упростить процесс нарезки 3D тензора в PyTorch, используя 1D тензор индексов, можно использовать функции torch.gather или torch.index_select, что позволяет избежать использования цикла for. Рассмотрим ваш подход к задаче и предложим более эффективное решение.

Исходная задача

У вас есть 3D тензор src_tensor размером (batch, max_len, hidden_dim) и 1D тензор indices размером (batch,), который вы хотите использовать для нарезки этого тензора по второй размерности (max_len).

Пример кода

Ваш код можно улучшить следующим образом:

import torch

# Исходные данные
nums = 30
src_tensor = torch.arange(0, nums, dtype=torch.float32).reshape((3, 5, 2))
indices = torch.tensor([1, 2, 3])

# Использование torch.gather для нарезки
# Необходимо изменить размерность indices для использования с gather
indices = indices.unsqueeze(1).expand(-1, 2)  # Расширяем индексы до (3, 2)
slice_tensor = src_tensor.gather(1, indices.unsqueeze(-1).expand(-1, -1, 2))  # Используем gather

print("Исходный тензор:")
print(src_tensor)
print("Индексы:")
print(indices)
print("Нарезанный тензор:")
print(slice_tensor)

Объяснение кода

  1. Создание tензоров: Мы сначала создаем исходный тензор src_tensor и 1D тензор indices, где indices – это индексы, по которым будет выполнена нарезка.

  2. Преобразование индексов: Используем метод unsqueeze для добавления дополнительной размерности к indices, а затем с expand создаем необходимую форму (batch, 2).

  3. Использование torch.gather: Метод gather позволяет извлекать элементы по индексам. Обратите внимание, что мы расширяем индексы до третьей размерности, чтобы соотнести их с размерностью hidden_dim тензора.

Преимущества этого подхода

  • Отсутствие цикла: Использование torch.gather позволяет избежать явных циклов, тем самым улучшая производительность.
  • Ясность кода: Решение более чистое и понятное, так как напрямую использует встроенные функции PyTorch для работы с тензорами.
  • Улучшенная производительность: Благодаря векторизации операций в PyTorch, код работает быстрее, особенно на больших данных.

Заключение

Таким образом, мы смогли упростить ваш код, сохранив его функциональность и улучшив читаемость. Применение функций, встроенных в библиотеку, не только облегчает кодирование, но и предоставляет возможность использования оптимизированных решений, что критически важно при работе с большими объемами данных.

Оцените материал
Добавить комментарий

Капча загружается...