Использование ResNet50 с SE-блоком на несбалансированных данных – pytorch

Вопрос или проблема

Я работал с набором данных ультразвуковых изображений рака груди, содержащим 432 доброкачественных случаев, 210 злокачественных случаев и 133 нормальных случаев. Сначала я использовал предобученную модель ResNet-50, которая дала следующие результаты:

  • Оценки F1 для тестирования: Класс 0 (доброкачественный): 0.53, Класс 1 (злокачественный): 0.67, Класс 2 (нормальный): 0.80
  • Макро F1 оценка для тестирования: 0.6652
    введите описание изображения здесь
    Эти оценки указывают на умеренную эффективность, что понимаемо, учитывая относительно небольшой размер выборки, ограничивающий эффективность модели. Затем я наткнулся на статью, обсуждающую, как интеграция блока Squeeze-and-Excitation (SE) в ResNet-50 может улучшить его производительность, поэтому я решил попробовать это.
transform_my_train = T.Compose([
    T.Resize(256),
    T.CenterCrop(224),
    T.RandomHorizontalFlip(),
        T.RandomVerticalFlip(),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
transform_my=T.Compose([ T.Resize(256),
    T.CenterCrop(224), T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
train_df, temp_df = train_test_split(df, test_size=0.2, stratify=df['label'], random_state=42)  # 80% обучающие, 20% временные
val_df, test_df = train_test_split(temp_df, test_size=0.5, stratify=temp_df['label'], random_state=42)  # 50% валидация, 50% тест
train_df = train_df.reset_index(drop=True)
val_df = val_df.reset_index(drop=True)
test_df = test_df.reset_index(drop=True)
print("TRAIN Dataset: {}".format(train_df.shape))
print("VALIDATION Dataset: {}".format(val_df.shape))
print("TEST Dataset: {}".format(test_df.shape))
train_df_2=Dataset(images=train_df['file_path'].values,labels=train_df['label'].values,transform=transform_my_train)
val_df_2=Dataset(images=val_df['file_path'].values,labels=val_df['label'].values,transform=transform_my)
test_df_2=Dataset(images=test_df['file_path'].values,labels=test_df['label'].values,transform=transform_my)

train_params = {'batch_size': 4,
                'shuffle': True,
                'num_workers': 0,
                'pin_memory':True
                }

test_params = {'batch_size': 4,
                'shuffle': True,
                'num_workers': 0,
               'pin_memory':True
                }
Validation_params = {'batch_size': 4,
                'shuffle': True,
                'num_workers': 0,
                     'pin_memory':True
                }
training_loader = DataLoader(train_df_2, **train_params)
testing_loader = DataLoader(test_df_2, **test_params)
validation_loader = DataLoader(val_df_2, **Validation_params)

import timm

# Загрузка предобученной модели
model = timm.create_model('seresnet50', pretrained=True)
class resnet50_se(nn.Module):
  def __init__(self):
    super(resnet50_se, self).__init__()
    self.resnet =model
    self.dropout = torch.nn.Dropout(0.3)
    self.fc1=torch.nn.Linear(1000,3)
  def forward(self,x):
    x=self.resnet(x)
    x=F.relu(x)
    x=self.dropout(x)
    x = F.softmax(self.fc1(x), dim=1)
    return x
from torch.optim.lr_scheduler import CyclicLR
class_weights = torch.tensor(class_weights, dtype=torch.float).to(device)
# Создание функции потерь и оптимизатора
loss_function = torch.nn.CrossEntropyLoss(weight=class_weights)
optimizer = torch.optim.Adam(params =my_resnet50_se.parameters(), lr=0.01,weight_decay=1e-3 )
#sched = torch.optim.lr_scheduler.OneCycleLR(optimizer, 0.001, epochs=50, steps_per_epoch=len(training_loader))
Epoch 11/50, Обучаемая потеря: 1.0832, Валидационная потеря: 1.0850
Обучаемые оценки F1: {0: 0.0, 1: 0.0, 2: 0.7186858316221766}, Валидационные оценки F1: {0: 0.0, 1: 0.0, 2: 0.7213114754098361}
Обучаемый макро F1: 0.2396, Валидационный макро F1: 0.2404 

введите описание изображения здесь
Как вы можете видеть, производительность модели значительно снизилась, особенно в ее способности обучаться на недостаточно представленых классах. Я был удивлён этим результатом, так как изначально ожидал, что добавление блока SE только улучшит модель. Может ли снижение производительности быть связано с несбалансированным набором данных? Поскольку изображений с доброкачественными образованиями больше, чем нормальных и злокачественных, возможно, что во время фазы возбуждения каналы, ответственные за меньшинство классов (нормальные и злокачественные), получили меньше внимания, что привело к худшему представительству для этих классов.

Ответ или решение

Использование ResNet50 с блоком SE на несбалансированных данных — PyTorch

Ваша ситуация с использованием модели ResNet50 с добавленным блоком Squeeze-and-Excitation (SE) на несбалансированном наборе данных изображений молочной железы является интересным примером вызовов, с которыми могут столкнуться исследователи и инженеры в области глубокого обучения. Давайте рассмотрим несколько ключевых аспектов влияния несбалансированности данных и возможные пути оптимизации вашего подхода.

1. Анализ текущих результатов

Вы упомянули, что первоначальные результаты вашего проекта, основанные на ResNet50, показали следующие значения F1 для классов:

  • Benign (0) : 0.53
  • Malignant (1) : 0.67
  • Normal (2) : 0.80

Эти результаты указывают на то, что модель показывает сравнительно неплохую точность для класса "нормальный", тогда как точность для менее представленных классов выше, чем 50% и 60%. Постепенно переход к интеграции блока SE привел к снижению производительности:

  • Обучение Macro F1 : 0.2396
  • Валидация Macro F1 : 0.2404

Данный спад в качестве указывает на необходимость более глубокого анализа и проработки модели.

2. Влияние несбалансированности данных

Ваши мысли о том, что несбалансированность данных может негативно повлиять на процесс обучения, абсолютно правы. Когда классы представлены неравномерно, модель может стать предвзятой к более представленным классам, игнорируя меньшие классы. Это может привести к тому, что каналы в блоке SE, представляющие малочисленные классы (нормальные и злокачественные), получают меньше внимания.

Предлагаемое решение:

  • Использование весов класса: Вы уже определили веса классов в вашей функции потерь CrossEntropyLoss, что является хорошей практикой. Регулируя веса, можно компенсировать влияние несбалансированности данных. Если вы еще не пробовали, рассмотрите возможность адаптивного изменения весов во время обучения.

3. Оптимизация архитектуры модели

Интеграция SE блока может быть не единственным методом улучшения итогов. Рекомендуется рассмотреть следующие практики:

  • Изменение структуры модели: Попробуйте модифицировать последний слой вашей сети. Например, вы можете увеличить размерность вывода последнего слоя до того, чтобы снизить вероятность потери информации для малочисленных классов.

  • Увеличение данных (Data Augmentation): Пополните набор данных, применяя такие техники, как аугментация размеров, повороты, сдвиги и добавление белого шума, чтобы сбалансировать набор данных и улучшить обобщающую способность модели.

4. Рекомендации по обучению

  • Изменение стратегии обучения: Вместо Adam, попробуйте другой оптимизатор, например, SGD с моментумом, который может оказаться более стойким к нестабильным данным.

  • Настройка гиперпараметров: Поскольку вы начали с довольно высокой скорости обучения (lr=0.01), попробуйте уменьшить её, возможно, до 0.001 или ниже, чтобы избежать резких колебаний в процессе обучения.

  • Использование различных методов валидации: Соучастие и K-кратная кросс-валидация могут оказать положительное влияние на предсказания моделей на малочисленных классах.

Заключение

Ваша работа с ResNet50 и блоком SE в контексте несбалансированных данных подчеркивает важные аспекты машинного обучения, требующие внимания. Исследование причин снижения производительности и поиск путей для улучшения модели — неотъемлемая часть работы с глубинными нейронными сетями. Регулярно тестируйте и адаптируйте свою структуру, весовые коэффициенты и методы оптимизации, чтобы повысить эффективность модели.

Для дальнейшего улучшения результатов, вероятно, придется комбинировать несколько стратегий, чтобы справиться с вызовами, связанными с несбалансированностью данных в вашей задаче. Успех в вашем исследовании!

Оцените материал
Добавить комментарий

Капча загружается...