検索
勾配ブースティングの中を見てみる
- M.R 
- 2021年1月1日
- 読了時間: 2分
更新日:2021年2月27日
概要
これまでDecisionTreeClassifierとRandomForestClassifierのソースを見てきました。今回はRandomForestClassifierと同じくDesicionTreeClassifierのensembleである勾配ブースティング(GradientBoostingClassifier)のソースを見てみます。
詳細
GradientBoostingClassifierのfitメソッドは継承元であるBaseGradientBosstingのものを使用します。
def fit(self, X, y, sample_weight=None, monitor=None):
//略
    # fit the boosting 
    stagesn_stages=self._fit_stages(X, y, raw_predictions, 
          sample_weight, self._rng, X_val, y_val,sample_weight_val, 
          begin_at_stage, monitor)パラメータの初期設定や値チェックなどがありますが基本的に_fit_stages()を呼び出しているだけです。
def _fit_stages(self, X, y, raw_predictions, sample_weight, 
                random_state,X_val, y_val, 
                sample_weight_val,begin_at_stage=0, monitor=None):
                
    //略
    
    for i in range(begin_at_stage, self.n_estimators):
    
        //略    
        # fit next stage of trees
        raw_predictions=self._fit_stage(i, X, y, raw_predictions, 
                             sample_weight, sample_mask,random_state, 
                             X_csc, X_csr)_fit_stagesでは木の数(n_estimators)だけ_fit_stageメソッドを実行し、木を学習させていきます。
def _fit_stage(self, i, X, y, raw_predictions, sample_weight, 
               sample_mask,random_state, X_csc=None, X_csr=None):
    //略
    for k in range(loss.K):
        if loss.is_multi_class:
            y=np.array(original_y==k, dtype=np.float64)
            
            //1 勾配を計算
        residual=loss.negative_gradient(y, raw_predictions_copy, 
                                     k=k,sample_weight=sample_weight)
            
            //2 決定木を作成
        tree=DecisionTreeRegressor(          
            criterion=self.criterion,
            splitter='best',         
            max_depth=self.max_depth,
            min_samples_split=self.min_samples_split,
            min_samples_leaf=self.min_samples_leaf,                     
            min_weight_fraction_leaf=self.min_weight_fraction_leaf,                       
            min_impurity_decrease=self.min_impurity_decrease,
            min_impurity_split=self.min_impurity_split,
            max_features=self.max_features,
            max_leaf_nodes=self.max_leaf_nodes,
            random_state=random_state,
            ccp_alpha=self.ccp_alpha)
                         
        if self.subsample<1.0:
            # no inplace multiplication!                              
            sample_weight=
            sample_weight*sample_mask.astype(np.float64)
                
         X=X_csr if X_csr is not None else X 
        //3 木を学習
        tree.fit(X, residual, 
                 sample_weight=sample_weight,check_input=False)
        //4 勾配をもとに木を更新 
        loss.update_terminal_regions(tree.tree_, X, y, residual, 
                     raw_predictions, sample_weight,sample_mask, 
                     learning_rate=self.learning_rate, k=k)
                         
        //5 木をensembleに追加
        self.estimators_[i, k] =tree
            
    return raw_predictions_fit_stageメソッドを見てみます。
loss.Kはカテゴリ数です。
1では前の木の予測結果と実際の値の差分をとっています。
2で新しく木を作り、1で計算した誤差を3で学習させます。学習した結果を4で木に反映させ、5でenembleに追加しています。
(なんでカテゴリ毎にループして木を作り直すのかがいまいち分かっていません、、、)
sklearn/ensemble/_gb_losses.py#L610
def _update_terminal_region(self, tree, terminal_regions, leaf, X, y,residual, raw_predictions, sample_weight):
    """Make a single Newton-Raphson step.
    our node estimate is given by:            
         sum(w * (y - prob)) / sum(w * prob * (1 - prob))
    we take advantage that: y - prob = residual 
    """
    terminal_region=np.where(terminal_regions==leaf[0]
    residual=residual.take(terminal_region, axis=0)
    y=y.take(terminal_region, axis=0)
    sample_weight=sample_weight.take(terminal_region, axis=0)
    
    numerator=np.sum(sample_weight*residual)
    denominator=np.sum(sample_weight*(y-residual) * (1-y+residual))
    # prevents overflow and division by zero
    if abs(denominator) <1e-150:
        tree.value[leaf, 0, 0] =0.0
    else:
        tree.value[leaf, 0, 0] =numerator/denominator4の_update_terminal_regionでは実際に勾配をもとに木のnodeの値を更新しています。
参考文献
最後に
勾配ブースティングに関してはコードよりもなぜそういう計算をするのか、が難しいですね。






コメント