Вопрос или проблема
Существует модифицированная модель efficientnet TF, которую я пытаюсь смоделировать в pytorch. Я внес изменения в архитектуру модели в pytorch, выгрузил веса модели TF и загрузил их обратно в новую модель pytorch. Выгрузка весов в TF осуществляется с помощью этого кода:
model = tf.saved_model.load(model_path)
ws = []
for i in range(len(model.variables)):
ws.append((i, model.variables[i].name, model.variables[i].numpy()))
with open("manually_dumped_contentnet_weights.pkl", "wb") as ofile:
pickle.dump(ws, ofile)
Формы весов в pytorch, похоже, соответствуют архитектуре и импортированным весам (после необходимых транспозиций между conv2d и depth-wise conv2d). Я могу запускать модель без каких-либо ошибок. Но выходные данные сильно отличаются от выходных данных модели TF.
Я заметил, что в коде TF модель не загружается напрямую, а загружается в сессии tf:
with Session(graph=Graph(), config=ConfigProto(allow_soft_placement=True, log_device_placement=False)) as sess:
saved_model.loader.load(sess, [saved_model.tag_constants.SERVING], model_path)
patch_feature, patch_label = sess.run(output_nodes,feed_dict={input_node: patch})
Теперь я задаюсь вопросом, была ли моя первоначальная попытка выгрузить веса модели выполнена некорректно. Или я упустил что-то еще.
Транспозиции, которые я выполнил при загрузке данных, это (3,2,0,1) для conv2d и (2,3,0,1) для depth-wise conv2d:
def reload_conv2d(layer, weights):
### weights - это кортеж, где каждый элемент состоит из кортежа: (1) номер индекса, (2) имя слоя, из которого были выгружены веса в TF, и (3) веса
count = 0
if (
"/conv2d/kernel" not in weights[0][1]
and "/conv2d_1/kernel" not in weights[0][1]
and "depthwise_conv2d/depthwise_kernel" not in weights[0][1]
and "final_conv2d/final_conv2d" not in weights[0][1]
):
raise ValueError(
f"нужен conv2d/kernel на первом индексе, но получено {weights[0][1]}"
)
transpose_shape = (2,3,0,1) if "depthwise" in weights[0][1] else (3, 2, 0, 1)
transposed_weights = torch.from_numpy(weights[0][2].transpose(transpose_shape[0], transpose_shape[1], transpose_shape[2], transpose_shape[3]))
layer.weight.data = transposed_weights
count += 1
if layer.bias is not None or layer.bias:
if (
"/conv2d/bias" not in weights[1][1]
and "/conv2d_1/bias" not in weights[1][1]
):
raise ValueError(
f"нужен conv2d/bias на втором индексе, но получено {weights[1][1]}"
)
layer.bias.data = (
torch.from_numpy(weights[1][2])
if type(weights[1][2]) == np.ndarray
else torch.from_numpy(weights[1][2])
)
count += 1
return layer, count
Почему модели pytorch и TF дают совершенно разные результаты на одном и том же входе? Это связано с выгрузкой весов, или загрузкой весов… или, может быть, изменением архитектуры модели? Тот факт, что входные веса TF (после изменения модели и транспозиции) были загружены без проблем и я могу запускать модель без ошибок, не помогает в отладке.
Ответ или решение
Передача весов модели из TensorFlow в PyTorch может быть сложной задачей. Давайте проанализируем вашу ситуацию, учитывая все основные шаги и возможные причины, по которым выходы ваших моделей могут сильно различаться.
1. Создание модели и выгрузка весов из TensorFlow
Вы правильно начали с загрузки модели TensorFlow и выгрузки весов с помощью кода, который предоставляет доступ к переменным модели. Использование метода model.variables
позволяет получить все веса модели, что является хорошей практикой. Убедитесь, что сохранили веса с их названиями и формами, как это сделано в вашем коде. Тем не менее, важно экспериментировать с процессом выгрузки весов и проверять, все ли нужные параметры были сохранены.
2. Трансформация весов для PyTorch
Когда вы импортируете веса в PyTorch, вам, безусловно, нужно обратить внимание на порядки размерностей. Стандартная практика заключается в том, что в TensorFlow (NCHW) порядок размерностей для весов сверточных слоев (Conv2D) обычно представлен в формате (высота, ширина, входные_каналы, выходные_каналы)
, в то время как в PyTorch используется формат (выходные_каналы, входные_каналы, высота, ширина)
.
Ваши трансформации весов на первый взгляд выглядят корректно (например, (3, 2, 0, 1)
для обычных сверток и (2, 3, 0, 1)
для глубинно-сверточных слоев). Но стоит дополнительно проверить:
- Действительно ли размеры входных и выходных каналов совпадают между соответствующими слоями обеих моделей.
- То, что сдвиги и насыщенности (если применялись) были правильно применены.
3. Проверка архитектуры модели
Определите, внесли ли вы какие-либо изменения в архитектуру вашей модели при переводе из TensorFlow в PyTorch. Сравните точные слои, их количество, порядок следования и функции активации в обеих версиях. Если в одной модели используются различные функции активации или параметры (например, Batch Normalization), это может повлиять на производительность.
4. Проблемы при загрузке весов
Обратите внимание и на этап загрузки весов в PyTorch:
- Убедитесь, что вы корректно загружаете все веса, а не только веса сверток. Могут быть слои, такие как BatchNorm, которые также требуют загрузки своих весов и смещений.
- Проверьте, что вы передаете правильные данные (в формате) в ваши слои, и что вы не меняете размер входных данных между TensorFlow и PyTorch.
5. Проверка вывода
Если вы уверены, что все веса загружены и транспонированы правильно, посмотрите на сами выходные данные моделей для определенных наборов входов, например, тестов на одних и тех же изображениях:
- Проверьте, совпадают ли размеры выходов между TensorFlow и PyTorch.
- Выполните отладку одного из промежуточных слоев на обеих моделях, чтобы определить, где именно начинают происходить расхождения.
6. Проблемы с инициализацией
Кроме того, стоит проверить настройки инициализации, такие как состояние оптимизаторов и начальные значения по умолчанию, которые могут различаться в TensorFlow и PyTorch. Если вы используете перенормировку или регуляризацию, убедитесь, что параметры также совпадают.
Заключение
Проблема с разными выходами между моделями в PyTorch и TensorFlow может быть вызвана несколькими факторами, такими как неверная выгрузка или загрузка весов, изменения в архитектуре модели и порядок трансформации весов. Рекомендуется тщательно сравнить как архитектуры, так и порядок слоев, трансформации весов и промежуточные результаты, чтобы выявить источник расхождения. Исправление этих проблем, как правило, требует кросс-проверки на каждом из этапов передачи весов и сверки результатов.