深度学习Keras入门(3)--创建简单手写体识别模型

    技术2022-07-15  50

    我们简单介绍一下,如何使用Keras创建一个简单手写识别模型

    Step 1: 导入数据

            从keras的库中导入手写体识别数据集mnist模块

    from keras.datasets import mnist

    Step 2: 数据划分

    (train_images, train_labels),(test_images, test_labels) = mnist.load_data()

    Step 3: 探索数据         

    train_images.shape test_images.shape

    Step 4: 模型初始化

    from keras import models network = models.Sequential()

               Sequential()函数代表建立序列模型

    Step 5: 构建两个全连接层模型结构

    network.add(layers.Dense(512, activation = "relu", input_shape=(28*28,))) network.add(layers.Dense(10, activation = "softmax"))

             第一个全连接层512个神经元,输入对应一个图表的像素数量(28*28)

            第一个全连接层10个神经元和10个数据手写体对应

           relu和softmax分别为激活函数类型

    Step 6: 确定网络训练方式

    network.compile(optimizer = "rmsprop", loss = "categorical_crossentropy", metrics = ["accuracy"])

     optimizer确定优化方法,loss损失函数形式,metrics模型度量指标

    Step 6: 数据集与处理

    train_images = train_images.reshape(60000, 28*28) train_images = train_images.astype("float32")/255 test_images = test_images.reshape(10000, 28*28) test_images = test_images.astype("float32")/255 from keras.utils import to_categorical train_labels = to_categorical(train_labels) test_labels = to_categorical(test_labels)

    把数据集的三维矩阵转换为二维,并进行0-1规范化,0-255是像素原始值取值范围

    to_categorical()把类别标签由数值转换为类别型

    Step 7: 模型训练

    network.fit(train_images, train_labels, epochs = 5, batch_size =128) Epoch 1/5 60000/60000 [==============================] - 6s 99us/step - loss: 0.2570 - accuracy: 0.9254 Epoch 2/5 60000/60000 [==============================] - 6s 105us/step - loss: 0.1050 - accuracy: 0.9688 Epoch 3/5 60000/60000 [==============================] - 7s 109us/step - loss: 0.0686 - accuracy: 0.9795 Epoch 4/5 60000/60000 [==============================] - 6s 100us/step - loss: 0.0494 - accuracy: 0.9849 Epoch 5/5 60000/60000 [==============================] - 6s 98us/step - loss: 0.0368 - accuracy: 0.9890

        训练五轮精度98.90

    Step 8: 预测和验证

    test_loss, test_accuracy = network.evaluate(test_images, test_labels) print(test_loss, test_accuracy) 0.06696232722438872 0.9804999828338623

    预测精度98.04

    搞定,效果还不错

    Processed: 0.020, SQL: 9