10.mnist_dropout
In [1]:
from IPython.core.display import display, HTML
display(HTML("<style> .container{width:100% !important;}</style>"))
In [2]:
import tensorflow as tf
import warnings
import numpy as np
import matplotlib.pyplot as plt
plt.rcParams["figure.figsize"] = (12, 8)
In [3]:
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("./mnist/data/", one_hot=True)
WARNING:tensorflow:From <ipython-input-3-4dcbd946c02b>:2: read_data_sets (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
WARNING:tensorflow:From /anaconda3/lib/python3.6/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:260: maybe_download (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.
Instructions for updating:
Please write your own downloading logic.
WARNING:tensorflow:From /anaconda3/lib/python3.6/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:262: extract_images (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting ./mnist/data/train-images-idx3-ubyte.gz
WARNING:tensorflow:From /anaconda3/lib/python3.6/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:267: extract_labels (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting ./mnist/data/train-labels-idx1-ubyte.gz
WARNING:tensorflow:From /anaconda3/lib/python3.6/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:110: dense_to_one_hot (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.one_hot on tensors.
Extracting ./mnist/data/t10k-images-idx3-ubyte.gz
Extracting ./mnist/data/t10k-labels-idx1-ubyte.gz
WARNING:tensorflow:From /anaconda3/lib/python3.6/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:290: DataSet.__init__ (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
In [4]:
global_step = tf.Variable(0, trainable=False, name="global_step")
X = tf.placeholder(tf.float32, shape=[None, 784], name="X") # None, 784
Y = tf.placeholder(tf.float32, shape=[None, 10], name="Y")

W1 = tf.Variable(tf.random_normal([784, 256], mean=0, stddev=0.01), name="W1")
W2 = tf.Variable(tf.random_normal([256, 256], mean=0, stddev=0.01), name="W2")
W3 = tf.Variable(tf.random_normal([256,  10], mean=0, stddev=0.011), name="W3")

b1 = tf.zeros([256], name="bias1")
b2 = tf.zeros([256], name="bias2")
b3 = tf.zeros([10],  name="bias3")


dropout

  • 학습시 전체 신경망 중 일부만 사용하도록 함 -> 과적합 방지
  • 시간이 오래걸리는 편

In [5]:
keep_prob = tf.placeholder(tf.float32)

with tf.name_scope("layer1"):
    L1 = tf.add(tf.matmul(X, W1), b1)
    L1 = tf.nn.relu(L1)
    L1 = tf.nn.dropout(L1, keep_prob)
    
with tf.name_scope("layer2"):
    L2 = tf.add(tf.matmul(L1, W2), b2)
    L2 = tf.nn.relu(L2)
    L2 = tf.nn.dropout(L2, keep_prob)
    
with tf.name_scope("layer3"):
    model = tf.add(tf.matmul(L2, W3), b3)
In [6]:
with tf.name_scope("optimizer"):
    cost = tf.reduce_mean(
        tf.nn.softmax_cross_entropy_with_logits_v2(labels=Y, logits=model))
    opt = tf.train.AdamOptimizer(learning_rate=0.001).minimize(cost, global_step=global_step)
    tf.summary.scalar("cost", cost)
In [7]:
init = tf.global_variables_initializer()
sess = tf.Session()
# saver = tf.train.Saver(tf.global_variables())
sess.run(init)

merged = tf.summary.merge_all()
writer = tf.summary.FileWriter("./logs/mnist_dropout", sess.graph)
In [8]:
batch_size = 50
total_batch = int(mnist.train.num_examples/batch_size)
cost_epoch = []
In [9]:
%%time
for epoch in range(30):
    total_cost = 0
    
    for i in range(total_batch):
        batch_xs, batch_ys = mnist.train.next_batch(batch_size)
        
        _, cost_val = sess.run([opt, cost], feed_dict={X:batch_xs, Y:batch_ys, keep_prob: 0.8})
        total_cost += cost_val
        cost_epoch.append(total_cost)
        
        summary = sess.run(merged, feed_dict={X:batch_xs, Y:batch_ys, keep_prob: 0.8})
        writer.add_summary(summary, global_step=sess.run(global_step))
        
    print("epoch: {}, Avg.cost: {}".format(epoch+1, total_cost / total_batch))
epoch: 1, Avg.cost: 0.3532606958614832
epoch: 2, Avg.cost: 0.1427451760597019
epoch: 3, Avg.cost: 0.10226461782120168
epoch: 4, Avg.cost: 0.08699281872800467
epoch: 5, Avg.cost: 0.06855666186279533
epoch: 6, Avg.cost: 0.05921383855340537
epoch: 7, Avg.cost: 0.0536975436815357
epoch: 8, Avg.cost: 0.04580990582659565
epoch: 9, Avg.cost: 0.04084625356087186
epoch: 10, Avg.cost: 0.040573723167723404
epoch: 11, Avg.cost: 0.035842695366584604
epoch: 12, Avg.cost: 0.03263294939398871
epoch: 13, Avg.cost: 0.03360669748346316
epoch: 14, Avg.cost: 0.030501310914848794
epoch: 15, Avg.cost: 0.028370174647349235
epoch: 16, Avg.cost: 0.02699218331828392
epoch: 17, Avg.cost: 0.026614617982999005
epoch: 18, Avg.cost: 0.027732884158863685
epoch: 19, Avg.cost: 0.02651893331764189
epoch: 20, Avg.cost: 0.024510102366322662
epoch: 21, Avg.cost: 0.024103802091576653
epoch: 22, Avg.cost: 0.021529521410293455
epoch: 23, Avg.cost: 0.024205624244715927
epoch: 24, Avg.cost: 0.021746395409784947
epoch: 25, Avg.cost: 0.02059082699589949
epoch: 26, Avg.cost: 0.02283201359495644
epoch: 27, Avg.cost: 0.021406652638101233
epoch: 28, Avg.cost: 0.022226517286706812
epoch: 29, Avg.cost: 0.019306987923368567
epoch: 30, Avg.cost: 0.020735127189004873
CPU times: user 4min 42s, sys: 1min, total: 5min 43s
Wall time: 3min 23s
In [10]:
plt.plot(cost_epoch, "g")
plt.title("cost_epoch")
plt.show()
In [11]:
is_correct = tf.equal(tf.argmax(model, 1), tf.argmax(Y, 1))
accuracy = tf.reduce_mean(tf.cast(is_correct, tf.float32))
print("accuracy: {}".format(sess.run(accuracy, feed_dict={X: mnist.test.images,
                                                        Y: mnist.test.labels,
                                                        keep_prob: 1})))
accuracy: 0.9835000038146973
In [12]:
### tensorboard graph

from IPython.display import clear_output, Image, display, HTML

def strip_consts(graph_def, max_const_size=32):
    """Strip large constant values from graph_def."""
    strip_def = tf.GraphDef()
    for n0 in graph_def.node:
        n = strip_def.node.add() 
        n.MergeFrom(n0)
        if n.op == 'Const':
            tensor = n.attr['value'].tensor
            size = len(tensor.tensor_content)
            if size > max_const_size:
                tensor.tensor_content = "<stripped %d bytes>"%size
    return strip_def

def show_graph(graph_def, max_const_size=32):
    """Visualize TensorFlow graph."""
    if hasattr(graph_def, 'as_graph_def'):
        graph_def = graph_def.as_graph_def()
    strip_def = strip_consts(graph_def, max_const_size=max_const_size)
    code = """
        <script>
          function load() {{
            document.getElementById("{id}").pbtxt = {data};
          }}
        </script>
        <link rel="import" href="https://tensorboard.appspot.com/tf-graph-basic.build.html" onload=load()>
        <div style="height:600px">
          <tf-graph-basic id="{id}"></tf-graph-basic>
        </div>
    """.format(data=repr(str(strip_def)), id='graph'+str(np.random.rand()))

    iframe = """
        <iframe seamless style="width:1200px;height:620px;border:0" srcdoc="{}"></iframe>
    """.format(code.replace('"', '&quot;'))
    display(HTML(iframe))
In [13]:
show_graph(tf.get_default_graph().as_graph_def())
In [14]:
# ### tensorboard

# def TB(cleanup=False):
#     import webbrowser
#     webbrowser.open('http://127.0.0.1:6006')

#     !tensorboard --logdir="./logs/mnist_dropout/"

# TB()

'Deep_Learning' 카테고리의 다른 글

12.mnist_cnn  (0) 2018.12.12
11.mnist_matplotlib_dropout_tensorgraph  (0) 2018.12.10
00.write_csv  (0) 2018.12.09
09.mnist_01_minibatch  (0) 2018.12.09
08.tensorboard03_example  (0) 2018.12.09

+ Recent posts