Почему моя программа не работает для тензоров более высокой размерности?

Вопрос или проблема

Я пытаюсь написать модель в pytorch. У меня есть 3 класса A, B, C. Каждый класс — это список тензоров. Я хочу брать образцы из этих классов. У каждого класса есть также резервный класс, когда класс пуст, он берёт образец из него. Например, когда класс A становится пустым, моя модель берёт образец из резервного списка и маркирует его как A.
Я написал программу об этом. Проблема в том, что когда A, B и C — это списки тензоров одной размерности, например, A= [1,2,3,4], программа работает. Но когда мои входные данные имеют более высокую размерность, например, A= [torch.tensor([1,1]), torch.tensor([2,2]), torch.tensor([3,3]), torch.tensor([4,4])] и reserve_A= [torch.tensor([5,5]), torch.tensor([6,6]), torch.tensor([7,7]), torch.tensor([8,8])], возникает ошибка. Код для вызова из резервного списка и присвоения такой же метки, как A, выглядит следующим образом:

CLASS_CONFIG = {
    # Формат: {class_id: {'main': range, 'reserve': range}}
    0: {  # Класс A
        'main': [torch.tensor([1,1]), torch.tensor([2,2]), torch.tensor([3,3]), torch.tensor([4,4]), torch.tensor([5,5])],
        'reserve': [torch.tensor([16,16]), torch.tensor([17,17]), torch.tensor([18,18]), torch.tensor([19,19]), torch.tensor([20,20])]
    },
    1: {  # Класс B
        'main': [torch.tensor([6,6]), torch.tensor([7,7]), torch.tensor([8,8]), torch.tensor([9,9]), torch.tensor([10,10])],
        'reserve': [torch.tensor([21,21]), torch.tensor([22,22]), torch.tensor([23,23]), torch.tensor([24,24]), torch.tensor([25,25])]
    },
    2: {  # Класс C
        'main': [torch.tensor([11,11]), torch.tensor([12,12]), torch.tensor([13,13]), torch.tensor([14,14]), torch.tensor([15,15])],
        'reserve': [torch.tensor([26,26]), torch.tensor([27,27]), torch.tensor([28,28]), torch.tensor([29,29]), torch.tensor([30,30])]
    }
}

и затем:

    def get_class(student_num):
        """Определить исходный класс студента"""
        for cls, config in CLASS_CONFIG.items():
            if student_num in config['main'] or student_num in config['reserve']:
                return cls
        raise ValueError(f"Некорректный номер студента: {student_num}")

Но я получаю ошибку:

Cell In[1], line 159, in GroupFormer.get_class(student_num)
    157 """Определить исходный класс студента"""
    158 for cls, config in CLASS_CONFIG.items():
--> 159     if student_num in config['main'] or student_num in config['reserve']:
    160         return cls
    161 raise ValueError(f"Некорректный номер студента: {student_num}")

RuntimeError: Boolean value of Tensor with more than one value is ambiguous```

Это странно, потому что когда мои входные данные являются списком из 1*1 тензоров, я не получаю такую ошибку. Но на более высоком уровне возникает эта ошибка.
Кто-нибудь может помочь мне исправить эту ошибку? Любой совет будет оценён.

Ответ или решение

Теория

Ваша проблема возникла из-за того, как Python и PyTorch управляют объектами и сравнением векторов. Когда вы работаете с объектами PyTorch, такими как тензоры, они представляют более сложные данные, которые требуют специальных операций для сравнения. В случае одномерных тензоров (например, [1, 2, 3, 4]), вы работаете с простыми целыми числами, которые могут быть легко сравнены между собой с помощью стандартных операторов сравнения Python, таких как in. Однако, когда вы переходите к более высоким измерениям (например, torch.tensor([1,1])), эти операторы стандартного сравнения больше не работают так, как это было при сравнении скаляров.

Основная причиной ошибки, которую вы наблюдаете, является в том, как Python обрабатывает булевы значения более сложных объектов, таких как тензоры с несколькими элементами. Ошибка "Boolean value of Tensor with more than one value is ambiguous" указывает на то, что PyTorch не может произвести операцию сравнения на тензоре, содержащем более одного элемента, без явного указания, как именно следует проводить это сравнение (например, используя методы, такие как torch.all или torch.any).

Пример

Рассмотрим следующий фрагмент кода, который вызывает ошибку:

def get_class(student_num):
    """Determine student's original class"""
    for cls, config in CLASS_CONFIG.items():
        if student_num in config['main'] or student_num in config['reserve']:
            return cls
    raise ValueError(f"Invalid student number: {student_num}")

Здесь student_num является тензором, в то время как config['main'] и config['reserve'] — это списки тензоров. Использование оператора in с тензорами в качестве операндов приводит к вышеуказанной ошибке, так как Python не может понять, каким образом должен проводиться этот поиск.

Применение

Чтобы устранить эту ошибку, вы должны изменить метод сравнения, использованный в вашей функции. Вместо использования оператора in непосредственно с тензорами, необходимо использовать специальные функции из PyTorch или Python, которые могут корректно интерпретировать многомерные сравнения.

Вот один из возможных способов решения проблемы:

import torch

def tensors_are_equal(tensor1, tensor2):
    """Проверяет равенство двух тензоров."""
    return torch.equal(tensor1, tensor2)

def get_class(student_num):
    """Определяет оригинальный класс студента."""
    for cls, config in CLASS_CONFIG.items():
        for main_tensor in config['main']:
            if tensors_are_equal(student_num, main_tensor):
                return cls
        for reserve_tensor in config['reserve']:
            if tensors_are_equal(student_num, reserve_tensor):
                return cls
    raise ValueError(f"Invalid student number: {student_num}")

В этой версии кода введен вспомогательный метод tensors_are_equal, который использует torch.equal, чтобы проверить, идентичны ли два тензора. Теперь в случае поиска по спискам main и reserve сначала производится итерация по элементам каждого списка, и для проверки равенства используется вспомогательная функция. Это гарантирует корректное сравнение многомерных тензоров.

Кроме того, возможно, будет полезно провести оптимизацию, если количество примеров в списках велико, чтобы избежать излишних проверок. Это можно делать через предварительную обработку структуры данных, например, хранить элементы в виде хеш-таблицы, только если эффективность является критической в рамках более узкого сценария.

Используя это решение, вы обеспечите корректную работу программы вне зависимости от сложности ваших тензоров, а также устраните проблему, связанную с неоднозначным сравнением в PyTorch.

Оцените материал
Добавить комментарий

Капча загружается...