“交叉验证法” (cross alidation) 将数据D分为 k个大小相似的互斥子集, 尽可保持数据分布的 致性,即从D通过分层采样得到后,每次用 k-1 子集的并集作为训练集,余下的那个子集作测试集;这样就可获得 k组训练 /试集,从而可进行 k次训练和测试,最终返回的是测试结果 的均值 ,显然,交叉验证法评估结果的稳定性和保真性在很大程上取决于k 的取取值,为强调这一点,通常把交叉验证法称为 ‘k折交叉验证" (k-fold cross validation). 最常用 的取 10 ,此时称为 折交叉验 其他常用 等.图 给出了 10 折交叉验证的示意图.
采用python五折交叉,数据集为xml形式的标注文件夹中获取,存储到列表,之后进行按比例分配
annotations_dir=r'C:\Users\55192\Desktop\Annotations_hanfeng2'#标注信息 import os o=[]#存储所有的标注信息 for filename in os.listdir(annotations_dir): o.append(filename[:-4]) print(len(o)) print(o) for i in range(len(o)): print('{}\n'.format(o[i][:-4]))#\n换行 def write_txt(txt,o):#写入文本信息,参数分别为:路径,即将写入的列表 with open(txt,'w') as f: for i in range(len(o)): f.write('{}\n'.format(o[i]))#\n换行 def read_txt(txt):#阅读文本信息 lines = [] with open(txt, 'r') as f: for eachline in f: eachline = eachline.strip('\n') eachline = str(eachline) #line = eachline + '.bmp' #print(line) lines.append(eachline) #splitlines = [x.strip().split(' ') for x in lines] return lines#返回带类列表 #quan=read_txt(r'D:\research\ce_Faster-RCNN-TensorFlow-Python3-master\data\VOCdevkit2007\VOC2007\ImageSets\Main\quan.txt') #print(quan,len(quan)) quan=o n=len(quan) n1=int(0.2*n)#安装8:2分配数据集进行交叉验证 test1=quan[0:n1] train1=[k for k in quan if k not in test1]#取补集合 print(test1) print(train1) test2=quan[n1:2*n1] train2=[k for k in quan if k not in test2] print(test2) print(train2) test3=quan[2*n1:3*n1] train3=[k for k in quan if k not in test3] print(test3) print(train3) test4=quan[3*n1:4*n1] train4=[k for k in quan if k not in test4] print(test4) print(train4) test5=quan[4*n1:] train5=[k for k in quan if k not in test5] print(test5) print(train5) write_txt('./trainval1.txt',train1) write_txt('./test1.txt',test1) write_txt('./trainval2.txt',train2) write_txt('./test2.txt',test2) write_txt('./trainval3.txt',train3) write_txt('./test3.txt',test3) write_txt('./trainval4.txt',train4) write_txt('./test4.txt',test4) write_txt('./trainval5.txt',train5) write_txt('./test5.txt',test5) quan1=len(read_txt('./test1.txt'))+len(read_txt('./trainval1.txt')) quan2=len(read_txt('./test2.txt'))+len(read_txt('./trainval2.txt')) quan3=len(read_txt('./test3.txt'))+len(read_txt('./trainval3.txt')) quan4=len(read_txt('./test4.txt'))+len(read_txt('./trainval4.txt')) quan5=len(read_txt('./test5.txt'))+len(read_txt('./trainval5.txt')) print(quan1,quan2,quan3,quan4,quan5)