Вопрос или проблема
Я тренирую модель в pytorch и хотел бы иметь возможность программно изменять некоторые компоненты архитектуры модели, чтобы проверить, что работает лучше, без каких-либо условных операторов в forward()
. Рассмотрим игрушечный пример:
import torch
class Model(torch.nn.Model):
def __init__(self, layers: str, d_in: int, d_out: int):
super().__init__()
self.layers = layers
linears = torch.nn.ModuleList([
torch.nn.Linear(d_in, d_out),
torch.nn.Linear(d_in, d_out),
])
def forward(x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
if self.layers == "parallel":
x1 = self.linears[0](x1)
x2 = self.linears[0](x2)
x = x1 + x2
elif self.layers == "sequential":
x = x1 + x2
x = self.linears[0](x)
x = self.linears[0](x)
return x
Моя первая интуиция заключалась в том, чтобы предоставить внешние функции, например,
def parallel(x1, x2):
x1 = self.linears[0](x1)
x2 = self.linears[0](x2)
return x1 + x2
и передать их модели, как
class Model(torch.nn.Model):
def __init__(self, layers: str, d_in: int, d_out: int, fn: Callable):
super().__init__()
self.layers = layers
linears = torch.nn.ModuleList([
torch.nn.Linear(d_in, d_out),
torch.nn.Linear(d_in, d_out),
])
self.fn = fn
def forward(x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
x = self.fn(x1, x2)
но, конечно, область видимости функции не знает о self.linears
, и я также хотел бы избежать необходимости передавать каждый элемент архитектуры функции.
Я слишком много хочу? Должен ли я “укусить кислое яблоко”, как говорится по-немецки, и иметь более широкие сигнатуры функций или использовать условные операторы, или что-то ещё? Или есть решение моей проблемы?
Ответ или решение
Когда речь заходит о программном изменении компонентов архитектуры модели в PyTorch, существует множество подходов, позволяющих сделать это эффективно, без необходимости использования условных операторов в методе forward()
. В вашем случае вы хотите обойтись без конструкции if
, чтобы ваши функции оставались чистыми и модульными. Вот несколько стратегий, которые можно использовать для достижения этой цели.
1. Параметризованные функции
Один из методов — создание параметризованных функций, которые могут принимать данные модели как аргументы. Это обеспечивает четкое разделение логики, сохраняя при этом доступ к компонентам модели.
Пример реализации:
import torch
from typing import Callable
class Model(torch.nn.Module):
def __init__(self, layers: str, d_in: int, d_out: int, fn: Callable):
super().__init__()
self.layers = layers
self.linears = torch.nn.ModuleList([
torch.nn.Linear(d_in, d_out),
torch.nn.Linear(d_in, d_out),
])
self.fn = fn
def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
return self.fn(x1, x2, self.linears)
def parallel(x1, x2, linears):
x1 = linears[0](x1)
x2 = linears[0](x2)
return x1 + x2
def sequential(x1, x2, linears):
x = x1 + x2
x = linears[0](x)
return linears[1](x)
# Пример использования:
model_parallel = Model("parallel", 10, 5, parallel)
model_sequential = Model("sequential", 10, 5, sequential)
2. Стратегия dict
для архитектур
Еще один подход — использование словаря (или другого контейнера), который будет хранить ссылки на различные функции обработки. Это позволяет вам выбрать нужную функцию во время инициализации и избавиться от условий внутри forward()
.
class Model(torch.nn.Module):
def __init__(self, layers: str, d_in: int, d_out: int):
super().__init__()
self.linears = torch.nn.ModuleList([
torch.nn.Linear(d_in, d_out),
torch.nn.Linear(d_in, d_out),
])
self.layer_functions = {
"parallel": self.parallel,
"sequential": self.sequential,
}
self.fn = self.layer_functions.get(layers)
def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
return self.fn(x1, x2)
def parallel(self, x1, x2):
x1 = self.linears[0](x1)
x2 = self.linears[0](x2)
return x1 + x2
def sequential(self, x1, x2):
x = x1 + x2
x = self.linears[0](x)
return self.linears[1](x)
3. Использование специальных классов
Если ваши архитектурные изменения становятся довольно сложными, вы можете рассмотреть возможность создания отдельных классов для каждой архитектуры. Эти классы могут наследоваться от базового класса модели и определять свой собственный метод forward()
.
class BaseModel(torch.nn.Module):
def __init__(self, d_in: int, d_out: int):
super().__init__()
self.linears = torch.nn.ModuleList([
torch.nn.Linear(d_in, d_out),
torch.nn.Linear(d_in, d_out),
])
class ParallelModel(BaseModel):
def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
x1 = self.linears[0](x1)
x2 = self.linears[0](x2)
return x1 + x2
class SequentialModel(BaseModel):
def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
x = x1 + x2
x = self.linears[0](x)
return self.linears[1](x)
# Пример использования:
model_parallel = ParallelModel(10, 5)
model_sequential = SequentialModel(10, 5)
Заключение
Такой подход позволяет вам динамически модифицировать архитектуру модели без необходимости использовать условные конструкции внутри метода forward()
. Работая с функциями как с первоклассными объектами или с помощью наследования, вы получаете большую гибкость и удобочитаемость кода. Это не только улучшает структуру вашего кода, но также способствует дальнейшей его поддержке и масштабированию.