Вопрос или проблема
Я работаю над моделью PyTorch, которая включает двойной цикл для вычислений. Проблема в том, что я получаю ошибку, связанную с операциями в месте, когда пытаюсь выполнить обратное распространение.
Вот упрощенный пример моего кода:
import torch
length = 1000
dx = 100
num_rows = 5
num_cols = int(length / dx) + 1
C = torch.rand(num_rows, requires_grad=True)
D = torch.rand(num_rows, requires_grad=True)
qobs = torch.rand(num_rows, requires_grad=False)
Caux = 1 + C + D
C0 = (-1 + C + D) / Caux
C1 = (1 + C - D) / Caux
C2 = (1 - C + D) / Caux
qprop = torch.zeros((num_rows, num_cols))
qprop[:, 0] = qobs
for j in range(1, num_cols):
for n in range(1, num_rows):
term1 = C2[n] * qprop[n-1, j-1]
term2 = C1[n] * qprop[n, j-1]
term3 = C0[n] * qprop[n-1, j]
qprop[n, j] = term1 + term2 + term3
loss = qprop.sum()
loss.backward()
Ответ или решение
Вопрос, с которым вы столкнулись, связан с использованием in-place операций в цикле PyTorch, что может привести к проблемам с обратным распространением градиентов. В вашем коде проблема в том, что переменная qprop
обновляется в цикле, а это может привести к конфликту с градиентами.
Чтобы избежать in-place операций и позволить выполнять обратное распространение, вы можете использовать временные переменные для хранения промежуточных результатов. Вместо прямого обновления qprop
, создайте новый тензор для хранения результатов. Вот как можно переписать ваш код:
import torch
length = 1000
dx = 100
num_rows = 5
num_cols = int(length / dx) + 1
C = torch.rand(num_rows, requires_grad=True)
D = torch.rand(num_rows, requires_grad=True)
qobs = torch.rand(num_rows, requires_grad=False)
Caux = 1 + C + D
C0 = (-1 + C + D) / Caux
C1 = (1 + C - D) / Caux
C2 = (1 - C + D) / Caux
qprop = torch.zeros((num_rows, num_cols))
qprop[:, 0] = qobs
for j in range(1, num_cols):
new_qprop_j = torch.zeros(num_rows) # Создаем новый тензор для хранения значений в текущем столбце
for n in range(1, num_rows):
term1 = C2[n] * qprop[n-1, j-1]
term2 = C1[n] * qprop[n, j-1]
term3 = C0[n] * qprop[n-1, j]
new_qprop_j[n] = term1 + term2 + term3
qprop[:, j] = new_qprop_j # Обновляем значение qprop для текущего столбца
loss = qprop.sum()
loss.backward()
Объяснение изменений:
-
Новый тензор для хранения текущих значений: Мы создаем переменную
new_qprop_j
, которая будет хранить результаты вычислений для текущего столбцаj
. Это помогает избежать перезаписиqprop
в процессе его вычисления, что и вызывает ошибку при обратном распространении. -
Использование временной переменной: Все промежуточные результаты в цикле записываются в
new_qprop_j
, и только после завершения внутренних вычислений происходит обновление оригинального тензораqprop
. -
Сохранение диаграммы вычислений: Теперь вы можете безопасно выполнять обратное распространение без конфликтов с градиентами, так как переменные не перезаписываются во время вычислений.
Следуя этим рекомендациям, вы сможете избежать ошибок, связанных с in-place операциям в вашем коде на PyTorch.