Вопрос или проблема
Я пытаюсь реализовать нейронную сеть с остаточным распространением в Rust, используя tch-rs (PyTorch).
Пока что вот мой код:
fn res_block(vs: &nn::Path) -> impl ModuleT {
let mut default = ConvConfigND::default();
default.padding = 1;
let conv1 = nn::conv1d(vs, NUM_HIDDEN, NUM_HIDDEN, 3, default);
let bn1 = batch_norm1d(vs, NUM_HIDDEN, BatchNormConfig::default());
let conv2 = nn::conv1d(vs, NUM_HIDDEN, NUM_HIDDEN, 3, default);
let bn2 = batch_norm1d(vs, NUM_HIDDEN, BatchNormConfig::default());
nn::func_t(|x,train| {
let mut residual = Tensor::new();
x.clone(&residual);
let x = bn1.forward_t(&conv1.forward(x),train).relu();
let x = bn2.forward_t(&conv2.forward(&x),train);
//let x = x + residual;
return x.relu();
})
Когда я компилирую этот код, я получаю следующую ошибку:
*mut torch_sys::C_tensor
не может быть безопасно использован между потоками\nв BatchNorm
, трейт Sync
не реализован для *mut torch_sys::C_tensor
, что необходимо для {closure@src\\nn.rs:19:16: 19:25}: Send
\nтребуется для &BatchNorm
, чтобы реализовать Send
“,
Эта проблема возникает, когда я помещаю строки forward_t в func_t.
Как мне сделать так, чтобы это работало?
Я также пробовал использовать последовательные сети, но они не работают с передачей переменной residual дальше. Есть ли способ это сделать? Или мне нужно сделать что-то другое?
Спасибо!
Ответ или решение
Для реализации остаточной нейронной сети (Residual Neural Network) на Rust с использованием библиотеки tch-rs
(PyTorch) есть несколько вещей, которые стоит учесть, чтобы избежать ошибок, связанных с безопасностью потоков, таких как Sync
и Send
. В вашем коде проблема возникает из-за использования замыкания внутри функции func_t
, которое пытается получить доступ к элементам, не реализующим необходимые трейты для безопасного многопоточного доступа.
Вот переработанный пример кода, который решает вашу проблему и включает необходимые изменения:
use tch::{nn, nn::ModuleT, Tensor, Device};
use tch::nn::ConvConfigND;
use tch::nn::BatchNormConfig;
const NUM_HIDDEN: i64 = 64; // установите желаемое количество скрытых нейронов
fn res_block(vs: &nn::Path) -> impl ModuleT {
let mut conv_config = ConvConfigND::default();
conv_config.padding = 1;
// Создаем слои свертки и пакетной нормализации
let conv1 = nn::conv1d(vs, NUM_HIDDEN, NUM_HIDDEN, 3, conv_config);
let bn1 = nn::batch_norm1d(vs, NUM_HIDDEN, Default::default());
let conv2 = nn::conv1d(vs, NUM_HIDDEN, NUM_HIDDEN, 3, conv_config);
let bn2 = nn::batch_norm1d(vs, NUM_HIDDEN, Default::default());
nn::func_t(move |x, train| {
let residual = x.shallow_clone(); // используем shallow_clone для создания точки отсечения
let x = bn1.forward_t(&conv1.forward(x), train).relu();
let x = bn2.forward_t(&conv2.forward(&x), train);
x + residual // возвращаем сумму, чтобы получить остаточную связь
})
}
fn main() {
// Пример использования
let vs = nn::VarStore::new(Device::cuda_if_available());
let model = res_block(&vs.root());
let input = Tensor::randn(&[1, NUM_HIDDEN, 10]); // Пример входных данных
let output = model.forward(&input);
println!("{:?}", output.size());
}
Объяснение изменений:
-
Использование
shallow_clone
: Я заменил ваше использованиеclone
наshallow_clone
, чтобы создать ссылку на тензор, который не затрагивает управление памятью в других потоках. Это позволяет избежать проблем сSync
. -
Код завершен: Я добавил простой пример в функции
main
, который показывает, как можно создать и использовать модель с массивом входных данных. - Продолжает использование
BatchNorm
:BatchNorm
оказывается важным для нормализации выходов и улучшения обучения. Ваш оригинальный код уже правильно включает это, поэтому я оставил его без изменений.
Следуйте данным рекомендациям и используйте приведенный код в качестве основы для вашей остовной сети. Не стесняйтесь тестировать ее и настраивать под свои задачи.