欢迎您访问 最编程 本站为您分享编程语言代码,编程技术文章!
您现在的位置是: 首页

用Python详细解析k近邻(kNN)算法的实战指南

最编程 2024-07-28 08:57:40
...

jupyter notebook: kNN

《机器学习实战》kNN算法源码详解。

kNN algorithm

kNN算法步骤:

  1. 计算训练集中的每个点到当前点的距离
  2. 按照距离递增排序
  3. 选取距离当前点最近的k个点
  4. 计算这k个点所在类别的出现频率
  5. 返回出现频率最高的类别,作为当前点的类别

源码

from numpy import *
import operator as op

def createDataSet():
    group = array([[1.0, 1.1],      #numpy.array
                  [1.0, 1.0],
                  [0  , 0],
                  [0  , 0.1]])
    labels = ['A', 'A', 'B', 'B']   #list
    return group,labels

def classify0(inX, dataSet, labels, k):             #这里用到好多numpy的方法;numpy是向量化的、广播的
    dataSetSize = dataSet.shape[0]                  #计算训练集样本数
    
    #计算距离
    diffMat = tile(inX, (dataSetSize,1)) - dataSet  #tile将当前点的特诊向量扩展为与训练集特征矩阵相同的维度
    sqDiffMat = diffMat**2                          #求平方
    sqDistances = sqDiffMat.sum(axis=1)             #求矩阵每一行的和
    distances = sqDistances**0.5                    #求平方根
    
    #排序
    sortedDistIndicies = distances.argsort()        #返回排序后的样本的序号
    
    #选取最近的k个点
    classCount = {}                                 #空的dictionary
    for i in range(k):                              #循环范围: i=0~(k-1)
        voteIlabel = labels[sortedDistIndicies[i]]  #获取排名第i个的样本的标签
        classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1;                     #该类别的样本数加1

    #按照各类别的样本数量降序排序
    sortedClassCount = sorted(classCount.items(), key=op.itemgetter(1), reverse=True)  #即按照字典classCount的value降序排序

    #返回样本数量最多的类别(在距离最近的k个样本中)
    return sortedClassCount[0][0]  #上一步sorted返回类型为一个list,list的元素为tuple,tuple的第一个元素是类别,第二个元素为该类别对应的样本数

测试

细节解析

  • shape

shape[x] 计算x轴的长度,轴即一个维度
个人理解:
    看array的初始化格式,()里是一个list,list里还可以是list。例如,group的最外层list是[[],[],[],[]], 有四个元素,即shape[0]等于4;最外层即第一层的list的元素类型也是list,例如[1, 1.1],则第二层的list的元素个数为2,所以shape[1]等于2。
    以此类推,如果第二层list的元素也是list,则有shape[2]。
    总结,array是一层层嵌套的list。shape[0]即第一层list的元素个数,shape[1]即第二层list的元素个数,shape[D-1]为第D层list的元素个数。
    
请问,n*m维度的数组,shape[0]=?
答案:shape[0]=n, shape[1]=m。(不要搞混概念哦。n*m数组是2D的,即只有两层list,第一层n个元素,每个元素是list类型,第二层m个元素)
同理,n*m*k维度的数组是3D的,shape[2]=k。

示例:

  • tile

  • sum

  • dict

    • dict.get()

    • dict.items()

  • argsort()

  • sorted

  • operator.itemgetter()