Моя сеть для классификации диалектов не работает

Вопрос или проблема

Я написал следующий код для классификации диалектов на основе набора данных TIMIT, используя .wav файлы. По какой-то причине моя модель не обучается и классифицирует все в один и тот же класс. Обязательно ли конвертировать в спектрограмму? Если да, то как это можно сделать? Может кто-то, пожалуйста, помочь мне понять, почему это происходит? Заранее спасибо.



import os
import torchaudio
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import random
import deeplake
import librosa
import numpy as np
import logging

not_all_files = True
# Настроить логирование
logging.basicConfig(
    level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)

logger = logging.getLogger(__name__)

class VoiceClassifier(nn.Module):
    def __init__(self):
        super(VoiceClassifier, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, kernel_size=(1, 3), padding=(0, 1))
        self.conv2 = nn.Conv2d(20, 64, kernel_size=(1, 3), padding=(0, 1))
        self.fc1 = nn.Linear(64 * 50000, 64)
        self.fc2 = nn.Linear(64, 2)
        self.relu = nn.ReLU()
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        logger.info(f"x1.shape: {x.shape}")
        x = self.relu(self.conv1(x))
        logger.info(f"x2.shape: {x.shape}")
        x = self.relu(self.conv2(x))
        logger.info(f"x3.shape: {x.shape}")
        batch_size = x.shape[0]
        x = x.view(batch_size, -1)
        logger.info(f"x4.shape: {x.shape}")
        x = self.relu(self.fc1(x))
        logger.info(f"x5.shape: {x.shape}")
        x = self.fc2(x)
        logger.info(f"x6.shape: {x.shape}")
        x = self.softmax(x)
        logger.info(f"x7.shape: {x.shape}")
        return x

# Определить пользовательский набор данных
class TIMITDataset(Dataset):
    def __init__(self, root_dir, dialects, transform=None):
        self.root_dir = root_dir
        self.dialects = dialects
        self.transform = transform
        self.samples = []
        self.max_length = 50000

        # Случайно выбрать два диалекта
        selected_dialects = random.sample(dialects, 2)

        # Загрузить аудиофайлы и их метки
        for dialect, label in zip(selected_dialects, [0, 1]):
            dialect_dir = os.path.join(root_dir, dialect)
            for filename in os.listdir(dialect_dir):
                if not_all_files:
                    if filename.endswith(".WAV"):
                        file_path = os.path.join(dialect_dir, filename)
                        self.samples.append((file_path, label))
                else:
                    if os.path.isdir(os.path.join(dialect_dir, filename)):
                        sub_file_path = os.path.join(dialect_dir, filename)
                        for file in os.listdir(sub_file_path):
                            if file.endswith(".WAV"):
                                f_path = os.path.join(sub_file_path, file)
                                self.samples.append((f_path, label))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        file_path, label = self.samples[idx]
        waveform, sample_rate = librosa.load(file_path, sr=None)
        # Дополнить или обрезать волну до фиксированной длины
        if self.max_length is not None:
            waveform = self.pad_or_truncate(waveform, self.max_length)

        waveform = torch.from_numpy(waveform).unsqueeze(0)
        if self.transform:
            waveform = self.transform(waveform)
        return waveform, label

    def pad_or_truncate(self, waveform, max_length):
        length = waveform.shape[-1]
        if length > max_length:
            waveform = waveform[:max_length]
        else:
            padding = max_length - length
            offset = padding // 2
            waveform = np.pad(waveform, (offset, padding - offset), mode="constant")
        return waveform

# Путь к директории, где вы извлекли набор данных TIMIT
dataset_path = "/home/paul/Documents/Projects/voice_model/data"

# Список доступных диалектов в корпусе TIMIT
dialects = ["dr5", "dr6"]

# Создание набора данных и загрузчиков данных
train_dataset = TIMITDataset(os.path.join(dataset_path, "TRAIN"), dialects)
train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True)

test_dataset = TIMITDataset(os.path.join(dataset_path, "TEST"), dialects)
test_loader = DataLoader(test_dataset, shuffle=True)

# Инициализация модели, функции потерь и оптимизатора
model = VoiceClassifier()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Цикл обучения
num_epochs = 10
for epoch in range(num_epochs):
    running_loss = 0.0
    for inputs, labels in train_loader:
        inputs = inputs.unsqueeze(1)  # Добавить размерность канала
        optimizer.zero_grad()

        logger.info(f"len of dataset: {train_dataset.__len__()}")
        logger.info(f"input shape: {inputs.shape}")
        outputs = model(inputs)
        logger.info(f"output shape: {outputs.shape}")
        logger.info(f"output: {outputs}\n")

        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    print(f"Эпоха {epoch + 1}, Потеря: {running_loss / len(train_loader)}")

# Оценка
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for inputs, labels in test_loader:
        inputs = inputs.unsqueeze(1)
        outputs = model(inputs)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f"Точность: {100 * correct / total}%")

```

Ответ или решение

Ваш код на Python для классификации диалектов на основе датасета TIMIT имеет несколько критических моментов, которые могут объяснять, почему ваша модель не обучается и классифицирует всё в один класс. Давайте разберем возможные проблемы и пути их решения.

1. Форматирование входных данных

Да, преобразование ваших .wav файлов в спектрограммы (или другие формы представления сигналов) может значительно улучшить обучение модели. Модели глубокого обучения обычно лучше работают с изображениями или признаковыми картами, так как они могут выявлять локальные паттерны. Вот как вы можете преобразовать аудиофайлы в спектрограммы:

def get_spectrogram(waveform, n_fft=1024, win_length=None, hop_length=512):
    if win_length is None:
        win_length = n_fft
    spectrogram = torchaudio.transforms.Spectrogram(n_fft=n_fft, win_length=win_length, hop_length=hop_length)(waveform)
    return spectrogram

Ваша функция __getitem__ в классе TIMITDataset должна быть обновлена:

def __getitem__(self, idx):
    file_path, label = self.samples[idx]
    waveform, sample_rate = librosa.load(file_path, sr=None)
    if self.max_length is not None:
        waveform = self.pad_or_truncate(waveform, self.max_length)

    waveform = torch.from_numpy(waveform).unsqueeze(0)
    waveform = self.get_spectrogram(waveform)  # Преобразование в спектрограмму
    if self.transform:
        waveform = self.transform(waveform)
    return waveform, label

2. Изменение архитектуры модели

Архитектура вашей модели может быть не оптимальна для работы с аудиоданными. Стандартные архитектуры CNN могут не дать хороших результатов из-за особенностей обработки временных рядов. Попробуйте использовать 1D свёрточные слои:

self.conv1 = nn.Conv1d(1, 20, kernel_size=3)
self.conv2 = nn.Conv1d(20, 64, kernel_size=3)
# Измените размер входных данных fc1 соответственно
self.fc1 = nn.Linear(64 * (размер(после_свёртки)), 64)

3. Увеличение размерности данных

Если у вас малый объем данных (что свойственно аудиоданным), стоит рассмотреть возможность аугментации данных – например, изменяя скорость воспроизведения, добавляя шум, изменяя высоту тона и т.д. Это поможет улучшить обобщающую способность вашей модели.

4. Неправильная инициализация модели

Убедитесь, что веса вашей модели инициализированы корректно. Можно добавить процедуру инициализации весов:

def init_weights(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
        torch.nn.init.xavier_uniform_(m.weight)

model.apply(init_weights)

5. Настройка гиперпараметров

Обратите внимание на размер батча и скорость обучения. Попробуйте изменить скорость обучения или увеличить размер батча. Также убедитесь, что ваши метки корректно закодированы. Возможно, вам понадобится изменить последние слои вашей модели, если классов больше или меньше, чем два.

6. Отладка

Чтобы понять, что происходит в вашей модели, добавьте больше информации для отладки. Проверьте, как изменяется функция потерь и какова точность на валидационной выборке. Это может дать вам дополнительную информацию о проблемах с обучением.

Заключение

Преобразование аудиосигналов в спектрограммы, изменение архитектуры вашей модели, применение аугментации данных и адекватная настройка гиперпараметров могут существенно помочь в нахождении решения проблемы. Постепенно принимайте эти изменения и тестируйте их, чтобы определить, что влияет на производительность модели. Удачи!

Оцените материал
Добавить комментарий

Капча загружается...