Как визуализировать тестовые выборки классификатора логической регрессии

Всем привет! Работаю с iris dataset.

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from sklearn import datasets
from sklearn.datasets import load_iris
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from sklearn.model_selection import train_test_split, cross_validate
from sklearn.preprocessing import StandardScaler
from sklearn.svm import SVC
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score
import seaborn as sns
from sklearn.metrics import classification_report, confusion_matrix



iris = datasets.load_iris()
X = iris.data
y = iris.target
x_train, x_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
model = LogisticRegression()
model.fit(x_train, y_train)
predictions = model.predict(x_test)

Я обучил выборку, нашел предсказанные значения, но не совсем понимаю, как получить графики тестовой выборки, а именно с исходными метками, а также с метками, полученными при классификации.


Ответы (1 шт):

Автор решения: CrazyElf

Ну вот например так. Хотя у меня цвета немного не согласованы, но вы дальше сами копайтесь уже:

import pandas as pd
import matplotlib.pyplot as plt
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
import seaborn as sns

data = load_iris(as_frame=True)
X = data['data']
y = data['target']
x_train, x_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42) #, shuffle=True)
model = LogisticRegression(max_iter=200)
model.fit(x_train, y_train)
predictions = model.predict(x_train)
df = x_train.copy()
y2name = dict(enumerate(data['target_names']))
df['y'] = y_train.map(y2name)
df['preds'] = pd.Series(predictions).map(y2name).values

sns.scatterplot(x='petal length (cm)', y='petal width (cm)', data=df, hue='y', alpha=0.5)
plt.title('Настоящие значения')

plt.figure()

sns.scatterplot(x='petal length (cm)', y='petal width (cm)', data=df, hue='y', alpha=0.1, legend=False)
sns.scatterplot(x='petal length (cm)', y='petal width (cm)', data=df[df['y']!=df['preds']], hue='preds')
plt.title('Ошибочные предсказания')

введите сюда описание изображения

введите сюда описание изображения

→ Ссылка