Слабая модель нейронной сети

у меня было проектное задание по учебе, создать нейронную сеть на любую тему. Выбрал я датасеты по предсказанию сердечных заболеваний. Дата сет вроде хороший. Модель я написал, но не уверен в её качестве, из-за низкого процента точности. Проблема в том что я не знаю как его улучшить, то ли опыта не хватает, то ли просто мозги не работают. Прошу помощи у знающих людей

from keras.models import Sequential 
from keras.layers import Dense 
import pandas as pd 
import keras 
import matplotlib.pyplot as plt 
# from scikeras.wrappers import KerasClassifier 
# from sklearn.model_selection import cross_val_score 
from sklearn.model_selection import train_test_split 
from sklearn.preprocessing import MinMaxScaler 
 
data = pd.read_csv("heart_attack_prediction_dataset.csv") 
input_names = ['Age', 'Sex', 'Cholesterol', 'Heart_Rate', 'Diabetes', 'Family_History', 'Smoking', 'Obesity', 
               'Alcohol_Consumption', 'Previous_Heart_Problems'] 
output_names = ['Heart_Attack_Risk'] 
 
encoders = {'Male': 0, 
            'Female': 1} 
 
X = data[input_names].replace(encoders).values 
Y = data[output_names].values 
 
x_train, x_test, y_train, y_test = train_test_split(X, Y, test_size=0.2, random_state=42) 
 
scaler = MinMaxScaler() 
x_train = scaler.fit_transform(x_train) 
x_test = scaler.transform(x_test) 
 
model = Sequential() 
model.add(Dense(10, input_dim=10, activation='relu')) 
model.add(Dense(20, activation='relu')) 
model.add(Dense(10, activation='relu')) 
model.add(Dense(7, activation='relu')) 
model.add(Dense(1, activation='sigmoid')) 
 
optimizer = keras.optimizers.Adam(learning_rate=0.0001) 
model.compile(loss="binary_crossentropy", optimizer=optimizer, metrics=['accuracy']) 
 
fit_results = model.fit(x_train, y_train, epochs=250, batch_size=32, validation_data=(x_test, y_test)) 
 
plt.title('Losses train/validation') 
plt.plot(fit_results.history['loss'], label='Train') 
plt.plot(fit_results.history['val_loss'], label='Validation') 
plt.legend() 
plt.show() 
 
plt.title('Accuracies train/validation') 
plt.plot(fit_results.history['accuracy'], label='Train') 
plt.plot(fit_results.history['val_accuracy'], label='Validation') 
plt.legend() 
plt.show() 
 
scores = model.evaluate(x_test, y_test) 
print("\n%s: %.2f%%" % (model.metrics_names[1], scores[1]*100))

введите сюда описание изображения введите сюда описание изображения введите сюда описание изображения

Датасет: https://disk.yandex.ru/d/IfvyQawL2QSbeg Размер датасета - 8764 записи и 26 атрибутов из которых я беру только 10 входных и 1 таргет


Ответы (1 шт):

Автор решения: CrazyElf

Так, ну я посмотрел, в этих данных просто нет сигнала. И если смотреть метрику не accuracy, а AUC, то она будет около 0.5, то есть это на уровне случайного угадывания. Вот вам корреляции ваших данных с целевой переменной:

import seaborn as sns

df = data[input_names+output_names].replace(encoders).corr()[output_names].reset_index()
sns.barplot(x='Heart Attack Risk', y='index', data=df)
plt.plot((0.02, 0.02), (-1, 9), '--', color='gray')
plt.text(0.03, -0.5, "0.02 (2%)")
plt.plot((-0.02, -0.02), (-1, 11), '--', color='gray')
plt.text(-0.01, 11, "-0.02 (-2%)")
plt.title('Корреляции c Heart Attack Risk')

введите сюда описание изображения

Корреляция Heart Attack Risk самой с собой равная 1 оставлена тут для понимания мизерности масштаба остальных корреляций.

Наблюдаемая корреляция фич с таргетом в диапазоне +/-2%, это вообще ни о чём и укладывается в статпогрешность. Либо в этих данных вообще нет сигнала, либо нужно исследовать остальные фичи (которые вы отбросили) и пытаться что-то из них вытащить.

Именно поэтому "know your data" - сначала изучите ваши данные, а потом уже стройте модели. Иначе вы не будете понимать, что вообще происходит. Ну и, конечно же, "garbage in - garbage out", если у вас непонятно что на входе, то и на выходе непонятно что. Сначала данные нужно изучить и причесать. Но может оказаться, что и после всех усилий сигнала в данных не обнаружится. И тут никакая нейросеть не поможет. Если сигнала нет, то нейросети его тоже неоткуда будет взять.

Также может оказаться, что датасет пришёл к вам в каком-то искажённом виде, данные испорчены каким-то образом и сигнала нет поэтому. Это тоже нужно выяснять. И такое бывает.

P.S. Вообще единственная сильная корреляция в данном датасете на числовых данных - это между Age и Smoking - аж 39%, это очень странно. )) Текстовые и прочие сложные данные я не смотрел пока, может там есть какой сигнал, но тоже не факт.

P.P.S. Скормил этот датасет в catboost, даже он с обработкой текстовых фич как категориальных дал AUC только 0.52, то есть даже при использовании всех признаков всё плохо. Может из фич и можно что-то вытащить, если как-то хитро их обработать, но вероятность этого очень мала на мой взгляд.

→ Ссылка