在一些计算机视觉任务中,需要对模型的输出做一些后处理以优化视觉效果,连通域就是一种常见的后处理方式。尤其对于分割任务,有时的输出mask会存在一些假阳(小的无用轮廓),通过3D连通域找出面积较小的独立轮廓并去除可以有效地提升视觉效果。skimage是python的一个图像处理库,功能强大并且使用方便。这里主要使用了skimage里的measure模块,官方也提供了相关方法的说明及使用。
模型输出了一个n*w*h的三维数组,数组中包含了模型对某一类(或几类)的分割结果,从形态学角度看在3D上每一类分割结果都应该是连续的,因此可以使用3D连通域去除那些不连通的小区域。
连通域(Connected Component)指的是图像中具有相同像素值且位置相邻的前景像素点组成的图像区域(Region)。连通域一般有两种,分别为4连通和8连通,下面借用skimage的示意图。
1-connectivity 2-connectivity diagonal connection close-up [ ] [ ] [ ] [ ] [ ] | \ | / | <- hop 2 [ ]--[x]--[ ] [ ]--[x]--[ ] [x]--[ ] | / | \ hop 1 [ ] [ ] [ ] [ ]连通域的实现有两种方式,一是two-pass算法,通过并查集(union-find)实现;二是Seed-Filling 算法,基于图形学中的区域生长,可以递归实现。
input: int类型的ndarray,需要标记的图像; neighbors: 废弃的参数,被connectivity替代; background: int类型,可选,将像素为该值的视为背景并设置为0,默认为0; return_num: bool,可选,是否返回标记的区域数量; connectivity: int,可选,最大为输入的ndims,选择是几连通(对于输入为2D来说,1即为4连通,2为8连通)。
返回值:labels: 形状和类型与input一致,连通的区域使用相同的整数标记 nums: 标签数,等于最大标签索引,仅在return_num为True时返回。
label_image: (N, M) ndarray,经过标记的图像,0当作背景忽略; intensity_image: (N, M) ndarray,与标签图像大小相同的强度(即输入)图像。 默认为无; cache: bool,可选,确定是否缓存计算的区域属性。 对于缓存的属性,计算速度要快得多,而内存消耗却增加了。 coordinates: 已淘汰,不推荐使用。
返回值:properties:属性列表,列表中每一个对象对应一个标记的区域,可以访问对象的属性获取区域的属性。
常用的属性: area:int, 区域面积;slice: tuple,对应的区域。详细见官方文档。有同学提供了基于SimpleITK库的实现方法,SimpleITK专门针对3维数据处理,因此速度比skimage要快一些。这儿直接附上代码,侵删。
import SimpleITK as sitk def connected_domain_2(image, mask=True): cca = sitk.ConnectedComponentImageFilter() cca.SetFullyConnected(True) _input = sitk.GetImageFromArray(image.astype(np.uint8)) output_ex = cca.Execute(_input) stats = sitk.LabelShapeStatisticsImageFilter() stats.Execute(output_ex) num_label = cca.GetObjectCount() num_list = [i for i in range(1, num_label+1)] area_list = [] for l in range(1, num_label +1): area_list.append(stats.GetNumberOfPixels(l)) num_list_sorted = sorted(num_list, key=lambda x: area_list[x-1])[::-1] largest_area = area_list[num_list_sorted[0] - 1] final_label_list = [num_list_sorted[0]] for idx, i in enumerate(num_list_sorted[1:]): if area_list[i-1] >= (largest_area//10): final_label_list.append(i) else: break output = sitk.GetArrayFromImage(output_ex) for one_label in num_list: if one_label in final_label_list: continue x, y, z, w, h, d = stats.GetBoundingBox(one_label) one_mask = (output[z: z + d, y: y + h, x: x + w] != one_label) output[z: z + d, y: y + h, x: x + w] *= one_mask if mask: output = (output > 0).astype(np.uint8) else: output = ((output > 0)*255.).astype(np.uint8) return output