#!/usr/bin/env python3
Decision Tree Regression
ram 가격 데이터를 이용해 Decision Tree Regression을 알아보겠습니다.
먼저 ram 가격 데이터를 시각화하여 구조를 파악해보면
import pandas as pd
import os
import matplotlib.pyplot as plt
import mglearn
import numpy as np
path = os.path.join(mglearn.datasets.DATA_PATH, 'ram_price.csv')
ram_prices = pd.read_csv(path)
ram_prices = ram_prices.iloc[:, 1:] # iloc: 정수로 columns접근, index number : 몇번째 인덱스 까지의 형태
plt.semilogy(ram_prices['date'], ram_prices['price'])
plt.xlabel('year', size=15)
plt.ylabel('price ($/Mbyte)', size=15)
plt.show()
log scale로 그린 램 가격 동향
시간이 지날수록 램 가격은 점점 감소하는 것을 알 수 있습니다.
DecisionTreeRegressor는 train set범위 밖의 데이터에 대해서는 예측을 할 수가 없는데
다음의 코드는 이 것을 확인할 수 있습니다.
from sklearn.tree import DecisionTreeRegressor
from sklearn.linear_model import LinearRegression
data_train = ram_prices[ram_prices['date'] < 2000]
data_test = ram_prices[ram_prices['date'] >= 2000]
x_train = data_train['date'][:, np.newaxis] # train data를 1열로 만듭니다.
y_train = np.log(data_train['price'])
tree = DecisionTreeRegressor().fit(x_train, y_train)
lr = LinearRegression().fit(x_train, y_train)
# test는 모든 데이터에 대해 적용합니다.
x_all = ram_prices['date'].reshape(-1, 1) # x_all를 1열로 만듭니다.
pred_tree = tree.predict(x_all)
price_tree = np.exp(pred_tree) # log값 되돌리기
pred_lr = lr.predict(x_all)
price_lr = np.exp(pred_lr) # log값 되돌리기
plt.semilogy(ram_prices['date'], price_tree, label='tree predict', ls='--', dashes=(2,1))
plt.semilogy(ram_prices['date'], price_lr, label='linear reg. predict', ls=':')
plt.semilogy(data_train['date'], data_train['price'], label='train data', alpha=0.4)
plt.semilogy(data_test['date'], data_test['price'], label='test data')
plt.legend(loc=1)
plt.xlabel('year', size=15)
plt.ylabel('price ($/Mbyte)', size=15)
plt.show()
ram 가격 데이터로 만든 linear model과 regression tree의 예측값 비교
linear model은 직선으로 데이터를 근사합니다.
tree model은 train set을 완벽하게 fitting했습니다. 이는 tree의 complexity를 지정하지 않아서 전체 데이터를 fitting 하기 때문입니다.
그러나 tree model은 train 데이터를 넘어가버리면 마지막 포인트를 이용해 예측하는 것이 전부입니다.
Decision Tree의 주요 단점은 pre-pruning를 함에도 불구하고 overfitting되는 경향이 있어 일반화 성능이 좋지 않습니다.
이에 대안으로 ensemble 방법을 많이 사용합니다.
참고 자료:
[1]Introduction to Machine Learning with Python, Sarah Guido
'python 머신러닝 -- 지도학습 > Regression' 카테고리의 다른 글
Lasso (0) | 2018.03.13 |
---|---|
Ridge (0) | 2018.03.12 |
LinearRegression (0) | 2018.03.12 |
k_NN(k-최근접 이웃) Regression (1) | 2018.03.12 |