Вопрос или проблема
У меня есть следующий набор данных torch (я заменил фактический код чтения данных из файлов на генерацию случайных чисел, чтобы сделать его минимально воспроизводимым):
from torch.utils.data import Dataset
import torch
class TempDataset(Dataset):
def __init__(self, window_size=200):
self.window = window_size
self.x = torch.randn(4340, 10, dtype=torch.float32) # None
self.y = torch.randn(4340, 3, dtype=torch.float32)
self.len = len(self.x) - self.window + 1 # = 4340 - 200 + 1 = 4141
# Следовательно, индекс начала последнего окна = 4140
# А последнее окно будет включать элементы с 4140 до 4339, т.е. всего 200 элементов
def __len__(self):
return self.len
def __getitem__(self, index):
# Насколько я понимаю, приведенное ниже условие if НИКОГДА не должно быть истинным, так как последний индекс, с которым
# вызывается __getitem__, должен быть self.len - 1
if index == self.len:
print('self.__len__(): ', self.__len__())
print('Попытка получить элемент по индексу: ', index)
return self.x[index: index + self.window], self.y[index + self.window - 1]
ds = TempDataset(window_size=200)
print('len: ', len(ds))
counter = 0 # еще не считали записи
for x, y in ds:
counter += 1 # предыдущая строка прочитала еще одну запись из набора данных
print('counter: ', counter)
Это выводит:
len: 4141
self.__len__(): 4141
Попытка получить элемент по индексу: 4141
counter: 4141
Насколько я понимаю, __getitem__()
вызывается с index
в диапазоне от 0
до __len__()-1
. Если это так, то почему он попытался вызвать __getitem__()
с индексом 4141, когда длина данных составляет 4141?
Еще одна вещь, которую я заметил, это то, что, несмотря на вызов с index = 4141
, он, похоже, не возвращает никаких элементов, из-за чего counter
остается на 4141.
Что мои глаза (или мозг) здесь упускают?
PS: Хотя это и не окажет никакого эффекта, чтобы подтвердить, я также попробовал обернуть DataSet
в DataLoader
от torch, и он все равно ведет себя так же.
Я думаю, проблема в том, что ваш TempDataset
не реализует __iter__
или __next__
, и Dataset
также не делает этого. Лично я не уверен, что оно вообще выдает в качестве итератора в этом случае, но, полагаю, это не должно работать. Вам следует создать подкласс IterableDataset
, если вы хотите использовать его таким образом, и реализовать __iter__
.
Во-первых
Я хотел бы, чтобы вы уточнили, для чего нужен этот код? Зачем вы используете скользящие окна и вызываете набор данных для какой цели, так как это явно не ясно из заданного вами вопроса.
Во-вторых :-
Вы не вызвали index
в цикле for в конце кода
, и интерпретатор не может распаковать так много тензоров. Я предлагаю использовать enumerate()
, что поможет вам достичь желаемого результата.
Если вы не используете enumerate, тогда будет следующая ошибка:-
Поэтому, если вы уделите больше внимания этому:
for index,(x, y)in enumerate(ds):
print(f'Индекс:{index}, форма : {x.shape}, форма y:{y} ')
counter += 1
print('counter: ', counter)
Ответ или решение
Вопрос о том, почему метод __getitem__()
вашего класса TempDataset
вызывается с индексом, превышающим значение, возвращаемое методом __len__()
, является очень интересным и может не сразу быть очевиден. Давайте проанализируем детали.
-
Объяснение проблемы:
Вы создали классTempDataset
, который наследуется отtorch.utils.data.Dataset
. В вашем методе__len__()
возвращается значение, равное 4141, что корректно отражает количество подокон. Тем не менее, возникает ситуация, когда метод__getitem__()
вызывается с индексом 4141. Это связано с особенностями итерации по объектам классаDataset
, когда вы используете его в циклеfor
. -
Итерация по объекту Dataset:
Когда вы используетеfor x, y in ds:
, Python пытается получить итератор для вашего объектаds
. Построение итерации происходит через метод__iter__()
, который по умолчанию не определен в классеDataset
и, следовательно, подразумевает вызовrange(len(ds))
. Когда Python вызывает метод__getitem__()
, он не обрабатывает условиеindex < len(ds)
и может попытаться получить доступ к элементу, находящемуся за пределами границ. -
Функция, возвращающая значения для индексации:
Ваша проверка внутри__getitem__()
для сравнения индекса сlen
не срабатывает, так как Python не ожидает, что вы обработаете случаи выхода за пределы массива. Следовательно, ваш код видит это как попытку доступа к элементу. -
Рекомендации по решению проблемы:
Для избежания такой ситуации рекомендуется явно обрабатывать индекс в методе__getitem__()
. Вы можете добавить это условие:def __getitem__(self, index): if index >= self.len: raise IndexError("Index out of bounds") return self.x[index: index + self.window], self.y[index + self.window - 1]
Кроме того, рассмотрите возможность реализации метода
__iter__()
для вашего класса, чтобы управлять итерацией более очевидным образом. -
Использование enumerate:
Рекомендация использоватьenumerate()
для итерации также имеет смысл. Это не только поможет избежать путаницы с индексированием, но и улучшит производительность кода, так как будет более явно видно, как индексы соответствуют вашим данным. Правильный код будет выглядеть следующим образом:for index, (x, y) in enumerate(ds): print(f'Index: {index}, x shape: {x.shape}, y shape: {y.shape}')
-
Исходные данные и контекст:
Перед использованием такого подхода необходимо убедиться, что ваши данные действительно содержат необходимую информацию, как это подразумевается в проекте. Концепция скользящего окна, которую вы используете, предполагает, что вы обрабатываете временные ряды или подобные данные.
Таким образом, если вы хотите избежать ошибок доступа к элементам за пределами массива, добавление обработки индекса в метод __getitem__()
и использование итератора с enumerate()
– это правильные шаги к эффективной отладке вашего кода.