resnet18由4个基础的block组成,每个block有两层网络,block的结构如下。右边的箭头表示将x与f(x)相加,这样可以让经过一个block之后的效果至少不会比没有经过的差。这也是为什么resnet网络能够达到这么深的原因。
1.每个block由两层网络组成,每层网络由卷积层、batchnormalization层和激活函数组成。 2.当stride等于1时,经过block的数据的size没有变化,直接相加。不等于1时size变化,需要一次下采样来让x的size与f(x)保持一致。
class BasicBlock(layers.Layer): def __init__(self, filters_num, stride=1): super(BasicBlock, self).__init__() # 第一层 self.conv1 = layers.Conv2D(filters_num, (3, 3), strides=stride, padding='same') self.bn1 = layers.BatchNormalization() self.relu = layers.Activation('relu') # 第二层 self.conv2 = layers.Conv2D(filters_num, (3, 3), strides=1, padding='same') self.bn2 = layers.BatchNormalization() # 如果strides等于1时输入与输出的size没有改变,不需要进行下采样操作 if stride != 1: self.downsample = Sequential() # 经过下采样让x与f(x)的size保持一致 self.downsample.add(layers.Conv2D(filters_num, (1, 1), strides=stride)) else: self.downsample = lambda x: x def call(self, inputs, training=None): # 前向传播 out = self.conv1(inputs) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) # 下采样 identity = self.downsample(inputs) # f(x)+x output = layers.add([out, identity]) # relu发在相加前后都可以 output = tf.nn.relu(output) return output1.第一层是数据处理层,把输入的channel变为64传进下一层。 2.每一个block只有在第一层才可能有channel的变化。 3.layer_dims是一维矩阵,值代表对应的block有几层网络。
class ResNet(keras.Model): def __init__(self, layer_dims, num_classes=100): # layer_dims是一维矩阵,值代表对应的block有几层网络 super(ResNet, self).__init__() # 第一层数据处理层 self.stem = Sequential([layers.Conv2D(64, (3, 3), strides=(1, 1)), layers.BatchNormalization(), layers.Activation('relu'), layers.MaxPool2D(pool_size=(2, 2), strides=(1, 1), padding='same')]) self.layer1 = self.build_resblock(64, layer_dims[0]) self.layer2 = self.build_resblock(128, layer_dims[1], strides=(2, 2)) self.layer3 = self.build_resblock(256, layer_dims[2], strides=(2, 2)) self.layer4 = self.build_resblock(512, layer_dims[3], strides=(2, 2)) # 全局平均池化层, [b, 512, h, w]->[b, 512, 1, 1] self.avgpool = layers.GlobalAvgPool2D() # 全连接层 self.fc = layers.Dense(num_classes) def call(self, inputs, training=None, mask=None): out = self.stem(inputs) out = self.layer1(out) out = self.layer2(out) out = self.layer3(out) out = self.layer4(out) out = self.avgpool(out) out = self.fc(out) return out def build_resblock(self, filter_num, block, strides=1): ''' :param filter_num:channel :param block: 要搭建几层的resblock :param strides: :return: ''' res_block = Sequential() # res_block里的第一层,只有在第一层的时候才有size的变化 res_block.add(BasicBlock(filter_num, strides)) for _ in range(1, block): res_block.add(BasicBlock(filter_num, stride=1)) return res_blockdepth代表有几类,我用的是tf2内置的cifar100
def preprocess(x, y): x = tf.cast(x, dtype=tf.float32) / 255 y = tf.cast(y, dtype=tf.int32) return x, y def get_db(batch_size=128, shuffle=10000): (x, y), (x_test, y_test) = datasets.cifar100.load_data() y = tf.squeeze(y) y_test = tf.squeeze(y_test) y = tf.one_hot(y, depth=100) y_test = tf.one_hot(y_test, depth=100) db_train = tf.data.Dataset.from_tensor_slices((x, y)) db_train = db_train.map(preprocess).shuffle(shuffle).batch(batch_size=batch_size) db_t = tf.data.Dataset.from_tensor_slices((x_test, y_test)) db_t = db_t.map(preprocess).shuffle(shuffle).batch(batch_size=batch_size) return db_train, db_t在自动保存模型中,用period会报警说已经弃用,建议使用save_freq,save_freq是隔几个batch保存一次模型,period是隔几个epoch保存一次模型。
if __name__ == '__main__': # 加载数据集 db, db_test = get_db() model = resnet18() model.build(input_shape=(None, 32, 32, 3)) # 打印模型信息 model.summary() # 编译模型 model.compile(optimizer=optimizers.Adam(lr=1e-3), loss=tf.losses.CategoricalCrossentropy(from_logits=True), metrics=['accuracy']) # 设置tf.keras.callbacks.ModelCheckpoint回调实现自动保存模型 checkpoint_path = "weight/resnet18-{epoch:04d}.ckpt" # period=1:每1个epochs 保存一次 # 用period会报警说已经弃用,建议使用save_freq,save_freq是隔几个batch保存一次模型,period是隔几个epoch保存一次模型 cp_callback = callbacks.ModelCheckpoint( checkpoint_path, verbose=1, save_weights_only=True, period=1) # 训练模型 model.fit(db, epochs=10, validation_data=db_test, validation_freq=1, callbacks=[cp_callback]) # 评估模型 score = model.evaluate(db_test) print('Test score:', score[0]) print('Test accuracy:', score[1]) # 保存模型 model.save_weights('resnet18.ckpt') # 评估加载后的模型 test_model = resnet18() test_model.build(input_shape=(None, 32, 32, 3)) test_model.load_weights('resnet18.ckpt') test_model.compile(optimizer=optimizers.Adam(lr=1e-3), loss=tf.losses.CategoricalCrossentropy(from_logits=True), metrics=['accuracy']) score = test_model.evaluate(db_test) print('Test score:', score[0]) print('Test accuracy:', score[1])输出
_________________________________________________________________ Layer (type) Output Shape Param # ================================================================= sequential (Sequential) (None, 30, 30, 64) 2048 _________________________________________________________________ sequential_1 (Sequential) (None, 30, 30, 64) 148736 _________________________________________________________________ sequential_2 (Sequential) (None, 15, 15, 128) 526976 _________________________________________________________________ sequential_4 (Sequential) (None, 8, 8, 256) 2102528 _________________________________________________________________ sequential_6 (Sequential) (None, 4, 4, 512) 8399360 _________________________________________________________________ global_average_pooling2d (Gl multiple 0 _________________________________________________________________ dense (Dense) multiple 51300 ================================================================= Total params: 11,230,948 Trainable params: 11,223,140 Non-trainable params: 7,808 _________________________________________________________________ WARNING:tensorflow:`period` argument is deprecated. Please use `save_freq` to specify the frequency in number of batches seen. Epoch 1/10 2020-07-04 10:48:25.869669: I tensorflow/stream_executor/platform/default/dso_loader.cc:44] Successfully opened dynamic library libcublas.so.10 2020-07-04 10:48:26.270991: I tensorflow/stream_executor/platform/default/dso_loader.cc:44] Successfully opened dynamic library libcudnn.so.7 391/391 [==============================] - ETA: 0s - loss: 3.3788 - accuracy: 0.1892 391/391 [==============================] - 30s 76ms/step - loss: 3.3788 - accuracy: 0.1892 - val_loss: 5.8297 - val_accuracy: 0.0498 Epoch 2/10 390/391 [============================>.] - ETA: 0s - loss: 2.4293 - accuracy: 0.3623 391/391 [==============================] - 28s 72ms/step - loss: 2.4293 - accuracy: 0.3623 - val_loss: 3.5513 - val_accuracy: 0.2126 Epoch 3/10 390/391 [============================>.] - ETA: 0s - loss: 1.9221 - accuracy: 0.4764 391/391 [==============================] - 28s 73ms/step - loss: 1.9217 - accuracy: 0.4766 - val_loss: 2.8024 - val_accuracy: 0.3319 Epoch 4/10 390/391 [============================>.] - ETA: 0s - loss: 1.5254 - accuracy: 0.5679 391/391 [==============================] - 28s 73ms/step - loss: 1.5254 - accuracy: 0.5679 - val_loss: 2.0562 - val_accuracy: 0.4607 Epoch 5/10 390/391 [============================>.] - ETA: 0s - loss: 1.1593 - accuracy: 0.6600 391/391 [==============================] - 28s 73ms/step - loss: 1.1590 - accuracy: 0.6600 - val_loss: 2.3511 - val_accuracy: 0.4156 Epoch 6/10 390/391 [============================>.] - ETA: 0s - loss: 0.7823 - accuracy: 0.7653 391/391 [==============================] - 28s 72ms/step - loss: 0.7825 - accuracy: 0.7652 - val_loss: 2.2684 - val_accuracy: 0.4614 Epoch 7/10 390/391 [============================>.] - ETA: 0s - loss: 0.4402 - accuracy: 0.8674 391/391 [==============================] - 28s 73ms/step - loss: 0.4401 - accuracy: 0.8674 - val_loss: 2.3768 - val_accuracy: 0.4819 Epoch 8/10 390/391 [============================>.] - ETA: 0s - loss: 0.2215 - accuracy: 0.9355 391/391 [==============================] - 28s 73ms/step - loss: 0.2215 - accuracy: 0.9355 - val_loss: 2.3600 - val_accuracy: 0.4959 Epoch 9/10 390/391 [============================>.] - ETA: 0s - loss: 0.1309 - accuracy: 0.9646 391/391 [==============================] - 28s 72ms/step - loss: 0.1310 - accuracy: 0.9646 - val_loss: 2.6867 - val_accuracy: 0.4810 Epoch 10/10 390/391 [============================>.] - ETA: 0s - loss: 0.1630 - accuracy: 0.9509 391/391 [==============================] - 28s 72ms/step - loss: 0.1631 - accuracy: 0.9508 - val_loss: 3.1384 - val_accuracy: 0.4432 79/79 [==============================] - 1s 19ms/step - loss: 3.1384 - accuracy: 0.4432 Test score: 3.1384365558624268 Test accuracy: 0.4431999921798706 79/79 [==============================] - 2s 19ms/step - loss: 3.1384 - accuracy: 0.4432 Test score: 3.138436794281006 Test accuracy: 0.4431999921798706