• 首页 首页 icon
  • 工具库 工具库 icon
    • IP查询 IP查询 icon
  • 内容库 内容库 icon
    • 快讯库 快讯库 icon
    • 精品库 精品库 icon
    • 问答库 问答库 icon
  • 更多 更多 icon
    • 服务条款 服务条款 icon

sklearn的系统学习——随机森林调参含案例和完整python代码

武飞扬头像
weiAweiww
帮助4

目录

一、 调参核心问题

二、 随机森林调参方向

 三、随机森林调参方法

 1、绘制学习曲线

 2、网格搜索

四、 详细代码


       对于调参,首先需要明白调参的核心问题是什么,然后理清思路,再进行调参。调参并非是一件容易的事情,很多大牛靠的是多年积累的经验和清晰的处理思路,那对于我们而言,也应对调参思路和方向有一个认识,然后就是不断地尝试。

一、调参核心问题

1、调参的目的是什么?

2、模型在未知数据上的准确率受什么因素影响?

泛化误差:衡量模型在未知数据上的准确率(准确率越高,泛化误差越小),受模型复杂度的影响。

模型复杂度与准确率的关系,就像压力值与考试成绩的关系,压力越大或者没有压力成绩往往越低,只有压力适当时,成绩才会更高。同理,模型越复杂或越简单往往结果也会不尽人意,那我们的目标就清楚了,就是将模型不至于太复杂也不至于太简单。比如,当为模型增加复杂度时,准确率提升,泛化误差降低,那说明此时模型有些简单,反之,如果降低模型复杂度,反而准确率提升,那说明此时模型较为复杂,适当调整简单即可。

对于树模型或者树的集成模型,树的深度越深,枝叶越多,模型越复杂。往往树模型或者树的集成模型普遍较为复杂,我们需要做的就是降低复杂度,进而提升准确率。

二、 随机森林调参方向

       降低复杂度,对复杂度影响巨大的参数挑选出来,研究他们的单调性,然后专注调整那些最大限度能让复杂度降低的参数,对于那些不单调的参数或者反而让复杂度升高的参数,视情况而定,大多时候甚至可以退避。(表中从上往下,建议调参的程度逐渐减小)

学新通

 三、随机森林调参方法

 1、绘制学习曲线

       有些参数没有参照,很难说清楚范围,这种情况用学习曲线看趋势,从曲线跑出的结果中选取一个更小的区间,再跑曲线,以此类推(建议打印输出最大值及其取的值)。

  1.  
    #调参第一步:n_estimators
  2.  
    cross = []
  3.  
    for i in range(0,200,10):
  4.  
    rf = RandomForestClassifier(n_estimators=i 1, n_jobs=-1,random_state=42)
  5.  
    cross_score = cross_val_score(rf, xtest, ytest, cv=5).mean()
  6.  
    cross.append(cross_score)
  7.  
    plt.plot(range(1,201,10),cross)
  8.  
    plt.xlabel('n_estimators')
  9.  
    plt.ylabel('acc')
  10.  
    plt.show()
  11.  
    print((cross.index(max(cross))*10) 1,max(cross))

学新通 学新通

2、网格搜索

       有一些参数有一定范围,或者我们知道他们的取值和随着他们的取值模型的准确率会如何变化。在这里值得说明的一点是,网格搜索,如果一次性在参数列表中写出多个参数及对应值,它不会抛弃任何一个我们设置的参数值,会尽力组合,而有时候效果可能不太好,且费时。那建议的操作是,可以一次设定一到两个参数及其值。

  1.  
    from sklearn.model_selection import GridSearchCV
  2.  
    #调整max_depth
  3.  
    param_grid = {'max_depth' : np.arange(1,20,1)}
  4.  
    #一般根据数据大小进行尝试,像该数据集 可从1-101-20开始
  5.  
    rf = RandomForestClassifier(n_estimators=11,random_state=42)
  6.  
    GS = GridSearchCV(rf,param_grid,cv=5)
  7.  
    GS.fit(data.data,data.target)
  8.  
    GS.best_params_ #最佳参数组合
  9.  
    GS.best_score_ #最佳得分

四、详细代码

        代码建议在jupyter notebook分段运行,因为最起码能保证划分的测试集和训练集不会变化,这样调参才有意义。

  1.  
    from sklearn.datasets import load_breast_cancer
  2.  
    from sklearn.model_selection import train_test_split
  3.  
    from sklearn.model_selection import cross_val_score
  4.  
    from sklearn.model_selection import GridSearchCV
  5.  
    import pandas as pd
  6.  
    import numpy as np
  7.  
    import matplotlib.pyplot as plt
  8.  
    from sklearn.ensemble import RandomForestClassifier
  9.  
     
  10.  
    data = load_breast_cancer() #乳腺癌案例
  11.  
    print(data.data.shape)
  12.  
     
  13.  
    xtrain,xtest,ytrain,ytest = train_test_split(data.data,data.target,test_size=0.3)
  14.  
    # GridSearchCV
  15.  
    rf = RandomForestClassifier(n_estimators=100,random_state=42)
  16.  
    rf.fit(xtrain,ytrain)
  17.  
    score = rf.score(xtest,ytest)
  18.  
    cross_s = cross_val_score(rf,xtest,ytest,cv=5).mean()
  19.  
    print('rf:',score)
  20.  
    print('cv:',cross_s)
  21.  
     
  22.  
    #调参第一步:n_estimators
  23.  
    cross = []
  24.  
    for i in range(0,200,10):
  25.  
    rf = RandomForestClassifier(n_estimators=i 1, n_jobs=-1,random_state=42)
  26.  
    cross_score = cross_val_score(rf, xtest, ytest, cv=5).mean()
  27.  
    cross.append(cross_score)
  28.  
    plt.plot(range(1,201,10),cross)
  29.  
    plt.xlabel('n_estimators')
  30.  
    plt.ylabel('acc')
  31.  
    plt.show()
  32.  
    print((cross.index(max(cross))*10) 1,max(cross))
  33.  
    # n_estimators缩小范围
  34.  
    cross = []
  35.  
    for i in range(0,25):
  36.  
    rf = RandomForestClassifier(n_estimators=i 1, n_jobs=-1,random_state=42)
  37.  
    cross_score = cross_val_score(rf, xtest, ytest, cv=5).mean()
  38.  
    cross.append(cross_score)
  39.  
    plt.plot(range(1,26),cross)
  40.  
    plt.xlabel('n_estimators')
  41.  
    plt.ylabel('acc')
  42.  
    plt.show()
  43.  
    print(cross.index(max(cross)) 1,max(cross))
  44.  
     
  45.  
    #调整max_depth
  46.  
    param_grid = {'max_depth' : np.arange(1,20,1)}
  47.  
    #一般根据数据大小进行尝试,像该数据集 可从1-101-20开始
  48.  
    rf = RandomForestClassifier(n_estimators=11,random_state=42)
  49.  
    GS = GridSearchCV(rf,param_grid,cv=5)
  50.  
    GS.fit(data.data,data.target)
  51.  
    GS.best_params_
  52.  
    GS.best_score_
  53.  
     
  54.  
    #调整max_features
  55.  
    param_grid = {'max_features' : np.arange(5,30,1)}
  56.  
    rf = RandomForestClassifier(n_estimators=11,random_state=42)
  57.  
    GS = GridSearchCV(rf,param_grid,cv=5)
  58.  
    GS.fit(data.data,data.target)
  59.  
    GS.best_params_
  60.  
    GS.best_score_
  61.  
     
  62.  
    #调整min_samples_leaf
  63.  
    param_grid = {'min_samples_leaf' : np.arange(1,1 10,1)}
  64.  
    #一般是从其最小值开始向上增加10或者20
  65.  
    # 面对高维度高样本数据,如果不放心,也可以直接 50,对于大型数据可能需要增加200-300
  66.  
    # 如果调整的时候发现准确率怎么都上不来,那可以放心大胆调一个很大的数据,大力限制模型的复杂度
  67.  
    rf = RandomForestClassifier(n_estimators=11,random_state=42)
  68.  
    GS = GridSearchCV(rf,param_grid,cv=5)
  69.  
    GS.fit(data.data,data.target)
  70.  
    GS.best_params_
  71.  
    GS.best_score_
  72.  
     
  73.  
    #调整min_samples_split
  74.  
    param_grid = {'min_samples_split' : np.arange(2,2 20,1)}
  75.  
    #一般是从其最小值开始向上增加10或者20
  76.  
    # 面对高维度高样本数据,如果不放心,也可以直接 50,对于大型数据可能需要增加200-300
  77.  
    # 如果调整的时候发现准确率怎么都上不来,那可以放心大胆调一个很大的数据,大力限制模型的复杂度
  78.  
    rf = RandomForestClassifier(n_estimators=11,random_state=42)
  79.  
    GS = GridSearchCV(rf,param_grid,cv=5)
  80.  
    GS.fit(data.data,data.target)
  81.  
    GS.best_params_
  82.  
    GS.best_score_
  83.  
     
  84.  
    #调整criterion
  85.  
    param_grid = {'criterion' :['gini','entropy']}
  86.  
    #一般是从其最小值开始向上增加10或者20
  87.  
    # 面对高维度高样本数据,如果不放心,也可以直接 50,对于大型数据可能需要增加200-300
  88.  
    # 如果调整的时候发现准确率怎么都上不来,那可以放心大胆调一个很大的数据,大力限制模型的复杂度
  89.  
    rf = RandomForestClassifier(n_estimators=11,random_state=42)
  90.  
    GS = GridSearchCV(rf,param_grid,cv=5)
  91.  
    GS.fit(data.data,data.target)
  92.  
    GS.best_params_
  93.  
    GS.best_score_
学新通

希望大家有所收获,欢迎留言~

这篇好文章是转载于:学新通技术网

  • 版权申明: 本站部分内容来自互联网,仅供学习及演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,请提供相关证据及您的身份证明,我们将在收到邮件后48小时内删除。
  • 本站站名: 学新通技术网
  • 本文地址: /boutique/detail/tanhhgijef
系列文章
更多 icon
同类精品
更多 icon
继续加载