• 首页 首页 icon
  • 工具库 工具库 icon
    • IP查询 IP查询 icon
  • 内容库 内容库 icon
    • 快讯库 快讯库 icon
    • 精品库 精品库 icon
    • 问答库 问答库 icon
  • 更多 更多 icon
    • 服务条款 服务条款 icon

实验-编写决策树的企鹅分类(python)

武飞扬头像
Yarhanry
帮助1

实验原理

决策树的核心思想是基于树结构对数据进行划分,这种思想是人类处理问题时的本能方法。
学新通
优点:
1.具有很好的解释性,模型可以生成可以理解的规则。
2.可以发现特征的重要程度。
3.模型的计算复杂度较低。
缺点:
1.模型容易过拟合,需要采用剪枝技术处理。
2.不能很好利用连续型特征。
3.方差较高,数据分布的轻微改变很容易造成树结构完全不同。

实验数据

企鹅数据,该数据集一共包含8个变量,其中7个特征变量,1个目标分类变量。共有344个样本,目标变量为 企鹅的类别,分别是(Adélie, Chinstrap and Gentoo)。包含的三种企鹅的七个特征,分别是所在岛屿,嘴巴长度,嘴巴深度,脚蹼长度,身体体积,性别以及年龄。

实验准备

1.数据集中存在NaN,一般的我们认为NaN在数据集中代表了缺失值,可能是数据采集或处理时产生的一种错误。这里我们采用-1将缺失值进行填补,还有其他例如“中位数填补、平均数填补”的缺失值处理方法。
2.每一个特征都应该是数值(整型或者实数)类型的。但是我们看到, Species列的取值都是类别(categorical)型的。所以,必须经过一步转换,把这些类别都映射成为某个数值,才能进行下面的步骤。

实验要求

1.展示模型分类准确度与测试集上的混淆矩阵
2.可视化决策树

实验过程

数据保存

调包实现保存格式如图所示,将Species放在第一列读取
学新通
手写实现保存格式如图所示,将Species放在最后一列读取
学新通

代码展示

类型转换

def transition(x):  # 将文字类型的信息转换为数值类型
    # Species中0为Adelie Penguin (Pygoscelis adeliae)
    # 1为Gentoo penguin (Pygoscelis papua)
    # 2为Chinstrap penguin (Pygoscelis antarctica)
    # Island中0为Torgersen  1为Biscoe  2为Dream
    # Sex中0为MALE  1为FEMALE
    if (x == data['Species'].unique()[0]):
        return 0
    if (x == data['Species'].unique()[1]):
        return 1
    if (x == data['Species'].unique()[2]):
        return 2
    if (x == data['Island'].unique()[0]):
        return 0
    if (x == data['Island'].unique()[1]):
        return 1
    if (x == data['Island'].unique()[2]):
        return 2
    if (x == data['Sex'].unique()[0]):
        return 0
    if (x == data['Sex'].unique()[1]):
        return 1
    if (x == data['Sex'].unique()[2]):
        return -1.0

可视化决策树

def getNumLeafs(myTree):  # 获取决策树叶子结点数目
    numLeafs = 0
    firstStr = next(iter(myTree))
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if (type(secondDict[key]) is dict):
            numLeafs  = getNumLeafs(secondDict[key])
        else:
            numLeafs  = 1
    return numLeafs


def getTreeDepth(myTree):  # 获取决策树层数
    maxDepth = 0
    firstStr = next(iter(myTree))
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if (type(secondDict[key]) is dict):
            thisDepth = 1   getTreeDepth(secondDict[key])
        else:
            thisDepth = 1
        if (thisDepth > maxDepth):
            maxDepth = thisDepth
    return maxDepth


def plotNode(
        nodeTxt, centerPt, parentPt,
        nodeType):  # nodeTxt为结点名centerPt为文本位置parentPt为标注的箭头位置nodeType为结点格式
    arrow_args = dict(arrowstyle="<-")  # 定义箭头格式
    createPlot.ax1.annotate(
        nodeTxt,
        xy=parentPt,
        xycoords='axes fraction',  # 绘制结点
        xytext=centerPt,
        textcoords='axes fraction',
        va="center",
        ha="center",
        bbox=nodeType,
        arrowprops=arrow_args)


def plotMidText(cntrPt, parentPt,
                txtString):  # 标注有向边属性值,cntrPt、parentPt用于计算标注位置,txtString为标注的内容
    xMid = (parentPt[0] - cntrPt[0]) / 2.0   cntrPt[0]  # 计算标注位置
    yMid = (parentPt[1] - cntrPt[1]) / 2.0   cntrPt[1]
    createPlot.ax1.text(xMid,
                        yMid,
                        txtString,
                        va="center",
                        ha="center",
                        rotation=30)


def plotTree(myTree, parentPt, nodeTxt):  # parentPt为标注的内容,nodeTxt为结点名
    decisionNode = dict(boxstyle="sawtooth", fc="0.8")  # 设置结点格式
    leafNode = dict(boxstyle="round4", fc="0.8")  # 设置叶结点格式
    numLeafs = getNumLeafs(myTree)
    firstStr = next(iter(myTree))
    cntrPt = (plotTree.xOff   (1.0   float(numLeafs)) / 2.0 / plotTree.totalW,
              plotTree.yOff)  # 中心位置
    plotMidText(cntrPt, parentPt, nodeTxt)  # 标注有向边属性值
    plotNode(firstStr, cntrPt, parentPt, decisionNode)
    secondDict = myTree[firstStr]
    plotTree.yOff = plotTree.yOff - 1.0 / plotTree.totalD
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':
            plotTree(secondDict[key], cntrPt, str(key))
        else:
            plotTree.xOff = plotTree.xOff   1.0 / plotTree.totalW
            plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt,
                     leafNode)
            plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
    plotTree.yOff = plotTree.yOff   1.0 / plotTree.totalD


def createPlot(inTree):
    fig = plt.figure(1, facecolor='white')
    fig.clf()
    axprops = dict(xticks=[], yticks=[])
    createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)
    plotTree.totalW = float(getNumLeafs(inTree))
    plotTree.totalD = float(getTreeDepth(inTree))
    plotTree.xOff = -0.5 / plotTree.totalW
    plotTree.yOff = 1.0
    plotTree(inTree, (0.5, 1.0), '')
    plt.show()

创建数据集

def createDataSet():  # 创建数据集
    global data
    data = data[[
        'Island',
        'Culmen Length (mm)',
        'Culmen Depth (mm)',
        'Flipper Length (mm)',
        'Body Mass (g)',
        'Sex',
        'Age',
        'Species',
    ]]
    data = data.fillna(-1)  # 用-1补缺失值
    data['Species'] = data['Species'].apply(transition)  # 将文字类型的信息转换为数值类型
    data['Island'] = data['Island'].apply(transition)
    data['Sex'] = data['Sex'].apply(transition)
    dataSet = []  # 存储所有企鹅信息
    for i in range(344):
        dataSet.append(list(data.iloc[i, :]))  # 按每行每行的存
    labels = [
        'Island', 'Culmen Length (mm)', 'Culmen Depth (mm)',
        'Flipper Length (mm)', 'Body Mass (g)', 'Sex', 'Age'
    ]  # 分类特征标签
    return dataSet, labels  # 返回数据集和分类特征

选择最优特征

对于连续型数据值,要选取一个划分节点(小于等于该值的为一类,大于该值的为一类),而该节点保证使得信息熵最小,先给所有连续型数值排序,每两两取中点作为一个节点,计算信息熵,最后取使得信息熵最小的节点为划分节点。

def chooseBestFeatureToSplit(dataSet, labelProperty):  # 选择最优特征
    numFeatures = len(labelProperty)
    baseEntropy = calcEnt(dataSet)  # 计算数据集的信息熵
    bestInfoGain = 0.0  # 最优信息增益
    bestFeature = -1  # 最优特征的索引值
    bestPartValue = None  # 连续特征值中最佳的划分值
    for i in range(numFeatures):
        featList = [example[i] for example in dataSet]
        uniqueVals = set(featList)  # 把每个特征的值的类别提取出来
        newEntropy = 0.0
        bestPartValuei = None
        if labelProperty[i] == 0:  # 对离散的特征
            for value in uniqueVals:  # 计算信息增益
                subDataSet = splitDataSet(dataSet, i,
                                          value)
                prob = len(subDataSet) / float(len(dataSet))
                newEntropy  = prob * calcEnt(subDataSet)
        else:  # 对连续的特征
            sortedUniqueVals = list(uniqueVals)
            sortedUniqueVals.sort()
            minEntropy = float('inf')  # 选出最小熵时作为划分连续值的节点
            for j in range(len(sortedUniqueVals) - 1):  # 计算划分点
                partValue = (float(sortedUniqueVals[j])  
                             float(sortedUniqueVals[j   1])) / 2
                # 对每个划分点,计算信息熵
                dataSetLeft = splitDataSet_c(dataSet, i, partValue, 'L')
                dataSetRight = splitDataSet_c(dataSet, i, partValue, 'R')
                probLeft = len(dataSetLeft) / float(len(dataSet))
                probRight = len(dataSetRight) / float(len(dataSet))
                Entropy = probLeft * calcEnt(
                    dataSetLeft)   probRight * calcEnt(dataSetRight)
                if (Entropy < minEntropy):  # 取最小的信息熵
                    minEntropy = Entropy
                    bestPartValuei = partValue
            newEntropy = minEntropy
        infoGain = baseEntropy - newEntropy  # 计算信息增益
        if (infoGain > bestInfoGain):  # 取最大的信息增益对应的特征
            bestInfoGain = infoGain
            bestFeature = i
            bestPartValue = bestPartValuei
    return bestFeature, bestPartValue

创建决策树

# 后剪枝:用其叶节点代替某些子树(该叶子节点所标识的类别通过大多数原则确定(大多数的类别表示这个叶节点))
# 创建树, 样本集 特征 特征属性(0 离散, 1 连续)
def createTree(dataSet, labels, labelProperty):
    classList = [example[-1] for example in dataSet]  # 类别
    if classList.count(classList[0]) == len(classList):  # 如果只有一个类别,返回
        return classList[0]
    if len(dataSet[0]) == 1:  # 如果所有特征都被遍历完了,返回出现次数最多的类别
        return majorityCnt(classList)
    bestFeat, bestPartValue = chooseBestFeatureToSplit(
        dataSet, labelProperty)  # 选择最优分类特征
    if bestFeat == -1:  # 如果无法选出最优分类特征,返回出现次数最多的类别
        return majorityCnt(classList)
    if labelProperty[bestFeat] == 0:  # 对离散的特征
        bestFeatLabel = labels[bestFeat]
        myTree = {bestFeatLabel: {}}
        labelsNew = copy.copy(labels)
        labelPropertyNew = copy.copy(labelProperty)
        del (labelsNew[bestFeat])  # 已经选择的特征不再参与分类
        del (labelPropertyNew[bestFeat])
        featValues = [example[bestFeat] for example in dataSet]
        uniqueValue = set(featValues)  # 该特征包含的所有值
        for value in uniqueValue:  # 对每个特征值,递归构建树
            subLabels = labelsNew[:]
            subLabelProperty = labelPropertyNew[:]
            myTree[bestFeatLabel][value] = createTree(
                splitDataSet(dataSet, bestFeat, value), subLabels,
                subLabelProperty)
    else:  # 对连续的特征分别构建左子树和右子树
        bestFeatLabel = labels[bestFeat]   '<'   str(bestPartValue)
        myTree = {bestFeatLabel: {}}
        subLabels = labels[:]
        subLabelProperty = labelProperty[:]
        # 构建左子树
        valueLeft = 'Yes'
        myTree[bestFeatLabel][valueLeft] = createTree(
            splitDataSet_c(dataSet, bestFeat, bestPartValue, 'L'), subLabels,
            subLabelProperty)
        # 构建右子树
        valueRight = 'No'
        myTree[bestFeatLabel][valueRight] = createTree(
            splitDataSet_c(dataSet, bestFeat, bestPartValue, 'R'), subLabels,
            subLabelProperty)
    return myTree

进行分类预测

def classify(inputTree, featLabels, featLabelProperties, testVec):
    firstStr = list(inputTree.keys())[0]  # 根节点
    firstLabel = firstStr
    lessIndex = str(firstStr).find('<')
    if lessIndex > -1:  # 如果是连续型的特征
        firstLabel = str(firstStr)[:lessIndex]
    secondDict = inputTree[firstStr]
    featIndex = featLabels.index(firstLabel)  # 跟节点对应的特征
    classLabel = None
    for key in secondDict.keys():  # 对每个分支循环
        if featLabelProperties[featIndex] == 0:  # 离散的特征
            if testVec[featIndex] == key:
                if type(secondDict[key]).__name__ == 'dict':
                    classLabel = classify(secondDict[key], featLabels,
                                          featLabelProperties, testVec)
                else:
                    classLabel = secondDict[key]
        else:  # 连续的特征
            partValue = float(str(firstStr)[lessIndex   1:])
            if testVec[featIndex] < partValue:  # 进入左子树
                if type(secondDict['Yes']).__name__ == 'dict':
                    classLabel = classify(secondDict['Yes'], featLabels,
                                          featLabelProperties, testVec)
                else:
                    classLabel = secondDict['Yes']
            else:
                if type(secondDict['No']).__name__ == 'dict':
                    classLabel = classify(secondDict['No'], featLabels,
                                          featLabelProperties, testVec)
                else:
                    classLabel = secondDict['No']

    return classLabel

调包总代码实现

import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import accuracy_score
from sklearn.metrics import confusion_matrix
from sklearn import tree
import graphviz


def transition(x):  # 将文字类型的信息转换为数值类型
    # Species中0为Adelie Penguin (Pygoscelis adeliae)
    # 1为Gentoo penguin (Pygoscelis papua)
    # 2为Chinstrap penguin (Pygoscelis antarctica)
    # Island中0为Torgersen  1为Biscoe  2为Dream
    # Sex中0为MALE  1为FEMALE
    if (x == data['Species'].unique()[0]):
        return 0
    if (x == data['Species'].unique()[1]):
        return 1
    if (x == data['Species'].unique()[2]):
        return 2
    if (x == data['Island'].unique()[0]):
        return 0
    if (x == data['Island'].unique()[1]):
        return 1
    if (x == data['Island'].unique()[2]):
        return 2
    if (x == data['Sex'].unique()[0]):
        return 0
    if (x == data['Sex'].unique()[1]):
        return 1
    if (x == data['Sex'].unique()[2]):
        return -1


data = pd.read_csv('D:/1_penguin.csv')
data = data[[
    'Species', 'Island', 'Culmen Length (mm)', 'Culmen Depth (mm)',
    'Flipper Length (mm)', 'Body Mass (g)', 'Sex', 'Age'
]]
data = data.fillna(-1)  # 用-1补缺失值
data['Species'] = data['Species'].apply(transition)  # 将文字类型的信息转换为数值类型
data['Island'] = data['Island'].apply(transition)
data['Sex'] = data['Sex'].apply(transition)

goal = data[data['Species'].isin([0, 1, 2])][['Species']]  # 三分类
feature = data[data['Species'].isin([0, 1, 2])][[
    'Island', 'Culmen Length (mm)', 'Culmen Depth (mm)', 'Flipper Length (mm)',
    'Body Mass (g)', 'Sex', 'Age'
]]
x_train, x_test, y_train, y_test = train_test_split(feature,
                                                    goal,
                                                    test_size=0.2,
                                                    random_state=100)
Tree = DecisionTreeClassifier(
    criterion='entropy')  # 定义一颗决策树,entropy表示采用信息增益来选择特征
Tree.fit(x_train, y_train)  # 训练模型
data2 = tree.export_graphviz(Tree, out_file=None)  # 画决策树
graph = graphviz.Source(data2)
graph.render("penguins")

train_pre = Tree.predict(x_train)  # 进行预测,返回预测标签
test_pre = Tree.predict(x_test)
# 其中第一列代表预测为0类的概率,第二列代表预测为1类的概率,第三列代表预测为2类的概率。
# 利用accuracy(准确度)【预测正确的样本数目占总预测样本数目的比例】评估模型效果
print('The accuracy of the Logistic Regression is:',
      accuracy_score(y_train, train_pre))
print('The accuracy of the Logistic Regression is:',
      accuracy_score(y_test, test_pre))

# 查看混淆矩阵
confu_matrix = confusion_matrix(test_pre, y_test)
plt.figure(figsize=(8, 6))
sns.heatmap(confu_matrix, annot=True, cmap='Blues')
plt.xlabel('Predicted labels')
plt.ylabel('True labels')
plt.show()

手写实现(处理连续以及离散值)

from math import log
import pandas as pd
import numpy as np
import operator
import seaborn as sns
from sklearn.metrics import confusion_matrix
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
import matplotlib.pyplot as plt
import copy


def getNumLeafs(myTree):  # 获取决策树叶子结点数目
    numLeafs = 0
    firstStr = next(iter(myTree))
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if (type(secondDict[key]) is dict):
            numLeafs  = getNumLeafs(secondDict[key])
        else:
            numLeafs  = 1
    return numLeafs


def getTreeDepth(myTree):  # 获取决策树层数
    maxDepth = 0
    firstStr = next(iter(myTree))
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if (type(secondDict[key]) is dict):
            thisDepth = 1   getTreeDepth(secondDict[key])
        else:
            thisDepth = 1
        if (thisDepth > maxDepth):
            maxDepth = thisDepth
    return maxDepth


def plotNode(
        nodeTxt, centerPt, parentPt,
        nodeType):  # nodeTxt为结点名centerPt为文本位置parentPt为标注的箭头位置nodeType为结点格式
    arrow_args = dict(arrowstyle="<-")  # 定义箭头格式
    createPlot.ax1.annotate(
        nodeTxt,
        xy=parentPt,
        xycoords='axes fraction',  # 绘制结点
        xytext=centerPt,
        textcoords='axes fraction',
        va="center",
        ha="center",
        bbox=nodeType,
        arrowprops=arrow_args)


def plotMidText(cntrPt, parentPt,
                txtString):  # 标注有向边属性值,cntrPt、parentPt用于计算标注位置,txtString为标注的内容
    xMid = (parentPt[0] - cntrPt[0]) / 2.0   cntrPt[0]  # 计算标注位置
    yMid = (parentPt[1] - cntrPt[1]) / 2.0   cntrPt[1]
    createPlot.ax1.text(xMid,
                        yMid,
                        txtString,
                        va="center",
                        ha="center",
                        rotation=30)


def plotTree(myTree, parentPt, nodeTxt):  # parentPt为标注的内容,nodeTxt为结点名
    decisionNode = dict(boxstyle="sawtooth", fc="0.8")  # 设置结点格式
    leafNode = dict(boxstyle="round4", fc="0.8")  # 设置叶结点格式
    numLeafs = getNumLeafs(myTree)
    firstStr = next(iter(myTree))
    cntrPt = (plotTree.xOff   (1.0   float(numLeafs)) / 2.0 / plotTree.totalW,
              plotTree.yOff)  # 中心位置
    plotMidText(cntrPt, parentPt, nodeTxt)  # 标注有向边属性值
    plotNode(firstStr, cntrPt, parentPt, decisionNode)
    secondDict = myTree[firstStr]
    plotTree.yOff = plotTree.yOff - 1.0 / plotTree.totalD
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':
            plotTree(secondDict[key], cntrPt, str(key))
        else:
            plotTree.xOff = plotTree.xOff   1.0 / plotTree.totalW
            plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt,
                     leafNode)
            plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
    plotTree.yOff = plotTree.yOff   1.0 / plotTree.totalD


def createPlot(inTree):
    fig = plt.figure(1, facecolor='white')
    fig.clf()
    axprops = dict(xticks=[], yticks=[])
    createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)
    plotTree.totalW = float(getNumLeafs(inTree))
    plotTree.totalD = float(getTreeDepth(inTree))
    plotTree.xOff = -0.5 / plotTree.totalW
    plotTree.yOff = 1.0
    plotTree(inTree, (0.5, 1.0), '')
    plt.show()


def transition(x):  # 将文字类型的信息转换为数值类型
    # Species中0为Adelie Penguin (Pygoscelis adeliae)
    # 1为Gentoo penguin (Pygoscelis papua)
    # 2为Chinstrap penguin (Pygoscelis antarctica)
    # Island中0为Torgersen  1为Biscoe  2为Dream
    # Sex中0为MALE  1为FEMALE
    if (x == data['Species'].unique()[0]):
        return 0
    if (x == data['Species'].unique()[1]):
        return 1
    if (x == data['Species'].unique()[2]):
        return 2
    if (x == data['Island'].unique()[0]):
        return 0
    if (x == data['Island'].unique()[1]):
        return 1
    if (x == data['Island'].unique()[2]):
        return 2
    if (x == data['Sex'].unique()[0]):
        return 0
    if (x == data['Sex'].unique()[1]):
        return 1
    if (x == data['Sex'].unique()[2]):
        return -1.0


def createDataSet():  # 创建数据集
    global data
    data = data[[
        'Island',
        'Culmen Length (mm)',
        'Culmen Depth (mm)',
        'Flipper Length (mm)',
        'Body Mass (g)',
        'Sex',
        'Age',
        'Species',
    ]]
    data = data.fillna(-1)  # 用-1补缺失值
    data['Species'] = data['Species'].apply(transition)  # 将文字类型的信息转换为数值类型
    data['Island'] = data['Island'].apply(transition)
    data['Sex'] = data['Sex'].apply(transition)
    dataSet = []  # 存储所有企鹅信息
    for i in range(344):
        dataSet.append(list(data.iloc[i, :]))  # 按每行每行的存
    labels = [
        'Island', 'Culmen Length (mm)', 'Culmen Depth (mm)',
        'Flipper Length (mm)', 'Body Mass (g)', 'Sex', 'Age'
    ]  # 分类特征标签
    return dataSet, labels  # 返回数据集和分类特征


def calcEnt(dataSet):  # 计算数据集的信息熵
    numEntires = len(dataSet)
    labelCounts = {}  # 保存每个标签出现的次数
    for featVec in dataSet:  # 对每组特征向量进行统计
        currentLabel = featVec[-1]
        labelCounts[currentLabel] = labelCounts.get(currentLabel, 0)   1
    inforEnt = 0.0
    for key in labelCounts:
        prob = float(labelCounts[key]) / numEntires  # 选择该标签的概率
        inforEnt -= prob * log(prob, 2)
    return inforEnt  # 返回该数据集的信息熵


def splitDataSet(dataSet, axis,
                 value):  # 划分数据集中每个特征的不同值,axis为数据集要划分的特征,value为划分的特征的值
    retDataSet = []
    for featVec in dataSet:
        if featVec[axis] == value:
            reducedFeatVec = featVec[:axis]  # 去掉axis特征
            reducedFeatVec.extend(featVec[axis   1:])
            retDataSet.append(featVec)
    return retDataSet  # 返回划分后(只留下为第axis特征为value的值的该条数据)的数据集


def splitDataSet_c(dataSet,
                   axis,
                   value,
                   LorR='L'):  # LorR: 取得value值左侧(小于)或右侧(大于)的数据集
    retDataSet = []
    if LorR == 'L':
        for featVec in dataSet:
            if float(featVec[axis]) < value:
                retDataSet.append(featVec)
    else:
        for featVec in dataSet:
            if float(featVec[axis]) > value:
                retDataSet.append(featVec)
    return retDataSet


def chooseBestFeatureToSplit(dataSet, labelProperty):  # 选择最优特征
    numFeatures = len(labelProperty)
    baseEntropy = calcEnt(dataSet)  # 计算数据集的信息熵
    bestInfoGain = 0.0  # 最优信息增益
    bestFeature = -1  # 最优特征的索引值
    bestPartValue = None  # 连续特征值中最佳的划分值
    for i in range(numFeatures):
        featList = [example[i] for example in dataSet]
        uniqueVals = set(featList)  # 把每个特征的值的类别提取出来
        newEntropy = 0.0
        bestPartValuei = None
        if labelProperty[i] == 0:  # 对离散的特征
            for value in uniqueVals:  # 计算信息增益
                subDataSet = splitDataSet(dataSet, i,
                                          value)
                prob = len(subDataSet) / float(len(dataSet))
                newEntropy  = prob * calcEnt(subDataSet)
        else:  # 对连续的特征
            sortedUniqueVals = list(uniqueVals)
            sortedUniqueVals.sort()
            minEntropy = float('inf')  # 选出最小熵时作为划分连续值的节点
            for j in range(len(sortedUniqueVals) - 1):  # 计算划分点
                partValue = (float(sortedUniqueVals[j])  
                             float(sortedUniqueVals[j   1])) / 2
                # 对每个划分点,计算信息熵
                dataSetLeft = splitDataSet_c(dataSet, i, partValue, 'L')
                dataSetRight = splitDataSet_c(dataSet, i, partValue, 'R')
                probLeft = len(dataSetLeft) / float(len(dataSet))
                probRight = len(dataSetRight) / float(len(dataSet))
                Entropy = probLeft * calcEnt(
                    dataSetLeft)   probRight * calcEnt(dataSetRight)
                if (Entropy < minEntropy):  # 取最小的信息熵
                    minEntropy = Entropy
                    bestPartValuei = partValue
            newEntropy = minEntropy
        infoGain = baseEntropy - newEntropy  # 计算信息增益
        if (infoGain > bestInfoGain):  # 取最大的信息增益对应的特征
            bestInfoGain = infoGain
            bestFeature = i
            bestPartValue = bestPartValuei
    return bestFeature, bestPartValue


def majorityCnt(classList):  # 统计出现次数最多的类的标签
    classCount = {}
    for vote in classList:
        classCount[vote] = classCount.get(vote, 0)   1
    sortedClassCount = sorted(classCount.items(),
                              key=operator.itemgetter(1),
                              reverse=True)
    return sortedClassCount[0][0]


# 后剪枝:用其叶节点代替某些子树(该叶子节点所标识的类别通过大多数原则确定(大多数的类别表示这个叶节点))
# 创建树, 样本集 特征 特征属性(0 离散, 1 连续)
def createTree(dataSet, labels, labelProperty):
    classList = [example[-1] for example in dataSet]  # 类别
    if classList.count(classList[0]) == len(classList):  # 如果只有一个类别,返回
        return classList[0]
    if len(dataSet[0]) == 1:  # 如果所有特征都被遍历完了,返回出现次数最多的类别
        return majorityCnt(classList)
    bestFeat, bestPartValue = chooseBestFeatureToSplit(
        dataSet, labelProperty)  # 选择最优分类特征
    if bestFeat == -1:  # 如果无法选出最优分类特征,返回出现次数最多的类别
        return majorityCnt(classList)
    if labelProperty[bestFeat] == 0:  # 对离散的特征
        bestFeatLabel = labels[bestFeat]
        myTree = {bestFeatLabel: {}}
        labelsNew = copy.copy(labels)
        labelPropertyNew = copy.copy(labelProperty)
        del (labelsNew[bestFeat])  # 已经选择的特征不再参与分类
        del (labelPropertyNew[bestFeat])
        featValues = [example[bestFeat] for example in dataSet]
        uniqueValue = set(featValues)  # 该特征包含的所有值
        for value in uniqueValue:  # 对每个特征值,递归构建树
            subLabels = labelsNew[:]
            subLabelProperty = labelPropertyNew[:]
            myTree[bestFeatLabel][value] = createTree(
                splitDataSet(dataSet, bestFeat, value), subLabels,
                subLabelProperty)
    else:  # 对连续的特征分别构建左子树和右子树
        bestFeatLabel = labels[bestFeat]   '<'   str(bestPartValue)
        myTree = {bestFeatLabel: {}}
        subLabels = labels[:]
        subLabelProperty = labelProperty[:]
        # 构建左子树
        valueLeft = 'Yes'
        myTree[bestFeatLabel][valueLeft] = createTree(
            splitDataSet_c(dataSet, bestFeat, bestPartValue, 'L'), subLabels,
            subLabelProperty)
        # 构建右子树
        valueRight = 'No'
        myTree[bestFeatLabel][valueRight] = createTree(
            splitDataSet_c(dataSet, bestFeat, bestPartValue, 'R'), subLabels,
            subLabelProperty)
    return myTree


def classify(inputTree, featLabels, featLabelProperties, testVec):
    firstStr = list(inputTree.keys())[0]  # 根节点
    firstLabel = firstStr
    lessIndex = str(firstStr).find('<')
    if lessIndex > -1:  # 如果是连续型的特征
        firstLabel = str(firstStr)[:lessIndex]
    secondDict = inputTree[firstStr]
    featIndex = featLabels.index(firstLabel)  # 跟节点对应的特征
    classLabel = None
    for key in secondDict.keys():  # 对每个分支循环
        if featLabelProperties[featIndex] == 0:  # 离散的特征
            if testVec[featIndex] == key:
                if type(secondDict[key]).__name__ == 'dict':
                    classLabel = classify(secondDict[key], featLabels,
                                          featLabelProperties, testVec)
                else:
                    classLabel = secondDict[key]
        else:  # 连续的特征
            partValue = float(str(firstStr)[lessIndex   1:])
            if testVec[featIndex] < partValue:  # 进入左子树
                if type(secondDict['Yes']).__name__ == 'dict':
                    classLabel = classify(secondDict['Yes'], featLabels,
                                          featLabelProperties, testVec)
                else:
                    classLabel = secondDict['Yes']
            else:
                if type(secondDict['No']).__name__ == 'dict':
                    classLabel = classify(secondDict['No'], featLabels,
                                          featLabelProperties, testVec)
                else:
                    classLabel = secondDict['No']

    return classLabel


data = pd.read_csv('D:/企鹅数据.csv')  # 不能带有index_col=0参数,否则不能按名称读取数据,读取后第一列是行号,不影响每个数据的读取,索引没有变
feature_name = [
    'Island', 'Culmen Length (mm)', 'Culmen Depth (mm)', 'Flipper Length (mm)',
    'Body Mass (g)', 'Sex', 'Age'
]
dataSet, labels = createDataSet()
labelProperties = [0, 1, 1, 1, 1, 0, 1]  # 属性的类型,0表示离散,1表示连续
myTree = createTree(dataSet, labels, labelProperties)
feature = data[[
    'Island', 'Culmen Length (mm)', 'Culmen Depth (mm)', 'Flipper Length (mm)',
    'Body Mass (g)', 'Sex', 'Age'
]]
goal = data['Species']
x_train, x_test, y_train, y_test = train_test_split(feature,
                                                    goal,
                                                    test_size=0.2)
x_test = np.array(x_test)
x_test = x_test.tolist()
test_pre = []  # 存放预测结果的类别标签
for i in range(len(x_test)):
    result = classify(myTree, labels, labelProperties, x_test[i])
    test_pre.append(result)
print("预测准确率为:", accuracy_score(y_test, test_pre))
confu_matrix = confusion_matrix(test_pre, y_test)  # 查看混淆矩阵
plt.figure(figsize=(8, 6))
sns.heatmap(confu_matrix, annot=True, cmap='Blues')
plt.xlabel('Predicted labels')
plt.ylabel('True labels')
plt.show()
createPlot(myTree)  # 决策树可视化
print(myTree)

实验结果

调包实现

学新通
学新通

手写

因为数据分为离散型和连续型,所以此方法,可处理两种类型的数据
学新通
进行了后剪枝的结果
学新通
学新通

这篇好文章是转载于:学新通技术网

  • 版权申明: 本站部分内容来自互联网,仅供学习及演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,请提供相关证据及您的身份证明,我们将在收到邮件后48小时内删除。
  • 本站站名: 学新通技术网
  • 本文地址: /boutique/detail/tanhgebkhk
系列文章
更多 icon
同类精品
更多 icon
继续加载