Python KNN分类算法学习
本文实例为大家分享了Python KNN分类算法的具体代码,供大家参考,具体内容如下 1、KNN分类算法 KNN分类算法(K-Nearest-Neighbors Classification),又叫K近邻算法,是一个概念极其简单,而分类效果又很优秀的分类算法。 上图中要确定测试样本绿色属于蓝色还是红色。 显然,当K=3时,将以1:2的投票结果分类于红色;而K=5时,将以3:2的投票结果分类于蓝色。 KNN算法简单有效,但没有优化的暴力法效率容易达到瓶颈。如样本个数为N,特征维度为D的时候,该算法时间复杂度呈O(DN)增长。 所以通常KNN的实现会把训练数据构建成K-D Tree(K-dimensional tree),构建过程很快,甚至不用计算D维欧氏距离,而搜索速度高达O(D*log(N))。 不过当D维度过高,会产生所谓的”维度灾难“,最终效率会降低到与暴力法一样。 因此通常D>20以后,最好使用更高效率的Ball-Tree,其时间复杂度为O(D*log(N))。 人们经过长期的实践发现KNN算法虽然简单,但能处理大规模的数据分类,尤其适用于样本分类边界不规则的情况。最重要的是该算法是很多高级机器学习算法的基础。 当然,KNN算法也存在一切问题。比如如果训练数据大部分都属于某一类,投票算法就有很大问题了。这时候就需要考虑设计每个投票者票的权重了。 2、测试数据 测试数据的格式仍然和前面使用的身高体重数据一致。不过数据增加了一些: 3、Python代码 scikit-learn提供了优秀的KNN算法支持。使用Python代码如下: # -*- coding: utf-8 -*- import numpy as np from sklearn import neighbors from sklearn.metrics import precision_recall_curve from sklearn.metrics import classification_report from sklearn.cross_validation import train_test_split import matplotlib.pyplot as plt ''''' 数据读入 ''' data = [] labels = [] with open("data1.txt") as ifile: for line in ifile: tokens = line.strip().split(' ') data.append([float(tk) for tk in tokens[:-1]]) labels.append(tokens[-1]) x = np.array(data) labels = np.array(labels) y = np.zeros(labels.shape) ''''' 标签转换为0/1 ''' y[labels=='fat']=1 ''''' 拆分训练数据与测试数据 ''' x_train,x_test,y_train,y_test = train_test_split(x,y,test_size = 0.2) ''''' 创建网格以方便绘制 ''' h = .01 x_min,x_max = x[:,0].min() - 0.1,x[:,0].max() + 0.1 y_min,y_max = x[:,1].min() - 1,1].max() + 1 xx,yy = np.meshgrid(np.arange(x_min,x_max,h),np.arange(y_min,y_max,h)) ''''' 训练KNN分类器 ''' clf = neighbors.KNeighborsClassifier(algorithm='kd_tree') clf.fit(x_train,y_train) '''''测试结果的打印''' answer = clf.predict(x) print(x) print(answer) print(y) print(np.mean( answer == y)) '''''准确率与召回率''' precision,recall,thresholds = precision_recall_curve(y_train,clf.predict(x_train)) answer = clf.predict_proba(x)[:,1] print(classification_report(y,answer,target_names = ['thin','fat'])) ''''' 将整个测试空间的分类结果用不同颜色区分开''' answer = clf.predict_proba(np.c_[xx.ravel(),yy.ravel()])[:,1] z = answer.reshape(xx.shape) plt.contourf(xx,yy,z,cmap=plt.cm.Paired,alpha=0.8) ''''' 绘制训练样本 ''' plt.scatter(x_train[:,0],x_train[:,1],c=y_train,cmap=plt.cm.Paired) plt.xlabel(u'身高') plt.ylabel(u'体重') plt.show() 4、结果分析 其输出结果如下: KNN分类器在众多分类算法中属于最简单的之一,需要注意的地方不多。有这几点要说明: 以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持编程小技巧。 (编辑:李大同) 【声明】本站内容均来自网络,其相关言论仅代表作者个人观点,不代表本站立场。若无意侵犯到您的权利,请及时与联系站长删除相关内容! |