Остаточная нейронная сеть на Rust с использованием tch-rs

Вопрос или проблема

Я пытаюсь реализовать нейронную сеть с остаточным распространением в Rust, используя tch-rs (PyTorch).

Пока что вот мой код:

fn res_block(vs: &nn::Path) -> impl ModuleT {
    let mut default = ConvConfigND::default();
    default.padding = 1;
    let conv1 = nn::conv1d(vs, NUM_HIDDEN, NUM_HIDDEN, 3, default);
    let bn1 = batch_norm1d(vs, NUM_HIDDEN, BatchNormConfig::default());
    let conv2 = nn::conv1d(vs, NUM_HIDDEN, NUM_HIDDEN, 3, default);
    let bn2 = batch_norm1d(vs, NUM_HIDDEN, BatchNormConfig::default());
    nn::func_t(|x,train| {
        let mut residual = Tensor::new();
        x.clone(&residual);
        let x = bn1.forward_t(&conv1.forward(x),train).relu();
        let x = bn2.forward_t(&conv2.forward(&x),train);
        //let x = x + residual;
        return x.relu();
    })

Когда я компилирую этот код, я получаю следующую ошибку:
*mut torch_sys::C_tensor не может быть безопасно использован между потоками\nв BatchNorm, трейт Sync не реализован для *mut torch_sys::C_tensor, что необходимо для {closure@src\\nn.rs:19:16: 19:25}: Send\nтребуется для &BatchNorm, чтобы реализовать Send“,

Эта проблема возникает, когда я помещаю строки forward_t в func_t.
Как мне сделать так, чтобы это работало?
Я также пробовал использовать последовательные сети, но они не работают с передачей переменной residual дальше. Есть ли способ это сделать? Или мне нужно сделать что-то другое?
Спасибо!

Ответ или решение

Для реализации остаточной нейронной сети (Residual Neural Network) на Rust с использованием библиотеки tch-rs (PyTorch) есть несколько вещей, которые стоит учесть, чтобы избежать ошибок, связанных с безопасностью потоков, таких как Sync и Send. В вашем коде проблема возникает из-за использования замыкания внутри функции func_t, которое пытается получить доступ к элементам, не реализующим необходимые трейты для безопасного многопоточного доступа.

Вот переработанный пример кода, который решает вашу проблему и включает необходимые изменения:

use tch::{nn, nn::ModuleT, Tensor, Device};
use tch::nn::ConvConfigND;
use tch::nn::BatchNormConfig;

const NUM_HIDDEN: i64 = 64; // установите желаемое количество скрытых нейронов

fn res_block(vs: &nn::Path) -> impl ModuleT {
    let mut conv_config = ConvConfigND::default();
    conv_config.padding = 1;

    // Создаем слои свертки и пакетной нормализации
    let conv1 = nn::conv1d(vs, NUM_HIDDEN, NUM_HIDDEN, 3, conv_config);
    let bn1 = nn::batch_norm1d(vs, NUM_HIDDEN, Default::default());
    let conv2 = nn::conv1d(vs, NUM_HIDDEN, NUM_HIDDEN, 3, conv_config);
    let bn2 = nn::batch_norm1d(vs, NUM_HIDDEN, Default::default());

    nn::func_t(move |x, train| {
        let residual = x.shallow_clone(); // используем shallow_clone для создания точки отсечения

        let x = bn1.forward_t(&conv1.forward(x), train).relu();
        let x = bn2.forward_t(&conv2.forward(&x), train);

        x + residual // возвращаем сумму, чтобы получить остаточную связь
    })
}

fn main() {
    // Пример использования
    let vs = nn::VarStore::new(Device::cuda_if_available());
    let model = res_block(&vs.root());

    let input = Tensor::randn(&[1, NUM_HIDDEN, 10]); // Пример входных данных
    let output = model.forward(&input);
    println!("{:?}", output.size());
}

Объяснение изменений:

  1. Использование shallow_clone: Я заменил ваше использование clone на shallow_clone, чтобы создать ссылку на тензор, который не затрагивает управление памятью в других потоках. Это позволяет избежать проблем с Sync.

  2. Код завершен: Я добавил простой пример в функции main, который показывает, как можно создать и использовать модель с массивом входных данных.

  3. Продолжает использование BatchNorm: BatchNorm оказывается важным для нормализации выходов и улучшения обучения. Ваш оригинальный код уже правильно включает это, поэтому я оставил его без изменений.

Следуйте данным рекомендациям и используйте приведенный код в качестве основы для вашей остовной сети. Не стесняйтесь тестировать ее и настраивать под свои задачи.

Оцените материал
Добавить комментарий

Капча загружается...