深度学习论文: LEDnet: A lightweight encoder-decoder network for real-time semantic segmentation及其PyTorch实现 LEDnet: A lightweight encoder-decoder network for real-time semantic segmentation PDF:https://arxiv.org/pdf/1905.02423.pdf PyTorch: https://github.com/shanglianlm0525/PyTorch-Networks
1 概述
LEDNet的不对称结构(asymmetrical architecture),如上图所示,使得网络参数大大减少,加速了推理过程;
残差网络中的 Channel split and shuffle 有强大的特征表示。
在 decoder 端,采用特征金字塔的注意力机制来设计APN,进一步降低了整个网络的复杂性。
模型参数不到1M,并且能够在单个GTX 1080Ti GPU中以超过71 FPS的速度运行。
2 LEDnet
LEDNet 由两部分构成:编码网络和解码网络
编码模块: LEDNet 的非对称机制使得可以减少参数量,加速推理过程
残差模块中的 channel split 和 shuffle 机制可以减小网络规模,提升特征表达能力。
skip connection 允许卷积学习残差函数来帮助训练, split 和 shuffle 过程能够加强通道间的信息转换同时保持类似于一维分解卷积的计算开销。
解码模块:
使用特征金字塔注意力机制来设计 attention pyramid network(APN)用来抽取丰富特征,使用注意力机制来估计每个像素的语义标签,
2-1 SS-nbt module
采用split-transform-merge 策略
class SS_nbt(nn
.Module
):
def __init__(self
, channels
, dilation
=1, groups
=4):
super(SS_nbt
, self
).__init__
()
mid_channels
= channels
// 2
self
.half_split
= HalfSplit
(dim
=1)
self
.first_bottleneck
= nn
.Sequential
(
ConvReLU
(in_channels
=mid_channels
, out_channels
=mid_channels
, kernel_size
=[3, 1], stride
=1, padding
=[1, 0]),
ConvBNReLU
(in_channels
=mid_channels
, out_channels
=mid_channels
, kernel_size
=[1, 3], stride
=1, padding
=[0, 1]),
ConvReLU
(in_channels
=mid_channels
, out_channels
=mid_channels
, kernel_size
=[3, 1], stride
=1, dilation
=[dilation
,1], padding
=[dilation
, 0]),
ConvBNReLU
(in_channels
=mid_channels
, out_channels
=mid_channels
, kernel_size
=[1, 3], stride
=1, dilation
=[1,dilation
], padding
=[0, dilation
]),
)
self
.second_bottleneck
= nn
.Sequential
(
ConvReLU
(in_channels
=mid_channels
, out_channels
=mid_channels
, kernel_size
=[1, 3], stride
=1, padding
=[0, 1]),
ConvBNReLU
(in_channels
=mid_channels
, out_channels
=mid_channels
, kernel_size
=[3, 1], stride
=1, padding
=[1, 0]),
ConvReLU
(in_channels
=mid_channels
, out_channels
=mid_channels
, kernel_size
=[1, 3], stride
=1, dilation
=[1,dilation
], padding
=[0, dilation
]),
ConvBNReLU
(in_channels
=mid_channels
, out_channels
=mid_channels
, kernel_size
=[3, 1], stride
=1, dilation
=[dilation
,1], padding
=[dilation
, 0]),
)
self
.channelShuffle
= ChannelShuffle
(groups
)
def forward(self
, x
):
x1
, x2
= self
.half_split
(x
)
x1
= self
.first_bottleneck
(x1
)
x2
= self
.second_bottleneck
(x2
)
out
= torch
.cat
([x1
, x2
], dim
=1)
return self
.channelShuffle
(out
+x
)
2-2 APN
class APN(nn
.Module
):
def __init__(self
, in_channels
, out_channels
):
super(APN
, self
).__init__
()
self
.conv1_1
= ConvBNReLU
(in_channels
=in_channels
, out_channels
=in_channels
, kernel_size
=3, stride
=2, padding
=1)
self
.conv1_2
= Conv1x1BNReLU
(in_channels
=in_channels
, out_channels
=out_channels
)
self
.conv2_1
= ConvBNReLU
(in_channels
=in_channels
, out_channels
=in_channels
, kernel_size
=5, stride
=2, padding
=2)
self
.conv2_2
= Conv1x1BNReLU
(in_channels
=in_channels
, out_channels
=out_channels
)
self
.conv3
= nn
.Sequential
(
ConvBNReLU
(in_channels
=in_channels
, out_channels
=in_channels
, kernel_size
=7, stride
=2, padding
=3),
Conv1x1BNReLU
(in_channels
=in_channels
, out_channels
=out_channels
),
)
self
.conv1
= nn
.Sequential
(
ConvBNReLU
(in_channels
=in_channels
, out_channels
=in_channels
, kernel_size
=3, stride
=2, padding
=1),
Conv1x1BNReLU
(in_channels
=in_channels
,out_channels
=out_channels
),
)
self
.branch2
= Conv1x1BNReLU
(in_channels
=in_channels
, out_channels
=out_channels
)
self
.branch3
= nn
.Sequential
(
nn
.AdaptiveAvgPool2d
(output_size
=1),
nn
.Conv2d
(in_channels
=in_channels
, out_channels
=out_channels
,kernel_size
=1, stride
=1,padding
=0),
)
def forward(self
, x
):
_
, _
, h
, w
= x
.shape
x1
= self
.conv1_1
(x
)
x2
= self
.conv2_1
(x1
)
x3
= self
.conv3
(x2
)
x3
= F
.interpolate
(x3
, size
=(h
//4, w
//4), mode
='bilinear', align_corners
=True)
x2
= self
.conv2_2
(x2
) + x3
x2
= F
.interpolate
(x2
, size
=(h
// 2, w
// 2), mode
='bilinear', align_corners
=True)
x1
= self
.conv1_2
(x1
) + x2
out1
= F
.interpolate
(x1
, size
=(h
, w
), mode
='bilinear', align_corners
=True)
out2
= self
.branch2
(x
)
out3
= self
.branch3
(x
)
out3
= F
.interpolate
(out3
, size
=(h
, w
), mode
='bilinear', align_corners
=True)
return out1
* out2
+ out3
2-3 LEDNet Architecture
import torch
import torch
.nn
as nn
import torchvision
import torch
.nn
.functional
as F
def ConvBNReLU(in_channels
,out_channels
,kernel_size
,stride
,padding
,dilation
=[1,1],groups
=1):
return nn
.Sequential
(
nn
.Conv2d
(in_channels
=in_channels
, out_channels
=out_channels
, kernel_size
=kernel_size
, stride
=stride
, padding
=padding
,dilation
=dilation
,groups
=groups
, bias
=False),
nn
.BatchNorm2d
(out_channels
),
nn
.ReLU6
(inplace
=True)
)
def ConvBN(in_channels
,out_channels
,kernel_size
,stride
,padding
,dilation
=[1,1],groups
=1):
return nn
.Sequential
(
nn
.Conv2d
(in_channels
=in_channels
, out_channels
=out_channels
, kernel_size
=kernel_size
, stride
=stride
, padding
=padding
,dilation
=dilation
,groups
=groups
, bias
=False),
nn
.BatchNorm2d
(out_channels
)
)
def ConvReLU(in_channels
,out_channels
,kernel_size
,stride
,padding
,dilation
=[1,1],groups
=1):
return nn
.Sequential
(
nn
.Conv2d
(in_channels
=in_channels
, out_channels
=out_channels
, kernel_size
=kernel_size
, stride
=stride
, padding
=padding
,dilation
=dilation
,groups
=groups
, bias
=False),
nn
.ReLU6
(inplace
=True)
)
def Conv1x1BNReLU(in_channels
,out_channels
):
return nn
.Sequential
(
nn
.Conv2d
(in_channels
=in_channels
, out_channels
=out_channels
, kernel_size
=1, stride
=1, bias
=False),
nn
.BatchNorm2d
(out_channels
),
nn
.ReLU6
(inplace
=True)
)
def Conv1x1BN(in_channels
,out_channels
):
return nn
.Sequential
(
nn
.Conv2d
(in_channels
=in_channels
, out_channels
=out_channels
, kernel_size
=1, stride
=1, bias
=False),
nn
.BatchNorm2d
(out_channels
)
)
class HalfSplit(nn
.Module
):
def __init__(self
, dim
=1):
super(HalfSplit
, self
).__init__
()
self
.dim
= dim
def forward(self
, input):
splits
= torch
.chunk
(input, 2, dim
=self
.dim
)
return splits
[0], splits
[1]
class ChannelShuffle(nn
.Module
):
def __init__(self
, groups
):
super(ChannelShuffle
, self
).__init__
()
self
.groups
= groups
def forward(self
, x
):
'''Channel shuffle: [N,C,H,W] -> [N,g,C/g,H,W] -> [N,C/g,g,H,w] -> [N,C,H,W]'''
N
, C
, H
, W
= x
.size
()
g
= self
.groups
return x
.view
(N
, g
, int(C
/ g
), H
, W
).permute
(0, 2, 1, 3, 4).contiguous
().view
(N
, C
, H
, W
)
class SS_nbt(nn
.Module
):
def __init__(self
, channels
, dilation
=1, groups
=4):
super(SS_nbt
, self
).__init__
()
mid_channels
= channels
// 2
self
.half_split
= HalfSplit
(dim
=1)
self
.first_bottleneck
= nn
.Sequential
(
ConvReLU
(in_channels
=mid_channels
, out_channels
=mid_channels
, kernel_size
=[3, 1], stride
=1, padding
=[1, 0]),
ConvBNReLU
(in_channels
=mid_channels
, out_channels
=mid_channels
, kernel_size
=[1, 3], stride
=1, padding
=[0, 1]),
ConvReLU
(in_channels
=mid_channels
, out_channels
=mid_channels
, kernel_size
=[3, 1], stride
=1, dilation
=[dilation
,1], padding
=[dilation
, 0]),
ConvBNReLU
(in_channels
=mid_channels
, out_channels
=mid_channels
, kernel_size
=[1, 3], stride
=1, dilation
=[1,dilation
], padding
=[0, dilation
]),
)
self
.second_bottleneck
= nn
.Sequential
(
ConvReLU
(in_channels
=mid_channels
, out_channels
=mid_channels
, kernel_size
=[1, 3], stride
=1, padding
=[0, 1]),
ConvBNReLU
(in_channels
=mid_channels
, out_channels
=mid_channels
, kernel_size
=[3, 1], stride
=1, padding
=[1, 0]),
ConvReLU
(in_channels
=mid_channels
, out_channels
=mid_channels
, kernel_size
=[1, 3], stride
=1, dilation
=[1,dilation
], padding
=[0, dilation
]),
ConvBNReLU
(in_channels
=mid_channels
, out_channels
=mid_channels
, kernel_size
=[3, 1], stride
=1, dilation
=[dilation
,1], padding
=[dilation
, 0]),
)
self
.channelShuffle
= ChannelShuffle
(groups
)
def forward(self
, x
):
x1
, x2
= self
.half_split
(x
)
x1
= self
.first_bottleneck
(x1
)
x2
= self
.second_bottleneck
(x2
)
out
= torch
.cat
([x1
, x2
], dim
=1)
return self
.channelShuffle
(out
+x
)
class DownSampling(nn
.Module
):
def __init__(self
, in_channels
, out_channels
):
super(DownSampling
, self
).__init__
()
mid_channels
= out_channels
- in_channels
self
.conv
= nn
.Conv2d
(in_channels
=in_channels
,out_channels
=mid_channels
,kernel_size
=3,stride
=2,padding
=1)
self
.maxpool
= nn
.MaxPool2d
(kernel_size
=3,stride
=2, padding
=1)
self
.bn
= nn
.BatchNorm2d
(out_channels
)
self
.relu
= nn
.ReLU
(inplace
=True)
def forward(self
, x
):
x1
= self
.conv
(x
)
x2
= self
.maxpool
(x
)
output
= torch
.cat
([x1
, x2
], 1)
return self
.relu
(self
.bn
(output
))
class Encoder(nn
.Module
):
def __init__(self
, groups
= 4):
super(Encoder
, self
).__init__
()
planes
= [32, 64, 128]
self
.downSampling1
= DownSampling
(in_channels
=3, out_channels
=planes
[0])
self
.ssBlock1
= self
._make_layer
(channels
=planes
[0], dilation
=1, groups
=groups
, block_num
=3)
self
.downSampling2
= DownSampling
(in_channels
=32, out_channels
=planes
[1])
self
.ssBlock2
= self
._make_layer
(channels
=planes
[1], dilation
=1, groups
=groups
, block_num
=2)
self
.downSampling3
= DownSampling
(in_channels
=planes
[1], out_channels
=planes
[2])
self
.ssBlock3
= nn
.Sequential
(
SS_nbt
(channels
=planes
[2], dilation
=1, groups
=groups
),
SS_nbt
(channels
=planes
[2], dilation
=2, groups
=groups
),
SS_nbt
(channels
=planes
[2], dilation
=5, groups
=groups
),
SS_nbt
(channels
=planes
[2], dilation
=9, groups
=groups
),
SS_nbt
(channels
=planes
[2], dilation
=2, groups
=groups
),
SS_nbt
(channels
=planes
[2], dilation
=5, groups
=groups
),
SS_nbt
(channels
=planes
[2], dilation
=9, groups
=groups
),
SS_nbt
(channels
=planes
[2], dilation
=17, groups
=groups
),
)
def _make_layer(self
, channels
, dilation
, groups
, block_num
):
layers
= []
for idx
in range(block_num
):
layers
.append
(SS_nbt
(channels
, dilation
=dilation
, groups
=groups
))
return nn
.Sequential
(*layers
)
def forward(self
, x
):
x
= self
.downSampling1
(x
)
x
= self
.ssBlock1
(x
)
x
= self
.downSampling2
(x
)
x
= self
.ssBlock2
(x
)
x
= self
.downSampling3
(x
)
out
= self
.ssBlock3
(x
)
return out
class APN(nn
.Module
):
def __init__(self
, in_channels
, out_channels
):
super(APN
, self
).__init__
()
self
.conv1_1
= ConvBNReLU
(in_channels
=in_channels
, out_channels
=in_channels
, kernel_size
=3, stride
=2, padding
=1)
self
.conv1_2
= Conv1x1BNReLU
(in_channels
=in_channels
, out_channels
=out_channels
)
self
.conv2_1
= ConvBNReLU
(in_channels
=in_channels
, out_channels
=in_channels
, kernel_size
=5, stride
=2, padding
=2)
self
.conv2_2
= Conv1x1BNReLU
(in_channels
=in_channels
, out_channels
=out_channels
)
self
.conv3
= nn
.Sequential
(
ConvBNReLU
(in_channels
=in_channels
, out_channels
=in_channels
, kernel_size
=7, stride
=2, padding
=3),
Conv1x1BNReLU
(in_channels
=in_channels
, out_channels
=out_channels
),
)
self
.conv1
= nn
.Sequential
(
ConvBNReLU
(in_channels
=in_channels
, out_channels
=in_channels
, kernel_size
=3, stride
=2, padding
=1),
Conv1x1BNReLU
(in_channels
=in_channels
,out_channels
=out_channels
),
)
self
.branch2
= Conv1x1BNReLU
(in_channels
=in_channels
, out_channels
=out_channels
)
self
.branch3
= nn
.Sequential
(
nn
.AdaptiveAvgPool2d
(output_size
=1),
nn
.Conv2d
(in_channels
=in_channels
, out_channels
=out_channels
,kernel_size
=1, stride
=1,padding
=0),
)
def forward(self
, x
):
_
, _
, h
, w
= x
.shape
x1
= self
.conv1_1
(x
)
x2
= self
.conv2_1
(x1
)
x3
= self
.conv3
(x2
)
x3
= F
.interpolate
(x3
, size
=(h
//4, w
//4), mode
='bilinear', align_corners
=True)
x2
= self
.conv2_2
(x2
) + x3
x2
= F
.interpolate
(x2
, size
=(h
// 2, w
// 2), mode
='bilinear', align_corners
=True)
x1
= self
.conv1_2
(x1
) + x2
out1
= F
.interpolate
(x1
, size
=(h
, w
), mode
='bilinear', align_corners
=True)
out2
= self
.branch2
(x
)
out3
= self
.branch3
(x
)
out3
= F
.interpolate
(out3
, size
=(h
, w
), mode
='bilinear', align_corners
=True)
return out1
* out2
+ out3
class Decoder(nn
.Module
):
def __init__(self
, in_channels
,num_classes
):
super(Decoder
, self
).__init__
()
self
.apn
= APN
(in_channels
=in_channels
, out_channels
=num_classes
)
def forward(self
, x
):
_
, _
, h
, w
= x
.shape
apn_x
= self
.apn
(x
)
out
= F
.interpolate
(apn_x
, size
=(h
*8, w
*8), mode
='bilinear', align_corners
=True)
return out
class LEDnet(nn
.Module
):
def __init__(self
, num_classes
=20):
super(LEDnet
, self
).__init__
()
self
.encoder
= Encoder
()
self
.decoder
= Decoder
(in_channels
=128,num_classes
=num_classes
)
def forward(self
, x
):
x
= self
.encoder
(x
)
out
= self
.decoder
(x
)
return out
if __name__
== '__main__':
model
= LEDnet
(num_classes
=20)
print(model
)
input = torch
.randn
(1,3,1024,512)
output
= model
(input)
print(output
.shape
)
3 Experiments