• 首页 首页 icon
  • 工具库 工具库 icon
    • IP查询 IP查询 icon
  • 内容库 内容库 icon
    • 快讯库 快讯库 icon
    • 精品库 精品库 icon
    • 问答库 问答库 icon
  • 更多 更多 icon
    • 服务条款 服务条款 icon

sklearn决策树Decision Trees模型

武飞扬头像
qq_27390023
帮助1

决策树(DT)是一种用于分类和回归的非参数化监督学习方法。其目的是创建一个模型,通过学习从数据特征推断出的简单决策规则来预测目标变量的值。一棵树可以被看作是一个分片的常数近似。

决策树的一些优点:

  • 易于理解和解释。树可以被视觉化。
  • 需要很少的数据准备。其他技术通常需要数据规范化,需要创建虚拟变量并删除空白值。但是请注意,这个模块不支持缺失值。
  • 使用树的成本(即预测数据)与用于训练树的数据点的数量成对数关系。
  • 能够处理数字和分类数据。但是scikit-learn的实现暂时不支持分类变量。其他技术通常专门用于分析只有一种类型变量的数据集。更多信息请参见算法。
  • 能够处理多输出问题。
  • 使用白盒模型。如果一个给定的情况在模型中是可以观察到的,那么对该情况的解释就很容易用布尔逻辑来解释。相比之下,在一个黑箱模型中(如在人工神经网络中),结果可能更难解释。
  • 有可能使用统计测试来验证一个模型。这使得核算模型的可靠性成为可能。
  • 即使其假设在某种程度上违反了数据产生的真实模型,也能表现良好。

决策树的缺点:

  • 决策树学习者可以创建过于复杂的树,不能很好地概括数据。这就是所谓的过度拟合。诸如修剪、设置叶子节点所需的最小样本数或设置树的最大深度等机制对于避免这一问题是必要的。
  • 决策树可能是不稳定的,因为数据的微小变化可能会导致生成一个完全不同的树。这个问题可以通过在一个集合体中使用决策树而得到缓解。
  • 决策树的预测既不是平滑的,也不是连续的,而是如上图所示的片状常数近似值。因此,它们不善于推断。
  • 众所周知,学习最优决策树的问题在几个方面的最优性下是NP-complete的,甚至对于简单的概念也是如此。因此,实用的决策树学习算法是基于启发式算法,如贪婪算法,在每个节点上做出局部最优决策。这种算法不能保证返回全局最优的决策树。这一点可以通过在集合学习器中训练多棵树来缓解,在集合学习器中,特征和样本都是随机抽样的,并有替换。
  • 有一些概念很难学习,因为决策树不容易表达,比如XOR、奇偶性或多路复用器问题。
  • 如果某些类占主导地位,决策树学习者会创建有偏见的树。因此,建议在用决策树拟合之前,平衡数据集。

1. 分类问题

  1.  
    ### 1. classification
  2.  
    from sklearn.datasets import load_iris
  3.  
    from sklearn import tree
  4.  
    from sklearn.tree import export_text
  5.  
    from sklearn.model_selection import train_test_split
  6.  
     
  7.  
    iris = load_iris()
  8.  
    X, y = iris.data, iris.target
  9.  
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.4, random_state=0)
  10.  
    clf = tree.DecisionTreeClassifier(random_state=0, max_depth=2)
  11.  
    clf = clf.fit(X_train, y_train)
  12.  
    r = export_text(clf, feature_names=iris['feature_names'])
  13.  
    print(r)
  14.  
    print(clf.predict(X_test))
  15.  
    #print(clf.predict_proba(X_test)) # probability of each class
  16.  
    print(clf.score(X_test,y_test))
  17.  
     
  18.  
    # plot
  19.  
    import matplotlib.pyplot as plt
  20.  
    from sklearn.tree import plot_tree
  21.  
    plt.figure()
  22.  
    clf = tree.DecisionTreeClassifier().fit(iris.data, iris.target)
  23.  
    plot_tree(clf, filled=True)
  24.  
    plt.title("Decision tree trained on all the iris features")
  25.  
    plt.show()
学新通

2. 回归问题

  1.  
    ### 2. regression
  2.  
    # Import the necessary modules and libraries
  3.  
    import numpy as np
  4.  
    from sklearn.tree import DecisionTreeRegressor
  5.  
    import matplotlib.pyplot as plt
  6.  
     
  7.  
    # Create a random dataset
  8.  
    rng = np.random.RandomState(1)
  9.  
    X = np.sort(5 * rng.rand(80, 1), axis=0)
  10.  
    y = np.sin(X).ravel()
  11.  
    y[::5] = 3 * (0.5 - rng.rand(16))
  12.  
     
  13.  
    # Fit regression model
  14.  
    regr_1 = DecisionTreeRegressor(max_depth=2)
  15.  
    regr_2 = DecisionTreeRegressor(max_depth=5)
  16.  
    regr_1.fit(X, y)
  17.  
    regr_2.fit(X, y)
  18.  
     
  19.  
    # Predict
  20.  
    X_test = np.arange(0.0, 5.0, 0.01)[:, np.newaxis]
  21.  
    y_1 = regr_1.predict(X_test)
  22.  
    y_2 = regr_2.predict(X_test)
  23.  
     
  24.  
    # Plot the results
  25.  
    plt.figure()
  26.  
    plt.scatter(X, y, s=20, edgecolor="black", c="darkorange", label="data")
  27.  
    plt.plot(X_test, y_1, color="cornflowerblue", label="max_depth=2", linewidth=2)
  28.  
    plt.plot(X_test, y_2, color="yellowgreen", label="max_depth=5", linewidth=2)
  29.  
    plt.xlabel("data")
  30.  
    plt.ylabel("target")
  31.  
    plt.title("Decision Tree Regression")
  32.  
    plt.legend()
  33.  
    plt.show()
学新通

参考:

https://scikit-learn.org/stable/modules/tree.html

这篇好文章是转载于:学新通技术网

  • 版权申明: 本站部分内容来自互联网,仅供学习及演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,请提供相关证据及您的身份证明,我们将在收到邮件后48小时内删除。
  • 本站站名: 学新通技术网
  • 本文地址: /boutique/detail/tanhfjciba
系列文章
更多 icon
同类精品
更多 icon
继续加载