Как избежать операций на месте в цикле PyTorch, чтобы разрешить обратное распространение ошибки?

Вопрос или проблема

Я работаю над моделью PyTorch, которая включает двойной цикл для вычислений. Проблема в том, что я получаю ошибку, связанную с операциями в месте, когда пытаюсь выполнить обратное распространение.

Вот упрощенный пример моего кода:

import torch

length = 1000
dx = 100
num_rows = 5  
num_cols = int(length / dx) + 1

C = torch.rand(num_rows, requires_grad=True)
D = torch.rand(num_rows, requires_grad=True)
qobs = torch.rand(num_rows, requires_grad=False)

Caux = 1 + C + D
C0 = (-1 + C + D) / Caux
C1 = (1 + C - D) / Caux
C2 = (1 - C + D) / Caux

qprop = torch.zeros((num_rows, num_cols))
qprop[:, 0] = qobs

for j in range(1, num_cols):
    for n in range(1, num_rows):
        term1 = C2[n] * qprop[n-1, j-1]
        term2 = C1[n] * qprop[n, j-1]
        term3 = C0[n] * qprop[n-1, j]
        qprop[n, j] = term1 + term2 + term3

loss = qprop.sum()
loss.backward()

Ответ или решение

Вопрос, с которым вы столкнулись, связан с использованием in-place операций в цикле PyTorch, что может привести к проблемам с обратным распространением градиентов. В вашем коде проблема в том, что переменная qprop обновляется в цикле, а это может привести к конфликту с градиентами.

Чтобы избежать in-place операций и позволить выполнять обратное распространение, вы можете использовать временные переменные для хранения промежуточных результатов. Вместо прямого обновления qprop, создайте новый тензор для хранения результатов. Вот как можно переписать ваш код:

import torch

length = 1000
dx = 100
num_rows = 5  
num_cols = int(length / dx) + 1

C = torch.rand(num_rows, requires_grad=True)
D = torch.rand(num_rows, requires_grad=True)
qobs = torch.rand(num_rows, requires_grad=False)

Caux = 1 + C + D
C0 = (-1 + C + D) / Caux
C1 = (1 + C - D) / Caux
C2 = (1 - C + D) / Caux

qprop = torch.zeros((num_rows, num_cols))
qprop[:, 0] = qobs

for j in range(1, num_cols):
    new_qprop_j = torch.zeros(num_rows)  # Создаем новый тензор для хранения значений в текущем столбце
    for n in range(1, num_rows):
        term1 = C2[n] * qprop[n-1, j-1]
        term2 = C1[n] * qprop[n, j-1]
        term3 = C0[n] * qprop[n-1, j]
        new_qprop_j[n] = term1 + term2 + term3
    qprop[:, j] = new_qprop_j  # Обновляем значение qprop для текущего столбца

loss = qprop.sum()
loss.backward()

Объяснение изменений:

  1. Новый тензор для хранения текущих значений: Мы создаем переменную new_qprop_j, которая будет хранить результаты вычислений для текущего столбца j. Это помогает избежать перезаписи qprop в процессе его вычисления, что и вызывает ошибку при обратном распространении.

  2. Использование временной переменной: Все промежуточные результаты в цикле записываются в new_qprop_j, и только после завершения внутренних вычислений происходит обновление оригинального тензора qprop.

  3. Сохранение диаграммы вычислений: Теперь вы можете безопасно выполнять обратное распространение без конфликтов с градиентами, так как переменные не перезаписываются во время вычислений.

Следуя этим рекомендациям, вы сможете избежать ошибок, связанных с in-place операциям в вашем коде на PyTorch.

Оцените материал
Добавить комментарий

Капча загружается...