Pytorch DataSet.__getitem__() вызван с index, большим чем __len__()

Вопрос или проблема

У меня есть следующий набор данных torch (я заменил фактический код чтения данных из файлов на генерацию случайных чисел, чтобы сделать его минимально воспроизводимым):

from torch.utils.data import Dataset
import torch 

class TempDataset(Dataset):
    def __init__(self, window_size=200):
        
        self.window = window_size

        self.x = torch.randn(4340, 10, dtype=torch.float32) # None
        self.y = torch.randn(4340, 3, dtype=torch.float32) 

        self.len = len(self.x) - self.window + 1 # = 4340 - 200 + 1 = 4141 
                                                # Следовательно, индекс начала последнего окна = 4140 
                                                # А последнее окно будет включать элементы с 4140 до 4339, т.е. всего 200 элементов

    def __len__(self):
        return self.len

    def __getitem__(self, index):

        # Насколько я понимаю, приведенное ниже условие if НИКОГДА не должно быть истинным, так как последний индекс, с которым
        # вызывается __getitem__, должен быть self.len - 1
        if index == self.len: 
            print('self.__len__(): ', self.__len__())
            print('Попытка получить элемент по индексу: ', index)
            
    return self.x[index: index + self.window], self.y[index + self.window - 1]

ds = TempDataset(window_size=200)
print('len: ', len(ds))
counter = 0 # еще не считали записи
for x, y in ds:
    counter += 1 # предыдущая строка прочитала еще одну запись из набора данных
print('counter: ', counter)

Это выводит:

len:  4141
self.__len__():  4141
Попытка получить элемент по индексу:  4141
counter:  4141

Насколько я понимаю, __getitem__() вызывается с index в диапазоне от 0 до __len__()-1. Если это так, то почему он попытался вызвать __getitem__() с индексом 4141, когда длина данных составляет 4141?

Еще одна вещь, которую я заметил, это то, что, несмотря на вызов с index = 4141, он, похоже, не возвращает никаких элементов, из-за чего counter остается на 4141.

Что мои глаза (или мозг) здесь упускают?

PS: Хотя это и не окажет никакого эффекта, чтобы подтвердить, я также попробовал обернуть DataSet в DataLoader от torch, и он все равно ведет себя так же.

Я думаю, проблема в том, что ваш TempDataset не реализует __iter__ или __next__, и Dataset также не делает этого. Лично я не уверен, что оно вообще выдает в качестве итератора в этом случае, но, полагаю, это не должно работать. Вам следует создать подкласс IterableDataset, если вы хотите использовать его таким образом, и реализовать __iter__.

Во-первых

Я хотел бы, чтобы вы уточнили, для чего нужен этот код? Зачем вы используете скользящие окна и вызываете набор данных для какой цели, так как это явно не ясно из заданного вами вопроса.

Во-вторых :-

Вы не вызвали index в цикле for в конце кода, и интерпретатор не может распаковать так много тензоров. Я предлагаю использовать enumerate(), что поможет вам достичь желаемого результата.

Если вы не используете enumerate, тогда будет следующая ошибка:-

введите описание изображения здесь

Поэтому, если вы уделите больше внимания этому:

for index,(x, y)in enumerate(ds):
  print(f'Индекс:{index}, форма : {x.shape}, форма y:{y} ')
  counter += 1 
  print('counter: ', counter)

Вывод будет следующим :- ПОСЛЕ ИЗМЕНЕНИЯ

Ответ или решение

Вопрос о том, почему метод __getitem__() вашего класса TempDataset вызывается с индексом, превышающим значение, возвращаемое методом __len__(), является очень интересным и может не сразу быть очевиден. Давайте проанализируем детали.

  1. Объяснение проблемы:
    Вы создали класс TempDataset, который наследуется от torch.utils.data.Dataset. В вашем методе __len__() возвращается значение, равное 4141, что корректно отражает количество подокон. Тем не менее, возникает ситуация, когда метод __getitem__() вызывается с индексом 4141. Это связано с особенностями итерации по объектам класса Dataset, когда вы используете его в цикле for.

  2. Итерация по объекту Dataset:
    Когда вы используете for x, y in ds:, Python пытается получить итератор для вашего объекта ds. Построение итерации происходит через метод __iter__(), который по умолчанию не определен в классе Dataset и, следовательно, подразумевает вызов range(len(ds)). Когда Python вызывает метод __getitem__(), он не обрабатывает условие index < len(ds) и может попытаться получить доступ к элементу, находящемуся за пределами границ.

  3. Функция, возвращающая значения для индексации:
    Ваша проверка внутри __getitem__() для сравнения индекса с len не срабатывает, так как Python не ожидает, что вы обработаете случаи выхода за пределы массива. Следовательно, ваш код видит это как попытку доступа к элементу.

  4. Рекомендации по решению проблемы:
    Для избежания такой ситуации рекомендуется явно обрабатывать индекс в методе __getitem__(). Вы можете добавить это условие:

    def __getitem__(self, index):
       if index >= self.len:
           raise IndexError("Index out of bounds")
       return self.x[index: index + self.window], self.y[index + self.window - 1]

    Кроме того, рассмотрите возможность реализации метода __iter__() для вашего класса, чтобы управлять итерацией более очевидным образом.

  5. Использование enumerate:
    Рекомендация использовать enumerate() для итерации также имеет смысл. Это не только поможет избежать путаницы с индексированием, но и улучшит производительность кода, так как будет более явно видно, как индексы соответствуют вашим данным. Правильный код будет выглядеть следующим образом:

    for index, (x, y) in enumerate(ds):
       print(f'Index: {index}, x shape: {x.shape}, y shape: {y.shape}')
  6. Исходные данные и контекст:
    Перед использованием такого подхода необходимо убедиться, что ваши данные действительно содержат необходимую информацию, как это подразумевается в проекте. Концепция скользящего окна, которую вы используете, предполагает, что вы обрабатываете временные ряды или подобные данные.

Таким образом, если вы хотите избежать ошибок доступа к элементам за пределами массива, добавление обработки индекса в метод __getitem__() и использование итератора с enumerate() – это правильные шаги к эффективной отладке вашего кода.

Оцените материал
Добавить комментарий

Капча загружается...