#!/usr/bin/env python3
k_NN(k-최근접 이웃) Classifier[2]
k-NN algorithm은 가장 가까운 데이터 포인트중 n_neighbors만큼 이웃으로 삼아 예측으로 사용합니다.
다음은 k-NN algorithm을 보여줍니다.
import matplotlib.pyplot as plt
import mglearn
mglearn.plots.plot_knn_classification(n_neighbors=1)
plt.xlabel('feature 0')
plt.ylabel('feature 1')
plt.title('n_neighbors=1')
plt.legend(feature_list, loc=(1.01, 0.4))
plt.show()
n_neighbors=1 일 때 algorithm
mglearn.plots.plot_knn_classification(n_neighbors=3)
plt.xlabel('feature 0')
plt.ylabel('feature 1')
plt.title('n_neighbors=3')
plt.legend(feature_list, loc=(1.01, 0.4))
plt.show()
n_neighbors=1 일 때 algorithm
이번에는 n_neighbors의 숫자에 따른 결정 경계를 그려보겠습니다.
먼저 임의로 데이터셋을 만들어보겠습니다.
from mglearn.datasets import make_forge
x, y = make_forge()
_, axes = plt.subplots(1, 3)
for i, ax in zip([1, 3, 9], axes.ravel()):
knn = KNeighborsClassifier(i)
knn.fit(x, y)
mglearn.plots.plot_2d_separator(knn, x, fill=True,
eps=0.5, alpha=0.5, ax=ax)
mglearn.discrete_scatter(x[:, 0], x[:, 1],
y=y, ax=ax)
ax.set_title('k={}'.format(i))
ax.set_xlabel('feature 0')
ax.set_ylabel('feature 1')
axes[0].legend(loc=(0.01, 1.01))
plt.show()
n_neighbors 값에 따른 k-NN모델이 만든 decision boundary
이 그림을 보면 제일 왼쪽인 k=1일 경우에 모든 데이터를 가장 정확하게 분류했으며 그 경계는 복잡합니다.
k의 값이 커질수록 경계선은 점점 수평에 가까워지며 경계는 완만해짐을 알 수 있습니다.
실제wine 데이터를 분석해보겠습니다.
분석하기전에 wine데이터의 전체적인 구조를 pandas의 scatter_matrix를 이용해서
확인해보겠습니다. 이 작업은 데이터틔 outlier가 있는지 쉽게 확인할 수 있는 장점이 있습니다.
from sklearn.datasets import load_wine
import pandas as pd
wine = load_wine()
print(wine.keys())
# dict_keys(['data', 'target', 'target_names', 'DESCR', 'feature_names'])
for i in range(len(wine.feature_names)):
print('{}:{}'.format(i, wine.feature_names[i]))
# 0:alcohol
# 1:malic_acid
# 2:ash
# 3:alcalinity_of_ash
# 4:magnesium
# 5:total_phenols
# 6:flavanoids
# 7:nonflavanoid_phenols
# 8:proanthocyanins
# 9:color_intensity
# 10:hue
# 11:od280/od315_of_diluted_wines
# 12:proline
wine_df = pd.DataFrame(wine.data, columns=range(len(wine.feature_names)))
pd.scatter_matrix(wine_df, # dataframe
c=wine.target, # color
hist_kwds={'bins':30}, # hist kwords
s=10, # size
alpha=0.5, # alpha
marker='o') # marker
plt.show()
wine의 scatter_plot
그래프가 복잡해 보일지 모르겠지만 전체적인 구조를 한 눈에 파악할 수 있습니다.
데이터를 파악했으니 k-NN 최근접 이웃 algorithm, n_neighbors=3으로 데이터를 분석해보겠습니다.
from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import train_test_split
x_train, x_test, y_train, y_test =\
train_test_split(wine.data, wine.target, stratify=wine.target, # stratify = 계층화
random_state=0, test_size=0.3)
knn = KNeighborsClassifier(n_neighbors=3, n_jobs=-1) # jobs = 사용할 코어의 수, -1 => 모든 코어
knn.fit(x_train, y_train)
print('{:.3f}'.format(knn.score(x_test, y_test)))
# 0.648
약 KNeighborsClassifier(n_neighbors=3)으로 와인의 품종을 테스트한 결과 약 70%의 정확도를 나타내는 것을
볼 수 있습니다.
이번엔 이 데이터로 모델의 복잡도와 일반화사이의 관계를 알아보겠습니다.
위의 임의로 만든 데이터로 알아본 결과 모델의 복잡도를 알아보았는데 실제 데이터셋으로 모델의 복잡도와 정확도를 알아보겠습니다.
from sklearn.datasets import load_wine
import matplotlib.pyplot as plt
wine = load_wine()
x_train, x_test, y_train, y_test = \
train_test_split(wine.data, wine.target, stratify=wine.target,
test_size=0.3, random_state=0)
train_list = []
test_list = []
n_range = range(1, 20)
for i in n_range:
knn = KNeighborsClassifier(n_neighbors=i, n_jobs=-1)
knn.fit(x_train, y_train)
tr_score = knn.score(x_train, y_train)
te_score = knn.score(x_test, y_test)
train_list.append(tr_score)
test_list.append(te_score)
plt.plot(n_range, train_list, color='red', ls='--', lw=2, label='train accuracy') # ls = linestyle, lw= linewidth
plt.plot(n_range, test_list, color='green', lw=2, label='test accuracy')
plt.xticks(n_range)
plt.xlabel('n_neighbors')
plt.ylabel('accuracy')
plt.legend()
plt.show()
n_neighbors 변화에 따른 훈련 정확도와 테스트 정확도
이웃의 수(n_neighbors)가 작을수록 모델이 복잡해지고 모델의 정확도는 100%입니다.
그러나 예측 정확도는 0.7정도인 것을 알수있으며
이웃의 수를 증가시켜도 예측 정확도는 많이 늘어나지는 않습니다.
따라서 이 데이터에는 k-NN 최근접 algorithm이 적합하지 않음을 알 수 있습니다.
참고 자료:
[1]Introduction to Machine Learning with Python, Sarah Guido
'python 머신러닝 -- 지도학습 > Classifier' 카테고리의 다른 글
Random Forest (0) | 2018.03.15 |
---|---|
Decision Tree (0) | 2018.03.14 |
Multi Linear Classification (0) | 2018.03.14 |
Logistic Regression (1) | 2018.03.13 |
k_NN(k-최근접 이웃) Classifier[1] (0) | 2018.03.11 |