Слабая модель нейронной сети
у меня было проектное задание по учебе, создать нейронную сеть на любую тему. Выбрал я датасеты по предсказанию сердечных заболеваний. Дата сет вроде хороший. Модель я написал, но не уверен в её качестве, из-за низкого процента точности. Проблема в том что я не знаю как его улучшить, то ли опыта не хватает, то ли просто мозги не работают. Прошу помощи у знающих людей
from keras.models import Sequential
from keras.layers import Dense
import pandas as pd
import keras
import matplotlib.pyplot as plt
# from scikeras.wrappers import KerasClassifier
# from sklearn.model_selection import cross_val_score
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler
data = pd.read_csv("heart_attack_prediction_dataset.csv")
input_names = ['Age', 'Sex', 'Cholesterol', 'Heart_Rate', 'Diabetes', 'Family_History', 'Smoking', 'Obesity',
'Alcohol_Consumption', 'Previous_Heart_Problems']
output_names = ['Heart_Attack_Risk']
encoders = {'Male': 0,
'Female': 1}
X = data[input_names].replace(encoders).values
Y = data[output_names].values
x_train, x_test, y_train, y_test = train_test_split(X, Y, test_size=0.2, random_state=42)
scaler = MinMaxScaler()
x_train = scaler.fit_transform(x_train)
x_test = scaler.transform(x_test)
model = Sequential()
model.add(Dense(10, input_dim=10, activation='relu'))
model.add(Dense(20, activation='relu'))
model.add(Dense(10, activation='relu'))
model.add(Dense(7, activation='relu'))
model.add(Dense(1, activation='sigmoid'))
optimizer = keras.optimizers.Adam(learning_rate=0.0001)
model.compile(loss="binary_crossentropy", optimizer=optimizer, metrics=['accuracy'])
fit_results = model.fit(x_train, y_train, epochs=250, batch_size=32, validation_data=(x_test, y_test))
plt.title('Losses train/validation')
plt.plot(fit_results.history['loss'], label='Train')
plt.plot(fit_results.history['val_loss'], label='Validation')
plt.legend()
plt.show()
plt.title('Accuracies train/validation')
plt.plot(fit_results.history['accuracy'], label='Train')
plt.plot(fit_results.history['val_accuracy'], label='Validation')
plt.legend()
plt.show()
scores = model.evaluate(x_test, y_test)
print("\n%s: %.2f%%" % (model.metrics_names[1], scores[1]*100))
Датасет: https://disk.yandex.ru/d/IfvyQawL2QSbeg Размер датасета - 8764 записи и 26 атрибутов из которых я беру только 10 входных и 1 таргет
Ответы (1 шт):
Так, ну я посмотрел, в этих данных просто нет сигнала. И если смотреть метрику не accuracy, а AUC, то она будет около 0.5, то есть это на уровне случайного угадывания. Вот вам корреляции ваших данных с целевой переменной:
import seaborn as sns
df = data[input_names+output_names].replace(encoders).corr()[output_names].reset_index()
sns.barplot(x='Heart Attack Risk', y='index', data=df)
plt.plot((0.02, 0.02), (-1, 9), '--', color='gray')
plt.text(0.03, -0.5, "0.02 (2%)")
plt.plot((-0.02, -0.02), (-1, 11), '--', color='gray')
plt.text(-0.01, 11, "-0.02 (-2%)")
plt.title('Корреляции c Heart Attack Risk')
Корреляция Heart Attack Risk самой с собой равная 1 оставлена тут для понимания мизерности масштаба остальных корреляций.
Наблюдаемая корреляция фич с таргетом в диапазоне +/-2%, это вообще ни о чём и укладывается в статпогрешность. Либо в этих данных вообще нет сигнала, либо нужно исследовать остальные фичи (которые вы отбросили) и пытаться что-то из них вытащить.
Именно поэтому "know your data" - сначала изучите ваши данные, а потом уже стройте модели. Иначе вы не будете понимать, что вообще происходит. Ну и, конечно же, "garbage in - garbage out", если у вас непонятно что на входе, то и на выходе непонятно что. Сначала данные нужно изучить и причесать. Но может оказаться, что и после всех усилий сигнала в данных не обнаружится. И тут никакая нейросеть не поможет. Если сигнала нет, то нейросети его тоже неоткуда будет взять.
Также может оказаться, что датасет пришёл к вам в каком-то искажённом виде, данные испорчены каким-то образом и сигнала нет поэтому. Это тоже нужно выяснять. И такое бывает.
P.S. Вообще единственная сильная корреляция в данном датасете на числовых данных - это между Age и Smoking - аж 39%, это очень странно. )) Текстовые и прочие сложные данные я не смотрел пока, может там есть какой сигнал, но тоже не факт.
P.P.S. Скормил этот датасет в catboost, даже он с обработкой текстовых фич как категориальных дал AUC только 0.52, то есть даже при использовании всех признаков всё плохо. Может из фич и можно что-то вытащить, если как-то хитро их обработать, но вероятность этого очень мала на мой взгляд.



