K-近邻分类器(KNN)是一种在线分类器,也就是说在分类的时候直接从训练样本中找出与待分类样本最接近的K个样本,以判断待分类样本的类别。初学者容易把KNN和Kmeans搞混,KNN是一种最简单的有监督分类方法,而Kmeans是一种无监督的聚类方法,Kmeans不直接得到样本的类别,而是根据样本本身特性将他们分别聚成几个簇。
KNN的思想:首先,计算新样本与训练样本之间的距离,找到距离最近的K个近邻,统计这K个近邻中个数最多的类别,然后把新样本归为该类别,通常K是不大于20的整数。经验上,通常取,N为训练样本的数目,有时候为了简单可以取10为默认值。
KNN是基于一个假设而建立的,即近邻的对象具有类似的预测值。如何判断两个对象是否近邻,可以通过距离函数和K值确定。通常使用最多的距离函数是欧氏距离,计算新样本与每个训练样本之间的欧氏距离,然后根据距离大小进行排序,取出K个最近的样本作为最近邻样本。K值的选取是根据每类样本中的数目和分散程度进行的,对不同的应用可以选取不同的K值。
KNN的优点:
(1)简单,分类效果好,应用广
(2)在样本比较少,且对分类速度没有太高要求时,KNN比较好用
KNN的缺点:
(1)因为要存储所有训练样本,所以需要较大的存储空间
(2)每次都要计算所有样本与新样本的距离,计算量大,样本多时速度会变慢
(3)距离函数的选取具有主观性,分类效果完全依赖于距离函数
KNN在使用中需要注意的问题:
(1)寻找适当的训练数据集
训练数据集应是对历史数据的一个很好的覆盖,这样才能保证KNN有利于预测,选择训练数据集的原则是使各类样本的数量大体一致,以及选取的历史数据要有代表性。常用的方法是按照类别把历史数据分组,然后再由每组中选取一些有代表性的样本组成训练集。这样既降低了训练集的大小,又保持了较高的准确度。
(2)确定距离函数
距离函数决定了哪些样本是待分类样本的K个最近邻,它的选取取决于实际的数据和决策问题。如果样本是空间中点,最常用的是欧氏距离,其他还有绝对距离、平方差和标准差等。
(3)决定最终类别的方法
通常使用多数法,即在K个最近邻中选择出现次数最高的类别作为新样本的最终类别,如果频率最高的类别不止一个,就选择最近邻的类别。也可以使用权重法,比较复杂,它对K个最近邻设置权重,距离越大,权重越小,然后计算每个类别的权重和,最大的那个就是新样本的类别。
(4)K值的选取
经验上,通常取,N为训练样本的数目,有时候为了简单可以取10为默认值。
以下是《模式识别与人工智能(基于matlab)》的一段代码
首先实现knn方法:
function [label_test] = knn(k, data_train, label_train, data_test) % data_train:(m, N1) % data_test:(m, N2) % label_train:(1, N1) % 其中m为特征数,N1为训练样本数,N2为测试样本数 error(nargchk(4,4,nargin)); %计算新样本与训练样本的距离 dist = l2_distance(data_train, data_test); % dist的shape为(N1, N2) %对距离进行排序 [sorted_dist, nearest] = sort(dist); % sorted_dist:排序后的距离矩阵,shape:(N1, N2) % nearest:排序后的下标矩阵,shape:(N1, N2) % 选出K个最近邻样本的下标 nearest = nearest(1:k,:); % nearest shape:(k, N2) % 根据近邻样本的下标找到该样本对应的类别 label_test = label_train(nearest); % label_test shape:(k, N2)欧氏距离函数的实现:
function d = l2_distance(X,Y) % 计算出x,y之间的欧式距离 if (nargin < 2) [D N] = size(X); lengths = sum(X.^2,1); d = repmat(lengths,[N 1]) + repmat(lengths',[1 N]); d = d - 2* X'*X; else XX = sum(X.^2,1); YY = sum(Y.^2,1); d = repmat(XX', [1 size(Y,2)]) + repmat(YY, [size(X,2) 1]); d = d - 2*X'*Y; end使用knn进行分类:
clear; clc; DATA = load('D.mat'); %% 绘制训练数据图 first = DATA.train_data(DATA.train_label==1,:,:); second = DATA.train_data(DATA.train_label==2,:,:); third = DATA.train_data(DATA.train_label==3,:,:); fourth = DATA.train_data(DATA.train_label==4,:,:); figure scatter3(first(:,1),first(:,2),first(:,3),'*'); hold on scatter3(second(:,1),second(:,2),second(:,3),'p'); scatter3(third(:,1),third(:,2),third(:,3),'s'); scatter3(fourth(:,1),fourth(:,2),fourth(:,3),'o'); title('训练数据');legend('第1类','第2类','第3类','第4类'); %% KNN寻优 acc = zeros(10,1); % 用来存储K值分别为1-10时的分类准确率 for k = 1:10 % KNN 算法 label_test = knn(k, DATA.train_data', DATA.train_label', DATA.test_data'); % 计算最终结果 if k ==1 testResults = label_test; else [maxCount,idx] = max(label_test); % 应该不是用这个函数吧? testResults = maxCount; % 得到在K个近邻中出现最多次的类别 end % 存储各分类结果 RESULTS(k,:) = testResults; % 计算正确率 count = 0; for i=1:30 if (testResults(i) == DATA.test_label(i)) count = count+1; end end acc(k) = count/30; end disp('精度:') disp(acc); %% 求出最优 K [~,K] = max(acc); disp('最佳的K值为:'); disp(K); %% 绘制测试数据分类图,并在命令行窗口显示分类 % 使用最优K进行一次测试 label_test = knn(K, DATA.train_data', DATA.train_label', DATA.test_data'); if K ==1 testResults = label_test else [maxCount,idx] = max(label_test); testResults = maxCount end %% 绘制测试数据图 first = DATA.test_data(testResults==1,:,:); second = DATA.test_data(testResults==2,:,:); third = DATA.test_data(testResults==3,:,:); fourth = DATA.test_data(testResults==4,:,:); figure; scatter3(first(:,1),first(:,2),first(:,3),'*'); hold on scatter3(second(:,1),second(:,2),second(:,3),'p'); scatter3(third(:,1),third(:,2),third(:,3),'s'); scatter3(fourth(:,1),fourth(:,2),fourth(:,3),'o'); title('测试数据');legend('第1类','第2类','第3类','第4类');训练数据图以及测试数据的分类结果图:
命令行窗口的输出:
精度: 0.9667 0.9333 0.9333 0.9333 0.8000 0.8000 0.7667 0.6333 0.6333 0.6333
最佳的K值为: 1
testResults =
1 至 15 列
3 3 1 3 4 2 2 3 4 1 3 3 1 2 4
16 至 30 列
2 4 3 4 2 2 3 3 1 1 4 1 3 3 3
>>