AI嵌入式系统:算法优化与实现
上QQ阅读APP看书,第一时间看更新

3.1 高斯朴素贝叶斯分类器

3.1.1 原理概述

朴素贝叶斯算法是相对简单的机器学习算法,它使用贝叶斯公式,从多个统计独立的观测量计算某个随机变量Y的取不同值的可能性(注意,我们这里用大写字母表示随机变量,用小写字母表示它的具体取值)。下面通过一个例子对高斯朴素贝叶斯模型进行简要说明,关于这一模型的详细理论分析可以参考本章参考文献部分。

考虑鸢尾花的分类问题,这种花有三种类型,分别为Setosa、Versicolour和Virginica,可以用取值分别为0,1,2的变量Y表示。我们希望通过测量鸢尾花的花萼长度X来区分它的类别,这可以通过计算条件概率实现:

上述概率代表了测得一朵鸢尾花的花萼长度为x时,它属于类别y的可能性。上述概率可以用贝叶斯公式表示,即

其中概率表示对于类型是y的鸢尾花,花萼长度为x的可能性。我们通常用高斯分布来描述这个条件概率,即

即X是由Y的取值y决定的高斯随机变量,X的均值和方差分别为

上面的例子是从一个观测量X计算出鸢尾花属于不同类别的概率。如果有多个不同的观测量,就能够更精确地判别鸢尾花的类别。我们可以测量鸢尾花的花萼长度、花萼宽度、花瓣长度、花瓣宽度这4个属性的具体数值,分别用表示它们,我们进一步假设这4个属性相互独立(统计独立),于是可以得到从这些观测量计算Y的条件概率,即

上述观测量相互独立以及高斯分布的模型就是“高斯朴素贝叶斯模型”。

3.1.2 模型训练和推理

下面我们基于Python的机器学习软件包Scikit-Learn说明如何训练高斯朴素贝叶斯模型。这里不会涉及模型训练的数学解释,仅仅是介绍训练所使用的Python代码。

我们还是以鸢尾花卉分类问题为例。Fisher于1936年收集整理了三种鸢尾花的花萼长度、花萼宽度、花瓣长度、花瓣宽度的测量值,这些数据能够从Scikit-Learn中直接获得。数据包括了3个类别共150朵鸢尾花测量数据,每朵花的测量值包括4个数值,每个数值对应前面所给出的一个属性。Python程序通过Scikit-Learn库的API读取鸢尾花数据,具体代码如下:


from sklearn import datasets
iris = datasets.load_iris()

运行之后变量iris中就存储了150朵鸢尾花测量数据和对应的花的类型数据。通过下面的命令能够分别打印出对每一朵鸢尾花的测量结果。


print(iris.data)
print(iris.target)

iris.data是尺寸为150×4的矩阵,每一行对应一朵花的测量数据,iris.target是存放了150个整数元素的数组,其中元素取值0、1、2分别对应Setosa、Versicolor、Virginica这三种类型。

下面是iris.data和iris.target的数据内容片段:


iris.data:
[[6.4 2.9 4.3 1.3]
 [6.5 3.  5.5 1.8]
 [5.  2.3 3.3 1. ]
 [6.3 3.3 6.  2.5]
 [5.5 2.5 4.  1.3]
 [5.4 3.7 1.5 0.2]
           …
 [6.7 3.1 5.6 2.4]
 [4.9 3.6 1.4 0.1]]
iris.target:
[1 2 1 2 1 0 … 1 2]

下面的代码利用加载的iris数据进行训练,得到高斯朴素贝叶斯模型:


from sklearn.naive_bayes import GaussianNB
model = GaussianNB()  # 构建高斯朴素贝叶斯模型
model.fit(iris.data, iris.target)

高斯朴素贝叶斯模型参数存储在变量model内,其中高斯分布的方差存放在model.sigma_中,而高斯分布的均值存放在model.theta_中。

完成模型训练后,使用下面的代码实现模型的推理,即对类别未知的数据进行分类:


y_pred = model.predict(new_data)

其中new_data是存放需要分类的花的测量数据,每一行对应一朵花的4个测量值,程序中y_pred是列向量,它的元素对应了new_data中对应行的鸢尾花分类结果。

注意,在上述训练过程中,先验概率是从训练数据中统计得到的(用每种类别在训练数据集中出现的比例作为先验概率的估计值),提供的iris训练数据中三类花的数量相同,因此先验概率。如果需要使用其他先验概率,那么可以在构建模型的时候提供先验数据作为输入参数,即


model= GaussianNB(priors)

上面代码中priors是用户提供的三类花的先验概率数组。