tensorflow(三):简单神经网络实现手写体识别MNIST

    技术2025-11-26  18

    一、代码

    import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data # 载入数据集 mnist = input_data.read_data_sets("MNIST_data", one_hot=True) # 每个批次的大小 batch_size = 100 # 计算一共有多少批次 n_batch = mnist.train.num_examples // batch_size # 定义两个placeholder x = tf.placeholder(tf.float32, [None, 784]) y = tf.placeholder(tf.float32, [None, 10]) keep_prob = tf.placeholder(tf.float32) # 创建一个简单的神经网络 W_1 = tf.Variable(tf.truncated_normal([784, 2000], stddev=0.1)) b_1 = tf.Variable(tf.zeros([2000]) + 0.1) L_1 = tf.nn.relu(tf.matmul(x, W_1) + b_1) W_2 = tf.Variable(tf.truncated_normal([2000,10],stddev=0.1)) b_2 = tf.Variable(tf.zeros([10]) + 0.1) prediction = tf.nn.softmax(tf.matmul(L_1,W_2) + b_2) loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y,logits=prediction)) train_step = tf.train.MomentumOptimizer(0.2,0.9).minimize(loss) # 初始化变量 init = tf.global_variables_initializer() # 结果存放在一个布尔型列表中 correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(prediction, 1)) # argmax返回一维张量中最大的值所在的位置 # 求准确率 accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) gpu_options=tf.GPUOptions(per_process_gpu_memory_fraction=0.333) with tf.Session() as sess: sess.run(init) for epoch in range(50): for batch in range(n_batch): batch_xs, batch_ys = mnist.train.next_batch(batch_size) sess.run(train_step, feed_dict={x: batch_xs, y: batch_ys}) acc = sess.run(accuracy, feed_dict={x: mnist.test.images, y: mnist.test.labels}) print("Iter" + str(epoch) + ",Testing Accuracy " + str(acc))

    二、结果

     

    Iter0,Testing Accuracy 0.561 Iter1,Testing Accuracy 0.7655 Iter2,Testing Accuracy 0.8711 Iter3,Testing Accuracy 0.8737 Iter4,Testing Accuracy 0.9675 Iter5,Testing Accuracy 0.9727 Iter6,Testing Accuracy 0.9741 Iter7,Testing Accuracy 0.9777 Iter8,Testing Accuracy 0.9783 Iter9,Testing Accuracy 0.9794 Iter10,Testing Accuracy 0.98 Iter11,Testing Accuracy 0.981 Iter12,Testing Accuracy 0.9815 Iter13,Testing Accuracy 0.9822 Iter14,Testing Accuracy 0.9812 Iter15,Testing Accuracy 0.9801 Iter16,Testing Accuracy 0.9811 Iter17,Testing Accuracy 0.9814 Iter18,Testing Accuracy 0.9811 Iter19,Testing Accuracy 0.9807 Iter20,Testing Accuracy 0.9808 Iter21,Testing Accuracy 0.9798 Iter22,Testing Accuracy 0.981 Iter23,Testing Accuracy 0.9809 Iter24,Testing Accuracy 0.9811 Iter25,Testing Accuracy 0.9812 Iter26,Testing Accuracy 0.9818 Iter27,Testing Accuracy 0.9821 Iter28,Testing Accuracy 0.9814 Iter29,Testing Accuracy 0.9812 Iter30,Testing Accuracy 0.9817 Iter31,Testing Accuracy 0.9817 Iter32,Testing Accuracy 0.9816 Iter33,Testing Accuracy 0.9811 Iter34,Testing Accuracy 0.9815 Iter35,Testing Accuracy 0.9819 Iter36,Testing Accuracy 0.9821 Iter37,Testing Accuracy 0.982 Iter38,Testing Accuracy 0.9822 Iter39,Testing Accuracy 0.9824 Iter40,Testing Accuracy 0.9823 Iter41,Testing Accuracy 0.982 Iter42,Testing Accuracy 0.9823 Iter43,Testing Accuracy 0.9826 Iter44,Testing Accuracy 0.9826 Iter45,Testing Accuracy 0.9824 Iter46,Testing Accuracy 0.9825 Iter47,Testing Accuracy 0.9824 Iter48,Testing Accuracy 0.9825 Iter49,Testing Accuracy 0.9823

     

    Processed: 0.021, SQL: 9