09_mnist_01_minibatch
In [1]:
import tensorflow as tf
import warnings
warnings.filterwarnings("ignore")


mnist data 준비

In [2]:
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("./mnist/data/", one_hot=True)
WARNING:tensorflow:From <ipython-input-2-4dcbd946c02b>:2: read_data_sets (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
WARNING:tensorflow:From /anaconda3/lib/python3.6/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:260: maybe_download (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.
Instructions for updating:
Please write your own downloading logic.
WARNING:tensorflow:From /anaconda3/lib/python3.6/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:262: extract_images (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting ./mnist/data/train-images-idx3-ubyte.gz
WARNING:tensorflow:From /anaconda3/lib/python3.6/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:267: extract_labels (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting ./mnist/data/train-labels-idx1-ubyte.gz
WARNING:tensorflow:From /anaconda3/lib/python3.6/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:110: dense_to_one_hot (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.one_hot on tensors.
Extracting ./mnist/data/t10k-images-idx3-ubyte.gz
Extracting ./mnist/data/t10k-labels-idx1-ubyte.gz
WARNING:tensorflow:From /anaconda3/lib/python3.6/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:290: DataSet.__init__ (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.


model setting

  • image = 28X28, 784 attribute
  • label = 0, 1, ,2 ,3 ,4 ,5, 6, 7, 8, 9
In [3]:
global_step = tf.Variable(0, trainable=False, name="global_step")
X = tf.placeholder(tf.float32, [None, 784], name="X")
Y = tf.placeholder(tf.float32, [None, 10], name="Y")

minibatch

이미지를 하나씩 학습시키는 것보다 여러 개를 한꺼번에 학습시키는 쪽이 효과가 좋음.(많은 컴퓨팅 자원이 뒷받침 될 때)
따라서 일반적으로 데이터를 적당한크기로 잘라서 학습 --> 미니배치minibatch

  • placeholder에서 [None, 784]는 한 번에 학습시킬 이미지의 갯수를 지정 -- minibatch
  • 원하는 크기로 지정할 수도 있지만 학습할 갯수를 바꿔가면서 진행할 때는 "None"으로 넣어주면 tensorflow가 계산함

784(입력, 특징수) -> 256(first hidden layer) -> 256(second hidden layer) -> 10 (output 0-9 분류 갯수)

In [4]:
W1 = tf.Variable(tf.random_normal([784, 256], mean=0, stddev=1), name="var1")
W2 = tf.Variable(tf.random_normal([256, 256], mean=0, stddev=1), name="var2")
W3 = tf.Variable(tf.random_normal([256, 10],  mean=0, stddev=1), name="var3")

b1 = tf.zeros([256], name="bias1")
b2 = tf.zeros([256], name="bias2")
b3 = tf.zeros([10],  name="bias3")
In [5]:
with tf.name_scope("layer1"):
    L1 = tf.add(tf.matmul(X, W1), b1)
    L1 = tf.nn.relu(L1)

with tf.name_scope("layer2"):
    L2 = tf.add(tf.matmul(L1, W2), b2)
    L2 = tf.nn.relu(L2)
    
with tf.name_scope("layer3"):
    model = tf.add(tf.matmul(L2, W3), b3)    
In [6]:
with tf.name_scope("opt"):
    cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=Y, logits=model))
    optimizer = tf.train.AdamOptimizer(learning_rate=0.01).minimize(cost, global_step=global_step)
    
    tf.summary.scalar("cost", cost)
In [7]:
init = tf.global_variables_initializer()
sess = tf.Session()
saver = tf.train.Saver(tf.global_variables())
sess.run(init)


merged = tf.summary.merge_all()
writer = tf.summary.FileWriter("./logs/mnist_01", sess.graph)

batch_size = 100
total_batch = int(mnist.train.num_examples / batch_size)

MNIST는 데이터가 수만 개로 매우 크므로 학습에 미니배치 사용

  • 미니배치의 크기를 100개로 설정
  • mnist.train.num_examples를 배치크기로 나눠 미니배치가 총 몇 개인지를 저장

그리고 MNIST 데이터 전체를 학습하는 일을 총 15번 반복
학습 데이터 전체를 한 바퀴 도는 것을 에포치epoch라 함

In [8]:
for epoch in range(15):
    total_cost = 0
    
    for i in range(total_batch):
        batch_xs, batch_ys = mnist.train.next_batch(batch_size)
        
        _, cost_val = sess.run([optimizer, cost], feed_dict={X:batch_xs, Y:batch_ys})
        total_cost += cost_val
        
        summary = sess.run(merged, feed_dict={X:batch_xs, Y:batch_ys})
        writer.add_summary(summary, global_step=sess.run(global_step))
        
    print("Epoch: {}, Avg.cost = {:.3f}".format(epoch+1, total_cost / total_batch))
Epoch: 1, Avg.cost = 53.018
Epoch: 2, Avg.cost = 8.887
Epoch: 3, Avg.cost = 4.864
Epoch: 4, Avg.cost = 3.215
Epoch: 5, Avg.cost = 2.577
Epoch: 6, Avg.cost = 2.208
Epoch: 7, Avg.cost = 2.240
Epoch: 8, Avg.cost = 1.909
Epoch: 9, Avg.cost = 1.490
Epoch: 10, Avg.cost = 1.574
Epoch: 11, Avg.cost = 1.440
Epoch: 12, Avg.cost = 1.100
Epoch: 13, Avg.cost = 0.913
Epoch: 14, Avg.cost = 1.022
Epoch: 15, Avg.cost = 0.742
In [9]:
# import os

# if not os.path.isdir("./model/mnist_01"):
#     os.mkdir("./model/mnist_01")

# saver.save(sess, "./model/mnist_01/dnn/ckpt", global_step=global_step)
# print("optimize complete!")
In [10]:
is_correct = tf.equal(tf.argmax(model, 1), tf.argmax(Y, 1))
accuracy = tf.reduce_mean(tf.cast(is_correct, tf.float32))
print("accuracy: {:.4f}".format(sess.run(accuracy, feed_dict={X:mnist.test.images, 
                                                     Y:mnist.test.labels})))
accuracy: 0.9612
In [11]:
from IPython.core.display import display, HTML
display(HTML("<style> .container{width:100% !important;}</style>"))

'Deep_Learning' 카테고리의 다른 글

10.mnist_dropout  (0) 2018.12.10
00.write_csv  (0) 2018.12.09
08.tensorboard03_example  (0) 2018.12.09
07.tensorboard02_example  (0) 2018.12.09
06.tensorboard01_example  (0) 2018.12.09

+ Recent posts