Не удается подогнать ИНС к “простому” набору данных?

Вопрос или проблема

Мне действительно трудно смоделировать набор данных, который я получил, проводя эксперименты. Конкретно это временные ряды (онлайн) данных измерений, а целевые параметры – это кинетические параметры, которые я подгонял под каждый дискретный временной интервал (1 час) химической реакции. Я думал, что это будет довольно простая задача моделирования, поскольку публикации на такие темы существуют. Однако речь идет о другой химической реакции. Я делаю что-то не так? Итак, Var1 – это одно измерение. Var2 – это производная измерения Var1 в той же точке времени. Var3 – это другое измерение. А Var4 – производная Var3 в этой же точке времени. В следующем я предоставлю вам мой набор данных и код, который я использовал. Любая помощь будет очень и очень признательна.

Var1 Var2 Var3 Var4 Obj1 Obj2
362.5236061 -0.07435406 864530.2255 3682.064903 0.003213373 0.00428076
359.7557849 -0.025538678 736590.687 -1075.97495 0.003263493 0.001824691
358.4991647 -0.017450826 691352.2705 -588.4522644 0.003275051 0.000967166
357.6539991 -0.010385667 661583.842 -303.3271618 0.003279432 0.000558512
357.0186877 -0.008255159 640921.1745 -194.7817104 0.003281502 0.000336697
356.5327836 -0.005731424 626451.1975 -198.922334 0.003282606 0.000207903
356.1503876 -0.003917789 615852.341 -94.75128942 0.00328324 0.000130188
355.8705829 -0.004881314 608682.1075 -101.6350643 0.00328363 8.21961E-05
355.64497 -0.003110438 603802.061 -160.5375622 0.003283885 5.21238E-05
355.4639077 -0.001994076 600991.2295 -135.4780436 0.003283367 3.4352E-05
356.5001926 -0.17556667 19813.98608 -115.8315456 0.005954796 0.060327557
354.7419647 -0.010554267 18572.43765 -15.1086585 0.019552834 0.002063382
354.5410356 0.004333489 18395.97176 3.58709367 0.019797517 0.000207129
354.4705883 0.003261012 18328.458 0.260837045 0.019477993 0.000139907
354.4471981 -0.000602105 18251.73567 -5.069623077 0.01871281 0.000274773
354.3714636 -0.002022128 18129.30001 -7.523852436 0.015549203 0.000622012
354.3472383 -0.005895534 18006.22191 -4.295397602 0.007628973 0.001178864
354.3640378 -8.50794E-05 17989.67701 0.778188728 0.003754402 0.001628867
354.3513409 -0.002118818 17925.92188 -0.650045175 0.003502226 0.002346723
354.3589349 0.000630516 17827.10342 -7.597788125 0.006913801 0.005766069
358.5681349 -0.137364094 1417431.912 5475.904179 0.002386724 0.021356606
354.8548339 -0.025447599 1433291.357 -501.5377722 0.002640114 0.008975807
353.9699445 -0.006276057 1399737.379 -532.2281408 0.002680502 0.005460136
353.5810013 -0.004313607 1353541.897 -762.2193935 0.002694162 0.003776971
353.4113006 -0.002831483 1301753.769 -706.8966195 0.002700392 0.002793216
353.2983542 -0.001158295 1265871.981 -267.6625715 0.002703748 0.002151943
353.2363759 -0.000551958 1241390.818 -382.8823426 0.002705757 0.001704076
353.2065872 0.000162355 1221418.773 -296.6559633 0.002707054 0.001376099
353.1937995 0.001616808 1203285.87 -232.2787229 0.002707935 0.001127516
353.2084878 -0.001193965 1190199.653 -200.9625616 0.002708213 0.000935661
355.2607728 -0.06549112 428587.8465 -2711.557955 0.012031713 0.03116071
353.3381078 -0.009040949 384908.2786 -206.9562854 0.013995544 0.010595812
353.0920602 -0.00379961 373721.5884 -179.0413032 0.014326813 0.005225963
352.9546877 -6.51482E-05 362706.5069 -157.7725837 0.014457437 0.002872939
352.8822109 -0.000143343 351736.6721 -170.2497273 0.014639103 0.001475783
352.8337743 -0.001013735 340706.7499 -123.3086883 0.014273536 0.001056589
352.7950174 0.000967585 331108.5719 -150.6339713 0.012875985 0.001497312
352.7563879 -0.004171553 323477.3783 -105.4632373 0.009130093 0.002423721
352.7234116 -0.000953956 317607.4407 -61.80249755 0.005652496 0.003252655
352.6931665 -0.003292633 312764.6088 -72.34633083 0.007929598 0.006297833
354.9099373 -0.064489957 59228.19355 -34.66483122 0.001684243 0.036859361
354.2851724 -0.002153598 52131.66616 -51.55228823 0.002411635 0.009851665
354.0427482 -0.002159137 50261.48344 -42.90076455 0.002465327 0.00553448
353.9577172 0.002098824 48680.74202 -6.426667598 0.00248066 0.00368941
353.9119757 -0.003922988 47473.56546 -15.18383365 0.002487109 0.002667652
353.8844547 0.002037318 46285.15086 -16.07529822 0.002490416 0.002022766
353.8668704 0.002827529 45195.30073 -16.9351117 0.002492325 0.001582153
353.8736059 0.000766154 44332.88679 2.868984782 0.002493515 0.001264637
353.8714587 -0.000399668 43442.88279 -27.44262532 0.002494291 0.001026988
353.8811012 0.000340549 42578.14691 -14.50903174 0.002494568 0.000848156
360.3026948 -0.251951117 43670.5716 -660.0508025 0.009775108 0.005432146
355.0732977 -0.034455945 33782.5738 -51.43104459 0.010053779 0.010858642
353.9743005 -0.005542832 31973.23072 -18.1358307 0.013353613 0.039248918
353.5721478 -0.003969615 31006.18569 -14.10908244 0.002593539 0.069104028
353.3813316 -0.003120327 30241.61401 -10.49708263 0.002215485 0.00242661
353.2839077 0.001319001 29525.55267 -10.14276558 0.001530724 0.001549412
353.2021507 -0.000136787 28906.70627 -4.715726494 0.001462896 0.00147424
353.2302605 0.000632397 28429.48104 -9.911424404 0.00166997 0.001680189
353.1941322 -0.001467403 27921.07523 -7.234047644 0.002364115 0.002377429
353.1655174 -0.00474924 27495.34164 -5.85646953 0.00584086 0.005873753
362.8602921 -0.068471633 660881.783 -4939.335052 0.00689823 0.001411131
357.3865145 -0.068253677 502808.1665 -889.3749967 0.0069007 0.001716404
354.6802533 -0.029288028 477658.981 -168.2014051 0.006904063 0.002104605
353.5374195 -0.01293931 472322.0689 31.03798463 0.006908756 0.002607386
353.0292191 -0.00594159 470954.9579 -25.26509872 0.006915607 0.003274263
352.7288209 -0.004889982 471085.8684 15.12284176 0.006926112 0.004187288
352.5306925 -0.004590968 471796.343 10.16912988 0.006943328 0.005493162
352.4127263 -0.000399474 472833.1245 38.27497096 0.006974109 0.007482901
352.3050464 -0.002530674 474542.7729 8.638551421 0.007033087 0.010827543
352.2001627 -1.2048E-05 475979.5012 0.169241821 0.007224841 0.017501587
361.7847108 -0.009007496 1153715.769 -6689.862256 0.001742378 0.017626545
360.7925504 -0.010893197 1050694.028 -797.7488604 0.002226841 0.003056319
360.1930918 -0.008834233 1011533.294 -567.8726227 0.00224652 0.001467164
359.6481841 -0.007971995 977920.9475 -620.7601626 0.002251805 0.000840495
359.1329689 -0.007196525 948079.119 -399.1770228 0.002253953 0.00052013
358.6224796 -0.008396382 919272.018 -301.0694647 0.002255019 0.00033549
358.1804212 -0.007504499 894862.8745 -297.3461779 0.002255614 0.000221761
357.7788613 -0.005287814 872575.8815 -229.7678176 0.002255972 0.000148833
357.4293511 -0.006047173 853865.9935 -327.1636538 0.002256198 0.000100853
357.1290048 -0.0048727 838230.665 -192.8233643 0.002256036 6.99439E-05
365.3614792 0.002578826 44113.43769 -43.16217324 0.000398913 0.001770589
365.2331841 -2.34215E-05 40861.15028 -52.40644436 0.000399046 0.002032988
365.1497403 -0.0019704 37532.09081 -62.30731611 0.000399248 0.002377068
365.007455 -0.000260381 33650.23947 -70.26406207 0.000399571 0.002847905
364.8196841 -0.004124577 29738.58254 -59.4030107 0.000400135 0.003531402
364.5005972 -0.009202236 26285.55899 -51.59209241 0.00040126 0.004614346
364.0713779 -0.010897665 23471.97984 -48.61360046 0.000404049 0.006596842
363.5073076 -0.007957968 21186.9511 -29.74440468 0.000414847 0.011468264
362.8063233 -0.008804873 19258.75186 -23.06773732 0.001310172 0.11032268
362.1130332 -0.009603384 17884.01664 -16.23392788 0.002541599 0.014220414
362.4537589 0.004774992 17948.11168 86.82277459 0.002219716 0.000346881
360.4861797 -0.033102583 14929.18175 -30.84452397 0.002219825 0.000404354
358.6382289 -0.028423561 13268.46278 -27.72950817 0.00221997 0.000473142
357.2304768 -0.02438991 12194.22824 -18.15947438 0.00222015 0.000556153
356.4330087 -0.008169779 11495.86992 -8.621806364 0.00222038 0.000657343
356.0182099 -0.009758736 10957.41406 -10.03130092 0.002220691 0.000782214
355.846049 -0.002704126 10481.66979 -0.71428435 0.002221109 0.000938664
355.7989276 0.00182447 10151.41159 -3.584484876 0.002221684 0.001138447
355.7820492 -0.000605059 9810.787665 -3.940115176 0.002222513 0.001399867
355.6510044 -0.001372932 9489.771635 -7.047168089 0.002223796 0.001753097
365.1639843 -0.007325483 24666.60769 -167.4591472 0.004003787 0.000410637
363.4189555 -0.029042867 21137.35526 -44.1262589 0.004003784 0.00041147
361.7973844 -0.025640341 18928.67821 -30.06259729 0.004003782 0.000412307
360.1569707 -0.027864248 17150.04544 -31.9610189 0.004003793 0.000413147
358.6535614 -0.020867489 15610.98791 -19.18569009 0.004003785 0.000413987
357.4473904 -0.010541334 14416.2981 -14.20753164 0.004003786 0.000414829
356.6848728 -0.008416044 13557.59372 -10.36728612 0.004003787 0.000415673
356.3000018 0.001319542 12913.51496 -7.101190757 0.004003789 0.000416519
356.0608057 0.001852254 12399.57296 -6.193103147 0.004003784 0.000417367
355.922531 -0.010983914 11948.73856 -7.193361714 0.004003804 0.000418222

Я попробовал различные структуры моделей, и ничто, похоже, не помогает. Вот код, который я использовал для подгонки ANN:

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from matplotlib import pyplot as plt
from copy import deepcopy
from scipy.integrate import odeint

# Установка случайного семени для воспроизводимости
rnd = np.random.RandomState(MAIN.rnd_seed)
torch.manual_seed(MAIN.rnd_seed)

# Загрузка данных и кинетических параметров
data = pd.read_csv("data.csv")

# Стандартизация данных
mean = np.mean(stacked_data, 0, keepdims=True)
stdv = np.std(stacked_data, 0, keepdims=True)

def standardize(data):
    return {k: (x - mean) / stdv for k, x in data.items()}

stan_data = standardize(data)

# Установка индексов для входа (X) и промежуточных (U)
x_index = ['Var1', 'Var2', 'Var3', 'Var4']
u_index = ['Obj1', 'Obj2']

# %% <Определение модели ANN>
class ANN(nn.Module):
    def __init__(self, hyprams):
        super().__init__()
        self.input_size = hyprams['input_size']
        self.hidden_size = hyprams['hidden_size']
        self.output_size = hyprams['output_size']
        self.activation = nn.Tanh()
        sizes = [self.input_size, *self.hidden_size, self.output_size]
        self.linears = nn.ModuleList([nn.Linear(i, j) for i, j in zip(sizes[:-1], sizes[1:])])

    def forward(self, x):
        x = x.float()
        for linear in self.linears[:-1]:
            x = linear(x)
            x = self.activation(x)
        x = self.linears[-1](x)
        return x * 2

    def init_params(self):
        for linear in self.linears:
            torch.nn.init.xavier_normal_(linear.weight.data)
            torch.nn.init.zeros_(linear.bias)

# %% <Определение тренера ANN>
class ANN_trainer:
    def __init__(self, hyprams, x_index, u_index):
        self.dtype = torch.float
        self.hyprams = hyprams
        self.x_index = x_index
        self.u_index = u_index

    def fit(self, data_train, data_valid):
        f = lambda x1: np.row_stack(np.transpose(x1, axes=(2, 0, 1)))
        g = lambda x2: np.row_stack([f(x1) for x1 in x2.values()])
        data_train = torch.from_numpy(g(data_train)).type(self.dtype)
        self.x_train = data_train[:, self.x_index]
        self.y_train = data_train[:, self.u_index]
        data_valid = torch.from_numpy(g(data_valid)).type(self.dtype)
        self.x_valid = data_valid[:, self.x_index]
        self.y_valid = data_valid[:, self.u_index]

        self.hyprams['input_size'] = len(self.x_index)
        self.hyprams['output_size'] = len(self.u_index)

        self.ANN = ANN(self.hyprams)
        self.ANN.init_params()

        self.optimizer = optim.Adam(self.ANN.parameters(), lr=self.hyprams['learning_rate'])
        self.loss_fn = nn.MSELoss()

        self.training_history = np.zeros((self.hyprams['epochs'], 3))
        for epoch in range(self.hyprams['epochs']):
            self.optimizer.zero_grad()
            y_train_pred = self.ANN(self.x_train)
            training_loss = self.loss_fn(self.y_train, y_train_pred)
            training_loss.backward()
            self.optimizer.step()

            y_valid_pred = self.ANN(self.x_valid)
            validation_loss = self.loss_fn(self.y_valid, y_valid_pred)
            self.training_history[epoch, 0] = epoch
            self.training_history[epoch, 1] = training_loss
            self.training_history[epoch, 2] = validation_loss

    def parity_plot(self, y_true, y_pred, title="Диаграмма соответствия", xlabel="Истинные значения", ylabel="Предсказанные значения"):
        plt.figure(figsize=(6, 6))
        plt.scatter(y_true, y_pred, edgecolors=(0, 0, 0))
        plt.plot([y_true.min(), y_true.max()], [y_true.min(), y_true.max()], 'k--', lw=2)
        plt.xlabel(xlabel)
        plt.ylabel(ylabel)
        plt.title(title)
        plt.grid(True)
        plt.show()

    def fit_aggregate(self, data_trainval):
        self.aggregate_ANN = []
        training_history = []
        for k in data_trainval.keys():
            data_train = {i: x for i, x in data_trainval.items() if i != k}
            data_valid = {k: data_trainval[k]}
            self.fit(data_train, data_valid)
            self.aggregate_ANN.append(deepcopy(self.ANN))
            training_history.append(self.training_history)
        self.training_history = np.mean(training_history, axis=0)

    def predict(self, x):
        y_aggregate = [ANN(x) for ANN in self.aggregate_ANN]
        y_aggregate = torch.dstack(y_aggregate)
        y_mean = torch.mean(y_aggregate, dim=2)
        y_std = torch.std(y_aggregate, dim=2)
        return y_mean, y_std

# %% <Определение гибридной модели>
class HybridModel(ANN_trainer):
    def __init__(self, hyprams, x_index, u_index):
        super().__init__(hyprams, x_index, u_index)

# %% <Обучение и тестирование гибридной модели>
hyprams = {'hidden_size': [6], 'learning_rate': 0.003, 'epochs': 2000}
train_split = MAIN.train_test_split['train']
test_split = MAIN.train_test_split['test']

data_train = {x: stan_data[x] for x in train_split}
data_test = {x: stan_data[x] for x in test_split}

hyb_model = HybridModel(hyprams, x_index, u_index)
hyb_model.fit_aggregate(data_train)

# Диаграмма соответствия
y_mean, y_std = hyb_model.predict(hyb_model.x_valid)
hyb_model.parity_plot(hyb_model.y_valid.detach().numpy(), y_mean.detach().numpy(), title="Итоговая диаграмма соответствия после предсказания")

print("Завершено")

```

Ответ или решение

Анализ проблемы: Невозможность подогнать ИНС к "простому" набору данных

Приветствую вас! Я вижу, что у вас возникли трудности с подгонкой искусственной нейронной сети (ИНС) к вашему набору данных, основанному на временн́ых рядах и кинетических параметрах. Давайте детально проанализируем вашу ситуацию и предоствим рекомендации.

1. Концепция и структура данных

Ваши данные включают четыре переменные (Var1, Var2, Var3 и Var4), которые являются измерениями и их производными, а также две целевые величины (Obj1 и Obj2). Это довольно стандартная структура данных для задач регрессии в контексте моделирования химических процессов, однако важно убедиться в следующих аспектах:

  • Качество данных: Проанализируйте набор данных на наличие выбросов и пропусков. Перед подгонкой модели убедитесь, что данные чистые и непрерывные.
  • Согласованность временных меток: Убедитесь, что временные ряды имеют одинаковую частоту и точно синхронизированы.

2. Стандартизация данных

Вы реализовали стандартизацию данных. Это важный шаг, который помогает улучшить производительность нейросети. Однако важно понимать, что не всегда общие статистические показатели, такие как среднее и стандартное отклонение, подходят для всех подмножеств данных:

  • Локальная стандартизация: Вместо применения глобальной нормализации можно рассмотреть возможность локальной стандартизации для каждого временного отрезка.

  • Сравнение распрделений: Проверьте распределение обучающих и тестовых наборов данных. Их различия могут негативно сказаться на обучении.

3. Структура нейронной сети

Обратить внимание на архитектуру вашей сети также очень важно. Вы используете единственный скрытый слой с 6 нейронами и функцию активации Тангент гиперболический. Есть несколько предложений по улучшению:

  • Увеличение скрытых слоев: Попробуйте добавить больше скрытых слоев или увеличить количество нейронов в существующем слое, поскольку модель с одним слоем может не быть достаточно мощной для захвата сложных зависимостей в данных.

  • Использование других функций активации: Попробуйте использовать ReLU или Leaky ReLU для избежания проблемы исчезающего градиента. Эти функции часто показывают лучшие результаты при обучении глубоких сетей.

4. Гиперпараметры и оптимизация

Ваши гиперпараметры, такие как скорость обучения (learning rate) и количество эпох, могут быть некорректно подобраны. Рекомендуется:

  • Использование поиска по сетке: Примените метод поиска по сетке для нахождения оптимальных значений гиперпараметров (например, learning_rate, hidden_size, epochs).

  • Регуляризация: Чтобы избежать переобучения, попробуйте добавить регуляризацию (например, Dropout или L2-регуляризацию) для улучшения обобщающей способности вашей модели.

5. Интерпретация результатов

Когда у вас уже есть модель, необходимо анализировать результаты:

  • Визуализация потерь: Постройте график потерь для обучающего и валидационного наборов данных, чтобы оценить прогресс обучения.

  • Паритетные графики: Испытайте паритетные графики, чтобы исследовать отклонения между предсказанными и истинными значениями. Это поможет выявить системные ошибки модели.

Заключение

Принимая во внимание все предложенные шаги, вы сможете улучшить подгонку вашей нейронной сети к "простым" данным. Помните, что проектирование модели является итеративным процессом, требующим множества экспериментов и анализов.

Успехов в ваших дальнейших экспериментах с регрессионной моделью! Если у вас возникнут дополнительные вопросы, не стесняйтесь обращаться.

Оцените материал
Добавить комментарий

Капча загружается...